@@ -109,6 +109,12 @@ class Combiner:
109
109
description. If ``None`` it uses ``np.float64``.
110
110
Default is ``None``.
111
111
112
+ xp : array namespace, optional
113
+ The array namespace to use for the data. If not provided, it will
114
+ be inferred from the first `~astropy.nddata.CCDData` object in
115
+ ``ccd_iter``.
116
+ Default is ``None``.
117
+
112
118
Raises
113
119
------
114
120
TypeError
@@ -136,7 +142,7 @@ class Combiner:
136
142
[ 0.66666667, 0.66666667, 0.66666667, 0.66666667]]...)
137
143
"""
138
144
139
- def __init__ (self , ccd_iter , dtype = None ):
145
+ def __init__ (self , ccd_iter , dtype = None , xp = None ):
140
146
if ccd_iter is None :
141
147
raise TypeError (
142
148
"ccd_iter should be a list or a generator of CCDData objects."
@@ -167,7 +173,8 @@ def __init__(self, ccd_iter, dtype=None):
167
173
raise TypeError ("CCDData objects don't have the same unit." )
168
174
169
175
# Set array namespace
170
- xp = array_api_compat .array_namespace (ccd_list [0 ].data )
176
+ xp = xp or array_api_compat .array_namespace (ccd_list [0 ].data )
177
+ self ._xp = xp
171
178
if dtype is None :
172
179
dtype = xp .float64
173
180
self .ccd_list = ccd_list
@@ -247,7 +254,7 @@ def scaling(self):
247
254
248
255
@scaling .setter
249
256
def scaling (self , value ):
250
- xp = array_api_compat . array_namespace ( self .data_arr )
257
+ xp = self ._xp
251
258
if value is None :
252
259
self ._scaling = value
253
260
else :
@@ -316,7 +323,7 @@ def clip_extrema(self, nlow=0, nhigh=0):
316
323
.. [0] image.imcombine help text.
317
324
http://stsdas.stsci.edu/cgi-bin/gethelp.cgi?imcombine
318
325
"""
319
- xp = array_api_compat . array_namespace ( self .data_arr )
326
+ xp = self ._xp
320
327
if nlow is None :
321
328
nlow = 0
322
329
if nhigh is None :
@@ -447,7 +454,7 @@ def _get_scaled_data(self, scale_arg):
447
454
return self .data_arr
448
455
449
456
def _get_nan_substituted_data (self , data ):
450
- xp = array_api_compat . array_namespace ( self .data_arr )
457
+ xp = self ._xp
451
458
452
459
# Get the data as an unmasked array with masked values filled as NaN
453
460
if self .data_arr_mask .any ():
@@ -462,7 +469,7 @@ def _combination_setup(self, user_func, default_func, scale_to):
462
469
Handle the common pieces of image combination data/mask setup.
463
470
"""
464
471
data = self ._get_scaled_data (scale_to )
465
- xp = array_api_compat . array_namespace ( data )
472
+ xp = self . _xp
466
473
# Play it safe for now and only do the nan thing if the user is using
467
474
# the default combination function.
468
475
if user_func is None :
@@ -515,7 +522,7 @@ def median_combine(
515
522
The uncertainty currently calculated using the median absolute
516
523
deviation does not account for rejected pixels.
517
524
"""
518
- xp = array_api_compat . array_namespace ( self .data_arr )
525
+ xp = self ._xp
519
526
520
527
_default_median_func = _default_median (xp = xp )
521
528
@@ -565,12 +572,12 @@ def median_combine(
565
572
# return the combined image
566
573
return combined_image
567
574
568
- def _weighted_sum (self , data , sum_func ):
575
+ def _weighted_sum (self , data , sum_func , xp = None ):
569
576
"""
570
577
Perform weighted sum, used by both ``sum_combine`` and in some cases
571
578
by ``average_combine``.
572
579
"""
573
- xp = array_api_compat .array_namespace (data )
580
+ xp = xp or array_api_compat .array_namespace (data )
574
581
if self .weights .shape != data .shape :
575
582
# Add extra axes to the weights for broadcasting
576
583
weights = xp .reshape (self .weights , [len (self .weights ), 1 , 1 ])
@@ -624,7 +631,7 @@ def average_combine(
624
631
combined_image: `~astropy.nddata.CCDData`
625
632
CCDData object based on the combined input of CCDData objects.
626
633
"""
627
- xp = array_api_compat . array_namespace ( self .data_arr )
634
+ xp = self ._xp
628
635
629
636
_default_average_func = _default_average (xp = xp )
630
637
@@ -641,7 +648,7 @@ def average_combine(
641
648
# Do NOT modify data after this -- we need it to be intact when we
642
649
# we get to the uncertainty calculation.
643
650
if self .weights is not None :
644
- weighted_sum , weights = self ._weighted_sum (data , sum_func )
651
+ weighted_sum , weights = self ._weighted_sum (data , sum_func , xp = xp )
645
652
mean = weighted_sum / sum_func (weights , axis = 0 )
646
653
else :
647
654
mean = scale_func (data , axis = 0 )
@@ -707,7 +714,7 @@ def sum_combine(
707
714
CCDData object based on the combined input of CCDData objects.
708
715
"""
709
716
710
- xp = array_api_compat . array_namespace ( self .data_arr )
717
+ xp = self ._xp
711
718
712
719
_default_sum_func = _default_sum (xp = xp )
713
720
@@ -719,7 +726,7 @@ def sum_combine(
719
726
)
720
727
721
728
if self .weights is not None :
722
- summed , weights = self ._weighted_sum (data , sum_func )
729
+ summed , weights = self ._weighted_sum (data , sum_func , xp = xp )
723
730
else :
724
731
summed = sum_func (data , axis = 0 )
725
732
0 commit comments