Skip to content

Commit 5ff1516

Browse files
niksirbiadamltyson
andcommitted
Add log_to_attrs decorator to scale function (#604)
* added log_to_attrs decorator to scale function and documented its effects * adapted tests for the scale function * fixed mistakes in docstring formatting * use monospace consistently in docstring * updated docstring to follow new log style * adapt tests for the serialized attrs.log * more specific name for test helper function --------- Co-authored-by: Adam Tyson <[email protected]>
1 parent c24e78c commit 5ff1516

File tree

2 files changed

+88
-7
lines changed

2 files changed

+88
-7
lines changed

movement/transforms.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
import xarray as xr
55
from numpy.typing import ArrayLike
66

7+
from movement.utils.logging import log_to_attrs
78
from movement.validators.arrays import validate_dims_coords
89

910

11+
@log_to_attrs
1012
def scale(
1113
data: xr.DataArray,
1214
factor: ArrayLike | float = 1.0,
@@ -27,8 +29,8 @@ def scale(
2729
broadcasted.
2830
space_unit : str or None
2931
The unit of the scaled data stored as a property in
30-
xarray.DataArray.attrs['space_unit']. In case of the default (``None``)
31-
the ``space_unit`` attribute is dropped.
32+
``xarray.DataArray.attrs['space_unit']``. In case of the default
33+
(``None``) the ``space_unit`` attribute is dropped.
3234
3335
Returns
3436
-------
@@ -37,9 +39,50 @@ def scale(
3739
3840
Notes
3941
-----
40-
When scale is used multiple times on the same xarray.DataArray,
41-
xarray.DataArray.attrs["space_unit"] is overwritten each time or is dropped
42-
if ``None`` is passed by default or explicitly.
42+
This function makes two changes to the resulting data array's attributes
43+
(``xarray.DataArray.attrs``) each time it is called:
44+
45+
- It sets the ``space_unit`` attribute to the value of the parameter
46+
with the same name, or removes it if ``space_unit=None``.
47+
- It adds a new entry to the ``log`` attribute of the data array, which
48+
contains a record of the operations performed, including the
49+
parameters used, as well as the datetime of the function call.
50+
51+
Examples
52+
--------
53+
Let's imagine a camera viewing a 2D plane from the top, with an
54+
estimated resolution of 10 pixels per cm. We can scale down
55+
position data by a factor of 1/10 to express it in cm units.
56+
57+
>>> from movement.transforms import scale
58+
>>> ds["position"] = scale(ds["position"], factor=1 / 10, space_unit="cm")
59+
>>> print(ds["position"].space_unit)
60+
cm
61+
>>> print(ds["position"].log)
62+
[
63+
{
64+
"operation": "scale",
65+
"datetime": "2025-06-05 15:08:16.919947",
66+
"factor": "0.1",
67+
"space_unit": "'cm'"
68+
}
69+
]
70+
71+
Note that the attributes of the scaled data array now contain the assigned
72+
``space_unit`` as well as a ``log`` entry with the arguments passed to
73+
the function.
74+
75+
We can also scale the two spatial dimensions by different factors.
76+
77+
>>> ds["position"] = scale(ds["position"], factor=[10, 20])
78+
79+
The second scale operation restored the x axis to its original scale,
80+
and scaled up the y axis to twice its original size.
81+
The log will now contain two entries, but the ``space_unit`` attribute
82+
has been removed, as it was not provided in the second function call.
83+
84+
>>> "space_unit" in ds["position"].attrs
85+
False
4386
4487
"""
4588
if len(data.coords["space"]) == 2:

tests/test_unit/test_transforms.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
from typing import Any
23

34
import numpy as np
@@ -32,6 +33,16 @@ def data_array_with_dims_and_coords(
3233
)
3334

3435

36+
def drop_attrs_log(attrs: dict) -> dict:
37+
"""Drop the log string from attrs to faclitate testing.
38+
The log string will never exactly match, because datetimes differ.
39+
"""
40+
attrs_copy = attrs.copy()
41+
if "log" in attrs:
42+
attrs_copy.pop("log", None)
43+
return attrs_copy
44+
45+
3546
@pytest.fixture
3647
def sample_data_2d() -> xr.DataArray:
3748
"""Turn the nparray_0_to_23 into a DataArray."""
@@ -103,7 +114,7 @@ def test_scale(
103114
"""Test scaling with different factors and space_units."""
104115
scaled_data = scale(sample_data_2d, **optional_arguments)
105116
xr.testing.assert_equal(scaled_data, expected_output)
106-
assert scaled_data.attrs == expected_output.attrs
117+
assert drop_attrs_log(scaled_data.attrs) == expected_output.attrs
107118

108119

109120
@pytest.mark.parametrize(
@@ -180,7 +191,7 @@ def test_scale_twice(
180191
**optional_arguments_2,
181192
)
182193
xr.testing.assert_equal(output_data_array, expected_output)
183-
assert output_data_array.attrs == expected_output.attrs
194+
assert drop_attrs_log(output_data_array.attrs) == expected_output.attrs
184195

185196

186197
@pytest.mark.parametrize(
@@ -241,3 +252,30 @@ def test_scale_invalid_3d_space(factor):
241252
assert str(error.value) == (
242253
"Input data must contain ['z'] in the 'space' coordinates.\n"
243254
)
255+
256+
257+
def test_scale_log(sample_data_2d: xr.DataArray):
258+
"""Test that the log attribute is correctly populated
259+
in the scaled data array.
260+
"""
261+
262+
def verify_log_entry(entry, expected_factor, expected_space_unit):
263+
"""Verify each scale log entry."""
264+
assert entry["factor"] == expected_factor
265+
assert entry["space_unit"] == expected_space_unit
266+
assert entry["operation"] == "scale"
267+
assert "datetime" in entry
268+
269+
# scale data twice
270+
scaled_data = scale(
271+
scale(sample_data_2d, factor=2, space_unit="elephants"),
272+
factor=[1, 2],
273+
space_unit="crabs",
274+
)
275+
276+
# verify the log attribute
277+
assert "log" in scaled_data.attrs
278+
log_entries = json.loads(scaled_data.attrs["log"])
279+
assert len(log_entries) == 2
280+
verify_log_entry(log_entries[0], "2", "'elephants'")
281+
verify_log_entry(log_entries[1], "[1, 2]", "'crabs'")

0 commit comments

Comments
 (0)