Skip to content

Commit fae5b7a

Browse files
author
Simon Klix
committed
added python bindings for edge features
1 parent 07b8022 commit fae5b7a

File tree

1 file changed

+174
-0
lines changed

1 file changed

+174
-0
lines changed

plugins/machine_learning/python/python_bindings.cpp

+174
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ namespace hal
3030
#endif // ifdef PYBIND11_MODULE
3131

3232
// Define submodules for namespaces
33+
py::module py_edge_feature = m.def_submodule("edge_feature");
3334
py::module py_gate_feature = m.def_submodule("gate_feature");
3435
py::module py_gate_pair_feature = m.def_submodule("gate_pair_feature");
3536
py::module py_gate_pair_label = m.def_submodule("gate_pair_label");
@@ -280,6 +281,12 @@ Get the gates of the context.
280281
:type: tuple[list[int], list[int]]
281282
)");
282283

284+
py_netlist_graph.def_readwrite("edge_features", &machine_learning::NetlistGraph::edge_features, R"(
285+
List of features corresponding to the edge list
286+
287+
:type: list[list[FEATURE_TYPE]]
288+
)");
289+
283290
py_netlist_graph.def_readwrite("direction", &machine_learning::NetlistGraph::direction, R"(
284291
Direction of the graph.
285292
@@ -2384,6 +2391,173 @@ Convert the NetNameKeyWord labeler to a string.
23842391
23852392
:returns: The string representation.
23862393
:rtype: str
2394+
)");
2395+
2396+
py::class_<machine_learning::edge_feature::EdgeFeature, RawPtrWrapper<machine_learning::edge_feature::EdgeFeature>> py_edge_feature_class(py_edge_feature,
2397+
"EdgeFeature",
2398+
R"(
2399+
Abstract base class representing a feature for an edge between two endpoints.
2400+
)");
2401+
2402+
py_edge_feature_class.def(
2403+
"calculate_feature",
2404+
[](machine_learning::edge_feature::EdgeFeature& self, machine_learning::Context& ctx, const Endpoint* source, const Endpoint* destination) -> std::optional<std::vector<FEATURE_TYPE>> {
2405+
auto res = self.calculate_feature(ctx, source, destination);
2406+
if (res.is_ok())
2407+
{
2408+
return res.get();
2409+
}
2410+
else
2411+
{
2412+
log_error("python_context", "Error in EdgeFeature::calculate_feature:\n{}", res.get_error().get());
2413+
return std::nullopt;
2414+
}
2415+
},
2416+
py::arg("ctx"),
2417+
py::arg("source"),
2418+
py::arg("destination"),
2419+
R"(
2420+
Calculate the feature vector for the given edge.
2421+
2422+
:param hal_py.Context ctx: The context.
2423+
:param hal_py.Endpoint source: The source endpoint.
2424+
:param hal_py.Endpoint destination: The destination endpoint.
2425+
:returns: The feature vector or None.
2426+
:rtype: list[hal_py.FEATURE_TYPE] or None
2427+
)");
2428+
2429+
py_edge_feature_class.def("to_string",
2430+
&machine_learning::edge_feature::EdgeFeature::to_string,
2431+
R"(
2432+
Get a string representation of this feature.
2433+
2434+
:returns: A string.
2435+
:rtype: str
2436+
)");
2437+
2438+
py::class_<machine_learning::edge_feature::PinTypesOnehot, RawPtrWrapper<machine_learning::edge_feature::PinTypesOnehot>, machine_learning::edge_feature::EdgeFeature> py_pin_types_onehot(
2439+
py_edge_feature,
2440+
"PinTypesOnehot",
2441+
R"(
2442+
One-hot encoding feature for pin types on an edge.
2443+
)");
2444+
2445+
py_pin_types_onehot.def(py::init<>(), R"(
2446+
Construct a PinTypesOnehot feature.
2447+
)");
2448+
2449+
py_pin_types_onehot.def(
2450+
"calculate_feature",
2451+
[](machine_learning::edge_feature::PinTypesOnehot& self, machine_learning::Context& ctx, const Endpoint* source, const Endpoint* destination) -> std::optional<std::vector<FEATURE_TYPE>> {
2452+
auto res = self.calculate_feature(ctx, source, destination);
2453+
if (res.is_ok())
2454+
{
2455+
return res.get();
2456+
}
2457+
else
2458+
{
2459+
log_error("python_context", "Error in PinTypesOnehot::calculate_feature:\n{}", res.get_error().get());
2460+
return std::nullopt;
2461+
}
2462+
},
2463+
py::arg("ctx"),
2464+
py::arg("source"),
2465+
py::arg("destination"),
2466+
R"(
2467+
One-hot encoding feature vector for pin types.
2468+
2469+
:param hal_py.Context ctx: The context.
2470+
:param hal_py.Endpoint source: The source endpoint.
2471+
:param hal_py.Endpoint destination: The destination endpoint.
2472+
:returns: The feature vector or None.
2473+
:rtype: list[hal_py.FEATURE_TYPE] or None
2474+
)");
2475+
2476+
py_pin_types_onehot.def("to_string",
2477+
&machine_learning::edge_feature::PinTypesOnehot::to_string,
2478+
R"(
2479+
Get a string representation.
2480+
2481+
:returns: A string.
2482+
:rtype: str
2483+
)");
2484+
2485+
py::class_<machine_learning::edge_feature::PinDirectionOnehot, RawPtrWrapper<machine_learning::edge_feature::PinDirectionOnehot>, machine_learning::edge_feature::EdgeFeature>
2486+
py_pin_direction_onehot(py_edge_feature,
2487+
"PinDirectionOnehot",
2488+
R"(
2489+
One-hot encoding feature for pin directions on an edge.
2490+
)");
2491+
2492+
py_pin_direction_onehot.def(py::init<>(), R"(
2493+
Construct a PinDirectionOnehot feature.
2494+
)");
2495+
2496+
py_pin_direction_onehot.def(
2497+
"calculate_feature",
2498+
[](machine_learning::edge_feature::PinDirectionOnehot& self, machine_learning::Context& ctx, const Endpoint* source, const Endpoint* destination)
2499+
-> std::optional<std::vector<FEATURE_TYPE>> {
2500+
auto res = self.calculate_feature(ctx, source, destination);
2501+
if (res.is_ok())
2502+
{
2503+
return res.get();
2504+
}
2505+
else
2506+
{
2507+
log_error("python_context", "Error in PinDirectionOnehot::calculate_feature:\n{}", res.get_error().get());
2508+
return std::nullopt;
2509+
}
2510+
},
2511+
py::arg("ctx"),
2512+
py::arg("source"),
2513+
py::arg("destination"),
2514+
R"(
2515+
One-hot encoding feature vector for pin directions.
2516+
2517+
:param hal_py.Context ctx: The context.
2518+
:param hal_py.Endpoint source: The source endpoint.
2519+
:param hal_py.Endpoint destination: The destination endpoint.
2520+
:returns: The feature vector or None.
2521+
:rtype: list[hal_py.FEATURE_TYPE] or None
2522+
)");
2523+
2524+
py_pin_direction_onehot.def("to_string",
2525+
&machine_learning::edge_feature::PinDirectionOnehot::to_string,
2526+
R"(
2527+
Get a string representation.
2528+
2529+
:returns: A string.
2530+
:rtype: str
2531+
)");
2532+
2533+
py_edge_feature.def(
2534+
"build_feature_vec",
2535+
[](machine_learning::Context& ctx, const std::vector<const machine_learning::edge_feature::EdgeFeature*>& features, const Endpoint* source, const Endpoint* destination)
2536+
-> std::optional<std::vector<FEATURE_TYPE>> {
2537+
auto res = machine_learning::edge_feature::build_feature_vec(ctx, features, source, destination);
2538+
if (res.is_ok())
2539+
{
2540+
return res.get();
2541+
}
2542+
else
2543+
{
2544+
log_error("python_context", "Error in build_feature_vec:\n{}", res.get_error().get());
2545+
return std::nullopt;
2546+
}
2547+
},
2548+
py::arg("ctx"),
2549+
py::arg("features"),
2550+
py::arg("source"),
2551+
py::arg("destination"),
2552+
R"(
2553+
Build a combined feature vector from multiple features.
2554+
2555+
:param hal_py.Context ctx: The context.
2556+
:param list[hal_py.edge_feature.EdgeFeature] features: A list of edge features.
2557+
:param hal_py.Endpoint source: The source endpoint.
2558+
:param hal_py.Endpoint destination: The destination endpoint.
2559+
:returns: The combined feature vector or None.
2560+
:rtype: list[hal_py.FEATURE_TYPE] or None
23872561
)");
23882562

23892563
#ifndef PYBIND11_MODULE

0 commit comments

Comments
 (0)