Skip to content

Commit 36fcb59

Browse files
committed
sum
1 parent 3942864 commit 36fcb59

File tree

3 files changed

+41
-13
lines changed

3 files changed

+41
-13
lines changed

src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,7 @@ mod tests {
750750

751751
fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> {
752752
let state = place.get::<u64>();
753-
if let Some(ScalarRef::Number(NumberScalar::UInt64(value))) = data.get(0) {
753+
if let Some(ScalarRef::Number(NumberScalar::UInt64(value))) = data.first() {
754754
*state += value;
755755
}
756756
Ok(())
@@ -1031,7 +1031,7 @@ mod tests {
10311031
{
10321032
let values = vec![10u64, 20u64, 30u64, 40u64];
10331033
let flags = vec![true, false, true, true];
1034-
let filter_bits = vec![true, false, true, false]; // Only process indices 0 and 2
1034+
let filter_bits = [true, false, true, false]; // Only process indices 0 and 2
10351035
let filter = Bitmap::from_iter(filter_bits.iter().copied());
10361036

10371037
test_both_strategies(&values, &flags, Some(&filter))?;
@@ -1040,7 +1040,7 @@ mod tests {
10401040
{
10411041
let values = vec![100u64, 200u64, 300u64, 400u64];
10421042
let flags = vec![true, true, false, true]; // Skip index 2
1043-
let filter_bits = vec![true, false, true, true]; // Skip index 1
1043+
let filter_bits = [true, false, true, true]; // Skip index 1
10441044
let filter = Bitmap::from_iter(filter_bits.iter().copied());
10451045

10461046
test_both_strategies(&values, &flags, Some(&filter))?;

src/query/functions/src/aggregates/aggregate_sum.rs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ use databend_common_expression::AggregateFunctionRef;
3232
use databend_common_expression::BlockEntry;
3333
use databend_common_expression::ColumnBuilder;
3434
use databend_common_expression::Scalar;
35+
use databend_common_expression::ScalarRef;
3536
use databend_common_expression::StateAddr;
37+
use databend_common_expression::StateSerdeItem;
3638
use databend_common_expression::SELECTIVITY_THRESHOLD;
3739
use num_traits::AsPrimitive;
3840

@@ -76,14 +78,14 @@ pub trait SumState: BorshSerialize + BorshDeserialize + Send + Sync + Default +
7678

7779
#[derive(BorshSerialize, BorshDeserialize)]
7880
pub struct NumberSumState<N>
79-
where N: ValueType
81+
where N: ArgType
8082
{
8183
pub value: N::Scalar,
8284
}
8385

8486
impl<N> Default for NumberSumState<N>
8587
where
86-
N: ValueType,
88+
N: ArgType,
8789
N::Scalar: Number + AsPrimitive<f64> + BorshSerialize + BorshDeserialize + std::ops::AddAssign,
8890
{
8991
fn default() -> Self {
@@ -130,7 +132,7 @@ where
130132
impl<T, N> UnaryState<T, N> for NumberSumState<N>
131133
where
132134
T: ArgType + Sync + Send,
133-
N: ValueType,
135+
N: ArgType,
134136
T::Scalar: Number + AsPrimitive<N::Scalar>,
135137
N::Scalar: Number + AsPrimitive<f64> + BorshSerialize + BorshDeserialize + std::ops::AddAssign,
136138
for<'a> T::ScalarRef<'a>: Number + AsPrimitive<N::Scalar>,
@@ -169,6 +171,20 @@ where
169171
builder.push_item(N::to_scalar_ref(&self.value));
170172
Ok(())
171173
}
174+
175+
fn serialize_type() -> Vec<StateSerdeItem> {
176+
std::vec![StateSerdeItem::DataType(T::data_type())]
177+
}
178+
179+
fn serialize_state(&self, builders: &mut [ColumnBuilder]) -> Result<()> {
180+
N::downcast_builder(&mut builders[0]).push_item(N::to_scalar_ref(&self.value));
181+
Ok(())
182+
}
183+
184+
fn deserialize_state(data: &[ScalarRef]) -> Result<Self> {
185+
let value = N::to_owned_scalar(N::try_downcast_scalar(&data[0]).unwrap());
186+
Ok(Self { value })
187+
}
172188
}
173189

174190
#[derive(BorshDeserialize, BorshSerialize)]

src/query/functions/src/aggregates/aggregate_unary.rs

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,22 @@ where
8080
builder: R::ColumnBuilderMut<'_>,
8181
function_data: Option<&dyn FunctionData>,
8282
) -> Result<()>;
83+
84+
fn serialize_type() -> Vec<StateSerdeItem> {
85+
vec![StateSerdeItem::Binary(None)]
86+
}
87+
88+
fn serialize_state(&self, builders: &mut [ColumnBuilder]) -> Result<()> {
89+
let binary_builder = builders[0].as_binary_mut().unwrap();
90+
self.serialize(&mut binary_builder.data)?;
91+
binary_builder.commit_row();
92+
Ok(())
93+
}
94+
95+
fn deserialize_state(data: &[ScalarRef]) -> Result<Self> {
96+
let mut binary = *data[0].as_binary().unwrap();
97+
Ok(Self::deserialize_reader(&mut binary)?)
98+
}
8399
}
84100

85101
pub trait FunctionData: Send + Sync {
@@ -230,21 +246,17 @@ where
230246
}
231247

232248
fn serialize_type(&self) -> Vec<StateSerdeItem> {
233-
vec![StateSerdeItem::Binary(None)]
249+
S::serialize_type()
234250
}
235251

236252
fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> {
237-
let binary_builder = builders[0].as_binary_mut().unwrap();
238253
let state: &mut S = place.get::<S>();
239-
state.serialize(&mut binary_builder.data)?;
240-
binary_builder.commit_row();
241-
Ok(())
254+
state.serialize_state(builders)
242255
}
243256

244257
fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> {
245-
let mut binary = *data[0].as_binary().unwrap();
246258
let state: &mut S = place.get::<S>();
247-
let rhs = S::deserialize_reader(&mut binary)?;
259+
let rhs = S::deserialize_state(data)?;
248260
state.merge(&rhs)
249261
}
250262

0 commit comments

Comments
 (0)