Skip to content

Commit d513d05

Browse files
committed
enh: improve general handling of parallelization
Brings improvements in parallelization management from #142, so they are kept even if we finally decided against #142. In particular, it opens up the implementation to set the ``step`` feature of DTI models. It also extends the base data object with two convenience properties for retrieving the 3D shape of the data and the number of voxels per volume. X-References: #142.
1 parent d16c97c commit d513d05

File tree

5 files changed

+70
-15
lines changed

5 files changed

+70
-15
lines changed

src/nifreeze/cli/parser.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ def build_parser() -> ArgumentParser:
9191
)
9292
parser.add_argument(
9393
"--nthreads",
94+
"--omp-nthreads",
95+
"--ncpus",
9496
action="store",
9597
type=int,
9698
default=None,

src/nifreeze/data/base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,16 @@ def __getitem__(
127127
affine = self.motion_affines[idx] if self.motion_affines is not None else None
128128
return self.dataobj[..., idx], affine, *self._getextra(idx)
129129

130+
@property
131+
def shape3d(self):
132+
"""Get the shape of the 3D volume."""
133+
return self.dataobj.shape[:3]
134+
135+
@property
136+
def size3d(self):
137+
"""Get the number of voxels in the 3D volume."""
138+
return np.prod(self.dataobj.shape[:3])
139+
130140
@classmethod
131141
def from_filename(cls, filename: Path | str) -> Self:
132142
"""

src/nifreeze/estimator.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from os import cpu_count
2828
from pathlib import Path
2929
from tempfile import TemporaryDirectory
30+
from timeit import default_timer as timer
3031
from typing import TypeVar
3132

3233
from tqdm import tqdm
@@ -42,6 +43,7 @@
4243

4344
DatasetT = TypeVar("DatasetT", bound=BaseDataset)
4445

46+
DEFAULT_CHUNK_SIZE: int = int(1e6)
4547
FIT_MSG = "Fit&predict"
4648
REG_MSG = "Realign"
4749

@@ -109,13 +111,24 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
109111
dataset = result # type: ignore[assignment]
110112

111113
n_jobs = kwargs.pop("n_jobs", None) or min(cpu_count() or 1, 8)
114+
n_threads = kwargs.pop("omp_nthreads", None) or ((cpu_count() or 2) - 1)
115+
116+
num_voxels = dataset.brainmask.sum() if dataset.brainmask is not None else dataset.size3d
117+
chunk_size = DEFAULT_CHUNK_SIZE * (n_threads or 1)
112118

113119
# Prepare iterator
114120
iterfunc = getattr(iterators, f"{self._strategy}_iterator")
115121
index_iter = iterfunc(len(dataset), seed=kwargs.get("seed", None))
116122

117123
# Initialize model
118124
if isinstance(self._model, str):
125+
if self._model.endswith("dti"):
126+
self._model_kwargs["step"] = chunk_size
127+
128+
# Example: change model parameters only for DKI
129+
# if self._model.endswith("dki"):
130+
# self._model_kwargs["fit_model"] = "CWLS"
131+
119132
# Factory creates the appropriate model and pipes arguments
120133
model = ModelFactory.init(
121134
model=self._model,
@@ -125,10 +138,25 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
125138
else:
126139
model = self._model
127140

141+
# Prepare fit/predict keyword arguments
142+
fit_pred_kwargs = {
143+
"n_jobs": n_jobs,
144+
"omp_nthreads": n_threads,
145+
}
146+
if model.__class__.__name__ == "DTIModel":
147+
fit_pred_kwargs["step"] = chunk_size
148+
149+
print(f"Dataset size: {num_voxels}x{len(dataset)}.")
150+
print(f"Parallel execution: {fit_pred_kwargs}.")
151+
print(f"Model: {model}.")
152+
128153
if self._single_fit:
129-
model.fit_predict(None, n_jobs=n_jobs)
154+
print("Fitting 'single' model started ...")
155+
start = timer()
156+
model.fit_predict(None, **fit_pred_kwargs)
157+
print(f"Fitting 'single' model finished, elapsed {timer() - start}s.")
130158

131-
kwargs["num_threads"] = kwargs.pop("omp_nthreads", None) or kwargs.pop("num_threads", None)
159+
kwargs["num_threads"] = n_threads
132160
kwargs = self._align_kwargs | kwargs
133161

134162
dataset_length = len(dataset)
@@ -151,15 +179,14 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
151179
pbar.set_description_str(f"{FIT_MSG: <16} vol. <{i}>")
152180

153181
# fit the model
154-
test_set = dataset[i]
155182
predicted = model.fit_predict( # type: ignore[union-attr]
156183
i,
157-
n_jobs=n_jobs,
184+
**fit_pred_kwargs,
158185
)
159186

160187
# prepare data for running ANTs
161188
predicted_path, volume_path, init_path = _prepare_registration_data(
162-
test_set[0],
189+
dataset[i][0], # Access the target volume
163190
predicted,
164191
dataset.affine,
165192
i,

src/nifreeze/model/dmri.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#
2323

2424
from importlib import import_module
25+
from typing import Any
2526

2627
import numpy as np
2728
from dipy.core.gradients import gradient_table_from_bvals_bvecs
@@ -38,9 +39,8 @@
3839
B_MIN = 50
3940

4041

41-
def _exec_fit(model, data, chunk=None):
42-
retval = model.fit(data)
43-
return retval, chunk
42+
def _exec_fit(model, data, chunk=None, **kwargs):
43+
return model.fit(data, **kwargs), chunk
4444

4545

4646
def _exec_predict(model, chunk=None, **kwargs):
@@ -104,7 +104,7 @@ def __init__(self, dataset: DWI, max_b: float | int | None = None, **kwargs):
104104

105105
super().__init__(dataset, **kwargs)
106106

107-
def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
107+
def _fit(self, index: int | None = None, n_jobs: int | None = None, **kwargs):
108108
"""Fit the model chunk-by-chunk asynchronously"""
109109

110110
n_jobs = n_jobs or 1
@@ -136,9 +136,18 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
136136
class_name,
137137
)(gtab, **kwargs)
138138

139+
fit_kwargs: dict[str, Any] = {} # Add here keyword arguments
140+
141+
is_dki = model_str == "dipy.reconst.dki.DiffusionKurtosisModel"
142+
139143
# One single CPU - linear execution (full model)
140-
if n_jobs == 1:
141-
_modelfit, _ = _exec_fit(model, data)
144+
# DKI model does not allow parallelization as implemented here
145+
if n_jobs == 1 or is_dki:
146+
_modelfit, _ = _exec_fit(model, data, **fit_kwargs)
147+
self._models = [_modelfit]
148+
return 1
149+
elif is_dki:
150+
_modelfit = model.multi_fit(data, **fit_kwargs)
142151
self._models = [_modelfit]
143152
return 1
144153

@@ -150,7 +159,8 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
150159
# Parallelize process with joblib
151160
with Parallel(n_jobs=n_jobs) as executor:
152161
results = executor(
153-
delayed(_exec_fit)(model, dchunk, i) for i, dchunk in enumerate(data_chunks)
162+
delayed(_exec_fit)(model, dchunk, i, **fit_kwargs)
163+
for i, dchunk in enumerate(data_chunks)
154164
)
155165
for submodel, rindex in results:
156166
self._models[rindex] = submodel
@@ -168,6 +178,7 @@ def fit_predict(self, index: int | None = None, **kwargs):
168178
169179
"""
170180

181+
kwargs.pop("omp_nthreads", None) # Drop omp_nthreads
171182
n_models = self._fit(
172183
index,
173184
n_jobs=kwargs.pop("n_jobs"),

test/test_data_base.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,16 @@
3232

3333
from nifreeze.data import NFDH5_EXT, BaseDataset, load
3434

35+
DEFAULT_RANDOM_DATASET_SHAPE = (32, 32, 32, 5)
36+
DEFAULT_RANDOM_DATASET_SIZE = int(np.prod(DEFAULT_RANDOM_DATASET_SHAPE[:3]))
37+
3538

3639
@pytest.fixture
37-
def random_dataset(request) -> BaseDataset:
40+
def random_dataset(request, size=DEFAULT_RANDOM_DATASET_SHAPE) -> BaseDataset:
3841
"""Create a BaseDataset with random data for testing."""
3942

4043
rng = request.node.rng
41-
data = rng.random((32, 32, 32, 5)).astype(np.float32)
44+
data = rng.random(size).astype(np.float32)
4245
affine = np.eye(4, dtype=np.float32)
4346
return BaseDataset(dataobj=data, affine=affine)
4447

@@ -47,8 +50,10 @@ def test_base_dataset_init(random_dataset: BaseDataset):
4750
"""Test that the BaseDataset can be initialized with random data."""
4851
assert random_dataset.dataobj is not None
4952
assert random_dataset.affine is not None
50-
assert random_dataset.dataobj.shape == (32, 32, 32, 5)
53+
assert random_dataset.dataobj.shape == DEFAULT_RANDOM_DATASET_SHAPE
5154
assert random_dataset.affine.shape == (4, 4)
55+
assert random_dataset.size3d == DEFAULT_RANDOM_DATASET_SIZE
56+
assert random_dataset.shape3d == DEFAULT_RANDOM_DATASET_SHAPE[:3]
5257

5358

5459
def test_len(random_dataset: BaseDataset):

0 commit comments

Comments
 (0)