Skip to content

Commit 48c3453

Browse files
committed
enh: improve general handling of parallelization
Brings improvements in parallelization management from #142, so they are kept even if we finally decided against #142. In particular, it opens up the implementation to set the ``step`` feature of DTI models. It also extends the base data object with two convenience properties for retrieving the 3D shape of the data and the number of voxels per volume. X-References: #142.
1 parent d16c97c commit 48c3453

File tree

5 files changed

+69
-15
lines changed

5 files changed

+69
-15
lines changed

src/nifreeze/cli/parser.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ def build_parser() -> ArgumentParser:
9191
)
9292
parser.add_argument(
9393
"--nthreads",
94+
"--omp-nthreads",
95+
"--ncpus",
9496
action="store",
9597
type=int,
9698
default=None,

src/nifreeze/data/base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,16 @@ def __getitem__(
127127
affine = self.motion_affines[idx] if self.motion_affines is not None else None
128128
return self.dataobj[..., idx], affine, *self._getextra(idx)
129129

130+
@property
131+
def shape3d(self):
132+
"""Get the shape of the 3D volume."""
133+
return self.dataobj.shape[:3]
134+
135+
@property
136+
def size3d(self):
137+
"""Get the number of voxels in the 3D volume."""
138+
return np.prod(self.dataobj.shape[:3])
139+
130140
@classmethod
131141
def from_filename(cls, filename: Path | str) -> Self:
132142
"""

src/nifreeze/estimator.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from os import cpu_count
2828
from pathlib import Path
2929
from tempfile import TemporaryDirectory
30+
from timeit import default_timer as timer
3031
from typing import TypeVar
3132

3233
from tqdm import tqdm
@@ -42,6 +43,7 @@
4243

4344
DatasetT = TypeVar("DatasetT", bound=BaseDataset)
4445

46+
DEFAULT_CHUNK_SIZE: int = int(1e6)
4547
FIT_MSG = "Fit&predict"
4648
REG_MSG = "Realign"
4749

@@ -109,13 +111,24 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
109111
dataset = result # type: ignore[assignment]
110112

111113
n_jobs = kwargs.pop("n_jobs", None) or min(cpu_count() or 1, 8)
114+
n_threads = kwargs.pop("omp_nthreads", None) or ((cpu_count() or 2) - 1)
115+
116+
num_voxels = dataset.brainmask.sum() if dataset.brainmask is not None else dataset.size3d
117+
chunk_size = DEFAULT_CHUNK_SIZE * (n_threads or 1)
112118

113119
# Prepare iterator
114120
iterfunc = getattr(iterators, f"{self._strategy}_iterator")
115121
index_iter = iterfunc(len(dataset), seed=kwargs.get("seed", None))
116122

117123
# Initialize model
118124
if isinstance(self._model, str):
125+
if self._model.endswith("dti"):
126+
self._model_kwargs["step"] = chunk_size
127+
128+
# Example: change model parameters only for DKI
129+
# if self._model.endswith("dki"):
130+
# self._model_kwargs["fit_model"] = "CWLS"
131+
119132
# Factory creates the appropriate model and pipes arguments
120133
model = ModelFactory.init(
121134
model=self._model,
@@ -125,10 +138,25 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
125138
else:
126139
model = self._model
127140

141+
# Prepare fit/predict keyword arguments
142+
fit_pred_kwargs = {
143+
"n_jobs": n_jobs,
144+
"omp_nthreads": n_threads,
145+
}
146+
if model.__class__.__name__ == "DTIModel":
147+
fit_pred_kwargs["step"] = chunk_size
148+
149+
print(f"Dataset size: {num_voxels}x{len(dataset)}.")
150+
print(f"Parallel execution: {fit_pred_kwargs}.")
151+
print(f"Model: {model}.")
152+
128153
if self._single_fit:
129-
model.fit_predict(None, n_jobs=n_jobs)
154+
print("Fitting 'single' model started ...")
155+
start = timer()
156+
model.fit_predict(None, **fit_pred_kwargs)
157+
print(f"Fitting 'single' model finished, elapsed {timer() - start}s.")
130158

131-
kwargs["num_threads"] = kwargs.pop("omp_nthreads", None) or kwargs.pop("num_threads", None)
159+
kwargs["num_threads"] = n_threads
132160
kwargs = self._align_kwargs | kwargs
133161

134162
dataset_length = len(dataset)
@@ -151,15 +179,14 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
151179
pbar.set_description_str(f"{FIT_MSG: <16} vol. <{i}>")
152180

153181
# fit the model
154-
test_set = dataset[i]
155182
predicted = model.fit_predict( # type: ignore[union-attr]
156183
i,
157-
n_jobs=n_jobs,
184+
**fit_pred_kwargs,
158185
)
159186

160187
# prepare data for running ANTs
161188
predicted_path, volume_path, init_path = _prepare_registration_data(
162-
test_set[0],
189+
dataset[i][0], # Access the target volume
163190
predicted,
164191
dataset.affine,
165192
i,

src/nifreeze/model/dmri.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,8 @@
3838
B_MIN = 50
3939

4040

41-
def _exec_fit(model, data, chunk=None):
42-
retval = model.fit(data)
43-
return retval, chunk
41+
def _exec_fit(model, data, chunk=None, **kwargs):
42+
return model.fit(data, **kwargs), chunk
4443

4544

4645
def _exec_predict(model, chunk=None, **kwargs):
@@ -104,7 +103,7 @@ def __init__(self, dataset: DWI, max_b: float | int | None = None, **kwargs):
104103

105104
super().__init__(dataset, **kwargs)
106105

107-
def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
106+
def _fit(self, index: int | None = None, n_jobs: int | None = None, **kwargs):
108107
"""Fit the model chunk-by-chunk asynchronously"""
109108

110109
n_jobs = n_jobs or 1
@@ -136,9 +135,18 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
136135
class_name,
137136
)(gtab, **kwargs)
138137

138+
fit_kwargs = {} # Add here keyword arguments
139+
140+
is_dki = model_str == "dipy.reconst.dki.DiffusionKurtosisModel"
141+
139142
# One single CPU - linear execution (full model)
140-
if n_jobs == 1:
141-
_modelfit, _ = _exec_fit(model, data)
143+
# DKI model does not allow parallelization as implemented here
144+
if n_jobs == 1 or is_dki:
145+
_modelfit, _ = _exec_fit(model, data, **fit_kwargs)
146+
self._models = [_modelfit]
147+
return 1
148+
elif is_dki:
149+
_modelfit = model.multi_fit(data, **fit_kwargs)
142150
self._models = [_modelfit]
143151
return 1
144152

@@ -150,7 +158,8 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
150158
# Parallelize process with joblib
151159
with Parallel(n_jobs=n_jobs) as executor:
152160
results = executor(
153-
delayed(_exec_fit)(model, dchunk, i) for i, dchunk in enumerate(data_chunks)
161+
delayed(_exec_fit)(model, dchunk, i, **fit_kwargs)
162+
for i, dchunk in enumerate(data_chunks)
154163
)
155164
for submodel, rindex in results:
156165
self._models[rindex] = submodel
@@ -168,6 +177,7 @@ def fit_predict(self, index: int | None = None, **kwargs):
168177
169178
"""
170179

180+
kwargs.pop("omp_nthreads", None) # Drop omp_nthreads
171181
n_models = self._fit(
172182
index,
173183
n_jobs=kwargs.pop("n_jobs"),

test/test_data_base.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,16 @@
3232

3333
from nifreeze.data import NFDH5_EXT, BaseDataset, load
3434

35+
DEFAULT_RANDOM_DATASET_SHAPE = (32, 32, 32, 5)
36+
DEFAULT_RANDOM_DATASET_SIZE = int(np.prod(DEFAULT_RANDOM_DATASET_SHAPE[:3]))
37+
3538

3639
@pytest.fixture
37-
def random_dataset(request) -> BaseDataset:
40+
def random_dataset(request, size=DEFAULT_RANDOM_DATASET_SHAPE) -> BaseDataset:
3841
"""Create a BaseDataset with random data for testing."""
3942

4043
rng = request.node.rng
41-
data = rng.random((32, 32, 32, 5)).astype(np.float32)
44+
data = rng.random(size).astype(np.float32)
4245
affine = np.eye(4, dtype=np.float32)
4346
return BaseDataset(dataobj=data, affine=affine)
4447

@@ -47,8 +50,10 @@ def test_base_dataset_init(random_dataset: BaseDataset):
4750
"""Test that the BaseDataset can be initialized with random data."""
4851
assert random_dataset.dataobj is not None
4952
assert random_dataset.affine is not None
50-
assert random_dataset.dataobj.shape == (32, 32, 32, 5)
53+
assert random_dataset.dataobj.shape == DEFAULT_RANDOM_DATASET_SHAPE
5154
assert random_dataset.affine.shape == (4, 4)
55+
assert random_dataset.size3d == DEFAULT_RANDOM_DATASET_SIZE
56+
assert random_dataset.shape3d == DEFAULT_RANDOM_DATASET_SHAPE[:3]
5257

5358

5459
def test_len(random_dataset: BaseDataset):

0 commit comments

Comments
 (0)