Skip to content

Commit 72f3a29

Browse files
committed
fix
1 parent 828da4a commit 72f3a29

File tree

6 files changed

+60
-48
lines changed

6 files changed

+60
-48
lines changed

src/query/catalog/src/plan/agg_index.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ pub struct AggIndexInfo {
5151
}
5252

5353
/// This meta just indicate the block is from aggregating index.
54-
#[derive(Debug, Clone)]
54+
#[derive(Debug, Clone, Copy)]
5555
pub struct AggIndexMeta {
5656
pub is_agg: bool,
5757
// Number of aggregation functions.
@@ -75,6 +75,6 @@ local_block_meta_serde!(AggIndexMeta);
7575
#[typetag::serde(name = "agg_index_meta")]
7676
impl BlockMetaInfo for AggIndexMeta {
7777
fn clone_self(&self) -> Box<dyn BlockMetaInfo> {
78-
Box::new(self.clone())
78+
Box::new(*self)
7979
}
8080
}

src/query/expression/src/aggregate/aggregate_function_state.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,10 @@ pub struct StatesLayout {
228228
}
229229

230230
impl StatesLayout {
231+
pub fn num_aggr_func(&self) -> usize {
232+
self.states_loc.len()
233+
}
234+
231235
pub fn serialize_builders(&self, num_rows: usize) -> Vec<ColumnBuilder> {
232236
self.serialize_type
233237
.iter()

src/query/service/src/pipelines/processors/transforms/aggregator/transform_aggregate_partial.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ impl TransformPartialAggregate {
150150
.params
151151
.states_layout
152152
.as_ref()
153-
.map(|layout| layout.states_loc.len())
153+
.map(|layout| layout.num_aggr_func())
154154
.unwrap_or(0);
155155
(
156156
vec![],

src/query/service/src/pipelines/processors/transforms/aggregator/transform_single_key.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,18 +93,19 @@ impl AccumulatingTransform for PartialSingleStateAggregator {
9393
self.first_block_start = Some(Instant::now());
9494
}
9595

96-
let is_agg_index_block = block
96+
let meta = block
9797
.get_meta()
9898
.and_then(AggIndexMeta::downcast_ref_from)
99-
.map(|index| index.is_agg)
100-
.unwrap_or_default();
99+
.copied();
101100

102101
let block = block.consume_convert_to_full();
103-
if is_agg_index_block {
102+
if let Some(meta) = meta
103+
&& meta.is_agg
104+
{
105+
assert_eq!(self.states_layout.num_aggr_func(), meta.num_agg_funcs);
104106
// Aggregation states are in the back of the block.
105-
let states_indices = (block.num_columns() - self.states_layout.states_loc.len()
106-
..block.num_columns())
107-
.collect::<Vec<_>>();
107+
let start = block.num_columns() - self.states_layout.num_aggr_func();
108+
let states_indices = (start..block.num_columns()).collect::<Vec<_>>();
108109
let states = ProjectedBlock::project(&states_indices, &block);
109110

110111
for ((place, func), state) in self

src/query/sql/src/planner/optimizer/optimizers/operator/aggregate/stats_aggregate.rs

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -129,51 +129,56 @@ impl RuleStatsAggregateOptimizer {
129129
for (need_rewrite_agg, agg) in
130130
need_rewrite_aggs.iter().zip(agg.aggregate_functions.iter())
131131
{
132-
if matches!(agg.scalar, ScalarExpr::UDAFCall(_)) {
133-
agg_results.push(agg.clone());
134-
continue;
135-
}
136-
137-
let agg_func = AggregateFunction::try_from(agg.scalar.clone())?;
138-
139-
if let Some((col_id, name)) = need_rewrite_agg {
140-
if let Some(stat) = stats.get(col_id) {
141-
let value_bound = if name.eq_ignore_ascii_case("min") {
142-
&stat.min
143-
} else {
144-
&stat.max
145-
};
146-
if !value_bound.may_be_truncated {
147-
let scalar = ScalarExpr::ConstantExpr(ConstantExpr {
148-
span: agg.scalar.span(),
149-
value: value_bound.value.clone(),
150-
});
151-
152-
let scalar =
153-
scalar.unify_to_data_type(agg_func.return_type.as_ref());
154-
155-
eval_scalar_results.push(ScalarItem {
156-
index: agg.index,
157-
scalar,
158-
});
159-
continue;
132+
let column = if let ScalarExpr::UDAFCall(udaf) = &agg.scalar {
133+
ColumnBindingBuilder::new(
134+
udaf.display_name.clone(),
135+
agg.index,
136+
udaf.return_type.clone(),
137+
Visibility::Visible,
138+
)
139+
.build()
140+
} else {
141+
let agg_func = AggregateFunction::try_from(agg.scalar.clone())?;
142+
if let Some((col_id, name)) = need_rewrite_agg {
143+
if let Some(stat) = stats.get(col_id) {
144+
let value_bound = if name.eq_ignore_ascii_case("min") {
145+
&stat.min
146+
} else {
147+
&stat.max
148+
};
149+
if !value_bound.may_be_truncated {
150+
let scalar = ScalarExpr::ConstantExpr(ConstantExpr {
151+
span: agg.scalar.span(),
152+
value: value_bound.value.clone(),
153+
});
154+
155+
let scalar = scalar
156+
.unify_to_data_type(agg_func.return_type.as_ref());
157+
158+
eval_scalar_results.push(ScalarItem {
159+
index: agg.index,
160+
scalar,
161+
});
162+
continue;
163+
}
160164
}
161165
}
162-
}
166+
ColumnBindingBuilder::new(
167+
agg_func.display_name.clone(),
168+
agg.index,
169+
agg_func.return_type.clone(),
170+
Visibility::Visible,
171+
)
172+
.build()
173+
};
163174

164175
// Add other aggregate functions as derived column,
165176
// this will be used in aggregating index rewrite.
166177
eval_scalar_results.push(ScalarItem {
167178
index: agg.index,
168179
scalar: ScalarExpr::BoundColumnRef(BoundColumnRef {
169180
span: agg.scalar.span(),
170-
column: ColumnBindingBuilder::new(
171-
agg_func.display_name.clone(),
172-
agg.index,
173-
agg_func.return_type.clone(),
174-
Visibility::Visible,
175-
)
176-
.build(),
181+
column,
177182
}),
178183
});
179184
agg_results.push(agg.clone());

tests/sqllogictests/suites/ee/02_ee_aggregating_index/02_0001_agg_index_projected_scan.test

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,5 +112,7 @@ SELECT MAX(a), MIN(b), weighted_avg(a,b) from t group by b;
112112
2 2 1.3333334
113113
1 1 1.0
114114

115-
# fix me
116-
# SELECT MAX(a), MIN(b), weighted_avg(a,b) from t;
115+
query IIR
116+
SELECT MAX(a), MIN(b), weighted_avg(a,b) from t;
117+
----
118+
2 1 1.2857143

0 commit comments

Comments
 (0)