Skip to content

Commit

Permalink
matrix fucntion changes + removing openmp from arm build
Browse files Browse the repository at this point in the history
  • Loading branch information
peekxc committed Dec 12, 2023
1 parent 75e00ea commit 34563a7
Show file tree
Hide file tree
Showing 11 changed files with 121 additions and 58 deletions.
6 changes: 3 additions & 3 deletions .cirrus.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ windows_task:
windows_container:
image: cirrusci/windowsservercore:2019
only_if: changesInclude('.cirrus.yml', '**.{h,cpp,py}')
env:
PATH: ${PATH}:"C:\python311"
setup_python_script: |
setup_script: |
choco install -y python311
env:
PATH: '%PATH%;C:\ProgramData\chocolatey\bin;C:\Python311'
dependencies_script: |
bash tools/cibw_windows.sh
before_script: |
Expand Down
6 changes: 3 additions & 3 deletions meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ project(
)
OS_platform = host_machine.system()
env = environment()
use_openmp = get_option('use_openmp') and OS_platform != 'windows' and host_machine.cpu() == 'x86_64'

use_openmp = get_option('use_openmp')
if use_openmp and OS_platform != 'windows'
if use_openmp
add_global_arguments('-DOMP_MULTITHREADED=1', language : 'cpp')
message('Compiling with OpenMP support')
endif
Expand Down Expand Up @@ -85,7 +85,7 @@ dependency_map = {}
# subdir('blas') # Configure BLAS / LAPACK

## Include OpenMP (mandatory ; but exclude on windows because it's too difficult to link)
if get_option('use_openmp') and OS_platform != 'windows'
if use_openmp
omp = dependency('openmp', required: false)
openmp_flags = []
if omp.found()
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ requires = ['meson-python', 'wheel', 'ninja', 'pybind11', 'numpy'] # 'pythran-op

