From 956e6a920e1bd1f0ec007c6f0ef5b61c3c585a22 Mon Sep 17 00:00:00 2001 From: mirjagranfors Date: Fri, 1 Aug 2025 14:51:44 +0200 Subject: [PATCH 1/2] updated image/pad image to fft and the corresponding unittest --- deeptrack/image.py | 44 +++++++++++++++++++++++++---------- deeptrack/tests/test_image.py | 33 +++++++++++++++++++++++++- 2 files changed, 64 insertions(+), 13 deletions(-) diff --git a/deeptrack/image.py b/deeptrack/image.py index 6a221a3fd..8706a8f9b 100644 --- a/deeptrack/image.py +++ b/deeptrack/image.py @@ -96,11 +96,17 @@ class is central to DeepTrack2, acting as a container for numerical data import operator as ops from typing import Any, Callable, Iterable +import array_api_compat as apc import numpy as np +from numpy.typing import NDArray +from deeptrack.backend import config, TORCH_AVAILABLE, xp from deeptrack.properties import Property from deeptrack.types import NumberLike +if TORCH_AVAILABLE: + import torch + #TODO ***??*** revise _binary_method - typing, docstring, unit test def _binary_method( @@ -1694,11 +1700,10 @@ def coerce( _FASTEST_SIZES = np.sort(_FASTEST_SIZES) -#TODO ***??*** revise pad_image_to_fft - typing, docstring, unit test def pad_image_to_fft( - image: Image | np.ndarray | np.ndarray, + image: Image | NDArray | torch.Tensor, axes: Iterable[int] = (0, 1), -) -> Image | np.ndarray: +) -> Image | NDArray | torch.Tensor: """Pads an image to optimize Fast Fourier Transform (FFT) performance. This function pads an image by adding zeros to the end of specified axes @@ -1707,7 +1712,7 @@ def pad_image_to_fft( Parameters ---------- - image: Image | np.ndarray + image: Image | np.ndarray | torch.tensor The input image to pad. It should be an instance of the `Image` class or any array-like structure compatible with FFT operations. axes: Iterable[int], optional @@ -1715,7 +1720,7 @@ def pad_image_to_fft( Returns ------- - Image | np.ndarray + Image | np.ndarray | torch.tensor The padded image with dimensions optimized for FFT performance. Raises @@ -1729,26 +1734,30 @@ def pad_image_to_fft( >>> from deeptrack.image import Image, pad_image_to_fft Pad an Image object: - - >>> img = Image(np.zeros((7, 13))) + >>> img = Image(np.ones((7, 13))) >>> padded_img = pad_image_to_fft(img) >>> print(padded_img.shape) (8, 16) Pad a NumPy array: - - >>> img = np.zeros((5, 11))) + >>> img = np.ones((5, 11)) >>> padded_img = pad_image_to_fft(img) >>> print(padded_img.shape) (6, 12) + Pad a PyTorch tensor: + >>> img = torch.ones(7, 11) + >>> padded_img = pad_image_to_fft(img) + >>> print(padded_img.shape) + (8, 12) + """ def _closest( dim: int, ) -> int: - # Returns the smallest value frin _FASTEST_SIZES larger than dim. + # Returns the smallest value from _FASTEST_SIZES that is >= dim. for size in _FASTEST_SIZES: if size >= dim: return size @@ -1763,7 +1772,18 @@ def _closest( new_shape[axis] = _closest(new_shape[axis]) # Calculate the padding for each axis. - pad_width = [(0, increase) for increase in np.array(new_shape) - image.shape] + pad_width = [ + (0, increase) + for increase in np.array(new_shape) - np.array(image.shape) + ] + + # Apply zero-padding with torch.nn.functional.pad if the input is a + # PyTorch tensor + if apc.is_torch_array(image): + pad = [] + for before, after in reversed(pad_width): + pad.extend([before, after]) + return torch.nn.functional.pad(image, pad, mode="constant", value=0) - # Pad the image using constant mode (add zeros). + # Apply zero-padding with np.pad if the input is a NumPy array or an Image return np.pad(image, pad_width, mode="constant") diff --git a/deeptrack/tests/test_image.py b/deeptrack/tests/test_image.py index d413c8da5..5d4901a82 100644 --- a/deeptrack/tests/test_image.py +++ b/deeptrack/tests/test_image.py @@ -12,7 +12,10 @@ import numpy as np -from deeptrack import features, image +from deeptrack import features, image, TORCH_AVAILABLE + +if TORCH_AVAILABLE: + import torch class TestImage(unittest.TestCase): @@ -389,6 +392,7 @@ def test_Image__view(self): def test_pad_image_to_fft(self): + # Test with dt.Image input_image = image.Image(np.zeros((7, 25))) padded_image = image.pad_image_to_fft(input_image) self.assertEqual(padded_image.shape, (8, 27)) @@ -401,6 +405,33 @@ def test_pad_image_to_fft(self): padded_image = image.pad_image_to_fft(input_image) self.assertEqual(padded_image.shape, (324, 432)) + # Test with NumPy array + input_image = np.ones((7, 13)) + padded_image = image.pad_image_to_fft(input_image) + self.assertEqual(padded_image.shape, (8, 16)) + + input_image = np.ones((5,)) + padded_image = image.pad_image_to_fft(input_image, axes=(0,)) + self.assertEqual(padded_image.shape, (6,)) + + ### Test with PyTorch tensor (if available) + if TORCH_AVAILABLE: + input_image = torch.ones(3, 5) + padded_image = image.pad_image_to_fft(input_image) + self.assertEqual(padded_image.shape, (3, 6)) + self.assertIsInstance(padded_image, torch.Tensor) + + input_image = torch.ones(5, 7, 11, 13) + padded_image = image.pad_image_to_fft(input_image, axes=(0, 1, 3)) + padded_image_np = image.pad_image_to_fft( + input_image.numpy(), axes=(0, 1, 3) + ) + self.assertEqual(padded_image.shape, (6, 8, 11, 16)) + self.assertIsInstance(padded_image, torch.Tensor) + np.testing.assert_allclose( + padded_image.numpy(), padded_image_np, atol=1e-6 + ) + if __name__ == "__main__": unittest.main() \ No newline at end of file From 9a9df609d541e9388a0373f479aa60e7e0895540 Mon Sep 17 00:00:00 2001 From: mirjagranfors Date: Fri, 1 Aug 2025 14:53:05 +0200 Subject: [PATCH 2/2] minor change --- deeptrack/image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deeptrack/image.py b/deeptrack/image.py index 8706a8f9b..a98482102 100644 --- a/deeptrack/image.py +++ b/deeptrack/image.py @@ -1704,7 +1704,7 @@ def pad_image_to_fft( image: Image | NDArray | torch.Tensor, axes: Iterable[int] = (0, 1), ) -> Image | NDArray | torch.Tensor: - """Pads an image to optimize Fast Fourier Transform (FFT) performance. + """Pad an image to optimize Fast Fourier Transform (FFT) performance. This function pads an image by adding zeros to the end of specified axes so that their lengths match the nearest larger size in `_FASTEST_SIZES`.