diff --git a/src/common/base/src/runtime/profile/profile.rs b/src/common/base/src/runtime/profile/profile.rs index 58cfb503be56a..1dc288e7303f0 100644 --- a/src/common/base/src/runtime/profile/profile.rs +++ b/src/common/base/src/runtime/profile/profile.rs @@ -23,7 +23,7 @@ use crate::runtime::metrics::ScopedRegistry; use crate::runtime::profile::ProfileStatisticsName; use crate::runtime::ThreadTracker; -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct ProfileLabel { pub name: String, pub value: Vec, diff --git a/src/common/storage/src/copy.rs b/src/common/storage/src/copy.rs index 927835fdb496a..a3464e33ed4fe 100644 --- a/src/common/storage/src/copy.rs +++ b/src/common/storage/src/copy.rs @@ -20,7 +20,7 @@ use serde::Deserialize; use serde::Serialize; use thiserror::Error; -#[derive(Default, Clone, Serialize, Deserialize)] +#[derive(Debug, Default, Clone, Serialize, Deserialize)] pub struct CopyStatus { /// Key is file path. pub files: DashMap, @@ -45,7 +45,7 @@ impl CopyStatus { } } -#[derive(Default, Clone, Serialize, Deserialize)] +#[derive(Debug, Default, Clone, Serialize, Deserialize)] pub struct FileStatus { pub num_rows_loaded: usize, pub error: Option, @@ -79,7 +79,7 @@ impl FileStatus { } } -#[derive(Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct FileErrorsInfo { pub num_errors: usize, pub first_error: FileParseErrorAtLine, @@ -156,7 +156,7 @@ impl FileParseError { } } -#[derive(Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct FileParseErrorAtLine { pub error: FileParseError, pub line: usize, diff --git a/src/common/storage/src/merge.rs b/src/common/storage/src/merge.rs index 9f4fe94080bc4..50125e8823c90 100644 --- a/src/common/storage/src/merge.rs +++ b/src/common/storage/src/merge.rs @@ -15,7 +15,7 @@ use serde::Deserialize; use serde::Serialize; -#[derive(Default, Clone, Serialize, Deserialize)] +#[derive(Debug, Default, Clone, Serialize, Deserialize)] pub struct MutationStatus { pub insert_rows: u64, pub deleted_rows: u64, diff --git a/src/query/catalog/src/plan/agg_index.rs b/src/query/catalog/src/plan/agg_index.rs index 12efe2937be56..65038a419ca8f 100644 --- a/src/query/catalog/src/plan/agg_index.rs +++ b/src/query/catalog/src/plan/agg_index.rs @@ -51,7 +51,7 @@ pub struct AggIndexInfo { } /// This meta just indicate the block is from aggregating index. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub struct AggIndexMeta { pub is_agg: bool, // Number of aggregation functions. @@ -75,6 +75,6 @@ local_block_meta_serde!(AggIndexMeta); #[typetag::serde(name = "agg_index_meta")] impl BlockMetaInfo for AggIndexMeta { fn clone_self(&self) -> Box { - Box::new(self.clone()) + Box::new(*self) } } diff --git a/src/query/catalog/src/statistics/data_cache_statistics.rs b/src/query/catalog/src/statistics/data_cache_statistics.rs index 04ae435204eb5..6838a93cc212d 100644 --- a/src/query/catalog/src/statistics/data_cache_statistics.rs +++ b/src/query/catalog/src/statistics/data_cache_statistics.rs @@ -24,7 +24,7 @@ pub struct DataCacheMetrics { bytes_from_memory: AtomicUsize, } -#[derive(Default, Clone, Serialize, Deserialize)] +#[derive(Debug, Default, Clone, Serialize, Deserialize)] pub struct DataCacheMetricValues { pub bytes_from_remote_disk: usize, pub bytes_from_local_disk: usize, diff --git a/src/query/expression/src/aggregate/aggregate_function.rs b/src/query/expression/src/aggregate/aggregate_function.rs index fe446c3ed1cf4..4f880267ea69e 100755 --- a/src/query/expression/src/aggregate/aggregate_function.rs +++ b/src/query/expression/src/aggregate/aggregate_function.rs @@ -22,13 +22,13 @@ use super::AggrState; use super::AggrStateLoc; use super::AggrStateRegistry; use super::StateAddr; -use crate::types::BinaryType; use crate::types::DataType; use crate::BlockEntry; use crate::ColumnBuilder; -use crate::ColumnView; use crate::ProjectedBlock; use crate::Scalar; +use crate::StateSerdeItem; +use crate::StateSerdeType; pub type AggregateFunctionRef = Arc; @@ -69,35 +69,28 @@ pub trait AggregateFunction: fmt::Display + Sync + Send { // Used in aggregate_null_adaptor fn accumulate_row(&self, place: AggrState, columns: ProjectedBlock, row: usize) -> Result<()>; - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()>; + fn serialize_type(&self) -> Vec; - fn serialize_size_per_row(&self) -> Option { - None + fn serialize_data_type(&self) -> DataType { + let serde_type = StateSerdeType::new(self.serialize_type()); + serde_type.data_type() } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()>; + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()>; - /// Batch merge and deserialize the state from binary array + /// Batch deserialize the state and merge fn batch_merge( &self, places: &[StateAddr], loc: &[AggrStateLoc], - state: &ColumnView, - ) -> Result<()> { - for (place, mut data) in places.iter().zip(state.iter()) { - self.merge(AggrState::new(*place, loc), &mut data)?; - } - - Ok(()) - } - - fn batch_merge_single(&self, place: AggrState, state: &BlockEntry) -> Result<()> { - let view = state.downcast::().unwrap(); - for mut data in view.iter() { - self.merge(place, &mut data)?; - } - Ok(()) - } + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()>; fn batch_merge_states( &self, @@ -149,9 +142,4 @@ pub trait AggregateFunction: fmt::Display + Sync + Send { fn get_if_condition(&self, _columns: ProjectedBlock) -> Option { None } - - // some features - fn convert_const_to_full(&self) -> bool { - true - } } diff --git a/src/query/expression/src/aggregate/aggregate_function_state.rs b/src/query/expression/src/aggregate/aggregate_function_state.rs index 5d1fe21ac4379..dafd4321b0c25 100644 --- a/src/query/expression/src/aggregate/aggregate_function_state.rs +++ b/src/query/expression/src/aggregate/aggregate_function_state.rs @@ -20,6 +20,8 @@ use enum_as_inner::EnumAsInner; use super::AggregateFunctionRef; use crate::types::binary::BinaryColumnBuilder; +use crate::types::DataType; +use crate::ColumnBuilder; #[derive(Clone, Copy, Debug)] pub struct StateAddr { @@ -113,11 +115,11 @@ impl From for usize { pub fn get_states_layout(funcs: &[AggregateFunctionRef]) -> Result { let mut registry = AggrStateRegistry::default(); - let mut serialize_size = Vec::with_capacity(funcs.len()); + let mut serialize_type = Vec::with_capacity(funcs.len()); for func in funcs { func.register_state(&mut registry); registry.commit(); - serialize_size.push(func.serialize_size_per_row()); + serialize_type.push(StateSerdeType(func.serialize_type().into())); } let AggrStateRegistry { states, offsets } = registry; @@ -132,7 +134,7 @@ pub fn get_states_layout(funcs: &[AggregateFunctionRef]) -> Result Ok(StatesLayout { layout, states_loc, - serialize_size, + serialize_type, }) } @@ -191,18 +193,66 @@ impl AggrStateLoc { } } +#[derive(Debug, Clone)] +pub enum StateSerdeItem { + DataType(DataType), + Binary(Option), +} + +#[derive(Debug, Clone)] +pub struct StateSerdeType(Box<[StateSerdeItem]>); + +impl StateSerdeType { + pub fn new(items: impl Into>) -> Self { + StateSerdeType(items.into()) + } + + pub fn data_type(&self) -> DataType { + DataType::Tuple( + self.0 + .iter() + .map(|item| match item { + StateSerdeItem::DataType(data_type) => data_type.clone(), + StateSerdeItem::Binary(_) => DataType::Binary, + }) + .collect(), + ) + } +} + #[derive(Debug, Clone)] pub struct StatesLayout { pub layout: Layout, pub states_loc: Vec>, - serialize_size: Vec>, + pub(super) serialize_type: Vec, } impl StatesLayout { - pub fn serialize_builders(&self, num_rows: usize) -> Vec { - self.serialize_size + pub fn num_aggr_func(&self) -> usize { + self.states_loc.len() + } + + pub fn serialize_builders(&self, num_rows: usize) -> Vec { + self.serialize_type .iter() - .map(|size| BinaryColumnBuilder::with_capacity(num_rows, num_rows * size.unwrap_or(0))) + .map(|serde_type| { + let builder = serde_type + .0 + .iter() + .map(|item| match item { + StateSerdeItem::DataType(data_type) => { + ColumnBuilder::with_capacity(data_type, num_rows) + } + StateSerdeItem::Binary(size) => { + ColumnBuilder::Binary(BinaryColumnBuilder::with_capacity( + num_rows, + num_rows * size.unwrap_or(0), + )) + } + }) + .collect(); + ColumnBuilder::Tuple(builder) + }) .collect() } } diff --git a/src/query/expression/src/aggregate/aggregate_hashtable.rs b/src/query/expression/src/aggregate/aggregate_hashtable.rs index 11907059ef90e..2956ae9026d6d 100644 --- a/src/query/expression/src/aggregate/aggregate_hashtable.rs +++ b/src/query/expression/src/aggregate/aggregate_hashtable.rs @@ -27,7 +27,6 @@ use crate::aggregate::payload_row::row_match_columns; use crate::group_hash_columns; use crate::new_sel; use crate::read; -use crate::types::BinaryType; use crate::types::DataType; use crate::AggregateFunctionRef; use crate::BlockEntry; @@ -219,7 +218,7 @@ impl AggregateHashTable { .zip(agg_states.iter()) .zip(states_layout.states_loc.iter()) { - func.batch_merge(state_places, loc, &state.downcast::().unwrap())?; + func.batch_merge(state_places, loc, state, None)?; } } } diff --git a/src/query/expression/src/aggregate/payload.rs b/src/query/expression/src/aggregate/payload.rs index aafe2921c8532..925293beb214c 100644 --- a/src/query/expression/src/aggregate/payload.rs +++ b/src/query/expression/src/aggregate/payload.rs @@ -421,11 +421,15 @@ impl Payload { true } - pub fn empty_block(&self, fake_rows: Option) -> DataBlock { - let fake_rows = fake_rows.unwrap_or(0); - let entries = (0..self.aggrs.len()) - .map(|_| { - ColumnBuilder::repeat_default(&DataType::Binary, fake_rows) + pub fn empty_block(&self, fake_rows: usize) -> DataBlock { + assert_eq!(self.aggrs.is_empty(), self.states_layout.is_none()); + let entries = self + .states_layout + .as_ref() + .iter() + .flat_map(|layout| layout.serialize_type.iter()) + .map(|serde_type| { + ColumnBuilder::repeat_default(&serde_type.data_type(), fake_rows) .build() .into() }) diff --git a/src/query/expression/src/aggregate/payload_flush.rs b/src/query/expression/src/aggregate/payload_flush.rs index 43dabd3b519c2..317737fca0960 100644 --- a/src/query/expression/src/aggregate/payload_flush.rs +++ b/src/query/expression/src/aggregate/payload_flush.rs @@ -18,7 +18,6 @@ use databend_common_io::prelude::bincode_deserialize_from_slice; use super::partitioned_payload::PartitionedPayload; use super::payload::Payload; use super::probe_state::ProbeState; -use super::AggrState; use crate::read; use crate::types::binary::BinaryColumn; use crate::types::binary::BinaryColumnBuilder; @@ -125,7 +124,7 @@ impl Payload { } if blocks.is_empty() { - return Ok(self.empty_block(None)); + return Ok(self.empty_block(0)); } DataBlock::concat(&blocks) } @@ -141,26 +140,17 @@ impl Payload { if let Some(state_layout) = self.states_layout.as_ref() { let mut builders = state_layout.serialize_builders(row_count); - for place in state.state_places.as_slice()[0..row_count].iter() { - for (idx, (loc, func)) in state_layout - .states_loc - .iter() - .zip(self.aggrs.iter()) - .enumerate() - { - { - let builder = &mut builders[idx]; - func.serialize(AggrState::new(*place, loc), &mut builder.data)?; - builder.commit_row(); - } - } + for ((loc, func), builder) in state_layout + .states_loc + .iter() + .zip(self.aggrs.iter()) + .zip(builders.iter_mut()) + { + let builders = builder.as_tuple_mut().unwrap().as_mut_slice(); + func.batch_serialize(&state.state_places.as_slice()[0..row_count], loc, builders)?; } - entries.extend( - builders - .into_iter() - .map(|builder| Column::Binary(builder.build()).into()), - ); + entries.extend(builders.into_iter().map(|builder| builder.build().into())); } entries.extend_from_slice(&state.take_group_columns()); @@ -177,7 +167,7 @@ impl Payload { } if blocks.is_empty() { - return Ok(self.empty_block(None)); + return Ok(self.empty_block(0)); } DataBlock::concat(&blocks) diff --git a/src/query/expression/src/types.rs b/src/query/expression/src/types.rs index 6e19868368e3e..ff414b39ceabc 100755 --- a/src/query/expression/src/types.rs +++ b/src/query/expression/src/types.rs @@ -34,6 +34,7 @@ pub mod number_class; pub mod simple_type; pub mod string; pub mod timestamp; +pub mod tuple; pub mod variant; pub mod vector; pub mod zero_size_type; @@ -82,6 +83,7 @@ use self::simple_type::*; pub use self::string::StringColumn; pub use self::string::StringType; pub use self::timestamp::TimestampType; +pub use self::tuple::*; pub use self::variant::VariantType; pub use self::vector::VectorColumn; pub use self::vector::VectorColumnBuilder; diff --git a/src/query/expression/src/types/tuple.rs b/src/query/expression/src/types/tuple.rs new file mode 100644 index 0000000000000..8100da50504f6 --- /dev/null +++ b/src/query/expression/src/types/tuple.rs @@ -0,0 +1,420 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::cmp::Ordering; +use std::iter::Map; +use std::iter::Zip; +use std::marker::PhantomData; + +use super::AccessType; +use super::AnyType; +use crate::Column; +use crate::Domain; +use crate::ScalarRef; + +type Iter<'a, T> = ::ColumnIterator<'a>; + +type SR<'a, T> = ::ScalarRef<'a>; + +pub type AnyUnaryType = UnaryType; +pub type AnyPairType = PairType; +pub type AnyTernaryType = TernaryType; +pub type AnyQuaternaryType = QuaternaryType; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct UnaryType(PhantomData); + +impl AccessType for UnaryType +where A: AccessType +{ + type Scalar = A::Scalar; + type ScalarRef<'a> = A::ScalarRef<'a>; + type Column = A::Column; + type Domain = A::Domain; + type ColumnIterator<'a> = A::ColumnIterator<'a>; + + fn to_owned_scalar(scalar: Self::ScalarRef<'_>) -> Self::Scalar { + A::to_owned_scalar(scalar) + } + + fn to_scalar_ref(scalar: &Self::Scalar) -> Self::ScalarRef<'_> { + A::to_scalar_ref(scalar) + } + + fn try_downcast_scalar<'a>(scalar: &ScalarRef<'a>) -> Option> { + let [a] = scalar.as_tuple()?.as_slice() else { + return None; + }; + A::try_downcast_scalar(a) + } + + fn try_downcast_domain(domain: &Domain) -> Option { + let [a] = domain.as_tuple()?.as_slice() else { + return None; + }; + A::try_downcast_domain(a) + } + + fn try_downcast_column(col: &Column) -> Option { + let [a] = col.as_tuple()?.as_slice() else { + return None; + }; + A::try_downcast_column(a) + } + + fn column_len(col: &Self::Column) -> usize { + A::column_len(col) + } + + fn index_column(col: &Self::Column, index: usize) -> Option> { + A::index_column(col, index) + } + + unsafe fn index_column_unchecked(col: &Self::Column, index: usize) -> Self::ScalarRef<'_> { + A::index_column_unchecked(col, index) + } + + fn slice_column(col: &Self::Column, range: std::ops::Range) -> Self::Column { + A::slice_column(col, range) + } + + fn iter_column(col: &Self::Column) -> Self::ColumnIterator<'_> { + A::iter_column(col) + } + + fn compare(lhs: Self::ScalarRef<'_>, rhs: Self::ScalarRef<'_>) -> Ordering { + A::compare(lhs, rhs) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PairType(PhantomData<(A, B)>); + +impl AccessType for PairType +where + A: AccessType, + B: AccessType, +{ + type Scalar = (A::Scalar, B::Scalar); + type ScalarRef<'a> = (A::ScalarRef<'a>, B::ScalarRef<'a>); + type Column = (A::Column, B::Column); + type Domain = (A::Domain, B::Domain); + type ColumnIterator<'a> = Zip, Iter<'a, B>>; + + fn to_owned_scalar((a, b): Self::ScalarRef<'_>) -> Self::Scalar { + (A::to_owned_scalar(a), B::to_owned_scalar(b)) + } + + fn to_scalar_ref((a, b): &Self::Scalar) -> Self::ScalarRef<'_> { + (A::to_scalar_ref(a), B::to_scalar_ref(b)) + } + + fn try_downcast_scalar<'a>(scalar: &ScalarRef<'a>) -> Option> { + let [a, b] = scalar.as_tuple()?.as_slice() else { + return None; + }; + Some((A::try_downcast_scalar(a)?, B::try_downcast_scalar(b)?)) + } + + fn try_downcast_domain(domain: &Domain) -> Option { + let [a, b] = domain.as_tuple()?.as_slice() else { + return None; + }; + Some((A::try_downcast_domain(a)?, B::try_downcast_domain(b)?)) + } + + fn try_downcast_column(col: &Column) -> Option { + let [a, b] = col.as_tuple()?.as_slice() else { + return None; + }; + Some((A::try_downcast_column(a)?, B::try_downcast_column(b)?)) + } + + fn column_len((a, _b): &Self::Column) -> usize { + debug_assert_eq!(A::column_len(a), B::column_len(_b)); + A::column_len(a) + } + + fn index_column((a, b): &Self::Column, index: usize) -> Option> { + Some((A::index_column(a, index)?, B::index_column(b, index)?)) + } + + unsafe fn index_column_unchecked((a, b): &Self::Column, index: usize) -> Self::ScalarRef<'_> { + ( + A::index_column_unchecked(a, index), + B::index_column_unchecked(b, index), + ) + } + + fn slice_column((a, b): &Self::Column, range: std::ops::Range) -> Self::Column { + (A::slice_column(a, range.clone()), B::slice_column(b, range)) + } + + fn iter_column((a, b): &Self::Column) -> Self::ColumnIterator<'_> { + A::iter_column(a).zip(B::iter_column(b)) + } + + fn compare( + (lhs_a, lhs_b): Self::ScalarRef<'_>, + (rhs_a, rhs_b): Self::ScalarRef<'_>, + ) -> Ordering { + A::compare(lhs_a, rhs_a).then_with(|| B::compare(lhs_b, rhs_b)) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TernaryType(PhantomData<(A, B, C)>); + +impl AccessType for TernaryType +where + A: AccessType, + B: AccessType, + C: AccessType, +{ + type Scalar = (A::Scalar, B::Scalar, C::Scalar); + type ScalarRef<'a> = (A::ScalarRef<'a>, B::ScalarRef<'a>, C::ScalarRef<'a>); + type Column = (A::Column, B::Column, C::Column); + type Domain = (A::Domain, B::Domain, C::Domain); + type ColumnIterator<'a> = Map< + Zip, Iter<'a, B>>, Iter<'a, C>>, + fn(((SR<'a, A>, SR<'a, B>), SR<'a, C>)) -> Self::ScalarRef<'a>, + >; + + fn to_owned_scalar((a, b, c): Self::ScalarRef<'_>) -> Self::Scalar { + ( + A::to_owned_scalar(a), + B::to_owned_scalar(b), + C::to_owned_scalar(c), + ) + } + + fn to_scalar_ref((a, b, c): &Self::Scalar) -> Self::ScalarRef<'_> { + ( + A::to_scalar_ref(a), + B::to_scalar_ref(b), + C::to_scalar_ref(c), + ) + } + + fn try_downcast_scalar<'a>(scalar: &ScalarRef<'a>) -> Option> { + let [a, b, c] = scalar.as_tuple()?.as_slice() else { + return None; + }; + Some(( + A::try_downcast_scalar(a)?, + B::try_downcast_scalar(b)?, + C::try_downcast_scalar(c)?, + )) + } + + fn try_downcast_domain(domain: &Domain) -> Option { + let [a, b, c] = domain.as_tuple()?.as_slice() else { + return None; + }; + Some(( + A::try_downcast_domain(a)?, + B::try_downcast_domain(b)?, + C::try_downcast_domain(c)?, + )) + } + + fn try_downcast_column(col: &Column) -> Option { + let [a, b, c] = col.as_tuple()?.as_slice() else { + return None; + }; + Some(( + A::try_downcast_column(a)?, + B::try_downcast_column(b)?, + C::try_downcast_column(c)?, + )) + } + + fn column_len((a, _b, _c): &Self::Column) -> usize { + debug_assert_eq!(A::column_len(a), B::column_len(_b)); + debug_assert_eq!(A::column_len(a), C::column_len(_c)); + A::column_len(a) + } + + fn index_column((a, b, c): &Self::Column, index: usize) -> Option> { + Some(( + A::index_column(a, index)?, + B::index_column(b, index)?, + C::index_column(c, index)?, + )) + } + + unsafe fn index_column_unchecked( + (a, b, c): &Self::Column, + index: usize, + ) -> Self::ScalarRef<'_> { + ( + A::index_column_unchecked(a, index), + B::index_column_unchecked(b, index), + C::index_column_unchecked(c, index), + ) + } + + fn slice_column((a, b, c): &Self::Column, range: std::ops::Range) -> Self::Column { + ( + A::slice_column(a, range.clone()), + B::slice_column(b, range.clone()), + C::slice_column(c, range), + ) + } + + fn iter_column((a, b, c): &Self::Column) -> Self::ColumnIterator<'_> { + A::iter_column(a) + .zip(B::iter_column(b)) + .zip(C::iter_column(c)) + .map(|((a, b), c)| (a, b, c)) + } + + fn compare( + (lhs_a, lhs_b, lhs_c): Self::ScalarRef<'_>, + (rhs_a, rhs_b, rhs_c): Self::ScalarRef<'_>, + ) -> Ordering { + A::compare(lhs_a, rhs_a) + .then_with(|| B::compare(lhs_b, rhs_b)) + .then_with(|| C::compare(lhs_c, rhs_c)) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct QuaternaryType(PhantomData<(A, B, C, D)>); + +impl AccessType for QuaternaryType +where + A: AccessType, + B: AccessType, + C: AccessType, + D: AccessType, +{ + type Scalar = (A::Scalar, B::Scalar, C::Scalar, D::Scalar); + type ScalarRef<'a> = (SR<'a, A>, SR<'a, B>, SR<'a, C>, SR<'a, D>); + type Column = (A::Column, B::Column, C::Column, D::Column); + type Domain = (A::Domain, B::Domain, C::Domain, D::Domain); + type ColumnIterator<'a> = Map< + Zip, Iter<'a, B>>, Iter<'a, C>>, Iter<'a, D>>, + fn((((SR<'a, A>, SR<'a, B>), SR<'a, C>), SR<'a, D>)) -> Self::ScalarRef<'a>, + >; + + fn to_owned_scalar((a, b, c, d): Self::ScalarRef<'_>) -> Self::Scalar { + ( + A::to_owned_scalar(a), + B::to_owned_scalar(b), + C::to_owned_scalar(c), + D::to_owned_scalar(d), + ) + } + + fn to_scalar_ref((a, b, c, d): &Self::Scalar) -> Self::ScalarRef<'_> { + ( + A::to_scalar_ref(a), + B::to_scalar_ref(b), + C::to_scalar_ref(c), + D::to_scalar_ref(d), + ) + } + + fn try_downcast_scalar<'a>(scalar: &ScalarRef<'a>) -> Option> { + let [a, b, c, d] = scalar.as_tuple()?.as_slice() else { + return None; + }; + Some(( + A::try_downcast_scalar(a)?, + B::try_downcast_scalar(b)?, + C::try_downcast_scalar(c)?, + D::try_downcast_scalar(d)?, + )) + } + + fn try_downcast_domain(domain: &Domain) -> Option { + let [a, b, c, d] = domain.as_tuple()?.as_slice() else { + return None; + }; + Some(( + A::try_downcast_domain(a)?, + B::try_downcast_domain(b)?, + C::try_downcast_domain(c)?, + D::try_downcast_domain(d)?, + )) + } + + fn try_downcast_column(col: &Column) -> Option { + let [a, b, c, d] = col.as_tuple()?.as_slice() else { + return None; + }; + Some(( + A::try_downcast_column(a)?, + B::try_downcast_column(b)?, + C::try_downcast_column(c)?, + D::try_downcast_column(d)?, + )) + } + + fn column_len((a, _b, _c, _d): &Self::Column) -> usize { + debug_assert_eq!(A::column_len(a), B::column_len(_b)); + debug_assert_eq!(A::column_len(a), C::column_len(_c)); + debug_assert_eq!(A::column_len(a), D::column_len(_d)); + A::column_len(a) + } + + fn index_column((a, b, c, d): &Self::Column, index: usize) -> Option> { + Some(( + A::index_column(a, index)?, + B::index_column(b, index)?, + C::index_column(c, index)?, + D::index_column(d, index)?, + )) + } + + unsafe fn index_column_unchecked( + (a, b, c, d): &Self::Column, + index: usize, + ) -> Self::ScalarRef<'_> { + ( + A::index_column_unchecked(a, index), + B::index_column_unchecked(b, index), + C::index_column_unchecked(c, index), + D::index_column_unchecked(d, index), + ) + } + + fn slice_column((a, b, c, d): &Self::Column, range: std::ops::Range) -> Self::Column { + ( + A::slice_column(a, range.clone()), + B::slice_column(b, range.clone()), + C::slice_column(c, range.clone()), + D::slice_column(d, range), + ) + } + + fn iter_column((a, b, c, d): &Self::Column) -> Self::ColumnIterator<'_> { + A::iter_column(a) + .zip(B::iter_column(b)) + .zip(C::iter_column(c)) + .zip(D::iter_column(d)) + .map(|(((a, b), c), d)| (a, b, c, d)) + } + + fn compare( + (lhs_a, lhs_b, lhs_c, lhs_d): Self::ScalarRef<'_>, + (rhs_a, rhs_b, rhs_c, rhs_d): Self::ScalarRef<'_>, + ) -> Ordering { + A::compare(lhs_a, rhs_a) + .then_with(|| B::compare(lhs_b, rhs_b)) + .then_with(|| C::compare(lhs_c, rhs_c)) + .then_with(|| D::compare(lhs_d, rhs_d)) + } +} diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs index 18af505c1241a..a800a0d95413c 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs @@ -20,10 +20,12 @@ use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::types::NumberDataType; use databend_common_expression::utils::column_merge_validity; +use databend_common_expression::BlockEntry; +use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; -use databend_common_io::prelude::BinaryWrite; +use databend_common_expression::StateSerdeItem; use super::AggrState; use super::AggrStateLoc; @@ -130,10 +132,6 @@ impl AggregateFunction for AggregateNullUnaryAdapto self.0.init_state(place); } - fn serialize_size_per_row(&self) -> Option { - self.0.serialize_size_per_row() - } - fn register_state(&self, registry: &mut AggrStateRegistry) { self.0.register_state(registry); } @@ -183,12 +181,27 @@ impl AggregateFunction for AggregateNullUnaryAdapto .accumulate_row(place, not_null_columns, validity, row) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - self.0.serialize(place, writer) + fn serialize_type(&self) -> Vec { + self.0.serialize_type() + } + + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + self.0.batch_serialize(places, loc, builders) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - self.0.merge(place, reader) + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + self.0.batch_merge(places, loc, state, filter) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { @@ -207,10 +220,6 @@ impl AggregateFunction for AggregateNullUnaryAdapto self.0.drop_state(place); } - fn convert_const_to_full(&self) -> bool { - self.0.nested.convert_const_to_full() - } - fn get_if_condition(&self, columns: ProjectedBlock) -> Option { self.0.nested.get_if_condition(columns) } @@ -233,6 +242,20 @@ impl AggregateNullVariadicAdaptor } } +impl AggregateNullVariadicAdaptor { + fn merge_validity( + columns: ProjectedBlock, + mut validity: Option, + ) -> (Vec, Option) { + let mut not_null_columns = Vec::with_capacity(columns.len()); + for entry in columns.iter() { + validity = column_merge_validity(&entry.clone(), validity); + not_null_columns.push(entry.clone().remove_nullable()); + } + (not_null_columns, validity) + } +} + impl AggregateFunction for AggregateNullVariadicAdaptor { @@ -248,10 +271,6 @@ impl AggregateFunction self.0.init_state(place); } - fn serialize_size_per_row(&self) -> Option { - self.0.serialize_size_per_row() - } - fn register_state(&self, registry: &mut AggrStateRegistry) { self.0.register_state(registry); } @@ -264,14 +283,8 @@ impl AggregateFunction validity: Option<&Bitmap>, input_rows: usize, ) -> Result<()> { - let mut not_null_columns = Vec::with_capacity(columns.len()); - let mut validity = validity.cloned(); - for entry in columns.iter() { - validity = column_merge_validity(&entry.clone(), validity); - not_null_columns.push(entry.clone().remove_nullable()); - } + let (not_null_columns, validity) = Self::merge_validity(columns, validity.cloned()); let not_null_columns = (¬_null_columns).into(); - self.0 .accumulate(place, not_null_columns, validity, input_rows) } @@ -283,37 +296,40 @@ impl AggregateFunction columns: ProjectedBlock, input_rows: usize, ) -> Result<()> { - let mut not_null_columns = Vec::with_capacity(columns.len()); - let mut validity = None; - for entry in columns.iter() { - validity = column_merge_validity(&entry.clone(), validity); - not_null_columns.push(entry.clone().remove_nullable()); - } + let (not_null_columns, validity) = Self::merge_validity(columns, None); let not_null_columns = (¬_null_columns).into(); - self.0 .accumulate_keys(addrs, loc, not_null_columns, validity, input_rows) } fn accumulate_row(&self, place: AggrState, columns: ProjectedBlock, row: usize) -> Result<()> { - let mut not_null_columns = Vec::with_capacity(columns.len()); - let mut validity = None; - for entry in columns.iter() { - validity = column_merge_validity(&entry.clone(), validity); - not_null_columns.push(entry.clone().remove_nullable()); - } + let (not_null_columns, validity) = Self::merge_validity(columns, None); let not_null_columns = (¬_null_columns).into(); - self.0 .accumulate_row(place, not_null_columns, validity, row) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - self.0.serialize(place, writer) + fn serialize_type(&self) -> Vec { + self.0.serialize_type() + } + + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + self.0.batch_serialize(places, loc, builders) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - self.0.merge(place, reader) + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + self.0.batch_merge(places, loc, state, filter) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { @@ -332,10 +348,6 @@ impl AggregateFunction self.0.drop_state(place); } - fn convert_const_to_full(&self) -> bool { - self.0.nested.convert_const_to_full() - } - fn get_if_condition(&self, columns: ProjectedBlock) -> Option { self.0.nested.get_if_condition(columns) } @@ -371,12 +383,6 @@ impl CommonNullAdaptor { self.nested.init_state(place.remove_last_loc()); } - fn serialize_size_per_row(&self) -> Option { - self.nested - .serialize_size_per_row() - .map(|row| if NULLABLE_RESULT { row + 1 } else { row }) - } - fn register_state(&self, registry: &mut AggrStateRegistry) { self.nested.register_state(registry); if NULLABLE_RESULT { @@ -496,33 +502,109 @@ impl CommonNullAdaptor { .accumulate_row(place.remove_last_loc(), not_null_columns, row) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { + fn serialize_type(&self) -> Vec { if !NULLABLE_RESULT { - return self.nested.serialize(place, writer); + return self.nested.serialize_type(); } - - self.nested.serialize(place.remove_last_loc(), writer)?; - let flag = get_flag(place); - writer.write_scalar(&flag) + self.nested + .serialize_type() + .into_iter() + .chain(Some(StateSerdeItem::DataType(DataType::Boolean))) + .collect() } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { if !NULLABLE_RESULT { - return self.nested.merge(place, reader); + return self.nested.batch_serialize(places, loc, builders); } + let n = builders.len(); + debug_assert_eq!(self.nested.serialize_type().len() + 1, n); + let flag_builder = builders + .last_mut() + .and_then(ColumnBuilder::as_boolean_mut) + .unwrap(); + for place in places { + let place = AggrState::new(*place, loc); + flag_builder.push(get_flag(place)); + } + self.nested + .batch_serialize(places, &loc[..loc.len() - 1], &mut builders[..(n - 1)]) + } - let flag = reader[reader.len() - 1]; - if flag == 0 { - return Ok(()); + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + if !NULLABLE_RESULT { + return self.nested.batch_merge(places, loc, state, filter); } - if !get_flag(place) { - // initial the state to remove the dirty stats - self.init_state(place); + match state { + BlockEntry::Column(Column::Tuple(tuple)) => { + let nested_state = Column::Tuple(tuple[0..tuple.len() - 1].to_vec()); + let flag = tuple.last().unwrap().as_boolean().unwrap(); + let flag = match filter { + Some(filter) => filter & flag, + None => flag.clone(), + }; + let filter = if flag.null_count() == 0 { + for place in places.iter() { + self.update_flag(AggrState::new(*place, loc)); + } + None + } else { + for place in places + .iter() + .zip(flag.iter()) + .filter_map(|(place, flag)| flag.then_some(place)) + { + self.update_flag(AggrState::new(*place, loc)); + } + Some(&flag) + }; + self.nested + .batch_merge(places, &loc[..loc.len() - 1], &nested_state.into(), filter) + } + BlockEntry::Const(Scalar::Tuple(tuple), DataType::Tuple(data_type), num_rows) => { + let flag = *tuple.last().and_then(Scalar::as_boolean).unwrap(); + let flag = Bitmap::new_constant(flag, *num_rows); + let flag = match filter { + Some(filter) => filter & &flag, + None => flag, + }; + let filter = if flag.null_count() == 0 { + for place in places.iter() { + self.update_flag(AggrState::new(*place, loc)); + } + None + } else { + for place in places + .iter() + .zip(flag.iter()) + .filter_map(|(place, flag)| flag.then_some(place)) + { + self.update_flag(AggrState::new(*place, loc)); + } + Some(&flag) + }; + let nested_state = BlockEntry::new_const_column( + DataType::Tuple(data_type[0..data_type.len() - 1].to_vec()), + Scalar::Tuple(tuple[0..tuple.len() - 1].to_vec()), + *num_rows, + ); + self.nested + .batch_merge(places, &loc[..loc.len() - 1], &nested_state, filter) + } + _ => unreachable!(), } - set_flag(place, true); - self.nested - .merge(place.remove_last_loc(), &mut &reader[..reader.len() - 1]) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { @@ -569,6 +651,14 @@ impl CommonNullAdaptor { self.nested.drop_state(place.remove_last_loc()) } } + + fn update_flag(&self, place: AggrState) { + if !get_flag(place) { + // initial the state to remove the dirty stats + self.init_state(place); + } + set_flag(place, true); + } } fn set_flag(place: AggrState, flag: bool) { diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs index c7f9d4c4677d4..318cd88e0c522 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs @@ -20,10 +20,12 @@ use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; +use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; -use databend_common_io::prelude::BinaryWrite; +use databend_common_expression::StateSerdeItem; use super::AggrState; use super::AggrStateLoc; @@ -37,23 +39,23 @@ use super::StateAddr; /// Use a single additional byte of data after the nested function data: /// 0 means there was no input, 1 means there was some. pub struct AggregateFunctionOrNullAdaptor { - inner: AggregateFunctionRef, + nested: AggregateFunctionRef, inner_nullable: bool, } impl AggregateFunctionOrNullAdaptor { pub fn create( - inner: AggregateFunctionRef, + nested: AggregateFunctionRef, features: AggregateFunctionFeatures, ) -> Result { // count/count distinct should not be nullable for empty set, just return zero - let inner_return_type = inner.return_type()?; + let inner_return_type = nested.return_type()?; if features.returns_default_when_only_null || inner_return_type == DataType::Null { - return Ok(inner); + return Ok(nested); } Ok(Arc::new(AggregateFunctionOrNullAdaptor { - inner, + nested, inner_nullable: inner_return_type.is_nullable(), })) } @@ -69,32 +71,33 @@ pub fn get_flag(place: AggrState) -> bool { *c != 0 } +fn merge_flag(place: AggrState, other: bool) { + let flag = other || get_flag(place); + set_flag(place, flag); +} + fn flag_offset(place: AggrState) -> usize { *place.loc.last().unwrap().as_bool().unwrap().1 } impl AggregateFunction for AggregateFunctionOrNullAdaptor { fn name(&self) -> &str { - self.inner.name() + self.nested.name() } fn return_type(&self) -> Result { - Ok(self.inner.return_type()?.wrap_nullable()) + Ok(self.nested.return_type()?.wrap_nullable()) } #[inline] fn init_state(&self, place: AggrState) { let c = place.addr.next(flag_offset(place)).get::(); *c = 0; - self.inner.init_state(place.remove_last_loc()) - } - - fn serialize_size_per_row(&self) -> Option { - self.inner.serialize_size_per_row().map(|row| row + 1) + self.nested.init_state(place.remove_last_loc()) } fn register_state(&self, registry: &mut AggrStateRegistry) { - self.inner.register_state(registry); + self.nested.register_state(registry); registry.register(AggrStateType::Bool); } @@ -110,7 +113,7 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { return Ok(()); } - let if_cond = self.inner.get_if_condition(columns); + let if_cond = self.nested.get_if_condition(columns); let validity = match (if_cond, validity) { (None, None) => None, @@ -125,7 +128,7 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { .unwrap_or(true) { set_flag(place, true); - self.inner.accumulate( + self.nested.accumulate( place.remove_last_loc(), columns, validity.as_ref(), @@ -142,9 +145,9 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { columns: ProjectedBlock, input_rows: usize, ) -> Result<()> { - self.inner + self.nested .accumulate_keys(places, &loc[..loc.len() - 1], columns, input_rows)?; - let if_cond = self.inner.get_if_condition(columns); + let if_cond = self.nested.get_if_condition(columns); match if_cond { Some(v) if v.null_count() > 0 => { @@ -171,30 +174,96 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { #[inline] fn accumulate_row(&self, place: AggrState, columns: ProjectedBlock, row: usize) -> Result<()> { - self.inner + self.nested .accumulate_row(place.remove_last_loc(), columns, row)?; set_flag(place, true); Ok(()) } - #[inline] - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - self.inner.serialize(place.remove_last_loc(), writer)?; - let flag = get_flag(place) as u8; - writer.write_scalar(&flag) + fn serialize_type(&self) -> Vec { + self.nested + .serialize_type() + .into_iter() + .chain(Some(StateSerdeItem::DataType(DataType::Boolean))) + .collect() } - #[inline] - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - let flag = get_flag(place) || reader[reader.len() - 1] > 0; - self.inner - .merge(place.remove_last_loc(), &mut &reader[..reader.len() - 1])?; - set_flag(place, flag); + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + let n = builders.len(); + debug_assert_eq!(self.nested.serialize_type().len() + 1, n); + let flag_builder = builders + .last_mut() + .and_then(ColumnBuilder::as_boolean_mut) + .unwrap(); + for place in places { + let place = AggrState::new(*place, loc); + flag_builder.push(get_flag(place)); + } + self.nested + .batch_serialize(places, &loc[..loc.len() - 1], &mut builders[..(n - 1)]) + } + + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + match state { + BlockEntry::Column(Column::Tuple(tuple)) => { + let flag = tuple.last().unwrap().as_boolean().unwrap(); + let iter = places.iter().zip(flag.iter()); + if let Some(filter) = filter { + for (place, flag) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) + { + merge_flag(AggrState::new(*place, loc), flag); + } + } else { + for (place, flag) in iter { + merge_flag(AggrState::new(*place, loc), flag); + } + } + let inner_state = Column::Tuple(tuple[0..tuple.len() - 1].to_vec()).into(); + self.nested + .batch_merge(places, &loc[0..loc.len() - 1], &inner_state, filter)?; + } + BlockEntry::Const(Scalar::Tuple(tuple), DataType::Tuple(data_type), num_rows) => { + let flag = *tuple.last().unwrap().as_boolean().unwrap(); + if let Some(filter) = filter { + for place in places + .iter() + .zip(filter.iter()) + .filter_map(|(v, b)| b.then_some(v)) + { + merge_flag(AggrState::new(*place, loc), flag); + } + } else { + for place in places { + merge_flag(AggrState::new(*place, loc), flag); + } + } + let inner_state = BlockEntry::new_const_column( + DataType::Tuple(data_type[0..data_type.len() - 1].to_vec()), + Scalar::Tuple(tuple[0..tuple.len() - 1].to_vec()), + *num_rows, + ); + self.nested + .batch_merge(places, &loc[0..loc.len() - 1], &inner_state, filter)?; + } + _ => unreachable!(), + } + Ok(()) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { - self.inner + self.nested .merge_states(place.remove_last_loc(), rhs.remove_last_loc())?; let flag = get_flag(place) || get_flag(rhs); set_flag(place, flag); @@ -207,9 +276,9 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { if !get_flag(place) { inner_mut.push_null(); } else if self.inner_nullable { - self.inner.merge_result(place.remove_last_loc(), builder)?; + self.nested.merge_result(place.remove_last_loc(), builder)?; } else { - self.inner + self.nested .merge_result(place.remove_last_loc(), &mut inner_mut.builder)?; inner_mut.validity.push(true); } @@ -225,25 +294,21 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor { params: Vec, arguments: Vec, ) -> Result> { - self.inner + self.nested .get_own_null_adaptor(nested_function, params, arguments) } fn need_manual_drop_state(&self) -> bool { - self.inner.need_manual_drop_state() + self.nested.need_manual_drop_state() } unsafe fn drop_state(&self, place: AggrState) { - self.inner.drop_state(place.remove_last_loc()) - } - - fn convert_const_to_full(&self) -> bool { - self.inner.convert_const_to_full() + self.nested.drop_state(place.remove_last_loc()) } } impl fmt::Display for AggregateFunctionOrNullAdaptor { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.inner) + write!(f, "{}", self.nested) } } diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_sort_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_sort_adaptor.rs index 4894fc24a9c84..0cfd393c776ce 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_sort_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_sort_adaptor.rs @@ -23,17 +23,23 @@ use borsh::BorshSerialize; use databend_common_column::bitmap::Bitmap; use databend_common_exception::Result; use databend_common_expression::types::AnyType; +use databend_common_expression::types::BinaryType; use databend_common_expression::types::DataType; +use databend_common_expression::types::UnaryType; use databend_common_expression::AggrState; +use databend_common_expression::AggrStateLoc; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; use databend_common_expression::AggregateFunction; use databend_common_expression::AggregateFunctionRef; +use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::DataBlock; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; use databend_common_expression::SortColumnDescription; +use databend_common_expression::StateAddr; +use databend_common_expression::StateSerdeItem; use itertools::Itertools; use crate::aggregates::AggregateFunctionSortDesc; @@ -121,16 +127,48 @@ impl AggregateFunction for AggregateFunctionSortAdaptor { Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - let state = Self::get_state(place); - Ok(state.serialize(writer)?) + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - let state = Self::get_state(place); - let rhs = SortAggState::deserialize(reader)?; + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); + for place in places { + let state = Self::get_state(AggrState::new(*place, loc)); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } + Ok(()) + } - Self::merge_states_inner(state, &rhs); + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = Self::get_state(AggrState::new(*place, loc)); + let rhs = SortAggState::deserialize(&mut data)?; + Self::merge_states_inner(state, &rhs); + } + } else { + for (place, mut data) in iter { + let state = Self::get_state(AggrState::new(*place, loc)); + let rhs = SortAggState::deserialize(&mut data)?; + Self::merge_states_inner(state, &rhs); + } + } Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_arg_min_max.rs b/src/query/functions/src/aggregates/aggregate_arg_min_max.rs index 98506ab12b8b5..a1076cbf8e039 100644 --- a/src/query/functions/src/aggregates/aggregate_arg_min_max.rs +++ b/src/query/functions/src/aggregates/aggregate_arg_min_max.rs @@ -27,10 +27,12 @@ use databend_common_expression::types::*; use databend_common_expression::with_number_mapped_type; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::ColumnView; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::StateSerdeItem; use super::aggregate_function_factory::AggregateFunctionDescription; use super::aggregate_function_factory::AggregateFunctionSortDesc; @@ -270,15 +272,49 @@ where Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - let state = place.get::(); - Ok(state.serialize(writer)?) + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - let state = place.get::(); - let rhs: State = borsh_partial_deserialize(reader)?; - state.merge_from(rhs) + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } + Ok(()) + } + + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let rhs: State = borsh_partial_deserialize(&mut data)?; + state.merge_from(rhs)?; + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + let rhs: State = borsh_partial_deserialize(&mut data)?; + state.merge_from(rhs)?; + } + } + Ok(()) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_array_agg.rs b/src/query/functions/src/aggregates/aggregate_array_agg.rs index 7083b886996df..b48dd221695c4 100644 --- a/src/query/functions/src/aggregates/aggregate_array_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_array_agg.rs @@ -43,11 +43,13 @@ use databend_common_expression::with_decimal_mapped_type; use databend_common_expression::with_number_mapped_type; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; use databend_common_expression::ScalarRef; +use databend_common_expression::StateSerdeItem; use super::aggregate_function_factory::AggregateFunctionDescription; use super::aggregate_scalar_state::ScalarStateFunc; @@ -548,16 +550,50 @@ where Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - let state = place.get::(); - Ok(state.serialize(writer)?) + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - let state = place.get::(); - let rhs = State::deserialize_reader(reader)?; + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } + Ok(()) + } - state.merge(&rhs) + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let rhs = State::deserialize_reader(&mut data)?; + state.merge(&rhs)?; + } + } else { + for (place, data) in iter { + let mut binary = data; + let state = AggrState::new(*place, loc).get::(); + let rhs = State::deserialize_reader(&mut binary)?; + state.merge(&rhs)?; + } + } + Ok(()) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_array_moving.rs b/src/query/functions/src/aggregates/aggregate_array_moving.rs index 6ef407c1402e2..ac4a7b50add8c 100644 --- a/src/query/functions/src/aggregates/aggregate_array_moving.rs +++ b/src/query/functions/src/aggregates/aggregate_array_moving.rs @@ -26,6 +26,7 @@ use databend_common_expression::types::i256; use databend_common_expression::types::number::Number; use databend_common_expression::types::AccessType; use databend_common_expression::types::ArgType; +use databend_common_expression::types::BinaryType; use databend_common_expression::types::Bitmap; use databend_common_expression::types::Buffer; use databend_common_expression::types::DataType; @@ -33,6 +34,7 @@ use databend_common_expression::types::Float64Type; use databend_common_expression::types::Int8Type; use databend_common_expression::types::NumberDataType; use databend_common_expression::types::NumberType; +use databend_common_expression::types::UnaryType; use databend_common_expression::types::F64; use databend_common_expression::utils::arithmetics_type::ResultTypeOfUnary; use databend_common_expression::with_decimal_mapped_type; @@ -45,6 +47,7 @@ use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; use databend_common_expression::ScalarRef; +use databend_common_expression::StateSerdeItem; use num_traits::AsPrimitive; use super::aggregate_function::AggregateFunction; @@ -447,16 +450,49 @@ where State: SumState state.accumulate_row(&columns[0], row) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - let state = place.get::(); - Ok(state.serialize(writer)?) + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - let state = place.get::(); - let rhs = State::deserialize_reader(reader)?; + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } + Ok(()) + } - state.merge(&rhs) + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let rhs = State::deserialize_reader(&mut data)?; + state.merge(&rhs)?; + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + let rhs = State::deserialize_reader(&mut data)?; + state.merge(&rhs)?; + } + } + Ok(()) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { @@ -616,16 +652,49 @@ where State: SumState state.accumulate_row(&columns[0], row) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - let state = place.get::(); - Ok(state.serialize(writer)?) + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - let state = place.get::(); - let rhs = State::deserialize_reader(reader)?; + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } + Ok(()) + } - state.merge(&rhs) + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let rhs = State::deserialize_reader(&mut data)?; + state.merge(&rhs)?; + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + let rhs = State::deserialize_reader(&mut data)?; + state.merge(&rhs)?; + } + } + Ok(()) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_bitmap.rs b/src/query/functions/src/aggregates/aggregate_bitmap.rs index cce9492c709d8..a5cf18c1e3d4d 100644 --- a/src/query/functions/src/aggregates/aggregate_bitmap.rs +++ b/src/query/functions/src/aggregates/aggregate_bitmap.rs @@ -33,9 +33,11 @@ use databend_common_expression::with_decimal_mapped_type; use databend_common_expression::with_number_mapped_type; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::StateSerdeItem; use databend_common_io::prelude::BinaryWrite; use roaring::RoaringTreemap; @@ -287,25 +289,61 @@ where Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - let state = place.get::(); - // flag indicate where bitmap is none - let flag: u8 = if state.rb.is_some() { 1 } else { 0 }; - writer.write_scalar(&flag)?; - if let Some(rb) = &state.rb { - rb.serialize_into(writer)?; + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] + } + + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + // flag indicate where bitmap is none + let flag: u8 = if state.rb.is_some() { 1 } else { 0 }; + binary_builder.data.write_scalar(&flag)?; + if let Some(rb) = &state.rb { + rb.serialize_into(&mut binary_builder.data)?; + } + binary_builder.commit_row(); } Ok(()) } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - let state = place.get::(); - - let flag = reader[0]; - reader.consume(1); - if flag == 1 { - let rb = deserialize_bitmap(reader)?; - state.add::(rb); + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let flag = data[0]; + data.consume(1); + if flag == 1 { + let rb = deserialize_bitmap(data)?; + state.add::(rb); + } + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + + let flag = data[0]; + data.consume(1); + if flag == 1 { + let rb = deserialize_bitmap(data)?; + state.add::(rb); + } + } } Ok(()) } @@ -482,12 +520,27 @@ where Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - self.inner.serialize(place, writer) + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - self.inner.merge(place, reader) + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + self.inner.batch_serialize(places, loc, builders) + } + + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + self.inner.batch_merge(places, loc, state, filter) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs b/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs index 77a11ec9b44f8..da9994b82624f 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_distinct.rs @@ -19,15 +19,19 @@ use std::sync::Arc; use databend_common_exception::Result; use databend_common_expression::types::number::NumberColumnBuilder; +use databend_common_expression::types::BinaryType; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::types::NumberDataType; +use databend_common_expression::types::UnaryType; use databend_common_expression::with_number_mapped_type; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::StateSerdeItem; use super::aggregate_distinct_state::AggregateDistinctNumberState; use super::aggregate_distinct_state::AggregateDistinctState; @@ -42,6 +46,8 @@ use super::aggregate_function_factory::CombinatorDescription; use super::aggregator_common::assert_variadic_arguments; use super::AggregateCountFunction; use crate::aggregates::AggrState; +use crate::aggregates::AggrStateLoc; +use crate::aggregates::StateAddr; #[derive(Clone)] pub struct AggregateDistinctCombinator { @@ -106,16 +112,49 @@ where State: DistinctStateFunc state.add(columns, row) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - let state = Self::get_state(place); - state.serialize(writer) + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - let state = Self::get_state(place); - let rhs = State::deserialize(reader)?; + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); + for place in places { + let state = Self::get_state(AggrState::new(*place, loc)); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } + Ok(()) + } - state.merge(&rhs) + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = Self::get_state(AggrState::new(*place, loc)); + let rhs = State::deserialize(&mut data)?; + state.merge(&rhs)?; + } + } else { + for (place, mut data) in iter { + let state = Self::get_state(AggrState::new(*place, loc)); + let rhs = State::deserialize(&mut data)?; + state.merge(&rhs)?; + } + } + Ok(()) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_combinator_if.rs b/src/query/functions/src/aggregates/aggregate_combinator_if.rs index ea83b8028d23e..4555f9b2cc85b 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_if.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_if.rs @@ -26,6 +26,7 @@ use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::StateSerdeItem; use super::StateAddr; use crate::aggregates::aggregate_function_factory::AggregateFunctionCreator; @@ -53,20 +54,18 @@ impl AggregateIfCombinator { sort_descs: Vec, nested_creator: &AggregateFunctionCreator, ) -> Result { - let name = format!("IfCombinator({})", nested_name); + let name = format!("IfCombinator({nested_name})"); let argument_len = arguments.len(); if argument_len == 0 { return Err(ErrorCode::NumberArgumentsNotMatch(format!( - "{} expect to have more than one argument", - name + "{name} expect to have more than one argument", ))); } if !matches!(&arguments[argument_len - 1], DataType::Boolean) { return Err(ErrorCode::BadArguments(format!( - "The type of the last argument for {} must be boolean type, but got {:?}", - name, + "The type of the last argument for {name} must be boolean type, but got {:?}", &arguments[argument_len - 1] ))); } @@ -155,12 +154,27 @@ impl AggregateFunction for AggregateIfCombinator { Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - self.nested.serialize(place, writer) + fn serialize_type(&self) -> Vec { + self.nested.serialize_type() } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - self.nested.merge(place, reader) + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + self.nested.batch_serialize(places, loc, builders) + } + + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &databend_common_expression::BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + self.nested.batch_merge(places, loc, state, filter) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_combinator_state.rs b/src/query/functions/src/aggregates/aggregate_combinator_state.rs index 486707ec7af61..4c2bd220856fb 100644 --- a/src/query/functions/src/aggregates/aggregate_combinator_state.rs +++ b/src/query/functions/src/aggregates/aggregate_combinator_state.rs @@ -19,9 +19,11 @@ use databend_common_exception::Result; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::AggrStateRegistry; +use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::StateSerdeItem; use super::AggregateFunctionFactory; use super::StateAddr; @@ -70,7 +72,7 @@ impl AggregateFunction for AggregateStateCombinator { } fn return_type(&self) -> Result { - Ok(DataType::Binary) + Ok(self.nested.serialize_data_type()) } fn init_state(&self, place: AggrState) { @@ -106,13 +108,27 @@ impl AggregateFunction for AggregateStateCombinator { self.nested.accumulate_row(place, columns, row) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - self.nested.serialize(place, writer) + fn serialize_type(&self) -> Vec { + self.nested.serialize_type() } - #[inline] - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - self.nested.merge(place, reader) + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + self.nested.batch_serialize(places, loc, builders) + } + + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + self.nested.batch_merge(places, loc, state, filter) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { @@ -120,10 +136,9 @@ impl AggregateFunction for AggregateStateCombinator { } fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> { - let builder = builder.as_binary_mut().unwrap(); - self.nested.serialize(place, &mut builder.data)?; - builder.commit_row(); - Ok(()) + let builders = builder.as_tuple_mut().unwrap().as_mut_slice(); + self.nested + .batch_serialize(&[place.addr], place.loc, builders) } fn need_manual_drop_state(&self) -> bool { diff --git a/src/query/functions/src/aggregates/aggregate_count.rs b/src/query/functions/src/aggregates/aggregate_count.rs index b822b832896ee..fb4ed78280fc9 100644 --- a/src/query/functions/src/aggregates/aggregate_count.rs +++ b/src/query/functions/src/aggregates/aggregate_count.rs @@ -16,26 +16,30 @@ use std::alloc::Layout; use std::fmt; use std::sync::Arc; -use borsh::BorshSerialize; use databend_common_exception::Result; use databend_common_expression::types::number::NumberColumnBuilder; +use databend_common_expression::types::ArgType; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::types::NumberDataType; +use databend_common_expression::types::UInt64Type; +use databend_common_expression::types::UnaryType; +use databend_common_expression::types::ValueType; use databend_common_expression::utils::column_merge_validity; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::StateSerdeItem; use super::aggregate_function::AggregateFunction; use super::aggregate_function_factory::AggregateFunctionDescription; use super::aggregate_function_factory::AggregateFunctionSortDesc; use super::StateAddr; use crate::aggregates::aggregator_common::assert_variadic_arguments; -use crate::aggregates::borsh_partial_deserialize; use crate::aggregates::AggrState; use crate::aggregates::AggrStateLoc; @@ -160,15 +164,44 @@ impl AggregateFunction for AggregateCountFunction { Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - let state = place.get::(); - Ok(state.count.serialize(writer)?) + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::DataType(UInt64Type::data_type())] } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - let state = place.get::(); - let other: u64 = borsh_partial_deserialize(reader)?; - state.count += other; + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + let mut builder = UInt64Type::downcast_builder(&mut builders[0]); + for place in places { + let state = AggrState::new(*place, loc).get::(); + builder.push(state.count); + } + Ok(()) + } + + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + if let Some(filter) = filter { + for (place, other) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + state.count += other; + } + } else { + for (place, other) in iter { + let state = AggrState::new(*place, loc).get::(); + state.count += other; + } + } Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_covariance.rs b/src/query/functions/src/aggregates/aggregate_covariance.rs index 8af1216f56567..0b4e7a52d4b06 100644 --- a/src/query/functions/src/aggregates/aggregate_covariance.rs +++ b/src/query/functions/src/aggregates/aggregate_covariance.rs @@ -23,17 +23,21 @@ use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::types::number::Number; use databend_common_expression::types::number::F64; +use databend_common_expression::types::BinaryType; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::types::NumberDataType; use databend_common_expression::types::NumberType; +use databend_common_expression::types::UnaryType; use databend_common_expression::types::ValueType; use databend_common_expression::with_number_mapped_type; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::StateSerdeItem; use num_traits::AsPrimitive; use super::borsh_partial_deserialize; @@ -231,15 +235,48 @@ where Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - let state = place.get::(); - Ok(state.serialize(writer)?) + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - let state = place.get::(); - let rhs: AggregateCovarianceState = borsh_partial_deserialize(reader)?; - state.merge(&rhs); + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } + Ok(()) + } + + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let rhs: AggregateCovarianceState = borsh_partial_deserialize(&mut data)?; + state.merge(&rhs); + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + let rhs: AggregateCovarianceState = borsh_partial_deserialize(&mut data)?; + state.merge(&rhs); + } + } Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_json_array_agg.rs b/src/query/functions/src/aggregates/aggregate_json_array_agg.rs index f1ba8d35807fa..0f7ab2609dc56 100644 --- a/src/query/functions/src/aggregates/aggregate_json_array_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_json_array_agg.rs @@ -29,11 +29,13 @@ use databend_common_expression::types::ValueType; use databend_common_expression::types::*; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; use databend_common_expression::ScalarRef; +use databend_common_expression::StateSerdeItem; use jiff::tz::TimeZone; use jsonb::OwnedJsonb; use jsonb::RawJsonb; @@ -240,16 +242,49 @@ where Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - let state = place.get::(); - Ok(state.serialize(writer)?) + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - let state = place.get::(); - let rhs: State = borsh_partial_deserialize(reader)?; + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } + Ok(()) + } - state.merge(&rhs) + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let rhs: State = borsh_partial_deserialize(&mut data)?; + state.merge(&rhs)?; + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + let rhs: State = borsh_partial_deserialize(&mut data)?; + state.merge(&rhs)?; + } + } + Ok(()) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_json_object_agg.rs b/src/query/functions/src/aggregates/aggregate_json_object_agg.rs index 9efe5fc0f6b90..ad9910d21a12e 100644 --- a/src/query/functions/src/aggregates/aggregate_json_object_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_json_object_agg.rs @@ -31,11 +31,13 @@ use databend_common_expression::types::ValueType; use databend_common_expression::types::*; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; use databend_common_expression::ScalarRef; +use databend_common_expression::StateSerdeItem; use jiff::tz::TimeZone; use jsonb::OwnedJsonb; use jsonb::RawJsonb; @@ -293,16 +295,49 @@ where Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - let state = place.get::(); - Ok(state.serialize(writer)?) + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - let state = place.get::(); - let rhs: State = borsh_partial_deserialize(reader)?; + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } + Ok(()) + } - state.merge(&rhs) + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let rhs: State = borsh_partial_deserialize(&mut data)?; + state.merge(&rhs)?; + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + let rhs: State = borsh_partial_deserialize(&mut data)?; + state.merge(&rhs)?; + } + } + Ok(()) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_markov_tarin.rs b/src/query/functions/src/aggregates/aggregate_markov_tarin.rs index 0524caa523e39..ccce62a529840 100644 --- a/src/query/functions/src/aggregates/aggregate_markov_tarin.rs +++ b/src/query/functions/src/aggregates/aggregate_markov_tarin.rs @@ -25,11 +25,13 @@ use databend_common_base::obfuscator::NGramHash; use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::types::ArgType; +use databend_common_expression::types::BinaryType; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::types::MapType; use databend_common_expression::types::StringType; use databend_common_expression::types::UInt32Type; +use databend_common_expression::types::UnaryType; use databend_common_expression::types::ValueType; use databend_common_expression::types::F64; use databend_common_expression::AggrState; @@ -40,12 +42,15 @@ use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::StateSerdeItem; use super::aggregate_function_factory::AggregateFunctionDescription; use super::assert_unary_arguments; use super::borsh_partial_deserialize; use super::extract_number_param; use super::AggregateFunction; +use super::StateAddr; +use crate::aggregates::AggrStateLoc; pub struct MarkovTarin { display_name: String, @@ -113,16 +118,48 @@ impl AggregateFunction for MarkovTarin { Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - let state = place.get::(); - Ok(state.serialize(writer)?) + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - let state = place.get::(); - let mut rhs = borsh_partial_deserialize::(reader)?; - state.merge(&mut rhs); + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } + Ok(()) + } + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let mut rhs = borsh_partial_deserialize::(&mut data)?; + state.merge(&mut rhs); + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + let mut rhs = borsh_partial_deserialize::(&mut data)?; + state.merge(&mut rhs); + } + } Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_null_result.rs b/src/query/functions/src/aggregates/aggregate_null_result.rs index 03f9b6a1ce041..92efe72ba8f51 100644 --- a/src/query/functions/src/aggregates/aggregate_null_result.rs +++ b/src/query/functions/src/aggregates/aggregate_null_result.rs @@ -23,8 +23,10 @@ use databend_common_expression::types::DataType; use databend_common_expression::types::ValueType; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; +use databend_common_expression::StateSerdeItem; use super::aggregate_function::AggregateFunction; use super::StateAddr; @@ -86,11 +88,30 @@ impl AggregateFunction for AggregateNullResultFunction { Ok(()) } - fn serialize(&self, _place: AggrState, _writer: &mut Vec) -> Result<()> { + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] + } + + fn batch_serialize( + &self, + places: &[StateAddr], + _loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); + for _ in places { + binary_builder.commit_row(); + } Ok(()) } - fn merge(&self, _place: AggrState, _reader: &mut &[u8]) -> Result<()> { + fn batch_merge( + &self, + _: &[StateAddr], + _: &[AggrStateLoc], + _: &BlockEntry, + _: Option<&Bitmap>, + ) -> Result<()> { Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs b/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs index 5aadd1887b5db..03422b26df1ea 100644 --- a/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs +++ b/src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs @@ -31,10 +31,12 @@ use databend_common_expression::with_decimal_mapped_type; use databend_common_expression::with_number_mapped_type; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; use databend_common_expression::ScalarRef; +use databend_common_expression::StateSerdeItem; use itertools::Itertools; use super::borsh_partial_deserialize; @@ -362,15 +364,50 @@ where for<'a> T: AccessType = F64> + Send + Sync }); Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - let state = place.get::(); - Ok(state.serialize(writer)?) + + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - let state = place.get::(); - let mut rhs: QuantileTDigestState = borsh_partial_deserialize(reader)?; - state.merge(&mut rhs) + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } + Ok(()) + } + + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let mut rhs: QuantileTDigestState = borsh_partial_deserialize(&mut data)?; + state.merge(&mut rhs)?; + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + let mut rhs: QuantileTDigestState = borsh_partial_deserialize(&mut data)?; + state.merge(&mut rhs)?; + } + } + Ok(()) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs b/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs index 7ac562153fe53..0f56b5c89fd06 100644 --- a/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs +++ b/src/query/functions/src/aggregates/aggregate_quantile_tdigest_weighted.rs @@ -28,9 +28,11 @@ use databend_common_expression::with_number_mapped_type; use databend_common_expression::with_unsigned_integer_mapped_type; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::StateSerdeItem; use num_traits::AsPrimitive; use super::borsh_partial_deserialize; @@ -143,15 +145,50 @@ where }); Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - let state = place.get::(); - Ok(state.serialize(writer)?) + + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - let state = place.get::(); - let mut rhs: QuantileTDigestState = borsh_partial_deserialize(reader)?; - state.merge(&mut rhs) + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } + Ok(()) + } + + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let mut rhs: QuantileTDigestState = borsh_partial_deserialize(&mut data)?; + state.merge(&mut rhs)?; + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + let mut rhs: QuantileTDigestState = borsh_partial_deserialize(&mut data)?; + state.merge(&mut rhs)?; + } + } + Ok(()) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_retention.rs b/src/query/functions/src/aggregates/aggregate_retention.rs index 5f42198bec4c0..44b76bc932cc0 100644 --- a/src/query/functions/src/aggregates/aggregate_retention.rs +++ b/src/query/functions/src/aggregates/aggregate_retention.rs @@ -20,15 +20,19 @@ use borsh::BorshDeserialize; use borsh::BorshSerialize; use databend_common_exception::ErrorCode; use databend_common_exception::Result; +use databend_common_expression::types::BinaryType; use databend_common_expression::types::Bitmap; use databend_common_expression::types::BooleanType; use databend_common_expression::types::DataType; use databend_common_expression::types::NumberDataType; +use databend_common_expression::types::UnaryType; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::StateSerdeItem; use super::aggregate_function::AggregateFunction; use super::aggregate_function::AggregateFunctionRef; @@ -143,15 +147,48 @@ impl AggregateFunction for AggregateRetentionFunction { Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - let state = place.get::(); - Ok(state.serialize(writer)?) + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - let state = place.get::(); - let rhs: AggregateRetentionState = borsh_partial_deserialize(reader)?; - state.merge(&rhs); + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } + Ok(()) + } + + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let rhs: AggregateRetentionState = borsh_partial_deserialize(&mut data)?; + state.merge(&rhs); + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + let rhs: AggregateRetentionState = borsh_partial_deserialize(&mut data)?; + state.merge(&rhs); + } + } Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_st_collect.rs b/src/query/functions/src/aggregates/aggregate_st_collect.rs index afdb3d88a373e..9356a10c2f6d7 100644 --- a/src/query/functions/src/aggregates/aggregate_st_collect.rs +++ b/src/query/functions/src/aggregates/aggregate_st_collect.rs @@ -23,17 +23,21 @@ use borsh::BorshSerialize; use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::types::ArgType; +use databend_common_expression::types::BinaryType; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::types::GeometryType; +use databend_common_expression::types::UnaryType; use databend_common_expression::types::ValueType; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; use databend_common_expression::ScalarRef; +use databend_common_expression::StateSerdeItem; use databend_common_io::ewkb_to_geo; use databend_common_io::geo_to_ewkb; use geo::Geometry; @@ -306,16 +310,49 @@ where Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - let state = place.get::(); - Ok(state.serialize(writer)?) + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - let state = place.get::(); - let rhs: State = borsh_partial_deserialize(reader)?; + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } + Ok(()) + } - state.merge(&rhs) + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let rhs: State = borsh_partial_deserialize(&mut data)?; + state.merge(&rhs)?; + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + let rhs: State = borsh_partial_deserialize(&mut data)?; + state.merge(&rhs)?; + } + } + Ok(()) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_string_agg.rs b/src/query/functions/src/aggregates/aggregate_string_agg.rs index ad72f7cc40c88..540e82d03b64a 100644 --- a/src/query/functions/src/aggregates/aggregate_string_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_string_agg.rs @@ -22,9 +22,11 @@ use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::types::compute_view::StringConvertView; use databend_common_expression::types::AccessType; +use databend_common_expression::types::BinaryType; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; use databend_common_expression::types::StringType; +use databend_common_expression::types::UnaryType; use databend_common_expression::types::ValueType; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; @@ -36,6 +38,7 @@ use databend_common_expression::Evaluator; use databend_common_expression::FunctionContext; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::StateSerdeItem; use super::aggregate_function_factory::AggregateFunctionDescription; use super::borsh_partial_deserialize; @@ -150,15 +153,48 @@ impl AggregateFunction for AggregateStringAggFunction { Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - let state = place.get::(); - Ok(state.serialize(writer)?) + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - let state = place.get::(); - let rhs: StringAggState = borsh_partial_deserialize(reader)?; - state.values.push_str(&rhs.values); + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } + Ok(()) + } + + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let rhs: StringAggState = borsh_partial_deserialize(&mut data)?; + state.values.push_str(&rhs.values); + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + let rhs: StringAggState = borsh_partial_deserialize(&mut data)?; + state.values.push_str(&rhs.values); + } + } Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregate_sum.rs b/src/query/functions/src/aggregates/aggregate_sum.rs index 1aeff4f5101b0..0b94e24fd52f7 100644 --- a/src/query/functions/src/aggregates/aggregate_sum.rs +++ b/src/query/functions/src/aggregates/aggregate_sum.rs @@ -28,11 +28,13 @@ use databend_common_expression::types::*; use databend_common_expression::utils::arithmetics_type::ResultTypeOfUnary; use databend_common_expression::with_decimal_mapped_type; use databend_common_expression::with_number_mapped_type; +use databend_common_expression::AggrState; use databend_common_expression::AggregateFunctionRef; use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::Scalar; use databend_common_expression::StateAddr; +use databend_common_expression::StateSerdeItem; use databend_common_expression::SELECTIVITY_THRESHOLD; use num_traits::AsPrimitive; @@ -76,14 +78,14 @@ pub trait SumState: BorshSerialize + BorshDeserialize + Send + Sync + Default + #[derive(BorshSerialize, BorshDeserialize)] pub struct NumberSumState -where N: ValueType +where N: ArgType { pub value: N::Scalar, } impl Default for NumberSumState where - N: ValueType, + N: ArgType, N::Scalar: Number + AsPrimitive + BorshSerialize + BorshDeserialize + std::ops::AddAssign, { fn default() -> Self { @@ -130,7 +132,7 @@ where impl UnaryState for NumberSumState where T: ArgType + Sync + Send, - N: ValueType, + N: ArgType, T::Scalar: Number + AsPrimitive, N::Scalar: Number + AsPrimitive + BorshSerialize + BorshDeserialize + std::ops::AddAssign, for<'a> T::ScalarRef<'a>: Number + AsPrimitive, @@ -169,6 +171,51 @@ where builder.push_item(N::to_scalar_ref(&self.value)); Ok(()) } + + fn serialize_type() -> Vec { + std::vec![StateSerdeItem::DataType(N::data_type())] + } + + fn batch_serialize( + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + let mut builder = N::downcast_builder(&mut builders[0]); + for place in places { + let state: &mut Self = AggrState::new(*place, loc).get(); + builder.push_item(N::to_scalar_ref(&state.value)); + } + Ok(()) + } + + fn batch_merge( + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + if let Some(filter) = filter { + for (place, data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let rhs = Self { + value: N::to_owned_scalar(data), + }; + let state: &mut Self = AggrState::new(*place, loc).get(); + >::merge(state, &rhs)?; + } + } else { + for (place, data) in iter { + let rhs = Self { + value: N::to_owned_scalar(data), + }; + let state: &mut Self = AggrState::new(*place, loc).get(); + >::merge(state, &rhs)?; + } + } + Ok(()) + } } #[derive(BorshDeserialize, BorshSerialize)] diff --git a/src/query/functions/src/aggregates/aggregate_unary.rs b/src/query/functions/src/aggregates/aggregate_unary.rs index 20b9fc5f85a70..0b880acb2830c 100644 --- a/src/query/functions/src/aggregates/aggregate_unary.rs +++ b/src/query/functions/src/aggregates/aggregate_unary.rs @@ -21,17 +21,21 @@ use std::sync::Arc; use databend_common_exception::Result; use databend_common_expression::types::AccessType; +use databend_common_expression::types::BinaryType; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; +use databend_common_expression::types::UnaryType; use databend_common_expression::types::ValueType; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; use databend_common_expression::AggregateFunction; use databend_common_expression::AggregateFunctionRef; +use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; use databend_common_expression::StateAddr; +use databend_common_expression::StateSerdeItem; use crate::aggregates::AggrState; use crate::aggregates::AggrStateLoc; @@ -78,6 +82,48 @@ where builder: R::ColumnBuilderMut<'_>, function_data: Option<&dyn FunctionData>, ) -> Result<()>; + + fn serialize_type() -> Vec { + vec![StateSerdeItem::Binary(None)] + } + + fn batch_serialize( + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); + for place in places { + let state: &mut Self = AggrState::new(*place, loc).get(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } + Ok(()) + } + + fn batch_merge( + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let rhs = Self::deserialize_reader(&mut data)?; + let state: &mut Self = AggrState::new(*place, loc).get(); + state.merge(&rhs)?; + } + } else { + for (place, mut data) in iter { + let rhs = Self::deserialize_reader(&mut data)?; + let state: &mut Self = AggrState::new(*place, loc).get(); + state.merge(&rhs)?; + } + } + Ok(()) + } } pub trait FunctionData: Send + Sync { @@ -227,15 +273,27 @@ where Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - let state: &mut S = place.get::(); - Ok(state.serialize(writer)?) + fn serialize_type(&self) -> Vec { + S::serialize_type() } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - let state: &mut S = place.get::(); - let rhs = S::deserialize_reader(reader)?; - state.merge(&rhs) + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + S::batch_serialize(places, loc, builders) + } + + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + S::batch_merge(places, loc, state, filter) } fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> { diff --git a/src/query/functions/src/aggregates/aggregate_window_funnel.rs b/src/query/functions/src/aggregates/aggregate_window_funnel.rs index 3d29c1734bff8..ece70c912870d 100644 --- a/src/query/functions/src/aggregates/aggregate_window_funnel.rs +++ b/src/query/functions/src/aggregates/aggregate_window_funnel.rs @@ -26,6 +26,7 @@ use databend_common_exception::Result; use databend_common_expression::types::number::Number; use databend_common_expression::types::number::UInt8Type; use databend_common_expression::types::ArgType; +use databend_common_expression::types::BinaryType; use databend_common_expression::types::Bitmap; use databend_common_expression::types::BooleanType; use databend_common_expression::types::DataType; @@ -33,13 +34,16 @@ use databend_common_expression::types::DateType; use databend_common_expression::types::NumberDataType; use databend_common_expression::types::NumberType; use databend_common_expression::types::TimestampType; +use databend_common_expression::types::UnaryType; use databend_common_expression::types::ValueType; use databend_common_expression::with_integer_mapped_type; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; use databend_common_expression::ProjectedBlock; use databend_common_expression::Scalar; +use databend_common_expression::StateSerdeItem; use num_traits::AsPrimitive; use super::borsh_partial_deserialize; @@ -275,15 +279,52 @@ where Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - let state = place.get::>(); - Ok(state.serialize(writer)?) + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - let state = place.get::>(); - let mut rhs: AggregateWindowFunnelState = borsh_partial_deserialize(reader)?; - state.merge(&mut rhs); + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); + for place in places { + let state = AggrState::new(*place, loc).get::>(); + state.serialize(&mut binary_builder.data)?; + binary_builder.commit_row(); + } + Ok(()) + } + + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = + AggrState::new(*place, loc).get::>(); + let mut rhs: AggregateWindowFunnelState = + borsh_partial_deserialize(&mut data)?; + state.merge(&mut rhs); + } + } else { + for (place, mut data) in iter { + let state = + AggrState::new(*place, loc).get::>(); + let mut rhs: AggregateWindowFunnelState = + borsh_partial_deserialize(&mut data)?; + state.merge(&mut rhs); + } + } Ok(()) } diff --git a/src/query/functions/src/aggregates/aggregator_common.rs b/src/query/functions/src/aggregates/aggregator_common.rs index 6430bcf3cc647..5fa214d4f89ce 100644 --- a/src/query/functions/src/aggregates/aggregator_common.rs +++ b/src/query/functions/src/aggregates/aggregator_common.rs @@ -186,10 +186,13 @@ pub fn eval_aggr_for_test( let state = AggrState::new(eval.addr, &eval.state_layout.states_loc[0]); func.accumulate(state, entries.into(), None, rows)?; if with_serialize { - let mut buf = vec![]; - func.serialize(state, &mut buf)?; + let data_type = func.serialize_data_type(); + let mut builder = ColumnBuilder::with_capacity(&data_type, 1); + let builders = builder.as_tuple_mut().unwrap().as_mut_slice(); + func.batch_serialize(&[eval.addr], state.loc, builders)?; func.init_state(state); - func.merge(state, &mut buf.as_slice())?; + let column = builder.build(); + func.batch_merge(&[eval.addr], state.loc, &column.into(), None)?; } let mut builder = ColumnBuilder::with_capacity(&data_type, 1024); func.merge_result(state, &mut builder)?; diff --git a/src/query/pipeline/core/src/processors/processor.rs b/src/query/pipeline/core/src/processors/processor.rs index ce70053b80ded..bb71f493a2061 100644 --- a/src/query/pipeline/core/src/processors/processor.rs +++ b/src/query/pipeline/core/src/processors/processor.rs @@ -161,12 +161,24 @@ impl ProcessorPtr { /// # Safety pub unsafe fn process(&self) -> Result<()> { - let mut name = self.name(); - name.push_str("::process"); - let _span = LocalSpan::enter_with_local_parent(name) + let span = LocalSpan::enter_with_local_parent(format!("{}::process", self.name())) .with_property(|| ("graph-node-id", self.id().index().to_string())); - (*self.inner.get()).process() + match (*self.inner.get()).process() { + Ok(_) => Ok(()), + Err(err) => { + let _ = span + .with_property(|| ("error", "true")) + .with_properties(|| { + [ + ("error.type", err.code().to_string()), + ("error.message", err.display_text()), + ] + }); + log::info!(error = err.to_string(); "[PIPELINE-EXECUTOR] Error in process"); + Err(err) + } + } } /// # Safety @@ -190,10 +202,23 @@ impl ProcessorPtr { async move { let span = Span::enter_with_local_parent(name) .with_property(|| ("graph-node-id", id.index().to_string())); - task.in_span(span).await?; - drop(inner); - Ok(()) + match task.await { + Ok(_) => { + drop(inner); + Ok(()) + } + Err(err) => { + span.with_property(|| ("error", "true")).add_properties(|| { + [ + ("error.type", err.code().to_string()), + ("error.message", err.display_text()), + ] + }); + log::info!(error = err.to_string(); "[PIPELINE-EXECUTOR] Error in process"); + Err(err) + } + } } .boxed() } diff --git a/src/query/pipeline/core/src/processors/profile.rs b/src/query/pipeline/core/src/processors/profile.rs index 7fb5af4f4e8be..58b0fd98169ab 100644 --- a/src/query/pipeline/core/src/processors/profile.rs +++ b/src/query/pipeline/core/src/processors/profile.rs @@ -43,7 +43,7 @@ impl Drop for PlanScopeGuard { } } -#[derive(serde::Serialize, serde::Deserialize)] +#[derive(Debug, serde::Serialize, serde::Deserialize)] pub struct ErrorInfoDesc { message: String, detail: String, @@ -60,7 +60,7 @@ impl ErrorInfoDesc { } } -#[derive(serde::Serialize, serde::Deserialize)] +#[derive(Debug, serde::Serialize, serde::Deserialize)] pub enum ErrorInfo { Other(ErrorInfoDesc), IoError(ErrorInfoDesc), @@ -68,7 +68,7 @@ pub enum ErrorInfo { CalculationError(ErrorInfoDesc), } -#[derive(Clone, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct PlanProfile { pub id: Option, pub name: Option, diff --git a/src/query/service/src/pipelines/executor/pipeline_executor.rs b/src/query/service/src/pipelines/executor/pipeline_executor.rs index a05d163c32f58..76732fb110220 100644 --- a/src/query/service/src/pipelines/executor/pipeline_executor.rs +++ b/src/query/service/src/pipelines/executor/pipeline_executor.rs @@ -147,6 +147,7 @@ impl PipelineExecutor { } } + #[fastrace::trace(name = "PipelineExecutor::init")] fn init(on_init_callback: &Mutex>, query_id: &Arc) -> Result<()> { // TODO: the on init callback cannot be killed. { @@ -159,10 +160,8 @@ impl PipelineExecutor { } } - info!( - "[PIPELINE-EXECUTOR] Pipeline initialized successfully for query {}, elapsed: {:?}", - query_id, - instant.elapsed() + info!(query_id, elapsed:? = instant.elapsed(); + "[PIPELINE-EXECUTOR] Pipeline initialized successfully", ); } Ok(()) diff --git a/src/query/service/src/pipelines/executor/query_pipeline_executor.rs b/src/query/service/src/pipelines/executor/query_pipeline_executor.rs index 9ff2d9fa7eb24..8cb68ab023dc2 100644 --- a/src/query/service/src/pipelines/executor/query_pipeline_executor.rs +++ b/src/query/service/src/pipelines/executor/query_pipeline_executor.rs @@ -34,7 +34,6 @@ use databend_common_pipeline_core::FinishedCallbackChain; use databend_common_pipeline_core::LockGuard; use databend_common_pipeline_core::Pipeline; use databend_common_pipeline_core::PlanProfile; -use fastrace::func_path; use fastrace::prelude::*; use futures::future::select; use futures_util::future::Either; @@ -267,6 +266,7 @@ impl QueryPipelineExecutor { Ok(()) } + #[fastrace::trace(name = "QueryPipelineExecutor::init")] fn init(self: &Arc, graph: Arc) -> Result<()> { unsafe { // TODO: the on init callback cannot be killed. @@ -287,10 +287,8 @@ impl QueryPipelineExecutor { } } - info!( - "[PIPELINE-EXECUTOR] Pipeline initialized successfully for query {}, elapsed: {:?}", - self.settings.query_id, - instant.elapsed() + info!(query_id = self.settings.query_id, elapsed:? = instant.elapsed(); + "[PIPELINE-EXECUTOR] Pipeline initialized successfully", ); } @@ -359,7 +357,7 @@ impl QueryPipelineExecutor { } } - let span = Span::enter_with_local_parent(func_path!()) + let span = Span::enter_with_local_parent("QueryPipelineExecutor::execute_threads") .with_property(|| ("thread_name", name.clone())); thread_join_handles.push(Thread::named_spawn(Some(name), move || unsafe { let _g = span.set_local_parent(); @@ -368,6 +366,12 @@ impl QueryPipelineExecutor { // finish the pipeline executor when has error or panic if let Err(cause) = try_result.flatten() { + span.with_property(|| ("error", "true")).add_properties(|| { + [ + ("error.type", cause.code().to_string()), + ("error.message", cause.display_text()), + ] + }); this.finish(Some(cause)); } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_meta.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_meta.rs index b95066fc57d68..905df88ccdcd5 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_meta.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregate_meta.rs @@ -45,6 +45,7 @@ impl SerializedPayload { entry.as_column().unwrap() } + #[fastrace::trace(name = "SerializedPayload::convert_to_aggregate_table")] pub fn convert_to_aggregate_table( &self, group_types: Vec, diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_aggregate_serializer.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_aggregate_serializer.rs index 096485fa98fcc..3032176fe5de8 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_aggregate_serializer.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_aggregate_serializer.rs @@ -121,26 +121,19 @@ impl Processor for TransformAggregateSerializer { impl TransformAggregateSerializer { fn transform_input_data(&mut self, mut data_block: DataBlock) -> Result { debug_assert!(data_block.is_empty()); - if let Some(block_meta) = data_block.take_meta() { - if let Some(block_meta) = AggregateMeta::downcast_from(block_meta) { - match block_meta { - AggregateMeta::Spilled(_) => unreachable!(), - AggregateMeta::Serialized(_) => unreachable!(), - AggregateMeta::BucketSpilled(_) => unreachable!(), - AggregateMeta::Partitioned { .. } => unreachable!(), - AggregateMeta::AggregateSpilling(_) => unreachable!(), - AggregateMeta::AggregatePayload(p) => { - self.input_data = Some(SerializeAggregateStream::create( - &self.params, - SerializePayload::AggregatePayload(p), - )); - return Ok(Event::Sync); - } - } - } - } - unreachable!() + let Some(AggregateMeta::AggregatePayload(p)) = data_block + .take_meta() + .and_then(AggregateMeta::downcast_from) + else { + unreachable!() + }; + + self.input_data = Some(SerializeAggregateStream::create( + &self.params, + SerializePayload::AggregatePayload(p), + )); + Ok(Event::Sync) } } @@ -218,41 +211,29 @@ impl SerializeAggregateStream { return Ok(None); } - match self.payload.as_ref().get_ref() { - SerializePayload::AggregatePayload(p) => { - let block = p.payload.aggregate_flush(&mut self.flush_state)?; - - if block.is_none() { - self.end_iter = true; - } - - match block { - Some(block) => { - self.nums += 1; - Ok(Some(block.add_meta(Some( - AggregateSerdeMeta::create_agg_payload( - p.bucket, - p.max_partition_count, - false, - ), - ))?)) - } - None => { - // always return at least one block - if self.nums == 0 { - self.nums += 1; - let block = p.payload.empty_block(Some(1)); - Ok(Some(block.add_meta(Some( - AggregateSerdeMeta::create_agg_payload( - p.bucket, - p.max_partition_count, - true, - ), - ))?)) - } else { - Ok(None) - } - } + let SerializePayload::AggregatePayload(p) = self.payload.as_ref().get_ref(); + match p.payload.aggregate_flush(&mut self.flush_state)? { + Some(block) => { + self.nums += 1; + Ok(Some(block.add_meta(Some( + AggregateSerdeMeta::create_agg_payload(p.bucket, p.max_partition_count, false), + ))?)) + } + None => { + self.end_iter = true; + // always return at least one block + if self.nums == 0 { + self.nums += 1; + let block = p.payload.empty_block(1); + Ok(Some(block.add_meta(Some( + AggregateSerdeMeta::create_agg_payload( + p.bucket, + p.max_partition_count, + true, + ), + ))?)) + } else { + Ok(None) } } } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_deserializer.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_deserializer.rs index 09725ee3bb173..9f632ec0b187c 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_deserializer.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_deserializer.rs @@ -79,94 +79,77 @@ impl TransformDeserializer { return Ok(DataBlock::new_with_meta(vec![], 0, meta)); } - let data_block = match &meta { - None => { - deserialize_block(dict, fragment_data, &self.schema, self.arrow_schema.clone())? - } - Some(meta) => match AggregateSerdeMeta::downcast_ref_from(meta) { - None => { - deserialize_block(dict, fragment_data, &self.schema, self.arrow_schema.clone())? + let Some(meta) = meta + .as_ref() + .and_then(AggregateSerdeMeta::downcast_ref_from) + else { + let data_block = + deserialize_block(dict, fragment_data, &self.schema, self.arrow_schema.clone())?; + return match data_block.num_columns() == 0 { + true => Ok(DataBlock::new_with_meta(vec![], row_count as usize, meta)), + false => data_block.add_meta(meta), + }; + }; + + match meta.typ { + BUCKET_TYPE => { + let mut block = deserialize_block( + dict, + fragment_data, + &self.schema, + self.arrow_schema.clone(), + )?; + + if meta.is_empty { + block = block.slice(0..0); } - Some(meta) => { - return match meta.typ == BUCKET_TYPE { - true => { - let mut block = deserialize_block( - dict, - fragment_data, - &self.schema, - self.arrow_schema.clone(), - )?; - - if meta.is_empty { - block = block.slice(0..0); - } - - Ok(DataBlock::empty_with_meta( - AggregateMeta::create_serialized( - meta.bucket, - block, - meta.max_partition_count, - ), - )) - } - false => { - let data_schema = Arc::new(exchange_defines::spilled_schema()); - let arrow_schema = Arc::new(exchange_defines::spilled_arrow_schema()); - let data_block = deserialize_block( - dict, - fragment_data, - &data_schema, - arrow_schema.clone(), - )?; - - let columns = data_block - .columns() - .iter() - .map(|c| c.as_column().unwrap().clone()) - .collect::>(); - - let buckets = - NumberType::::try_downcast_column(&columns[0]).unwrap(); - let data_range_start = - NumberType::::try_downcast_column(&columns[1]).unwrap(); - let data_range_end = - NumberType::::try_downcast_column(&columns[2]).unwrap(); - let columns_layout = - ArrayType::::try_downcast_column(&columns[3]).unwrap(); - - let columns_layout_data = columns_layout.values().as_slice(); - let columns_layout_offsets = columns_layout.offsets(); - - let mut buckets_payload = Vec::with_capacity(data_block.num_rows()); - for index in 0..data_block.num_rows() { - unsafe { - buckets_payload.push(BucketSpilledPayload { - bucket: *buckets.get_unchecked(index) as isize, - location: meta.location.clone().unwrap(), - data_range: *data_range_start.get_unchecked(index) - ..*data_range_end.get_unchecked(index), - columns_layout: columns_layout_data[columns_layout_offsets - [index] - as usize - ..columns_layout_offsets[index + 1] as usize] - .to_vec(), - max_partition_count: meta.max_partition_count, - }); - } - } - - Ok(DataBlock::empty_with_meta(AggregateMeta::create_spilled( - buckets_payload, - ))) - } - }; + + Ok(DataBlock::empty_with_meta( + AggregateMeta::create_serialized(meta.bucket, block, meta.max_partition_count), + )) + } + _ => { + let data_schema = Arc::new(exchange_defines::spilled_schema()); + let arrow_schema = Arc::new(exchange_defines::spilled_arrow_schema()); + let data_block = + deserialize_block(dict, fragment_data, &data_schema, arrow_schema.clone())?; + + let columns = data_block + .columns() + .iter() + .map(|c| c.as_column().unwrap().clone()) + .collect::>(); + + let buckets = NumberType::::try_downcast_column(&columns[0]).unwrap(); + let data_range_start = NumberType::::try_downcast_column(&columns[1]).unwrap(); + let data_range_end = NumberType::::try_downcast_column(&columns[2]).unwrap(); + let columns_layout = + ArrayType::::try_downcast_column(&columns[3]).unwrap(); + + let columns_layout_data = columns_layout.values().as_slice(); + let columns_layout_offsets = columns_layout.offsets(); + + let mut buckets_payload = Vec::with_capacity(data_block.num_rows()); + for index in 0..data_block.num_rows() { + unsafe { + buckets_payload.push(BucketSpilledPayload { + bucket: *buckets.get_unchecked(index) as isize, + location: meta.location.clone().unwrap(), + data_range: *data_range_start.get_unchecked(index) + ..*data_range_end.get_unchecked(index), + columns_layout: columns_layout_data[columns_layout_offsets[index] + as usize + ..columns_layout_offsets[index + 1] as usize] + .to_vec(), + max_partition_count: meta.max_partition_count, + }); + } } - }, - }; - match data_block.num_columns() == 0 { - true => Ok(DataBlock::new_with_meta(vec![], row_count as usize, meta)), - false => data_block.add_meta(meta), + Ok(DataBlock::empty_with_meta(AggregateMeta::create_spilled( + buckets_payload, + ))) + } } } } @@ -177,15 +160,9 @@ impl BlockMetaTransform for TransformDeserializer { fn transform(&mut self, mut meta: ExchangeDeserializeMeta) -> Result> { match meta.packet.pop().unwrap() { - DataPacket::ErrorCode(v) => Err(v), - DataPacket::Dictionary(_) => unreachable!(), - DataPacket::QueryProfiles(_) => unreachable!(), - DataPacket::SerializeProgress { .. } => unreachable!(), - DataPacket::CopyStatus { .. } => unreachable!(), - DataPacket::MutationStatus { .. } => unreachable!(), - DataPacket::DataCacheMetrics(_) => unreachable!(), DataPacket::FragmentData(v) => Ok(vec![self.recv_data(meta.packet, v)?]), - DataPacket::QueryPerf(_) => unreachable!(), + DataPacket::ErrorCode(err) => Err(err), + _ => unreachable!(), } } } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_exchange_aggregate_serializer.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_exchange_aggregate_serializer.rs index ffd80674bdab4..118f8119fed5e 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_exchange_aggregate_serializer.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_exchange_aggregate_serializer.rs @@ -125,12 +125,7 @@ impl BlockMetaTransform for TransformExchangeAggregateSeria continue; } - match AggregateMeta::downcast_from(block.take_meta().unwrap()) { - None => unreachable!(), - Some(AggregateMeta::Spilled(_)) => unreachable!(), - Some(AggregateMeta::Serialized(_)) => unreachable!(), - Some(AggregateMeta::BucketSpilled(_)) => unreachable!(), - Some(AggregateMeta::Partitioned { .. }) => unreachable!(), + match block.take_meta().and_then(AggregateMeta::downcast_from) { Some(AggregateMeta::AggregateSpilling(payload)) => { serialized_blocks.push(FlightSerialized::Future( match index == self.local_pos { @@ -172,6 +167,8 @@ impl BlockMetaTransform for TransformExchangeAggregateSeria let c = serialize_block(block_number, c, &self.options)?; serialized_blocks.push(FlightSerialized::DataBlock(c)); } + + _ => unreachable!(), }; } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_exchange_async_barrier.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_exchange_async_barrier.rs index 1628bc9af5beb..eaebda0543e6d 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_exchange_async_barrier.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/serde/transform_exchange_async_barrier.rs @@ -45,29 +45,31 @@ impl AsyncTransform for TransformExchangeAsyncBarrier { const NAME: &'static str = "TransformExchangeAsyncBarrier"; async fn transform(&mut self, mut data: DataBlock) -> Result { - if let Some(meta) = data + let Some(meta) = data .take_meta() .and_then(FlightSerializedMeta::downcast_from) - { - let mut futures = Vec::with_capacity(meta.serialized_blocks.len()); + else { + return Err(ErrorCode::Internal("")); + }; - for serialized_block in meta.serialized_blocks { - futures.push(databend_common_base::runtime::spawn(async move { - match serialized_block { - FlightSerialized::DataBlock(v) => Ok(v), - FlightSerialized::Future(f) => f.await, - } - })); - } - - return match futures::future::try_join_all(futures).await { - Err(_) => Err(ErrorCode::TokioError("Cannot join tokio job")), - Ok(spilled_data) => Ok(DataBlock::empty_with_meta(ExchangeShuffleMeta::create( - spilled_data.into_iter().collect::>>()?, - ))), - }; + let mut futures = Vec::with_capacity(meta.serialized_blocks.len()); + for serialized_block in meta.serialized_blocks { + futures.push(databend_common_base::runtime::spawn(async move { + match serialized_block { + FlightSerialized::DataBlock(v) => Ok(v), + FlightSerialized::Future(f) => f.await, + } + })); } - Err(ErrorCode::Internal("")) + match futures::future::try_join_all(futures).await { + Err(_) => Err(ErrorCode::TokioError("Cannot join tokio job")), + Ok(spilled_data) => { + let blocks = spilled_data.into_iter().collect::>()?; + Ok(DataBlock::empty_with_meta(ExchangeShuffleMeta::create( + blocks, + ))) + } + } } } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs index e40e7c80e2cd8..f856c17bbe135 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs @@ -150,7 +150,7 @@ impl TransformPartialAggregate { .params .states_layout .as_ref() - .map(|layout| layout.states_loc.len()) + .map(|layout| layout.num_aggr_func()) .unwrap_or(0); ( vec![], diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs index a43003ba291b4..19acf639becc0 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs @@ -24,7 +24,6 @@ use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::AggrState; use databend_common_expression::BlockMetaInfoDowncast; -use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::DataBlock; use databend_common_expression::ProjectedBlock; @@ -94,29 +93,29 @@ impl AccumulatingTransform for PartialSingleStateAggregator { self.first_block_start = Some(Instant::now()); } - let is_agg_index_block = block + let meta = block .get_meta() .and_then(AggIndexMeta::downcast_ref_from) - .map(|index| index.is_agg) - .unwrap_or_default(); + .copied(); let block = block.consume_convert_to_full(); - if is_agg_index_block { + if let Some(meta) = meta + && meta.is_agg + { + assert_eq!(self.states_layout.num_aggr_func(), meta.num_agg_funcs); // Aggregation states are in the back of the block. - let states_indices = (block.num_columns() - self.states_layout.states_loc.len() - ..block.num_columns()) - .collect::>(); + let start = block.num_columns() - self.states_layout.num_aggr_func(); + let states_indices = (start..block.num_columns()).collect::>(); let states = ProjectedBlock::project(&states_indices, &block); - for ((place, func), state) in self + for ((loc, func), state) in self .states_layout .states_loc .iter() - .map(|loc| AggrState::new(self.addr, loc)) .zip(self.funcs.iter()) .zip(states.iter()) { - func.batch_merge_single(place, state)?; + func.batch_merge(&[self.addr], loc, state, None)?; } } else { for ((place, columns), func) in self @@ -145,25 +144,17 @@ impl AccumulatingTransform for PartialSingleStateAggregator { let blocks = if generate_data { let mut builders = self.states_layout.serialize_builders(1); - for ((func, place), builder) in self + for ((func, loc), builder) in self .funcs .iter() - .zip( - self.states_layout - .states_loc - .iter() - .map(|loc| AggrState::new(self.addr, loc)), - ) + .zip(self.states_layout.states_loc.iter()) .zip(builders.iter_mut()) { - func.serialize(place, &mut builder.data)?; - builder.commit_row(); + let builders = builder.as_tuple_mut().unwrap().as_mut_slice(); + func.batch_serialize(&[self.addr], loc, builders)?; } - let columns = builders - .into_iter() - .map(|b| Column::Binary(b.build())) - .collect(); + let columns = builders.into_iter().map(|b| b.build()).collect(); vec![DataBlock::new_from_columns(columns)] } else { vec![] @@ -249,20 +240,9 @@ impl AccumulatingTransform for FinalSingleStateAggregator { let main_addr: StateAddr = self.arena.alloc_layout(self.states_layout.layout).into(); - let main_places = self - .funcs - .iter() - .zip( - self.states_layout - .states_loc - .iter() - .map(|loc| AggrState::new(main_addr, loc)), - ) - .map(|(func, place)| { - func.init_state(place); - place - }) - .collect::>(); + for (func, loc) in self.funcs.iter().zip(self.states_layout.states_loc.iter()) { + func.init_state(AggrState::new(main_addr, loc)); + } let mut result_builders = self .funcs @@ -270,25 +250,25 @@ impl AccumulatingTransform for FinalSingleStateAggregator { .map(|f| Ok(ColumnBuilder::with_capacity(&f.return_type()?, 1))) .collect::>>()?; - for (idx, ((func, place), builder)) in self + for (idx, ((func, loc), builder)) in self .funcs .iter() - .zip(main_places.iter().copied()) + .zip(self.states_layout.states_loc.iter()) .zip(result_builders.iter_mut()) .enumerate() { for block in self.to_merge_data.iter() { - func.batch_merge_single(place, block.get_by_offset(idx))?; + func.batch_merge(&[main_addr], loc, block.get_by_offset(idx), None)?; } - func.merge_result(place, builder)?; + func.merge_result(AggrState::new(main_addr, loc), builder)?; } let columns = result_builders.into_iter().map(|b| b.build()).collect(); // destroy states - for (place, func) in main_places.iter().copied().zip(self.funcs.iter()) { + for (func, loc) in self.funcs.iter().zip(self.states_layout.states_loc.iter()) { if func.need_manual_drop_state() { - unsafe { func.drop_state(place) } + unsafe { func.drop_state(AggrState::new(main_addr, loc)) } } } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs index d92ace4e212fd..5ddf677b0cd7c 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs @@ -28,18 +28,24 @@ use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::converts::arrow::ARROW_EXT_TYPE_VARIANT; use databend_common_expression::converts::arrow::EXTENSION_KEY; +use databend_common_expression::types::BinaryType; use databend_common_expression::types::Bitmap; use databend_common_expression::types::DataType; +use databend_common_expression::types::UnaryType; use databend_common_expression::AggrState; use databend_common_expression::AggrStateRegistry; use databend_common_expression::AggrStateType; +use databend_common_expression::BlockEntry; use databend_common_expression::Column; use databend_common_expression::ColumnBuilder; use databend_common_expression::DataBlock; use databend_common_expression::DataField; use databend_common_expression::DataSchema; use databend_common_expression::ProjectedBlock; +use databend_common_expression::StateSerdeItem; +use databend_common_functions::aggregates::AggrStateLoc; use databend_common_functions::aggregates::AggregateFunction; +use databend_common_functions::aggregates::StateAddr; use databend_common_sql::plans::UDFLanguage; use databend_common_sql::plans::UDFScriptCode; @@ -98,23 +104,62 @@ impl AggregateFunction for AggregateUdfScript { Ok(()) } - fn serialize(&self, place: AggrState, writer: &mut Vec) -> Result<()> { - let state = place.get::(); - state - .serialize(writer) - .map_err(|e| ErrorCode::Internal(format!("state failed to serialize: {e}"))) + fn serialize_type(&self) -> Vec { + vec![StateSerdeItem::Binary(None)] } - fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> { - let state = place.get::(); - let rhs = - UdfAggState::deserialize(reader).map_err(|e| ErrorCode::Internal(e.to_string()))?; - let states = arrow_select::concat::concat(&[&state.0, &rhs.0])?; - let state = self - .runtime - .merge(&states) - .map_err(|e| ErrorCode::UDFRuntimeError(format!("failed to merge: {e}")))?; - place.write(|| state); + fn batch_serialize( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + builders: &mut [ColumnBuilder], + ) -> Result<()> { + let binary_builder = builders[0].as_binary_mut().unwrap(); + for place in places { + let state = AggrState::new(*place, loc).get::(); + state + .serialize(&mut binary_builder.data) + .map_err(|e| ErrorCode::Internal(format!("state failed to serialize: {e}")))?; + binary_builder.commit_row(); + } + Ok(()) + } + + fn batch_merge( + &self, + places: &[StateAddr], + loc: &[AggrStateLoc], + state: &BlockEntry, + filter: Option<&Bitmap>, + ) -> Result<()> { + let view = state.downcast::>().unwrap(); + let iter = places.iter().zip(view.iter()); + + if let Some(filter) = filter { + for (place, mut data) in iter.zip(filter.iter()).filter_map(|(v, b)| b.then_some(v)) { + let state = AggrState::new(*place, loc).get::(); + let rhs = UdfAggState::deserialize(&mut data) + .map_err(|e| ErrorCode::Internal(e.to_string()))?; + let states = arrow_select::concat::concat(&[&state.0, &rhs.0])?; + let merged_state = self + .runtime + .merge(&states) + .map_err(|e| ErrorCode::UDFRuntimeError(format!("failed to merge: {e}")))?; + AggrState::new(*place, loc).write(|| merged_state); + } + } else { + for (place, mut data) in iter { + let state = AggrState::new(*place, loc).get::(); + let rhs = UdfAggState::deserialize(&mut data) + .map_err(|e| ErrorCode::Internal(e.to_string()))?; + let states = arrow_select::concat::concat(&[&state.0, &rhs.0])?; + let merged_state = self + .runtime + .merge(&states) + .map_err(|e| ErrorCode::UDFRuntimeError(format!("failed to merge: {e}")))?; + AggrState::new(*place, loc).write(|| merged_state); + } + } Ok(()) } diff --git a/src/query/service/src/servers/flight/v1/packets/packet_data.rs b/src/query/service/src/servers/flight/v1/packets/packet_data.rs index 2b5407ffb2946..81226501438f8 100644 --- a/src/query/service/src/servers/flight/v1/packets/packet_data.rs +++ b/src/query/service/src/servers/flight/v1/packets/packet_data.rs @@ -54,6 +54,7 @@ impl Debug for FragmentData { } } +#[derive(Debug)] pub enum DataPacket { ErrorCode(ErrorCode), Dictionary(FlightData), diff --git a/src/query/sql/src/executor/physical_plans/physical_aggregate_partial.rs b/src/query/sql/src/executor/physical_plans/physical_aggregate_partial.rs index a8a73071aaa57..ab915b3c727e5 100644 --- a/src/query/sql/src/executor/physical_plans/physical_aggregate_partial.rs +++ b/src/query/sql/src/executor/physical_plans/physical_aggregate_partial.rs @@ -14,11 +14,10 @@ use databend_common_exception::Result; use databend_common_expression::types::DataType; -#[allow(unused_imports)] -use databend_common_expression::DataBlock; use databend_common_expression::DataField; use databend_common_expression::DataSchemaRef; use databend_common_expression::DataSchemaRefExt; +use databend_common_functions::aggregates::AggregateFunctionFactory; use super::SortDesc; use crate::executor::explain::PlanStatsInfo; @@ -47,11 +46,30 @@ impl AggregatePartial { let input_schema = self.input.output_schema()?; let mut fields = Vec::with_capacity(self.agg_funcs.len() + self.group_by.len()); + let factory = AggregateFunctionFactory::instance(); - fields.extend(self.agg_funcs.iter().map(|func| { - let name = func.output_column.to_string(); - DataField::new(&name, DataType::Binary) - })); + for desc in &self.agg_funcs { + let name = desc.output_column.to_string(); + + if desc.sig.udaf.is_some() { + fields.push(DataField::new( + &name, + DataType::Tuple(vec![DataType::Binary]), + )); + continue; + } + + let func = factory + .get( + &desc.sig.name, + desc.sig.params.clone(), + desc.sig.args.clone(), + desc.sig.sort_descs.clone(), + ) + .unwrap(); + + fields.push(DataField::new(&name, func.serialize_data_type())) + } for (idx, field) in self.group_by.iter().zip( self.group_by diff --git a/src/query/sql/src/planner/optimizer/optimizers/operator/aggregate/stats_aggregate.rs b/src/query/sql/src/planner/optimizer/optimizers/operator/aggregate/stats_aggregate.rs index 03a604ff86ebc..9f8eb81c06b31 100644 --- a/src/query/sql/src/planner/optimizer/optimizers/operator/aggregate/stats_aggregate.rs +++ b/src/query/sql/src/planner/optimizer/optimizers/operator/aggregate/stats_aggregate.rs @@ -129,32 +129,48 @@ impl RuleStatsAggregateOptimizer { for (need_rewrite_agg, agg) in need_rewrite_aggs.iter().zip(agg.aggregate_functions.iter()) { - let agg_func = AggregateFunction::try_from(agg.scalar.clone())?; - - if let Some((col_id, name)) = need_rewrite_agg { - if let Some(stat) = stats.get(col_id) { - let value_bound = if name.eq_ignore_ascii_case("min") { - &stat.min - } else { - &stat.max - }; - if !value_bound.may_be_truncated { - let scalar = ScalarExpr::ConstantExpr(ConstantExpr { - span: agg.scalar.span(), - value: value_bound.value.clone(), - }); - - let scalar = - scalar.unify_to_data_type(agg_func.return_type.as_ref()); - - eval_scalar_results.push(ScalarItem { - index: agg.index, - scalar, - }); - continue; + let column = if let ScalarExpr::UDAFCall(udaf) = &agg.scalar { + ColumnBindingBuilder::new( + udaf.display_name.clone(), + agg.index, + udaf.return_type.clone(), + Visibility::Visible, + ) + .build() + } else { + let agg_func = AggregateFunction::try_from(agg.scalar.clone())?; + if let Some((col_id, name)) = need_rewrite_agg { + if let Some(stat) = stats.get(col_id) { + let value_bound = if name.eq_ignore_ascii_case("min") { + &stat.min + } else { + &stat.max + }; + if !value_bound.may_be_truncated { + let scalar = ScalarExpr::ConstantExpr(ConstantExpr { + span: agg.scalar.span(), + value: value_bound.value.clone(), + }); + + let scalar = scalar + .unify_to_data_type(agg_func.return_type.as_ref()); + + eval_scalar_results.push(ScalarItem { + index: agg.index, + scalar, + }); + continue; + } } } - } + ColumnBindingBuilder::new( + agg_func.display_name.clone(), + agg.index, + agg_func.return_type.clone(), + Visibility::Visible, + ) + .build() + }; // Add other aggregate functions as derived column, // this will be used in aggregating index rewrite. @@ -162,13 +178,7 @@ impl RuleStatsAggregateOptimizer { index: agg.index, scalar: ScalarExpr::BoundColumnRef(BoundColumnRef { span: agg.scalar.span(), - column: ColumnBindingBuilder::new( - agg_func.display_name.clone(), - agg.index, - agg_func.return_type.clone(), - Visibility::Visible, - ) - .build(), + column, }), }); agg_results.push(agg.clone()); diff --git a/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/agg_index/query_rewrite.rs b/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/agg_index/query_rewrite.rs index 22885188c78ee..09e75a0a9f594 100644 --- a/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/agg_index/query_rewrite.rs +++ b/src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/agg_index/query_rewrite.rs @@ -25,6 +25,7 @@ use databend_common_expression::types::DataType; use databend_common_expression::Scalar; use databend_common_expression::TableField; use databend_common_expression::TableSchemaRefExt; +use databend_common_functions::aggregates::AggregateFunctionFactory; use itertools::Itertools; use log::info; @@ -306,6 +307,12 @@ impl QueryInfo { } return Ok(None); } + ScalarExpr::UDAFCall(udaf) => { + for arg in &udaf.arguments { + self.check_output_cols(arg, index_output_cols, new_selection_set)?; + } + return Ok(None); + } ScalarExpr::UDFCall(udf) => { let mut valid = true; let mut new_args = Vec::with_capacity(udf.arguments.len()); @@ -361,23 +368,38 @@ impl ViewInfo { // query can use those columns to compute expressions. let mut index_fields = Vec::with_capacity(query_info.output_cols.len()); let mut index_output_cols = HashMap::with_capacity(query_info.output_cols.len()); + let factory = AggregateFunctionFactory::instance(); for (index, item) in query_info.output_cols.iter().enumerate() { let display_name = format_scalar(&item.scalar, &query_info.column_map); - let mut is_agg = false; - if let Some(ref aggregate) = query_info.aggregate { - for agg_func in &aggregate.aggregate_functions { - if item.index == agg_func.index { - is_agg = true; - break; - } + let aggr_scalar_item = query_info.aggregate.as_ref().and_then(|aggregate| { + aggregate + .aggregate_functions + .iter() + .find(|agg_func| agg_func.index == item.index) + }); + + let (data_type, is_agg) = match aggr_scalar_item { + Some(item) => { + let func = match &item.scalar { + ScalarExpr::AggregateFunction(func) => func, + _ => unreachable!(), + }; + let func = factory.get( + &func.func_name, + func.params.clone(), + func.args + .iter() + .map(|arg| arg.data_type()) + .collect::>()?, + func.sort_descs + .iter() + .map(|desc| desc.try_into()) + .collect::>()?, + )?; + (func.serialize_data_type(), true) } - } - // we store the value of aggregate function as binary data. - let data_type = if is_agg { - DataType::Binary - } else { - item.scalar.data_type().unwrap() + None => (item.scalar.data_type().unwrap(), false), }; let name = format!("{index}"); @@ -1299,6 +1321,15 @@ fn format_scalar(scalar: &ScalarExpr, column_map: &HashMap { + let args = udaf + .arguments + .iter() + .map(|arg| format_scalar(arg, column_map)) + .collect::>() + .join(", "); + format!("{}({})", &udaf.name, args) + } ScalarExpr::UDFCall(udf) => format!( "{}({})", &udf.handler, diff --git a/src/query/storages/fuse/src/operations/analyze.rs b/src/query/storages/fuse/src/operations/analyze.rs index b6d951a84545b..84dc9817826c5 100644 --- a/src/query/storages/fuse/src/operations/analyze.rs +++ b/src/query/storages/fuse/src/operations/analyze.rs @@ -186,8 +186,8 @@ impl SinkAnalyzeState { let index: u32 = name.strip_prefix("ndv_").unwrap().parse().unwrap(); let col = col.index(0).unwrap(); - let col = col.as_binary().unwrap(); - let hll: MetaHLL = borsh_deserialize_from_slice(col)?; + let data = col.as_tuple().unwrap()[0].as_binary().unwrap(); + let hll: MetaHLL = borsh_deserialize_from_slice(data)?; if !is_full { ndv_states diff --git a/tests/sqllogictests/suites/ee/02_ee_aggregating_index/02_0001_agg_index_projected_scan.test b/tests/sqllogictests/suites/ee/02_ee_aggregating_index/02_0001_agg_index_projected_scan.test index d9e8286a48e69..f7b09f81240b3 100644 --- a/tests/sqllogictests/suites/ee/02_ee_aggregating_index/02_0001_agg_index_projected_scan.test +++ b/tests/sqllogictests/suites/ee/02_ee_aggregating_index/02_0001_agg_index_projected_scan.test @@ -75,7 +75,44 @@ SELECT b, SUM(a) from t WHERE c > 1 GROUP BY b ORDER BY b 1 1 2 3 -query IIT +query IIR SELECT MAX(a), MIN(b), AVG(c) from t ---- 2 1 3.5 + +statement ok +CREATE or REPLACE FUNCTION weighted_avg (INT, INT) STATE {sum INT, weight INT} RETURNS FLOAT +LANGUAGE javascript AS $$ +export function create_state() { + return {sum: 0, weight: 0}; +} +export function accumulate(state, value, weight) { + state.sum += value * weight; + state.weight += weight; + return state; +} +export function retract(state, value, weight) { + state.sum -= value * weight; + state.weight -= weight; + return state; +} +export function merge(state1, state2) { + state1.sum += state2.sum; + state1.weight += state2.weight; + return state1; +} +export function finish(state) { + return state.sum / state.weight; +} +$$; + +query IIR +SELECT MAX(a), MIN(b), weighted_avg(a,b) from t group by b; +---- +2 2 1.3333334 +1 1 1.0 + +query IIR +SELECT MAX(a), MIN(b), weighted_avg(a,b) from t; +---- +2 1 1.2857143 diff --git a/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_mix.test b/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_mix.test index e865f23303e72..85166f60d5b24 100644 --- a/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_mix.test +++ b/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_mix.test @@ -578,8 +578,8 @@ INSERT INTO convert_string (c1, c2, c3) VALUES ('a', 2, 2), ('b', 1, 2); -query IT -select c1, string_agg(json_object('col', c2, 'no2', c3), ',') from convert_string group by c1 order by c1; +query IT rowsort +select c1, string_agg(json_object('col', c2, 'no2', c3), ',') from convert_string group by c1; ---- a {"col":1,"no2":1},{"col":2,"no2":1},{"col":1,"no2":1},{"col":1,"no2":2},{"col":2,"no2":2} b {"col":1,"no2":1},{"col":2,"no2":1},{"col":1,"no2":2} diff --git a/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_state.test b/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_state.test index bd2c249067d25..823f71a00e9ff 100644 --- a/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_state.test +++ b/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_state.test @@ -1,9 +1,9 @@ query IT -select length(max_state(number)), typeof(max_state(number)) from numbers(100); +select length(max_state(number).1), typeof(max_state(number)) from numbers(100); ---- -10 BINARY +9 TUPLE(BINARY, BOOLEAN) -query I -select length(sum_state(number)), typeof(max_state(number)) from numbers(10000); +query IT +select sum_state(number).1, typeof(sum_state(number)) from numbers(10000); ---- -9 BINARY +49995000 TUPLE(UINT64, BOOLEAN)