[project]
name = "primate"
version = "0.2.6"
version = "0.2.7"
readme = "README.md"
classifiers = [
"Intended Audience :: Science/Research",
Expand Down
17 changes: 12 additions & 5 deletions src/primate/_lanczos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,17 +169,24 @@ void _lanczos_wrapper(py::module& m, const std::string suffix, WrapperFunc wrap
return std::unique_ptr< MatrixFunction< F, WrapperType > >(new MatrixFunction(op, sf, deg, rtol, orth));
}))
.def_property_readonly("shape", &MatrixFunction< F, WrapperType >::shape)
.def_property_readonly("dtype", [](const MatrixFunction< F, WrapperType >& M) -> py::dtype {
auto dtype = pybind11::dtype(pybind11::format_descriptor< F >::format());
return dtype;
})
.def_readonly("deg", &MatrixFunction< F, WrapperType >::deg)
.def_readwrite("rtol", &MatrixFunction< F, WrapperType >::rtol)
.def_readwrite("orth", &MatrixFunction< F, WrapperType >::orth)
.def("matvec", [](const MatrixFunction< F, WrapperType >& m, const py_array< F >& x) -> py_array< F >{
.def("matvec", [](const MatrixFunction< F, WrapperType >& M, const py_array< F >& x) -> py_array< F >{
using VectorF = Eigen::Matrix< F, Dynamic, 1 >;
auto output = static_cast< ArrayF >(VectorF::Zero(m.shape().first));
m.matvec(x.data(), output.data());
auto output = static_cast< ArrayF >(VectorF::Zero(M.shape().first));
M.matvec(x.data(), output.data());
return py::cast(output);
})
.def("matvec", [](const MatrixFunction< F, WrapperType >& m, const py_array< F >& x, py_array< F >& y) -> void {
m.matvec(x.data(), y.mutable_data());
.def("matvec", [](const MatrixFunction< F, WrapperType >& M, const py_array< F >& x, py_array< F >& y) -> void {
if (M.shape().second != x.size() || M.shape().first != y.size()){
throw std::invalid_argument("Input/output dimension mismatch; vector inputs must match shape of the operator.");
}
M.matvec(x.data(), y.mutable_data());
})
// .def_method("__repr__", &MatrixFunction::eval)
;
Expand Down
35 changes: 0 additions & 35 deletions src/primate/_vapprox.cpp

This file was deleted.

2 changes: 2 additions & 0 deletions src/primate/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ python_sources = [
'diagonalize.py',
'plotting.py',
'trace.py',
'sparse.py',
'stats.py',
'special.py',
'__init__.py'
]

Expand Down
40 changes: 40 additions & 0 deletions src/primate/pylinop.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
namespace py = pybind11;

template< typename F >
using py_array = py::array_t< F, py::array::f_style | py::array::forcecast >;

template< std::floating_point F >
struct PyLinearOperator {
using value_type = F;
const py::object _op;
PyLinearOperator(const py::object& op) : _op(op) {
if (!py::hasattr(op, "matvec")) { throw std::invalid_argument("Supplied object is missing 'matvec' attribute."); }
if (!py::hasattr(op, "shape")) { throw std::invalid_argument("Supplied object is missing 'shape' attribute."); }
// if (!op.has_attr("dtype")) { throw std::invalid_argument("Supplied object is missing 'dtype' attribute."); }
}

// Calls the matvec in python, casts the result to py::array_t, and copies through
auto matvec(const F* inp, F* out) const {
py_array< F > input({ static_cast<py::ssize_t>(shape().second) }, inp);
py::object matvec_out = _op.attr("matvec")(input);
py::array_t< F > output = matvec_out.cast< py::array_t< F > >();
std::copy(output.data(), output.data() + output.size(), out);
}

auto matvec(const py_array< F >& input) const -> py_array< F > {
auto out = std::vector< F >(static_cast< size_t >(shape().first), 0);
this->matvec(input.data(), static_cast< F* >(&out[0]));
return py::cast(out);
}

auto shape() const -> pair< size_t, size_t > {
return _op.attr("shape").cast< std::pair< size_t, size_t > >();
}

auto dtype() const -> py::dtype {
auto dtype = pybind11::dtype(pybind11::format_descriptor< F >::format());
return dtype;
}
};
53 changes: 45 additions & 8 deletions src/primate/sparse.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import numpy as np
from typing import Union
from typing import *
from scipy.sparse.linalg import LinearOperator
import _lanczos
from scipy.sparse import issparse

from .special import _builtin_matrix_functions
import _lanczos

def matrix_function(
A: Union[LinearOperator, np.ndarray],
fun: Union[str, Callable] = "identity",
deg: int = 20,
orth: int = 0
rtol: float = 1e-8,
orth: int = 0,
**kwargs
) -> LinearOperator:
"""Constructs an operator approximating the action v |-> f(A)v
Expand All @@ -23,23 +27,56 @@ def matrix_function(
real-valued function defined on the spectrum of `A`.
deg : int, default = 20
Degree of the Krylov expansion.
rtol : float, default = 1e-8
Relative tolerance to consider two Lanczos vectors are numerically orthogonal.
orth: int, default = 0
Number of additional Lanczos vectors to orthogonalize against when building the Krylov basis.
kwargs : dict, optional
additional key-values to parameterize the chosen function 'fun'.
Returns:
--------
operator : LinearOperator
Operator approximating the action of `fun` on the spectrum of `A`
"""
attr_checks = [hasattr(A, "__matmul__"), hasattr(A, "matmul"), hasattr(A, "dot"), hasattr(A, "matvec")]
assert any(attr_checks), "Invalid operator; must have an overloaded 'matvec' or 'matmul' method"
assert hasattr(A, "shape") and len(A.shape) >= 2, "Operator must be at least two dimensional."
assert A.shape[0] == A.shape[1], "This function only works with square, symmetric matrices!"

## Parameterize the type of matrix function
if isinstance(A, np.ndarray):
module_func = "MatrixFunction_dense"
elif issparray(A):
elif issparse(A):
module_func = "MatrixFunction_sparse"
elif isinstance(A, LinearOperator):
module_func = "MatrixFunction_linop"
else:
raise ValueError(f"Invalid type '{type(A)}' supplied for operator A")

#_lanczos.MatrixFunction_sparse(A_sparse, deg, rtol, orth, **dict(function="log"))
#_lanczos.MatrixFunction_sparse(A_sparse, deg, rtol, orth, **dict(function="log"))
## Get the dtype; infer it if it's not available
f_dtype = (A @ np.zeros(A.shape[1])).dtype if not hasattr(A, "dtype") else A.dtype
i_dtype = np.int32
assert f_dtype.type == np.float32 or f_dtype.type == np.float64, "Only 32- or 64-bit floating point numbers are supported."

## Argument checking
lanczos_rtol = np.finfo(f_dtype).eps # if lanczos_rtol is None else f_dtype.type(lanczos_rtol)
orth = int(orth) # Number of additional vectors should be an integer
deg = max(deg, 2) # Should be at least two

## Parameterize the matrix function and trace call
if isinstance(fun, str):
assert fun in _builtin_matrix_functions, "If given as a string, matrix_function be one of the builtin functions."
kwargs["function"] = fun # _builtin_matrix_functions.index(matrix_function)
elif isinstance(fun, Callable):
kwargs["function"] = "generic"
kwargs["matrix_func"] = fun
else:
raise ValueError(f"Invalid matrix function type '{type(fun)}'")

## Construct the instance
M = getattr(_lanczos, module_func)(A, deg, orth, **dict(function=fun))
M = getattr(_lanczos, module_func)(A, deg, rtol, orth, **kwargs)
return M


Expand Down
4 changes: 4 additions & 0 deletions src/primate/special.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from typing import *
import numpy as np

## Natively support matrix functions
_builtin_matrix_functions = ["identity", "sqrt", "exp", "pow", "log", "numrank", "smoothstep", "gaussian"]

def soft_sign(x: np.ndarray = None, q: int = 1):
"""Soft-sign function.
Expand Down
4 changes: 1 addition & 3 deletions src/primate/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@

## Package imports
from .random import _engine_prefixes, _engines
from .special import _builtin_matrix_functions
import _lanczos

## Natively support matrix functions
_builtin_matrix_functions = ["identity", "sqrt", "exp", "pow", "log", "numrank", "smoothstep", "gaussian"]

def sl_trace (
A: Union[LinearOperator, np.ndarray],
fun: Union[str, Callable] = "identity",
Expand Down
10 changes: 10 additions & 0 deletions tests/test_vapprox.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,15 @@ def test_mf_approx():
assert np.all(y_log_test != y_test)

def test_mf_api():
from primate.sparse import matrix_function
np.random.seed(1234)
n = 10
A_sparse = csc_array(symmetric(n, psd = True), dtype=np.float32)
M = matrix_function(A_sparse)
v0 = np.random.normal(size=n)
assert np.max(np.abs(M.matvec(v0) - A_sparse @ v0)) <= 1e-6
assert True

# from scipy.sparse.linalg import LinearOperator
# assert isinstance(M, LinearOperator)

0 comments on commit 34563a7

Please sign in to comment.