Skip to content

Commit 84b7f1f

Browse files
committed
Store array namespace when Combiner is created
1 parent 606af8d commit 84b7f1f

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

ccdproc/combiner.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,12 @@ class Combiner:
109109
description. If ``None`` it uses ``np.float64``.
110110
Default is ``None``.
111111
112+
xp : array namespace, optional
113+
The array namespace to use for the data. If not provided, it will
114+
be inferred from the first `~astropy.nddata.CCDData` object in
115+
``ccd_iter``.
116+
Default is ``None``.
117+
112118
Raises
113119
------
114120
TypeError
@@ -136,7 +142,7 @@ class Combiner:
136142
[ 0.66666667, 0.66666667, 0.66666667, 0.66666667]]...)
137143
"""
138144

139-
def __init__(self, ccd_iter, dtype=None):
145+
def __init__(self, ccd_iter, dtype=None, xp=None):
140146
if ccd_iter is None:
141147
raise TypeError(
142148
"ccd_iter should be a list or a generator of CCDData objects."
@@ -167,7 +173,8 @@ def __init__(self, ccd_iter, dtype=None):
167173
raise TypeError("CCDData objects don't have the same unit.")
168174

169175
# Set array namespace
170-
xp = array_api_compat.array_namespace(ccd_list[0].data)
176+
xp = xp or array_api_compat.array_namespace(ccd_list[0].data)
177+
self._xp = xp
171178
if dtype is None:
172179
dtype = xp.float64
173180
self.ccd_list = ccd_list
@@ -247,7 +254,7 @@ def scaling(self):
247254

248255
@scaling.setter
249256
def scaling(self, value):
250-
xp = array_api_compat.array_namespace(self.data_arr)
257+
xp = self._xp
251258
if value is None:
252259
self._scaling = value
253260
else:
@@ -316,7 +323,7 @@ def clip_extrema(self, nlow=0, nhigh=0):
316323
.. [0] image.imcombine help text.
317324
http://stsdas.stsci.edu/cgi-bin/gethelp.cgi?imcombine
318325
"""
319-
xp = array_api_compat.array_namespace(self.data_arr)
326+
xp = self._xp
320327
if nlow is None:
321328
nlow = 0
322329
if nhigh is None:
@@ -447,7 +454,7 @@ def _get_scaled_data(self, scale_arg):
447454
return self.data_arr
448455

449456
def _get_nan_substituted_data(self, data):
450-
xp = array_api_compat.array_namespace(self.data_arr)
457+
xp = self._xp
451458

452459
# Get the data as an unmasked array with masked values filled as NaN
453460
if self.data_arr_mask.any():
@@ -462,7 +469,7 @@ def _combination_setup(self, user_func, default_func, scale_to):
462469
Handle the common pieces of image combination data/mask setup.
463470
"""
464471
data = self._get_scaled_data(scale_to)
465-
xp = array_api_compat.array_namespace(data)
472+
xp = self._xp
466473
# Play it safe for now and only do the nan thing if the user is using
467474
# the default combination function.
468475
if user_func is None:
@@ -515,7 +522,7 @@ def median_combine(
515522
The uncertainty currently calculated using the median absolute
516523
deviation does not account for rejected pixels.
517524
"""
518-
xp = array_api_compat.array_namespace(self.data_arr)
525+
xp = self._xp
519526

520527
_default_median_func = _default_median(xp=xp)
521528

@@ -565,12 +572,12 @@ def median_combine(
565572
# return the combined image
566573
return combined_image
567574

568-
def _weighted_sum(self, data, sum_func):
575+
def _weighted_sum(self, data, sum_func, xp=None):
569576
"""
570577
Perform weighted sum, used by both ``sum_combine`` and in some cases
571578
by ``average_combine``.
572579
"""
573-
xp = array_api_compat.array_namespace(data)
580+
xp = xp or array_api_compat.array_namespace(data)
574581
if self.weights.shape != data.shape:
575582
# Add extra axes to the weights for broadcasting
576583
weights = xp.reshape(self.weights, [len(self.weights), 1, 1])
@@ -624,7 +631,7 @@ def average_combine(
624631
combined_image: `~astropy.nddata.CCDData`
625632
CCDData object based on the combined input of CCDData objects.
626633
"""
627-
xp = array_api_compat.array_namespace(self.data_arr)
634+
xp = self._xp
628635

629636
_default_average_func = _default_average(xp=xp)
630637

@@ -641,7 +648,7 @@ def average_combine(
641648
# Do NOT modify data after this -- we need it to be intact when we
642649
# we get to the uncertainty calculation.
643650
if self.weights is not None:
644-
weighted_sum, weights = self._weighted_sum(data, sum_func)
651+
weighted_sum, weights = self._weighted_sum(data, sum_func, xp=xp)
645652
mean = weighted_sum / sum_func(weights, axis=0)
646653
else:
647654
mean = scale_func(data, axis=0)
@@ -707,7 +714,7 @@ def sum_combine(
707714
CCDData object based on the combined input of CCDData objects.
708715
"""
709716

710-
xp = array_api_compat.array_namespace(self.data_arr)
717+
xp = self._xp
711718

712719
_default_sum_func = _default_sum(xp=xp)
713720

@@ -719,7 +726,7 @@ def sum_combine(
719726
)
720727

721728
if self.weights is not None:
722-
summed, weights = self._weighted_sum(data, sum_func)
729+
summed, weights = self._weighted_sum(data, sum_func, xp=xp)
723730
else:
724731
summed = sum_func(data, axis=0)
725732

0 commit comments

Comments
 (0)