Skip to content

Commit 21b7d76

Browse files
committed
enh: robustify parallelization argument passing
1 parent d816f7c commit 21b7d76

File tree

4 files changed

+63
-18
lines changed

4 files changed

+63
-18
lines changed

src/nifreeze/cli/parser.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ def build_parser() -> ArgumentParser:
9191
)
9292
parser.add_argument(
9393
"--nthreads",
94+
"--omp-nthreads",
95+
"--ncpus",
9496
action="store",
9597
type=int,
9698
default=None,

src/nifreeze/data/base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,16 @@ def __len__(self) -> int:
9797

9898
return self.dataobj.shape[-1]
9999

100+
@property
101+
def shape3d(self):
102+
"""Get the shape of the 3D volume."""
103+
return self.dataobj.shape[:3]
104+
105+
@property
106+
def size3d(self):
107+
"""Get the number of voxels in the 3D volume."""
108+
return np.prod(self.dataobj.shape[:3])
109+
100110
def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[Unpack[Ts]]:
101111
return () # type: ignore[return-value]
102112

src/nifreeze/estimator.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from os import cpu_count
2828
from pathlib import Path
2929
from tempfile import TemporaryDirectory
30+
from timeit import default_timer as timer
3031
from typing import TypeVar
3132

3233
from tqdm import tqdm
@@ -42,7 +43,9 @@
4243

4344
DatasetT = TypeVar("DatasetT", bound=BaseDataset)
4445

46+
DEFAULT_CHUNK_SIZE: int = int(1e6)
4547
FIT_MSG = "Fit&predict"
48+
PRE_MSG = "Predicted"
4649
REG_MSG = "Realign"
4750

4851

@@ -109,13 +112,20 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
109112
dataset = result # type: ignore[assignment]
110113

111114
n_jobs = kwargs.pop("n_jobs", None) or min(cpu_count() or 1, 8)
115+
n_threads = kwargs.pop("omp_nthreads", None) or ((cpu_count() or 2) - 1)
116+
117+
num_voxels = dataset.brainmask.sum() if dataset.brainmask is not None else dataset.size3d
118+
chunk_size = DEFAULT_CHUNK_SIZE * (n_threads or 1)
112119

113120
# Prepare iterator
114121
iterfunc = getattr(iterators, f"{self._strategy}_iterator")
115122
index_iter = iterfunc(len(dataset), seed=kwargs.get("seed", None))
116123

117124
# Initialize model
118125
if isinstance(self._model, str):
126+
if self._model.endswith("dti"):
127+
self._model_kwargs["step"] = chunk_size
128+
119129
# Factory creates the appropriate model and pipes arguments
120130
model = ModelFactory.init(
121131
model=self._model,
@@ -125,10 +135,25 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
125135
else:
126136
model = self._model
127137

138+
fit_pred_kwargs = {
139+
"n_jobs": n_jobs,
140+
"omp_nthreads": n_threads,
141+
}
142+
143+
if model.__class__.__name__ == "DTIModel":
144+
fit_pred_kwargs["step"] = chunk_size
145+
146+
print(f"Dataset size: {num_voxels}x{len(dataset)}.")
147+
print(f"Parallel execution: {fit_pred_kwargs}.")
148+
print(f"Model: {model}.")
149+
128150
if self._single_fit:
129-
model.fit_predict(None, n_jobs=n_jobs)
151+
print("Fitting 'single' model started ...")
152+
start = timer()
153+
model.fit_predict(None, **fit_pred_kwargs)
154+
print(f"Fitting 'single' model finished, ellapsed {timer() - start}s.")
130155

131-
kwargs["num_threads"] = kwargs.pop("omp_nthreads", None) or kwargs.pop("num_threads", None)
156+
kwargs["num_threads"] = n_threads
132157
kwargs = self._align_kwargs | kwargs
133158

134159
dataset_length = len(dataset)
@@ -151,15 +176,16 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
151176
pbar.set_description_str(f"{FIT_MSG: <16} vol. <{i}>")
152177

153178
# fit the model
154-
test_set = dataset[i]
155179
predicted = model.fit_predict( # type: ignore[union-attr]
156180
i,
157-
n_jobs=n_jobs,
181+
**fit_pred_kwargs,
158182
)
159183

184+
pbar.set_description_str(f"{PRE_MSG: <16} vol. <{i}>")
185+
160186
# prepare data for running ANTs
161187
predicted_path, volume_path, init_path = _prepare_registration_data(
162-
test_set[0],
188+
dataset[i][0], # Access the target volume
163189
predicted,
164190
dataset.affine,
165191
i,

src/nifreeze/model/dmri.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(self, dataset: DWI, **kwargs):
6767

6868
super().__init__(dataset, **kwargs)
6969

70-
def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
70+
def _fit(self, index: int | None = None, n_jobs=None, omp_nthreads=None, **kwargs):
7171
"""Fit the model chunk-by-chunk asynchronously"""
7272

7373
if self._locked_fit is not None:
@@ -97,20 +97,20 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
9797
class_name,
9898
)(gtab, **kwargs)
9999

100+
fitargs = {"engine": "ray", "n_jobs": n_jobs} if n_jobs > 1 else {}
101+
102+
if "step" in kwargs:
103+
fitargs["step"] = kwargs["step"]
104+
100105
try:
101-
self._model_fit = model.fit(
102-
data,
103-
engine="serial" if n_jobs == 1 else "joblib",
104-
n_jobs=n_jobs,
105-
)
106+
self._model_fit = model.fit(data, **fitargs)
106107
except TypeError:
107108
from nifreeze.model._dipy import multi_fit
108109

109110
self._model_fit = multi_fit(
110111
model,
111112
data,
112-
engine="serial" if n_jobs == 1 else "ray",
113-
n_jobs=n_jobs,
113+
**fitargs,
114114
)
115115
return n_jobs
116116

@@ -125,40 +125,47 @@ def fit_predict(self, index: int | None = None, **kwargs):
125125
126126
"""
127127

128+
omp_nthreads = kwargs.pop("omp_nthreads", None)
129+
n_jobs = kwargs.pop("n_jobs", None)
130+
131+
brainmask = self._dataset.brainmask
128132
self._fit(
129133
index,
130-
n_jobs=kwargs.pop("n_jobs"),
134+
n_jobs=n_jobs,
135+
omp_nthreads=omp_nthreads,
131136
**kwargs,
132137
)
133138

134139
if index is None:
135140
self._locked_fit = True
136141
return None
137142

138-
brainmask = self._dataset.brainmask
143+
# Prepare gradient(s) for simulation
139144
gradient = self._dataset.gradients[:, index]
140-
141145
if "dipy" in getattr(self, "_model_class", ""):
142146
gradient = gradient_table_from_bvals_bvecs(
143147
gradient[np.newaxis, -1], gradient[np.newaxis, :-1]
144148
)
145149

150+
# Prepare the b=0
146151
S0 = self._dataset.bzero
147152
if S0 is not None:
148153
S0 = S0[brainmask, ...] if brainmask is not None else S0.reshape(-1)
149154

155+
# Prediction
150156
predicted = np.squeeze(
151157
self._model_fit.predict(
152158
gtab=gradient,
153159
S0=S0,
160+
**kwargs,
154161
)
155162
)
156163

157164
if brainmask is not None:
158-
retval = np.zeros_like(brainmask, dtype="float32")
165+
retval = np.zeros(self._dataset.shape3d, dtype=self._dataset.dataobj.dtype)
159166
retval[brainmask, ...] = predicted
160167
else:
161-
retval = predicted.reshape(self._dataset.dataobj.shape[:-1])
168+
retval = predicted.reshape(self._dataset.shape3d)
162169

163170
return retval
164171

0 commit comments

Comments
 (0)