Skip to content

Commit 5199113

Browse files
authored
PY: Added NeighborList python bindings (#1314)
* Added NeighborList python bindings * Added example test script * Added loop example * Added dtype property to NeighborList bindings * Added NeighborList python documentation Signed-off-by: Jared Duffey <[email protected]> --------- Signed-off-by: Jared Duffey <[email protected]>
1 parent dce1694 commit 5199113

File tree

4 files changed

+147
-0
lines changed

4 files changed

+147
-0
lines changed

src/Plugins/SimplnxCore/wrapping/python/simplnxpy.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <simplnx/DataStructure/Geometry/TetrahedralGeom.hpp>
2626
#include <simplnx/DataStructure/Geometry/TriangleGeom.hpp>
2727
#include <simplnx/DataStructure/Geometry/VertexGeom.hpp>
28+
#include <simplnx/DataStructure/NeighborList.hpp>
2829
#include <simplnx/DataStructure/StringArray.hpp>
2930
#include <simplnx/Filter/Actions/CopyArrayInstanceAction.hpp>
3031
#include <simplnx/Filter/Actions/CopyDataObjectAction.hpp>
@@ -209,6 +210,35 @@ auto BindDataArray(py::handle scope, const char* name)
209210
#define SIMPLNX_PY_BIND_DATA_STORE(scope, className) BindDataStore<className::value_type>(scope, #className)
210211
#define SIMPLNX_PY_BIND_ABSTRACT_DATA_STORE(scope, className) SIMPLNX_PY_BIND_CLASS_VARIADIC(scope, className, IDataStore, std::shared_ptr<className>)
211212

213+
template <class T>
214+
auto BindNeighborList(py::handle scope, const char* name)
215+
{
216+
using NeighborListType = NeighborList<T>;
217+
218+
auto neighborList = py::class_<NeighborListType, INeighborList, std::shared_ptr<NeighborListType>>(scope, name);
219+
neighborList.def_property_readonly_static("dtype", []([[maybe_unused]] py::object self) { return py::dtype::of<T>(); });
220+
neighborList.def("get_list", &NeighborListType::getList, "grain_id"_a);
221+
neighborList.def("set_list", py::overload_cast<int32, const typename NeighborListType::VectorType&>(&NeighborListType::setList), "grain_id"_a, "neighbor_list"_a);
222+
neighborList.def(
223+
"get_value",
224+
[](const NeighborListType& self, int32 grainId, int32 index) {
225+
bool ok = false;
226+
int32 value = self.getValue(grainId, index, ok);
227+
if(!ok)
228+
{
229+
throw std::out_of_range(fmt::format("NeighborList.get_value called with grain_id = {} and index = {} which was out of range", grainId, index));
230+
}
231+
return value;
232+
},
233+
"grain_id"_a, "index"_a);
234+
neighborList.def("add_entry", &NeighborListType::addEntry, "grain_id"_a, "value"_a);
235+
neighborList.def("get_list_size", &NeighborListType::getListSize, "grain_id"_a);
236+
neighborList.def("get_number_of_lists", &NeighborListType::getNumberOfLists);
237+
return neighborList;
238+
}
239+
240+
#define SIMPLNX_PY_BIND_NEIGHBOR_LIST(scope, className) BindNeighborList<className::value_type>(scope, #className)
241+
212242
template <class GeomT>
213243
auto BindCreateGeometry2DAction(py::handle scope, const char* name)
214244
{
@@ -1026,6 +1056,19 @@ PYBIND11_MODULE(simplnx, mod)
10261056
stringArray.def_property_readonly("values", &StringArray::values);
10271057
stringArray.def("resize_tuples", &StringArray::resizeTuples, "Resize the tuples with the given shape");
10281058

1059+
auto iNeighborList = py::class_<INeighborList, IArray, std::shared_ptr<INeighborList>>(mod, "INeighborList");
1060+
1061+
auto neighborListInt8 = SIMPLNX_PY_BIND_NEIGHBOR_LIST(mod, Int8NeighborList);
1062+
auto neighborListUInt8 = SIMPLNX_PY_BIND_NEIGHBOR_LIST(mod, UInt8NeighborList);
1063+
auto neighborListInt16 = SIMPLNX_PY_BIND_NEIGHBOR_LIST(mod, Int16NeighborList);
1064+
auto neighborListUInt16 = SIMPLNX_PY_BIND_NEIGHBOR_LIST(mod, UInt16NeighborList);
1065+
auto neighborListInt32 = SIMPLNX_PY_BIND_NEIGHBOR_LIST(mod, Int32NeighborList);
1066+
auto neighborListUInt32 = SIMPLNX_PY_BIND_NEIGHBOR_LIST(mod, UInt32NeighborList);
1067+
auto neighborListInt64 = SIMPLNX_PY_BIND_NEIGHBOR_LIST(mod, Int64NeighborList);
1068+
auto neighborListUInt64 = SIMPLNX_PY_BIND_NEIGHBOR_LIST(mod, UInt64NeighborList);
1069+
auto neighborListFloat32 = SIMPLNX_PY_BIND_NEIGHBOR_LIST(mod, Float32NeighborList);
1070+
auto neighborListFloat64 = SIMPLNX_PY_BIND_NEIGHBOR_LIST(mod, Float64NeighborList);
1071+
10291072
auto dataArrayInt8 = SIMPLNX_PY_BIND_DATA_ARRAY(mod, Int8Array);
10301073
auto dataArrayUInt8 = SIMPLNX_PY_BIND_DATA_ARRAY(mod, UInt8Array);
10311074
auto dataArrayInt16 = SIMPLNX_PY_BIND_DATA_ARRAY(mod, Int16Array);

wrapping/python/docs/source/DataObjects.rst

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,75 @@ DataStore Example Usage
366366
# The developer can also just inline the above lines into a single line
367367
npdata = data_structure[output_array_path].store.npview
368368
369+
370+
.. _NeighborList:
371+
372+
NeighborList
373+
-----------
374+
375+
.. py:class:: NeighborList[T]
376+
377+
.. py:property:: tuple_shape
378+
:type: list[int]
379+
380+
The dimensions of the NeighborList from slowest to fastest (C Ordering)
381+
382+
.. py:property:: component_shape
383+
:type: list[int]
384+
385+
The dimensions of the components of the NeighborList from slowest to fastest (C Ordering)
386+
387+
.. py:property:: dtype
388+
:type: numpy.dtype
389+
390+
The type of the NeighborList elements
391+
392+
.. py:method:: get_list(grain_id: int) -> list[T]
393+
394+
Returns the target neighbor list.
395+
396+
:param int grain_id: The grain id of the target list
397+
:return: The target list
398+
:rtype: list[T]
399+
400+
.. py:method:: set_list(grain_id: int, neighbor_list: list[T])
401+
402+
Set the target neighbor list to the given list.
403+
404+
:param int grain_id: The grain id of the target list
405+
:param list[T] neighbor_list: The replacement list
406+
407+
.. py:method:: get_value(grain_id: int, index: int) -> T
408+
409+
Returns the value at the given index in the target neighbor list.
410+
411+
:param int grain_id: The grain id of the target list
412+
:param int index: The index into the target list
413+
:return: The target value
414+
:rtype: T
415+
416+
.. py:method:: add_entry(grain_id: int, value: T)
417+
418+
Appends the given value to the target neighbor list.
419+
420+
:param int grain_id: The grain id of the target list
421+
:param T value: The value to append
422+
423+
.. py:method:: get_list_size(grain_id: int) -> int
424+
425+
Returns the size of the target neighbor list.
426+
427+
:param int grain_id: The grain id of the target list
428+
:return: The target list size
429+
:rtype: int
430+
431+
.. py:method:: get_number_of_lists() -> int
432+
433+
Returns the total number of lists.
434+
435+
:return: The total number of lists
436+
:rtype: int
437+
369438
.. _AttributeMatrix:
370439

371440
AttributeMatrix

wrapping/python/examples/scripts/SourceList.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ set(SIMPLNX_PYTHON_TESTS
1414
"output_file"
1515
"pipeline"
1616
"read_csv_file"
17+
"neighbor_list"
1718
# "read_esprit_data"
1819
)
1920

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import simplnx as nx
2+
3+
ds = nx.DataStructure()
4+
5+
SIZE = 42
6+
PATH = nx.DataPath(['foo'])
7+
8+
action = nx.CreateNeighborListAction(nx.DataType.int32, SIZE, PATH)
9+
10+
assert action.apply(ds, nx.IDataAction.Mode.Execute)
11+
12+
nl: nx.Int32NeighborList = ds[PATH]
13+
assert nl.get_number_of_lists() == SIZE
14+
15+
INDEX = 2
16+
17+
# get_list returns a copy
18+
assert nl.get_list(INDEX) == []
19+
assert nl.get_list_size(INDEX) == 0
20+
21+
VALUE = 4
22+
23+
nl.add_entry(INDEX, VALUE)
24+
25+
assert nl.get_list(INDEX) == [VALUE]
26+
assert nl.get_list_size(INDEX) == 1
27+
28+
for i in range(3):
29+
nl.add_entry(1, i)
30+
31+
for grain_id in range(nl.get_number_of_lists()):
32+
print(f'grain_id={grain_id}')
33+
for value in nl.get_list(grain_id):
34+
print(value)

0 commit comments

Comments
 (0)