Skip to content

ENH: Use DIPY's parallelization #142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@ license = "Apache-2.0"
requires-python = ">=3.10"
dependencies = [
"attrs",
"dipy>=1.5.0",
"dipy>=1.10.0",
"joblib",
"nipype>= 1.5.1,<2.0",
"nitransforms>=22.0.0,<24",
"nireports",
"numpy>=1.21.3",
"nest-asyncio>=1.5.1",
"ray",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we requiring both joblib and ray? I am not sure how we can make these be individual choices for the user depending on their preference (?). We should probably raise a warning if there is no parallelization backend so that users become aware of the reason why the run is so slow.

As an additional note, DIPY also offers dask as another parallelization backend, e.g.
https://github.com/dipy/dipy/blob/master/dipy/utils/tests/test_parallel.py#L12-L18

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our experiments, ray performs very well, even relative to joblib, so giving people this option would be great (maybe even as default?). I would stay away from dask. In our experiments it very often choked and performed worse than a serial baseline.

For some of the details: https://nrdg.github.io/2024-dipy-parallelization/

"scikit-image>=0.15.0",
"scikit_learn>=1.3.0",
"scipy>=1.8.0",
Expand Down
2 changes: 2 additions & 0 deletions src/nifreeze/cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def build_parser() -> ArgumentParser:
)
parser.add_argument(
"--nthreads",
"--omp-nthreads",
"--ncpus",
action="store",
type=int,
default=None,
Expand Down
10 changes: 10 additions & 0 deletions src/nifreeze/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,16 @@

return self.dataobj.shape[-1]

@property
def shape3d(self):
"""Get the shape of the 3D volume."""
return self.dataobj.shape[:3]

Check warning on line 103 in src/nifreeze/data/base.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/data/base.py#L103

Added line #L103 was not covered by tests

@property
def size3d(self):
"""Get the number of voxels in the 3D volume."""
return np.prod(self.dataobj.shape[:3])

Check warning on line 108 in src/nifreeze/data/base.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/data/base.py#L108

Added line #L108 was not covered by tests

def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[Unpack[Ts]]:
return () # type: ignore[return-value]

Expand Down
36 changes: 31 additions & 5 deletions src/nifreeze/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from os import cpu_count
from pathlib import Path
from tempfile import TemporaryDirectory
from timeit import default_timer as timer
from typing import TypeVar

from tqdm import tqdm
Expand All @@ -42,7 +43,9 @@

DatasetT = TypeVar("DatasetT", bound=BaseDataset)

DEFAULT_CHUNK_SIZE: int = int(1e6)
FIT_MSG = "Fit&predict"
PRE_MSG = "Predicted"
REG_MSG = "Realign"


Expand Down Expand Up @@ -109,13 +112,20 @@
dataset = result # type: ignore[assignment]

n_jobs = kwargs.pop("n_jobs", None) or min(cpu_count() or 1, 8)
n_threads = kwargs.pop("omp_nthreads", None) or ((cpu_count() or 2) - 1)

num_voxels = dataset.brainmask.sum() if dataset.brainmask is not None else dataset.size3d
chunk_size = DEFAULT_CHUNK_SIZE * (n_threads or 1)

# Prepare iterator
iterfunc = getattr(iterators, f"{self._strategy}_iterator")
index_iter = iterfunc(len(dataset), seed=kwargs.get("seed", None))

# Initialize model
if isinstance(self._model, str):
if self._model.endswith("dti"):
self._model_kwargs["step"] = chunk_size

Check warning on line 127 in src/nifreeze/estimator.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/estimator.py#L127

Added line #L127 was not covered by tests

# Factory creates the appropriate model and pipes arguments
model = ModelFactory.init(
model=self._model,
Expand All @@ -125,10 +135,25 @@
else:
model = self._model

fit_pred_kwargs = {
"n_jobs": n_jobs,
"omp_nthreads": n_threads,
}

if model.__class__.__name__ == "DTIModel":
fit_pred_kwargs["step"] = chunk_size

Check warning on line 144 in src/nifreeze/estimator.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/estimator.py#L144

Added line #L144 was not covered by tests

print(f"Dataset size: {num_voxels}x{len(dataset)}.")
print(f"Parallel execution: {fit_pred_kwargs}.")
print(f"Model: {model}.")

if self._single_fit:
model.fit_predict(None, n_jobs=n_jobs)
print("Fitting 'single' model started ...")
start = timer()
model.fit_predict(None, **fit_pred_kwargs)
print(f"Fitting 'single' model finished, elapsed {timer() - start}s.")

Check warning on line 154 in src/nifreeze/estimator.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/estimator.py#L151-L154

Added lines #L151 - L154 were not covered by tests

kwargs["num_threads"] = kwargs.pop("omp_nthreads", None) or kwargs.pop("num_threads", None)
kwargs["num_threads"] = n_threads
kwargs = self._align_kwargs | kwargs

dataset_length = len(dataset)
Expand All @@ -151,15 +176,16 @@
pbar.set_description_str(f"{FIT_MSG: <16} vol. <{i}>")

# fit the model
test_set = dataset[i]
predicted = model.fit_predict( # type: ignore[union-attr]
i,
n_jobs=n_jobs,
**fit_pred_kwargs,
)

pbar.set_description_str(f"{PRE_MSG: <16} vol. <{i}>")

# prepare data for running ANTs
predicted_path, volume_path, init_path = _prepare_registration_data(
test_set[0],
dataset[i][0], # Access the target volume
predicted,
dataset.affine,
i,
Expand Down
6 changes: 6 additions & 0 deletions src/nifreeze/model/_dipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import numpy as np
from dipy.core.gradients import GradientTable
from dipy.reconst.base import ReconstModel
from dipy.reconst.multi_voxel import multi_voxel_fit
from sklearn.gaussian_process import GaussianProcessRegressor

from nifreeze.model.gpr import (
Expand All @@ -38,6 +39,11 @@
)


@multi_voxel_fit
def multi_fit(obj, data, *, mask=None, **kwargs):
return obj.fit(data, *obj.args, mask=mask)

Check warning on line 44 in src/nifreeze/model/_dipy.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/model/_dipy.py#L44

Added line #L44 was not covered by tests


def gp_prediction(
model: GaussianProcessRegressor,
gtab: GradientTable | np.ndarray,
Expand Down
84 changes: 29 additions & 55 deletions src/nifreeze/model/dmri.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

import numpy as np
from dipy.core.gradients import gradient_table_from_bvals_bvecs
from joblib import Parallel, delayed

from nifreeze.data.dmri import (
DEFAULT_CLIP_PERCENTILE,
Expand All @@ -38,16 +37,6 @@
B_MIN = 50


def _exec_fit(model, data, chunk=None):
retval = model.fit(data)
return retval, chunk


def _exec_predict(model, chunk=None, **kwargs):
"""Propagate model parameters and call predict."""
return np.squeeze(model.predict(**kwargs)), chunk


class BaseDWIModel(BaseModel):
"""Interface and default methods for DWI models."""

Expand All @@ -57,7 +46,7 @@
"_S0": "The S0 (b=0 reference signal) that will be fed into DIPY models",
"_model_class": "Defining a model class, DIPY models are instantiated automagically",
"_modelargs": "Arguments acceptable by the underlying DIPY-like model.",
"_models": "List with one or more (if parallel execution) model instances",
"_model_fit": "Fitted model",
}

def __init__(self, dataset: DWI, max_b: float | int | None = None, **kwargs):
Expand Down Expand Up @@ -104,11 +93,9 @@

super().__init__(dataset, **kwargs)

def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
def _fit(self, index: int | None = None, n_jobs=None, omp_nthreads=None, **kwargs):
"""Fit the model chunk-by-chunk asynchronously"""

n_jobs = n_jobs or 1

if self._locked_fit is not None:
return n_jobs

Expand Down Expand Up @@ -136,25 +123,21 @@
class_name,
)(gtab, **kwargs)

# One single CPU - linear execution (full model)
if n_jobs == 1:
_modelfit, _ = _exec_fit(model, data)
self._models = [_modelfit]
return 1
fitargs = {"engine": "ray", "n_jobs": n_jobs} if n_jobs > 1 else {}

Check warning on line 126 in src/nifreeze/model/dmri.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/model/dmri.py#L126

Added line #L126 was not covered by tests

# Split data into chunks of group of slices
data_chunks = np.array_split(data, n_jobs)
if "step" in kwargs:
fitargs["step"] = kwargs["step"]

Check warning on line 129 in src/nifreeze/model/dmri.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/model/dmri.py#L129

Added line #L129 was not covered by tests

self._models = [None] * n_jobs
try:
self._model_fit = model.fit(data, **fitargs)
except TypeError:
from nifreeze.model._dipy import multi_fit

Check warning on line 134 in src/nifreeze/model/dmri.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/model/dmri.py#L131-L134

Added lines #L131 - L134 were not covered by tests

# Parallelize process with joblib
with Parallel(n_jobs=n_jobs) as executor:
results = executor(
delayed(_exec_fit)(model, dchunk, i) for i, dchunk in enumerate(data_chunks)
self._model_fit = multi_fit(

Check warning on line 136 in src/nifreeze/model/dmri.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/model/dmri.py#L136

Added line #L136 was not covered by tests
model,
data,
**fitargs,
)
for submodel, rindex in results:
self._models[rindex] = submodel

return n_jobs

def fit_predict(self, index: int | None = None, **kwargs):
Expand All @@ -168,44 +151,35 @@

"""

n_models = self._fit(
omp_nthreads = kwargs.pop("omp_nthreads", None)
n_jobs = kwargs.pop("n_jobs", None)

Check warning on line 155 in src/nifreeze/model/dmri.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/model/dmri.py#L154-L155

Added lines #L154 - L155 were not covered by tests

brainmask = self._dataset.brainmask
self._fit(

Check warning on line 158 in src/nifreeze/model/dmri.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/model/dmri.py#L157-L158

Added lines #L157 - L158 were not covered by tests
index,
n_jobs=kwargs.pop("n_jobs"),
n_jobs=n_jobs,
omp_nthreads=omp_nthreads,
**kwargs,
)

if index is None:
self._locked_fit = True

Check warning on line 166 in src/nifreeze/model/dmri.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/model/dmri.py#L166

Added line #L166 was not covered by tests
return None

# Prepare gradient(s) for simulation
gradient = self._dataset.gradients[:, index]

if "dipy" in getattr(self, "_model_class", ""):
gradient = gradient_table_from_bvals_bvecs(
gradient[np.newaxis, -1], gradient[np.newaxis, :-1]
)

if n_models == 1:
predicted, _ = _exec_predict(
self._models[0], **(kwargs | {"gtab": gradient, "S0": self._S0})
# Prediction
predicted = np.squeeze(

Check warning on line 176 in src/nifreeze/model/dmri.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/model/dmri.py#L176

Added line #L176 was not covered by tests
self._model_fit.predict(
gtab=gradient,
S0=self._S0,
**kwargs,
)
else:
predicted = [None] * n_models
S0 = np.array_split(self._S0, n_models)

# Parallelize process with joblib
with Parallel(n_jobs=n_models) as executor:
results = executor(
delayed(_exec_predict)(
model,
chunk=i,
**(kwargs | {"gtab": gradient, "S0": S0[i]}),
)
for i, model in enumerate(self._models)
)
for subprediction, index in results:
predicted[index] = subprediction

predicted = np.hstack(predicted)
)

retval = np.zeros_like(self._data_mask, dtype=self._dataset.dataobj.dtype)
retval[self._data_mask, ...] = predicted
Expand Down