@@ -28,15 +28,13 @@ use datafusion_common::{
28
28
JoinConstraint , Result ,
29
29
} ;
30
30
use datafusion_expr:: expr_rewriter:: replace_col;
31
- use datafusion_expr:: logical_plan:: {
32
- CrossJoin , Join , JoinType , LogicalPlan , TableScan , Union ,
33
- } ;
31
+ use datafusion_expr:: logical_plan:: { CrossJoin , Join , JoinType , LogicalPlan , Union } ;
34
32
use datafusion_expr:: utils:: {
35
33
conjunction, expr_to_columns, split_conjunction, split_conjunction_owned,
36
34
} ;
37
35
use datafusion_expr:: {
38
36
and, build_join_schema, or, BinaryExpr , Expr , Filter , LogicalPlanBuilder , Operator ,
39
- Projection , TableProviderFilterPushDown ,
37
+ Projection , TableProviderFilterPushDown , TableScan ,
40
38
} ;
41
39
42
40
use crate :: optimizer:: ApplyOrder ;
@@ -897,23 +895,106 @@ impl OptimizerRule for PushDownFilter {
897
895
. map ( |( pred, _) | pred) ;
898
896
let new_scan_filters: Vec < Expr > =
899
897
new_scan_filters. unique ( ) . cloned ( ) . collect ( ) ;
898
+
899
+ let source_schema = scan. source . schema ( ) ;
900
+ let mut additional_projection = HashSet :: new ( ) ;
900
901
let new_predicate: Vec < Expr > = zip
901
- . filter ( |( _, res) | res != & TableProviderFilterPushDown :: Exact )
902
+ . filter ( |( expr, res) | {
903
+ if * res == TableProviderFilterPushDown :: Exact {
904
+ return false ;
905
+ }
906
+ expr. apply ( |expr| {
907
+ if let Expr :: Column ( column) = expr {
908
+ if let Ok ( idx) = source_schema. index_of ( column. name ( ) ) {
909
+ if scan
910
+ . projection
911
+ . as_ref ( )
912
+ . is_some_and ( |p| !p. contains ( & idx) )
913
+ {
914
+ additional_projection. insert ( idx) ;
915
+ }
916
+ }
917
+ }
918
+ Ok ( TreeNodeRecursion :: Continue )
919
+ } )
920
+ . unwrap ( ) ;
921
+ true
922
+ } )
902
923
. map ( |( pred, _) | pred. clone ( ) )
903
924
. collect ( ) ;
904
925
905
- let new_scan = LogicalPlan :: TableScan ( TableScan {
906
- filters : new_scan_filters,
907
- ..scan
908
- } ) ;
926
+ let scan_source = Arc :: clone ( & scan. source ) ;
927
+ let scan_table_name = scan. table_name . clone ( ) ;
909
928
910
- Transformed :: yes ( new_scan) . transform_data ( |new_scan| {
911
- if let Some ( predicate) = conjunction ( new_predicate) {
912
- make_filter ( predicate, Arc :: new ( new_scan) ) . map ( Transformed :: yes)
929
+ // Wraps with a filter if some filters are not supported exactly.
930
+ let filtered = move |plan| {
931
+ if let Some ( new_predicate) = conjunction ( new_predicate) {
932
+ Filter :: try_new ( new_predicate, Arc :: new ( plan) )
933
+ . map ( LogicalPlan :: Filter )
913
934
} else {
914
- Ok ( Transformed :: no ( new_scan ) )
935
+ Ok ( plan )
915
936
}
916
- } )
937
+ } ;
938
+
939
+ if additional_projection. is_empty ( ) {
940
+ // No additional projection is required.
941
+ let new_scan = LogicalPlan :: TableScan ( TableScan {
942
+ filters : new_scan_filters,
943
+ ..scan
944
+ } ) ;
945
+ return filtered ( new_scan) . map ( Transformed :: yes) ;
946
+ }
947
+
948
+ let new_scan = filtered (
949
+ LogicalPlanBuilder :: scan_with_filters_fetch (
950
+ scan_table_name. clone ( ) ,
951
+ scan. source ,
952
+ scan. projection . clone ( ) . map ( |mut projection| {
953
+ // Extend a projection.
954
+ projection. extend ( additional_projection) ;
955
+ projection
956
+ } ) ,
957
+ new_scan_filters,
958
+ scan. fetch ,
959
+ ) ?
960
+ . build ( ) ?,
961
+ ) ?;
962
+
963
+ // Project fields required by the initial projection.
964
+ let source_schema = scan_source. schema ( ) ;
965
+ let new_plan = LogicalPlan :: Projection ( Projection :: try_new_with_schema (
966
+ scan. projection
967
+ . as_ref ( )
968
+ . map ( |projection| {
969
+ projection
970
+ . into_iter ( )
971
+ . cloned ( )
972
+ . map ( |idx| {
973
+ Expr :: Column ( Column :: new (
974
+ Some ( scan_table_name. clone ( ) ) ,
975
+ source_schema. field ( idx) . name ( ) ,
976
+ ) )
977
+ } )
978
+ . collect ( )
979
+ } )
980
+ . unwrap_or_else ( || {
981
+ source_schema
982
+ . fields ( )
983
+ . iter ( )
984
+ . map ( |field| {
985
+ Expr :: Column ( Column :: new (
986
+ Some ( scan_table_name. clone ( ) ) ,
987
+ field. name ( ) ,
988
+ ) )
989
+ } )
990
+ . collect ( )
991
+ } ) ,
992
+ Arc :: new ( new_scan) ,
993
+ // Preserve a projected schema.
994
+ scan. projected_schema ,
995
+ ) ?) ;
996
+
997
+ Ok ( Transformed :: yes ( new_plan) )
917
998
}
918
999
LogicalPlan :: Extension ( extension_plan) => {
919
1000
let prevent_cols =
@@ -1206,8 +1287,8 @@ mod tests {
1206
1287
use datafusion_expr:: logical_plan:: table_scan;
1207
1288
use datafusion_expr:: {
1208
1289
col, in_list, in_subquery, lit, ColumnarValue , Extension , ScalarUDF ,
1209
- ScalarUDFImpl , Signature , TableSource , TableType , UserDefinedLogicalNodeCore ,
1210
- Volatility ,
1290
+ ScalarUDFImpl , Signature , TableScan , TableSource , TableType ,
1291
+ UserDefinedLogicalNodeCore , Volatility ,
1211
1292
} ;
1212
1293
1213
1294
use crate :: optimizer:: Optimizer ;
@@ -2452,6 +2533,33 @@ mod tests {
2452
2533
. build ( )
2453
2534
}
2454
2535
2536
+ #[ test]
2537
+ fn projection_is_updated_when_filter_becomes_unsupported ( ) -> Result < ( ) > {
2538
+ let test_provider = PushDownProvider {
2539
+ filter_support : TableProviderFilterPushDown :: Unsupported ,
2540
+ } ;
2541
+
2542
+ let projeted_schema = test_provider. schema ( ) . project ( & [ 0 ] ) ?;
2543
+ let table_scan = LogicalPlan :: TableScan ( TableScan {
2544
+ table_name : "test" . into ( ) ,
2545
+ // Emulate that there were pushed filters but now
2546
+ // provider cannot support it.
2547
+ filters : vec ! [ col( "b" ) . eq( lit( 1i64 ) ) ] ,
2548
+ projected_schema : Arc :: new ( DFSchema :: try_from ( projeted_schema) ?) ,
2549
+ projection : Some ( vec ! [ 0 ] ) ,
2550
+ source : Arc :: new ( test_provider) ,
2551
+ fetch : None ,
2552
+ } ) ;
2553
+
2554
+ let plan = LogicalPlanBuilder :: from ( table_scan)
2555
+ . filter ( col ( "a" ) . eq ( lit ( 1i64 ) ) ) ?
2556
+ . build ( ) ?;
2557
+
2558
+ let expected = "\
2559
+ Projection: test.a\n Filter: a = Int64(1) AND b = Int64(1)\n TableScan: test projection=[a, b]";
2560
+ assert_optimized_plan_eq ( plan, expected)
2561
+ }
2562
+
2455
2563
#[ test]
2456
2564
fn filter_with_table_provider_exact ( ) -> Result < ( ) > {
2457
2565
let plan = table_scan_with_pushdown_provider ( TableProviderFilterPushDown :: Exact ) ?;
@@ -2514,7 +2622,7 @@ mod tests {
2514
2622
projected_schema : Arc :: new ( DFSchema :: try_from (
2515
2623
( * test_provider. schema ( ) ) . clone ( ) ,
2516
2624
) ?) ,
2517
- projection : Some ( vec ! [ 0 ] ) ,
2625
+ projection : Some ( vec ! [ 0 , 1 ] ) ,
2518
2626
source : Arc :: new ( test_provider) ,
2519
2627
fetch : None ,
2520
2628
} ) ;
0 commit comments