Skip to content

Commit e1f35be

Browse files
committed
enh: robustify parallelization argument passing
1 parent d615ae1 commit e1f35be

File tree

4 files changed

+60
-16
lines changed

4 files changed

+60
-16
lines changed

src/nifreeze/cli/parser.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ def build_parser() -> ArgumentParser:
9191
)
9292
parser.add_argument(
9393
"--nthreads",
94+
"--omp-nthreads",
95+
"--ncpus",
9496
action="store",
9597
type=int,
9698
default=None,

src/nifreeze/data/base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,16 @@ def __len__(self) -> int:
9797

9898
return self.dataobj.shape[-1]
9999

100+
@property
101+
def shape3d(self):
102+
"""Get the shape of the 3D volume."""
103+
return self.dataobj.shape[:3]
104+
105+
@property
106+
def size3d(self):
107+
"""Get the number of voxels in the 3D volume."""
108+
return np.prod(self.dataobj.shape[:3])
109+
100110
def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[Unpack[Ts]]:
101111
return () # type: ignore[return-value]
102112

src/nifreeze/estimator.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from os import cpu_count
2828
from pathlib import Path
2929
from tempfile import TemporaryDirectory
30+
from timeit import default_timer as timer
3031
from typing import TypeVar
3132

3233
from tqdm import tqdm
@@ -42,7 +43,9 @@
4243

4344
DatasetT = TypeVar("DatasetT", bound=BaseDataset)
4445

46+
DEFAULT_CHUNK_SIZE: int = int(1e6)
4547
FIT_MSG = "Fit&predict"
48+
PRE_MSG = "Predicted"
4649
REG_MSG = "Realign"
4750

4851

@@ -109,13 +112,20 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
109112
dataset = result # type: ignore[assignment]
110113

111114
n_jobs = kwargs.pop("n_jobs", None) or min(cpu_count() or 1, 8)
115+
n_threads = kwargs.pop("omp_nthreads", None) or ((cpu_count() or 2) - 1)
116+
117+
num_voxels = dataset.brainmask.sum() if dataset.brainmask is not None else dataset.size3d
118+
chunk_size = DEFAULT_CHUNK_SIZE * (n_threads or 1)
112119

113120
# Prepare iterator
114121
iterfunc = getattr(iterators, f"{self._strategy}_iterator")
115122
index_iter = iterfunc(len(dataset), seed=kwargs.get("seed", None))
116123

117124
# Initialize model
118125
if isinstance(self._model, str):
126+
if self._model.endswith("dti"):
127+
self._model_kwargs["step"] = chunk_size
128+
119129
# Factory creates the appropriate model and pipes arguments
120130
model = ModelFactory.init(
121131
model=self._model,
@@ -125,10 +135,25 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
125135
else:
126136
model = self._model
127137

138+
fit_pred_kwargs = {
139+
"n_jobs": n_jobs,
140+
"omp_nthreads": n_threads,
141+
}
142+
143+
if model.__class__.__name__ == "DTIModel":
144+
fit_pred_kwargs["step"] = chunk_size
145+
146+
print(f"Dataset size: {num_voxels}x{len(dataset)}.")
147+
print(f"Parallel execution: {fit_pred_kwargs}.")
148+
print(f"Model: {model}.")
149+
128150
if self._single_fit:
129-
model.fit_predict(None, n_jobs=n_jobs)
151+
print("Fitting 'single' model started ...")
152+
start = timer()
153+
model.fit_predict(None, **fit_pred_kwargs)
154+
print(f"Fitting 'single' model finished, elapsed {timer() - start}s.")
130155

131-
kwargs["num_threads"] = kwargs.pop("omp_nthreads", None) or kwargs.pop("num_threads", None)
156+
kwargs["num_threads"] = n_threads
132157
kwargs = self._align_kwargs | kwargs
133158

134159
dataset_length = len(dataset)
@@ -151,15 +176,16 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
151176
pbar.set_description_str(f"{FIT_MSG: <16} vol. <{i}>")
152177

153178
# fit the model
154-
test_set = dataset[i]
155179
predicted = model.fit_predict( # type: ignore[union-attr]
156180
i,
157-
n_jobs=n_jobs,
181+
**fit_pred_kwargs,
158182
)
159183

184+
pbar.set_description_str(f"{PRE_MSG: <16} vol. <{i}>")
185+
160186
# prepare data for running ANTs
161187
predicted_path, volume_path, init_path = _prepare_registration_data(
162-
test_set[0],
188+
dataset[i][0], # Access the target volume
163189
predicted,
164190
dataset.affine,
165191
i,

src/nifreeze/model/dmri.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def __init__(self, dataset: DWI, max_b: float | int | None = None, **kwargs):
9393

9494
super().__init__(dataset, **kwargs)
9595

96-
def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
96+
def _fit(self, index: int | None = None, n_jobs=None, omp_nthreads=None, **kwargs):
9797
"""Fit the model chunk-by-chunk asynchronously"""
9898

9999
if self._locked_fit is not None:
@@ -123,20 +123,20 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
123123
class_name,
124124
)(gtab, **kwargs)
125125

126+
fitargs = {"engine": "ray", "n_jobs": n_jobs} if n_jobs > 1 else {}
127+
128+
if "step" in kwargs:
129+
fitargs["step"] = kwargs["step"]
130+
126131
try:
127-
self._model_fit = model.fit(
128-
data,
129-
engine="serial" if n_jobs == 1 else "joblib",
130-
n_jobs=n_jobs,
131-
)
132+
self._model_fit = model.fit(data, **fitargs)
132133
except TypeError:
133134
from nifreeze.model._dipy import multi_fit
134135

135136
self._model_fit = multi_fit(
136137
model,
137138
data,
138-
engine="serial" if n_jobs == 1 else "ray",
139-
n_jobs=n_jobs,
139+
**fitargs,
140140
)
141141
return n_jobs
142142

@@ -151,27 +151,33 @@ def fit_predict(self, index: int | None = None, **kwargs):
151151
152152
"""
153153

154+
omp_nthreads = kwargs.pop("omp_nthreads", None)
155+
n_jobs = kwargs.pop("n_jobs", None)
156+
157+
brainmask = self._dataset.brainmask
154158
self._fit(
155159
index,
156-
n_jobs=kwargs.pop("n_jobs"),
160+
n_jobs=n_jobs,
161+
omp_nthreads=omp_nthreads,
157162
**kwargs,
158163
)
159164

160165
if index is None:
161166
self._locked_fit = True
162167
return None
163168

169+
# Prepare gradient(s) for simulation
164170
gradient = self._dataset.gradients[:, index]
165-
166171
if "dipy" in getattr(self, "_model_class", ""):
167172
gradient = gradient_table_from_bvals_bvecs(
168173
gradient[np.newaxis, -1], gradient[np.newaxis, :-1]
169174
)
170-
175+
# Prediction
171176
predicted = np.squeeze(
172177
self._model_fit.predict(
173178
gtab=gradient,
174179
S0=self._S0,
180+
**kwargs,
175181
)
176182
)
177183

0 commit comments

Comments
 (0)