From 1409ec7418bd2c07e738de0058e2868bc0a37856 Mon Sep 17 00:00:00 2001 From: wszhdshys <1925792291@qq.com> Date: Thu, 24 Jul 2025 11:14:28 +0800 Subject: [PATCH 1/2] Using the RBO method, the related subquery is implemented --- src/binder/expr.rs | 13 +- src/binder/mod.rs | 15 +- src/db.rs | 5 + src/optimizer/heuristic/graph.rs | 3 +- .../rule/normalization/correlated_subquery.rs | 245 ++++++++++++++++++ src/optimizer/rule/normalization/mod.rs | 5 + .../rule/normalization/simplification.rs | 4 +- tests/slt/correlated_subquery.slt | 112 ++++++++ tests/slt/crdb/update.slt | 2 +- 9 files changed, 395 insertions(+), 9 deletions(-) create mode 100644 src/optimizer/rule/normalization/correlated_subquery.rs create mode 100644 tests/slt/correlated_subquery.slt diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 890d6d62..a7bd9ebf 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -7,7 +7,7 @@ use sqlparser::ast::{ BinaryOperator, CharLengthUnits, DataType, Expr, Function, FunctionArg, FunctionArgExpr, Ident, Query, UnaryOperator, Value, }; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::slice; use std::sync::Arc; @@ -293,6 +293,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T self.args, Some(self), ); + let mut sub_query = binder.bind_query(subquery)?; let sub_query_schema = sub_query.output_schema(); @@ -368,7 +369,15 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T try_default!(&full_name.0, full_name.1); } if let Some(table) = full_name.0.or(bind_table_name) { - let source = self.context.bind_source(&table)?; + let (source, is_parent) = self.context.bind_source::(self.parent, &table, false)?; + + if is_parent { + self.parent_table_col + .entry(Arc::new(table.clone())) + .or_default() + .insert(full_name.1.clone()); + } + let schema_buf = self.table_schema_buf.entry(Arc::new(table)).or_default(); Ok(ScalarExpression::ColumnRef( diff --git a/src/binder/mod.rs b/src/binder/mod.rs index 9be10c73..b132f474 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -26,7 +26,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use crate::catalog::view::View; -use crate::catalog::{ColumnRef, TableCatalog, TableName}; +use crate::catalog::{ColumnCatalog, ColumnRef, TableCatalog, TableName}; use crate::db::{ScalaFunctions, TableFunctions}; use crate::errors::DatabaseError; use crate::expression::ScalarExpression; @@ -276,12 +276,19 @@ impl<'a, T: Transaction> BinderContext<'a, T> { Ok(source) } - pub fn bind_source<'b: 'a>(&self, table_name: &str) -> Result<&Source, DatabaseError> { + pub fn bind_source<'b: 'a, A: AsRef<[(&'static str, DataValue)]>>( + &self, + parent: Option<&'a Binder<'a, 'b, T, A>>, + table_name: &str, + is_parent: bool, + ) -> Result<(&'b Source, bool), DatabaseError> { if let Some(source) = self.bind_table.iter().find(|((t, alias, _), _)| { t.as_str() == table_name || matches!(alias.as_ref().map(|a| a.as_str() == table_name), Some(true)) }) { - Ok(source.1) + Ok((source.1, is_parent)) + } else if let Some(binder) = parent { + binder.context.bind_source(binder.parent, table_name, true) } else { Err(DatabaseError::InvalidTable(table_name.into())) } @@ -323,6 +330,7 @@ pub struct Binder<'a, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> args: &'a A, with_pk: Option, pub(crate) parent: Option<&'b Binder<'a, 'b, T, A>>, + pub(crate) parent_table_col: HashMap>, } impl<'a, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, 'b, T, A> { @@ -337,6 +345,7 @@ impl<'a, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, ' args, with_pk: None, parent, + parent_table_col: Default::default(), } } diff --git a/src/db.rs b/src/db.rs index ff08a4fa..ca60c0eb 100644 --- a/src/db.rs +++ b/src/db.rs @@ -173,6 +173,11 @@ impl State { pub(crate) fn default_optimizer(source_plan: LogicalPlan) -> HepOptimizer { HepOptimizer::new(source_plan) + .batch( + "Correlated Subquery".to_string(), + HepBatchStrategy::once_topdown(), + vec![NormalizationRuleImpl::CorrelateSubquery], + ) .batch( "Column Pruning".to_string(), HepBatchStrategy::once_topdown(), diff --git a/src/optimizer/heuristic/graph.rs b/src/optimizer/heuristic/graph.rs index f6de8c61..6695bfdc 100644 --- a/src/optimizer/heuristic/graph.rs +++ b/src/optimizer/heuristic/graph.rs @@ -79,7 +79,7 @@ impl HepGraph { source_id: HepNodeId, children_option: Option, new_node: Operator, - ) { + ) -> HepNodeId { let new_index = self.graph.add_node(new_node); let mut order = self.graph.edges(source_id).count(); @@ -95,6 +95,7 @@ impl HepGraph { self.graph.add_edge(source_id, new_index, order); self.version += 1; + new_index } pub fn replace_node(&mut self, source_id: HepNodeId, new_node: Operator) { diff --git a/src/optimizer/rule/normalization/correlated_subquery.rs b/src/optimizer/rule/normalization/correlated_subquery.rs new file mode 100644 index 00000000..32eaca39 --- /dev/null +++ b/src/optimizer/rule/normalization/correlated_subquery.rs @@ -0,0 +1,245 @@ +use crate::catalog::{ColumnRef, TableName}; +use crate::errors::DatabaseError; +use crate::expression::visitor::Visitor; +use crate::expression::HasCountStar; +use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; +use crate::optimizer::core::rule::{MatchPattern, NormalizationRule}; +use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; +use crate::planner::operator::join::{JoinCondition, JoinOperator, JoinType}; +use crate::planner::operator::table_scan::TableScanOperator; +use crate::planner::operator::Operator; +use crate::planner::operator::Operator::{Join, TableScan}; +use crate::types::index::IndexInfo; +use crate::types::ColumnId; +use itertools::Itertools; +use std::collections::BTreeMap; +use std::collections::{HashMap, HashSet}; +use std::sync::{Arc, LazyLock}; + +static CORRELATED_SUBQUERY_RULE: LazyLock = LazyLock::new(|| Pattern { + predicate: |op| matches!(op, Join(_)), + children: PatternChildrenPredicate::None, +}); + +#[derive(Clone)] +pub struct CorrelatedSubquery; + +macro_rules! trans_references { + ($columns:expr) => {{ + let mut column_references = HashSet::with_capacity($columns.len()); + for column in $columns { + column_references.insert(column); + } + column_references + }}; +} + +impl CorrelatedSubquery { + fn _apply( + column_references: HashSet<&ColumnRef>, + scan_columns: HashMap, HashMap, Vec)>, + node_id: HepNodeId, + graph: &mut HepGraph, + ) -> Result< + HashMap, HashMap, Vec)>, + DatabaseError, + > { + let operator = &graph.operator(node_id).clone(); + + match operator { + Operator::Aggregate(op) => { + let is_distinct = op.is_distinct; + let referenced_columns = operator.referenced_columns(false); + let mut new_column_references = trans_references!(&referenced_columns); + // on distinct + if is_distinct { + for summary in column_references { + new_column_references.insert(summary); + } + } + + Self::recollect_apply(new_column_references, scan_columns, node_id, graph) + } + Operator::Project(op) => { + let mut has_count_star = HasCountStar::default(); + for expr in &op.exprs { + has_count_star.visit(expr)?; + } + let referenced_columns = operator.referenced_columns(false); + let new_column_references = trans_references!(&referenced_columns); + + Self::recollect_apply(new_column_references, scan_columns, node_id, graph) + } + Operator::TableScan(op) => { + let table_column: HashSet<&ColumnRef> = op.columns.values().collect(); + let mut new_scan_columns = scan_columns.clone(); + new_scan_columns.insert( + op.table_name.clone(), + ( + op.primary_keys.clone(), + op.columns + .iter() + .map(|(num, col)| (col.id().unwrap(), *num)) + .collect(), + op.index_infos.clone(), + ), + ); + let mut parent_col = HashMap::new(); + for col in column_references { + match ( + table_column.contains(col), + scan_columns.get(col.table_name().unwrap_or(&Arc::new("".to_string()))), + ) { + (false, Some(..)) => { + parent_col + .entry(col.table_name().unwrap()) + .or_insert(HashSet::new()) + .insert(col); + } + _ => continue, + } + } + for (table_name, table_columns) in parent_col { + let table_columns = table_columns.into_iter().collect_vec(); + let (primary_keys, columns, index_infos) = + scan_columns.get(table_name).unwrap(); + let map: BTreeMap = table_columns + .into_iter() + .map(|col| (*columns.get(&col.id().unwrap()).unwrap(), col.clone())) + .collect(); + let left_operator = graph.operator(node_id).clone(); + let right_operator = TableScan(TableScanOperator { + table_name: table_name.clone(), + primary_keys: primary_keys.clone(), + columns: map, + limit: (None, None), + index_infos: index_infos.clone(), + with_pk: false, + }); + let join_operator = Join(JoinOperator { + on: JoinCondition::None, + join_type: JoinType::Cross, + }); + + match &left_operator { + TableScan(_) => { + graph.replace_node(node_id, join_operator); + graph.add_node(node_id, None, left_operator); + graph.add_node(node_id, None, right_operator); + } + Join(_) => { + let left_id = graph.eldest_child_at(node_id).unwrap(); + let left_id = graph.add_node(node_id, Some(left_id), join_operator); + graph.add_node(left_id, None, right_operator); + } + _ => unreachable!(), + } + } + Ok(new_scan_columns) + } + Operator::Sort(_) | Operator::Limit(_) | Operator::Filter(_) | Operator::Union(_) => { + let mut new_scan_columns = scan_columns.clone(); + let temp_columns = operator.referenced_columns(false); + // why? + let mut column_references = column_references; + for column in temp_columns.iter() { + column_references.insert(column); + } + for child_id in graph.children_at(node_id).collect_vec() { + let copy_references = column_references.clone(); + let copy_scan = scan_columns.clone(); + if let Ok(scan) = Self::_apply(copy_references, copy_scan, child_id, graph) { + new_scan_columns.extend(scan); + }; + } + Ok(new_scan_columns) + } + Operator::Join(_) => { + let mut new_scan_columns = scan_columns.clone(); + for child_id in graph.children_at(node_id).collect_vec() { + let copy_references = column_references.clone(); + let copy_scan = new_scan_columns.clone(); + if let Ok(scan) = Self::_apply(copy_references, copy_scan, child_id, graph) { + new_scan_columns.extend(scan); + }; + } + Ok(new_scan_columns) + } + // Last Operator + Operator::Dummy | Operator::Values(_) | Operator::FunctionScan(_) => Ok(scan_columns), + Operator::Explain => { + if let Some(child_id) = graph.eldest_child_at(node_id) { + Self::_apply(column_references, scan_columns, child_id, graph) + } else { + unreachable!() + } + } + // DDL Based on Other Plan + Operator::Insert(_) + | Operator::Update(_) + | Operator::Delete(_) + | Operator::Analyze(_) => { + let referenced_columns = operator.referenced_columns(false); + let new_column_references = trans_references!(&referenced_columns); + + if let Some(child_id) = graph.eldest_child_at(node_id) { + Self::recollect_apply(new_column_references, scan_columns, child_id, graph) + } else { + unreachable!(); + } + } + // DDL Single Plan + Operator::CreateTable(_) + | Operator::CreateIndex(_) + | Operator::CreateView(_) + | Operator::DropTable(_) + | Operator::DropView(_) + | Operator::DropIndex(_) + | Operator::Truncate(_) + | Operator::ShowTable + | Operator::ShowView + | Operator::CopyFromFile(_) + | Operator::CopyToFile(_) + | Operator::AddColumn(_) + | Operator::DropColumn(_) + | Operator::Describe(_) => Ok(scan_columns), + } + } + + fn recollect_apply( + referenced_columns: HashSet<&ColumnRef>, + scan_columns: HashMap, HashMap, Vec)>, + node_id: HepNodeId, + graph: &mut HepGraph, + ) -> Result< + HashMap, HashMap, Vec)>, + DatabaseError, + > { + let mut new_scan_columns = scan_columns.clone(); + for child_id in graph.children_at(node_id).collect_vec() { + let copy_references = referenced_columns.clone(); + let copy_scan = scan_columns.clone(); + + if let Ok(scan) = Self::_apply(copy_references, copy_scan, child_id, graph) { + new_scan_columns.extend(scan); + }; + } + Ok(new_scan_columns) + } +} + +impl MatchPattern for CorrelatedSubquery { + fn pattern(&self) -> &Pattern { + &CORRELATED_SUBQUERY_RULE + } +} + +impl NormalizationRule for CorrelatedSubquery { + fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { + Self::_apply(HashSet::new(), HashMap::new(), node_id, graph)?; + // mark changed to skip this rule batch + graph.version += 1; + + Ok(()) + } +} diff --git a/src/optimizer/rule/normalization/mod.rs b/src/optimizer/rule/normalization/mod.rs index 8c30fd4a..b7fc6d4b 100644 --- a/src/optimizer/rule/normalization/mod.rs +++ b/src/optimizer/rule/normalization/mod.rs @@ -10,6 +10,7 @@ use crate::optimizer::rule::normalization::combine_operators::{ use crate::optimizer::rule::normalization::compilation_in_advance::{ EvaluatorBind, ExpressionRemapper, }; +use crate::optimizer::rule::normalization::correlated_subquery::CorrelatedSubquery; use crate::optimizer::rule::normalization::pushdown_limit::{ LimitProjectTranspose, PushLimitIntoScan, PushLimitThroughJoin, }; @@ -21,6 +22,7 @@ use crate::optimizer::rule::normalization::simplification::SimplifyFilter; mod column_pruning; mod combine_operators; mod compilation_in_advance; +mod correlated_subquery; mod pushdown_limit; mod pushdown_predicates; mod simplification; @@ -32,6 +34,7 @@ pub enum NormalizationRuleImpl { CollapseProject, CollapseGroupByAgg, CombineFilter, + CorrelateSubquery, // PushDown limit LimitProjectTranspose, PushLimitThroughJoin, @@ -55,6 +58,7 @@ impl MatchPattern for NormalizationRuleImpl { NormalizationRuleImpl::CollapseProject => CollapseProject.pattern(), NormalizationRuleImpl::CollapseGroupByAgg => CollapseGroupByAgg.pattern(), NormalizationRuleImpl::CombineFilter => CombineFilter.pattern(), + NormalizationRuleImpl::CorrelateSubquery => CorrelatedSubquery.pattern(), NormalizationRuleImpl::LimitProjectTranspose => LimitProjectTranspose.pattern(), NormalizationRuleImpl::PushLimitThroughJoin => PushLimitThroughJoin.pattern(), NormalizationRuleImpl::PushLimitIntoTableScan => PushLimitIntoScan.pattern(), @@ -75,6 +79,7 @@ impl NormalizationRule for NormalizationRuleImpl { NormalizationRuleImpl::CollapseProject => CollapseProject.apply(node_id, graph), NormalizationRuleImpl::CollapseGroupByAgg => CollapseGroupByAgg.apply(node_id, graph), NormalizationRuleImpl::CombineFilter => CombineFilter.apply(node_id, graph), + NormalizationRuleImpl::CorrelateSubquery => CorrelatedSubquery.apply(node_id, graph), NormalizationRuleImpl::LimitProjectTranspose => { LimitProjectTranspose.apply(node_id, graph) } diff --git a/src/optimizer/rule/normalization/simplification.rs b/src/optimizer/rule/normalization/simplification.rs index a220c065..02d66ef5 100644 --- a/src/optimizer/rule/normalization/simplification.rs +++ b/src/optimizer/rule/normalization/simplification.rs @@ -298,7 +298,7 @@ mod test { op: BinaryOperator::Plus, left_expr: Box::new(ScalarExpression::ColumnRef(ColumnRef::from( c1_col - ))), + ),)), right_expr: Box::new(ScalarExpression::Constant(DataValue::Int32(1))), evaluator: None, ty: LogicalType::Integer, @@ -306,7 +306,7 @@ mod test { evaluator: None, ty: LogicalType::Integer, }), - right_expr: Box::new(ScalarExpression::ColumnRef(ColumnRef::from(c2_col))), + right_expr: Box::new(ScalarExpression::ColumnRef(ColumnRef::from(c2_col),)), evaluator: None, ty: LogicalType::Boolean, } diff --git a/tests/slt/correlated_subquery.slt b/tests/slt/correlated_subquery.slt new file mode 100644 index 00000000..1a6c0d62 --- /dev/null +++ b/tests/slt/correlated_subquery.slt @@ -0,0 +1,112 @@ +statement ok +CREATE TABLE t1 (id INT PRIMARY KEY, v1 VARCHAR(50), v2 INT); + +statement ok +CREATE TABLE t2 (id INT PRIMARY KEY, v1 VARCHAR(50), v2 INT); + +statement ok +CREATE TABLE t3 (id INT PRIMARY KEY, v1 INT, v2 INT); + +statement ok +insert into t1(id, v1, v2) values (1,'a',9) + +statement ok +insert into t1(id, v1, v2) values (2,'b',6) + +statement ok +insert into t1(id, v1, v2) values (3,'c',11) + +statement ok +insert into t2(id, v1, v2) values (1,'A',10) + +statement ok +insert into t2(id, v1, v2) values (2,'B',11) + +statement ok +insert into t2(id, v1, v2) values (3,'C',9) + +statement ok +insert into t3(id, v1, v2) values (1,6,10) + +statement ok +insert into t3(id, v1, v2) values (2,5,10) + +statement ok +insert into t3(id, v1, v2) values (3,4,10) + +query IT rowsort +SELECT id, v1 FROM t1 WHERE id IN ( SELECT t2.id FROM t2 WHERE t2.v2 < t1.v2 ) +---- +1 a +3 c + +query I rowsort +SELECT v1 FROM t1 WHERE EXISTS ( SELECT 1 FROM t3 WHERE t3.id = t1.id ) +---- +a +b +c + +query TT rowsort +SELECT t1.v1, t2.v1 FROM t1 JOIN t2 ON t1.id = t2.id WHERE t2.v2 > ( SELECT AVG(v2) FROM t1 ) +---- +a A +b B +c C + +query IT rowsort +SELECT id, v1 FROM t1 WHERE NOT EXISTS ( SELECT 1 FROM t2 WHERE t2.id = t1.id AND t2.v2 = t1.v2 ) +---- +1 a +2 b +3 c + +query IT rowsort +SELECT id, v1 FROM t1 WHERE v2 > ( SELECT MIN(v2) FROM t2 WHERE t2.id = t1.id ) +---- +3 c + +query IT rowsort +SELECT id, v1 FROM t1 WHERE EXISTS ( SELECT 1 FROM t2 WHERE t2.id = t1.id AND EXISTS ( SELECT 1 FROM t3 WHERE t3.id = t1.id ) ) +---- +1 a +2 b +3 c + +query IT rowsort +SELECT id, v1 FROM t1 WHERE v2 - 5 > ( SELECT AVG(v1) FROM t3 WHERE t3.id <= t1.id ) +---- +3 c + +query IT rowsort +SELECT id, v1 FROM t1 WHERE id NOT IN ( SELECT t2.id FROM t2 WHERE t2.v2 > t1.v2 ) +---- + + +query IT rowsort +SELECT id, v1 FROM t1 WHERE EXISTS ( SELECT 1 FROM t2 WHERE t2.id = t1.id AND t2.v2 + t1.v2 > 15 ) +---- +1 a +2 b +3 c + +query IT rowsort +SELECT id, v1 FROM t1 WHERE v2 = ( SELECT MAX(v2) FROM t2 WHERE t2.id <= t1.id ) ORDER BY id +---- +3 c + +query IT rowsort +SELECT id, v1 FROM t1 WHERE ( SELECT COUNT(*) FROM t2 WHERE t2.v2 = t1.v2 ) = 2 +---- +1 a +2 b +3 c + +statement ok +DROP TABLE t1; + +statement ok +DROP TABLE t2; + +statement ok +DROP TABLE t3; \ No newline at end of file diff --git a/tests/slt/crdb/update.slt b/tests/slt/crdb/update.slt index e87066a7..de96e271 100644 --- a/tests/slt/crdb/update.slt +++ b/tests/slt/crdb/update.slt @@ -88,8 +88,8 @@ query I select a from t1 order by a; ---- 1 -2 3 +4 8 # sqlparser-rs not support From 68888d86dd5ada1a62854e504e2ff222696d5f70 Mon Sep 17 00:00:00 2001 From: wszhdshys <1925792291@qq.com> Date: Thu, 24 Jul 2025 21:58:00 +0800 Subject: [PATCH 2/2] Changed the logic and naming --- src/binder/expr.rs | 9 +- src/binder/mod.rs | 9 +- .../rule/normalization/correlated_subquery.rs | 140 +++++++----------- 3 files changed, 57 insertions(+), 101 deletions(-) diff --git a/src/binder/expr.rs b/src/binder/expr.rs index a7bd9ebf..bced8a4d 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -369,14 +369,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T try_default!(&full_name.0, full_name.1); } if let Some(table) = full_name.0.or(bind_table_name) { - let (source, is_parent) = self.context.bind_source::(self.parent, &table, false)?; - - if is_parent { - self.parent_table_col - .entry(Arc::new(table.clone())) - .or_default() - .insert(full_name.1.clone()); - } + let source = self.context.bind_source::(self.parent, &table)?; let schema_buf = self.table_schema_buf.entry(Arc::new(table)).or_default(); diff --git a/src/binder/mod.rs b/src/binder/mod.rs index b132f474..307c8201 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -280,15 +280,14 @@ impl<'a, T: Transaction> BinderContext<'a, T> { &self, parent: Option<&'a Binder<'a, 'b, T, A>>, table_name: &str, - is_parent: bool, - ) -> Result<(&'b Source, bool), DatabaseError> { + ) -> Result<&'b Source, DatabaseError> { if let Some(source) = self.bind_table.iter().find(|((t, alias, _), _)| { t.as_str() == table_name || matches!(alias.as_ref().map(|a| a.as_str() == table_name), Some(true)) }) { - Ok((source.1, is_parent)) + Ok(source.1) } else if let Some(binder) = parent { - binder.context.bind_source(binder.parent, table_name, true) + binder.context.bind_source(binder.parent, table_name) } else { Err(DatabaseError::InvalidTable(table_name.into())) } @@ -330,7 +329,6 @@ pub struct Binder<'a, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> args: &'a A, with_pk: Option, pub(crate) parent: Option<&'b Binder<'a, 'b, T, A>>, - pub(crate) parent_table_col: HashMap>, } impl<'a, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, 'b, T, A> { @@ -345,7 +343,6 @@ impl<'a, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, ' args, with_pk: None, parent, - parent_table_col: Default::default(), } } diff --git a/src/optimizer/rule/normalization/correlated_subquery.rs b/src/optimizer/rule/normalization/correlated_subquery.rs index 32eaca39..30e44502 100644 --- a/src/optimizer/rule/normalization/correlated_subquery.rs +++ b/src/optimizer/rule/normalization/correlated_subquery.rs @@ -37,19 +37,16 @@ macro_rules! trans_references { impl CorrelatedSubquery { fn _apply( column_references: HashSet<&ColumnRef>, - scan_columns: HashMap, HashMap, Vec)>, + mut used_scan: HashMap, node_id: HepNodeId, graph: &mut HepGraph, - ) -> Result< - HashMap, HashMap, Vec)>, - DatabaseError, - > { + ) -> Result, DatabaseError> { let operator = &graph.operator(node_id).clone(); match operator { Operator::Aggregate(op) => { let is_distinct = op.is_distinct; - let referenced_columns = operator.referenced_columns(false); + let referenced_columns = operator.referenced_columns(true); let mut new_column_references = trans_references!(&referenced_columns); // on distinct if is_distinct { @@ -58,62 +55,45 @@ impl CorrelatedSubquery { } } - Self::recollect_apply(new_column_references, scan_columns, node_id, graph) + Self::recollect_apply(new_column_references, used_scan, node_id, graph) } Operator::Project(op) => { - let mut has_count_star = HasCountStar::default(); - for expr in &op.exprs { - has_count_star.visit(expr)?; - } - let referenced_columns = operator.referenced_columns(false); + let referenced_columns = operator.referenced_columns(true); let new_column_references = trans_references!(&referenced_columns); - Self::recollect_apply(new_column_references, scan_columns, node_id, graph) + Self::recollect_apply(new_column_references, used_scan, node_id, graph) } - Operator::TableScan(op) => { - let table_column: HashSet<&ColumnRef> = op.columns.values().collect(); - let mut new_scan_columns = scan_columns.clone(); - new_scan_columns.insert( - op.table_name.clone(), - ( - op.primary_keys.clone(), - op.columns - .iter() - .map(|(num, col)| (col.id().unwrap(), *num)) - .collect(), - op.index_infos.clone(), - ), - ); - let mut parent_col = HashMap::new(); + TableScan(op) => { + let table_columns: HashSet<&ColumnRef> = op.columns.values().collect(); + let mut parent_scan_to_added = HashMap::new(); for col in column_references { - match ( - table_column.contains(col), - scan_columns.get(col.table_name().unwrap_or(&Arc::new("".to_string()))), - ) { - (false, Some(..)) => { - parent_col - .entry(col.table_name().unwrap()) - .or_insert(HashSet::new()) - .insert(col); + if table_columns.contains(col) { + continue; + } + if let Some(table_name) = col.table_name() { + if !used_scan.contains_key(table_name) { + continue; } - _ => continue, + parent_scan_to_added + .entry(table_name) + .or_insert(HashSet::new()) + .insert(col); } } - for (table_name, table_columns) in parent_col { - let table_columns = table_columns.into_iter().collect_vec(); - let (primary_keys, columns, index_infos) = - scan_columns.get(table_name).unwrap(); - let map: BTreeMap = table_columns - .into_iter() - .map(|col| (*columns.get(&col.id().unwrap()).unwrap(), col.clone())) - .collect(); + for (table_name, table_columns) in parent_scan_to_added { + let op = used_scan.get(table_name).unwrap(); let left_operator = graph.operator(node_id).clone(); let right_operator = TableScan(TableScanOperator { table_name: table_name.clone(), - primary_keys: primary_keys.clone(), - columns: map, + primary_keys: op.primary_keys.clone(), + columns: op + .columns + .iter() + .filter(|(_, column)| table_columns.contains(column)) + .map(|(i, col)| (*i, col.clone())) + .collect(), limit: (None, None), - index_infos: index_infos.clone(), + index_infos: op.index_infos.clone(), with_pk: false, }); let join_operator = Join(JoinOperator { @@ -135,41 +115,33 @@ impl CorrelatedSubquery { _ => unreachable!(), } } - Ok(new_scan_columns) + used_scan.insert(op.table_name.clone(), op.clone()); + Ok(used_scan) } Operator::Sort(_) | Operator::Limit(_) | Operator::Filter(_) | Operator::Union(_) => { - let mut new_scan_columns = scan_columns.clone(); - let temp_columns = operator.referenced_columns(false); - // why? + let temp_columns = operator.referenced_columns(true); let mut column_references = column_references; for column in temp_columns.iter() { column_references.insert(column); } - for child_id in graph.children_at(node_id).collect_vec() { - let copy_references = column_references.clone(); - let copy_scan = scan_columns.clone(); - if let Ok(scan) = Self::_apply(copy_references, copy_scan, child_id, graph) { - new_scan_columns.extend(scan); - }; - } - Ok(new_scan_columns) + Self::recollect_apply(column_references, used_scan, node_id, graph) } - Operator::Join(_) => { - let mut new_scan_columns = scan_columns.clone(); - for child_id in graph.children_at(node_id).collect_vec() { - let copy_references = column_references.clone(); - let copy_scan = new_scan_columns.clone(); - if let Ok(scan) = Self::_apply(copy_references, copy_scan, child_id, graph) { - new_scan_columns.extend(scan); - }; + Join(_) => { + let used_scan = + Self::recollect_apply(column_references.clone(), used_scan, node_id, graph)?; + let temp_columns = operator.referenced_columns(true); + let mut column_references = column_references; + for column in temp_columns.iter() { + column_references.insert(column); } - Ok(new_scan_columns) + Ok(used_scan) + //todo Supplemental testing is required } // Last Operator - Operator::Dummy | Operator::Values(_) | Operator::FunctionScan(_) => Ok(scan_columns), + Operator::Dummy | Operator::Values(_) | Operator::FunctionScan(_) => Ok(used_scan), Operator::Explain => { if let Some(child_id) = graph.eldest_child_at(node_id) { - Self::_apply(column_references, scan_columns, child_id, graph) + Self::_apply(column_references, used_scan, child_id, graph) } else { unreachable!() } @@ -179,11 +151,11 @@ impl CorrelatedSubquery { | Operator::Update(_) | Operator::Delete(_) | Operator::Analyze(_) => { - let referenced_columns = operator.referenced_columns(false); + let referenced_columns = operator.referenced_columns(true); let new_column_references = trans_references!(&referenced_columns); if let Some(child_id) = graph.eldest_child_at(node_id) { - Self::recollect_apply(new_column_references, scan_columns, child_id, graph) + Self::recollect_apply(new_column_references, used_scan, child_id, graph) } else { unreachable!(); } @@ -202,29 +174,23 @@ impl CorrelatedSubquery { | Operator::CopyToFile(_) | Operator::AddColumn(_) | Operator::DropColumn(_) - | Operator::Describe(_) => Ok(scan_columns), + | Operator::Describe(_) => Ok(used_scan), } } fn recollect_apply( referenced_columns: HashSet<&ColumnRef>, - scan_columns: HashMap, HashMap, Vec)>, + mut used_scan: HashMap, node_id: HepNodeId, graph: &mut HepGraph, - ) -> Result< - HashMap, HashMap, Vec)>, - DatabaseError, - > { - let mut new_scan_columns = scan_columns.clone(); + ) -> Result, DatabaseError> { for child_id in graph.children_at(node_id).collect_vec() { let copy_references = referenced_columns.clone(); - let copy_scan = scan_columns.clone(); - - if let Ok(scan) = Self::_apply(copy_references, copy_scan, child_id, graph) { - new_scan_columns.extend(scan); - }; + let copy_scan = used_scan.clone(); + let scan = Self::_apply(copy_references, copy_scan, child_id, graph)?; + used_scan.extend(scan); } - Ok(new_scan_columns) + Ok(used_scan) } }