Skip to content

Commit 24922d3

Browse files
nyoungbqimikejackson
authored andcommitted
Move IterativeClosestPoint to algorithm class
1 parent 0dbdc15 commit 24922d3

File tree

4 files changed

+262
-172
lines changed

4 files changed

+262
-172
lines changed

src/Plugins/SimplnxCore/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ set(AlgorithmList
205205
FlyingEdges3D
206206
IdentifyDuplicateVertices
207207
InitializeData
208+
IterativeClosestPoint
208209
LabelTriangleGeometry
209210
LaplacianSmoothing
210211
NearestPointFuseRegularGrids
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
#include "IterativeClosestPoint.hpp"
2+
3+
#include "SimplnxCore/utils/nanoflann.hpp"
4+
5+
#include "simplnx/DataStructure/DataArray.hpp"
6+
#include "simplnx/DataStructure/Geometry/VertexGeom.hpp"
7+
8+
#include <Eigen/Geometry>
9+
10+
using namespace nx::core;
11+
12+
namespace
13+
{
14+
constexpr int32 k_MissingVertices = -4503;
15+
constexpr int32 k_EmptyVertices = -4505;
16+
17+
template <typename Derived>
18+
struct VertexGeomAdaptor
19+
{
20+
const Derived& obj;
21+
AbstractDataStore<INodeGeometry0D::SharedVertexList::value_type>* verts;
22+
size_t m_NumComponents = 0;
23+
size_t m_NumTuples = 0;
24+
25+
explicit VertexGeomAdaptor(const Derived& obj_)
26+
: obj(obj_)
27+
{
28+
// These values never change for the lifetime of this object so cache them now.
29+
verts = derived()->getVertices()->getDataStore();
30+
m_NumComponents = verts->getNumberOfComponents();
31+
m_NumTuples = verts->getNumberOfTuples();
32+
}
33+
34+
[[nodiscard]] const Derived& derived() const
35+
{
36+
return obj;
37+
}
38+
39+
[[nodiscard]] usize kdtree_get_point_count() const
40+
{
41+
return m_NumTuples;
42+
}
43+
44+
[[nodiscard]] float kdtree_get_pt(const usize idx, const usize dim) const
45+
{
46+
auto offset = idx * m_NumComponents;
47+
return verts->getValue(offset + dim);
48+
}
49+
50+
template <class BBOX>
51+
bool kdtree_get_bbox(BBOX& /*bb*/) const
52+
{
53+
return false;
54+
}
55+
};
56+
} // namespace
57+
58+
// -----------------------------------------------------------------------------
59+
IterativeClosestPoint::IterativeClosestPoint(DataStructure& dataStructure, const IFilter::MessageHandler& mesgHandler, const std::atomic_bool& shouldCancel,
60+
IterativeClosestPointInputValues* inputValues)
61+
: m_DataStructure(dataStructure)
62+
, m_InputValues(inputValues)
63+
, m_ShouldCancel(shouldCancel)
64+
, m_MessageHandler(mesgHandler)
65+
{
66+
}
67+
68+
// -----------------------------------------------------------------------------
69+
IterativeClosestPoint::~IterativeClosestPoint() noexcept = default;
70+
71+
// -----------------------------------------------------------------------------
72+
void IterativeClosestPoint::updateProgress(const std::string& message)
73+
{
74+
m_MessageHandler(IFilter::Message::Type::Info, message);
75+
}
76+
77+
// -----------------------------------------------------------------------------
78+
const std::atomic_bool& IterativeClosestPoint::getCancel()
79+
{
80+
return m_ShouldCancel;
81+
}
82+
83+
// -----------------------------------------------------------------------------
84+
Result<> IterativeClosestPoint::operator()()
85+
{
86+
auto movingVertexGeom = m_DataStructure.getDataAs<VertexGeom>(m_InputValues->MovingVertexPath);
87+
auto targetVertexGeom = m_DataStructure.getDataAs<VertexGeom>(m_InputValues->TargetVertexPath);
88+
89+
if(movingVertexGeom == nullptr)
90+
{
91+
return MakeErrorResult(k_MissingVertices, fmt::format("Moving Vertex Geometry not found at path '{}'", m_InputValues->MovingVertexPath.toString()));
92+
}
93+
if(targetVertexGeom == nullptr)
94+
{
95+
return MakeErrorResult(k_MissingVertices, fmt::format("Target Vertex Geometry not found at path '{}'", m_InputValues->TargetVertexPath.toString()));
96+
}
97+
98+
if(movingVertexGeom->getVertices() == nullptr)
99+
{
100+
return MakeErrorResult(k_MissingVertices, fmt::format("Moving Vertex Geometry does not contain a vertex array"));
101+
}
102+
if(targetVertexGeom->getVertices() == nullptr)
103+
{
104+
return MakeErrorResult(k_MissingVertices, fmt::format("Target Vertex Geometry does not contain a vertex array"));
105+
}
106+
107+
Float32AbstractDataStore& movingStore = movingVertexGeom->getVertices()->getDataStoreRef();
108+
if(movingStore.getNumberOfTuples() == 0)
109+
{
110+
return MakeErrorResult(k_EmptyVertices, fmt::format("Moving Vertex Geometry does not contain any vertices"));
111+
}
112+
Float32AbstractDataStore& targetStore = targetVertexGeom->getVertices()->getDataStoreRef();
113+
if(targetStore.getNumberOfTuples() == 0)
114+
{
115+
return MakeErrorResult(k_EmptyVertices, fmt::format("Target Vertex Geometry does not contain any vertices"));
116+
}
117+
118+
std::vector<float32> movingVector(movingStore.begin(), movingStore.end());
119+
float32* movingCopyPtr = movingVector.data();
120+
DataStructure tmp;
121+
122+
usize numMovingVerts = movingVertexGeom->getNumberOfVertices();
123+
std::vector<float32> dynTarget(numMovingVerts * 3, 0.0F);
124+
float* dynTargetPtr = dynTarget.data();
125+
126+
using Adaptor = VertexGeomAdaptor<VertexGeom*>;
127+
const Adaptor adaptor(targetVertexGeom);
128+
129+
m_MessageHandler("Building kd-tree index...");
130+
131+
using KDtree = nanoflann::KDTreeSingleIndexAdaptor<nanoflann::L2_Adaptor<float32, Adaptor>, Adaptor, 3>;
132+
KDtree index(3, adaptor, nanoflann::KDTreeSingleIndexAdaptorParams(30));
133+
index.buildIndex();
134+
135+
const usize nn = 1;
136+
137+
typedef Eigen::Matrix<float, 3, Eigen::Dynamic, Eigen::ColMajor> PointCloud;
138+
typedef Eigen::Matrix<float, 4, 4, Eigen::ColMajor> UmeyamaTransform;
139+
140+
UmeyamaTransform globalTransform;
141+
globalTransform << 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1;
142+
143+
auto start = std::chrono::steady_clock::now();
144+
for(usize i = 0; i < m_InputValues->NumIterations; i++)
145+
{
146+
if(m_ShouldCancel)
147+
{
148+
return {};
149+
}
150+
151+
for(usize j = 0; j < numMovingVerts; j++)
152+
{
153+
usize identifier;
154+
float dist;
155+
nanoflann::KNNResultSet<float> results(nn);
156+
results.init(&identifier, &dist);
157+
index.findNeighbors(results, movingCopyPtr + (3 * j), nanoflann::SearchParams());
158+
dynTargetPtr[3 * j + 0] = targetStore[3 * identifier + 0];
159+
dynTargetPtr[3 * j + 1] = targetStore[3 * identifier + 1];
160+
dynTargetPtr[3 * j + 2] = targetStore[3 * identifier + 2];
161+
}
162+
163+
Eigen::Map<PointCloud> moving_(movingCopyPtr, 3, numMovingVerts);
164+
Eigen::Map<PointCloud> target_(dynTargetPtr, 3, numMovingVerts);
165+
166+
UmeyamaTransform transform = Eigen::umeyama(moving_, target_, false);
167+
168+
for(usize j = 0; j < numMovingVerts; j++)
169+
{
170+
Eigen::Vector4f position(movingCopyPtr[3 * j + 0], movingCopyPtr[3 * j + 1], movingCopyPtr[3 * j + 2], 1);
171+
Eigen::Vector4f transformedPosition = transform * position;
172+
std::memcpy(movingCopyPtr + (3 * j), transformedPosition.data(), sizeof(float) * 3);
173+
}
174+
// Update the global transform
175+
globalTransform = transform * globalTransform;
176+
177+
auto now = std::chrono::steady_clock::now();
178+
if(std::chrono::duration_cast<std::chrono::milliseconds>(now - start).count() > 1000)
179+
{
180+
m_MessageHandler(fmt::format("Performing Registration Iterations || {}% Completed", static_cast<int64>((static_cast<float>(i) / m_InputValues->NumIterations) * 100.0f)));
181+
start = now;
182+
}
183+
}
184+
185+
auto& transformStore = m_DataStructure.getDataAs<Float32Array>(m_InputValues->TransformArrayPath)->getDataStoreRef();
186+
187+
if(m_InputValues->ApplyTransformation)
188+
{
189+
for(usize j = 0; j < numMovingVerts; j++)
190+
{
191+
Eigen::Vector4f position(movingStore[3 * j + 0], movingStore[3 * j + 1], movingStore[3 * j + 2], 1);
192+
Eigen::Vector4f transformedPosition = globalTransform * position;
193+
for(usize k = 0; k < 3; k++)
194+
{
195+
movingStore[3 * j + k] = transformedPosition.data()[k];
196+
}
197+
}
198+
}
199+
200+
globalTransform.transposeInPlace();
201+
for(usize j = 0; j < 16; j++)
202+
{
203+
transformStore[j] = globalTransform.data()[j];
204+
}
205+
206+
return {};
207+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#pragma once
2+
3+
#include "SimplnxCore/SimplnxCore_export.hpp"
4+
5+
#include "simplnx/DataStructure/DataPath.hpp"
6+
#include "simplnx/DataStructure/DataStructure.hpp"
7+
#include "simplnx/Filter/IFilter.hpp"
8+
9+
namespace nx::core
10+
{
11+
12+
struct SIMPLNXCORE_EXPORT IterativeClosestPointInputValues
13+
{
14+
bool ApplyTransformation;
15+
uint64 NumIterations;
16+
DataPath MovingVertexPath;
17+
DataPath TargetVertexPath;
18+
DataPath TransformArrayPath;
19+
};
20+
21+
/**
22+
* @class
23+
*/
24+
class SIMPLNXCORE_EXPORT IterativeClosestPoint
25+
{
26+
public:
27+
IterativeClosestPoint(DataStructure& dataStructure, const IFilter::MessageHandler& mesgHandler, const std::atomic_bool& shouldCancel, IterativeClosestPointInputValues* inputValues);
28+
~IterativeClosestPoint() noexcept;
29+
30+
IterativeClosestPoint(const IterativeClosestPoint&) = delete;
31+
IterativeClosestPoint(IterativeClosestPoint&&) noexcept = delete;
32+
IterativeClosestPoint& operator=(const IterativeClosestPoint&) = delete;
33+
IterativeClosestPoint& operator=(IterativeClosestPoint&&) noexcept = delete;
34+
35+
Result<> operator()();
36+
void updateProgress(const std::string& message);
37+
const std::atomic_bool& getCancel();
38+
39+
private:
40+
DataStructure& m_DataStructure;
41+
const IterativeClosestPointInputValues* m_InputValues = nullptr;
42+
const std::atomic_bool& m_ShouldCancel;
43+
const IFilter::MessageHandler& m_MessageHandler;
44+
};
45+
46+
} // namespace nx::core

0 commit comments

Comments
 (0)