Skip to content

Commit b558f55

Browse files
fix: make sure last slice is included in iterate slices (#398)
1 parent 3fb5b6c commit b558f55

File tree

3 files changed

+35
-3
lines changed

3 files changed

+35
-3
lines changed

doc/release_notes.rst

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Release Notes
44
Upcoming Version
55
----------------
66

7+
8+
* IMPORTANT BUGFIX: The last slice of constraints was not correctly written to LP files in case the constraint size was not a multiple of the slice size. This is fixed now.
79
* Solution files that following a different naming scheme of variables and constraints using more than on initial letter in the prefix (e.g. `col123`, `row456`) are now supported.
810

911
Version 0.4.3

linopy/common.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ def iterate_slices(
519519
return
520520

521521
# number of slices
522-
n_slices = max(size // slice_size, 1)
522+
n_slices = max((size + slice_size - 1) // slice_size, 1)
523523

524524
# leading dimension (the dimension with the largest size)
525525
sizes = {dim: ds.sizes[dim] for dim in slice_dims}
@@ -533,12 +533,12 @@ def iterate_slices(
533533
if size_of_leading_dim < n_slices:
534534
n_slices = size_of_leading_dim
535535

536-
chunk_size = ds.sizes[leading_dim] // n_slices
536+
chunk_size = (ds.sizes[leading_dim] + n_slices - 1) // n_slices
537537

538538
# Iterate over the Cartesian product of slice indices
539539
for i in range(n_slices):
540540
start = i * chunk_size
541-
end = start + chunk_size
541+
end = min(start + chunk_size, size_of_leading_dim)
542542
slice_dict = {leading_dim: slice(start, end)}
543543
yield ds.isel(slice_dict)
544544

test/test_common.py

+30
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,20 @@ def test_iterate_slices_slice_size_none():
522522
assert ds.equals(s)
523523

524524

525+
def test_iterate_slices_includes_last_slice():
526+
ds = xr.Dataset(
527+
{"var": (("x"), np.random.rand(10))}, # noqa: NPY002
528+
coords={"x": np.arange(10)},
529+
)
530+
slices = list(iterate_slices(ds, slice_size=3, slice_dims=["x"]))
531+
assert len(slices) == 4 # 10 slices for dimension 'x' with size 10
532+
total_elements = sum(s.sizes["x"] for s in slices)
533+
assert total_elements == ds.sizes["x"] # Ensure all elements are included
534+
for s in slices:
535+
assert isinstance(s, xr.Dataset)
536+
assert set(s.dims) == set(ds.dims)
537+
538+
525539
def test_iterate_slices_empty_slice_dims():
526540
ds = xr.Dataset(
527541
{"var": (("x", "y"), np.random.rand(10, 10))}, # noqa: NPY002
@@ -542,6 +556,22 @@ def test_iterate_slices_invalid_slice_dims():
542556
list(iterate_slices(ds, slice_size=50, slice_dims=["z"]))
543557

544558

559+
def test_iterate_slices_empty_dataset():
560+
ds = xr.Dataset(
561+
{"var": (("x", "y"), np.array([]).reshape(0, 0))}, coords={"x": [], "y": []}
562+
)
563+
slices = list(iterate_slices(ds, slice_size=10, slice_dims=["x"]))
564+
assert len(slices) == 1
565+
assert ds.equals(slices[0])
566+
567+
568+
def test_iterate_slices_single_element():
569+
ds = xr.Dataset({"var": (("x", "y"), np.array([[1]]))}, coords={"x": [0], "y": [0]})
570+
slices = list(iterate_slices(ds, slice_size=1, slice_dims=["x"]))
571+
assert len(slices) == 1
572+
assert ds.equals(slices[0])
573+
574+
545575
def test_get_dims_with_index_levels():
546576
# Create test data
547577

0 commit comments

Comments
 (0)