Skip to content

Commit 8a2c1c5

Browse files
committed
REF: Inherit base model class from ABC to enforece native abstraction
Inherit base model class from `ABC` to enforece native abstraction. Fixes: ``` Abstract methods are allowed in classes whose metaclass is 'ABCMeta' ``` raised locally by the IDE.
1 parent 19af931 commit 8a2c1c5

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

src/nifreeze/model/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
#
2323
"""Base infrastructure for nifreeze's models."""
2424

25-
from abc import abstractmethod
25+
from abc import ABC, ABCMeta, abstractmethod
2626
from typing import Union
2727
from warnings import warn
2828

@@ -77,7 +77,7 @@ def init(model: str | None = None, **kwargs):
7777
raise NotImplementedError(f"Unsupported model <{model}>.")
7878

7979

80-
class BaseModel:
80+
class BaseModel(ABC):
8181
"""
8282
Defines the interface and default methods.
8383
@@ -88,6 +88,8 @@ class BaseModel:
8888
8989
"""
9090

91+
__metaclass__ = ABCMeta
92+
9193
__slots__ = ("_dataset", "_locked_fit")
9294

9395
def __init__(self, dataset, **kwargs):
@@ -116,7 +118,7 @@ def fit_predict(self, index: int | None = None, **kwargs) -> Union[np.ndarray, N
116118
If ``None``, no prediction will be executed.
117119
118120
"""
119-
raise NotImplementedError("Cannot call fit_predict() on a BaseModel instance.")
121+
return None
120122

121123

122124
class TrivialModel(BaseModel):

test/test_model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,16 @@
3535
from nifreeze.testing import simulations as _sim
3636

3737

38+
def test_base_model():
39+
from nifreeze.model.base import BaseModel
40+
41+
with pytest.raises(
42+
TypeError,
43+
match="Can't instantiate abstract class BaseModel with abstract method fit_predict",
44+
):
45+
BaseModel(None)
46+
47+
3848
@pytest.mark.parametrize("use_mask", (False, True))
3949
def test_trivial_model(request, use_mask):
4050
"""Check the implementation of the trivial B0 model."""

0 commit comments

Comments
 (0)