Skip to content

Commit adb51aa

Browse files
committed
update pyo3
1 parent 80472e5 commit adb51aa

File tree

6 files changed

+44
-56
lines changed

6 files changed

+44
-56
lines changed

Cargo.lock

Lines changed: 15 additions & 33 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "tiny-solver"
3-
version = "0.4.1"
3+
version = "0.5.0"
44
edition = "2021"
55
authors = ["Powei Lin <[email protected]>"]
66
readme = "README.md"
@@ -19,19 +19,19 @@ faer = "0.18.2"
1919
faer-ext = { version = "0.1.0", features = ["nalgebra"] }
2020
log = "0.4.21"
2121
nalgebra = "0.32.4"
22-
num-dual = "0.8.1"
22+
num-dual = "0.9.0"
2323
num-traits = "0.2.18"
24-
numpy = { version = "0.20.0", features = ["nalgebra"], optional = true }
25-
pyo3 = { version = "0.20.3", features = ["abi3", "abi3-py38"] }
26-
pyo3-log = { version = "0.9.0", optional = true }
24+
numpy = { version = "0.21.0", features = ["nalgebra"], optional = true }
25+
pyo3 = { version = "0.21.0", features = ["abi3", "abi3-py38"] }
26+
# pyo3-log = { version = "0.9.0", optional = true }
2727
rayon = "1.9.0"
2828

2929
[[example]]
3030
name = "m3500_benchmark"
3131
path = "examples/m3500_benchmark.rs"
3232

3333
[features]
34-
python = ["num-dual/python", "numpy", "pyo3-log"]
34+
python = ["num-dual/python", "numpy"]
3535

3636
[dev-dependencies]
3737
env_logger = "0.11.3"

src/python/mod.rs

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,36 +10,40 @@ mod py_optimizer;
1010
mod py_problem;
1111
use self::py_factors::*;
1212

13-
fn register_child_module(py: Python<'_>, parent_module: &PyModule) -> PyResult<()> {
13+
fn register_child_module(parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
1414
// For factors submodule
15-
let factors_module = PyModule::new(py, "factors")?;
15+
let factors_module = PyModule::new_bound(parent_module.py(), "factors")?;
1616
factors_module.add_class::<BetweenFactorSE2>()?;
1717
factors_module.add_class::<PriorFactor>()?;
1818
factors_module.add_class::<PyFactor>()?;
19-
parent_module.add_submodule(factors_module)?;
20-
py.import("sys")?
19+
parent_module.add_submodule(&factors_module)?;
20+
parent_module
21+
.py()
22+
.import_bound("sys")?
2123
.getattr("modules")?
2224
.set_item("tiny_solver.factors", factors_module)?;
2325

24-
let loss_functions_module = PyModule::new(py, "loss_functions")?;
26+
let loss_functions_module = PyModule::new_bound(parent_module.py(), "loss_functions")?;
2527
loss_functions_module.add_class::<HuberLoss>()?;
26-
parent_module.add_submodule(loss_functions_module)?;
27-
py.import("sys")?
28+
parent_module.add_submodule(&loss_functions_module)?;
29+
parent_module
30+
.py()
31+
.import_bound("sys")?
2832
.getattr("modules")?
2933
.set_item("tiny_solver.loss_functions", loss_functions_module)?;
3034
Ok(())
3135
}
3236

3337
/// A Python module implemented in Rust.
3438
#[pymodule]
35-
pub fn tiny_solver<'py>(_py: Python<'py>, m: &'py PyModule) -> PyResult<()> {
36-
pyo3_log::init();
39+
pub fn tiny_solver(m: &Bound<'_, PyModule>) -> PyResult<()> {
40+
// pyo3_log::init();
3741
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
3842
m.add_class::<Problem>()?;
3943
m.add_class::<LinearSolver>()?;
4044
m.add_class::<OptimizerOptions>()?;
4145
m.add_class::<GaussNewtonOptimizer>()?;
42-
register_child_module(_py, m)?;
46+
register_child_module(m)?;
4347

4448
Ok(())
4549
}

src/python/py_factors.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ impl Factor for PyFactor {
6262
})
6363
.map(|x| x.into_py(py))
6464
.collect();
65-
let args = PyTuple::new(py, py_params);
65+
let args = PyTuple::new_bound(py, py_params);
6666
let result = self.func.call1(py, args);
6767
let residual_py = result.unwrap().extract::<Vec<PyDual64Dyn>>(py);
6868
residual_py

src/python/py_optimizer.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ impl GaussNewtonOptimizer {
2323
&self,
2424
py: Python<'_>,
2525
problem: &Problem,
26-
initial_values: &PyDict,
26+
initial_values: &Bound<'_, PyDict>,
2727
optimizer_options: Option<OptimizerOptions>,
2828
) -> PyResult<HashMap<String, Py<PyArray2<f64>>>> {
2929
let init_values: HashMap<String, PyReadonlyArray1<f64>> = initial_values.extract().unwrap();
@@ -35,7 +35,7 @@ impl GaussNewtonOptimizer {
3535

3636
let output_d: HashMap<String, Py<PyArray2<f64>>> = result
3737
.iter()
38-
.map(|(k, v)| (k.to_string(), v.to_pyarray(py).to_owned().into()))
38+
.map(|(k, v)| (k.to_string(), v.to_pyarray_bound(py).to_owned().into()))
3939
.collect();
4040
Ok(output_d)
4141
}

src/python/py_problem.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use crate::problem::Problem;
66

77
use super::PyFactor;
88

9-
fn convert_pyany_to_factor(py_any: &PyAny) -> PyResult<(bool, Box<dyn Factor + Send>)> {
9+
fn convert_pyany_to_factor(py_any: &Bound<'_, PyAny>) -> PyResult<(bool, Box<dyn Factor + Send>)> {
1010
let factor_name: String = py_any.get_type().getattr("__name__")?.extract()?;
1111
match factor_name.as_str() {
1212
"BetweenFactorSE2" => {
@@ -26,7 +26,9 @@ fn convert_pyany_to_factor(py_any: &PyAny) -> PyResult<(bool, Box<dyn Factor + S
2626
)),
2727
}
2828
}
29-
fn convert_pyany_to_loss_function(py_any: &PyAny) -> PyResult<Option<Box<dyn Loss + Send>>> {
29+
fn convert_pyany_to_loss_function(
30+
py_any: &Bound<'_, PyAny>,
31+
) -> PyResult<Option<Box<dyn Loss + Send>>> {
3032
let factor_name: String = py_any.get_type().getattr("__name__")?.extract()?;
3133
match factor_name.as_str() {
3234
"HuberLoss" => {
@@ -52,8 +54,8 @@ impl Problem {
5254
&mut self,
5355
dim_residual: usize,
5456
variable_key_size_list: Vec<(String, usize)>,
55-
pyfactor: &PyAny,
56-
pyloss_func: &PyAny,
57+
pyfactor: &Bound<'_, PyAny>,
58+
pyloss_func: &Bound<'_, PyAny>,
5759
) -> PyResult<()> {
5860
let (is_pyfactor, factor) = convert_pyany_to_factor(pyfactor).unwrap();
5961
self.add_residual_block(

0 commit comments

Comments
 (0)