From 028584eaa94e4b787f81d804df47094e5d72fe4c Mon Sep 17 00:00:00 2001 From: coldWater Date: Mon, 21 Jul 2025 18:15:38 +0800 Subject: [PATCH 01/22] trait --- .../src/aggregate/aggregate_function.rs | 37 +++++++++++++----- .../src/aggregate/aggregate_function_state.rs | 38 +++++++++++++++---- .../src/aggregate/aggregate_hashtable.rs | 3 +- .../expression/src/aggregate/payload_flush.rs | 9 +---- .../adaptors/aggregate_null_adaptor.rs | 17 +++++---- .../adaptors/aggregate_ornull_adaptor.rs | 9 +++-- .../adaptors/aggregate_sort_adaptor.rs | 4 +- .../src/aggregates/aggregate_arg_min_max.rs | 4 +- .../src/aggregates/aggregate_array_agg.rs | 4 +- .../src/aggregates/aggregate_array_moving.rs | 8 ++-- .../src/aggregates/aggregate_bitmap.rs | 12 +++--- .../aggregate_combinator_distinct.rs | 4 +- .../src/aggregates/aggregate_combinator_if.rs | 8 ++-- .../aggregates/aggregate_combinator_state.rs | 10 ++--- .../src/aggregates/aggregate_count.rs | 4 +- .../src/aggregates/aggregate_covariance.rs | 4 +- .../aggregates/aggregate_json_array_agg.rs | 4 +- .../aggregates/aggregate_json_object_agg.rs | 4 +- .../src/aggregates/aggregate_markov_tarin.rs | 4 +- .../src/aggregates/aggregate_null_result.rs | 4 +- .../aggregates/aggregate_quantile_tdigest.rs | 4 +- .../aggregate_quantile_tdigest_weighted.rs | 4 +- .../src/aggregates/aggregate_retention.rs | 4 +- .../src/aggregates/aggregate_st_collect.rs | 4 +- .../src/aggregates/aggregate_string_agg.rs | 4 +- .../src/aggregates/aggregate_unary.rs | 4 +- .../src/aggregates/aggregate_window_funnel.rs | 4 +- .../src/aggregates/aggregator_common.rs | 4 +- .../aggregator/transform_single_key.rs | 9 +---- .../transforms/aggregator/udaf_script.rs | 4 +- 30 files changed, 134 insertions(+), 102 deletions(-) diff --git a/src/query/expression/src/aggregate/aggregate_function.rs b/src/query/expression/src/aggregate/aggregate_function.rs index fe446c3ed1cf4..d937fe61905e6 100755 --- a/src/query/expression/src/aggregate/aggregate_function.rs +++ b/src/query/expression/src/aggregate/aggregate_function.rs @@ -22,13 +22,13 @@ use super::AggrState; use super::AggrStateLoc; use super::AggrStateRegistry; use super::StateAddr; -use crate::types::BinaryType; use crate::types::DataType; +use crate::AggrStateSerdeType; use crate::BlockEntry; use crate::ColumnBuilder; -use crate::ColumnView; use crate::ProjectedBlock; use crate::Scalar; +use crate::ScalarRef; pub type AggregateFunctionRef = Arc; @@ -69,32 +69,49 @@ pub trait AggregateFunction: fmt::Display + Sync + Send { // Used in aggregate_null_adaptor fn accumulate_row(&self, place: AggrState, columns: ProjectedBlock, row: usize) -> Result<()>; - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()>; + fn serialize_type(&self) -> Vec { + vec![AggrStateSerdeType::Binary(self.serialize_size_per_row())] + } + + fn serialize(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { + let binary_builder = builder.as_tuple_mut().unwrap()[0].as_binary_mut().unwrap(); + self.serialize_binary(place, &mut binary_builder.data)?; + binary_builder.commit_row(); + Ok(()) + } + + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()>; fn serialize_size_per_row(&self) -> Option { None } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()>; + fn merge(&self, place: AggrState, data: ScalarRef) -> Result<()> { + let mut binary = *data.as_tuple().unwrap()[0].as_binary().unwrap(); + self.merge_binary(place, &mut binary) + } + + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()>; /// Batch merge and deserialize the state from binary array fn batch_merge( &self, places: &[StateAddr], loc: &[AggrStateLoc], - state: &ColumnView, + state: &BlockEntry, ) -> Result<()> { - for (place, mut data) in places.iter().zip(state.iter()) { - self.merge(AggrState::new(*place, loc), &mut data)?; + let column = state.to_column(); + for (place, data) in places.iter().zip(column.iter()) { + self.merge(AggrState::new(*place, loc), data)?; } Ok(()) } fn batch_merge_single(&self, place: AggrState, state: &BlockEntry) -> Result<()> { - let view = state.downcast::().unwrap(); - for mut data in view.iter() { - self.merge(place, &mut data)?; + let column = state.to_column(); + for data in column.iter() { + self.merge(place, data)?; } Ok(()) } diff --git a/src/query/expression/src/aggregate/aggregate_function_state.rs b/src/query/expression/src/aggregate/aggregate_function_state.rs index 5d1fe21ac4379..409278532dd06 100644 --- a/src/query/expression/src/aggregate/aggregate_function_state.rs +++ b/src/query/expression/src/aggregate/aggregate_function_state.rs @@ -20,6 +20,8 @@ use enum_as_inner::EnumAsInner; use super::AggregateFunctionRef; use crate::types::binary::BinaryColumnBuilder; +use crate::types::DataType; +use crate::ColumnBuilder; #[derive(Clone, Copy, Debug)] pub struct StateAddr { @@ -113,11 +115,11 @@ impl From for usize { pub fn get_states_layout(funcs: &[AggregateFunctionRef]) -> Result { let mut registry = AggrStateRegistry::default(); - let mut serialize_size = Vec::with_capacity(funcs.len()); + let mut serialize_type = Vec::with_capacity(funcs.len()); for func in funcs { func.register_state(&mut registry); registry.commit(); - serialize_size.push(func.serialize_size_per_row()); + serialize_type.push(func.serialize_type().into_boxed_slice()); } let AggrStateRegistry { states, offsets } = registry; @@ -132,7 +134,7 @@ pub fn get_states_layout(funcs: &[AggregateFunctionRef]) -> Result Ok(StatesLayout { layout, states_loc, - serialize_size, + serialize_type, }) } @@ -195,14 +197,30 @@ impl AggrStateLoc { pub struct StatesLayout { pub layout: Layout, pub states_loc: Vec>, - serialize_size: Vec>, + serialize_type: Vec>, } impl StatesLayout { - pub fn serialize_builders(&self, num_rows: usize) -> Vec { - self.serialize_size + pub fn serialize_builders(&self, num_rows: usize) -> Vec { + self.serialize_type .iter() - .map(|size| BinaryColumnBuilder::with_capacity(num_rows, num_rows * size.unwrap_or(0))) + .map(|item| { + let builder = item + .iter() + .map(|serde_type| match serde_type { + AggrStateSerdeType::Bool => { + ColumnBuilder::with_capacity(&DataType::Boolean, num_rows) + } + AggrStateSerdeType::Binary(size) => { + ColumnBuilder::Binary(BinaryColumnBuilder::with_capacity( + num_rows, + num_rows * size.unwrap_or(0), + )) + } + }) + .collect(); + ColumnBuilder::Tuple(builder) + }) .collect() } } @@ -288,6 +306,12 @@ pub enum AggrStateType { Custom(Layout), } +#[derive(Debug, Clone, Copy)] +pub enum AggrStateSerdeType { + Bool, + Binary(Option), +} + #[cfg(test)] mod tests { use proptest::prelude::*; diff --git a/src/query/expression/src/aggregate/aggregate_hashtable.rs b/src/query/expression/src/aggregate/aggregate_hashtable.rs index 11907059ef90e..a605227b7f5ce 100644 --- a/src/query/expression/src/aggregate/aggregate_hashtable.rs +++ b/src/query/expression/src/aggregate/aggregate_hashtable.rs @@ -27,7 +27,6 @@ use crate::aggregate::payload_row::row_match_columns; use crate::group_hash_columns; use crate::new_sel; use crate::read; -use crate::types::BinaryType; use crate::types::DataType; use crate::AggregateFunctionRef; use crate::BlockEntry; @@ -219,7 +218,7 @@ impl AggregateHashTable { .zip(agg_states.iter()) .zip(states_layout.states_loc.iter()) { - func.batch_merge(state_places, loc, &state.downcast::().unwrap())?; + func.batch_merge(state_places, loc, state)?; } } } diff --git a/src/query/expression/src/aggregate/payload_flush.rs b/src/query/expression/src/aggregate/payload_flush.rs index 43dabd3b519c2..e0d15864a3a8e 100644 --- a/src/query/expression/src/aggregate/payload_flush.rs +++ b/src/query/expression/src/aggregate/payload_flush.rs @@ -150,17 +150,12 @@ impl Payload { { { let builder = &mut builders[idx]; - func.serialize(AggrState::new(*place, loc), &mut builder.data)?; - builder.commit_row(); + func.serialize(AggrState::new(*place, loc), builder)?; } } } - entries.extend( - builders - .into_iter() - .map(|builder| Column::Binary(builder.build()).into()), - ); + entries.extend(builders.into_iter().map(|builder| builder.build().into())); } entries.extend_from_slice(&state.take_group_columns()); diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs index 18af505c1241a..b100564223d3a 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs @@ -183,11 +183,11 @@ impl AggregateFunction for AggregateNullUnaryAdapto .accumulate_row(place, not_null_columns, validity, row) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { self.0.serialize(place, writer) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { self.0.merge(place, reader) } @@ -308,11 +308,11 @@ impl AggregateFunction .accumulate_row(place, not_null_columns, validity, row) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { self.0.serialize(place, writer) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { self.0.merge(place, reader) } @@ -498,17 +498,18 @@ impl CommonNullAdaptor { fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { if !NULLABLE_RESULT { - return self.nested.serialize(place, writer); + return self.nested.serialize_binary(place, writer); } - self.nested.serialize(place.remove_last_loc(), writer)?; + self.nested + .serialize_binary(place.remove_last_loc(), writer)?; let flag = get_flag(place); writer.write_scalar(&flag) } fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { if !NULLABLE_RESULT { - return self.nested.merge(place, reader); + return self.nested.merge_binary(place, reader); } let flag = reader[reader.len() - 1]; @@ -522,7 +523,7 @@ impl CommonNullAdaptor { } set_flag(place, true); self.nested - .merge(place.remove_last_loc(), &mut &reader[..reader.len() - 1]) + .merge_binary(place.remove_last_loc(), &mut &reader[..reader.len() - 1]) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs index c7f9d4c4677d4..3fd0cbcd70ccc 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs @@ -178,17 +178,18 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { } #[inline] - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - self.inner.serialize(place.remove_last_loc(), writer)?; + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + self.inner + .serialize_binary(place.remove_last_loc(), writer)?; let flag = get_flag(place) as u8; writer.write_scalar(&flag) } #[inline] - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let flag = get_flag(place) || reader[reader.len() - 1] > 0; self.inner - .merge(place.remove_last_loc(), &mut &reader[..reader.len() - 1])?; + .merge_binary(place.remove_last_loc(), &mut &reader[..reader.len() - 1])?; set_flag(place, flag); Ok(()) } diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_sort_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_sort_adaptor.rs index 4894fc24a9c84..20b11214d15fc 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_sort_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_sort_adaptor.rs @@ -121,12 +121,12 @@ impl AggregateFunction for AggregateFunctionSortAdaptor { Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = Self::get_state(place); Ok(state.serialize(writer)?) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = Self::get_state(place); let rhs = SortAggState::deserialize(reader)?; diff --git a/src/query/functions/src/aggregates/aggregate_arg_min_max.rs b/src/query/functions/src/aggregates/aggregate_arg_min_max.rs index 98506ab12b8b5..5311335b6b229 100644 --- a/src/query/functions/src/aggregates/aggregate_arg_min_max.rs +++ b/src/query/functions/src/aggregates/aggregate_arg_min_max.rs @@ -270,12 +270,12 @@ where Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.serialize(writer)?) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let rhs: State = borsh_partial_deserialize(reader)?; state.merge_from(rhs) diff --git a/src/query/functions/src/aggregates/aggregate_array_agg.rs b/src/query/functions/src/aggregates/aggregate_array_agg.rs index 7083b886996df..ddeba25f9c86f 100644 --- a/src/query/functions/src/aggregates/aggregate_array_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_array_agg.rs @@ -548,12 +548,12 @@ where Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.serialize(writer)?) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let rhs = State::deserialize_reader(reader)?; diff --git a/src/query/functions/src/aggregates/aggregate_array_moving.rs b/src/query/functions/src/aggregates/aggregate_array_moving.rs index 6ef407c1402e2..0070924a974df 100644 --- a/src/query/functions/src/aggregates/aggregate_array_moving.rs +++ b/src/query/functions/src/aggregates/aggregate_array_moving.rs @@ -447,12 +447,12 @@ where State: SumState state.accumulate_row(&columns[0], row) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.serialize(writer)?) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let rhs = State::deserialize_reader(reader)?; @@ -616,12 +616,12 @@ where State: SumState state.accumulate_row(&columns[0], row) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.serialize(writer)?) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let rhs = State::deserialize_reader(reader)?; diff --git a/src/query/functions/src/aggregates/aggregate_bitmap.rs b/src/query/functions/src/aggregates/aggregate_bitmap.rs index cce9492c709d8..3f56ff133296b 100644 --- a/src/query/functions/src/aggregates/aggregate_bitmap.rs +++ b/src/query/functions/src/aggregates/aggregate_bitmap.rs @@ -287,7 +287,7 @@ where Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); // flag indicate where bitmap is none let flag: u8 = if state.rb.is_some() { 1 } else { 0 }; @@ -298,7 +298,7 @@ where Ok(()) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let flag = reader[0]; @@ -482,12 +482,12 @@ where Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - self.inner.serialize(place, writer) + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + self.inner.serialize_binary(place, writer) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - self.inner.merge(place, reader) + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + self.inner.merge_binary(place, reader) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs b/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs index 77a11ec9b44f8..d74b9a32f661a 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs @@ -106,12 +106,12 @@ where State: DistinctStateFunc state.add(columns, row) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = Self::get_state(place); state.serialize(writer) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = Self::get_state(place); let rhs = State::deserialize(reader)?; diff --git a/src/query/functions/src/aggregates/aggregate_combinator_if.rs b/src/query/functions/src/aggregates/aggregate_combinator_if.rs index ea83b8028d23e..03bbdf625e4df 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_if.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_if.rs @@ -155,12 +155,12 @@ impl AggregateFunction for AggregateIfCombinator { Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - self.nested.serialize(place, writer) + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + self.nested.serialize_binary(place, writer) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - self.nested.merge(place, reader) + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + self.nested.merge_binary(place, reader) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_combinator_state.rs b/src/query/functions/src/aggregates/aggregate_combinator_state.rs index 486707ec7af61..e83d78f24164c 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_state.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_state.rs @@ -106,13 +106,13 @@ impl AggregateFunction for AggregateStateCombinator { self.nested.accumulate_row(place, columns, row) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - self.nested.serialize(place, writer) + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + self.nested.serialize_binary(place, writer) } #[inline] - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - self.nested.merge(place, reader) + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + self.nested.merge_binary(place, reader) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { @@ -121,7 +121,7 @@ impl AggregateFunction for AggregateStateCombinator { fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { let builder = builder.as_binary_mut().unwrap(); - self.nested.serialize(place, &mut builder.data)?; + self.nested.serialize_binary(place, &mut builder.data)?; builder.commit_row(); Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_count.rs b/src/query/functions/src/aggregates/aggregate_count.rs index b822b832896ee..9e1db7e50a989 100644 --- a/src/query/functions/src/aggregates/aggregate_count.rs +++ b/src/query/functions/src/aggregates/aggregate_count.rs @@ -160,12 +160,12 @@ impl AggregateFunction for AggregateCountFunction { Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.count.serialize(writer)?) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let other: u64 = borsh_partial_deserialize(reader)?; state.count += other; diff --git a/src/query/functions/src/aggregates/aggregate_covariance.rs b/src/query/functions/src/aggregates/aggregate_covariance.rs index 8af1216f56567..c69d242ea811a 100644 --- a/src/query/functions/src/aggregates/aggregate_covariance.rs +++ b/src/query/functions/src/aggregates/aggregate_covariance.rs @@ -231,12 +231,12 @@ where Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.serialize(writer)?) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let rhs: AggregateCovarianceState = borsh_partial_deserialize(reader)?; state.merge(&rhs); diff --git a/src/query/functions/src/aggregates/aggregate_json_array_agg.rs b/src/query/functions/src/aggregates/aggregate_json_array_agg.rs index f1ba8d35807fa..acd867e4f2a8e 100644 --- a/src/query/functions/src/aggregates/aggregate_json_array_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_json_array_agg.rs @@ -240,12 +240,12 @@ where Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.serialize(writer)?) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let rhs: State = borsh_partial_deserialize(reader)?; diff --git a/src/query/functions/src/aggregates/aggregate_json_object_agg.rs b/src/query/functions/src/aggregates/aggregate_json_object_agg.rs index 9efe5fc0f6b90..45179f529cee2 100644 --- a/src/query/functions/src/aggregates/aggregate_json_object_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_json_object_agg.rs @@ -293,12 +293,12 @@ where Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.serialize(writer)?) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let rhs: State = borsh_partial_deserialize(reader)?; diff --git a/src/query/functions/src/aggregates/aggregate_markov_tarin.rs b/src/query/functions/src/aggregates/aggregate_markov_tarin.rs index 0524caa523e39..285965bcbafa8 100644 --- a/src/query/functions/src/aggregates/aggregate_markov_tarin.rs +++ b/src/query/functions/src/aggregates/aggregate_markov_tarin.rs @@ -113,12 +113,12 @@ impl AggregateFunction for MarkovTarin { Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.serialize(writer)?) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let mut rhs = borsh_partial_deserialize::(reader)?; state.merge(&mut rhs); diff --git a/src/query/functions/src/aggregates/aggregate_null_result.rs b/src/query/functions/src/aggregates/aggregate_null_result.rs index 03f9b6a1ce041..74e2942e9bceb 100644 --- a/src/query/functions/src/aggregates/aggregate_null_result.rs +++ b/src/query/functions/src/aggregates/aggregate_null_result.rs @@ -86,11 +86,11 @@ impl AggregateFunction for AggregateNullResultFunction { Ok(()) } - fn serialize(&self, _place: AggrState, _writer: &mut Vec) -> Result<()> { + fn serialize_binary(&self, _place: AggrState, _writer: &mut Vec) -> Result<()> { Ok(()) } - fn merge(&self, _place: AggrState, _reader: &mut &[u8]) -> Result<()> { + fn merge_binary(&self, _place: AggrState, _reader: &mut &[u8]) -> Result<()> { Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs b/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs index 5aadd1887b5db..d982f8b80ca14 100644 --- a/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs +++ b/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs @@ -362,12 +362,12 @@ where for<'a> T: AccessType = F64> + Send + Sync }); Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.serialize(writer)?) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let mut rhs: QuantileTDigestState = borsh_partial_deserialize(reader)?; state.merge(&mut rhs) diff --git a/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs b/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs index 7ac562153fe53..d4aa266fd1a5e 100644 --- a/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs +++ b/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs @@ -143,12 +143,12 @@ where }); Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.serialize(writer)?) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let mut rhs: QuantileTDigestState = borsh_partial_deserialize(reader)?; state.merge(&mut rhs) diff --git a/src/query/functions/src/aggregates/aggregate_retention.rs b/src/query/functions/src/aggregates/aggregate_retention.rs index 5f42198bec4c0..5b904f234e740 100644 --- a/src/query/functions/src/aggregates/aggregate_retention.rs +++ b/src/query/functions/src/aggregates/aggregate_retention.rs @@ -143,12 +143,12 @@ impl AggregateFunction for AggregateRetentionFunction { Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.serialize(writer)?) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let rhs: AggregateRetentionState = borsh_partial_deserialize(reader)?; state.merge(&rhs); diff --git a/src/query/functions/src/aggregates/aggregate_st_collect.rs b/src/query/functions/src/aggregates/aggregate_st_collect.rs index afdb3d88a373e..bfe960471c5ee 100644 --- a/src/query/functions/src/aggregates/aggregate_st_collect.rs +++ b/src/query/functions/src/aggregates/aggregate_st_collect.rs @@ -306,12 +306,12 @@ where Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.serialize(writer)?) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let rhs: State = borsh_partial_deserialize(reader)?; diff --git a/src/query/functions/src/aggregates/aggregate_string_agg.rs b/src/query/functions/src/aggregates/aggregate_string_agg.rs index 34a530c0253f7..743c2592d0b9b 100644 --- a/src/query/functions/src/aggregates/aggregate_string_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_string_agg.rs @@ -147,12 +147,12 @@ impl AggregateFunction for AggregateStringAggFunction { Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.serialize(writer)?) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let rhs: StringAggState = borsh_partial_deserialize(reader)?; state.values.push_str(&rhs.values); diff --git a/src/query/functions/src/aggregates/aggregate_unary.rs b/src/query/functions/src/aggregates/aggregate_unary.rs index 20b9fc5f85a70..be20d15157534 100644 --- a/src/query/functions/src/aggregates/aggregate_unary.rs +++ b/src/query/functions/src/aggregates/aggregate_unary.rs @@ -227,12 +227,12 @@ where Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state: &mut S = place.get::(); Ok(state.serialize(writer)?) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state: &mut S = place.get::(); let rhs = S::deserialize_reader(reader)?; state.merge(&rhs) diff --git a/src/query/functions/src/aggregates/aggregate_window_funnel.rs b/src/query/functions/src/aggregates/aggregate_window_funnel.rs index 3d29c1734bff8..864e5d0eaeea5 100644 --- a/src/query/functions/src/aggregates/aggregate_window_funnel.rs +++ b/src/query/functions/src/aggregates/aggregate_window_funnel.rs @@ -275,12 +275,12 @@ where Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::>(); Ok(state.serialize(writer)?) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::>(); let mut rhs: AggregateWindowFunnelState = borsh_partial_deserialize(reader)?; state.merge(&mut rhs); diff --git a/src/query/functions/src/aggregates/aggregator_common.rs b/src/query/functions/src/aggregates/aggregator_common.rs index 6430bcf3cc647..8d4255a991983 100644 --- a/src/query/functions/src/aggregates/aggregator_common.rs +++ b/src/query/functions/src/aggregates/aggregator_common.rs @@ -187,9 +187,9 @@ pub fn eval_aggr_for_test( func.accumulate(state, entries.into(), None, rows)?; if with_serialize { let mut buf = vec![]; - func.serialize(state, &mut buf)?; + func.serialize_binary(state, &mut buf)?; func.init_state(state); - func.merge(state, &mut buf.as_slice())?; + func.merge_binary(state, &mut buf.as_slice())?; } let mut builder = ColumnBuilder::with_capacity(&data_type, 1024); func.merge_result(state, &mut builder)?; diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs index a43003ba291b4..9b5fba779986a 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs @@ -24,7 +24,6 @@ use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::AggrState; use databend_common_expression::BlockMetaInfoDowncast; -use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::DataBlock; use databend_common_expression::ProjectedBlock; @@ -156,14 +155,10 @@ impl AccumulatingTransform for PartialSingleStateAggregator { ) .zip(builders.iter_mut()) { - func.serialize(place, &mut builder.data)?; - builder.commit_row(); + func.serialize(place, builder)?; } - let columns = builders - .into_iter() - .map(|b| Column::Binary(b.build())) - .collect(); + let columns = builders.into_iter().map(|b| b.build()).collect(); vec![DataBlock::new_from_columns(columns)] } else { vec![] diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs index d92ace4e212fd..fdfc1dcbe6d39 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs @@ -98,14 +98,14 @@ impl AggregateFunction for AggregateUdfScript { Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); state .serialize(writer) .map_err(|e| ErrorCode::Internal(format!("state failed to serialize: {e}"))) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { let state = place.get::(); let rhs = UdfAggState::deserialize(reader).map_err(|e| ErrorCode::Internal(e.to_string()))?; From d64d2ec5246b973b0c028339924fe15de065af9f Mon Sep 17 00:00:00 2001 From: coldWater Date: Tue, 22 Jul 2025 15:29:07 +0800 Subject: [PATCH 02/22] fix --- .../base/src/runtime/profile/profile.rs | 2 +- src/common/storage/src/copy.rs | 8 +- src/common/storage/src/merge.rs | 2 +- .../src/statistics/data_cache_statistics.rs | 2 +- .../src/aggregate/aggregate_function.rs | 6 +- .../src/aggregate/aggregate_function_state.rs | 44 +++-- src/query/expression/src/aggregate/payload.rs | 11 +- .../pipeline/core/src/processors/processor.rs | 39 ++++- .../pipeline/core/src/processors/profile.rs | 6 +- .../serde/transform_aggregate_serializer.rs | 89 ++++------ .../serde/transform_deserializer.rs | 163 ++++++++---------- .../serde/transform_exchange_async_barrier.rs | 40 +++-- .../servers/flight/v1/packets/packet_data.rs | 1 + .../physical_aggregate_partial.rs | 40 ++++- 14 files changed, 245 insertions(+), 208 deletions(-) diff --git a/src/common/base/src/runtime/profile/profile.rs b/src/common/base/src/runtime/profile/profile.rs index 58cfb503be56a..1dc288e7303f0 100644 --- a/src/common/base/src/runtime/profile/profile.rs +++ b/src/common/base/src/runtime/profile/profile.rs @@ -23,7 +23,7 @@ use crate::runtime::metrics::ScopedRegistry; use crate::runtime::profile::ProfileStatisticsName; use crate::runtime::ThreadTracker; -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct ProfileLabel { pub name: String, pub value: Vec, diff --git a/src/common/storage/src/copy.rs b/src/common/storage/src/copy.rs index 927835fdb496a..a3464e33ed4fe 100644 --- a/src/common/storage/src/copy.rs +++ b/src/common/storage/src/copy.rs @@ -20,7 +20,7 @@ use serde::Deserialize; use serde::Serialize; use thiserror::Error; -#[derive(Default, Clone, Serialize, Deserialize)] +#[derive(Debug, Default, Clone, Serialize, Deserialize)] pub struct CopyStatus { /// Key is file path. pub files: DashMap, @@ -45,7 +45,7 @@ impl CopyStatus { } } -#[derive(Default, Clone, Serialize, Deserialize)] +#[derive(Debug, Default, Clone, Serialize, Deserialize)] pub struct FileStatus { pub num_rows_loaded: usize, pub error: Option, @@ -79,7 +79,7 @@ impl FileStatus { } } -#[derive(Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct FileErrorsInfo { pub num_errors: usize, pub first_error: FileParseErrorAtLine, @@ -156,7 +156,7 @@ impl FileParseError { } } -#[derive(Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct FileParseErrorAtLine { pub error: FileParseError, pub line: usize, diff --git a/src/common/storage/src/merge.rs b/src/common/storage/src/merge.rs index 9f4fe94080bc4..50125e8823c90 100644 --- a/src/common/storage/src/merge.rs +++ b/src/common/storage/src/merge.rs @@ -15,7 +15,7 @@ use serde::Deserialize; use serde::Serialize; -#[derive(Default, Clone, Serialize, Deserialize)] +#[derive(Debug, Default, Clone, Serialize, Deserialize)] pub struct MutationStatus { pub insert_rows: u64, pub deleted_rows: u64, diff --git a/src/query/catalog/src/statistics/data_cache_statistics.rs b/src/query/catalog/src/statistics/data_cache_statistics.rs index 04ae435204eb5..6838a93cc212d 100644 --- a/src/query/catalog/src/statistics/data_cache_statistics.rs +++ b/src/query/catalog/src/statistics/data_cache_statistics.rs @@ -24,7 +24,7 @@ pub struct DataCacheMetrics { bytes_from_memory: AtomicUsize, } -#[derive(Default, Clone, Serialize, Deserialize)] +#[derive(Debug, Default, Clone, Serialize, Deserialize)] pub struct DataCacheMetricValues { pub bytes_from_remote_disk: usize, pub bytes_from_local_disk: usize, diff --git a/src/query/expression/src/aggregate/aggregate_function.rs b/src/query/expression/src/aggregate/aggregate_function.rs index d937fe61905e6..e4eb3ea51f047 100755 --- a/src/query/expression/src/aggregate/aggregate_function.rs +++ b/src/query/expression/src/aggregate/aggregate_function.rs @@ -23,12 +23,12 @@ use super::AggrStateLoc; use super::AggrStateRegistry; use super::StateAddr; use crate::types::DataType; -use crate::AggrStateSerdeType; use crate::BlockEntry; use crate::ColumnBuilder; use crate::ProjectedBlock; use crate::Scalar; use crate::ScalarRef; +use crate::StateSerdeItem; pub type AggregateFunctionRef = Arc; @@ -69,8 +69,8 @@ pub trait AggregateFunction: fmt::Display + Sync + Send { // Used in aggregate_null_adaptor fn accumulate_row(&self, place: AggrState, columns: ProjectedBlock, row: usize) -> Result<()>; - fn serialize_type(&self) -> Vec { - vec![AggrStateSerdeType::Binary(self.serialize_size_per_row())] + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(self.serialize_size_per_row())] } fn serialize(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { diff --git a/src/query/expression/src/aggregate/aggregate_function_state.rs b/src/query/expression/src/aggregate/aggregate_function_state.rs index 409278532dd06..791ee68542ff7 100644 --- a/src/query/expression/src/aggregate/aggregate_function_state.rs +++ b/src/query/expression/src/aggregate/aggregate_function_state.rs @@ -119,7 +119,7 @@ pub fn get_states_layout(funcs: &[AggregateFunctionRef]) -> Result for func in funcs { func.register_state(&mut registry); registry.commit(); - serialize_type.push(func.serialize_type().into_boxed_slice()); + serialize_type.push(StateSerdeType(func.serialize_type().into())); } let AggrStateRegistry { states, offsets } = registry; @@ -193,25 +193,49 @@ impl AggrStateLoc { } } +#[derive(Debug, Clone, Copy)] +pub enum StateSerdeItem { + Bool, + Binary(Option), +} + +#[derive(Debug, Clone)] +pub struct StateSerdeType(Box<[StateSerdeItem]>); + +impl StateSerdeType { + pub fn data_type(&self) -> DataType { + DataType::Tuple( + self.0 + .iter() + .map(|item| match item { + StateSerdeItem::Bool => DataType::Boolean, + StateSerdeItem::Binary(_) => DataType::Binary, + }) + .collect(), + ) + } +} + #[derive(Debug, Clone)] pub struct StatesLayout { pub layout: Layout, pub states_loc: Vec>, - serialize_type: Vec>, + pub(super) serialize_type: Vec, } impl StatesLayout { pub fn serialize_builders(&self, num_rows: usize) -> Vec { self.serialize_type .iter() - .map(|item| { - let builder = item + .map(|serde_type| { + let builder = serde_type + .0 .iter() - .map(|serde_type| match serde_type { - AggrStateSerdeType::Bool => { + .map(|item| match item { + StateSerdeItem::Bool => { ColumnBuilder::with_capacity(&DataType::Boolean, num_rows) } - AggrStateSerdeType::Binary(size) => { + StateSerdeItem::Binary(size) => { ColumnBuilder::Binary(BinaryColumnBuilder::with_capacity( num_rows, num_rows * size.unwrap_or(0), @@ -306,12 +330,6 @@ pub enum AggrStateType { Custom(Layout), } -#[derive(Debug, Clone, Copy)] -pub enum AggrStateSerdeType { - Bool, - Binary(Option), -} - #[cfg(test)] mod tests { use proptest::prelude::*; diff --git a/src/query/expression/src/aggregate/payload.rs b/src/query/expression/src/aggregate/payload.rs index aafe2921c8532..836e299f16d47 100644 --- a/src/query/expression/src/aggregate/payload.rs +++ b/src/query/expression/src/aggregate/payload.rs @@ -423,9 +423,14 @@ impl Payload { pub fn empty_block(&self, fake_rows: Option) -> DataBlock { let fake_rows = fake_rows.unwrap_or(0); - let entries = (0..self.aggrs.len()) - .map(|_| { - ColumnBuilder::repeat_default(&DataType::Binary, fake_rows) + let entries = self + .states_layout + .as_ref() + .unwrap() + .serialize_type + .iter() + .map(|serde_type| { + ColumnBuilder::repeat_default(&serde_type.data_type(), fake_rows) .build() .into() }) diff --git a/src/query/pipeline/core/src/processors/processor.rs b/src/query/pipeline/core/src/processors/processor.rs index ce70053b80ded..bb71f493a2061 100644 --- a/src/query/pipeline/core/src/processors/processor.rs +++ b/src/query/pipeline/core/src/processors/processor.rs @@ -161,12 +161,24 @@ impl ProcessorPtr { /// # Safety pub unsafe fn process(&self) -> Result<()> { - let mut name = self.name(); - name.push_str("::process"); - let _span = LocalSpan::enter_with_local_parent(name) + let span = LocalSpan::enter_with_local_parent(format!("{}::process", self.name())) .with_property(|| ("graph-node-id", self.id().index().to_string())); - (*self.inner.get()).process() + match (*self.inner.get()).process() { + Ok(_) => Ok(()), + Err(err) => { + let _ = span + .with_property(|| ("error", "true")) + .with_properties(|| { + [ + ("error.type", err.code().to_string()), + ("error.message", err.display_text()), + ] + }); + log::info!(error = err.to_string(); "[PIPELINE-EXECUTOR] Error in process"); + Err(err) + } + } } /// # Safety @@ -190,10 +202,23 @@ impl ProcessorPtr { async move { let span = Span::enter_with_local_parent(name) .with_property(|| ("graph-node-id", id.index().to_string())); - task.in_span(span).await?; - drop(inner); - Ok(()) + match task.await { + Ok(_) => { + drop(inner); + Ok(()) + } + Err(err) => { + span.with_property(|| ("error", "true")).add_properties(|| { + [ + ("error.type", err.code().to_string()), + ("error.message", err.display_text()), + ] + }); + log::info!(error = err.to_string(); "[PIPELINE-EXECUTOR] Error in process"); + Err(err) + } + } } .boxed() } diff --git a/src/query/pipeline/core/src/processors/profile.rs b/src/query/pipeline/core/src/processors/profile.rs index 7fb5af4f4e8be..58b0fd98169ab 100644 --- a/src/query/pipeline/core/src/processors/profile.rs +++ b/src/query/pipeline/core/src/processors/profile.rs @@ -43,7 +43,7 @@ impl Drop for PlanScopeGuard { } } -#[derive(serde::Serialize, serde::Deserialize)] +#[derive(Debug, serde::Serialize, serde::Deserialize)] pub struct ErrorInfoDesc { message: String, detail: String, @@ -60,7 +60,7 @@ impl ErrorInfoDesc { } } -#[derive(serde::Serialize, serde::Deserialize)] +#[derive(Debug, serde::Serialize, serde::Deserialize)] pub enum ErrorInfo { Other(ErrorInfoDesc), IoError(ErrorInfoDesc), @@ -68,7 +68,7 @@ pub enum ErrorInfo { CalculationError(ErrorInfoDesc), } -#[derive(Clone, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct PlanProfile { pub id: Option, pub name: Option, diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_aggregate_serializer.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_aggregate_serializer.rs index 096485fa98fcc..03b45b75d6607 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_aggregate_serializer.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_aggregate_serializer.rs @@ -121,26 +121,19 @@ impl Processor for TransformAggregateSerializer { impl TransformAggregateSerializer { fn transform_input_data(&mut self, mut data_block: DataBlock) -> Result { debug_assert!(data_block.is_empty()); - if let Some(block_meta) = data_block.take_meta() { - if let Some(block_meta) = AggregateMeta::downcast_from(block_meta) { - match block_meta { - AggregateMeta::Spilled(_) => unreachable!(), - AggregateMeta::Serialized(_) => unreachable!(), - AggregateMeta::BucketSpilled(_) => unreachable!(), - AggregateMeta::Partitioned { .. } => unreachable!(), - AggregateMeta::AggregateSpilling(_) => unreachable!(), - AggregateMeta::AggregatePayload(p) => { - self.input_data = Some(SerializeAggregateStream::create( - &self.params, - SerializePayload::AggregatePayload(p), - )); - return Ok(Event::Sync); - } - } - } - } - unreachable!() + let Some(AggregateMeta::AggregatePayload(p)) = data_block + .take_meta() + .and_then(AggregateMeta::downcast_from) + else { + unreachable!() + }; + + self.input_data = Some(SerializeAggregateStream::create( + &self.params, + SerializePayload::AggregatePayload(p), + )); + Ok(Event::Sync) } } @@ -218,41 +211,29 @@ impl SerializeAggregateStream { return Ok(None); } - match self.payload.as_ref().get_ref() { - SerializePayload::AggregatePayload(p) => { - let block = p.payload.aggregate_flush(&mut self.flush_state)?; - - if block.is_none() { - self.end_iter = true; - } - - match block { - Some(block) => { - self.nums += 1; - Ok(Some(block.add_meta(Some( - AggregateSerdeMeta::create_agg_payload( - p.bucket, - p.max_partition_count, - false, - ), - ))?)) - } - None => { - // always return at least one block - if self.nums == 0 { - self.nums += 1; - let block = p.payload.empty_block(Some(1)); - Ok(Some(block.add_meta(Some( - AggregateSerdeMeta::create_agg_payload( - p.bucket, - p.max_partition_count, - true, - ), - ))?)) - } else { - Ok(None) - } - } + let SerializePayload::AggregatePayload(p) = self.payload.as_ref().get_ref(); + match p.payload.aggregate_flush(&mut self.flush_state)? { + Some(block) => { + self.nums += 1; + Ok(Some(block.add_meta(Some( + AggregateSerdeMeta::create_agg_payload(p.bucket, p.max_partition_count, false), + ))?)) + } + None => { + self.end_iter = true; + // always return at least one block + if self.nums == 0 { + self.nums += 1; + let block = p.payload.empty_block(Some(1)); + Ok(Some(block.add_meta(Some( + AggregateSerdeMeta::create_agg_payload( + p.bucket, + p.max_partition_count, + true, + ), + ))?)) + } else { + Ok(None) } } } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_deserializer.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_deserializer.rs index 09725ee3bb173..9f632ec0b187c 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_deserializer.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_deserializer.rs @@ -79,94 +79,77 @@ impl TransformDeserializer { return Ok(DataBlock::new_with_meta(vec![], 0, meta)); } - let data_block = match &meta { - None => { - deserialize_block(dict, fragment_data, &self.schema, self.arrow_schema.clone())? - } - Some(meta) => match AggregateSerdeMeta::downcast_ref_from(meta) { - None => { - deserialize_block(dict, fragment_data, &self.schema, self.arrow_schema.clone())? + let Some(meta) = meta + .as_ref() + .and_then(AggregateSerdeMeta::downcast_ref_from) + else { + let data_block = + deserialize_block(dict, fragment_data, &self.schema, self.arrow_schema.clone())?; + return match data_block.num_columns() == 0 { + true => Ok(DataBlock::new_with_meta(vec![], row_count as usize, meta)), + false => data_block.add_meta(meta), + }; + }; + + match meta.typ { + BUCKET_TYPE => { + let mut block = deserialize_block( + dict, + fragment_data, + &self.schema, + self.arrow_schema.clone(), + )?; + + if meta.is_empty { + block = block.slice(0..0); } - Some(meta) => { - return match meta.typ == BUCKET_TYPE { - true => { - let mut block = deserialize_block( - dict, - fragment_data, - &self.schema, - self.arrow_schema.clone(), - )?; - - if meta.is_empty { - block = block.slice(0..0); - } - - Ok(DataBlock::empty_with_meta( - AggregateMeta::create_serialized( - meta.bucket, - block, - meta.max_partition_count, - ), - )) - } - false => { - let data_schema = Arc::new(exchange_defines::spilled_schema()); - let arrow_schema = Arc::new(exchange_defines::spilled_arrow_schema()); - let data_block = deserialize_block( - dict, - fragment_data, - &data_schema, - arrow_schema.clone(), - )?; - - let columns = data_block - .columns() - .iter() - .map(|c| c.as_column().unwrap().clone()) - .collect::>(); - - let buckets = - NumberType::::try_downcast_column(&columns[0]).unwrap(); - let data_range_start = - NumberType::::try_downcast_column(&columns[1]).unwrap(); - let data_range_end = - NumberType::::try_downcast_column(&columns[2]).unwrap(); - let columns_layout = - ArrayType::::try_downcast_column(&columns[3]).unwrap(); - - let columns_layout_data = columns_layout.values().as_slice(); - let columns_layout_offsets = columns_layout.offsets(); - - let mut buckets_payload = Vec::with_capacity(data_block.num_rows()); - for index in 0..data_block.num_rows() { - unsafe { - buckets_payload.push(BucketSpilledPayload { - bucket: *buckets.get_unchecked(index) as isize, - location: meta.location.clone().unwrap(), - data_range: *data_range_start.get_unchecked(index) - ..*data_range_end.get_unchecked(index), - columns_layout: columns_layout_data[columns_layout_offsets - [index] - as usize - ..columns_layout_offsets[index + 1] as usize] - .to_vec(), - max_partition_count: meta.max_partition_count, - }); - } - } - - Ok(DataBlock::empty_with_meta(AggregateMeta::create_spilled( - buckets_payload, - ))) - } - }; + + Ok(DataBlock::empty_with_meta( + AggregateMeta::create_serialized(meta.bucket, block, meta.max_partition_count), + )) + } + _ => { + let data_schema = Arc::new(exchange_defines::spilled_schema()); + let arrow_schema = Arc::new(exchange_defines::spilled_arrow_schema()); + let data_block = + deserialize_block(dict, fragment_data, &data_schema, arrow_schema.clone())?; + + let columns = data_block + .columns() + .iter() + .map(|c| c.as_column().unwrap().clone()) + .collect::>(); + + let buckets = NumberType::::try_downcast_column(&columns[0]).unwrap(); + let data_range_start = NumberType::::try_downcast_column(&columns[1]).unwrap(); + let data_range_end = NumberType::::try_downcast_column(&columns[2]).unwrap(); + let columns_layout = + ArrayType::::try_downcast_column(&columns[3]).unwrap(); + + let columns_layout_data = columns_layout.values().as_slice(); + let columns_layout_offsets = columns_layout.offsets(); + + let mut buckets_payload = Vec::with_capacity(data_block.num_rows()); + for index in 0..data_block.num_rows() { + unsafe { + buckets_payload.push(BucketSpilledPayload { + bucket: *buckets.get_unchecked(index) as isize, + location: meta.location.clone().unwrap(), + data_range: *data_range_start.get_unchecked(index) + ..*data_range_end.get_unchecked(index), + columns_layout: columns_layout_data[columns_layout_offsets[index] + as usize + ..columns_layout_offsets[index + 1] as usize] + .to_vec(), + max_partition_count: meta.max_partition_count, + }); + } } - }, - }; - match data_block.num_columns() == 0 { - true => Ok(DataBlock::new_with_meta(vec![], row_count as usize, meta)), - false => data_block.add_meta(meta), + Ok(DataBlock::empty_with_meta(AggregateMeta::create_spilled( + buckets_payload, + ))) + } } } } @@ -177,15 +160,9 @@ impl BlockMetaTransform for TransformDeserializer { fn transform(&mut self, mut meta: ExchangeDeserializeMeta) -> Result> { match meta.packet.pop().unwrap() { - DataPacket::ErrorCode(v) => Err(v), - DataPacket::Dictionary(_) => unreachable!(), - DataPacket::QueryProfiles(_) => unreachable!(), - DataPacket::SerializeProgress { .. } => unreachable!(), - DataPacket::CopyStatus { .. } => unreachable!(), - DataPacket::MutationStatus { .. } => unreachable!(), - DataPacket::DataCacheMetrics(_) => unreachable!(), DataPacket::FragmentData(v) => Ok(vec![self.recv_data(meta.packet, v)?]), - DataPacket::QueryPerf(_) => unreachable!(), + DataPacket::ErrorCode(err) => Err(err), + _ => unreachable!(), } } } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_exchange_async_barrier.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_exchange_async_barrier.rs index 1628bc9af5beb..eaebda0543e6d 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_exchange_async_barrier.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_exchange_async_barrier.rs @@ -45,29 +45,31 @@ impl AsyncTransform for TransformExchangeAsyncBarrier { const NAME: &'static str = "TransformExchangeAsyncBarrier"; async fn transform(&mut self, mut data: DataBlock) -> Result { - if let Some(meta) = data + let Some(meta) = data .take_meta() .and_then(FlightSerializedMeta::downcast_from) - { - let mut futures = Vec::with_capacity(meta.serialized_blocks.len()); + else { + return Err(ErrorCode::Internal("")); + }; - for serialized_block in meta.serialized_blocks { - futures.push(databend_common_base::runtime::spawn(async move { - match serialized_block { - FlightSerialized::DataBlock(v) => Ok(v), - FlightSerialized::Future(f) => f.await, - } - })); - } - - return match futures::future::try_join_all(futures).await { - Err(_) => Err(ErrorCode::TokioError("Cannot join tokio job")), - Ok(spilled_data) => Ok(DataBlock::empty_with_meta(ExchangeShuffleMeta::create( - spilled_data.into_iter().collect::>>()?, - ))), - }; + let mut futures = Vec::with_capacity(meta.serialized_blocks.len()); + for serialized_block in meta.serialized_blocks { + futures.push(databend_common_base::runtime::spawn(async move { + match serialized_block { + FlightSerialized::DataBlock(v) => Ok(v), + FlightSerialized::Future(f) => f.await, + } + })); } - Err(ErrorCode::Internal("")) + match futures::future::try_join_all(futures).await { + Err(_) => Err(ErrorCode::TokioError("Cannot join tokio job")), + Ok(spilled_data) => { + let blocks = spilled_data.into_iter().collect::>()?; + Ok(DataBlock::empty_with_meta(ExchangeShuffleMeta::create( + blocks, + ))) + } + } } } diff --git a/src/query/service/src/servers/flight/v1/packets/packet_data.rs b/src/query/service/src/servers/flight/v1/packets/packet_data.rs index 2b5407ffb2946..81226501438f8 100644 --- a/src/query/service/src/servers/flight/v1/packets/packet_data.rs +++ b/src/query/service/src/servers/flight/v1/packets/packet_data.rs @@ -54,6 +54,7 @@ impl Debug for FragmentData { } } +#[derive(Debug)] pub enum DataPacket { ErrorCode(ErrorCode), Dictionary(FlightData), diff --git a/src/query/sql/src/executor/physical_plans/physical_aggregate_partial.rs b/src/query/sql/src/executor/physical_plans/physical_aggregate_partial.rs index a8a73071aaa57..5026f7ba38f9a 100644 --- a/src/query/sql/src/executor/physical_plans/physical_aggregate_partial.rs +++ b/src/query/sql/src/executor/physical_plans/physical_aggregate_partial.rs @@ -14,11 +14,11 @@ use databend_common_exception::Result; use databend_common_expression::types::DataType; -#[allow(unused_imports)] -use databend_common_expression::DataBlock; use databend_common_expression::DataField; use databend_common_expression::DataSchemaRef; use databend_common_expression::DataSchemaRefExt; +use databend_common_expression::StateSerdeItem; +use databend_common_functions::aggregates::AggregateFunctionFactory; use super::SortDesc; use crate::executor::explain::PlanStatsInfo; @@ -47,11 +47,39 @@ impl AggregatePartial { let input_schema = self.input.output_schema()?; let mut fields = Vec::with_capacity(self.agg_funcs.len() + self.group_by.len()); + let factory = AggregateFunctionFactory::instance(); - fields.extend(self.agg_funcs.iter().map(|func| { - let name = func.output_column.to_string(); - DataField::new(&name, DataType::Binary) - })); + for desc in &self.agg_funcs { + let name = desc.output_column.to_string(); + + if desc.sig.udaf.is_some() { + fields.push(DataField::new( + &name, + DataType::Tuple(vec![DataType::Binary]), + )); + continue; + } + + let func = factory + .get( + &desc.sig.name, + desc.sig.params.clone(), + desc.sig.args.clone(), + desc.sig.sort_descs.clone(), + ) + .unwrap(); + + let tuple = func + .serialize_type() + .iter() + .map(|serde_type| match serde_type { + StateSerdeItem::Bool => DataType::Boolean, + StateSerdeItem::Binary(_) => DataType::Binary, + }) + .collect(); + + fields.push(DataField::new(&name, DataType::Tuple(tuple))) + } for (idx, field) in self.group_by.iter().zip( self.group_by From f303632aeb96ea402ca9ae37aa4064c2848331f1 Mon Sep 17 00:00:00 2001 From: coldWater Date: Tue, 22 Jul 2025 17:35:21 +0800 Subject: [PATCH 03/22] fix --- src/query/expression/src/aggregate/payload.rs | 4 ++-- .../src/pipelines/executor/pipeline_executor.rs | 7 +++---- .../pipelines/executor/query_pipeline_executor.rs | 15 ++++++++++----- .../transform_exchange_aggregate_serializer.rs | 9 +++------ 4 files changed, 18 insertions(+), 17 deletions(-) diff --git a/src/query/expression/src/aggregate/payload.rs b/src/query/expression/src/aggregate/payload.rs index 836e299f16d47..bd225a6a9f785 100644 --- a/src/query/expression/src/aggregate/payload.rs +++ b/src/query/expression/src/aggregate/payload.rs @@ -423,12 +423,12 @@ impl Payload { pub fn empty_block(&self, fake_rows: Option) -> DataBlock { let fake_rows = fake_rows.unwrap_or(0); + assert_eq!(self.aggrs.is_empty(), self.states_layout.is_none()); let entries = self .states_layout .as_ref() - .unwrap() - .serialize_type .iter() + .flat_map(|x| x.serialize_type.iter()) .map(|serde_type| { ColumnBuilder::repeat_default(&serde_type.data_type(), fake_rows) .build() diff --git a/src/query/service/src/pipelines/executor/pipeline_executor.rs b/src/query/service/src/pipelines/executor/pipeline_executor.rs index 53a82eba8da10..7117dadb04634 100644 --- a/src/query/service/src/pipelines/executor/pipeline_executor.rs +++ b/src/query/service/src/pipelines/executor/pipeline_executor.rs @@ -146,6 +146,7 @@ impl PipelineExecutor { } } + #[fastrace::trace(name = "PipelineExecutor::init")] fn init(on_init_callback: &Mutex>, query_id: &Arc) -> Result<()> { // TODO: the on init callback cannot be killed. { @@ -158,10 +159,8 @@ impl PipelineExecutor { } } - info!( - "[PIPELINE-EXECUTOR] Pipeline initialized successfully for query {}, elapsed: {:?}", - query_id, - instant.elapsed() + info!(query_id, elapsed:? = instant.elapsed(); + "[PIPELINE-EXECUTOR] Pipeline initialized successfully", ); } Ok(()) diff --git a/src/query/service/src/pipelines/executor/query_pipeline_executor.rs b/src/query/service/src/pipelines/executor/query_pipeline_executor.rs index b45e707f30cc8..88aa39a5a51fa 100644 --- a/src/query/service/src/pipelines/executor/query_pipeline_executor.rs +++ b/src/query/service/src/pipelines/executor/query_pipeline_executor.rs @@ -266,6 +266,7 @@ impl QueryPipelineExecutor { Ok(()) } + #[fastrace::trace(name = "QueryPipelineExecutor::init")] fn init(self: &Arc, graph: Arc) -> Result<()> { unsafe { // TODO: the on init callback cannot be killed. @@ -286,10 +287,8 @@ impl QueryPipelineExecutor { } } - info!( - "[PIPELINE-EXECUTOR] Pipeline initialized successfully for query {}, elapsed: {:?}", - self.settings.query_id, - instant.elapsed() + info!(query_id = self.settings.query_id, elapsed:? = instant.elapsed(); + "[PIPELINE-EXECUTOR] Pipeline initialized successfully", ); } @@ -358,7 +357,7 @@ impl QueryPipelineExecutor { } } - let span = Span::enter_with_local_parent(func_path!()) + let span = Span::enter_with_local_parent("QueryPipelineExecutor::execute_threads") .with_property(|| ("thread_name", name.clone())); thread_join_handles.push(Thread::named_spawn(Some(name), move || unsafe { let _g = span.set_local_parent(); @@ -367,6 +366,12 @@ impl QueryPipelineExecutor { // finish the pipeline executor when has error or panic if let Err(cause) = try_result.flatten() { + span.with_property(|| ("error", "true")).add_properties(|| { + [ + ("error.type", cause.code().to_string()), + ("error.message", cause.display_text()), + ] + }); this.finish(Some(cause)); } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_exchange_aggregate_serializer.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_exchange_aggregate_serializer.rs index ffd80674bdab4..118f8119fed5e 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_exchange_aggregate_serializer.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_exchange_aggregate_serializer.rs @@ -125,12 +125,7 @@ impl BlockMetaTransform for TransformExchangeAggregateSeria continue; } - match AggregateMeta::downcast_from(block.take_meta().unwrap()) { - None => unreachable!(), - Some(AggregateMeta::Spilled(_)) => unreachable!(), - Some(AggregateMeta::Serialized(_)) => unreachable!(), - Some(AggregateMeta::BucketSpilled(_)) => unreachable!(), - Some(AggregateMeta::Partitioned { .. }) => unreachable!(), + match block.take_meta().and_then(AggregateMeta::downcast_from) { Some(AggregateMeta::AggregateSpilling(payload)) => { serialized_blocks.push(FlightSerialized::Future( match index == self.local_pos { @@ -172,6 +167,8 @@ impl BlockMetaTransform for TransformExchangeAggregateSeria let c = serialize_block(block_number, c, &self.options)?; serialized_blocks.push(FlightSerialized::DataBlock(c)); } + + _ => unreachable!(), }; } From 64d74db68b17b9dbf800b005bc0ff185130bb5f9 Mon Sep 17 00:00:00 2001 From: coldWater Date: Tue, 22 Jul 2025 18:20:24 +0800 Subject: [PATCH 04/22] fix --- src/query/expression/src/aggregate/aggregate_function.rs | 5 ----- src/query/expression/src/aggregate/payload.rs | 5 ++--- src/query/expression/src/aggregate/payload_flush.rs | 4 ++-- .../src/aggregates/adaptors/aggregate_null_adaptor.rs | 8 -------- .../src/aggregates/adaptors/aggregate_ornull_adaptor.rs | 4 ---- .../src/pipelines/executor/query_pipeline_executor.rs | 1 - .../aggregator/serde/transform_aggregate_serializer.rs | 2 +- 7 files changed, 5 insertions(+), 24 deletions(-) diff --git a/src/query/expression/src/aggregate/aggregate_function.rs b/src/query/expression/src/aggregate/aggregate_function.rs index e4eb3ea51f047..6ba5b3f569c89 100755 --- a/src/query/expression/src/aggregate/aggregate_function.rs +++ b/src/query/expression/src/aggregate/aggregate_function.rs @@ -166,9 +166,4 @@ pub trait AggregateFunction: fmt::Display + Sync + Send { fn get_if_condition(&self, _columns: ProjectedBlock) -> Option { None } - - // some features - fn convert_const_to_full(&self) -> bool { - true - } } diff --git a/src/query/expression/src/aggregate/payload.rs b/src/query/expression/src/aggregate/payload.rs index bd225a6a9f785..925293beb214c 100644 --- a/src/query/expression/src/aggregate/payload.rs +++ b/src/query/expression/src/aggregate/payload.rs @@ -421,14 +421,13 @@ impl Payload { true } - pub fn empty_block(&self, fake_rows: Option) -> DataBlock { - let fake_rows = fake_rows.unwrap_or(0); + pub fn empty_block(&self, fake_rows: usize) -> DataBlock { assert_eq!(self.aggrs.is_empty(), self.states_layout.is_none()); let entries = self .states_layout .as_ref() .iter() - .flat_map(|x| x.serialize_type.iter()) + .flat_map(|layout| layout.serialize_type.iter()) .map(|serde_type| { ColumnBuilder::repeat_default(&serde_type.data_type(), fake_rows) .build() diff --git a/src/query/expression/src/aggregate/payload_flush.rs b/src/query/expression/src/aggregate/payload_flush.rs index e0d15864a3a8e..5863b533e31b9 100644 --- a/src/query/expression/src/aggregate/payload_flush.rs +++ b/src/query/expression/src/aggregate/payload_flush.rs @@ -125,7 +125,7 @@ impl Payload { } if blocks.is_empty() { - return Ok(self.empty_block(None)); + return Ok(self.empty_block(0)); } DataBlock::concat(&blocks) } @@ -172,7 +172,7 @@ impl Payload { } if blocks.is_empty() { - return Ok(self.empty_block(None)); + return Ok(self.empty_block(0)); } DataBlock::concat(&blocks) diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs index b100564223d3a..548327ecdaff5 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs @@ -207,10 +207,6 @@ impl AggregateFunction for AggregateNullUnaryAdapto self.0.drop_state(place); } - fn convert_const_to_full(&self) -> bool { - self.0.nested.convert_const_to_full() - } - fn get_if_condition(&self, columns: ProjectedBlock) -> Option { self.0.nested.get_if_condition(columns) } @@ -332,10 +328,6 @@ impl AggregateFunction self.0.drop_state(place); } - fn convert_const_to_full(&self) -> bool { - self.0.nested.convert_const_to_full() - } - fn get_if_condition(&self, columns: ProjectedBlock) -> Option { self.0.nested.get_if_condition(columns) } diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs index 3fd0cbcd70ccc..677bb7fc6f814 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs @@ -237,10 +237,6 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { unsafe fn drop_state(&self, place: AggrState) { self.inner.drop_state(place.remove_last_loc()) } - - fn convert_const_to_full(&self) -> bool { - self.inner.convert_const_to_full() - } } impl fmt::Display for AggregateFunctionOrNullAdaptor { diff --git a/src/query/service/src/pipelines/executor/query_pipeline_executor.rs b/src/query/service/src/pipelines/executor/query_pipeline_executor.rs index 88aa39a5a51fa..125d18929ea91 100644 --- a/src/query/service/src/pipelines/executor/query_pipeline_executor.rs +++ b/src/query/service/src/pipelines/executor/query_pipeline_executor.rs @@ -33,7 +33,6 @@ use databend_common_pipeline_core::FinishedCallbackChain; use databend_common_pipeline_core::LockGuard; use databend_common_pipeline_core::Pipeline; use databend_common_pipeline_core::PlanProfile; -use fastrace::func_path; use fastrace::prelude::*; use futures::future::select; use futures_util::future::Either; diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_aggregate_serializer.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_aggregate_serializer.rs index 03b45b75d6607..3032176fe5de8 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_aggregate_serializer.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_aggregate_serializer.rs @@ -224,7 +224,7 @@ impl SerializeAggregateStream { // always return at least one block if self.nums == 0 { self.nums += 1; - let block = p.payload.empty_block(Some(1)); + let block = p.payload.empty_block(1); Ok(Some(block.add_meta(Some( AggregateSerdeMeta::create_agg_payload( p.bucket, From de267414e88bd3740bcf8fdbe6b62080759f179d Mon Sep 17 00:00:00 2001 From: coldWater Date: Tue, 22 Jul 2025 19:39:29 +0800 Subject: [PATCH 05/22] null adaptor --- .../src/aggregate/aggregate_function.rs | 15 +-- .../expression/src/aggregate/payload_flush.rs | 4 +- .../adaptors/aggregate_null_adaptor.rs | 78 ++++++++++---- .../adaptors/aggregate_ornull_adaptor.rs | 100 ++++++++++++++++-- .../aggregator/transform_single_key.rs | 3 +- 5 files changed, 161 insertions(+), 39 deletions(-) diff --git a/src/query/expression/src/aggregate/aggregate_function.rs b/src/query/expression/src/aggregate/aggregate_function.rs index 6ba5b3f569c89..08e5ac3759337 100755 --- a/src/query/expression/src/aggregate/aggregate_function.rs +++ b/src/query/expression/src/aggregate/aggregate_function.rs @@ -73,8 +73,8 @@ pub trait AggregateFunction: fmt::Display + Sync + Send { vec![StateSerdeItem::Binary(self.serialize_size_per_row())] } - fn serialize(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { - let binary_builder = builder.as_tuple_mut().unwrap()[0].as_binary_mut().unwrap(); + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); self.serialize_binary(place, &mut binary_builder.data)?; binary_builder.commit_row(); Ok(()) @@ -86,8 +86,8 @@ pub trait AggregateFunction: fmt::Display + Sync + Send { None } - fn merge(&self, place: AggrState, data: ScalarRef) -> Result<()> { - let mut binary = *data.as_tuple().unwrap()[0].as_binary().unwrap(); + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + let mut binary = *data[0].as_binary().unwrap(); self.merge_binary(place, &mut binary) } @@ -102,7 +102,10 @@ pub trait AggregateFunction: fmt::Display + Sync + Send { ) -> Result<()> { let column = state.to_column(); for (place, data) in places.iter().zip(column.iter()) { - self.merge(AggrState::new(*place, loc), data)?; + self.merge( + AggrState::new(*place, loc), + data.as_tuple().unwrap().as_slice(), + )?; } Ok(()) @@ -111,7 +114,7 @@ pub trait AggregateFunction: fmt::Display + Sync + Send { fn batch_merge_single(&self, place: AggrState, state: &BlockEntry) -> Result<()> { let column = state.to_column(); for data in column.iter() { - self.merge(place, data)?; + self.merge(place, data.as_tuple().unwrap().as_slice())?; } Ok(()) } diff --git a/src/query/expression/src/aggregate/payload_flush.rs b/src/query/expression/src/aggregate/payload_flush.rs index 5863b533e31b9..f6aa0d65d5772 100644 --- a/src/query/expression/src/aggregate/payload_flush.rs +++ b/src/query/expression/src/aggregate/payload_flush.rs @@ -149,8 +149,8 @@ impl Payload { .enumerate() { { - let builder = &mut builders[idx]; - func.serialize(AggrState::new(*place, loc), builder)?; + let builders = builders[idx].as_tuple_mut().unwrap().as_mut_slice(); + func.serialize(AggrState::new(*place, loc), builders)?; } } } diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs index 548327ecdaff5..6a9ef13527f69 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs @@ -23,7 +23,8 @@ use databend_common_expression::utils::column_merge_validity; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; -use databend_common_io::prelude::BinaryWrite; +use databend_common_expression::ScalarRef; +use databend_common_expression::StateSerdeItem; use super::AggrState; use super::AggrStateLoc; @@ -183,12 +184,24 @@ impl AggregateFunction for AggregateNullUnaryAdapto .accumulate_row(place, not_null_columns, validity, row) } - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - self.0.serialize(place, writer) + fn serialize_type(&self) -> Vec { + self.0.serialize_type() } - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - self.0.merge(place, reader) + fn serialize(&self, place: AggrState, builder: &mut [ColumnBuilder]) -> Result<()> { + self.0.serialize(place, builder) + } + + fn serialize_binary(&self, _: AggrState, _: &mut Vec) -> Result<()> { + unreachable!() + } + + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + self.0.merge(place, data) + } + + fn merge_binary(&self, _: AggrState, _: &mut &[u8]) -> Result<()> { + unreachable!() } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { @@ -304,12 +317,20 @@ impl AggregateFunction .accumulate_row(place, not_null_columns, validity, row) } - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - self.0.serialize(place, writer) + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + self.0.serialize(place, builders) + } + + fn serialize_binary(&self, _: AggrState, _: &mut Vec) -> Result<()> { + unreachable!() + } + + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + self.0.merge(place, data) } - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - self.0.merge(place, reader) + fn merge_binary(&self, _: AggrState, _: &mut &[u8]) -> Result<()> { + unreachable!() } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { @@ -488,24 +509,42 @@ impl CommonNullAdaptor { .accumulate_row(place.remove_last_loc(), not_null_columns, row) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize_type(&self) -> Vec { if !NULLABLE_RESULT { - return self.nested.serialize_binary(place, writer); + return self.nested.serialize_type(); } - self.nested - .serialize_binary(place.remove_last_loc(), writer)?; + .serialize_type() + .into_iter() + .chain(Some(StateSerdeItem::Bool)) + .collect() + } + + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + if !NULLABLE_RESULT { + return self.nested.serialize(place, builders); + } + let n = builders.len(); + debug_assert_eq!(self.nested.serialize_type().len() + 1, n); + let flag = get_flag(place); - writer.write_scalar(&flag) + builders + .last_mut() + .and_then(ColumnBuilder::as_boolean_mut) + .unwrap() + .push(flag); + self.nested + .serialize(place.remove_last_loc(), &mut builders[..(n - 1)]) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { if !NULLABLE_RESULT { - return self.nested.merge_binary(place, reader); + return self.nested.merge(place, data); } - let flag = reader[reader.len() - 1]; - if flag == 0 { + let n = data.len(); + let flag = *data.last().and_then(ScalarRef::as_boolean).unwrap(); + if !flag { return Ok(()); } @@ -514,8 +553,7 @@ impl CommonNullAdaptor { self.init_state(place); } set_flag(place, true); - self.nested - .merge_binary(place.remove_last_loc(), &mut &reader[..reader.len() - 1]) + self.nested.merge(place.remove_last_loc(), &data[..n - 1]) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs index 677bb7fc6f814..4c6ee6e25b62b 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs @@ -23,7 +23,8 @@ use databend_common_expression::AggrStateType; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; -use databend_common_io::prelude::BinaryWrite; +use databend_common_expression::ScalarRef; +use databend_common_expression::StateSerdeItem; use super::AggrState; use super::AggrStateLoc; @@ -177,23 +178,45 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { Ok(()) } - #[inline] - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize_type(&self) -> Vec { self.inner - .serialize_binary(place.remove_last_loc(), writer)?; - let flag = get_flag(place) as u8; - writer.write_scalar(&flag) + .serialize_type() + .into_iter() + .chain(Some(StateSerdeItem::Bool)) + .collect() } - #[inline] - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - let flag = get_flag(place) || reader[reader.len() - 1] > 0; + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + let n = builders.len(); + debug_assert_eq!(self.inner.serialize_type().len() + 1, n); + + let flag = get_flag(place); + builders + .last_mut() + .and_then(ColumnBuilder::as_boolean_mut) + .unwrap() + .push(flag); + + self.inner + .serialize(place.remove_last_loc(), &mut builders[..n - 1]) + } + + fn serialize_binary(&self, _: AggrState, _: &mut Vec) -> Result<()> { + unreachable!() + } + + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + let flag = get_flag(place) || *data.last().and_then(ScalarRef::as_boolean).unwrap(); self.inner - .merge_binary(place.remove_last_loc(), &mut &reader[..reader.len() - 1])?; + .merge(place.remove_last_loc(), &data[..data.len() - 1])?; set_flag(place, flag); Ok(()) } + fn merge_binary(&self, _: AggrState, _: &mut &[u8]) -> Result<()> { + unreachable!() + } + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { self.inner .merge_states(place.remove_last_loc(), rhs.remove_last_loc())?; @@ -237,6 +260,63 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { unsafe fn drop_state(&self, place: AggrState) { self.inner.drop_state(place.remove_last_loc()) } + + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &databend_common_expression::BlockEntry, + ) -> Result<()> { + let column = state.to_column(); + for (place, data) in places.iter().zip(column.iter()) { + self.merge( + AggrState::new(*place, loc), + data.as_tuple().unwrap().as_slice(), + )?; + } + + Ok(()) + } + + fn batch_merge_single( + &self, + place: AggrState, + state: &databend_common_expression::BlockEntry, + ) -> Result<()> { + let column = state.to_column(); + for data in column.iter() { + self.merge(place, data.as_tuple().unwrap().as_slice())?; + } + Ok(()) + } + + fn batch_merge_states( + &self, + places: &[StateAddr], + rhses: &[StateAddr], + loc: &[AggrStateLoc], + ) -> Result<()> { + for (place, rhs) in places.iter().zip(rhses.iter()) { + self.merge_states(AggrState::new(*place, loc), AggrState::new(*rhs, loc))?; + } + Ok(()) + } + + fn batch_merge_result( + &self, + places: &[StateAddr], + loc: Box<[AggrStateLoc]>, + builder: &mut ColumnBuilder, + ) -> Result<()> { + for place in places { + self.merge_result(AggrState::new(*place, &loc), builder)?; + } + Ok(()) + } + + fn get_if_condition(&self, _columns: ProjectedBlock) -> Option { + None + } } impl fmt::Display for AggregateFunctionOrNullAdaptor { diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs index 9b5fba779986a..549a2ebc9765b 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs @@ -155,7 +155,8 @@ impl AccumulatingTransform for PartialSingleStateAggregator { ) .zip(builders.iter_mut()) { - func.serialize(place, builder)?; + let builders = builder.as_tuple_mut().unwrap().as_mut_slice(); + func.serialize(place, builders)?; } let columns = builders.into_iter().map(|b| b.build()).collect(); From dfec718ca9d0027ba567fcf00a5233b5e8acdba9 Mon Sep 17 00:00:00 2001 From: coldWater Date: Wed, 23 Jul 2025 12:19:03 +0800 Subject: [PATCH 06/22] serialize_type --- .../src/aggregate/aggregate_function.rs | 8 +-- .../src/aggregate/aggregate_function_state.rs | 10 +-- .../adaptors/aggregate_null_adaptor.rs | 24 +++---- .../adaptors/aggregate_ornull_adaptor.rs | 63 +------------------ .../adaptors/aggregate_sort_adaptor.rs | 5 ++ .../src/aggregates/aggregate_arg_min_max.rs | 5 ++ .../src/aggregates/aggregate_array_agg.rs | 5 ++ .../src/aggregates/aggregate_array_moving.rs | 9 +++ .../src/aggregates/aggregate_bitmap.rs | 9 +++ .../aggregate_combinator_distinct.rs | 5 ++ .../src/aggregates/aggregate_combinator_if.rs | 5 ++ .../aggregates/aggregate_combinator_state.rs | 5 ++ .../src/aggregates/aggregate_count.rs | 5 ++ .../src/aggregates/aggregate_covariance.rs | 5 ++ .../aggregates/aggregate_json_array_agg.rs | 5 ++ .../aggregates/aggregate_json_object_agg.rs | 5 ++ .../src/aggregates/aggregate_markov_tarin.rs | 5 ++ .../src/aggregates/aggregate_null_result.rs | 5 ++ .../aggregates/aggregate_quantile_tdigest.rs | 6 ++ .../aggregate_quantile_tdigest_weighted.rs | 6 ++ .../src/aggregates/aggregate_retention.rs | 5 ++ .../src/aggregates/aggregate_st_collect.rs | 5 ++ .../src/aggregates/aggregate_string_agg.rs | 5 ++ .../src/aggregates/aggregate_unary.rs | 5 ++ .../src/aggregates/aggregate_window_funnel.rs | 5 ++ .../transforms/aggregator/udaf_script.rs | 5 ++ .../physical_aggregate_partial.rs | 2 +- 27 files changed, 135 insertions(+), 92 deletions(-) diff --git a/src/query/expression/src/aggregate/aggregate_function.rs b/src/query/expression/src/aggregate/aggregate_function.rs index 08e5ac3759337..40f213d4deaf5 100755 --- a/src/query/expression/src/aggregate/aggregate_function.rs +++ b/src/query/expression/src/aggregate/aggregate_function.rs @@ -69,9 +69,7 @@ pub trait AggregateFunction: fmt::Display + Sync + Send { // Used in aggregate_null_adaptor fn accumulate_row(&self, place: AggrState, columns: ProjectedBlock, row: usize) -> Result<()>; - fn serialize_type(&self) -> Vec { - vec![StateSerdeItem::Binary(self.serialize_size_per_row())] - } + fn serialize_type(&self) -> Vec; fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { let binary_builder = builders[0].as_binary_mut().unwrap(); @@ -82,10 +80,6 @@ pub trait AggregateFunction: fmt::Display + Sync + Send { fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()>; - fn serialize_size_per_row(&self) -> Option { - None - } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { let mut binary = *data[0].as_binary().unwrap(); self.merge_binary(place, &mut binary) diff --git a/src/query/expression/src/aggregate/aggregate_function_state.rs b/src/query/expression/src/aggregate/aggregate_function_state.rs index 791ee68542ff7..dc58aa8f0f5b2 100644 --- a/src/query/expression/src/aggregate/aggregate_function_state.rs +++ b/src/query/expression/src/aggregate/aggregate_function_state.rs @@ -193,9 +193,9 @@ impl AggrStateLoc { } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone)] pub enum StateSerdeItem { - Bool, + DataType(DataType), Binary(Option), } @@ -208,7 +208,7 @@ impl StateSerdeType { self.0 .iter() .map(|item| match item { - StateSerdeItem::Bool => DataType::Boolean, + StateSerdeItem::DataType(data_type) => data_type.clone(), StateSerdeItem::Binary(_) => DataType::Binary, }) .collect(), @@ -232,8 +232,8 @@ impl StatesLayout { .0 .iter() .map(|item| match item { - StateSerdeItem::Bool => { - ColumnBuilder::with_capacity(&DataType::Boolean, num_rows) + StateSerdeItem::DataType(data_type) => { + ColumnBuilder::with_capacity(data_type, num_rows) } StateSerdeItem::Binary(size) => { ColumnBuilder::Binary(BinaryColumnBuilder::with_capacity( diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs index 6a9ef13527f69..3eb1a017553f1 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs @@ -131,10 +131,6 @@ impl AggregateFunction for AggregateNullUnaryAdapto self.0.init_state(place); } - fn serialize_size_per_row(&self) -> Option { - self.0.serialize_size_per_row() - } - fn register_state(&self, registry: &mut AggrStateRegistry) { self.0.register_state(registry); } @@ -188,8 +184,8 @@ impl AggregateFunction for AggregateNullUnaryAdapto self.0.serialize_type() } - fn serialize(&self, place: AggrState, builder: &mut [ColumnBuilder]) -> Result<()> { - self.0.serialize(place, builder) + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + self.0.serialize(place, builders) } fn serialize_binary(&self, _: AggrState, _: &mut Vec) -> Result<()> { @@ -257,10 +253,6 @@ impl AggregateFunction self.0.init_state(place); } - fn serialize_size_per_row(&self) -> Option { - self.0.serialize_size_per_row() - } - fn register_state(&self, registry: &mut AggrStateRegistry) { self.0.register_state(registry); } @@ -317,6 +309,10 @@ impl AggregateFunction .accumulate_row(place, not_null_columns, validity, row) } + fn serialize_type(&self) -> Vec { + self.0.serialize_type() + } + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { self.0.serialize(place, builders) } @@ -384,12 +380,6 @@ impl CommonNullAdaptor { self.nested.init_state(place.remove_last_loc()); } - fn serialize_size_per_row(&self) -> Option { - self.nested - .serialize_size_per_row() - .map(|row| if NULLABLE_RESULT { row + 1 } else { row }) - } - fn register_state(&self, registry: &mut AggrStateRegistry) { self.nested.register_state(registry); if NULLABLE_RESULT { @@ -516,7 +506,7 @@ impl CommonNullAdaptor { self.nested .serialize_type() .into_iter() - .chain(Some(StateSerdeItem::Bool)) + .chain(Some(StateSerdeItem::DataType(DataType::Boolean))) .collect() } diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs index 4c6ee6e25b62b..513ffaec97f3e 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs @@ -90,10 +90,6 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { self.inner.init_state(place.remove_last_loc()) } - fn serialize_size_per_row(&self) -> Option { - self.inner.serialize_size_per_row().map(|row| row + 1) - } - fn register_state(&self, registry: &mut AggrStateRegistry) { self.inner.register_state(registry); registry.register(AggrStateType::Bool); @@ -182,7 +178,7 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { self.inner .serialize_type() .into_iter() - .chain(Some(StateSerdeItem::Bool)) + .chain(Some(StateSerdeItem::DataType(DataType::Boolean))) .collect() } @@ -260,63 +256,6 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { unsafe fn drop_state(&self, place: AggrState) { self.inner.drop_state(place.remove_last_loc()) } - - fn batch_merge( - &self, - places: &[StateAddr], - loc: &[AggrStateLoc], - state: &databend_common_expression::BlockEntry, - ) -> Result<()> { - let column = state.to_column(); - for (place, data) in places.iter().zip(column.iter()) { - self.merge( - AggrState::new(*place, loc), - data.as_tuple().unwrap().as_slice(), - )?; - } - - Ok(()) - } - - fn batch_merge_single( - &self, - place: AggrState, - state: &databend_common_expression::BlockEntry, - ) -> Result<()> { - let column = state.to_column(); - for data in column.iter() { - self.merge(place, data.as_tuple().unwrap().as_slice())?; - } - Ok(()) - } - - fn batch_merge_states( - &self, - places: &[StateAddr], - rhses: &[StateAddr], - loc: &[AggrStateLoc], - ) -> Result<()> { - for (place, rhs) in places.iter().zip(rhses.iter()) { - self.merge_states(AggrState::new(*place, loc), AggrState::new(*rhs, loc))?; - } - Ok(()) - } - - fn batch_merge_result( - &self, - places: &[StateAddr], - loc: Box<[AggrStateLoc]>, - builder: &mut ColumnBuilder, - ) -> Result<()> { - for place in places { - self.merge_result(AggrState::new(*place, &loc), builder)?; - } - Ok(()) - } - - fn get_if_condition(&self, _columns: ProjectedBlock) -> Option { - None - } } impl fmt::Display for AggregateFunctionOrNullAdaptor { diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_sort_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_sort_adaptor.rs index 20b11214d15fc..5c0c37500a423 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_sort_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_sort_adaptor.rs @@ -34,6 +34,7 @@ use databend_common_expression::DataBlock; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; use databend_common_expression::SortColumnDescription; +use databend_common_expression::StateSerdeItem; use itertools::Itertools; use crate::aggregates::AggregateFunctionSortDesc; @@ -121,6 +122,10 @@ impl AggregateFunction for AggregateFunctionSortAdaptor { Ok(()) } + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] + } + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = Self::get_state(place); Ok(state.serialize(writer)?) diff --git a/src/query/functions/src/aggregates/aggregate_arg_min_max.rs b/src/query/functions/src/aggregates/aggregate_arg_min_max.rs index 5311335b6b229..a9043c479cda8 100644 --- a/src/query/functions/src/aggregates/aggregate_arg_min_max.rs +++ b/src/query/functions/src/aggregates/aggregate_arg_min_max.rs @@ -31,6 +31,7 @@ use databend_common_expression::ColumnBuilder; use databend_common_expression::ColumnView; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::StateSerdeItem; use super::aggregate_function_factory::AggregateFunctionDescription; use super::aggregate_function_factory::AggregateFunctionSortDesc; @@ -270,6 +271,10 @@ where Ok(()) } + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] + } + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.serialize(writer)?) diff --git a/src/query/functions/src/aggregates/aggregate_array_agg.rs b/src/query/functions/src/aggregates/aggregate_array_agg.rs index ddeba25f9c86f..ce44a9d3dc691 100644 --- a/src/query/functions/src/aggregates/aggregate_array_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_array_agg.rs @@ -48,6 +48,7 @@ use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; use databend_common_expression::ScalarRef; +use databend_common_expression::StateSerdeItem; use super::aggregate_function_factory::AggregateFunctionDescription; use super::aggregate_scalar_state::ScalarStateFunc; @@ -548,6 +549,10 @@ where Ok(()) } + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] + } + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.serialize(writer)?) diff --git a/src/query/functions/src/aggregates/aggregate_array_moving.rs b/src/query/functions/src/aggregates/aggregate_array_moving.rs index 0070924a974df..423af8229e265 100644 --- a/src/query/functions/src/aggregates/aggregate_array_moving.rs +++ b/src/query/functions/src/aggregates/aggregate_array_moving.rs @@ -45,6 +45,7 @@ use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; use databend_common_expression::ScalarRef; +use databend_common_expression::StateSerdeItem; use num_traits::AsPrimitive; use super::aggregate_function::AggregateFunction; @@ -447,6 +448,10 @@ where State: SumState state.accumulate_row(&columns[0], row) } + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] + } + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.serialize(writer)?) @@ -616,6 +621,10 @@ where State: SumState state.accumulate_row(&columns[0], row) } + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] + } + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.serialize(writer)?) diff --git a/src/query/functions/src/aggregates/aggregate_bitmap.rs b/src/query/functions/src/aggregates/aggregate_bitmap.rs index 3f56ff133296b..348749834070a 100644 --- a/src/query/functions/src/aggregates/aggregate_bitmap.rs +++ b/src/query/functions/src/aggregates/aggregate_bitmap.rs @@ -36,6 +36,7 @@ use databend_common_expression::AggrStateType; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::StateSerdeItem; use databend_common_io::prelude::BinaryWrite; use roaring::RoaringTreemap; @@ -287,6 +288,10 @@ where Ok(()) } + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] + } + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); // flag indicate where bitmap is none @@ -482,6 +487,10 @@ where Ok(()) } + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] + } + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { self.inner.serialize_binary(place, writer) } diff --git a/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs b/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs index d74b9a32f661a..4f6ec3fafefd9 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs @@ -28,6 +28,7 @@ use databend_common_expression::AggrStateType; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::StateSerdeItem; use super::aggregate_distinct_state::AggregateDistinctNumberState; use super::aggregate_distinct_state::AggregateDistinctState; @@ -106,6 +107,10 @@ where State: DistinctStateFunc state.add(columns, row) } + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] + } + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = Self::get_state(place); state.serialize(writer) diff --git a/src/query/functions/src/aggregates/aggregate_combinator_if.rs b/src/query/functions/src/aggregates/aggregate_combinator_if.rs index 03bbdf625e4df..38ef9fd660af5 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_if.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_if.rs @@ -26,6 +26,7 @@ use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::StateSerdeItem; use super::StateAddr; use crate::aggregates::aggregate_function_factory::AggregateFunctionCreator; @@ -155,6 +156,10 @@ impl AggregateFunction for AggregateIfCombinator { Ok(()) } + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] + } + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { self.nested.serialize_binary(place, writer) } diff --git a/src/query/functions/src/aggregates/aggregate_combinator_state.rs b/src/query/functions/src/aggregates/aggregate_combinator_state.rs index e83d78f24164c..74a02da005941 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_state.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_state.rs @@ -22,6 +22,7 @@ use databend_common_expression::AggrStateRegistry; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::StateSerdeItem; use super::AggregateFunctionFactory; use super::StateAddr; @@ -106,6 +107,10 @@ impl AggregateFunction for AggregateStateCombinator { self.nested.accumulate_row(place, columns, row) } + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] + } + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { self.nested.serialize_binary(place, writer) } diff --git a/src/query/functions/src/aggregates/aggregate_count.rs b/src/query/functions/src/aggregates/aggregate_count.rs index 9e1db7e50a989..64c1cb850f8ca 100644 --- a/src/query/functions/src/aggregates/aggregate_count.rs +++ b/src/query/functions/src/aggregates/aggregate_count.rs @@ -29,6 +29,7 @@ use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::StateSerdeItem; use super::aggregate_function::AggregateFunction; use super::aggregate_function_factory::AggregateFunctionDescription; @@ -160,6 +161,10 @@ impl AggregateFunction for AggregateCountFunction { Ok(()) } + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] + } + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.count.serialize(writer)?) diff --git a/src/query/functions/src/aggregates/aggregate_covariance.rs b/src/query/functions/src/aggregates/aggregate_covariance.rs index c69d242ea811a..774c955679255 100644 --- a/src/query/functions/src/aggregates/aggregate_covariance.rs +++ b/src/query/functions/src/aggregates/aggregate_covariance.rs @@ -34,6 +34,7 @@ use databend_common_expression::AggrStateType; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::StateSerdeItem; use num_traits::AsPrimitive; use super::borsh_partial_deserialize; @@ -231,6 +232,10 @@ where Ok(()) } + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] + } + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.serialize(writer)?) diff --git a/src/query/functions/src/aggregates/aggregate_json_array_agg.rs b/src/query/functions/src/aggregates/aggregate_json_array_agg.rs index acd867e4f2a8e..69cc6b44ca7c3 100644 --- a/src/query/functions/src/aggregates/aggregate_json_array_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_json_array_agg.rs @@ -34,6 +34,7 @@ use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; use databend_common_expression::ScalarRef; +use databend_common_expression::StateSerdeItem; use jiff::tz::TimeZone; use jsonb::OwnedJsonb; use jsonb::RawJsonb; @@ -240,6 +241,10 @@ where Ok(()) } + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] + } + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.serialize(writer)?) diff --git a/src/query/functions/src/aggregates/aggregate_json_object_agg.rs b/src/query/functions/src/aggregates/aggregate_json_object_agg.rs index 45179f529cee2..e9a0d72f9de98 100644 --- a/src/query/functions/src/aggregates/aggregate_json_object_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_json_object_agg.rs @@ -36,6 +36,7 @@ use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; use databend_common_expression::ScalarRef; +use databend_common_expression::StateSerdeItem; use jiff::tz::TimeZone; use jsonb::OwnedJsonb; use jsonb::RawJsonb; @@ -293,6 +294,10 @@ where Ok(()) } + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] + } + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.serialize(writer)?) diff --git a/src/query/functions/src/aggregates/aggregate_markov_tarin.rs b/src/query/functions/src/aggregates/aggregate_markov_tarin.rs index 285965bcbafa8..e0b485983bf26 100644 --- a/src/query/functions/src/aggregates/aggregate_markov_tarin.rs +++ b/src/query/functions/src/aggregates/aggregate_markov_tarin.rs @@ -40,6 +40,7 @@ use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::StateSerdeItem; use super::aggregate_function_factory::AggregateFunctionDescription; use super::assert_unary_arguments; @@ -113,6 +114,10 @@ impl AggregateFunction for MarkovTarin { Ok(()) } + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] + } + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.serialize(writer)?) diff --git a/src/query/functions/src/aggregates/aggregate_null_result.rs b/src/query/functions/src/aggregates/aggregate_null_result.rs index 74e2942e9bceb..5333ba8910445 100644 --- a/src/query/functions/src/aggregates/aggregate_null_result.rs +++ b/src/query/functions/src/aggregates/aggregate_null_result.rs @@ -25,6 +25,7 @@ use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; +use databend_common_expression::StateSerdeItem; use super::aggregate_function::AggregateFunction; use super::StateAddr; @@ -86,6 +87,10 @@ impl AggregateFunction for AggregateNullResultFunction { Ok(()) } + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] + } + fn serialize_binary(&self, _place: AggrState, _writer: &mut Vec) -> Result<()> { Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs b/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs index d982f8b80ca14..3e3dae0759936 100644 --- a/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs +++ b/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs @@ -35,6 +35,7 @@ use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; use databend_common_expression::ScalarRef; +use databend_common_expression::StateSerdeItem; use itertools::Itertools; use super::borsh_partial_deserialize; @@ -362,6 +363,11 @@ where for<'a> T: AccessType = F64> + Send + Sync }); Ok(()) } + + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] + } + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.serialize(writer)?) diff --git a/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs b/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs index d4aa266fd1a5e..dfc03b876c35a 100644 --- a/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs +++ b/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs @@ -31,6 +31,7 @@ use databend_common_expression::AggrStateType; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::StateSerdeItem; use num_traits::AsPrimitive; use super::borsh_partial_deserialize; @@ -143,6 +144,11 @@ where }); Ok(()) } + + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] + } + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.serialize(writer)?) diff --git a/src/query/functions/src/aggregates/aggregate_retention.rs b/src/query/functions/src/aggregates/aggregate_retention.rs index 5b904f234e740..a5d6780e73dc9 100644 --- a/src/query/functions/src/aggregates/aggregate_retention.rs +++ b/src/query/functions/src/aggregates/aggregate_retention.rs @@ -29,6 +29,7 @@ use databend_common_expression::AggrStateType; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::StateSerdeItem; use super::aggregate_function::AggregateFunction; use super::aggregate_function::AggregateFunctionRef; @@ -143,6 +144,10 @@ impl AggregateFunction for AggregateRetentionFunction { Ok(()) } + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] + } + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.serialize(writer)?) diff --git a/src/query/functions/src/aggregates/aggregate_st_collect.rs b/src/query/functions/src/aggregates/aggregate_st_collect.rs index bfe960471c5ee..4657955f0dbb4 100644 --- a/src/query/functions/src/aggregates/aggregate_st_collect.rs +++ b/src/query/functions/src/aggregates/aggregate_st_collect.rs @@ -34,6 +34,7 @@ use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; use databend_common_expression::ScalarRef; +use databend_common_expression::StateSerdeItem; use databend_common_io::ewkb_to_geo; use databend_common_io::geo_to_ewkb; use geo::Geometry; @@ -306,6 +307,10 @@ where Ok(()) } + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] + } + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.serialize(writer)?) diff --git a/src/query/functions/src/aggregates/aggregate_string_agg.rs b/src/query/functions/src/aggregates/aggregate_string_agg.rs index 743c2592d0b9b..4bf05651aeae7 100644 --- a/src/query/functions/src/aggregates/aggregate_string_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_string_agg.rs @@ -34,6 +34,7 @@ use databend_common_expression::Evaluator; use databend_common_expression::FunctionContext; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::StateSerdeItem; use super::aggregate_function_factory::AggregateFunctionDescription; use super::borsh_partial_deserialize; @@ -147,6 +148,10 @@ impl AggregateFunction for AggregateStringAggFunction { Ok(()) } + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] + } + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); Ok(state.serialize(writer)?) diff --git a/src/query/functions/src/aggregates/aggregate_unary.rs b/src/query/functions/src/aggregates/aggregate_unary.rs index be20d15157534..a18ba2871ede3 100644 --- a/src/query/functions/src/aggregates/aggregate_unary.rs +++ b/src/query/functions/src/aggregates/aggregate_unary.rs @@ -32,6 +32,7 @@ use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; use databend_common_expression::StateAddr; +use databend_common_expression::StateSerdeItem; use crate::aggregates::AggrState; use crate::aggregates::AggrStateLoc; @@ -227,6 +228,10 @@ where Ok(()) } + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] + } + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state: &mut S = place.get::(); Ok(state.serialize(writer)?) diff --git a/src/query/functions/src/aggregates/aggregate_window_funnel.rs b/src/query/functions/src/aggregates/aggregate_window_funnel.rs index 864e5d0eaeea5..d87dc3a5c0683 100644 --- a/src/query/functions/src/aggregates/aggregate_window_funnel.rs +++ b/src/query/functions/src/aggregates/aggregate_window_funnel.rs @@ -40,6 +40,7 @@ use databend_common_expression::AggrStateType; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::StateSerdeItem; use num_traits::AsPrimitive; use super::borsh_partial_deserialize; @@ -275,6 +276,10 @@ where Ok(()) } + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] + } + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::>(); Ok(state.serialize(writer)?) diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs index fdfc1dcbe6d39..748e0d89b9252 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs @@ -39,6 +39,7 @@ use databend_common_expression::DataBlock; use databend_common_expression::DataField; use databend_common_expression::DataSchema; use databend_common_expression::ProjectedBlock; +use databend_common_expression::StateSerdeItem; use databend_common_functions::aggregates::AggregateFunction; use databend_common_sql::plans::UDFLanguage; use databend_common_sql::plans::UDFScriptCode; @@ -98,6 +99,10 @@ impl AggregateFunction for AggregateUdfScript { Ok(()) } + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] + } + fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { let state = place.get::(); state diff --git a/src/query/sql/src/executor/physical_plans/physical_aggregate_partial.rs b/src/query/sql/src/executor/physical_plans/physical_aggregate_partial.rs index 5026f7ba38f9a..9c8e4a3ff4ea3 100644 --- a/src/query/sql/src/executor/physical_plans/physical_aggregate_partial.rs +++ b/src/query/sql/src/executor/physical_plans/physical_aggregate_partial.rs @@ -73,7 +73,7 @@ impl AggregatePartial { .serialize_type() .iter() .map(|serde_type| match serde_type { - StateSerdeItem::Bool => DataType::Boolean, + StateSerdeItem::DataType(data_type) => data_type.clone(), StateSerdeItem::Binary(_) => DataType::Binary, }) .collect(); From 2d76884bd333fa07b4c750044a23ee7412537d75 Mon Sep 17 00:00:00 2001 From: coldWater Date: Wed, 23 Jul 2025 15:03:18 +0800 Subject: [PATCH 07/22] ee --- .../src/aggregate/aggregate_function.rs | 6 ++ .../src/aggregate/aggregate_function_state.rs | 4 ++ .../aggregates/aggregate_combinator_state.rs | 28 +++++---- .../physical_aggregate_partial.rs | 12 +--- .../operator/aggregate/stats_aggregate.rs | 5 ++ .../rule/agg_rules/agg_index/query_rewrite.rs | 57 ++++++++++++++----- .../02_0001_agg_index_projected_scan.test | 37 +++++++++++- 7 files changed, 113 insertions(+), 36 deletions(-) diff --git a/src/query/expression/src/aggregate/aggregate_function.rs b/src/query/expression/src/aggregate/aggregate_function.rs index 40f213d4deaf5..57b4095abd148 100755 --- a/src/query/expression/src/aggregate/aggregate_function.rs +++ b/src/query/expression/src/aggregate/aggregate_function.rs @@ -29,6 +29,7 @@ use crate::ProjectedBlock; use crate::Scalar; use crate::ScalarRef; use crate::StateSerdeItem; +use crate::StateSerdeType; pub type AggregateFunctionRef = Arc; @@ -71,6 +72,11 @@ pub trait AggregateFunction: fmt::Display + Sync + Send { fn serialize_type(&self) -> Vec; + fn serialize_data_type(&self) -> DataType { + let serde_type = StateSerdeType::new(self.serialize_type()); + serde_type.data_type() + } + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { let binary_builder = builders[0].as_binary_mut().unwrap(); self.serialize_binary(place, &mut binary_builder.data)?; diff --git a/src/query/expression/src/aggregate/aggregate_function_state.rs b/src/query/expression/src/aggregate/aggregate_function_state.rs index dc58aa8f0f5b2..0c07ee1aaab33 100644 --- a/src/query/expression/src/aggregate/aggregate_function_state.rs +++ b/src/query/expression/src/aggregate/aggregate_function_state.rs @@ -203,6 +203,10 @@ pub enum StateSerdeItem { pub struct StateSerdeType(Box<[StateSerdeItem]>); impl StateSerdeType { + pub fn new(items: impl Into>) -> Self { + StateSerdeType(items.into()) + } + pub fn data_type(&self) -> DataType { DataType::Tuple( self.0 diff --git a/src/query/functions/src/aggregates/aggregate_combinator_state.rs b/src/query/functions/src/aggregates/aggregate_combinator_state.rs index 74a02da005941..f0ebbb1d4f411 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_state.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_state.rs @@ -22,6 +22,7 @@ use databend_common_expression::AggrStateRegistry; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use super::AggregateFunctionFactory; @@ -71,7 +72,7 @@ impl AggregateFunction for AggregateStateCombinator { } fn return_type(&self) -> Result { - Ok(DataType::Binary) + Ok(self.nested.serialize_data_type()) } fn init_state(&self, place: AggrState) { @@ -108,16 +109,23 @@ impl AggregateFunction for AggregateStateCombinator { } fn serialize_type(&self) -> Vec { - vec![StateSerdeItem::Binary(None)] + self.nested.serialize_type() } - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - self.nested.serialize_binary(place, writer) + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + self.nested.serialize(place, builders) } - #[inline] - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - self.nested.merge_binary(place, reader) + fn serialize_binary(&self, _: AggrState, _: &mut Vec) -> Result<()> { + unreachable!() + } + + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + self.nested.merge(place, data) + } + + fn merge_binary(&self, _: AggrState, _: &mut &[u8]) -> Result<()> { + unreachable!() } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { @@ -125,10 +133,8 @@ impl AggregateFunction for AggregateStateCombinator { } fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { - let builder = builder.as_binary_mut().unwrap(); - self.nested.serialize_binary(place, &mut builder.data)?; - builder.commit_row(); - Ok(()) + let builders = builder.as_tuple_mut().unwrap().as_mut_slice(); + self.nested.serialize(place, builders) } fn need_manual_drop_state(&self) -> bool { diff --git a/src/query/sql/src/executor/physical_plans/physical_aggregate_partial.rs b/src/query/sql/src/executor/physical_plans/physical_aggregate_partial.rs index 9c8e4a3ff4ea3..ab915b3c727e5 100644 --- a/src/query/sql/src/executor/physical_plans/physical_aggregate_partial.rs +++ b/src/query/sql/src/executor/physical_plans/physical_aggregate_partial.rs @@ -17,7 +17,6 @@ use databend_common_expression::types::DataType; use databend_common_expression::DataField; use databend_common_expression::DataSchemaRef; use databend_common_expression::DataSchemaRefExt; -use databend_common_expression::StateSerdeItem; use databend_common_functions::aggregates::AggregateFunctionFactory; use super::SortDesc; @@ -69,16 +68,7 @@ impl AggregatePartial { ) .unwrap(); - let tuple = func - .serialize_type() - .iter() - .map(|serde_type| match serde_type { - StateSerdeItem::DataType(data_type) => data_type.clone(), - StateSerdeItem::Binary(_) => DataType::Binary, - }) - .collect(); - - fields.push(DataField::new(&name, DataType::Tuple(tuple))) + fields.push(DataField::new(&name, func.serialize_data_type())) } for (idx, field) in self.group_by.iter().zip( diff --git a/src/query/sql/src/planner/optimizer/optimizers/operator/aggregate/stats_aggregate.rs b/src/query/sql/src/planner/optimizer/optimizers/operator/aggregate/stats_aggregate.rs index 03a604ff86ebc..cff4c1fddaee3 100644 --- a/src/query/sql/src/planner/optimizer/optimizers/operator/aggregate/stats_aggregate.rs +++ b/src/query/sql/src/planner/optimizer/optimizers/operator/aggregate/stats_aggregate.rs @@ -129,6 +129,11 @@ impl RuleStatsAggregateOptimizer { for (need_rewrite_agg, agg) in need_rewrite_aggs.iter().zip(agg.aggregate_functions.iter()) { + if matches!(agg.scalar, ScalarExpr::UDAFCall(_)) { + agg_results.push(agg.clone()); + continue; + } + let agg_func = AggregateFunction::try_from(agg.scalar.clone())?; if let Some((col_id, name)) = need_rewrite_agg { diff --git a/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/agg_index/query_rewrite.rs b/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/agg_index/query_rewrite.rs index 22885188c78ee..09e75a0a9f594 100644 --- a/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/agg_index/query_rewrite.rs +++ b/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/agg_index/query_rewrite.rs @@ -25,6 +25,7 @@ use databend_common_expression::types::DataType; use databend_common_expression::Scalar; use databend_common_expression::TableField; use databend_common_expression::TableSchemaRefExt; +use databend_common_functions::aggregates::AggregateFunctionFactory; use itertools::Itertools; use log::info; @@ -306,6 +307,12 @@ impl QueryInfo { } return Ok(None); } + ScalarExpr::UDAFCall(udaf) => { + for arg in &udaf.arguments { + self.check_output_cols(arg, index_output_cols, new_selection_set)?; + } + return Ok(None); + } ScalarExpr::UDFCall(udf) => { let mut valid = true; let mut new_args = Vec::with_capacity(udf.arguments.len()); @@ -361,23 +368,38 @@ impl ViewInfo { // query can use those columns to compute expressions. let mut index_fields = Vec::with_capacity(query_info.output_cols.len()); let mut index_output_cols = HashMap::with_capacity(query_info.output_cols.len()); + let factory = AggregateFunctionFactory::instance(); for (index, item) in query_info.output_cols.iter().enumerate() { let display_name = format_scalar(&item.scalar, &query_info.column_map); - let mut is_agg = false; - if let Some(ref aggregate) = query_info.aggregate { - for agg_func in &aggregate.aggregate_functions { - if item.index == agg_func.index { - is_agg = true; - break; - } + let aggr_scalar_item = query_info.aggregate.as_ref().and_then(|aggregate| { + aggregate + .aggregate_functions + .iter() + .find(|agg_func| agg_func.index == item.index) + }); + + let (data_type, is_agg) = match aggr_scalar_item { + Some(item) => { + let func = match &item.scalar { + ScalarExpr::AggregateFunction(func) => func, + _ => unreachable!(), + }; + let func = factory.get( + &func.func_name, + func.params.clone(), + func.args + .iter() + .map(|arg| arg.data_type()) + .collect::>()?, + func.sort_descs + .iter() + .map(|desc| desc.try_into()) + .collect::>()?, + )?; + (func.serialize_data_type(), true) } - } - // we store the value of aggregate function as binary data. - let data_type = if is_agg { - DataType::Binary - } else { - item.scalar.data_type().unwrap() + None => (item.scalar.data_type().unwrap(), false), }; let name = format!("{index}"); @@ -1299,6 +1321,15 @@ fn format_scalar(scalar: &ScalarExpr, column_map: &HashMap { + let args = udaf + .arguments + .iter() + .map(|arg| format_scalar(arg, column_map)) + .collect::>() + .join(", "); + format!("{}({})", &udaf.name, args) + } ScalarExpr::UDFCall(udf) => format!( "{}({})", &udf.handler, diff --git a/tests/sqllogictests/suites/ee/02_ee_aggregating_index/02_0001_agg_index_projected_scan.test b/tests/sqllogictests/suites/ee/02_ee_aggregating_index/02_0001_agg_index_projected_scan.test index d9e8286a48e69..c5525dd6f40a4 100644 --- a/tests/sqllogictests/suites/ee/02_ee_aggregating_index/02_0001_agg_index_projected_scan.test +++ b/tests/sqllogictests/suites/ee/02_ee_aggregating_index/02_0001_agg_index_projected_scan.test @@ -75,7 +75,42 @@ SELECT b, SUM(a) from t WHERE c > 1 GROUP BY b ORDER BY b 1 1 2 3 -query IIT +query IIR SELECT MAX(a), MIN(b), AVG(c) from t ---- 2 1 3.5 + +statement ok +CREATE or REPLACE FUNCTION weighted_avg (INT, INT) STATE {sum INT, weight INT} RETURNS FLOAT +LANGUAGE javascript AS $$ +export function create_state() { + return {sum: 0, weight: 0}; +} +export function accumulate(state, value, weight) { + state.sum += value * weight; + state.weight += weight; + return state; +} +export function retract(state, value, weight) { + state.sum -= value * weight; + state.weight -= weight; + return state; +} +export function merge(state1, state2) { + state1.sum += state2.sum; + state1.weight += state2.weight; + return state1; +} +export function finish(state) { + return state.sum / state.weight; +} +$$; + +query IIR +SELECT MAX(a), MIN(b), weighted_avg(a,b) from t group by b; +---- +2 2 1.3333334 +1 1 1.0 + +# fix me +# SELECT MAX(a), MIN(b), weighted_avg(a,b) from t; From 5b16525b1972a960786d0cc702c09e8b089a4b42 Mon Sep 17 00:00:00 2001 From: coldWater Date: Wed, 23 Jul 2025 18:01:53 +0800 Subject: [PATCH 08/22] fix --- .../src/aggregates/aggregator_common.rs | 17 ++++++++++++++--- .../storages/fuse/src/operations/analyze.rs | 4 ++-- .../02_0000_function_aggregate_state.test | 10 +++++----- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/src/query/functions/src/aggregates/aggregator_common.rs b/src/query/functions/src/aggregates/aggregator_common.rs index 8d4255a991983..76f4264e59684 100644 --- a/src/query/functions/src/aggregates/aggregator_common.rs +++ b/src/query/functions/src/aggregates/aggregator_common.rs @@ -29,6 +29,7 @@ use databend_common_expression::ColumnBuilder; use databend_common_expression::Constant; use databend_common_expression::FunctionContext; use databend_common_expression::Scalar; +use databend_common_expression::ScalarRef; use databend_common_expression::StateAddr; use super::get_states_layout; @@ -186,10 +187,20 @@ pub fn eval_aggr_for_test( let state = AggrState::new(eval.addr, &eval.state_layout.states_loc[0]); func.accumulate(state, entries.into(), None, rows)?; if with_serialize { - let mut buf = vec![]; - func.serialize_binary(state, &mut buf)?; + let data_type = func.serialize_data_type(); + let mut builder = ColumnBuilder::with_capacity(&data_type, 1); + let builders = builder.as_tuple_mut().unwrap().as_mut_slice(); + func.serialize(state, builders)?; func.init_state(state); - func.merge_binary(state, &mut buf.as_slice())?; + let column = builder.build(); + let data = column.index(0); + func.merge( + state, + data.as_ref() + .and_then(ScalarRef::as_tuple) + .unwrap() + .as_slice(), + )?; } let mut builder = ColumnBuilder::with_capacity(&data_type, 1024); func.merge_result(state, &mut builder)?; diff --git a/src/query/storages/fuse/src/operations/analyze.rs b/src/query/storages/fuse/src/operations/analyze.rs index b6d951a84545b..84dc9817826c5 100644 --- a/src/query/storages/fuse/src/operations/analyze.rs +++ b/src/query/storages/fuse/src/operations/analyze.rs @@ -186,8 +186,8 @@ impl SinkAnalyzeState { let index: u32 = name.strip_prefix("ndv_").unwrap().parse().unwrap(); let col = col.index(0).unwrap(); - let col = col.as_binary().unwrap(); - let hll: MetaHLL = borsh_deserialize_from_slice(col)?; + let data = col.as_tuple().unwrap()[0].as_binary().unwrap(); + let hll: MetaHLL = borsh_deserialize_from_slice(data)?; if !is_full { ndv_states diff --git a/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_state.test b/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_state.test index bd2c249067d25..e425b986c7262 100644 --- a/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_state.test +++ b/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_state.test @@ -1,9 +1,9 @@ query IT -select length(max_state(number)), typeof(max_state(number)) from numbers(100); +select length(max_state(number).1), typeof(max_state(number)) from numbers(100); ---- -10 BINARY +9 TUPLE(BINARY, BOOLEAN) -query I -select length(sum_state(number)), typeof(max_state(number)) from numbers(10000); +query IT +select length(sum_state(number).1), typeof(max_state(number)) from numbers(10000); ---- -9 BINARY +8 TUPLE(BINARY, BOOLEAN) From 828da4ad2d672b70524334575800787521e97a2e Mon Sep 17 00:00:00 2001 From: coldWater Date: Wed, 23 Jul 2025 20:52:18 +0800 Subject: [PATCH 09/22] if --- .../src/aggregates/aggregate_combinator_if.rs | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/query/functions/src/aggregates/aggregate_combinator_if.rs b/src/query/functions/src/aggregates/aggregate_combinator_if.rs index 38ef9fd660af5..eb737aeb18948 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_if.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_if.rs @@ -26,6 +26,7 @@ use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use super::StateAddr; @@ -54,20 +55,18 @@ impl AggregateIfCombinator { sort_descs: Vec, nested_creator: &AggregateFunctionCreator, ) -> Result { - let name = format!("IfCombinator({})", nested_name); + let name = format!("IfCombinator({nested_name})"); let argument_len = arguments.len(); if argument_len == 0 { return Err(ErrorCode::NumberArgumentsNotMatch(format!( - "{} expect to have more than one argument", - name + "{name} expect to have more than one argument", ))); } if !matches!(&arguments[argument_len - 1], DataType::Boolean) { return Err(ErrorCode::BadArguments(format!( - "The type of the last argument for {} must be boolean type, but got {:?}", - name, + "The type of the last argument for {name} must be boolean type, but got {:?}", &arguments[argument_len - 1] ))); } @@ -157,15 +156,23 @@ impl AggregateFunction for AggregateIfCombinator { } fn serialize_type(&self) -> Vec { - vec![StateSerdeItem::Binary(None)] + self.nested.serialize_type() } - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - self.nested.serialize_binary(place, writer) + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + self.nested.serialize(place, builders) } - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - self.nested.merge_binary(place, reader) + fn serialize_binary(&self, _: AggrState, _: &mut Vec) -> Result<()> { + unreachable!() + } + + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + self.nested.merge(place, data) + } + + fn merge_binary(&self, _: AggrState, _: &mut &[u8]) -> Result<()> { + unreachable!() } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { From 72f3a2900b2da1c5eada07b1409d25ab951813a3 Mon Sep 17 00:00:00 2001 From: coldWater Date: Thu, 24 Jul 2025 11:15:35 +0800 Subject: [PATCH 10/22] fix --- src/query/catalog/src/plan/agg_index.rs | 4 +- .../src/aggregate/aggregate_function_state.rs | 4 + .../aggregator/transform_aggregate_partial.rs | 2 +- .../aggregator/transform_single_key.rs | 15 ++-- .../operator/aggregate/stats_aggregate.rs | 77 ++++++++++--------- .../02_0001_agg_index_projected_scan.test | 6 +- 6 files changed, 60 insertions(+), 48 deletions(-) diff --git a/src/query/catalog/src/plan/agg_index.rs b/src/query/catalog/src/plan/agg_index.rs index 12efe2937be56..65038a419ca8f 100644 --- a/src/query/catalog/src/plan/agg_index.rs +++ b/src/query/catalog/src/plan/agg_index.rs @@ -51,7 +51,7 @@ pub struct AggIndexInfo { } /// This meta just indicate the block is from aggregating index. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub struct AggIndexMeta { pub is_agg: bool, // Number of aggregation functions. @@ -75,6 +75,6 @@ local_block_meta_serde!(AggIndexMeta); #[typetag::serde(name = "agg_index_meta")] impl BlockMetaInfo for AggIndexMeta { fn clone_self(&self) -> Box { - Box::new(self.clone()) + Box::new(*self) } } diff --git a/src/query/expression/src/aggregate/aggregate_function_state.rs b/src/query/expression/src/aggregate/aggregate_function_state.rs index 0c07ee1aaab33..dafd4321b0c25 100644 --- a/src/query/expression/src/aggregate/aggregate_function_state.rs +++ b/src/query/expression/src/aggregate/aggregate_function_state.rs @@ -228,6 +228,10 @@ pub struct StatesLayout { } impl StatesLayout { + pub fn num_aggr_func(&self) -> usize { + self.states_loc.len() + } + pub fn serialize_builders(&self, num_rows: usize) -> Vec { self.serialize_type .iter() diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs index e40e7c80e2cd8..f856c17bbe135 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs @@ -150,7 +150,7 @@ impl TransformPartialAggregate { .params .states_layout .as_ref() - .map(|layout| layout.states_loc.len()) + .map(|layout| layout.num_aggr_func()) .unwrap_or(0); ( vec![], diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs index 549a2ebc9765b..12992aadf4c1e 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs @@ -93,18 +93,19 @@ impl AccumulatingTransform for PartialSingleStateAggregator { self.first_block_start = Some(Instant::now()); } - let is_agg_index_block = block + let meta = block .get_meta() .and_then(AggIndexMeta::downcast_ref_from) - .map(|index| index.is_agg) - .unwrap_or_default(); + .copied(); let block = block.consume_convert_to_full(); - if is_agg_index_block { + if let Some(meta) = meta + && meta.is_agg + { + assert_eq!(self.states_layout.num_aggr_func(), meta.num_agg_funcs); // Aggregation states are in the back of the block. - let states_indices = (block.num_columns() - self.states_layout.states_loc.len() - ..block.num_columns()) - .collect::>(); + let start = block.num_columns() - self.states_layout.num_aggr_func(); + let states_indices = (start..block.num_columns()).collect::>(); let states = ProjectedBlock::project(&states_indices, &block); for ((place, func), state) in self diff --git a/src/query/sql/src/planner/optimizer/optimizers/operator/aggregate/stats_aggregate.rs b/src/query/sql/src/planner/optimizer/optimizers/operator/aggregate/stats_aggregate.rs index cff4c1fddaee3..9f8eb81c06b31 100644 --- a/src/query/sql/src/planner/optimizer/optimizers/operator/aggregate/stats_aggregate.rs +++ b/src/query/sql/src/planner/optimizer/optimizers/operator/aggregate/stats_aggregate.rs @@ -129,37 +129,48 @@ impl RuleStatsAggregateOptimizer { for (need_rewrite_agg, agg) in need_rewrite_aggs.iter().zip(agg.aggregate_functions.iter()) { - if matches!(agg.scalar, ScalarExpr::UDAFCall(_)) { - agg_results.push(agg.clone()); - continue; - } - - let agg_func = AggregateFunction::try_from(agg.scalar.clone())?; - - if let Some((col_id, name)) = need_rewrite_agg { - if let Some(stat) = stats.get(col_id) { - let value_bound = if name.eq_ignore_ascii_case("min") { - &stat.min - } else { - &stat.max - }; - if !value_bound.may_be_truncated { - let scalar = ScalarExpr::ConstantExpr(ConstantExpr { - span: agg.scalar.span(), - value: value_bound.value.clone(), - }); - - let scalar = - scalar.unify_to_data_type(agg_func.return_type.as_ref()); - - eval_scalar_results.push(ScalarItem { - index: agg.index, - scalar, - }); - continue; + let column = if let ScalarExpr::UDAFCall(udaf) = &agg.scalar { + ColumnBindingBuilder::new( + udaf.display_name.clone(), + agg.index, + udaf.return_type.clone(), + Visibility::Visible, + ) + .build() + } else { + let agg_func = AggregateFunction::try_from(agg.scalar.clone())?; + if let Some((col_id, name)) = need_rewrite_agg { + if let Some(stat) = stats.get(col_id) { + let value_bound = if name.eq_ignore_ascii_case("min") { + &stat.min + } else { + &stat.max + }; + if !value_bound.may_be_truncated { + let scalar = ScalarExpr::ConstantExpr(ConstantExpr { + span: agg.scalar.span(), + value: value_bound.value.clone(), + }); + + let scalar = scalar + .unify_to_data_type(agg_func.return_type.as_ref()); + + eval_scalar_results.push(ScalarItem { + index: agg.index, + scalar, + }); + continue; + } } } - } + ColumnBindingBuilder::new( + agg_func.display_name.clone(), + agg.index, + agg_func.return_type.clone(), + Visibility::Visible, + ) + .build() + }; // Add other aggregate functions as derived column, // this will be used in aggregating index rewrite. @@ -167,13 +178,7 @@ impl RuleStatsAggregateOptimizer { index: agg.index, scalar: ScalarExpr::BoundColumnRef(BoundColumnRef { span: agg.scalar.span(), - column: ColumnBindingBuilder::new( - agg_func.display_name.clone(), - agg.index, - agg_func.return_type.clone(), - Visibility::Visible, - ) - .build(), + column, }), }); agg_results.push(agg.clone()); diff --git a/tests/sqllogictests/suites/ee/02_ee_aggregating_index/02_0001_agg_index_projected_scan.test b/tests/sqllogictests/suites/ee/02_ee_aggregating_index/02_0001_agg_index_projected_scan.test index c5525dd6f40a4..f7b09f81240b3 100644 --- a/tests/sqllogictests/suites/ee/02_ee_aggregating_index/02_0001_agg_index_projected_scan.test +++ b/tests/sqllogictests/suites/ee/02_ee_aggregating_index/02_0001_agg_index_projected_scan.test @@ -112,5 +112,7 @@ SELECT MAX(a), MIN(b), weighted_avg(a,b) from t group by b; 2 2 1.3333334 1 1 1.0 -# fix me -# SELECT MAX(a), MIN(b), weighted_avg(a,b) from t; +query IIR +SELECT MAX(a), MIN(b), weighted_avg(a,b) from t; +---- +2 1 1.2857143 From f16744f0dcff1eb134ec9aea1114aacf0fb2daff Mon Sep 17 00:00:00 2001 From: coldWater Date: Thu, 24 Jul 2025 11:51:34 +0800 Subject: [PATCH 11/22] fix --- .../adaptors/aggregate_null_adaptor.rs | 17 +++++++++++++---- .../adaptors/aggregate_ornull_adaptor.rs | 9 +++++++-- .../src/aggregates/aggregate_combinator_if.rs | 8 ++++++-- .../aggregates/aggregate_combinator_state.rs | 9 +++++++-- 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs index 3eb1a017553f1..c1e8babd47e5b 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs @@ -15,6 +15,7 @@ use std::fmt; use std::sync::Arc; +use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; @@ -189,7 +190,9 @@ impl AggregateFunction for AggregateNullUnaryAdapto } fn serialize_binary(&self, _: AggrState, _: &mut Vec) -> Result<()> { - unreachable!() + Err(ErrorCode::Internal( + "Calls to serialize_binary should be refactored to calls to serialize", + )) } fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { @@ -197,7 +200,9 @@ impl AggregateFunction for AggregateNullUnaryAdapto } fn merge_binary(&self, _: AggrState, _: &mut &[u8]) -> Result<()> { - unreachable!() + Err(ErrorCode::Internal( + "Calls to merge_binary should be refactored to calls to merge", + )) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { @@ -318,7 +323,9 @@ impl AggregateFunction } fn serialize_binary(&self, _: AggrState, _: &mut Vec) -> Result<()> { - unreachable!() + Err(ErrorCode::Internal( + "Calls to serialize_binary should be refactored to calls to serialize", + )) } fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { @@ -326,7 +333,9 @@ impl AggregateFunction } fn merge_binary(&self, _: AggrState, _: &mut &[u8]) -> Result<()> { - unreachable!() + Err(ErrorCode::Internal( + "Calls to merge_binary should be refactored to calls to merge", + )) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs index 513ffaec97f3e..3f90570d1c5a4 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs @@ -15,6 +15,7 @@ use std::fmt; use std::sync::Arc; +use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; @@ -198,7 +199,9 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { } fn serialize_binary(&self, _: AggrState, _: &mut Vec) -> Result<()> { - unreachable!() + Err(ErrorCode::Internal( + "Calls to serialize_binary should be refactored to calls to serialize", + )) } fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { @@ -210,7 +213,9 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { } fn merge_binary(&self, _: AggrState, _: &mut &[u8]) -> Result<()> { - unreachable!() + Err(ErrorCode::Internal( + "Calls to merge_binary should be refactored to calls to merge", + )) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_combinator_if.rs b/src/query/functions/src/aggregates/aggregate_combinator_if.rs index eb737aeb18948..622d057b92d48 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_if.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_if.rs @@ -164,7 +164,9 @@ impl AggregateFunction for AggregateIfCombinator { } fn serialize_binary(&self, _: AggrState, _: &mut Vec) -> Result<()> { - unreachable!() + Err(ErrorCode::Internal( + "Calls to serialize_binary should be refactored to calls to serialize", + )) } fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { @@ -172,7 +174,9 @@ impl AggregateFunction for AggregateIfCombinator { } fn merge_binary(&self, _: AggrState, _: &mut &[u8]) -> Result<()> { - unreachable!() + Err(ErrorCode::Internal( + "Calls to merge_binary should be refactored to calls to merge", + )) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_combinator_state.rs b/src/query/functions/src/aggregates/aggregate_combinator_state.rs index f0ebbb1d4f411..c756326e46d57 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_state.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_state.rs @@ -15,6 +15,7 @@ use std::fmt; use std::sync::Arc; +use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; @@ -117,7 +118,9 @@ impl AggregateFunction for AggregateStateCombinator { } fn serialize_binary(&self, _: AggrState, _: &mut Vec) -> Result<()> { - unreachable!() + Err(ErrorCode::Internal( + "Calls to serialize_binary should be refactored to calls to serialize", + )) } fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { @@ -125,7 +128,9 @@ impl AggregateFunction for AggregateStateCombinator { } fn merge_binary(&self, _: AggrState, _: &mut &[u8]) -> Result<()> { - unreachable!() + Err(ErrorCode::Internal( + "Calls to merge_binary should be refactored to calls to merge", + )) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { From 8cac94069b43140205a9810a66ef5b3ae81ea814 Mon Sep 17 00:00:00 2001 From: coldWater Date: Thu, 24 Jul 2025 13:14:07 +0800 Subject: [PATCH 12/22] refine --- .../src/aggregate/aggregate_function.rs | 16 ++---------- .../adaptors/aggregate_null_adaptor.rs | 25 ------------------ .../adaptors/aggregate_ornull_adaptor.rs | 13 ---------- .../adaptors/aggregate_sort_adaptor.rs | 16 +++++++++--- .../src/aggregates/aggregate_arg_min_max.rs | 13 +++++++--- .../src/aggregates/aggregate_array_agg.rs | 12 ++++++--- .../src/aggregates/aggregate_array_moving.rs | 24 +++++++++++------ .../src/aggregates/aggregate_bitmap.rs | 26 +++++++++++-------- .../aggregate_combinator_distinct.rs | 13 +++++++--- .../src/aggregates/aggregate_combinator_if.rs | 12 --------- .../aggregates/aggregate_combinator_state.rs | 13 ---------- .../src/aggregates/aggregate_count.rs | 13 +++++++--- .../src/aggregates/aggregate_covariance.rs | 13 +++++++--- .../aggregates/aggregate_json_array_agg.rs | 12 ++++++--- .../aggregates/aggregate_json_object_agg.rs | 12 ++++++--- .../src/aggregates/aggregate_markov_tarin.rs | 13 +++++++--- .../src/aggregates/aggregate_null_result.rs | 8 ++++-- .../aggregates/aggregate_quantile_tdigest.rs | 12 ++++++--- .../aggregate_quantile_tdigest_weighted.rs | 13 +++++++--- .../src/aggregates/aggregate_retention.rs | 13 +++++++--- .../src/aggregates/aggregate_st_collect.rs | 12 ++++++--- .../src/aggregates/aggregate_string_agg.rs | 13 +++++++--- .../src/aggregates/aggregate_unary.rs | 13 +++++++--- .../src/aggregates/aggregate_window_funnel.rs | 14 +++++++--- .../transforms/aggregator/udaf_script.rs | 17 +++++++----- 25 files changed, 193 insertions(+), 168 deletions(-) diff --git a/src/query/expression/src/aggregate/aggregate_function.rs b/src/query/expression/src/aggregate/aggregate_function.rs index 57b4095abd148..3e778832174ba 100755 --- a/src/query/expression/src/aggregate/aggregate_function.rs +++ b/src/query/expression/src/aggregate/aggregate_function.rs @@ -77,21 +77,9 @@ pub trait AggregateFunction: fmt::Display + Sync + Send { serde_type.data_type() } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { - let binary_builder = builders[0].as_binary_mut().unwrap(); - self.serialize_binary(place, &mut binary_builder.data)?; - binary_builder.commit_row(); - Ok(()) - } - - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()>; - - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - let mut binary = *data[0].as_binary().unwrap(); - self.merge_binary(place, &mut binary) - } + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()>; - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()>; + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()>; /// Batch merge and deserialize the state from binary array fn batch_merge( diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs index c1e8babd47e5b..ace936d1e6a02 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs @@ -15,7 +15,6 @@ use std::fmt; use std::sync::Arc; -use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; @@ -189,22 +188,10 @@ impl AggregateFunction for AggregateNullUnaryAdapto self.0.serialize(place, builders) } - fn serialize_binary(&self, _: AggrState, _: &mut Vec) -> Result<()> { - Err(ErrorCode::Internal( - "Calls to serialize_binary should be refactored to calls to serialize", - )) - } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { self.0.merge(place, data) } - fn merge_binary(&self, _: AggrState, _: &mut &[u8]) -> Result<()> { - Err(ErrorCode::Internal( - "Calls to merge_binary should be refactored to calls to merge", - )) - } - fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { self.0.merge_states(place, rhs) } @@ -322,22 +309,10 @@ impl AggregateFunction self.0.serialize(place, builders) } - fn serialize_binary(&self, _: AggrState, _: &mut Vec) -> Result<()> { - Err(ErrorCode::Internal( - "Calls to serialize_binary should be refactored to calls to serialize", - )) - } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { self.0.merge(place, data) } - fn merge_binary(&self, _: AggrState, _: &mut &[u8]) -> Result<()> { - Err(ErrorCode::Internal( - "Calls to merge_binary should be refactored to calls to merge", - )) - } - fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { self.0.merge_states(place, rhs) } diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs index 3f90570d1c5a4..0b393b8810af4 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs @@ -15,7 +15,6 @@ use std::fmt; use std::sync::Arc; -use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; @@ -198,12 +197,6 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { .serialize(place.remove_last_loc(), &mut builders[..n - 1]) } - fn serialize_binary(&self, _: AggrState, _: &mut Vec) -> Result<()> { - Err(ErrorCode::Internal( - "Calls to serialize_binary should be refactored to calls to serialize", - )) - } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { let flag = get_flag(place) || *data.last().and_then(ScalarRef::as_boolean).unwrap(); self.inner @@ -212,12 +205,6 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { Ok(()) } - fn merge_binary(&self, _: AggrState, _: &mut &[u8]) -> Result<()> { - Err(ErrorCode::Internal( - "Calls to merge_binary should be refactored to calls to merge", - )) - } - fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { self.inner .merge_states(place.remove_last_loc(), rhs.remove_last_loc())?; diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_sort_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_sort_adaptor.rs index 5c0c37500a423..babf42ad2bce4 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_sort_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_sort_adaptor.rs @@ -126,14 +126,22 @@ impl AggregateFunction for AggregateFunctionSortAdaptor { vec![StateSerdeItem::Binary(None)] } - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); let state = Self::get_state(place); - Ok(state.serialize(writer)?) + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + Ok(()) } - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge( + &self, + place: AggrState, + data: &[databend_common_expression::ScalarRef], + ) -> Result<()> { + let mut binary = *data[0].as_binary().unwrap(); let state = Self::get_state(place); - let rhs = SortAggState::deserialize(reader)?; + let rhs = SortAggState::deserialize(&mut binary)?; Self::merge_states_inner(state, &rhs); Ok(()) diff --git a/src/query/functions/src/aggregates/aggregate_arg_min_max.rs b/src/query/functions/src/aggregates/aggregate_arg_min_max.rs index a9043c479cda8..f66bd2ec729d4 100644 --- a/src/query/functions/src/aggregates/aggregate_arg_min_max.rs +++ b/src/query/functions/src/aggregates/aggregate_arg_min_max.rs @@ -31,6 +31,7 @@ use databend_common_expression::ColumnBuilder; use databend_common_expression::ColumnView; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use super::aggregate_function_factory::AggregateFunctionDescription; @@ -275,14 +276,18 @@ where vec![StateSerdeItem::Binary(None)] } - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); let state = place.get::(); - Ok(state.serialize(writer)?) + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + Ok(()) } - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + let mut binary = *data[0].as_binary().unwrap(); let state = place.get::(); - let rhs: State = borsh_partial_deserialize(reader)?; + let rhs: State = borsh_partial_deserialize(&mut binary)?; state.merge_from(rhs) } diff --git a/src/query/functions/src/aggregates/aggregate_array_agg.rs b/src/query/functions/src/aggregates/aggregate_array_agg.rs index ce44a9d3dc691..98935efdee83e 100644 --- a/src/query/functions/src/aggregates/aggregate_array_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_array_agg.rs @@ -553,14 +553,18 @@ where vec![StateSerdeItem::Binary(None)] } - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); let state = place.get::(); - Ok(state.serialize(writer)?) + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + Ok(()) } - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + let mut binary = *data[0].as_binary().unwrap(); let state = place.get::(); - let rhs = State::deserialize_reader(reader)?; + let rhs = State::deserialize_reader(&mut binary)?; state.merge(&rhs) } diff --git a/src/query/functions/src/aggregates/aggregate_array_moving.rs b/src/query/functions/src/aggregates/aggregate_array_moving.rs index 423af8229e265..c1742dd501842 100644 --- a/src/query/functions/src/aggregates/aggregate_array_moving.rs +++ b/src/query/functions/src/aggregates/aggregate_array_moving.rs @@ -452,14 +452,18 @@ where State: SumState vec![StateSerdeItem::Binary(None)] } - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); let state = place.get::(); - Ok(state.serialize(writer)?) + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + Ok(()) } - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + let mut binary = *data[0].as_binary().unwrap(); let state = place.get::(); - let rhs = State::deserialize_reader(reader)?; + let rhs = State::deserialize_reader(&mut binary)?; state.merge(&rhs) } @@ -625,14 +629,18 @@ where State: SumState vec![StateSerdeItem::Binary(None)] } - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); let state = place.get::(); - Ok(state.serialize(writer)?) + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + Ok(()) } - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + let mut binary = *data[0].as_binary().unwrap(); let state = place.get::(); - let rhs = State::deserialize_reader(reader)?; + let rhs = State::deserialize_reader(&mut binary)?; state.merge(&rhs) } diff --git a/src/query/functions/src/aggregates/aggregate_bitmap.rs b/src/query/functions/src/aggregates/aggregate_bitmap.rs index 348749834070a..d3c0f331be0aa 100644 --- a/src/query/functions/src/aggregates/aggregate_bitmap.rs +++ b/src/query/functions/src/aggregates/aggregate_bitmap.rs @@ -36,6 +36,7 @@ use databend_common_expression::AggrStateType; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use databend_common_io::prelude::BinaryWrite; use roaring::RoaringTreemap; @@ -292,24 +293,27 @@ where vec![StateSerdeItem::Binary(None)] } - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); let state = place.get::(); // flag indicate where bitmap is none let flag: u8 = if state.rb.is_some() { 1 } else { 0 }; - writer.write_scalar(&flag)?; + binary_builder.data.write_scalar(&flag)?; if let Some(rb) = &state.rb { - rb.serialize_into(writer)?; + rb.serialize_into(&mut binary_builder.data)?; } + binary_builder.commit_row(); Ok(()) } - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + let mut binary = *data[0].as_binary().unwrap(); let state = place.get::(); - let flag = reader[0]; - reader.consume(1); + let flag = binary[0]; + binary.consume(1); if flag == 1 { - let rb = deserialize_bitmap(reader)?; + let rb = deserialize_bitmap(binary)?; state.add::(rb); } Ok(()) @@ -491,12 +495,12 @@ where vec![StateSerdeItem::Binary(None)] } - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - self.inner.serialize_binary(place, writer) + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + self.inner.serialize(place, builders) } - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - self.inner.merge_binary(place, reader) + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + self.inner.merge(place, data) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs b/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs index 4f6ec3fafefd9..989964ae1babd 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs @@ -28,6 +28,7 @@ use databend_common_expression::AggrStateType; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use super::aggregate_distinct_state::AggregateDistinctNumberState; @@ -111,14 +112,18 @@ where State: DistinctStateFunc vec![StateSerdeItem::Binary(None)] } - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); let state = Self::get_state(place); - state.serialize(writer) + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + Ok(()) } - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + let mut binary = *data[0].as_binary().unwrap(); let state = Self::get_state(place); - let rhs = State::deserialize(reader)?; + let rhs = State::deserialize(&mut binary)?; state.merge(&rhs) } diff --git a/src/query/functions/src/aggregates/aggregate_combinator_if.rs b/src/query/functions/src/aggregates/aggregate_combinator_if.rs index 622d057b92d48..00c959a29a108 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_if.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_if.rs @@ -163,22 +163,10 @@ impl AggregateFunction for AggregateIfCombinator { self.nested.serialize(place, builders) } - fn serialize_binary(&self, _: AggrState, _: &mut Vec) -> Result<()> { - Err(ErrorCode::Internal( - "Calls to serialize_binary should be refactored to calls to serialize", - )) - } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { self.nested.merge(place, data) } - fn merge_binary(&self, _: AggrState, _: &mut &[u8]) -> Result<()> { - Err(ErrorCode::Internal( - "Calls to merge_binary should be refactored to calls to merge", - )) - } - fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { self.nested.merge_states(place, rhs) } diff --git a/src/query/functions/src/aggregates/aggregate_combinator_state.rs b/src/query/functions/src/aggregates/aggregate_combinator_state.rs index c756326e46d57..1183bc20e9dc9 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_state.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_state.rs @@ -15,7 +15,6 @@ use std::fmt; use std::sync::Arc; -use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; @@ -117,22 +116,10 @@ impl AggregateFunction for AggregateStateCombinator { self.nested.serialize(place, builders) } - fn serialize_binary(&self, _: AggrState, _: &mut Vec) -> Result<()> { - Err(ErrorCode::Internal( - "Calls to serialize_binary should be refactored to calls to serialize", - )) - } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { self.nested.merge(place, data) } - fn merge_binary(&self, _: AggrState, _: &mut &[u8]) -> Result<()> { - Err(ErrorCode::Internal( - "Calls to merge_binary should be refactored to calls to merge", - )) - } - fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { self.nested.merge_states(place, rhs) } diff --git a/src/query/functions/src/aggregates/aggregate_count.rs b/src/query/functions/src/aggregates/aggregate_count.rs index 64c1cb850f8ca..280932b89f2b7 100644 --- a/src/query/functions/src/aggregates/aggregate_count.rs +++ b/src/query/functions/src/aggregates/aggregate_count.rs @@ -29,6 +29,7 @@ use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use super::aggregate_function::AggregateFunction; @@ -165,14 +166,18 @@ impl AggregateFunction for AggregateCountFunction { vec![StateSerdeItem::Binary(None)] } - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); let state = place.get::(); - Ok(state.count.serialize(writer)?) + state.count.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + Ok(()) } - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + let mut binary = *data[0].as_binary().unwrap(); let state = place.get::(); - let other: u64 = borsh_partial_deserialize(reader)?; + let other: u64 = borsh_partial_deserialize(&mut binary)?; state.count += other; Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_covariance.rs b/src/query/functions/src/aggregates/aggregate_covariance.rs index 774c955679255..fd30fb84cb6c8 100644 --- a/src/query/functions/src/aggregates/aggregate_covariance.rs +++ b/src/query/functions/src/aggregates/aggregate_covariance.rs @@ -34,6 +34,7 @@ use databend_common_expression::AggrStateType; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use num_traits::AsPrimitive; @@ -236,14 +237,18 @@ where vec![StateSerdeItem::Binary(None)] } - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); let state = place.get::(); - Ok(state.serialize(writer)?) + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + Ok(()) } - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + let mut binary = *data[0].as_binary().unwrap(); let state = place.get::(); - let rhs: AggregateCovarianceState = borsh_partial_deserialize(reader)?; + let rhs: AggregateCovarianceState = borsh_partial_deserialize(&mut binary)?; state.merge(&rhs); Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_json_array_agg.rs b/src/query/functions/src/aggregates/aggregate_json_array_agg.rs index 69cc6b44ca7c3..642c15a81fff6 100644 --- a/src/query/functions/src/aggregates/aggregate_json_array_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_json_array_agg.rs @@ -245,14 +245,18 @@ where vec![StateSerdeItem::Binary(None)] } - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); let state = place.get::(); - Ok(state.serialize(writer)?) + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + Ok(()) } - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + let mut binary = *data[0].as_binary().unwrap(); let state = place.get::(); - let rhs: State = borsh_partial_deserialize(reader)?; + let rhs: State = borsh_partial_deserialize(&mut binary)?; state.merge(&rhs) } diff --git a/src/query/functions/src/aggregates/aggregate_json_object_agg.rs b/src/query/functions/src/aggregates/aggregate_json_object_agg.rs index e9a0d72f9de98..2d5782c75b6a5 100644 --- a/src/query/functions/src/aggregates/aggregate_json_object_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_json_object_agg.rs @@ -298,14 +298,18 @@ where vec![StateSerdeItem::Binary(None)] } - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); let state = place.get::(); - Ok(state.serialize(writer)?) + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + Ok(()) } - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + let mut binary = *data[0].as_binary().unwrap(); let state = place.get::(); - let rhs: State = borsh_partial_deserialize(reader)?; + let rhs: State = borsh_partial_deserialize(&mut binary)?; state.merge(&rhs) } diff --git a/src/query/functions/src/aggregates/aggregate_markov_tarin.rs b/src/query/functions/src/aggregates/aggregate_markov_tarin.rs index e0b485983bf26..1293c512aab88 100644 --- a/src/query/functions/src/aggregates/aggregate_markov_tarin.rs +++ b/src/query/functions/src/aggregates/aggregate_markov_tarin.rs @@ -40,6 +40,7 @@ use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use super::aggregate_function_factory::AggregateFunctionDescription; @@ -118,14 +119,18 @@ impl AggregateFunction for MarkovTarin { vec![StateSerdeItem::Binary(None)] } - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); let state = place.get::(); - Ok(state.serialize(writer)?) + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + Ok(()) } - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + let mut binary = *data[0].as_binary().unwrap(); let state = place.get::(); - let mut rhs = borsh_partial_deserialize::(reader)?; + let mut rhs = borsh_partial_deserialize::(&mut binary)?; state.merge(&mut rhs); Ok(()) diff --git a/src/query/functions/src/aggregates/aggregate_null_result.rs b/src/query/functions/src/aggregates/aggregate_null_result.rs index 5333ba8910445..66e5a385e4b41 100644 --- a/src/query/functions/src/aggregates/aggregate_null_result.rs +++ b/src/query/functions/src/aggregates/aggregate_null_result.rs @@ -25,6 +25,7 @@ use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; +use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use super::aggregate_function::AggregateFunction; @@ -91,11 +92,14 @@ impl AggregateFunction for AggregateNullResultFunction { vec![StateSerdeItem::Binary(None)] } - fn serialize_binary(&self, _place: AggrState, _writer: &mut Vec) -> Result<()> { + fn serialize(&self, _place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); + binary_builder.commit_row(); Ok(()) } - fn merge_binary(&self, _place: AggrState, _reader: &mut &[u8]) -> Result<()> { + fn merge(&self, _place: AggrState, data: &[ScalarRef]) -> Result<()> { + let _binary = *data[0].as_binary().unwrap(); Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs b/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs index 3e3dae0759936..e09146cfb2917 100644 --- a/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs +++ b/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs @@ -368,14 +368,18 @@ where for<'a> T: AccessType = F64> + Send + Sync vec![StateSerdeItem::Binary(None)] } - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); let state = place.get::(); - Ok(state.serialize(writer)?) + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + Ok(()) } - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + let mut binary = *data[0].as_binary().unwrap(); let state = place.get::(); - let mut rhs: QuantileTDigestState = borsh_partial_deserialize(reader)?; + let mut rhs: QuantileTDigestState = borsh_partial_deserialize(&mut binary)?; state.merge(&mut rhs) } diff --git a/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs b/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs index dfc03b876c35a..46afac0dd7b54 100644 --- a/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs +++ b/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs @@ -31,6 +31,7 @@ use databend_common_expression::AggrStateType; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use num_traits::AsPrimitive; @@ -149,14 +150,18 @@ where vec![StateSerdeItem::Binary(None)] } - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); let state = place.get::(); - Ok(state.serialize(writer)?) + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + Ok(()) } - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + let mut binary = *data[0].as_binary().unwrap(); let state = place.get::(); - let mut rhs: QuantileTDigestState = borsh_partial_deserialize(reader)?; + let mut rhs: QuantileTDigestState = borsh_partial_deserialize(&mut binary)?; state.merge(&mut rhs) } diff --git a/src/query/functions/src/aggregates/aggregate_retention.rs b/src/query/functions/src/aggregates/aggregate_retention.rs index a5d6780e73dc9..59d334f3bae8a 100644 --- a/src/query/functions/src/aggregates/aggregate_retention.rs +++ b/src/query/functions/src/aggregates/aggregate_retention.rs @@ -29,6 +29,7 @@ use databend_common_expression::AggrStateType; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use super::aggregate_function::AggregateFunction; @@ -148,14 +149,18 @@ impl AggregateFunction for AggregateRetentionFunction { vec![StateSerdeItem::Binary(None)] } - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); let state = place.get::(); - Ok(state.serialize(writer)?) + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + Ok(()) } - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + let mut binary = *data[0].as_binary().unwrap(); let state = place.get::(); - let rhs: AggregateRetentionState = borsh_partial_deserialize(reader)?; + let rhs: AggregateRetentionState = borsh_partial_deserialize(&mut binary)?; state.merge(&rhs); Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_st_collect.rs b/src/query/functions/src/aggregates/aggregate_st_collect.rs index 4657955f0dbb4..28ce7f30959e7 100644 --- a/src/query/functions/src/aggregates/aggregate_st_collect.rs +++ b/src/query/functions/src/aggregates/aggregate_st_collect.rs @@ -311,14 +311,18 @@ where vec![StateSerdeItem::Binary(None)] } - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); let state = place.get::(); - Ok(state.serialize(writer)?) + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + Ok(()) } - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + let mut binary = *data[0].as_binary().unwrap(); let state = place.get::(); - let rhs: State = borsh_partial_deserialize(reader)?; + let rhs: State = borsh_partial_deserialize(&mut binary)?; state.merge(&rhs) } diff --git a/src/query/functions/src/aggregates/aggregate_string_agg.rs b/src/query/functions/src/aggregates/aggregate_string_agg.rs index 4bf05651aeae7..10b4350cd32f3 100644 --- a/src/query/functions/src/aggregates/aggregate_string_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_string_agg.rs @@ -34,6 +34,7 @@ use databend_common_expression::Evaluator; use databend_common_expression::FunctionContext; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use super::aggregate_function_factory::AggregateFunctionDescription; @@ -152,14 +153,18 @@ impl AggregateFunction for AggregateStringAggFunction { vec![StateSerdeItem::Binary(None)] } - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); let state = place.get::(); - Ok(state.serialize(writer)?) + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + Ok(()) } - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + let mut binary = *data[0].as_binary().unwrap(); let state = place.get::(); - let rhs: StringAggState = borsh_partial_deserialize(reader)?; + let rhs: StringAggState = borsh_partial_deserialize(&mut binary)?; state.values.push_str(&rhs.values); Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_unary.rs b/src/query/functions/src/aggregates/aggregate_unary.rs index a18ba2871ede3..ab08ea35eb08a 100644 --- a/src/query/functions/src/aggregates/aggregate_unary.rs +++ b/src/query/functions/src/aggregates/aggregate_unary.rs @@ -31,6 +31,7 @@ use databend_common_expression::AggregateFunctionRef; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::ScalarRef; use databend_common_expression::StateAddr; use databend_common_expression::StateSerdeItem; @@ -232,14 +233,18 @@ where vec![StateSerdeItem::Binary(None)] } - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); let state: &mut S = place.get::(); - Ok(state.serialize(writer)?) + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + Ok(()) } - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + let mut binary = *data[0].as_binary().unwrap(); let state: &mut S = place.get::(); - let rhs = S::deserialize_reader(reader)?; + let rhs = S::deserialize_reader(&mut binary)?; state.merge(&rhs) } diff --git a/src/query/functions/src/aggregates/aggregate_window_funnel.rs b/src/query/functions/src/aggregates/aggregate_window_funnel.rs index d87dc3a5c0683..616ec6e0e26fc 100644 --- a/src/query/functions/src/aggregates/aggregate_window_funnel.rs +++ b/src/query/functions/src/aggregates/aggregate_window_funnel.rs @@ -40,6 +40,7 @@ use databend_common_expression::AggrStateType; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use num_traits::AsPrimitive; @@ -280,14 +281,19 @@ where vec![StateSerdeItem::Binary(None)] } - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); let state = place.get::>(); - Ok(state.serialize(writer)?) + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + Ok(()) } - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + let mut binary = *data[0].as_binary().unwrap(); let state = place.get::>(); - let mut rhs: AggregateWindowFunnelState = borsh_partial_deserialize(reader)?; + let mut rhs: AggregateWindowFunnelState = + borsh_partial_deserialize(&mut binary)?; state.merge(&mut rhs); Ok(()) } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs index 748e0d89b9252..eac893078c0d9 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs @@ -39,6 +39,7 @@ use databend_common_expression::DataBlock; use databend_common_expression::DataField; use databend_common_expression::DataSchema; use databend_common_expression::ProjectedBlock; +use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use databend_common_functions::aggregates::AggregateFunction; use databend_common_sql::plans::UDFLanguage; @@ -103,17 +104,21 @@ impl AggregateFunction for AggregateUdfScript { vec![StateSerdeItem::Binary(None)] } - fn serialize_binary(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); let state = place.get::(); state - .serialize(writer) - .map_err(|e| ErrorCode::Internal(format!("state failed to serialize: {e}"))) + .serialize(&mut binary_builder.data) + .map_err(|e| ErrorCode::Internal(format!("state failed to serialize: {e}")))?; + binary_builder.commit_row(); + Ok(()) } - fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + let mut binary = *data[0].as_binary().unwrap(); let state = place.get::(); - let rhs = - UdfAggState::deserialize(reader).map_err(|e| ErrorCode::Internal(e.to_string()))?; + let rhs = UdfAggState::deserialize(&mut binary) + .map_err(|e| ErrorCode::Internal(e.to_string()))?; let states = arrow_select::concat::concat(&[&state.0, &rhs.0])?; let state = self .runtime From 34756db23a9f81d055c4cd3bff6bac52a52393f9 Mon Sep 17 00:00:00 2001 From: coldWater Date: Thu, 24 Jul 2025 15:42:52 +0800 Subject: [PATCH 13/22] tuple --- .../src/aggregate/aggregate_function.rs | 82 +++- src/query/expression/src/types.rs | 2 + src/query/expression/src/types/tuple.rs | 411 ++++++++++++++++++ 3 files changed, 485 insertions(+), 10 deletions(-) create mode 100644 src/query/expression/src/types/tuple.rs diff --git a/src/query/expression/src/aggregate/aggregate_function.rs b/src/query/expression/src/aggregate/aggregate_function.rs index 3e778832174ba..527790071f1be 100755 --- a/src/query/expression/src/aggregate/aggregate_function.rs +++ b/src/query/expression/src/aggregate/aggregate_function.rs @@ -22,6 +22,10 @@ use super::AggrState; use super::AggrStateLoc; use super::AggrStateRegistry; use super::StateAddr; +use crate::types::AnyPairType; +use crate::types::AnyQuaternaryType; +use crate::types::AnyTernaryType; +use crate::types::AnyUnaryType; use crate::types::DataType; use crate::BlockEntry; use crate::ColumnBuilder; @@ -81,28 +85,86 @@ pub trait AggregateFunction: fmt::Display + Sync + Send { fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()>; - /// Batch merge and deserialize the state from binary array + /// Batch deserialize the state and merge fn batch_merge( &self, places: &[StateAddr], loc: &[AggrStateLoc], state: &BlockEntry, ) -> Result<()> { - let column = state.to_column(); - for (place, data) in places.iter().zip(column.iter()) { - self.merge( - AggrState::new(*place, loc), - data.as_tuple().unwrap().as_slice(), - )?; + match state.data_type().as_tuple().unwrap().len() { + 1 => { + let view = state.downcast::().unwrap(); + for (place, data) in places.iter().zip(view.iter()) { + self.merge(AggrState::new(*place, loc), &[data])?; + } + } + 2 => { + let view = state.downcast::().unwrap(); + for (place, data) in places.iter().zip(view.iter()) { + self.merge(AggrState::new(*place, loc), &[data.0, data.1])?; + } + } + 3 => { + let view = state.downcast::().unwrap(); + for (place, data) in places.iter().zip(view.iter()) { + self.merge(AggrState::new(*place, loc), &[data.0, data.1, data.2])?; + } + } + 4 => { + let view = state.downcast::().unwrap(); + for (place, data) in places.iter().zip(view.iter()) { + self.merge(AggrState::new(*place, loc), &[ + data.0, data.1, data.2, data.3, + ])?; + } + } + _ => { + let state = state.to_column(); + for (place, data) in places.iter().zip(state.iter()) { + self.merge( + AggrState::new(*place, loc), + data.as_tuple().unwrap().as_slice(), + )?; + } + } } Ok(()) } fn batch_merge_single(&self, place: AggrState, state: &BlockEntry) -> Result<()> { - let column = state.to_column(); - for data in column.iter() { - self.merge(place, data.as_tuple().unwrap().as_slice())?; + match state.data_type().as_tuple().unwrap().len() { + 1 => { + let view = state.downcast::().unwrap(); + for data in view.iter() { + self.merge(place, &[data])?; + } + } + 2 => { + let view = state.downcast::().unwrap(); + for data in view.iter() { + self.merge(place, &[data.0, data.1])?; + } + } + 3 => { + let view = state.downcast::().unwrap(); + for data in view.iter() { + self.merge(place, &[data.0, data.1, data.2])?; + } + } + 4 => { + let view = state.downcast::().unwrap(); + for data in view.iter() { + self.merge(place, &[data.0, data.1, data.2, data.3])?; + } + } + _ => { + let state = state.to_column(); + for data in state.iter() { + self.merge(place, data.as_tuple().unwrap().as_slice())?; + } + } } Ok(()) } diff --git a/src/query/expression/src/types.rs b/src/query/expression/src/types.rs index a0d2cf87eef19..4c69ecd6aa5ba 100755 --- a/src/query/expression/src/types.rs +++ b/src/query/expression/src/types.rs @@ -34,6 +34,7 @@ pub mod number_class; pub mod simple_type; pub mod string; pub mod timestamp; +pub mod tuple; pub mod variant; pub mod vector; pub mod zero_size_type; @@ -81,6 +82,7 @@ use self::simple_type::*; pub use self::string::StringColumn; pub use self::string::StringType; pub use self::timestamp::TimestampType; +pub use self::tuple::*; pub use self::variant::VariantType; pub use self::vector::VectorColumn; pub use self::vector::VectorColumnBuilder; diff --git a/src/query/expression/src/types/tuple.rs b/src/query/expression/src/types/tuple.rs new file mode 100644 index 0000000000000..732f8436917d0 --- /dev/null +++ b/src/query/expression/src/types/tuple.rs @@ -0,0 +1,411 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::cmp::Ordering; +use std::iter::Map; +use std::iter::Zip; +use std::marker::PhantomData; + +use super::AccessType; +use super::AnyType; +use crate::Column; +use crate::Domain; +use crate::ScalarRef; + +type Iter<'a, T> = ::ColumnIterator<'a>; + +type SR<'a, T> = ::ScalarRef<'a>; + +pub type AnyUnaryType = UnaryType; +pub type AnyPairType = PairType; +pub type AnyTernaryType = TernaryType; +pub type AnyQuaternaryType = QuaternaryType; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct UnaryType(PhantomData); + +impl AccessType for UnaryType +where T: AccessType +{ + type Scalar = T::Scalar; + type ScalarRef<'a> = T::ScalarRef<'a>; + type Column = T::Column; + type Domain = T::Domain; + type ColumnIterator<'a> = T::ColumnIterator<'a>; + + fn to_owned_scalar(scalar: Self::ScalarRef<'_>) -> Self::Scalar { + T::to_owned_scalar(scalar) + } + + fn to_scalar_ref(scalar: &Self::Scalar) -> Self::ScalarRef<'_> { + T::to_scalar_ref(scalar) + } + + fn try_downcast_scalar<'a>(scalar: &ScalarRef<'a>) -> Option> { + T::try_downcast_scalar(scalar) + } + + fn try_downcast_domain(domain: &Domain) -> Option { + T::try_downcast_domain(domain) + } + + fn try_downcast_column(col: &Column) -> Option { + T::try_downcast_column(col) + } + + fn column_len(col: &Self::Column) -> usize { + T::column_len(col) + } + + fn index_column(col: &Self::Column, index: usize) -> Option> { + T::index_column(col, index) + } + + unsafe fn index_column_unchecked(col: &Self::Column, index: usize) -> Self::ScalarRef<'_> { + T::index_column_unchecked(col, index) + } + + fn slice_column(col: &Self::Column, range: std::ops::Range) -> Self::Column { + T::slice_column(col, range) + } + + fn iter_column(col: &Self::Column) -> Self::ColumnIterator<'_> { + T::iter_column(col) + } + + fn compare(lhs: Self::ScalarRef<'_>, rhs: Self::ScalarRef<'_>) -> Ordering { + T::compare(lhs, rhs) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PairType(PhantomData<(A, B)>); + +impl AccessType for PairType +where + A: AccessType, + B: AccessType, +{ + type Scalar = (A::Scalar, B::Scalar); + type ScalarRef<'a> = (A::ScalarRef<'a>, B::ScalarRef<'a>); + type Column = (A::Column, B::Column); + type Domain = (A::Domain, B::Domain); + type ColumnIterator<'a> = Zip, Iter<'a, B>>; + + fn to_owned_scalar((a, b): Self::ScalarRef<'_>) -> Self::Scalar { + (A::to_owned_scalar(a), B::to_owned_scalar(b)) + } + + fn to_scalar_ref((a, b): &Self::Scalar) -> Self::ScalarRef<'_> { + (A::to_scalar_ref(a), B::to_scalar_ref(b)) + } + + fn try_downcast_scalar<'a>(scalar: &ScalarRef<'a>) -> Option> { + let [a, b] = scalar.as_tuple()?.as_slice() else { + return None; + }; + Some((A::try_downcast_scalar(a)?, B::try_downcast_scalar(b)?)) + } + + fn try_downcast_domain(domain: &Domain) -> Option { + let [a, b] = domain.as_tuple()?.as_slice() else { + return None; + }; + Some((A::try_downcast_domain(a)?, B::try_downcast_domain(b)?)) + } + + fn try_downcast_column(col: &Column) -> Option { + let [a, b] = col.as_tuple()?.as_slice() else { + return None; + }; + Some((A::try_downcast_column(a)?, B::try_downcast_column(b)?)) + } + + fn column_len((a, _b): &Self::Column) -> usize { + debug_assert_eq!(A::column_len(a), B::column_len(_b)); + A::column_len(a) + } + + fn index_column((a, b): &Self::Column, index: usize) -> Option> { + Some((A::index_column(a, index)?, B::index_column(b, index)?)) + } + + unsafe fn index_column_unchecked((a, b): &Self::Column, index: usize) -> Self::ScalarRef<'_> { + ( + A::index_column_unchecked(a, index), + B::index_column_unchecked(b, index), + ) + } + + fn slice_column((a, b): &Self::Column, range: std::ops::Range) -> Self::Column { + (A::slice_column(a, range.clone()), B::slice_column(b, range)) + } + + fn iter_column((a, b): &Self::Column) -> Self::ColumnIterator<'_> { + A::iter_column(a).zip(B::iter_column(b)) + } + + fn compare( + (lhs_a, lhs_b): Self::ScalarRef<'_>, + (rhs_a, rhs_b): Self::ScalarRef<'_>, + ) -> Ordering { + A::compare(lhs_a, rhs_a).then_with(|| B::compare(lhs_b, rhs_b)) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TernaryType(PhantomData<(A, B, C)>); + +impl AccessType for TernaryType +where + A: AccessType, + B: AccessType, + C: AccessType, +{ + type Scalar = (A::Scalar, B::Scalar, C::Scalar); + type ScalarRef<'a> = (A::ScalarRef<'a>, B::ScalarRef<'a>, C::ScalarRef<'a>); + type Column = (A::Column, B::Column, C::Column); + type Domain = (A::Domain, B::Domain, C::Domain); + type ColumnIterator<'a> = Map< + Zip, Iter<'a, B>>, Iter<'a, C>>, + fn(((SR<'a, A>, SR<'a, B>), SR<'a, C>)) -> Self::ScalarRef<'a>, + >; + + fn to_owned_scalar((a, b, c): Self::ScalarRef<'_>) -> Self::Scalar { + ( + A::to_owned_scalar(a), + B::to_owned_scalar(b), + C::to_owned_scalar(c), + ) + } + + fn to_scalar_ref((a, b, c): &Self::Scalar) -> Self::ScalarRef<'_> { + ( + A::to_scalar_ref(a), + B::to_scalar_ref(b), + C::to_scalar_ref(c), + ) + } + + fn try_downcast_scalar<'a>(scalar: &ScalarRef<'a>) -> Option> { + let [a, b, c] = scalar.as_tuple()?.as_slice() else { + return None; + }; + Some(( + A::try_downcast_scalar(a)?, + B::try_downcast_scalar(b)?, + C::try_downcast_scalar(c)?, + )) + } + + fn try_downcast_domain(domain: &Domain) -> Option { + let [a, b, c] = domain.as_tuple()?.as_slice() else { + return None; + }; + Some(( + A::try_downcast_domain(a)?, + B::try_downcast_domain(b)?, + C::try_downcast_domain(c)?, + )) + } + + fn try_downcast_column(col: &Column) -> Option { + let [a, b, c] = col.as_tuple()?.as_slice() else { + return None; + }; + Some(( + A::try_downcast_column(a)?, + B::try_downcast_column(b)?, + C::try_downcast_column(c)?, + )) + } + + fn column_len((a, _b, _c): &Self::Column) -> usize { + debug_assert_eq!(A::column_len(a), B::column_len(_b)); + debug_assert_eq!(A::column_len(a), C::column_len(_c)); + A::column_len(a) + } + + fn index_column((a, b, c): &Self::Column, index: usize) -> Option> { + Some(( + A::index_column(a, index)?, + B::index_column(b, index)?, + C::index_column(c, index)?, + )) + } + + unsafe fn index_column_unchecked( + (a, b, c): &Self::Column, + index: usize, + ) -> Self::ScalarRef<'_> { + ( + A::index_column_unchecked(a, index), + B::index_column_unchecked(b, index), + C::index_column_unchecked(c, index), + ) + } + + fn slice_column((a, b, c): &Self::Column, range: std::ops::Range) -> Self::Column { + ( + A::slice_column(a, range.clone()), + B::slice_column(b, range.clone()), + C::slice_column(c, range), + ) + } + + fn iter_column((a, b, c): &Self::Column) -> Self::ColumnIterator<'_> { + A::iter_column(a) + .zip(B::iter_column(b)) + .zip(C::iter_column(c)) + .map(|((a, b), c)| (a, b, c)) + } + + fn compare( + (lhs_a, lhs_b, lhs_c): Self::ScalarRef<'_>, + (rhs_a, rhs_b, rhs_c): Self::ScalarRef<'_>, + ) -> Ordering { + A::compare(lhs_a, rhs_a) + .then_with(|| B::compare(lhs_b, rhs_b)) + .then_with(|| C::compare(lhs_c, rhs_c)) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct QuaternaryType(PhantomData<(A, B, C, D)>); + +impl AccessType for QuaternaryType +where + A: AccessType, + B: AccessType, + C: AccessType, + D: AccessType, +{ + type Scalar = (A::Scalar, B::Scalar, C::Scalar, D::Scalar); + type ScalarRef<'a> = (SR<'a, A>, SR<'a, B>, SR<'a, C>, SR<'a, D>); + type Column = (A::Column, B::Column, C::Column, D::Column); + type Domain = (A::Domain, B::Domain, C::Domain, D::Domain); + type ColumnIterator<'a> = Map< + Zip, Iter<'a, B>>, Iter<'a, C>>, Iter<'a, D>>, + fn((((SR<'a, A>, SR<'a, B>), SR<'a, C>), SR<'a, D>)) -> Self::ScalarRef<'a>, + >; + + fn to_owned_scalar((a, b, c, d): Self::ScalarRef<'_>) -> Self::Scalar { + ( + A::to_owned_scalar(a), + B::to_owned_scalar(b), + C::to_owned_scalar(c), + D::to_owned_scalar(d), + ) + } + + fn to_scalar_ref((a, b, c, d): &Self::Scalar) -> Self::ScalarRef<'_> { + ( + A::to_scalar_ref(a), + B::to_scalar_ref(b), + C::to_scalar_ref(c), + D::to_scalar_ref(d), + ) + } + + fn try_downcast_scalar<'a>(scalar: &ScalarRef<'a>) -> Option> { + let [a, b, c, d] = scalar.as_tuple()?.as_slice() else { + return None; + }; + Some(( + A::try_downcast_scalar(a)?, + B::try_downcast_scalar(b)?, + C::try_downcast_scalar(c)?, + D::try_downcast_scalar(d)?, + )) + } + + fn try_downcast_domain(domain: &Domain) -> Option { + let [a, b, c, d] = domain.as_tuple()?.as_slice() else { + return None; + }; + Some(( + A::try_downcast_domain(a)?, + B::try_downcast_domain(b)?, + C::try_downcast_domain(c)?, + D::try_downcast_domain(d)?, + )) + } + + fn try_downcast_column(col: &Column) -> Option { + let [a, b, c, d] = col.as_tuple()?.as_slice() else { + return None; + }; + Some(( + A::try_downcast_column(a)?, + B::try_downcast_column(b)?, + C::try_downcast_column(c)?, + D::try_downcast_column(d)?, + )) + } + + fn column_len((a, _b, _c, _d): &Self::Column) -> usize { + debug_assert_eq!(A::column_len(a), B::column_len(_b)); + debug_assert_eq!(A::column_len(a), C::column_len(_c)); + debug_assert_eq!(A::column_len(a), D::column_len(_d)); + A::column_len(a) + } + + fn index_column((a, b, c, d): &Self::Column, index: usize) -> Option> { + Some(( + A::index_column(a, index)?, + B::index_column(b, index)?, + C::index_column(c, index)?, + D::index_column(d, index)?, + )) + } + + unsafe fn index_column_unchecked( + (a, b, c, d): &Self::Column, + index: usize, + ) -> Self::ScalarRef<'_> { + ( + A::index_column_unchecked(a, index), + B::index_column_unchecked(b, index), + C::index_column_unchecked(c, index), + D::index_column_unchecked(d, index), + ) + } + + fn slice_column((a, b, c, d): &Self::Column, range: std::ops::Range) -> Self::Column { + ( + A::slice_column(a, range.clone()), + B::slice_column(b, range.clone()), + C::slice_column(c, range.clone()), + D::slice_column(d, range), + ) + } + + fn iter_column((a, b, c, d): &Self::Column) -> Self::ColumnIterator<'_> { + A::iter_column(a) + .zip(B::iter_column(b)) + .zip(C::iter_column(c)) + .zip(D::iter_column(d)) + .map(|(((a, b), c), d)| (a, b, c, d)) + } + + fn compare( + (lhs_a, lhs_b, lhs_c, lhs_d): Self::ScalarRef<'_>, + (rhs_a, rhs_b, rhs_c, rhs_d): Self::ScalarRef<'_>, + ) -> Ordering { + A::compare(lhs_a, rhs_a) + .then_with(|| B::compare(lhs_b, rhs_b)) + .then_with(|| C::compare(lhs_c, rhs_c)) + .then_with(|| D::compare(lhs_d, rhs_d)) + } +} From 564a034e76d3fa5a37a25279fe8f6adb26275249 Mon Sep 17 00:00:00 2001 From: coldWater Date: Thu, 24 Jul 2025 19:54:03 +0800 Subject: [PATCH 14/22] batch_merge --- .../adaptors/aggregate_null_adaptor.rs | 63 ++++++++++++++++++- .../adaptors/aggregate_ornull_adaptor.rs | 40 +++++++++++- .../aggregates/aggregate_combinator_state.rs | 14 +++++ 3 files changed, 113 insertions(+), 4 deletions(-) diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs index ace936d1e6a02..9c42359dec6fc 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs @@ -20,6 +20,8 @@ use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::types::NumberDataType; use databend_common_expression::utils::column_merge_validity; +use databend_common_expression::BlockEntry; +use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; @@ -313,6 +315,15 @@ impl AggregateFunction self.0.merge(place, data) } + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + ) -> Result<()> { + self.0.batch_merge(places, loc, state) + } + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { self.0.merge_states(place, rhs) } @@ -516,7 +527,6 @@ impl CommonNullAdaptor { return self.nested.merge(place, data); } - let n = data.len(); let flag = *data.last().and_then(ScalarRef::as_boolean).unwrap(); if !flag { return Ok(()); @@ -527,7 +537,56 @@ impl CommonNullAdaptor { self.init_state(place); } set_flag(place, true); - self.nested.merge(place.remove_last_loc(), &data[..n - 1]) + self.nested + .merge(place.remove_last_loc(), &data[..data.len() - 1]) + } + + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + ) -> Result<()> { + if !NULLABLE_RESULT { + return self.nested.batch_merge(places, loc, state); + } + + match state { + BlockEntry::Column(Column::Tuple(tuple)) => { + let nested_state = tuple[0..tuple.len() - 1].to_vec(); + let flag = tuple.last().unwrap().as_boolean().unwrap(); + + let places = places + .iter() + .zip(flag.iter()) + .filter_map(|(place, flag)| { + if flag { + let addr = AggrState::new(*place, loc); + if !get_flag(addr) { + // initial the state to remove the dirty stats + self.init_state(AggrState::new(*place, loc)); + } + set_flag(addr, true); + } + flag.then_some(*place) + }) + .collect::>(); + + let nested_state = Column::Tuple(nested_state).filter(flag).into(); + self.nested + .batch_merge(&places, &loc[..loc.len() - 1], &nested_state) + } + _ => { + let state = state.to_column(); + for (place, data) in places.iter().zip(state.iter()) { + self.merge( + AggrState::new(*place, loc), + data.as_tuple().unwrap().as_slice(), + )?; + } + Ok(()) + } + } } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs index 0b393b8810af4..8e07879b3a9d3 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs @@ -20,6 +20,8 @@ use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; +use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; @@ -70,6 +72,11 @@ pub fn get_flag(place: AggrState) -> bool { *c != 0 } +fn merge_flag(place: AggrState, other: bool) { + let flag = other || get_flag(place); + set_flag(place, flag); +} + fn flag_offset(place: AggrState) -> usize { *place.loc.last().unwrap().as_bool().unwrap().1 } @@ -198,10 +205,39 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { } fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - let flag = get_flag(place) || *data.last().and_then(ScalarRef::as_boolean).unwrap(); + merge_flag(place, *data.last().and_then(ScalarRef::as_boolean).unwrap()); self.inner .merge(place.remove_last_loc(), &data[..data.len() - 1])?; - set_flag(place, flag); + Ok(()) + } + + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + ) -> Result<()> { + match state { + BlockEntry::Column(Column::Tuple(tuple)) => { + let flag = tuple.last().unwrap().as_boolean().unwrap(); + for (place, flag) in places.iter().zip(flag.iter()) { + merge_flag(AggrState::new(*place, loc), flag); + } + let inner_state = Column::Tuple(tuple[0..tuple.len() - 1].to_vec()).into(); + self.inner + .batch_merge(places, &loc[0..loc.len() - 1], &inner_state)?; + } + _ => { + let state = state.to_column(); + for (place, data) in places.iter().zip(state.iter()) { + self.merge( + AggrState::new(*place, loc), + data.as_tuple().unwrap().as_slice(), + )?; + } + } + } + Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_combinator_state.rs b/src/query/functions/src/aggregates/aggregate_combinator_state.rs index 1183bc20e9dc9..0a65760be88f1 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_state.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_state.rs @@ -19,6 +19,7 @@ use databend_common_exception::Result; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::AggrStateRegistry; +use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; @@ -120,6 +121,19 @@ impl AggregateFunction for AggregateStateCombinator { self.nested.merge(place, data) } + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + ) -> Result<()> { + self.nested.batch_merge(places, loc, state) + } + + fn batch_merge_single(&self, place: AggrState, state: &BlockEntry) -> Result<()> { + self.nested.batch_merge_single(place, state) + } + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { self.nested.merge_states(place, rhs) } From 5eb78bfda69b61a70cd4083058618e0d92c9ae38 Mon Sep 17 00:00:00 2001 From: coldWater Date: Thu, 24 Jul 2025 20:36:23 +0800 Subject: [PATCH 15/22] fix --- .../src/aggregate/aggregate_function.rs | 80 +++++++++++++++---- .../src/aggregate/aggregate_hashtable.rs | 2 +- src/query/expression/src/types/tuple.rs | 45 ++++++----- .../adaptors/aggregate_null_adaptor.rs | 48 ++++++----- .../adaptors/aggregate_ornull_adaptor.rs | 36 +++++++-- .../aggregates/aggregate_combinator_state.rs | 3 +- 6 files changed, 150 insertions(+), 64 deletions(-) diff --git a/src/query/expression/src/aggregate/aggregate_function.rs b/src/query/expression/src/aggregate/aggregate_function.rs index 527790071f1be..abcfee98379ea 100755 --- a/src/query/expression/src/aggregate/aggregate_function.rs +++ b/src/query/expression/src/aggregate/aggregate_function.rs @@ -25,6 +25,7 @@ use super::StateAddr; use crate::types::AnyPairType; use crate::types::AnyQuaternaryType; use crate::types::AnyTernaryType; +use crate::types::AnyType; use crate::types::AnyUnaryType; use crate::types::DataType; use crate::BlockEntry; @@ -91,45 +92,90 @@ pub trait AggregateFunction: fmt::Display + Sync + Send { places: &[StateAddr], loc: &[AggrStateLoc], state: &BlockEntry, + filter: Option<&Bitmap>, ) -> Result<()> { match state.data_type().as_tuple().unwrap().len() { 1 => { let view = state.downcast::().unwrap(); - for (place, data) in places.iter().zip(view.iter()) { - self.merge(AggrState::new(*place, loc), &[data])?; + let iter = places.iter().zip(view.iter()); + if let Some(filter) = filter { + for (place, data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) + { + self.merge(AggrState::new(*place, loc), &[data])?; + } + } else { + for (place, data) in iter { + self.merge(AggrState::new(*place, loc), &[data])?; + } } } 2 => { let view = state.downcast::().unwrap(); - for (place, data) in places.iter().zip(view.iter()) { - self.merge(AggrState::new(*place, loc), &[data.0, data.1])?; + let iter = places.iter().zip(view.iter()); + if let Some(filter) = filter { + for (place, data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) + { + self.merge(AggrState::new(*place, loc), &[data.0, data.1])?; + } + } else { + for (place, data) in iter { + self.merge(AggrState::new(*place, loc), &[data.0, data.1])?; + } } } 3 => { let view = state.downcast::().unwrap(); - for (place, data) in places.iter().zip(view.iter()) { - self.merge(AggrState::new(*place, loc), &[data.0, data.1, data.2])?; + let iter = places.iter().zip(view.iter()); + if let Some(filter) = filter { + for (place, data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) + { + self.merge(AggrState::new(*place, loc), &[data.0, data.1, data.2])?; + } + } else { + for (place, data) in iter { + self.merge(AggrState::new(*place, loc), &[data.0, data.1, data.2])?; + } } } 4 => { let view = state.downcast::().unwrap(); - for (place, data) in places.iter().zip(view.iter()) { - self.merge(AggrState::new(*place, loc), &[ - data.0, data.1, data.2, data.3, - ])?; + let iter = places.iter().zip(view.iter()); + if let Some(filter) = filter { + for (place, data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) + { + self.merge(AggrState::new(*place, loc), &[ + data.0, data.1, data.2, data.3, + ])?; + } + } else { + for (place, data) in iter { + self.merge(AggrState::new(*place, loc), &[ + data.0, data.1, data.2, data.3, + ])?; + } } } _ => { - let state = state.to_column(); - for (place, data) in places.iter().zip(state.iter()) { - self.merge( - AggrState::new(*place, loc), - data.as_tuple().unwrap().as_slice(), - )?; + let view = state.downcast::().unwrap(); + let iter = places.iter().zip(view.iter()); + if let Some(filter) = filter { + for (place, data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) + { + self.merge( + AggrState::new(*place, loc), + data.as_tuple().unwrap().as_slice(), + )?; + } + } else { + for (place, data) in iter { + self.merge( + AggrState::new(*place, loc), + data.as_tuple().unwrap().as_slice(), + )?; + } } } } - Ok(()) } diff --git a/src/query/expression/src/aggregate/aggregate_hashtable.rs b/src/query/expression/src/aggregate/aggregate_hashtable.rs index a605227b7f5ce..2956ae9026d6d 100644 --- a/src/query/expression/src/aggregate/aggregate_hashtable.rs +++ b/src/query/expression/src/aggregate/aggregate_hashtable.rs @@ -218,7 +218,7 @@ impl AggregateHashTable { .zip(agg_states.iter()) .zip(states_layout.states_loc.iter()) { - func.batch_merge(state_places, loc, state)?; + func.batch_merge(state_places, loc, state, None)?; } } } diff --git a/src/query/expression/src/types/tuple.rs b/src/query/expression/src/types/tuple.rs index 732f8436917d0..8100da50504f6 100644 --- a/src/query/expression/src/types/tuple.rs +++ b/src/query/expression/src/types/tuple.rs @@ -35,57 +35,66 @@ pub type AnyQuaternaryType = QuaternaryType; #[derive(Debug, Clone, PartialEq, Eq)] pub struct UnaryType(PhantomData); -impl AccessType for UnaryType -where T: AccessType +impl AccessType for UnaryType +where A: AccessType { - type Scalar = T::Scalar; - type ScalarRef<'a> = T::ScalarRef<'a>; - type Column = T::Column; - type Domain = T::Domain; - type ColumnIterator<'a> = T::ColumnIterator<'a>; + type Scalar = A::Scalar; + type ScalarRef<'a> = A::ScalarRef<'a>; + type Column = A::Column; + type Domain = A::Domain; + type ColumnIterator<'a> = A::ColumnIterator<'a>; fn to_owned_scalar(scalar: Self::ScalarRef<'_>) -> Self::Scalar { - T::to_owned_scalar(scalar) + A::to_owned_scalar(scalar) } fn to_scalar_ref(scalar: &Self::Scalar) -> Self::ScalarRef<'_> { - T::to_scalar_ref(scalar) + A::to_scalar_ref(scalar) } fn try_downcast_scalar<'a>(scalar: &ScalarRef<'a>) -> Option> { - T::try_downcast_scalar(scalar) + let [a] = scalar.as_tuple()?.as_slice() else { + return None; + }; + A::try_downcast_scalar(a) } fn try_downcast_domain(domain: &Domain) -> Option { - T::try_downcast_domain(domain) + let [a] = domain.as_tuple()?.as_slice() else { + return None; + }; + A::try_downcast_domain(a) } fn try_downcast_column(col: &Column) -> Option { - T::try_downcast_column(col) + let [a] = col.as_tuple()?.as_slice() else { + return None; + }; + A::try_downcast_column(a) } fn column_len(col: &Self::Column) -> usize { - T::column_len(col) + A::column_len(col) } fn index_column(col: &Self::Column, index: usize) -> Option> { - T::index_column(col, index) + A::index_column(col, index) } unsafe fn index_column_unchecked(col: &Self::Column, index: usize) -> Self::ScalarRef<'_> { - T::index_column_unchecked(col, index) + A::index_column_unchecked(col, index) } fn slice_column(col: &Self::Column, range: std::ops::Range) -> Self::Column { - T::slice_column(col, range) + A::slice_column(col, range) } fn iter_column(col: &Self::Column) -> Self::ColumnIterator<'_> { - T::iter_column(col) + A::iter_column(col) } fn compare(lhs: Self::ScalarRef<'_>, rhs: Self::ScalarRef<'_>) -> Ordering { - T::compare(lhs, rhs) + A::compare(lhs, rhs) } } diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs index 9c42359dec6fc..0865d87cb3adc 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs @@ -16,6 +16,7 @@ use std::fmt; use std::sync::Arc; use databend_common_exception::Result; +use databend_common_expression::types::AnyType; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::types::NumberDataType; @@ -320,8 +321,9 @@ impl AggregateFunction places: &[StateAddr], loc: &[AggrStateLoc], state: &BlockEntry, + filter: Option<&Bitmap>, ) -> Result<()> { - self.0.batch_merge(places, loc, state) + self.0.batch_merge(places, loc, state, filter) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { @@ -546,38 +548,46 @@ impl CommonNullAdaptor { places: &[StateAddr], loc: &[AggrStateLoc], state: &BlockEntry, + filter: Option<&Bitmap>, ) -> Result<()> { if !NULLABLE_RESULT { - return self.nested.batch_merge(places, loc, state); + return self.nested.batch_merge(places, loc, state, filter); } match state { BlockEntry::Column(Column::Tuple(tuple)) => { let nested_state = tuple[0..tuple.len() - 1].to_vec(); let flag = tuple.last().unwrap().as_boolean().unwrap(); + let flag = match filter { + Some(filter) => filter & flag, + None => flag.clone(), + }; + if flag.null_count() == 0 { + return self.nested.batch_merge( + places, + loc, + &Column::Tuple(nested_state).into(), + filter, + ); + } - let places = places - .iter() - .zip(flag.iter()) - .filter_map(|(place, flag)| { - if flag { - let addr = AggrState::new(*place, loc); - if !get_flag(addr) { - // initial the state to remove the dirty stats - self.init_state(AggrState::new(*place, loc)); - } - set_flag(addr, true); + for (place, flag) in places.iter().zip(flag.iter()) { + if flag { + let addr = AggrState::new(*place, loc); + if !get_flag(addr) { + // initial the state to remove the dirty stats + self.init_state(AggrState::new(*place, loc)); } - flag.then_some(*place) - }) - .collect::>(); + set_flag(addr, true); + } + } - let nested_state = Column::Tuple(nested_state).filter(flag).into(); + let nested_state = Column::Tuple(nested_state).into(); self.nested - .batch_merge(&places, &loc[..loc.len() - 1], &nested_state) + .batch_merge(places, &loc[..loc.len() - 1], &nested_state, Some(&flag)) } _ => { - let state = state.to_column(); + let state = state.downcast::().unwrap(); for (place, data) in places.iter().zip(state.iter()) { self.merge( AggrState::new(*place, loc), diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs index 8e07879b3a9d3..f1f1692228a57 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs @@ -216,24 +216,44 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { places: &[StateAddr], loc: &[AggrStateLoc], state: &BlockEntry, + filter: Option<&Bitmap>, ) -> Result<()> { match state { BlockEntry::Column(Column::Tuple(tuple)) => { let flag = tuple.last().unwrap().as_boolean().unwrap(); - for (place, flag) in places.iter().zip(flag.iter()) { - merge_flag(AggrState::new(*place, loc), flag); + let iter = places.iter().zip(flag.iter()); + if let Some(filter) = filter { + for (place, flag) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) + { + merge_flag(AggrState::new(*place, loc), flag); + } + } else { + for (place, flag) in iter { + merge_flag(AggrState::new(*place, loc), flag); + } } let inner_state = Column::Tuple(tuple[0..tuple.len() - 1].to_vec()).into(); self.inner - .batch_merge(places, &loc[0..loc.len() - 1], &inner_state)?; + .batch_merge(places, &loc[0..loc.len() - 1], &inner_state, filter)?; } _ => { let state = state.to_column(); - for (place, data) in places.iter().zip(state.iter()) { - self.merge( - AggrState::new(*place, loc), - data.as_tuple().unwrap().as_slice(), - )?; + let iter = places.iter().zip(state.iter()); + if let Some(filter) = filter { + for (place, data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) + { + self.merge( + AggrState::new(*place, loc), + data.as_tuple().unwrap().as_slice(), + )?; + } + } else { + for (place, data) in iter { + self.merge( + AggrState::new(*place, loc), + data.as_tuple().unwrap().as_slice(), + )?; + } } } } diff --git a/src/query/functions/src/aggregates/aggregate_combinator_state.rs b/src/query/functions/src/aggregates/aggregate_combinator_state.rs index 0a65760be88f1..346ae3973186e 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_state.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_state.rs @@ -126,8 +126,9 @@ impl AggregateFunction for AggregateStateCombinator { places: &[StateAddr], loc: &[AggrStateLoc], state: &BlockEntry, + filter: Option<&Bitmap>, ) -> Result<()> { - self.nested.batch_merge(places, loc, state) + self.nested.batch_merge(places, loc, state, filter) } fn batch_merge_single(&self, place: AggrState, state: &BlockEntry) -> Result<()> { From 1edcb4318dcdb76aa9e47167c40a36b3f3e2238c Mon Sep 17 00:00:00 2001 From: coldWater Date: Fri, 25 Jul 2025 10:18:28 +0800 Subject: [PATCH 16/22] fix --- Cargo.lock | 1 + src/query/expression/Cargo.toml | 1 + .../src/aggregate/aggregate_hashtable.rs | 1 + .../adaptors/aggregate_null_adaptor.rs | 52 +++++++++---------- .../transforms/aggregator/aggregate_meta.rs | 1 + 5 files changed, 29 insertions(+), 27 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2af0887a305e0..fee355a5110fe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3350,6 +3350,7 @@ dependencies = [ "either", "enum-as-inner", "ethnum", + "fastrace", "futures", "geo", "geozero", diff --git a/src/query/expression/Cargo.toml b/src/query/expression/Cargo.toml index 6cb93c1420d0a..a0e1d699dfc52 100644 --- a/src/query/expression/Cargo.toml +++ b/src/query/expression/Cargo.toml @@ -33,6 +33,7 @@ educe = { workspace = true } either = { workspace = true } enum-as-inner = { workspace = true } ethnum = { workspace = true, features = ["serde", "macros"] } +fastrace = { workspace = true } futures = { workspace = true } geo = { workspace = true } geozero = { workspace = true } diff --git a/src/query/expression/src/aggregate/aggregate_hashtable.rs b/src/query/expression/src/aggregate/aggregate_hashtable.rs index 2956ae9026d6d..24646b092f81c 100644 --- a/src/query/expression/src/aggregate/aggregate_hashtable.rs +++ b/src/query/expression/src/aggregate/aggregate_hashtable.rs @@ -167,6 +167,7 @@ impl AggregateHashTable { } } + #[fastrace::trace(name = "AggregateHashTable::add_groups_inner")] // Add new groups and combine the states fn add_groups_inner( &mut self, diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs index 0865d87cb3adc..3bf1a6889bf98 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs @@ -534,11 +534,7 @@ impl CommonNullAdaptor { return Ok(()); } - if !get_flag(place) { - // initial the state to remove the dirty stats - self.init_state(place); - } - set_flag(place, true); + self.update_flag(place); self.nested .merge(place.remove_last_loc(), &data[..data.len() - 1]) } @@ -556,35 +552,29 @@ impl CommonNullAdaptor { match state { BlockEntry::Column(Column::Tuple(tuple)) => { - let nested_state = tuple[0..tuple.len() - 1].to_vec(); + let nested_state = Column::Tuple(tuple[0..tuple.len() - 1].to_vec()); let flag = tuple.last().unwrap().as_boolean().unwrap(); let flag = match filter { Some(filter) => filter & flag, None => flag.clone(), }; - if flag.null_count() == 0 { - return self.nested.batch_merge( - places, - loc, - &Column::Tuple(nested_state).into(), - filter, - ); - } - - for (place, flag) in places.iter().zip(flag.iter()) { - if flag { - let addr = AggrState::new(*place, loc); - if !get_flag(addr) { - // initial the state to remove the dirty stats - self.init_state(AggrState::new(*place, loc)); - } - set_flag(addr, true); + let filter = if flag.null_count() == 0 { + for place in places.iter() { + self.update_flag(AggrState::new(*place, loc)); } - } - - let nested_state = Column::Tuple(nested_state).into(); + None + } else { + for place in places + .iter() + .zip(flag.iter()) + .filter_map(|(place, flag)| flag.then_some(place)) + { + self.update_flag(AggrState::new(*place, loc)); + } + Some(&flag) + }; self.nested - .batch_merge(places, &loc[..loc.len() - 1], &nested_state, Some(&flag)) + .batch_merge(places, &loc[..loc.len() - 1], &nested_state.into(), filter) } _ => { let state = state.downcast::().unwrap(); @@ -643,6 +633,14 @@ impl CommonNullAdaptor { self.nested.drop_state(place.remove_last_loc()) } } + + fn update_flag(&self, place: AggrState) { + if !get_flag(place) { + // initial the state to remove the dirty stats + self.init_state(place); + } + set_flag(place, true); + } } fn set_flag(place: AggrState, flag: bool) { diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_meta.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_meta.rs index b95066fc57d68..905df88ccdcd5 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_meta.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_meta.rs @@ -45,6 +45,7 @@ impl SerializedPayload { entry.as_column().unwrap() } + #[fastrace::trace(name = "SerializedPayload::convert_to_aggregate_table")] pub fn convert_to_aggregate_table( &self, group_types: Vec, From 394286472140bbe79bfd2a8ffbb8ec9e390b84a4 Mon Sep 17 00:00:00 2001 From: coldWater Date: Fri, 25 Jul 2025 14:59:36 +0800 Subject: [PATCH 17/22] test --- .../src/aggregate/aggregate_hashtable.rs | 1 + .../adaptors/aggregate_null_adaptor.rs | 435 +++++++++++++++++- 2 files changed, 415 insertions(+), 21 deletions(-) diff --git a/src/query/expression/src/aggregate/aggregate_hashtable.rs b/src/query/expression/src/aggregate/aggregate_hashtable.rs index 24646b092f81c..98bb0cfe10476 100644 --- a/src/query/expression/src/aggregate/aggregate_hashtable.rs +++ b/src/query/expression/src/aggregate/aggregate_hashtable.rs @@ -382,6 +382,7 @@ impl AggregateHashTable { Ok(()) } + #[fastrace::trace(name = "AggregateHashTable::combine_payload")] pub fn combine_payload( &mut self, payload: &Payload, diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs index 3bf1a6889bf98..d64d551ac3528 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs @@ -195,6 +195,16 @@ impl AggregateFunction for AggregateNullUnaryAdapto self.0.merge(place, data) } + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + self.0.batch_merge(places, loc, state, filter) + } + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { self.0.merge_states(place, rhs) } @@ -233,6 +243,20 @@ impl AggregateNullVariadicAdaptor } } +impl AggregateNullVariadicAdaptor { + fn merge_validity( + columns: ProjectedBlock, + mut validity: Option, + ) -> (Vec, Option) { + let mut not_null_columns = Vec::with_capacity(columns.len()); + for entry in columns.iter() { + validity = column_merge_validity(&entry.clone(), validity); + not_null_columns.push(entry.clone().remove_nullable()); + } + (not_null_columns, validity) + } +} + impl AggregateFunction for AggregateNullVariadicAdaptor { @@ -260,14 +284,8 @@ impl AggregateFunction validity: Option<&Bitmap>, input_rows: usize, ) -> Result<()> { - let mut not_null_columns = Vec::with_capacity(columns.len()); - let mut validity = validity.cloned(); - for entry in columns.iter() { - validity = column_merge_validity(&entry.clone(), validity); - not_null_columns.push(entry.clone().remove_nullable()); - } + let (not_null_columns, validity) = Self::merge_validity(columns, validity.cloned()); let not_null_columns = (¬_null_columns).into(); - self.0 .accumulate(place, not_null_columns, validity, input_rows) } @@ -279,27 +297,15 @@ impl AggregateFunction columns: ProjectedBlock, input_rows: usize, ) -> Result<()> { - let mut not_null_columns = Vec::with_capacity(columns.len()); - let mut validity = None; - for entry in columns.iter() { - validity = column_merge_validity(&entry.clone(), validity); - not_null_columns.push(entry.clone().remove_nullable()); - } + let (not_null_columns, validity) = Self::merge_validity(columns, None); let not_null_columns = (¬_null_columns).into(); - self.0 .accumulate_keys(addrs, loc, not_null_columns, validity, input_rows) } fn accumulate_row(&self, place: AggrState, columns: ProjectedBlock, row: usize) -> Result<()> { - let mut not_null_columns = Vec::with_capacity(columns.len()); - let mut validity = None; - for entry in columns.iter() { - validity = column_merge_validity(&entry.clone(), validity); - not_null_columns.push(entry.clone().remove_nullable()); - } + let (not_null_columns, validity) = Self::merge_validity(columns, None); let not_null_columns = (¬_null_columns).into(); - self.0 .accumulate_row(place, not_null_columns, validity, row) } @@ -656,3 +662,390 @@ fn get_flag(place: AggrState) -> bool { fn flag_offset(place: AggrState) -> usize { *place.loc.last().unwrap().as_bool().unwrap().1 } + +#[cfg(test)] +mod tests { + + use std::sync::Arc; + + use databend_common_expression::types::*; + use databend_common_expression::*; + + use super::*; + + struct TestStateBuffer { + buffer: Vec, + } + + impl TestStateBuffer { + fn new(layout: &std::alloc::Layout) -> Self { + let buffer = vec![1u8; layout.size()]; + Self { buffer } + } + + fn addr(&self) -> StateAddr { + StateAddr::new(self.buffer.as_ptr() as usize) + } + } + + #[derive(Clone)] + struct MockAggregateFunction { + return_type: DataType, + serialize_items: Vec, + } + + impl fmt::Display for MockAggregateFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "MockAggregateFunction") + } + } + + impl AggregateFunction for MockAggregateFunction { + fn name(&self) -> &str { + "mock" + } + + fn return_type(&self) -> Result { + Ok(self.return_type.clone()) + } + + fn init_state(&self, place: AggrState) { + place.write(|| 3_u64); + } + + fn register_state(&self, registry: &mut AggrStateRegistry) { + registry.register(AggrStateType::Custom(std::alloc::Layout::new::())); + } + + fn accumulate( + &self, + _place: AggrState, + _columns: ProjectedBlock, + _validity: Option<&Bitmap>, + _input_rows: usize, + ) -> Result<()> { + Ok(()) + } + + fn accumulate_row( + &self, + _place: AggrState, + _columns: ProjectedBlock, + _row: usize, + ) -> Result<()> { + Ok(()) + } + + fn serialize_type(&self) -> Vec { + self.serialize_items.clone() + } + + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + let state = place.get::(); + if let ColumnBuilder::Number(NumberColumnBuilder::UInt64(builder)) = &mut builders[0] { + builder.push(*state); + } + Ok(()) + } + + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + let state = place.get::(); + if let Some(ScalarRef::Number(NumberScalar::UInt64(value))) = data.get(0) { + *state += value; + } + Ok(()) + } + + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { + let lhs_state = place.get::(); + let rhs_state = rhs.get::(); + *lhs_state += *rhs_state; + Ok(()) + } + + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { + let state = place.get::(); + if let ColumnBuilder::Number(NumberColumnBuilder::UInt64(builder)) = builder { + builder.push(*state); + } + Ok(()) + } + } + + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + pub enum BatchMergeStrategy { + Reference, + Optimized, + } + + pub struct TestNullFunction { + adaptor: CommonNullAdaptor, + strategy: BatchMergeStrategy, + } + + impl TestNullFunction { + pub fn new(nested: AggregateFunctionRef, strategy: BatchMergeStrategy) -> Self { + Self { + adaptor: CommonNullAdaptor { nested }, + strategy, + } + } + + fn reference_batch_merge_impl( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::().unwrap(); + let iter = places.iter().zip(view.iter()); + if let Some(filter) = filter { + for (place, data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + self.merge( + AggrState::new(*place, loc), + data.as_tuple().unwrap().as_slice(), + )?; + } + } else { + for (place, data) in iter { + self.merge( + AggrState::new(*place, loc), + data.as_tuple().unwrap().as_slice(), + )?; + } + } + + Ok(()) + } + } + + impl fmt::Display for TestNullFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "ConfigurableNull") + } + } + + impl AggregateFunction for TestNullFunction { + fn name(&self) -> &str { + "test_null" + } + + fn return_type(&self) -> Result { + self.adaptor.return_type() + } + + fn init_state(&self, place: AggrState) { + self.adaptor.init_state(place) + } + + fn register_state(&self, registry: &mut AggrStateRegistry) { + self.adaptor.register_state(registry) + } + + fn accumulate( + &self, + place: AggrState, + columns: ProjectedBlock, + validity: Option<&Bitmap>, + input_rows: usize, + ) -> Result<()> { + self.adaptor + .accumulate(place, columns, validity.cloned(), input_rows) + } + + fn accumulate_row( + &self, + place: AggrState, + columns: ProjectedBlock, + row: usize, + ) -> Result<()> { + self.adaptor.accumulate_row(place, columns, None, row) + } + + fn serialize_type(&self) -> Vec { + self.adaptor.serialize_type() + } + + fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + self.adaptor.serialize(place, builders) + } + + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { + self.adaptor.merge(place, data) + } + + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { + self.adaptor.merge_states(place, rhs) + } + + fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { + self.adaptor.merge_result(place, builder) + } + + /// Custom batch_merge implementation that switches between reference and optimized strategies + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + match self.strategy { + BatchMergeStrategy::Reference => { + self.reference_batch_merge_impl(places, loc, state, filter) + } + BatchMergeStrategy::Optimized => { + self.adaptor.batch_merge(places, loc, state, filter) + } + } + } + } + + fn create_configurable(strategy: BatchMergeStrategy) -> TestNullFunction { + let nested = Arc::new(MockAggregateFunction { + return_type: DataType::Number(NumberDataType::UInt64), + serialize_items: vec![StateSerdeItem::DataType(DataType::Number( + NumberDataType::UInt64, + ))], + }); + TestNullFunction::new(nested, strategy) + } + + fn create_test_state_entry(values: &[u64], flags: &[bool]) -> Result { + // Build the state column as a tuple + let mut value_builder = + ColumnBuilder::with_capacity(&DataType::Number(NumberDataType::UInt64), values.len()); + let mut flag_builder = ColumnBuilder::with_capacity(&DataType::Boolean, flags.len()); + + for &value in values { + if let ColumnBuilder::Number(NumberColumnBuilder::UInt64(builder)) = &mut value_builder + { + builder.push(value); + } + } + + for &flag in flags { + if let ColumnBuilder::Boolean(builder) = &mut flag_builder { + builder.push(flag); + } + } + + let value_column = value_builder.build(); + let flag_column = flag_builder.build(); + let tuple_column = Column::Tuple(vec![value_column, flag_column]); + + Ok(BlockEntry::new(Value::Column(tuple_column), || { + ( + DataType::Tuple(vec![ + DataType::Number(NumberDataType::UInt64), + DataType::Boolean, + ]), + values.len(), + ) + })) + } + + fn run_batch_merge_with_strategy( + strategy: BatchMergeStrategy, + buffers: &[TestStateBuffer], + loc: &[AggrStateLoc], + state_entry: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let func = create_configurable(strategy); + let places: Vec = buffers.iter().map(|b| b.addr()).collect(); + + // Initialize states + for buffer in buffers { + let place = AggrState::new(buffer.addr(), loc); + func.init_state(place); + } + + // Run batch_merge + func.batch_merge(&places, loc, state_entry, filter) + } + + fn test_both_strategies(values: &[u64], flags: &[bool], filter: Option<&Bitmap>) -> Result<()> { + let state_entry = create_test_state_entry(values, flags)?; + let (reference_buffers, optimized_buffers, loc) = { + let buffer_count = values.len(); + let states_layout = + get_states_layout(&[Arc::new(create_configurable(BatchMergeStrategy::Reference))])?; + let layout = states_layout.layout; + let loc = states_layout.states_loc[0].to_vec(); + + let reference_buffers: Vec = (0..buffer_count) + .map(|_| TestStateBuffer::new(&layout)) + .collect(); + let optimized_buffers: Vec = (0..buffer_count) + .map(|_| TestStateBuffer::new(&layout)) + .collect(); + + (reference_buffers, optimized_buffers, loc) + }; + + // Run both strategies + run_batch_merge_with_strategy( + BatchMergeStrategy::Reference, + &reference_buffers, + &loc, + &state_entry, + filter, + )?; + + run_batch_merge_with_strategy( + BatchMergeStrategy::Optimized, + &optimized_buffers, + &loc, + &state_entry, + filter, + )?; + + for i in 0..reference_buffers.len() { + assert_eq!( + reference_buffers[i].buffer, optimized_buffers[i].buffer, + "Buffer contents should match at index {}", + i + ); + } + Ok(()) + } + + #[test] + fn test_batch_merge_equivalence() -> Result<()> { + { + let values = vec![10u64, 20u64, 30u64]; + let flags = vec![true, false, true]; + + test_both_strategies(&values, &flags, None)?; + } + + { + let values = vec![100u64, 200u64, 300u64, 400u64]; + let flags = vec![true, false, true, true]; // Skip index 1 + + test_both_strategies(&values, &flags, None)?; + } + + { + let values = vec![10u64, 20u64, 30u64, 40u64]; + let flags = vec![true, false, true, true]; + let filter_bits = vec![true, false, true, false]; // Only process indices 0 and 2 + let filter = Bitmap::from_iter(filter_bits.iter().copied()); + + test_both_strategies(&values, &flags, Some(&filter))?; + } + + { + let values = vec![100u64, 200u64, 300u64, 400u64]; + let flags = vec![true, true, false, true]; // Skip index 2 + let filter_bits = vec![true, false, true, true]; // Skip index 1 + let filter = Bitmap::from_iter(filter_bits.iter().copied()); + + test_both_strategies(&values, &flags, Some(&filter))?; + } + + Ok(()) + } +} From 36fcb59f777a5625ad7eea843c302b67843cd6b2 Mon Sep 17 00:00:00 2001 From: coldWater Date: Fri, 25 Jul 2025 16:25:27 +0800 Subject: [PATCH 18/22] sum --- .../adaptors/aggregate_null_adaptor.rs | 6 ++--- .../functions/src/aggregates/aggregate_sum.rs | 22 +++++++++++++--- .../src/aggregates/aggregate_unary.rs | 26 ++++++++++++++----- 3 files changed, 41 insertions(+), 13 deletions(-) diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs index d64d551ac3528..31f22f50583d4 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs @@ -750,7 +750,7 @@ mod tests { fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { let state = place.get::(); - if let Some(ScalarRef::Number(NumberScalar::UInt64(value))) = data.get(0) { + if let Some(ScalarRef::Number(NumberScalar::UInt64(value))) = data.first() { *state += value; } Ok(()) @@ -1031,7 +1031,7 @@ mod tests { { let values = vec![10u64, 20u64, 30u64, 40u64]; let flags = vec![true, false, true, true]; - let filter_bits = vec![true, false, true, false]; // Only process indices 0 and 2 + let filter_bits = [true, false, true, false]; // Only process indices 0 and 2 let filter = Bitmap::from_iter(filter_bits.iter().copied()); test_both_strategies(&values, &flags, Some(&filter))?; @@ -1040,7 +1040,7 @@ mod tests { { let values = vec![100u64, 200u64, 300u64, 400u64]; let flags = vec![true, true, false, true]; // Skip index 2 - let filter_bits = vec![true, false, true, true]; // Skip index 1 + let filter_bits = [true, false, true, true]; // Skip index 1 let filter = Bitmap::from_iter(filter_bits.iter().copied()); test_both_strategies(&values, &flags, Some(&filter))?; diff --git a/src/query/functions/src/aggregates/aggregate_sum.rs b/src/query/functions/src/aggregates/aggregate_sum.rs index 1aeff4f5101b0..94596dc3c61b7 100644 --- a/src/query/functions/src/aggregates/aggregate_sum.rs +++ b/src/query/functions/src/aggregates/aggregate_sum.rs @@ -32,7 +32,9 @@ use databend_common_expression::AggregateFunctionRef; use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::Scalar; +use databend_common_expression::ScalarRef; use databend_common_expression::StateAddr; +use databend_common_expression::StateSerdeItem; use databend_common_expression::SELECTIVITY_THRESHOLD; use num_traits::AsPrimitive; @@ -76,14 +78,14 @@ pub trait SumState: BorshSerialize + BorshDeserialize + Send + Sync + Default + #[derive(BorshSerialize, BorshDeserialize)] pub struct NumberSumState -where N: ValueType +where N: ArgType { pub value: N::Scalar, } impl Default for NumberSumState where - N: ValueType, + N: ArgType, N::Scalar: Number + AsPrimitive + BorshSerialize + BorshDeserialize + std::ops::AddAssign, { fn default() -> Self { @@ -130,7 +132,7 @@ where impl UnaryState for NumberSumState where T: ArgType + Sync + Send, - N: ValueType, + N: ArgType, T::Scalar: Number + AsPrimitive, N::Scalar: Number + AsPrimitive + BorshSerialize + BorshDeserialize + std::ops::AddAssign, for<'a> T::ScalarRef<'a>: Number + AsPrimitive, @@ -169,6 +171,20 @@ where builder.push_item(N::to_scalar_ref(&self.value)); Ok(()) } + + fn serialize_type() -> Vec { + std::vec![StateSerdeItem::DataType(T::data_type())] + } + + fn serialize_state(&self, builders: &mut [ColumnBuilder]) -> Result<()> { + N::downcast_builder(&mut builders[0]).push_item(N::to_scalar_ref(&self.value)); + Ok(()) + } + + fn deserialize_state(data: &[ScalarRef]) -> Result { + let value = N::to_owned_scalar(N::try_downcast_scalar(&data[0]).unwrap()); + Ok(Self { value }) + } } #[derive(BorshDeserialize, BorshSerialize)] diff --git a/src/query/functions/src/aggregates/aggregate_unary.rs b/src/query/functions/src/aggregates/aggregate_unary.rs index ab08ea35eb08a..629156d90f14b 100644 --- a/src/query/functions/src/aggregates/aggregate_unary.rs +++ b/src/query/functions/src/aggregates/aggregate_unary.rs @@ -80,6 +80,22 @@ where builder: R::ColumnBuilderMut<'_>, function_data: Option<&dyn FunctionData>, ) -> Result<()>; + + fn serialize_type() -> Vec { + vec![StateSerdeItem::Binary(None)] + } + + fn serialize_state(&self, builders: &mut [ColumnBuilder]) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); + self.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + Ok(()) + } + + fn deserialize_state(data: &[ScalarRef]) -> Result { + let mut binary = *data[0].as_binary().unwrap(); + Ok(Self::deserialize_reader(&mut binary)?) + } } pub trait FunctionData: Send + Sync { @@ -230,21 +246,17 @@ where } fn serialize_type(&self) -> Vec { - vec![StateSerdeItem::Binary(None)] + S::serialize_type() } fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { - let binary_builder = builders[0].as_binary_mut().unwrap(); let state: &mut S = place.get::(); - state.serialize(&mut binary_builder.data)?; - binary_builder.commit_row(); - Ok(()) + state.serialize_state(builders) } fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - let mut binary = *data[0].as_binary().unwrap(); let state: &mut S = place.get::(); - let rhs = S::deserialize_reader(&mut binary)?; + let rhs = S::deserialize_state(data)?; state.merge(&rhs) } From 45a3f7bb90afd4c3e3a62269c70023f893dcf33e Mon Sep 17 00:00:00 2001 From: coldWater Date: Fri, 25 Jul 2025 18:28:10 +0800 Subject: [PATCH 19/22] batch_serialize --- Cargo.lock | 1 - src/query/expression/Cargo.toml | 1 - .../src/aggregate/aggregate_function.rs | 12 ++++ .../src/aggregate/aggregate_hashtable.rs | 2 - .../expression/src/aggregate/payload_flush.rs | 21 +++--- .../adaptors/aggregate_null_adaptor.rs | 41 +++++++++++ .../adaptors/aggregate_ornull_adaptor.rs | 72 ++++++++++++------- .../src/aggregates/aggregate_combinator_if.rs | 19 +++++ .../aggregates/aggregate_combinator_state.rs | 9 +++ .../functions/src/aggregates/aggregate_sum.rs | 2 +- .../02_0000_function_aggregate_state.test | 4 +- 11 files changed, 138 insertions(+), 46 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index fee355a5110fe..2af0887a305e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3350,7 +3350,6 @@ dependencies = [ "either", "enum-as-inner", "ethnum", - "fastrace", "futures", "geo", "geozero", diff --git a/src/query/expression/Cargo.toml b/src/query/expression/Cargo.toml index a0e1d699dfc52..6cb93c1420d0a 100644 --- a/src/query/expression/Cargo.toml +++ b/src/query/expression/Cargo.toml @@ -33,7 +33,6 @@ educe = { workspace = true } either = { workspace = true } enum-as-inner = { workspace = true } ethnum = { workspace = true, features = ["serde", "macros"] } -fastrace = { workspace = true } futures = { workspace = true } geo = { workspace = true } geozero = { workspace = true } diff --git a/src/query/expression/src/aggregate/aggregate_function.rs b/src/query/expression/src/aggregate/aggregate_function.rs index abcfee98379ea..700d9a2c77f11 100755 --- a/src/query/expression/src/aggregate/aggregate_function.rs +++ b/src/query/expression/src/aggregate/aggregate_function.rs @@ -84,6 +84,18 @@ pub trait AggregateFunction: fmt::Display + Sync + Send { fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()>; + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + for place in places { + self.serialize(AggrState::new(*place, loc), builders)?; + } + Ok(()) + } + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()>; /// Batch deserialize the state and merge diff --git a/src/query/expression/src/aggregate/aggregate_hashtable.rs b/src/query/expression/src/aggregate/aggregate_hashtable.rs index 98bb0cfe10476..2956ae9026d6d 100644 --- a/src/query/expression/src/aggregate/aggregate_hashtable.rs +++ b/src/query/expression/src/aggregate/aggregate_hashtable.rs @@ -167,7 +167,6 @@ impl AggregateHashTable { } } - #[fastrace::trace(name = "AggregateHashTable::add_groups_inner")] // Add new groups and combine the states fn add_groups_inner( &mut self, @@ -382,7 +381,6 @@ impl AggregateHashTable { Ok(()) } - #[fastrace::trace(name = "AggregateHashTable::combine_payload")] pub fn combine_payload( &mut self, payload: &Payload, diff --git a/src/query/expression/src/aggregate/payload_flush.rs b/src/query/expression/src/aggregate/payload_flush.rs index f6aa0d65d5772..317737fca0960 100644 --- a/src/query/expression/src/aggregate/payload_flush.rs +++ b/src/query/expression/src/aggregate/payload_flush.rs @@ -18,7 +18,6 @@ use databend_common_io::prelude::bincode_deserialize_from_slice; use super::partitioned_payload::PartitionedPayload; use super::payload::Payload; use super::probe_state::ProbeState; -use super::AggrState; use crate::read; use crate::types::binary::BinaryColumn; use crate::types::binary::BinaryColumnBuilder; @@ -141,18 +140,14 @@ impl Payload { if let Some(state_layout) = self.states_layout.as_ref() { let mut builders = state_layout.serialize_builders(row_count); - for place in state.state_places.as_slice()[0..row_count].iter() { - for (idx, (loc, func)) in state_layout - .states_loc - .iter() - .zip(self.aggrs.iter()) - .enumerate() - { - { - let builders = builders[idx].as_tuple_mut().unwrap().as_mut_slice(); - func.serialize(AggrState::new(*place, loc), builders)?; - } - } + for ((loc, func), builder) in state_layout + .states_loc + .iter() + .zip(self.aggrs.iter()) + .zip(builders.iter_mut()) + { + let builders = builder.as_tuple_mut().unwrap().as_mut_slice(); + func.batch_serialize(&state.state_places.as_slice()[0..row_count], loc, builders)?; } entries.extend(builders.into_iter().map(|builder| builder.build().into())); diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs index 31f22f50583d4..b438fb0edc455 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs @@ -191,6 +191,15 @@ impl AggregateFunction for AggregateNullUnaryAdapto self.0.serialize(place, builders) } + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + self.0.batch_serialize(places, loc, builders) + } + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { self.0.merge(place, data) } @@ -318,6 +327,15 @@ impl AggregateFunction self.0.serialize(place, builders) } + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + self.0.batch_serialize(places, loc, builders) + } + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { self.0.merge(place, data) } @@ -530,6 +548,29 @@ impl CommonNullAdaptor { .serialize(place.remove_last_loc(), &mut builders[..(n - 1)]) } + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + if !NULLABLE_RESULT { + return self.nested.batch_serialize(places, loc, builders); + } + let n = builders.len(); + debug_assert_eq!(self.nested.serialize_type().len() + 1, n); + let flag_builder = builders + .last_mut() + .and_then(ColumnBuilder::as_boolean_mut) + .unwrap(); + for place in places { + let place = AggrState::new(*place, loc); + flag_builder.push(get_flag(place)); + } + self.nested + .batch_serialize(places, &loc[..loc.len() - 1], &mut builders[..(n - 1)]) + } + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { if !NULLABLE_RESULT { return self.nested.merge(place, data); diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs index f1f1692228a57..be39946e94aba 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs @@ -40,23 +40,23 @@ use super::StateAddr; /// Use a single additional byte of data after the nested function data: /// 0 means there was no input, 1 means there was some. pub struct AggregateFunctionOrNullAdaptor { - inner: AggregateFunctionRef, + nested: AggregateFunctionRef, inner_nullable: bool, } impl AggregateFunctionOrNullAdaptor { pub fn create( - inner: AggregateFunctionRef, + nested: AggregateFunctionRef, features: AggregateFunctionFeatures, ) -> Result { // count/count distinct should not be nullable for empty set, just return zero - let inner_return_type = inner.return_type()?; + let inner_return_type = nested.return_type()?; if features.returns_default_when_only_null || inner_return_type == DataType::Null { - return Ok(inner); + return Ok(nested); } Ok(Arc::new(AggregateFunctionOrNullAdaptor { - inner, + nested, inner_nullable: inner_return_type.is_nullable(), })) } @@ -83,22 +83,22 @@ fn flag_offset(place: AggrState) -> usize { impl AggregateFunction for AggregateFunctionOrNullAdaptor { fn name(&self) -> &str { - self.inner.name() + self.nested.name() } fn return_type(&self) -> Result { - Ok(self.inner.return_type()?.wrap_nullable()) + Ok(self.nested.return_type()?.wrap_nullable()) } #[inline] fn init_state(&self, place: AggrState) { let c = place.addr.next(flag_offset(place)).get::(); *c = 0; - self.inner.init_state(place.remove_last_loc()) + self.nested.init_state(place.remove_last_loc()) } fn register_state(&self, registry: &mut AggrStateRegistry) { - self.inner.register_state(registry); + self.nested.register_state(registry); registry.register(AggrStateType::Bool); } @@ -114,7 +114,7 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { return Ok(()); } - let if_cond = self.inner.get_if_condition(columns); + let if_cond = self.nested.get_if_condition(columns); let validity = match (if_cond, validity) { (None, None) => None, @@ -129,7 +129,7 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { .unwrap_or(true) { set_flag(place, true); - self.inner.accumulate( + self.nested.accumulate( place.remove_last_loc(), columns, validity.as_ref(), @@ -146,9 +146,9 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { columns: ProjectedBlock, input_rows: usize, ) -> Result<()> { - self.inner + self.nested .accumulate_keys(places, &loc[..loc.len() - 1], columns, input_rows)?; - let if_cond = self.inner.get_if_condition(columns); + let if_cond = self.nested.get_if_condition(columns); match if_cond { Some(v) if v.null_count() > 0 => { @@ -175,14 +175,14 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { #[inline] fn accumulate_row(&self, place: AggrState, columns: ProjectedBlock, row: usize) -> Result<()> { - self.inner + self.nested .accumulate_row(place.remove_last_loc(), columns, row)?; set_flag(place, true); Ok(()) } fn serialize_type(&self) -> Vec { - self.inner + self.nested .serialize_type() .into_iter() .chain(Some(StateSerdeItem::DataType(DataType::Boolean))) @@ -191,7 +191,7 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { let n = builders.len(); - debug_assert_eq!(self.inner.serialize_type().len() + 1, n); + debug_assert_eq!(self.nested.serialize_type().len() + 1, n); let flag = get_flag(place); builders @@ -200,13 +200,33 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { .unwrap() .push(flag); - self.inner + self.nested .serialize(place.remove_last_loc(), &mut builders[..n - 1]) } + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + let n = builders.len(); + debug_assert_eq!(self.nested.serialize_type().len() + 1, n); + let flag_builder = builders + .last_mut() + .and_then(ColumnBuilder::as_boolean_mut) + .unwrap(); + for place in places { + let place = AggrState::new(*place, loc); + flag_builder.push(get_flag(place)); + } + self.nested + .batch_serialize(places, &loc[..loc.len() - 1], &mut builders[..(n - 1)]) + } + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { merge_flag(place, *data.last().and_then(ScalarRef::as_boolean).unwrap()); - self.inner + self.nested .merge(place.remove_last_loc(), &data[..data.len() - 1])?; Ok(()) } @@ -233,7 +253,7 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { } } let inner_state = Column::Tuple(tuple[0..tuple.len() - 1].to_vec()).into(); - self.inner + self.nested .batch_merge(places, &loc[0..loc.len() - 1], &inner_state, filter)?; } _ => { @@ -262,7 +282,7 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { - self.inner + self.nested .merge_states(place.remove_last_loc(), rhs.remove_last_loc())?; let flag = get_flag(place) || get_flag(rhs); set_flag(place, flag); @@ -275,9 +295,9 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { if !get_flag(place) { inner_mut.push_null(); } else if self.inner_nullable { - self.inner.merge_result(place.remove_last_loc(), builder)?; + self.nested.merge_result(place.remove_last_loc(), builder)?; } else { - self.inner + self.nested .merge_result(place.remove_last_loc(), &mut inner_mut.builder)?; inner_mut.validity.push(true); } @@ -293,21 +313,21 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { params: Vec, arguments: Vec, ) -> Result> { - self.inner + self.nested .get_own_null_adaptor(nested_function, params, arguments) } fn need_manual_drop_state(&self) -> bool { - self.inner.need_manual_drop_state() + self.nested.need_manual_drop_state() } unsafe fn drop_state(&self, place: AggrState) { - self.inner.drop_state(place.remove_last_loc()) + self.nested.drop_state(place.remove_last_loc()) } } impl fmt::Display for AggregateFunctionOrNullAdaptor { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.inner) + write!(f, "{}", self.nested) } } diff --git a/src/query/functions/src/aggregates/aggregate_combinator_if.rs b/src/query/functions/src/aggregates/aggregate_combinator_if.rs index 00c959a29a108..47b9ffee31f73 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_if.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_if.rs @@ -163,10 +163,29 @@ impl AggregateFunction for AggregateIfCombinator { self.nested.serialize(place, builders) } + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + self.nested.batch_serialize(places, loc, builders) + } + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { self.nested.merge(place, data) } + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &databend_common_expression::BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + self.nested.batch_merge(places, loc, state, filter) + } + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { self.nested.merge_states(place, rhs) } diff --git a/src/query/functions/src/aggregates/aggregate_combinator_state.rs b/src/query/functions/src/aggregates/aggregate_combinator_state.rs index 346ae3973186e..4f720331b7347 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_state.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_state.rs @@ -117,6 +117,15 @@ impl AggregateFunction for AggregateStateCombinator { self.nested.serialize(place, builders) } + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + self.nested.batch_serialize(places, loc, builders) + } + fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { self.nested.merge(place, data) } diff --git a/src/query/functions/src/aggregates/aggregate_sum.rs b/src/query/functions/src/aggregates/aggregate_sum.rs index 94596dc3c61b7..2b67315f66684 100644 --- a/src/query/functions/src/aggregates/aggregate_sum.rs +++ b/src/query/functions/src/aggregates/aggregate_sum.rs @@ -173,7 +173,7 @@ where } fn serialize_type() -> Vec { - std::vec![StateSerdeItem::DataType(T::data_type())] + std::vec![StateSerdeItem::DataType(N::data_type())] } fn serialize_state(&self, builders: &mut [ColumnBuilder]) -> Result<()> { diff --git a/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_state.test b/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_state.test index e425b986c7262..823f71a00e9ff 100644 --- a/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_state.test +++ b/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_state.test @@ -4,6 +4,6 @@ select length(max_state(number).1), typeof(max_state(number)) from numbers(100); 9 TUPLE(BINARY, BOOLEAN) query IT -select length(sum_state(number).1), typeof(max_state(number)) from numbers(10000); +select sum_state(number).1, typeof(sum_state(number)) from numbers(10000); ---- -8 TUPLE(BINARY, BOOLEAN) +49995000 TUPLE(UINT64, BOOLEAN) From a8743eeea0f02a8de8f90ca933c6584e5634cfdd Mon Sep 17 00:00:00 2001 From: coldWater Date: Sat, 26 Jul 2025 08:13:39 +0800 Subject: [PATCH 20/22] count --- .../src/aggregates/aggregate_count.rs | 54 ++++++++++++++++--- 1 file changed, 46 insertions(+), 8 deletions(-) diff --git a/src/query/functions/src/aggregates/aggregate_count.rs b/src/query/functions/src/aggregates/aggregate_count.rs index 280932b89f2b7..2a7fa0ad45226 100644 --- a/src/query/functions/src/aggregates/aggregate_count.rs +++ b/src/query/functions/src/aggregates/aggregate_count.rs @@ -16,15 +16,20 @@ use std::alloc::Layout; use std::fmt; use std::sync::Arc; -use borsh::BorshSerialize; use databend_common_exception::Result; use databend_common_expression::types::number::NumberColumnBuilder; +use databend_common_expression::types::AccessType; +use databend_common_expression::types::ArgType; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::types::NumberDataType; +use databend_common_expression::types::UInt64Type; +use databend_common_expression::types::UnaryType; +use databend_common_expression::types::ValueType; use databend_common_expression::utils::column_merge_validity; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; @@ -37,7 +42,6 @@ use super::aggregate_function_factory::AggregateFunctionDescription; use super::aggregate_function_factory::AggregateFunctionSortDesc; use super::StateAddr; use crate::aggregates::aggregator_common::assert_variadic_arguments; -use crate::aggregates::borsh_partial_deserialize; use crate::aggregates::AggrState; use crate::aggregates::AggrStateLoc; @@ -163,25 +167,59 @@ impl AggregateFunction for AggregateCountFunction { } fn serialize_type(&self) -> Vec { - vec![StateSerdeItem::Binary(None)] + vec![StateSerdeItem::DataType(UInt64Type::data_type())] } fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { - let binary_builder = builders[0].as_binary_mut().unwrap(); let state = place.get::(); - state.count.serialize(&mut binary_builder.data)?; - binary_builder.commit_row(); + UInt64Type::downcast_builder(&mut builders[0]).push(state.count); + Ok(()) + } + + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + let mut builder = UInt64Type::downcast_builder(&mut builders[0]); + for place in places { + let state = AggrState::new(*place, loc).get::(); + builder.push(state.count); + } Ok(()) } fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - let mut binary = *data[0].as_binary().unwrap(); let state = place.get::(); - let other: u64 = borsh_partial_deserialize(&mut binary)?; + let other = UInt64Type::try_downcast_scalar(&data[0]).unwrap(); state.count += other; Ok(()) } + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + if let Some(filter) = filter { + for (place, other) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + state.count += other; + } + } else { + for (place, other) in iter { + let state = AggrState::new(*place, loc).get::(); + state.count += other; + } + } + Ok(()) + } + fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { let state = place.get::(); let other = rhs.get::(); From db81bfe3379f48b122c9ee9a873e7261770ac132 Mon Sep 17 00:00:00 2001 From: coldWater Date: Sat, 26 Jul 2025 16:29:47 +0800 Subject: [PATCH 21/22] clean up --- .../src/aggregate/aggregate_function.rs | 139 +---- .../adaptors/aggregate_null_adaptor.rs | 476 ++---------------- .../adaptors/aggregate_ornull_adaptor.rs | 53 +- .../adaptors/aggregate_sort_adaptor.rs | 49 +- .../src/aggregates/aggregate_arg_min_max.rs | 46 +- .../src/aggregates/aggregate_array_agg.rs | 47 +- .../src/aggregates/aggregate_array_moving.rs | 92 +++- .../src/aggregates/aggregate_bitmap.rs | 84 +++- .../aggregate_combinator_distinct.rs | 49 +- .../src/aggregates/aggregate_combinator_if.rs | 9 - .../aggregates/aggregate_combinator_state.rs | 16 +- .../src/aggregates/aggregate_count.rs | 15 - .../src/aggregates/aggregate_covariance.rs | 47 +- .../aggregates/aggregate_json_array_agg.rs | 46 +- .../aggregates/aggregate_json_object_agg.rs | 44 +- .../src/aggregates/aggregate_markov_tarin.rs | 49 +- .../src/aggregates/aggregate_null_result.rs | 22 +- .../aggregates/aggregate_quantile_tdigest.rs | 45 +- .../aggregate_quantile_tdigest_weighted.rs | 46 +- .../src/aggregates/aggregate_retention.rs | 47 +- .../src/aggregates/aggregate_st_collect.rs | 48 +- .../src/aggregates/aggregate_string_agg.rs | 46 +- .../functions/src/aggregates/aggregate_sum.rs | 43 +- .../src/aggregates/aggregate_unary.rs | 64 ++- .../src/aggregates/aggregate_window_funnel.rs | 52 +- .../src/aggregates/aggregator_common.rs | 12 +- .../aggregator/transform_single_key.rs | 45 +- .../transforms/aggregator/udaf_script.rs | 71 ++- 28 files changed, 877 insertions(+), 925 deletions(-) diff --git a/src/query/expression/src/aggregate/aggregate_function.rs b/src/query/expression/src/aggregate/aggregate_function.rs index 700d9a2c77f11..4f880267ea69e 100755 --- a/src/query/expression/src/aggregate/aggregate_function.rs +++ b/src/query/expression/src/aggregate/aggregate_function.rs @@ -22,17 +22,11 @@ use super::AggrState; use super::AggrStateLoc; use super::AggrStateRegistry; use super::StateAddr; -use crate::types::AnyPairType; -use crate::types::AnyQuaternaryType; -use crate::types::AnyTernaryType; -use crate::types::AnyType; -use crate::types::AnyUnaryType; use crate::types::DataType; use crate::BlockEntry; use crate::ColumnBuilder; use crate::ProjectedBlock; use crate::Scalar; -use crate::ScalarRef; use crate::StateSerdeItem; use crate::StateSerdeType; @@ -82,21 +76,12 @@ pub trait AggregateFunction: fmt::Display + Sync + Send { serde_type.data_type() } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()>; - fn batch_serialize( &self, places: &[StateAddr], loc: &[AggrStateLoc], builders: &mut [ColumnBuilder], - ) -> Result<()> { - for place in places { - self.serialize(AggrState::new(*place, loc), builders)?; - } - Ok(()) - } - - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()>; + ) -> Result<()>; /// Batch deserialize the state and merge fn batch_merge( @@ -105,127 +90,7 @@ pub trait AggregateFunction: fmt::Display + Sync + Send { loc: &[AggrStateLoc], state: &BlockEntry, filter: Option<&Bitmap>, - ) -> Result<()> { - match state.data_type().as_tuple().unwrap().len() { - 1 => { - let view = state.downcast::().unwrap(); - let iter = places.iter().zip(view.iter()); - if let Some(filter) = filter { - for (place, data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) - { - self.merge(AggrState::new(*place, loc), &[data])?; - } - } else { - for (place, data) in iter { - self.merge(AggrState::new(*place, loc), &[data])?; - } - } - } - 2 => { - let view = state.downcast::().unwrap(); - let iter = places.iter().zip(view.iter()); - if let Some(filter) = filter { - for (place, data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) - { - self.merge(AggrState::new(*place, loc), &[data.0, data.1])?; - } - } else { - for (place, data) in iter { - self.merge(AggrState::new(*place, loc), &[data.0, data.1])?; - } - } - } - 3 => { - let view = state.downcast::().unwrap(); - let iter = places.iter().zip(view.iter()); - if let Some(filter) = filter { - for (place, data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) - { - self.merge(AggrState::new(*place, loc), &[data.0, data.1, data.2])?; - } - } else { - for (place, data) in iter { - self.merge(AggrState::new(*place, loc), &[data.0, data.1, data.2])?; - } - } - } - 4 => { - let view = state.downcast::().unwrap(); - let iter = places.iter().zip(view.iter()); - if let Some(filter) = filter { - for (place, data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) - { - self.merge(AggrState::new(*place, loc), &[ - data.0, data.1, data.2, data.3, - ])?; - } - } else { - for (place, data) in iter { - self.merge(AggrState::new(*place, loc), &[ - data.0, data.1, data.2, data.3, - ])?; - } - } - } - _ => { - let view = state.downcast::().unwrap(); - let iter = places.iter().zip(view.iter()); - if let Some(filter) = filter { - for (place, data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) - { - self.merge( - AggrState::new(*place, loc), - data.as_tuple().unwrap().as_slice(), - )?; - } - } else { - for (place, data) in iter { - self.merge( - AggrState::new(*place, loc), - data.as_tuple().unwrap().as_slice(), - )?; - } - } - } - } - Ok(()) - } - - fn batch_merge_single(&self, place: AggrState, state: &BlockEntry) -> Result<()> { - match state.data_type().as_tuple().unwrap().len() { - 1 => { - let view = state.downcast::().unwrap(); - for data in view.iter() { - self.merge(place, &[data])?; - } - } - 2 => { - let view = state.downcast::().unwrap(); - for data in view.iter() { - self.merge(place, &[data.0, data.1])?; - } - } - 3 => { - let view = state.downcast::().unwrap(); - for data in view.iter() { - self.merge(place, &[data.0, data.1, data.2])?; - } - } - 4 => { - let view = state.downcast::().unwrap(); - for data in view.iter() { - self.merge(place, &[data.0, data.1, data.2, data.3])?; - } - } - _ => { - let state = state.to_column(); - for data in state.iter() { - self.merge(place, data.as_tuple().unwrap().as_slice())?; - } - } - } - Ok(()) - } + ) -> Result<()>; fn batch_merge_states( &self, diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs index b438fb0edc455..a800a0d95413c 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs @@ -16,7 +16,6 @@ use std::fmt; use std::sync::Arc; use databend_common_exception::Result; -use databend_common_expression::types::AnyType; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::types::NumberDataType; @@ -26,7 +25,6 @@ use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; -use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use super::AggrState; @@ -187,10 +185,6 @@ impl AggregateFunction for AggregateNullUnaryAdapto self.0.serialize_type() } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { - self.0.serialize(place, builders) - } - fn batch_serialize( &self, places: &[StateAddr], @@ -200,10 +194,6 @@ impl AggregateFunction for AggregateNullUnaryAdapto self.0.batch_serialize(places, loc, builders) } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - self.0.merge(place, data) - } - fn batch_merge( &self, places: &[StateAddr], @@ -323,10 +313,6 @@ impl AggregateFunction self.0.serialize_type() } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { - self.0.serialize(place, builders) - } - fn batch_serialize( &self, places: &[StateAddr], @@ -336,10 +322,6 @@ impl AggregateFunction self.0.batch_serialize(places, loc, builders) } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - self.0.merge(place, data) - } - fn batch_merge( &self, places: &[StateAddr], @@ -531,23 +513,6 @@ impl CommonNullAdaptor { .collect() } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { - if !NULLABLE_RESULT { - return self.nested.serialize(place, builders); - } - let n = builders.len(); - debug_assert_eq!(self.nested.serialize_type().len() + 1, n); - - let flag = get_flag(place); - builders - .last_mut() - .and_then(ColumnBuilder::as_boolean_mut) - .unwrap() - .push(flag); - self.nested - .serialize(place.remove_last_loc(), &mut builders[..(n - 1)]) - } - fn batch_serialize( &self, places: &[StateAddr], @@ -571,21 +536,6 @@ impl CommonNullAdaptor { .batch_serialize(places, &loc[..loc.len() - 1], &mut builders[..(n - 1)]) } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - if !NULLABLE_RESULT { - return self.nested.merge(place, data); - } - - let flag = *data.last().and_then(ScalarRef::as_boolean).unwrap(); - if !flag { - return Ok(()); - } - - self.update_flag(place); - self.nested - .merge(place.remove_last_loc(), &data[..data.len() - 1]) - } - fn batch_merge( &self, places: &[StateAddr], @@ -623,16 +573,37 @@ impl CommonNullAdaptor { self.nested .batch_merge(places, &loc[..loc.len() - 1], &nested_state.into(), filter) } - _ => { - let state = state.downcast::().unwrap(); - for (place, data) in places.iter().zip(state.iter()) { - self.merge( - AggrState::new(*place, loc), - data.as_tuple().unwrap().as_slice(), - )?; - } - Ok(()) + BlockEntry::Const(Scalar::Tuple(tuple), DataType::Tuple(data_type), num_rows) => { + let flag = *tuple.last().and_then(Scalar::as_boolean).unwrap(); + let flag = Bitmap::new_constant(flag, *num_rows); + let flag = match filter { + Some(filter) => filter & &flag, + None => flag, + }; + let filter = if flag.null_count() == 0 { + for place in places.iter() { + self.update_flag(AggrState::new(*place, loc)); + } + None + } else { + for place in places + .iter() + .zip(flag.iter()) + .filter_map(|(place, flag)| flag.then_some(place)) + { + self.update_flag(AggrState::new(*place, loc)); + } + Some(&flag) + }; + let nested_state = BlockEntry::new_const_column( + DataType::Tuple(data_type[0..data_type.len() - 1].to_vec()), + Scalar::Tuple(tuple[0..tuple.len() - 1].to_vec()), + *num_rows, + ); + self.nested + .batch_merge(places, &loc[..loc.len() - 1], &nested_state, filter) } + _ => unreachable!(), } } @@ -703,390 +674,3 @@ fn get_flag(place: AggrState) -> bool { fn flag_offset(place: AggrState) -> usize { *place.loc.last().unwrap().as_bool().unwrap().1 } - -#[cfg(test)] -mod tests { - - use std::sync::Arc; - - use databend_common_expression::types::*; - use databend_common_expression::*; - - use super::*; - - struct TestStateBuffer { - buffer: Vec, - } - - impl TestStateBuffer { - fn new(layout: &std::alloc::Layout) -> Self { - let buffer = vec![1u8; layout.size()]; - Self { buffer } - } - - fn addr(&self) -> StateAddr { - StateAddr::new(self.buffer.as_ptr() as usize) - } - } - - #[derive(Clone)] - struct MockAggregateFunction { - return_type: DataType, - serialize_items: Vec, - } - - impl fmt::Display for MockAggregateFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "MockAggregateFunction") - } - } - - impl AggregateFunction for MockAggregateFunction { - fn name(&self) -> &str { - "mock" - } - - fn return_type(&self) -> Result { - Ok(self.return_type.clone()) - } - - fn init_state(&self, place: AggrState) { - place.write(|| 3_u64); - } - - fn register_state(&self, registry: &mut AggrStateRegistry) { - registry.register(AggrStateType::Custom(std::alloc::Layout::new::())); - } - - fn accumulate( - &self, - _place: AggrState, - _columns: ProjectedBlock, - _validity: Option<&Bitmap>, - _input_rows: usize, - ) -> Result<()> { - Ok(()) - } - - fn accumulate_row( - &self, - _place: AggrState, - _columns: ProjectedBlock, - _row: usize, - ) -> Result<()> { - Ok(()) - } - - fn serialize_type(&self) -> Vec { - self.serialize_items.clone() - } - - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { - let state = place.get::(); - if let ColumnBuilder::Number(NumberColumnBuilder::UInt64(builder)) = &mut builders[0] { - builder.push(*state); - } - Ok(()) - } - - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - let state = place.get::(); - if let Some(ScalarRef::Number(NumberScalar::UInt64(value))) = data.first() { - *state += value; - } - Ok(()) - } - - fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { - let lhs_state = place.get::(); - let rhs_state = rhs.get::(); - *lhs_state += *rhs_state; - Ok(()) - } - - fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { - let state = place.get::(); - if let ColumnBuilder::Number(NumberColumnBuilder::UInt64(builder)) = builder { - builder.push(*state); - } - Ok(()) - } - } - - #[derive(Debug, Clone, Copy, PartialEq, Eq)] - pub enum BatchMergeStrategy { - Reference, - Optimized, - } - - pub struct TestNullFunction { - adaptor: CommonNullAdaptor, - strategy: BatchMergeStrategy, - } - - impl TestNullFunction { - pub fn new(nested: AggregateFunctionRef, strategy: BatchMergeStrategy) -> Self { - Self { - adaptor: CommonNullAdaptor { nested }, - strategy, - } - } - - fn reference_batch_merge_impl( - &self, - places: &[StateAddr], - loc: &[AggrStateLoc], - state: &BlockEntry, - filter: Option<&Bitmap>, - ) -> Result<()> { - let view = state.downcast::().unwrap(); - let iter = places.iter().zip(view.iter()); - if let Some(filter) = filter { - for (place, data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { - self.merge( - AggrState::new(*place, loc), - data.as_tuple().unwrap().as_slice(), - )?; - } - } else { - for (place, data) in iter { - self.merge( - AggrState::new(*place, loc), - data.as_tuple().unwrap().as_slice(), - )?; - } - } - - Ok(()) - } - } - - impl fmt::Display for TestNullFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "ConfigurableNull") - } - } - - impl AggregateFunction for TestNullFunction { - fn name(&self) -> &str { - "test_null" - } - - fn return_type(&self) -> Result { - self.adaptor.return_type() - } - - fn init_state(&self, place: AggrState) { - self.adaptor.init_state(place) - } - - fn register_state(&self, registry: &mut AggrStateRegistry) { - self.adaptor.register_state(registry) - } - - fn accumulate( - &self, - place: AggrState, - columns: ProjectedBlock, - validity: Option<&Bitmap>, - input_rows: usize, - ) -> Result<()> { - self.adaptor - .accumulate(place, columns, validity.cloned(), input_rows) - } - - fn accumulate_row( - &self, - place: AggrState, - columns: ProjectedBlock, - row: usize, - ) -> Result<()> { - self.adaptor.accumulate_row(place, columns, None, row) - } - - fn serialize_type(&self) -> Vec { - self.adaptor.serialize_type() - } - - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { - self.adaptor.serialize(place, builders) - } - - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - self.adaptor.merge(place, data) - } - - fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { - self.adaptor.merge_states(place, rhs) - } - - fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { - self.adaptor.merge_result(place, builder) - } - - /// Custom batch_merge implementation that switches between reference and optimized strategies - fn batch_merge( - &self, - places: &[StateAddr], - loc: &[AggrStateLoc], - state: &BlockEntry, - filter: Option<&Bitmap>, - ) -> Result<()> { - match self.strategy { - BatchMergeStrategy::Reference => { - self.reference_batch_merge_impl(places, loc, state, filter) - } - BatchMergeStrategy::Optimized => { - self.adaptor.batch_merge(places, loc, state, filter) - } - } - } - } - - fn create_configurable(strategy: BatchMergeStrategy) -> TestNullFunction { - let nested = Arc::new(MockAggregateFunction { - return_type: DataType::Number(NumberDataType::UInt64), - serialize_items: vec![StateSerdeItem::DataType(DataType::Number( - NumberDataType::UInt64, - ))], - }); - TestNullFunction::new(nested, strategy) - } - - fn create_test_state_entry(values: &[u64], flags: &[bool]) -> Result { - // Build the state column as a tuple - let mut value_builder = - ColumnBuilder::with_capacity(&DataType::Number(NumberDataType::UInt64), values.len()); - let mut flag_builder = ColumnBuilder::with_capacity(&DataType::Boolean, flags.len()); - - for &value in values { - if let ColumnBuilder::Number(NumberColumnBuilder::UInt64(builder)) = &mut value_builder - { - builder.push(value); - } - } - - for &flag in flags { - if let ColumnBuilder::Boolean(builder) = &mut flag_builder { - builder.push(flag); - } - } - - let value_column = value_builder.build(); - let flag_column = flag_builder.build(); - let tuple_column = Column::Tuple(vec![value_column, flag_column]); - - Ok(BlockEntry::new(Value::Column(tuple_column), || { - ( - DataType::Tuple(vec![ - DataType::Number(NumberDataType::UInt64), - DataType::Boolean, - ]), - values.len(), - ) - })) - } - - fn run_batch_merge_with_strategy( - strategy: BatchMergeStrategy, - buffers: &[TestStateBuffer], - loc: &[AggrStateLoc], - state_entry: &BlockEntry, - filter: Option<&Bitmap>, - ) -> Result<()> { - let func = create_configurable(strategy); - let places: Vec = buffers.iter().map(|b| b.addr()).collect(); - - // Initialize states - for buffer in buffers { - let place = AggrState::new(buffer.addr(), loc); - func.init_state(place); - } - - // Run batch_merge - func.batch_merge(&places, loc, state_entry, filter) - } - - fn test_both_strategies(values: &[u64], flags: &[bool], filter: Option<&Bitmap>) -> Result<()> { - let state_entry = create_test_state_entry(values, flags)?; - let (reference_buffers, optimized_buffers, loc) = { - let buffer_count = values.len(); - let states_layout = - get_states_layout(&[Arc::new(create_configurable(BatchMergeStrategy::Reference))])?; - let layout = states_layout.layout; - let loc = states_layout.states_loc[0].to_vec(); - - let reference_buffers: Vec = (0..buffer_count) - .map(|_| TestStateBuffer::new(&layout)) - .collect(); - let optimized_buffers: Vec = (0..buffer_count) - .map(|_| TestStateBuffer::new(&layout)) - .collect(); - - (reference_buffers, optimized_buffers, loc) - }; - - // Run both strategies - run_batch_merge_with_strategy( - BatchMergeStrategy::Reference, - &reference_buffers, - &loc, - &state_entry, - filter, - )?; - - run_batch_merge_with_strategy( - BatchMergeStrategy::Optimized, - &optimized_buffers, - &loc, - &state_entry, - filter, - )?; - - for i in 0..reference_buffers.len() { - assert_eq!( - reference_buffers[i].buffer, optimized_buffers[i].buffer, - "Buffer contents should match at index {}", - i - ); - } - Ok(()) - } - - #[test] - fn test_batch_merge_equivalence() -> Result<()> { - { - let values = vec![10u64, 20u64, 30u64]; - let flags = vec![true, false, true]; - - test_both_strategies(&values, &flags, None)?; - } - - { - let values = vec![100u64, 200u64, 300u64, 400u64]; - let flags = vec![true, false, true, true]; // Skip index 1 - - test_both_strategies(&values, &flags, None)?; - } - - { - let values = vec![10u64, 20u64, 30u64, 40u64]; - let flags = vec![true, false, true, true]; - let filter_bits = [true, false, true, false]; // Only process indices 0 and 2 - let filter = Bitmap::from_iter(filter_bits.iter().copied()); - - test_both_strategies(&values, &flags, Some(&filter))?; - } - - { - let values = vec![100u64, 200u64, 300u64, 400u64]; - let flags = vec![true, true, false, true]; // Skip index 2 - let filter_bits = [true, false, true, true]; // Skip index 1 - let filter = Bitmap::from_iter(filter_bits.iter().copied()); - - test_both_strategies(&values, &flags, Some(&filter))?; - } - - Ok(()) - } -} diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs index be39946e94aba..318cd88e0c522 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs @@ -25,7 +25,6 @@ use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; -use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use super::AggrState; @@ -189,21 +188,6 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { .collect() } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { - let n = builders.len(); - debug_assert_eq!(self.nested.serialize_type().len() + 1, n); - - let flag = get_flag(place); - builders - .last_mut() - .and_then(ColumnBuilder::as_boolean_mut) - .unwrap() - .push(flag); - - self.nested - .serialize(place.remove_last_loc(), &mut builders[..n - 1]) - } - fn batch_serialize( &self, places: &[StateAddr], @@ -224,13 +208,6 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { .batch_serialize(places, &loc[..loc.len() - 1], &mut builders[..(n - 1)]) } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - merge_flag(place, *data.last().and_then(ScalarRef::as_boolean).unwrap()); - self.nested - .merge(place.remove_last_loc(), &data[..data.len() - 1])?; - Ok(()) - } - fn batch_merge( &self, places: &[StateAddr], @@ -256,26 +233,30 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { self.nested .batch_merge(places, &loc[0..loc.len() - 1], &inner_state, filter)?; } - _ => { - let state = state.to_column(); - let iter = places.iter().zip(state.iter()); + BlockEntry::Const(Scalar::Tuple(tuple), DataType::Tuple(data_type), num_rows) => { + let flag = *tuple.last().unwrap().as_boolean().unwrap(); if let Some(filter) = filter { - for (place, data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) + for place in places + .iter() + .zip(filter.iter()) + .filter_map(|(v, b)| b.then_some(v)) { - self.merge( - AggrState::new(*place, loc), - data.as_tuple().unwrap().as_slice(), - )?; + merge_flag(AggrState::new(*place, loc), flag); } } else { - for (place, data) in iter { - self.merge( - AggrState::new(*place, loc), - data.as_tuple().unwrap().as_slice(), - )?; + for place in places { + merge_flag(AggrState::new(*place, loc), flag); } } + let inner_state = BlockEntry::new_const_column( + DataType::Tuple(data_type[0..data_type.len() - 1].to_vec()), + Scalar::Tuple(tuple[0..tuple.len() - 1].to_vec()), + *num_rows, + ); + self.nested + .batch_merge(places, &loc[0..loc.len() - 1], &inner_state, filter)?; } + _ => unreachable!(), } Ok(()) diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_sort_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_sort_adaptor.rs index babf42ad2bce4..0cfd393c776ce 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_sort_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_sort_adaptor.rs @@ -23,17 +23,22 @@ use borsh::BorshSerialize; use databend_common_column::bitmap::Bitmap; use databend_common_exception::Result; use databend_common_expression::types::AnyType; +use databend_common_expression::types::BinaryType; use databend_common_expression::types::DataType; +use databend_common_expression::types::UnaryType; use databend_common_expression::AggrState; +use databend_common_expression::AggrStateLoc; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; use databend_common_expression::AggregateFunction; use databend_common_expression::AggregateFunctionRef; +use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::DataBlock; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; use databend_common_expression::SortColumnDescription; +use databend_common_expression::StateAddr; use databend_common_expression::StateSerdeItem; use itertools::Itertools; @@ -126,24 +131,44 @@ impl AggregateFunction for AggregateFunctionSortAdaptor { vec![StateSerdeItem::Binary(None)] } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { let binary_builder = builders[0].as_binary_mut().unwrap(); - let state = Self::get_state(place); - state.serialize(&mut binary_builder.data)?; - binary_builder.commit_row(); + for place in places { + let state = Self::get_state(AggrState::new(*place, loc)); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } Ok(()) } - fn merge( + fn batch_merge( &self, - place: AggrState, - data: &[databend_common_expression::ScalarRef], + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, ) -> Result<()> { - let mut binary = *data[0].as_binary().unwrap(); - let state = Self::get_state(place); - let rhs = SortAggState::deserialize(&mut binary)?; - - Self::merge_states_inner(state, &rhs); + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = Self::get_state(AggrState::new(*place, loc)); + let rhs = SortAggState::deserialize(&mut data)?; + Self::merge_states_inner(state, &rhs); + } + } else { + for (place, mut data) in iter { + let state = Self::get_state(AggrState::new(*place, loc)); + let rhs = SortAggState::deserialize(&mut data)?; + Self::merge_states_inner(state, &rhs); + } + } Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_arg_min_max.rs b/src/query/functions/src/aggregates/aggregate_arg_min_max.rs index f66bd2ec729d4..a1076cbf8e039 100644 --- a/src/query/functions/src/aggregates/aggregate_arg_min_max.rs +++ b/src/query/functions/src/aggregates/aggregate_arg_min_max.rs @@ -27,11 +27,11 @@ use databend_common_expression::types::*; use databend_common_expression::with_number_mapped_type; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::ColumnView; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; -use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use super::aggregate_function_factory::AggregateFunctionDescription; @@ -276,19 +276,45 @@ where vec![StateSerdeItem::Binary(None)] } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { let binary_builder = builders[0].as_binary_mut().unwrap(); - let state = place.get::(); - state.serialize(&mut binary_builder.data)?; - binary_builder.commit_row(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } Ok(()) } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - let mut binary = *data[0].as_binary().unwrap(); - let state = place.get::(); - let rhs: State = borsh_partial_deserialize(&mut binary)?; - state.merge_from(rhs) + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let rhs: State = borsh_partial_deserialize(&mut data)?; + state.merge_from(rhs)?; + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + let rhs: State = borsh_partial_deserialize(&mut data)?; + state.merge_from(rhs)?; + } + } + Ok(()) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_array_agg.rs b/src/query/functions/src/aggregates/aggregate_array_agg.rs index 98935efdee83e..b48dd221695c4 100644 --- a/src/query/functions/src/aggregates/aggregate_array_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_array_agg.rs @@ -43,6 +43,7 @@ use databend_common_expression::with_decimal_mapped_type; use databend_common_expression::with_number_mapped_type; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; @@ -553,20 +554,46 @@ where vec![StateSerdeItem::Binary(None)] } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { let binary_builder = builders[0].as_binary_mut().unwrap(); - let state = place.get::(); - state.serialize(&mut binary_builder.data)?; - binary_builder.commit_row(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } Ok(()) } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - let mut binary = *data[0].as_binary().unwrap(); - let state = place.get::(); - let rhs = State::deserialize_reader(&mut binary)?; - - state.merge(&rhs) + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let rhs = State::deserialize_reader(&mut data)?; + state.merge(&rhs)?; + } + } else { + for (place, data) in iter { + let mut binary = data; + let state = AggrState::new(*place, loc).get::(); + let rhs = State::deserialize_reader(&mut binary)?; + state.merge(&rhs)?; + } + } + Ok(()) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_array_moving.rs b/src/query/functions/src/aggregates/aggregate_array_moving.rs index c1742dd501842..ac4a7b50add8c 100644 --- a/src/query/functions/src/aggregates/aggregate_array_moving.rs +++ b/src/query/functions/src/aggregates/aggregate_array_moving.rs @@ -26,6 +26,7 @@ use databend_common_expression::types::i256; use databend_common_expression::types::number::Number; use databend_common_expression::types::AccessType; use databend_common_expression::types::ArgType; +use databend_common_expression::types::BinaryType; use databend_common_expression::types::Bitmap; use databend_common_expression::types::Buffer; use databend_common_expression::types::DataType; @@ -33,6 +34,7 @@ use databend_common_expression::types::Float64Type; use databend_common_expression::types::Int8Type; use databend_common_expression::types::NumberDataType; use databend_common_expression::types::NumberType; +use databend_common_expression::types::UnaryType; use databend_common_expression::types::F64; use databend_common_expression::utils::arithmetics_type::ResultTypeOfUnary; use databend_common_expression::with_decimal_mapped_type; @@ -452,20 +454,45 @@ where State: SumState vec![StateSerdeItem::Binary(None)] } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { let binary_builder = builders[0].as_binary_mut().unwrap(); - let state = place.get::(); - state.serialize(&mut binary_builder.data)?; - binary_builder.commit_row(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } Ok(()) } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - let mut binary = *data[0].as_binary().unwrap(); - let state = place.get::(); - let rhs = State::deserialize_reader(&mut binary)?; - - state.merge(&rhs) + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let rhs = State::deserialize_reader(&mut data)?; + state.merge(&rhs)?; + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + let rhs = State::deserialize_reader(&mut data)?; + state.merge(&rhs)?; + } + } + Ok(()) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { @@ -629,20 +656,45 @@ where State: SumState vec![StateSerdeItem::Binary(None)] } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { let binary_builder = builders[0].as_binary_mut().unwrap(); - let state = place.get::(); - state.serialize(&mut binary_builder.data)?; - binary_builder.commit_row(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } Ok(()) } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - let mut binary = *data[0].as_binary().unwrap(); - let state = place.get::(); - let rhs = State::deserialize_reader(&mut binary)?; - - state.merge(&rhs) + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let rhs = State::deserialize_reader(&mut data)?; + state.merge(&rhs)?; + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + let rhs = State::deserialize_reader(&mut data)?; + state.merge(&rhs)?; + } + } + Ok(()) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_bitmap.rs b/src/query/functions/src/aggregates/aggregate_bitmap.rs index d3c0f331be0aa..a5cf18c1e3d4d 100644 --- a/src/query/functions/src/aggregates/aggregate_bitmap.rs +++ b/src/query/functions/src/aggregates/aggregate_bitmap.rs @@ -33,10 +33,10 @@ use databend_common_expression::with_decimal_mapped_type; use databend_common_expression::with_number_mapped_type; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; -use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use databend_common_io::prelude::BinaryWrite; use roaring::RoaringTreemap; @@ -293,28 +293,57 @@ where vec![StateSerdeItem::Binary(None)] } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { let binary_builder = builders[0].as_binary_mut().unwrap(); - let state = place.get::(); - // flag indicate where bitmap is none - let flag: u8 = if state.rb.is_some() { 1 } else { 0 }; - binary_builder.data.write_scalar(&flag)?; - if let Some(rb) = &state.rb { - rb.serialize_into(&mut binary_builder.data)?; + for place in places { + let state = AggrState::new(*place, loc).get::(); + // flag indicate where bitmap is none + let flag: u8 = if state.rb.is_some() { 1 } else { 0 }; + binary_builder.data.write_scalar(&flag)?; + if let Some(rb) = &state.rb { + rb.serialize_into(&mut binary_builder.data)?; + } + binary_builder.commit_row(); } - binary_builder.commit_row(); Ok(()) } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - let mut binary = *data[0].as_binary().unwrap(); - let state = place.get::(); - - let flag = binary[0]; - binary.consume(1); - if flag == 1 { - let rb = deserialize_bitmap(binary)?; - state.add::(rb); + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let flag = data[0]; + data.consume(1); + if flag == 1 { + let rb = deserialize_bitmap(data)?; + state.add::(rb); + } + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + + let flag = data[0]; + data.consume(1); + if flag == 1 { + let rb = deserialize_bitmap(data)?; + state.add::(rb); + } + } } Ok(()) } @@ -495,12 +524,23 @@ where vec![StateSerdeItem::Binary(None)] } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { - self.inner.serialize(place, builders) + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + self.inner.batch_serialize(places, loc, builders) } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - self.inner.merge(place, data) + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + self.inner.batch_merge(places, loc, state, filter) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs b/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs index 989964ae1babd..da9994b82624f 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs @@ -19,16 +19,18 @@ use std::sync::Arc; use databend_common_exception::Result; use databend_common_expression::types::number::NumberColumnBuilder; +use databend_common_expression::types::BinaryType; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::types::NumberDataType; +use databend_common_expression::types::UnaryType; use databend_common_expression::with_number_mapped_type; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; -use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use super::aggregate_distinct_state::AggregateDistinctNumberState; @@ -44,6 +46,8 @@ use super::aggregate_function_factory::CombinatorDescription; use super::aggregator_common::assert_variadic_arguments; use super::AggregateCountFunction; use crate::aggregates::AggrState; +use crate::aggregates::AggrStateLoc; +use crate::aggregates::StateAddr; #[derive(Clone)] pub struct AggregateDistinctCombinator { @@ -112,20 +116,45 @@ where State: DistinctStateFunc vec![StateSerdeItem::Binary(None)] } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { let binary_builder = builders[0].as_binary_mut().unwrap(); - let state = Self::get_state(place); - state.serialize(&mut binary_builder.data)?; - binary_builder.commit_row(); + for place in places { + let state = Self::get_state(AggrState::new(*place, loc)); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } Ok(()) } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - let mut binary = *data[0].as_binary().unwrap(); - let state = Self::get_state(place); - let rhs = State::deserialize(&mut binary)?; + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); - state.merge(&rhs) + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = Self::get_state(AggrState::new(*place, loc)); + let rhs = State::deserialize(&mut data)?; + state.merge(&rhs)?; + } + } else { + for (place, mut data) in iter { + let state = Self::get_state(AggrState::new(*place, loc)); + let rhs = State::deserialize(&mut data)?; + state.merge(&rhs)?; + } + } + Ok(()) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_combinator_if.rs b/src/query/functions/src/aggregates/aggregate_combinator_if.rs index 47b9ffee31f73..4555f9b2cc85b 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_if.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_if.rs @@ -26,7 +26,6 @@ use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; -use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use super::StateAddr; @@ -159,10 +158,6 @@ impl AggregateFunction for AggregateIfCombinator { self.nested.serialize_type() } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { - self.nested.serialize(place, builders) - } - fn batch_serialize( &self, places: &[StateAddr], @@ -172,10 +167,6 @@ impl AggregateFunction for AggregateIfCombinator { self.nested.batch_serialize(places, loc, builders) } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - self.nested.merge(place, data) - } - fn batch_merge( &self, places: &[StateAddr], diff --git a/src/query/functions/src/aggregates/aggregate_combinator_state.rs b/src/query/functions/src/aggregates/aggregate_combinator_state.rs index 4f720331b7347..4c2bd220856fb 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_state.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_state.rs @@ -23,7 +23,6 @@ use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; -use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use super::AggregateFunctionFactory; @@ -113,10 +112,6 @@ impl AggregateFunction for AggregateStateCombinator { self.nested.serialize_type() } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { - self.nested.serialize(place, builders) - } - fn batch_serialize( &self, places: &[StateAddr], @@ -126,10 +121,6 @@ impl AggregateFunction for AggregateStateCombinator { self.nested.batch_serialize(places, loc, builders) } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - self.nested.merge(place, data) - } - fn batch_merge( &self, places: &[StateAddr], @@ -140,17 +131,14 @@ impl AggregateFunction for AggregateStateCombinator { self.nested.batch_merge(places, loc, state, filter) } - fn batch_merge_single(&self, place: AggrState, state: &BlockEntry) -> Result<()> { - self.nested.batch_merge_single(place, state) - } - fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { self.nested.merge_states(place, rhs) } fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { let builders = builder.as_tuple_mut().unwrap().as_mut_slice(); - self.nested.serialize(place, builders) + self.nested + .batch_serialize(&[place.addr], place.loc, builders) } fn need_manual_drop_state(&self) -> bool { diff --git a/src/query/functions/src/aggregates/aggregate_count.rs b/src/query/functions/src/aggregates/aggregate_count.rs index 2a7fa0ad45226..fb4ed78280fc9 100644 --- a/src/query/functions/src/aggregates/aggregate_count.rs +++ b/src/query/functions/src/aggregates/aggregate_count.rs @@ -18,7 +18,6 @@ use std::sync::Arc; use databend_common_exception::Result; use databend_common_expression::types::number::NumberColumnBuilder; -use databend_common_expression::types::AccessType; use databend_common_expression::types::ArgType; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; @@ -34,7 +33,6 @@ use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; -use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use super::aggregate_function::AggregateFunction; @@ -170,12 +168,6 @@ impl AggregateFunction for AggregateCountFunction { vec![StateSerdeItem::DataType(UInt64Type::data_type())] } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { - let state = place.get::(); - UInt64Type::downcast_builder(&mut builders[0]).push(state.count); - Ok(()) - } - fn batch_serialize( &self, places: &[StateAddr], @@ -190,13 +182,6 @@ impl AggregateFunction for AggregateCountFunction { Ok(()) } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - let state = place.get::(); - let other = UInt64Type::try_downcast_scalar(&data[0]).unwrap(); - state.count += other; - Ok(()) - } - fn batch_merge( &self, places: &[StateAddr], diff --git a/src/query/functions/src/aggregates/aggregate_covariance.rs b/src/query/functions/src/aggregates/aggregate_covariance.rs index fd30fb84cb6c8..0b4e7a52d4b06 100644 --- a/src/query/functions/src/aggregates/aggregate_covariance.rs +++ b/src/query/functions/src/aggregates/aggregate_covariance.rs @@ -23,18 +23,20 @@ use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::types::number::Number; use databend_common_expression::types::number::F64; +use databend_common_expression::types::BinaryType; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::types::NumberDataType; use databend_common_expression::types::NumberType; +use databend_common_expression::types::UnaryType; use databend_common_expression::types::ValueType; use databend_common_expression::with_number_mapped_type; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; -use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use num_traits::AsPrimitive; @@ -237,19 +239,44 @@ where vec![StateSerdeItem::Binary(None)] } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { let binary_builder = builders[0].as_binary_mut().unwrap(); - let state = place.get::(); - state.serialize(&mut binary_builder.data)?; - binary_builder.commit_row(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } Ok(()) } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - let mut binary = *data[0].as_binary().unwrap(); - let state = place.get::(); - let rhs: AggregateCovarianceState = borsh_partial_deserialize(&mut binary)?; - state.merge(&rhs); + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let rhs: AggregateCovarianceState = borsh_partial_deserialize(&mut data)?; + state.merge(&rhs); + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + let rhs: AggregateCovarianceState = borsh_partial_deserialize(&mut data)?; + state.merge(&rhs); + } + } Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_json_array_agg.rs b/src/query/functions/src/aggregates/aggregate_json_array_agg.rs index 642c15a81fff6..0f7ab2609dc56 100644 --- a/src/query/functions/src/aggregates/aggregate_json_array_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_json_array_agg.rs @@ -29,6 +29,7 @@ use databend_common_expression::types::ValueType; use databend_common_expression::types::*; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; @@ -245,20 +246,45 @@ where vec![StateSerdeItem::Binary(None)] } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { let binary_builder = builders[0].as_binary_mut().unwrap(); - let state = place.get::(); - state.serialize(&mut binary_builder.data)?; - binary_builder.commit_row(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } Ok(()) } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - let mut binary = *data[0].as_binary().unwrap(); - let state = place.get::(); - let rhs: State = borsh_partial_deserialize(&mut binary)?; - - state.merge(&rhs) + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let rhs: State = borsh_partial_deserialize(&mut data)?; + state.merge(&rhs)?; + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + let rhs: State = borsh_partial_deserialize(&mut data)?; + state.merge(&rhs)?; + } + } + Ok(()) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_json_object_agg.rs b/src/query/functions/src/aggregates/aggregate_json_object_agg.rs index 2d5782c75b6a5..ad9910d21a12e 100644 --- a/src/query/functions/src/aggregates/aggregate_json_object_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_json_object_agg.rs @@ -31,6 +31,7 @@ use databend_common_expression::types::ValueType; use databend_common_expression::types::*; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; @@ -298,20 +299,45 @@ where vec![StateSerdeItem::Binary(None)] } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { let binary_builder = builders[0].as_binary_mut().unwrap(); - let state = place.get::(); - state.serialize(&mut binary_builder.data)?; - binary_builder.commit_row(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } Ok(()) } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - let mut binary = *data[0].as_binary().unwrap(); - let state = place.get::(); - let rhs: State = borsh_partial_deserialize(&mut binary)?; + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); - state.merge(&rhs) + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let rhs: State = borsh_partial_deserialize(&mut data)?; + state.merge(&rhs)?; + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + let rhs: State = borsh_partial_deserialize(&mut data)?; + state.merge(&rhs)?; + } + } + Ok(()) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_markov_tarin.rs b/src/query/functions/src/aggregates/aggregate_markov_tarin.rs index 1293c512aab88..ccce62a529840 100644 --- a/src/query/functions/src/aggregates/aggregate_markov_tarin.rs +++ b/src/query/functions/src/aggregates/aggregate_markov_tarin.rs @@ -25,11 +25,13 @@ use databend_common_base::obfuscator::NGramHash; use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::types::ArgType; +use databend_common_expression::types::BinaryType; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::types::MapType; use databend_common_expression::types::StringType; use databend_common_expression::types::UInt32Type; +use databend_common_expression::types::UnaryType; use databend_common_expression::types::ValueType; use databend_common_expression::types::F64; use databend_common_expression::AggrState; @@ -40,7 +42,6 @@ use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; -use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use super::aggregate_function_factory::AggregateFunctionDescription; @@ -48,6 +49,8 @@ use super::assert_unary_arguments; use super::borsh_partial_deserialize; use super::extract_number_param; use super::AggregateFunction; +use super::StateAddr; +use crate::aggregates::AggrStateLoc; pub struct MarkovTarin { display_name: String, @@ -119,20 +122,44 @@ impl AggregateFunction for MarkovTarin { vec![StateSerdeItem::Binary(None)] } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { let binary_builder = builders[0].as_binary_mut().unwrap(); - let state = place.get::(); - state.serialize(&mut binary_builder.data)?; - binary_builder.commit_row(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } Ok(()) } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - let mut binary = *data[0].as_binary().unwrap(); - let state = place.get::(); - let mut rhs = borsh_partial_deserialize::(&mut binary)?; - state.merge(&mut rhs); - + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let mut rhs = borsh_partial_deserialize::(&mut data)?; + state.merge(&mut rhs); + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + let mut rhs = borsh_partial_deserialize::(&mut data)?; + state.merge(&mut rhs); + } + } Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_null_result.rs b/src/query/functions/src/aggregates/aggregate_null_result.rs index 66e5a385e4b41..92efe72ba8f51 100644 --- a/src/query/functions/src/aggregates/aggregate_null_result.rs +++ b/src/query/functions/src/aggregates/aggregate_null_result.rs @@ -23,9 +23,9 @@ use databend_common_expression::types::DataType; use databend_common_expression::types::ValueType; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; -use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use super::aggregate_function::AggregateFunction; @@ -92,14 +92,26 @@ impl AggregateFunction for AggregateNullResultFunction { vec![StateSerdeItem::Binary(None)] } - fn serialize(&self, _place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + fn batch_serialize( + &self, + places: &[StateAddr], + _loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { let binary_builder = builders[0].as_binary_mut().unwrap(); - binary_builder.commit_row(); + for _ in places { + binary_builder.commit_row(); + } Ok(()) } - fn merge(&self, _place: AggrState, data: &[ScalarRef]) -> Result<()> { - let _binary = *data[0].as_binary().unwrap(); + fn batch_merge( + &self, + _: &[StateAddr], + _: &[AggrStateLoc], + _: &BlockEntry, + _: Option<&Bitmap>, + ) -> Result<()> { Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs b/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs index e09146cfb2917..03422b26df1ea 100644 --- a/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs +++ b/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs @@ -31,6 +31,7 @@ use databend_common_expression::with_decimal_mapped_type; use databend_common_expression::with_number_mapped_type; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; @@ -368,19 +369,45 @@ where for<'a> T: AccessType = F64> + Send + Sync vec![StateSerdeItem::Binary(None)] } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { let binary_builder = builders[0].as_binary_mut().unwrap(); - let state = place.get::(); - state.serialize(&mut binary_builder.data)?; - binary_builder.commit_row(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } Ok(()) } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - let mut binary = *data[0].as_binary().unwrap(); - let state = place.get::(); - let mut rhs: QuantileTDigestState = borsh_partial_deserialize(&mut binary)?; - state.merge(&mut rhs) + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let mut rhs: QuantileTDigestState = borsh_partial_deserialize(&mut data)?; + state.merge(&mut rhs)?; + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + let mut rhs: QuantileTDigestState = borsh_partial_deserialize(&mut data)?; + state.merge(&mut rhs)?; + } + } + Ok(()) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs b/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs index 46afac0dd7b54..0f56b5c89fd06 100644 --- a/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs +++ b/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs @@ -28,10 +28,10 @@ use databend_common_expression::with_number_mapped_type; use databend_common_expression::with_unsigned_integer_mapped_type; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; -use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use num_traits::AsPrimitive; @@ -150,19 +150,45 @@ where vec![StateSerdeItem::Binary(None)] } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { let binary_builder = builders[0].as_binary_mut().unwrap(); - let state = place.get::(); - state.serialize(&mut binary_builder.data)?; - binary_builder.commit_row(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } Ok(()) } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - let mut binary = *data[0].as_binary().unwrap(); - let state = place.get::(); - let mut rhs: QuantileTDigestState = borsh_partial_deserialize(&mut binary)?; - state.merge(&mut rhs) + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let mut rhs: QuantileTDigestState = borsh_partial_deserialize(&mut data)?; + state.merge(&mut rhs)?; + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + let mut rhs: QuantileTDigestState = borsh_partial_deserialize(&mut data)?; + state.merge(&mut rhs)?; + } + } + Ok(()) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_retention.rs b/src/query/functions/src/aggregates/aggregate_retention.rs index 59d334f3bae8a..44b76bc932cc0 100644 --- a/src/query/functions/src/aggregates/aggregate_retention.rs +++ b/src/query/functions/src/aggregates/aggregate_retention.rs @@ -20,16 +20,18 @@ use borsh::BorshDeserialize; use borsh::BorshSerialize; use databend_common_exception::ErrorCode; use databend_common_exception::Result; +use databend_common_expression::types::BinaryType; use databend_common_expression::types::Bitmap; use databend_common_expression::types::BooleanType; use databend_common_expression::types::DataType; use databend_common_expression::types::NumberDataType; +use databend_common_expression::types::UnaryType; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; -use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use super::aggregate_function::AggregateFunction; @@ -149,19 +151,44 @@ impl AggregateFunction for AggregateRetentionFunction { vec![StateSerdeItem::Binary(None)] } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { let binary_builder = builders[0].as_binary_mut().unwrap(); - let state = place.get::(); - state.serialize(&mut binary_builder.data)?; - binary_builder.commit_row(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } Ok(()) } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - let mut binary = *data[0].as_binary().unwrap(); - let state = place.get::(); - let rhs: AggregateRetentionState = borsh_partial_deserialize(&mut binary)?; - state.merge(&rhs); + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let rhs: AggregateRetentionState = borsh_partial_deserialize(&mut data)?; + state.merge(&rhs); + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + let rhs: AggregateRetentionState = borsh_partial_deserialize(&mut data)?; + state.merge(&rhs); + } + } Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_st_collect.rs b/src/query/functions/src/aggregates/aggregate_st_collect.rs index 28ce7f30959e7..9356a10c2f6d7 100644 --- a/src/query/functions/src/aggregates/aggregate_st_collect.rs +++ b/src/query/functions/src/aggregates/aggregate_st_collect.rs @@ -23,12 +23,15 @@ use borsh::BorshSerialize; use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::types::ArgType; +use databend_common_expression::types::BinaryType; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::types::GeometryType; +use databend_common_expression::types::UnaryType; use databend_common_expression::types::ValueType; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; @@ -311,20 +314,45 @@ where vec![StateSerdeItem::Binary(None)] } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { let binary_builder = builders[0].as_binary_mut().unwrap(); - let state = place.get::(); - state.serialize(&mut binary_builder.data)?; - binary_builder.commit_row(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } Ok(()) } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - let mut binary = *data[0].as_binary().unwrap(); - let state = place.get::(); - let rhs: State = borsh_partial_deserialize(&mut binary)?; - - state.merge(&rhs) + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let rhs: State = borsh_partial_deserialize(&mut data)?; + state.merge(&rhs)?; + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + let rhs: State = borsh_partial_deserialize(&mut data)?; + state.merge(&rhs)?; + } + } + Ok(()) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_string_agg.rs b/src/query/functions/src/aggregates/aggregate_string_agg.rs index 10b4350cd32f3..4687fd599157c 100644 --- a/src/query/functions/src/aggregates/aggregate_string_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_string_agg.rs @@ -20,9 +20,11 @@ use borsh::BorshDeserialize; use borsh::BorshSerialize; use databend_common_exception::ErrorCode; use databend_common_exception::Result; +use databend_common_expression::types::BinaryType; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::types::StringType; +use databend_common_expression::types::UnaryType; use databend_common_expression::types::ValueType; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; @@ -34,7 +36,6 @@ use databend_common_expression::Evaluator; use databend_common_expression::FunctionContext; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; -use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use super::aggregate_function_factory::AggregateFunctionDescription; @@ -153,19 +154,44 @@ impl AggregateFunction for AggregateStringAggFunction { vec![StateSerdeItem::Binary(None)] } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { let binary_builder = builders[0].as_binary_mut().unwrap(); - let state = place.get::(); - state.serialize(&mut binary_builder.data)?; - binary_builder.commit_row(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } Ok(()) } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - let mut binary = *data[0].as_binary().unwrap(); - let state = place.get::(); - let rhs: StringAggState = borsh_partial_deserialize(&mut binary)?; - state.values.push_str(&rhs.values); + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let rhs: StringAggState = borsh_partial_deserialize(&mut data)?; + state.values.push_str(&rhs.values); + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + let rhs: StringAggState = borsh_partial_deserialize(&mut data)?; + state.values.push_str(&rhs.values); + } + } Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_sum.rs b/src/query/functions/src/aggregates/aggregate_sum.rs index 2b67315f66684..0b94e24fd52f7 100644 --- a/src/query/functions/src/aggregates/aggregate_sum.rs +++ b/src/query/functions/src/aggregates/aggregate_sum.rs @@ -28,11 +28,11 @@ use databend_common_expression::types::*; use databend_common_expression::utils::arithmetics_type::ResultTypeOfUnary; use databend_common_expression::with_decimal_mapped_type; use databend_common_expression::with_number_mapped_type; +use databend_common_expression::AggrState; use databend_common_expression::AggregateFunctionRef; use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::Scalar; -use databend_common_expression::ScalarRef; use databend_common_expression::StateAddr; use databend_common_expression::StateSerdeItem; use databend_common_expression::SELECTIVITY_THRESHOLD; @@ -176,14 +176,45 @@ where std::vec![StateSerdeItem::DataType(N::data_type())] } - fn serialize_state(&self, builders: &mut [ColumnBuilder]) -> Result<()> { - N::downcast_builder(&mut builders[0]).push_item(N::to_scalar_ref(&self.value)); + fn batch_serialize( + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + let mut builder = N::downcast_builder(&mut builders[0]); + for place in places { + let state: &mut Self = AggrState::new(*place, loc).get(); + builder.push_item(N::to_scalar_ref(&state.value)); + } Ok(()) } - fn deserialize_state(data: &[ScalarRef]) -> Result { - let value = N::to_owned_scalar(N::try_downcast_scalar(&data[0]).unwrap()); - Ok(Self { value }) + fn batch_merge( + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + if let Some(filter) = filter { + for (place, data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let rhs = Self { + value: N::to_owned_scalar(data), + }; + let state: &mut Self = AggrState::new(*place, loc).get(); + >::merge(state, &rhs)?; + } + } else { + for (place, data) in iter { + let rhs = Self { + value: N::to_owned_scalar(data), + }; + let state: &mut Self = AggrState::new(*place, loc).get(); + >::merge(state, &rhs)?; + } + } + Ok(()) } } diff --git a/src/query/functions/src/aggregates/aggregate_unary.rs b/src/query/functions/src/aggregates/aggregate_unary.rs index 629156d90f14b..0b880acb2830c 100644 --- a/src/query/functions/src/aggregates/aggregate_unary.rs +++ b/src/query/functions/src/aggregates/aggregate_unary.rs @@ -21,17 +21,19 @@ use std::sync::Arc; use databend_common_exception::Result; use databend_common_expression::types::AccessType; +use databend_common_expression::types::BinaryType; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; +use databend_common_expression::types::UnaryType; use databend_common_expression::types::ValueType; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; use databend_common_expression::AggregateFunction; use databend_common_expression::AggregateFunctionRef; +use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; -use databend_common_expression::ScalarRef; use databend_common_expression::StateAddr; use databend_common_expression::StateSerdeItem; @@ -85,16 +87,42 @@ where vec![StateSerdeItem::Binary(None)] } - fn serialize_state(&self, builders: &mut [ColumnBuilder]) -> Result<()> { + fn batch_serialize( + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { let binary_builder = builders[0].as_binary_mut().unwrap(); - self.serialize(&mut binary_builder.data)?; - binary_builder.commit_row(); + for place in places { + let state: &mut Self = AggrState::new(*place, loc).get(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } Ok(()) } - fn deserialize_state(data: &[ScalarRef]) -> Result { - let mut binary = *data[0].as_binary().unwrap(); - Ok(Self::deserialize_reader(&mut binary)?) + fn batch_merge( + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let rhs = Self::deserialize_reader(&mut data)?; + let state: &mut Self = AggrState::new(*place, loc).get(); + state.merge(&rhs)?; + } + } else { + for (place, mut data) in iter { + let rhs = Self::deserialize_reader(&mut data)?; + let state: &mut Self = AggrState::new(*place, loc).get(); + state.merge(&rhs)?; + } + } + Ok(()) } } @@ -249,15 +277,23 @@ where S::serialize_type() } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { - let state: &mut S = place.get::(); - state.serialize_state(builders) + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + S::batch_serialize(places, loc, builders) } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - let state: &mut S = place.get::(); - let rhs = S::deserialize_state(data)?; - state.merge(&rhs) + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + S::batch_merge(places, loc, state, filter) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_window_funnel.rs b/src/query/functions/src/aggregates/aggregate_window_funnel.rs index 616ec6e0e26fc..ece70c912870d 100644 --- a/src/query/functions/src/aggregates/aggregate_window_funnel.rs +++ b/src/query/functions/src/aggregates/aggregate_window_funnel.rs @@ -26,6 +26,7 @@ use databend_common_exception::Result; use databend_common_expression::types::number::Number; use databend_common_expression::types::number::UInt8Type; use databend_common_expression::types::ArgType; +use databend_common_expression::types::BinaryType; use databend_common_expression::types::Bitmap; use databend_common_expression::types::BooleanType; use databend_common_expression::types::DataType; @@ -33,14 +34,15 @@ use databend_common_expression::types::DateType; use databend_common_expression::types::NumberDataType; use databend_common_expression::types::NumberType; use databend_common_expression::types::TimestampType; +use databend_common_expression::types::UnaryType; use databend_common_expression::types::ValueType; use databend_common_expression::with_integer_mapped_type; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; -use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; use num_traits::AsPrimitive; @@ -281,20 +283,48 @@ where vec![StateSerdeItem::Binary(None)] } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { let binary_builder = builders[0].as_binary_mut().unwrap(); - let state = place.get::>(); - state.serialize(&mut binary_builder.data)?; - binary_builder.commit_row(); + for place in places { + let state = AggrState::new(*place, loc).get::>(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } Ok(()) } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - let mut binary = *data[0].as_binary().unwrap(); - let state = place.get::>(); - let mut rhs: AggregateWindowFunnelState = - borsh_partial_deserialize(&mut binary)?; - state.merge(&mut rhs); + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = + AggrState::new(*place, loc).get::>(); + let mut rhs: AggregateWindowFunnelState = + borsh_partial_deserialize(&mut data)?; + state.merge(&mut rhs); + } + } else { + for (place, mut data) in iter { + let state = + AggrState::new(*place, loc).get::>(); + let mut rhs: AggregateWindowFunnelState = + borsh_partial_deserialize(&mut data)?; + state.merge(&mut rhs); + } + } Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregator_common.rs b/src/query/functions/src/aggregates/aggregator_common.rs index 76f4264e59684..5fa214d4f89ce 100644 --- a/src/query/functions/src/aggregates/aggregator_common.rs +++ b/src/query/functions/src/aggregates/aggregator_common.rs @@ -29,7 +29,6 @@ use databend_common_expression::ColumnBuilder; use databend_common_expression::Constant; use databend_common_expression::FunctionContext; use databend_common_expression::Scalar; -use databend_common_expression::ScalarRef; use databend_common_expression::StateAddr; use super::get_states_layout; @@ -190,17 +189,10 @@ pub fn eval_aggr_for_test( let data_type = func.serialize_data_type(); let mut builder = ColumnBuilder::with_capacity(&data_type, 1); let builders = builder.as_tuple_mut().unwrap().as_mut_slice(); - func.serialize(state, builders)?; + func.batch_serialize(&[eval.addr], state.loc, builders)?; func.init_state(state); let column = builder.build(); - let data = column.index(0); - func.merge( - state, - data.as_ref() - .and_then(ScalarRef::as_tuple) - .unwrap() - .as_slice(), - )?; + func.batch_merge(&[eval.addr], state.loc, &column.into(), None)?; } let mut builder = ColumnBuilder::with_capacity(&data_type, 1024); func.merge_result(state, &mut builder)?; diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs index 12992aadf4c1e..19acf639becc0 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs @@ -108,15 +108,14 @@ impl AccumulatingTransform for PartialSingleStateAggregator { let states_indices = (start..block.num_columns()).collect::>(); let states = ProjectedBlock::project(&states_indices, &block); - for ((place, func), state) in self + for ((loc, func), state) in self .states_layout .states_loc .iter() - .map(|loc| AggrState::new(self.addr, loc)) .zip(self.funcs.iter()) .zip(states.iter()) { - func.batch_merge_single(place, state)?; + func.batch_merge(&[self.addr], loc, state, None)?; } } else { for ((place, columns), func) in self @@ -145,19 +144,14 @@ impl AccumulatingTransform for PartialSingleStateAggregator { let blocks = if generate_data { let mut builders = self.states_layout.serialize_builders(1); - for ((func, place), builder) in self + for ((func, loc), builder) in self .funcs .iter() - .zip( - self.states_layout - .states_loc - .iter() - .map(|loc| AggrState::new(self.addr, loc)), - ) + .zip(self.states_layout.states_loc.iter()) .zip(builders.iter_mut()) { let builders = builder.as_tuple_mut().unwrap().as_mut_slice(); - func.serialize(place, builders)?; + func.batch_serialize(&[self.addr], loc, builders)?; } let columns = builders.into_iter().map(|b| b.build()).collect(); @@ -246,20 +240,9 @@ impl AccumulatingTransform for FinalSingleStateAggregator { let main_addr: StateAddr = self.arena.alloc_layout(self.states_layout.layout).into(); - let main_places = self - .funcs - .iter() - .zip( - self.states_layout - .states_loc - .iter() - .map(|loc| AggrState::new(main_addr, loc)), - ) - .map(|(func, place)| { - func.init_state(place); - place - }) - .collect::>(); + for (func, loc) in self.funcs.iter().zip(self.states_layout.states_loc.iter()) { + func.init_state(AggrState::new(main_addr, loc)); + } let mut result_builders = self .funcs @@ -267,25 +250,25 @@ impl AccumulatingTransform for FinalSingleStateAggregator { .map(|f| Ok(ColumnBuilder::with_capacity(&f.return_type()?, 1))) .collect::>>()?; - for (idx, ((func, place), builder)) in self + for (idx, ((func, loc), builder)) in self .funcs .iter() - .zip(main_places.iter().copied()) + .zip(self.states_layout.states_loc.iter()) .zip(result_builders.iter_mut()) .enumerate() { for block in self.to_merge_data.iter() { - func.batch_merge_single(place, block.get_by_offset(idx))?; + func.batch_merge(&[main_addr], loc, block.get_by_offset(idx), None)?; } - func.merge_result(place, builder)?; + func.merge_result(AggrState::new(main_addr, loc), builder)?; } let columns = result_builders.into_iter().map(|b| b.build()).collect(); // destroy states - for (place, func) in main_places.iter().copied().zip(self.funcs.iter()) { + for (func, loc) in self.funcs.iter().zip(self.states_layout.states_loc.iter()) { if func.need_manual_drop_state() { - unsafe { func.drop_state(place) } + unsafe { func.drop_state(AggrState::new(main_addr, loc)) } } } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs index eac893078c0d9..5ddf677b0cd7c 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs @@ -28,20 +28,24 @@ use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::converts::arrow::ARROW_EXT_TYPE_VARIANT; use databend_common_expression::converts::arrow::EXTENSION_KEY; +use databend_common_expression::types::BinaryType; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; +use databend_common_expression::types::UnaryType; use databend_common_expression::AggrState; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::DataBlock; use databend_common_expression::DataField; use databend_common_expression::DataSchema; use databend_common_expression::ProjectedBlock; -use databend_common_expression::ScalarRef; use databend_common_expression::StateSerdeItem; +use databend_common_functions::aggregates::AggrStateLoc; use databend_common_functions::aggregates::AggregateFunction; +use databend_common_functions::aggregates::StateAddr; use databend_common_sql::plans::UDFLanguage; use databend_common_sql::plans::UDFScriptCode; @@ -104,27 +108,58 @@ impl AggregateFunction for AggregateUdfScript { vec![StateSerdeItem::Binary(None)] } - fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> { + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { let binary_builder = builders[0].as_binary_mut().unwrap(); - let state = place.get::(); - state - .serialize(&mut binary_builder.data) - .map_err(|e| ErrorCode::Internal(format!("state failed to serialize: {e}")))?; - binary_builder.commit_row(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state + .serialize(&mut binary_builder.data) + .map_err(|e| ErrorCode::Internal(format!("state failed to serialize: {e}")))?; + binary_builder.commit_row(); + } Ok(()) } - fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> { - let mut binary = *data[0].as_binary().unwrap(); - let state = place.get::(); - let rhs = UdfAggState::deserialize(&mut binary) - .map_err(|e| ErrorCode::Internal(e.to_string()))?; - let states = arrow_select::concat::concat(&[&state.0, &rhs.0])?; - let state = self - .runtime - .merge(&states) - .map_err(|e| ErrorCode::UDFRuntimeError(format!("failed to merge: {e}")))?; - place.write(|| state); + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let rhs = UdfAggState::deserialize(&mut data) + .map_err(|e| ErrorCode::Internal(e.to_string()))?; + let states = arrow_select::concat::concat(&[&state.0, &rhs.0])?; + let merged_state = self + .runtime + .merge(&states) + .map_err(|e| ErrorCode::UDFRuntimeError(format!("failed to merge: {e}")))?; + AggrState::new(*place, loc).write(|| merged_state); + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + let rhs = UdfAggState::deserialize(&mut data) + .map_err(|e| ErrorCode::Internal(e.to_string()))?; + let states = arrow_select::concat::concat(&[&state.0, &rhs.0])?; + let merged_state = self + .runtime + .merge(&states) + .map_err(|e| ErrorCode::UDFRuntimeError(format!("failed to merge: {e}")))?; + AggrState::new(*place, loc).write(|| merged_state); + } + } Ok(()) } From 3b18b074656ba32a20ddbbfb11b64984ec035d1d Mon Sep 17 00:00:00 2001 From: coldWater Date: Sat, 26 Jul 2025 17:21:13 +0800 Subject: [PATCH 22/22] fix --- .../suites/query/functions/02_0000_function_aggregate_mix.test | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_mix.test b/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_mix.test index cf535847b467f..85166f60d5b24 100644 --- a/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_mix.test +++ b/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_mix.test @@ -578,7 +578,7 @@ INSERT INTO convert_string (c1, c2, c3) VALUES ('a', 2, 2), ('b', 1, 2); -query IT +query IT rowsort select c1, string_agg(json_object('col', c2, 'no2', c3), ',') from convert_string group by c1; ---- a {"col":1,"no2":1},{"col":2,"no2":1},{"col":1,"no2":1},{"col":1,"no2":2},{"col":2,"no2":2}