Skip to content

Commit 7ff9337

Browse files
update
1 parent dc691ac commit 7ff9337

File tree

9 files changed

+126
-126
lines changed

9 files changed

+126
-126
lines changed

example/define_custom_local_operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33

44

5-
class CustomLocalOperator(Htool.LocalOperator):
5+
class CustomLocalOperator(Htool.RestrictedGlobalToLocalOperator):
66
def __init__(
77
self,
88
generator: Htool.VirtualGenerator,

example/use_local_hmatrix_compression.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,9 @@
102102
)
103103

104104
if local_operator_1:
105-
distributed_operator.add_local_operator(local_operator_1)
105+
distributed_operator.add_global_to_local_operator(local_operator_1)
106106
if local_operator_2:
107-
distributed_operator.add_local_operator(local_operator_2)
107+
distributed_operator.add_global_to_local_operator(local_operator_2)
108108

109109
# Test matrix vector product
110110
np.random.seed(0)

lib/htool

src/htool/distributed_operator/distributed_operator.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ void declare_distributed_operator(py::module &m, const std::string &class_name)
1515

1616
py::class_<Class> py_class(m, class_name.c_str());
1717
py_class.def(py::init<VirtualPartition<CoefficientPrecision> &, VirtualPartition<CoefficientPrecision> &, MPI_Comm_wrapper>(), py::keep_alive<1, 2>(), py::keep_alive<1, 3>());
18-
py_class.def("add_local_operator", &Class::add_local_operator, py::keep_alive<1, 2>());
18+
py_class.def("add_global_to_local_operator", &Class::add_global_to_local_operator, py::keep_alive<1, 2>());
19+
py_class.def("add_local_to_local_operator", &Class::add_local_to_local_operator, py::keep_alive<1, 2>());
1920

