From e5f833a3802e6b0f4df7942ad53eba9db23eab35 Mon Sep 17 00:00:00 2001 From: Idlir Shkurti Date: Mon, 8 Apr 2024 11:34:39 +0200 Subject: [PATCH] Add model manager as BackgroundService for handling machine learning models Signed-off-by: Idlir Shkurti --- src/frequenz/sdk/ml/__init__.py | 10 ++ src/frequenz/sdk/ml/_model_manager.py | 143 ++++++++++++++++++++++++++ tests/model/__init__.py | 4 + tests/model/test_model_manager.py | 118 +++++++++++++++++++++ 4 files changed, 275 insertions(+) create mode 100644 src/frequenz/sdk/ml/__init__.py create mode 100644 src/frequenz/sdk/ml/_model_manager.py create mode 100644 tests/model/__init__.py create mode 100644 tests/model/test_model_manager.py diff --git a/src/frequenz/sdk/ml/__init__.py b/src/frequenz/sdk/ml/__init__.py new file mode 100644 index 000000000..a42245d0f --- /dev/null +++ b/src/frequenz/sdk/ml/__init__.py @@ -0,0 +1,10 @@ +# License: MIT +# Copyright © 2024 Frequenz Energy-as-a-Service GmbH + +"""Model interface.""" + +from ._model_manager import ModelManager + +__all__ = [ + "ModelManager", +] diff --git a/src/frequenz/sdk/ml/_model_manager.py b/src/frequenz/sdk/ml/_model_manager.py new file mode 100644 index 000000000..b22420fcf --- /dev/null +++ b/src/frequenz/sdk/ml/_model_manager.py @@ -0,0 +1,143 @@ +# License: MIT +# Copyright © 2024 Frequenz Energy-as-a-Service GmbH + +"""Load, update, monitor and retrieve machine learning models.""" + +import asyncio +import logging +import pickle +from dataclasses import dataclass +from pathlib import Path +from typing import Generic, TypeVar, cast + +from frequenz.channels.file_watcher import EventType, FileWatcher +from typing_extensions import override + +from frequenz.sdk.actor import BackgroundService + +_logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +@dataclass +class _Model(Generic[T]): + """Represent a machine learning model.""" + + data: T + path: Path + + +class ModelNotFoundError(Exception): + """Exception raised when a model is not found.""" + + def __init__(self, key: str) -> None: + """Initialize the exception with the specified model key. + + Args: + key: The key of the model that was not found. + """ + super().__init__(f"Model with key '{key}' is not found.") + + +class ModelManager(BackgroundService, Generic[T]): + """Load, update, monitor and retrieve machine learning models.""" + + def __init__(self, model_paths: dict[str, Path], *, name: str | None = None): + """Initialize the model manager with the specified model paths. + + Args: + model_paths: A dictionary of model keys and their corresponding file paths. + name: The name of the model manager service. + """ + super().__init__(name=name) + self._models: dict[str, _Model[T]] = {} + self.model_paths = model_paths + self.load_models() + + def load_models(self) -> None: + """Load the models from the specified paths.""" + for key, path in self.model_paths.items(): + self._models[key] = _Model(data=self._load(path), path=path) + + @staticmethod + def _load(path: Path) -> T: + """Load the model from the specified path. + + Args: + path: The path to the model file. + + Returns: + T: The loaded model data. + + Raises: + ModelNotFoundError: If the model file does not exist. + """ + try: + with path.open("rb") as file: + return cast(T, pickle.load(file)) + except FileNotFoundError as exc: + raise ModelNotFoundError(str(path)) from exc + + @override + def start(self) -> None: + """Start the model monitoring service by creating a background task.""" + if not self.is_running: + task = asyncio.create_task(self._monitor_paths()) + self._tasks.add(task) + _logger.info( + "%s: Started ModelManager service with task %s", + self.name, + task, + ) + + async def _monitor_paths(self) -> None: + """Monitor model file paths and reload models as necessary.""" + model_paths = [model.path for model in self._models.values()] + file_watcher = FileWatcher( + paths=list(model_paths), event_types=[EventType.CREATE, EventType.MODIFY] + ) + _logger.info("%s: Monitoring model paths for changes.", self.name) + async for event in file_watcher: + _logger.info( + "%s: Reloading model from file %s due to a %s event...", + self.name, + event.path, + event.type.name, + ) + self.reload_model(Path(event.path)) + + def reload_model(self, path: Path) -> None: + """Reload the model from the specified path. + + Args: + path: The path to the model file. + """ + for key, model in self._models.items(): + if model.path == path: + try: + model.data = self._load(path) + _logger.info( + "%s: Successfully reloaded model from %s", + self.name, + path, + ) + except Exception: # pylint: disable=broad-except + _logger.exception("Failed to reload model from %s", path) + + def get_model(self, key: str) -> T: + """Retrieve a loaded model by key. + + Args: + key: The key of the model to retrieve. + + Returns: + The loaded model data. + + Raises: + KeyError: If the model with the specified key is not found. + """ + try: + return self._models[key].data + except KeyError as exc: + raise KeyError(f"Model with key '{key}' is not found.") from exc diff --git a/tests/model/__init__.py b/tests/model/__init__.py new file mode 100644 index 000000000..00cd013d3 --- /dev/null +++ b/tests/model/__init__.py @@ -0,0 +1,4 @@ +# License: MIT +# Copyright © 2024 Frequenz Energy-as-a-Service GmbH + +"""Tests for the model package.""" diff --git a/tests/model/test_model_manager.py b/tests/model/test_model_manager.py new file mode 100644 index 000000000..2eaca7c33 --- /dev/null +++ b/tests/model/test_model_manager.py @@ -0,0 +1,118 @@ +# License: MIT +# Copyright © 2024 Frequenz Energy-as-a-Service GmbH + +"""Tests for machine learning model manager.""" + +import pickle +from dataclasses import dataclass +from pathlib import Path +from typing import Any +from unittest.mock import AsyncMock, MagicMock, mock_open, patch + +import pytest + +from frequenz.sdk.ml import ModelManager + + +@dataclass +class MockModel: + """Mock model for unit testing purposes.""" + + data: int | str + + def predict(self) -> int | str: + """Make a prediction based on the model data.""" + return self.data + + +async def test_model_manager_loading() -> None: + """Test loading models using ModelManager with direct configuration.""" + model1 = MockModel("Model 1 Data") + model2 = MockModel("Model 2 Data") + pickled_model1 = pickle.dumps(model1) + pickled_model2 = pickle.dumps(model2) + + model_paths = { + "model1": Path("path/to/model1.pkl"), + "model2": Path("path/to/model2.pkl"), + } + + mock_files = { + "path/to/model1.pkl": mock_open(read_data=pickled_model1)(), + "path/to/model2.pkl": mock_open(read_data=pickled_model2)(), + } + + def mock_open_func(file_path: Path, *__args: Any, **__kwargs: Any) -> Any: + """Mock open function to return the correct mock file object. + + Args: + file_path: The path to the file to open. + *__args: Variable length argument list. This can be used to pass additional + positional parameters typically used in file opening operations, + such as `mode` or `buffering`. + **__kwargs: Arbitrary keyword arguments. This can include parameters like + `encoding` and `errors`, common in file opening operations. + + Returns: + Any: The mock file object. + + Raises: + FileNotFoundError: If the file path is not in the mock files dictionary. + """ + file_path_str = str(file_path) + if file_path_str in mock_files: + file_handle = MagicMock() + file_handle.__enter__.return_value = mock_files[file_path_str] + return file_handle + raise FileNotFoundError(f"No mock setup for {file_path_str}") + + with patch("pathlib.Path.open", new=mock_open_func): + with patch.object(Path, "exists", return_value=True): + model_manager: ModelManager[MockModel] = ModelManager( + model_paths=model_paths + ) + + with patch( + "frequenz.channels.file_watcher.FileWatcher", new_callable=AsyncMock + ): + model_manager.start() # Start the service + + assert isinstance(model_manager.get_model("model1"), MockModel) + assert model_manager.get_model("model1").data == "Model 1 Data" + assert model_manager.get_model("model2").data == "Model 2 Data" + + with pytest.raises(KeyError): + model_manager.get_model("key3") + + await model_manager.stop() # Stop the service to clean up + + +async def test_model_manager_update() -> None: + """Test updating a model in ModelManager.""" + original_model = MockModel("Original Data") + updated_model = MockModel("Updated Data") + pickled_original_model = pickle.dumps(original_model) + pickled_updated_model = pickle.dumps(updated_model) + + model_paths = {"model1": Path("path/to/model1.pkl")} + + mock_file = mock_open(read_data=pickled_original_model) + with ( + patch("pathlib.Path.open", mock_file), + patch.object(Path, "exists", return_value=True), + ): + model_manager = ModelManager[MockModel](model_paths=model_paths) + with patch( + "frequenz.channels.file_watcher.FileWatcher", new_callable=AsyncMock + ): + model_manager.start() # Start the service + + assert model_manager.get_model("model1").data == "Original Data" + + # Simulate updating the model file + mock_file.return_value.read.return_value = pickled_updated_model + with patch("pathlib.Path.open", mock_file): + model_manager.reload_model(Path("path/to/model1.pkl")) + assert model_manager.get_model("model1").data == "Updated Data" + + await model_manager.stop() # Stop the service to clean up