2021
// Linear algebra
2122
py_class.def(

src/htool/distributed_operator/utility.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ void declare_distributed_operator_utility(py::module &m, std::string prefix = ""
1717
std::string default_local_approximation_name = prefix + "DefaultLocalApproximationBuilder";
1818

1919
py::class_<CustomApproximation> custom_approximation_class(m, custom_approximation_name.c_str());
20-
custom_approximation_class.def(py::init<const Cluster<CoordinatePrecision> &, const Cluster<CoordinatePrecision> &, MPI_Comm_wrapper, const VirtualLocalOperator<CoefficientPrecision> &>());
20+
custom_approximation_class.def(py::init<const Cluster<CoordinatePrecision> &, const Cluster<CoordinatePrecision> &, MPI_Comm_wrapper, const VirtualLocalToLocalOperator<CoefficientPrecision> &>());
21+
custom_approximation_class.def(py::init<const Cluster<CoordinatePrecision> &, const Cluster<CoordinatePrecision> &, MPI_Comm_wrapper, const VirtualGlobalToLocalOperator<CoefficientPrecision> &>());
2122
custom_approximation_class.def_property_readonly(
2223
"distributed_operator", [](const CustomApproximation &self) { return &self.distributed_operator; }, py::return_value_policy::reference_internal);
2324

src/htool/local_operator/local_operator.hpp

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
#ifndef HTOOL_LOCAL_OPERATOR_CPP
22
#define HTOOL_LOCAL_OPERATOR_CPP
33

4-
#include <htool/distributed_operator/implementations/local_operators/local_operator.hpp>
4+
#include <htool/distributed_operator/implementations/global_to_local_operators/restricted_operator.hpp>
55
#include <pybind11/pybind11.h>
66

77
template <typename CoefficientPrecision>
8-
class LocalOperatorPython : public htool::LocalOperator<CoefficientPrecision> {
8+
class RestrictedGlobalToLocalOperatorPython : public htool::RestrictedGlobalToLocalOperator<CoefficientPrecision> {
99
public:
10-
using htool::LocalOperator<CoefficientPrecision>::LocalOperator;
10+
using htool::RestrictedGlobalToLocalOperator<CoefficientPrecision>::RestrictedGlobalToLocalOperator;
1111

12-
LocalOperatorPython(LocalRenumbering target_local_renumbering, LocalRenumbering source_local_renumbering, bool target_use_permutation_to_mvprod = false, bool source_use_permutation_to_mvprod = false) : LocalOperator<CoefficientPrecision>(target_local_renumbering, source_local_renumbering, target_use_permutation_to_mvprod, source_use_permutation_to_mvprod) {}
12+
RestrictedGlobalToLocalOperatorPython(LocalRenumbering target_local_renumbering, LocalRenumbering source_local_renumbering, bool target_use_permutation_to_mvprod = false, bool source_use_permutation_to_mvprod = false) : RestrictedGlobalToLocalOperator<CoefficientPrecision>(target_local_renumbering, source_local_renumbering, target_use_permutation_to_mvprod, source_use_permutation_to_mvprod) {}
1313

1414
void local_add_vector_product(char trans, CoefficientPrecision alpha, const CoefficientPrecision *in, CoefficientPrecision beta, CoefficientPrecision *out) const override {
1515

@@ -35,16 +35,16 @@ class LocalOperatorPython : public htool::LocalOperator<CoefficientPrecision> {
3535
};
3636

3737
template <typename CoefficientPrecision>
38-
class PyLocalOperator : public LocalOperatorPython<CoefficientPrecision> {
38+
class PyRestrictedGlobalToLocalOperator : public RestrictedGlobalToLocalOperatorPython<CoefficientPrecision> {
3939
public:
40-
using LocalOperatorPython<CoefficientPrecision>::LocalOperatorPython;
40+
using RestrictedGlobalToLocalOperatorPython<CoefficientPrecision>::RestrictedGlobalToLocalOperatorPython;
4141

4242
/* Trampoline (need one for each virtual function) */
4343
virtual void add_vector_product(char trans, CoefficientPrecision alpha, const py::array_t<CoefficientPrecision> &in, CoefficientPrecision beta, py::array_t<CoefficientPrecision> &out) const override {
4444
PYBIND11_OVERRIDE_PURE(
45-
void, /* Return type */
46-
LocalOperatorPython<CoefficientPrecision>, /* Parent class */
47-
add_vector_product, /* Name of function in C++ (must match Python name) */
45+
void, /* Return type */
46+
RestrictedGlobalToLocalOperatorPython<CoefficientPrecision>, /* Parent class */
47+
add_vector_product, /* Name of function in C++ (must match Python name) */
4848
trans,
4949
alpha,
5050
in,
@@ -54,9 +54,9 @@ class PyLocalOperator : public LocalOperatorPython<CoefficientPrecision> {
5454
}
5555
virtual void add_matrix_product_row_major(char trans, CoefficientPrecision alpha, const py::array_t<CoefficientPrecision, py::array::c_style> &in, CoefficientPrecision beta, py::array_t<CoefficientPrecision, py::array::c_style> &out) const override {
5656
PYBIND11_OVERRIDE_PURE(
57-
void, /* Return type */
58-
LocalOperatorPython<CoefficientPrecision>, /* Parent class */
59-
add_matrix_product_row_major, /* Name of function in C++ (must match Python name) */
57+
void, /* Return type */
58+
RestrictedGlobalToLocalOperatorPython<CoefficientPrecision>, /* Parent class */
59+
add_matrix_product_row_major, /* Name of function in C++ (must match Python name) */
6060
trans,
6161
alpha,
6262
in,
@@ -67,15 +67,15 @@ class PyLocalOperator : public LocalOperatorPython<CoefficientPrecision> {
6767
};
6868

6969
template <typename CoefficientPrecision>
70-
void declare_local_operator(py::module &m, const std::string &class_name) {
71-
using VirtualClass = htool::VirtualLocalOperator<CoefficientPrecision>;
72-
py::class_<VirtualClass>(m, ("Virtual" + class_name).c_str());
70+
void declare_global_to_local_operator(py::module &m, const std::string &class_name) {
71+
using VirtualClass = htool::VirtualGlobalToLocalOperator<CoefficientPrecision>;
72+
py::class_<VirtualClass>(m, "IGlobalToLocalOperator");
7373

74-
using BaseClass = LocalOperator<CoefficientPrecision>;
75-
py::class_<BaseClass, VirtualClass> py_base_class(m, ("Base" + class_name).c_str());
74+
using BaseClass = RestrictedGlobalToLocalOperator<CoefficientPrecision>;
75+
py::class_<BaseClass, VirtualClass> py_base_class(m, "IRestrictedGlobalToLocalOperator");
7676

77-
using Class = LocalOperatorPython<CoefficientPrecision>;
78-
py::class_<Class, PyLocalOperator<CoefficientPrecision>, BaseClass> py_class(m, class_name.c_str());
77+
using Class = RestrictedGlobalToLocalOperatorPython<CoefficientPrecision>;
78+
py::class_<Class, PyRestrictedGlobalToLocalOperator<CoefficientPrecision>, BaseClass> py_class(m, class_name.c_str());
7979
py_class.def(py::init<LocalRenumbering, LocalRenumbering, bool, bool>());
8080
py_class.def("add_vector_product", &Class::add_vector_product, py::arg("trans"), py::arg("alpha"), py::arg("in").noconvert(true), py::arg("beta"), py::arg("out").noconvert(true));
8181
py_class.def("add_matrix_product_row_major", &Class::add_matrix_product_row_major);

src/htool/local_operator/virtual_local_operator.hpp

Lines changed: 0 additions & 98 deletions
This file was deleted.

0 commit comments

Comments
 (0)