diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 890d6d62..bced8a4d 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -7,7 +7,7 @@ use sqlparser::ast::{ BinaryOperator, CharLengthUnits, DataType, Expr, Function, FunctionArg, FunctionArgExpr, Ident, Query, UnaryOperator, Value, }; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::slice; use std::sync::Arc; @@ -293,6 +293,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T self.args, Some(self), ); + let mut sub_query = binder.bind_query(subquery)?; let sub_query_schema = sub_query.output_schema(); @@ -368,7 +369,8 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T try_default!(&full_name.0, full_name.1); } if let Some(table) = full_name.0.or(bind_table_name) { - let source = self.context.bind_source(&table)?; + let source = self.context.bind_source::(self.parent, &table)?; + let schema_buf = self.table_schema_buf.entry(Arc::new(table)).or_default(); Ok(ScalarExpression::ColumnRef( diff --git a/src/binder/mod.rs b/src/binder/mod.rs index 9be10c73..307c8201 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -26,7 +26,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use crate::catalog::view::View; -use crate::catalog::{ColumnRef, TableCatalog, TableName}; +use crate::catalog::{ColumnCatalog, ColumnRef, TableCatalog, TableName}; use crate::db::{ScalaFunctions, TableFunctions}; use crate::errors::DatabaseError; use crate::expression::ScalarExpression; @@ -276,12 +276,18 @@ impl<'a, T: Transaction> BinderContext<'a, T> { Ok(source) } - pub fn bind_source<'b: 'a>(&self, table_name: &str) -> Result<&Source, DatabaseError> { + pub fn bind_source<'b: 'a, A: AsRef<[(&'static str, DataValue)]>>( + &self, + parent: Option<&'a Binder<'a, 'b, T, A>>, + table_name: &str, + ) -> Result<&'b Source, DatabaseError> { if let Some(source) = self.bind_table.iter().find(|((t, alias, _), _)| { t.as_str() == table_name || matches!(alias.as_ref().map(|a| a.as_str() == table_name), Some(true)) }) { Ok(source.1) + } else if let Some(binder) = parent { + binder.context.bind_source(binder.parent, table_name) } else { Err(DatabaseError::InvalidTable(table_name.into())) } diff --git a/src/db.rs b/src/db.rs index ff08a4fa..ca60c0eb 100644 --- a/src/db.rs +++ b/src/db.rs @@ -173,6 +173,11 @@ impl State { pub(crate) fn default_optimizer(source_plan: LogicalPlan) -> HepOptimizer { HepOptimizer::new(source_plan) + .batch( + "Correlated Subquery".to_string(), + HepBatchStrategy::once_topdown(), + vec![NormalizationRuleImpl::CorrelateSubquery], + ) .batch( "Column Pruning".to_string(), HepBatchStrategy::once_topdown(), diff --git a/src/optimizer/heuristic/graph.rs b/src/optimizer/heuristic/graph.rs index f6de8c61..6695bfdc 100644 --- a/src/optimizer/heuristic/graph.rs +++ b/src/optimizer/heuristic/graph.rs @@ -79,7 +79,7 @@ impl HepGraph { source_id: HepNodeId, children_option: Option, new_node: Operator, - ) { + ) -> HepNodeId { let new_index = self.graph.add_node(new_node); let mut order = self.graph.edges(source_id).count(); @@ -95,6 +95,7 @@ impl HepGraph { self.graph.add_edge(source_id, new_index, order); self.version += 1; + new_index } pub fn replace_node(&mut self, source_id: HepNodeId, new_node: Operator) { diff --git a/src/optimizer/rule/normalization/correlated_subquery.rs b/src/optimizer/rule/normalization/correlated_subquery.rs new file mode 100644 index 00000000..30e44502 --- /dev/null +++ b/src/optimizer/rule/normalization/correlated_subquery.rs @@ -0,0 +1,211 @@ +use crate::catalog::{ColumnRef, TableName}; +use crate::errors::DatabaseError; +use crate::expression::visitor::Visitor; +use crate::expression::HasCountStar; +use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate}; +use crate::optimizer::core::rule::{MatchPattern, NormalizationRule}; +use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; +use crate::planner::operator::join::{JoinCondition, JoinOperator, JoinType}; +use crate::planner::operator::table_scan::TableScanOperator; +use crate::planner::operator::Operator; +use crate::planner::operator::Operator::{Join, TableScan}; +use crate::types::index::IndexInfo; +use crate::types::ColumnId; +use itertools::Itertools; +use std::collections::BTreeMap; +use std::collections::{HashMap, HashSet}; +use std::sync::{Arc, LazyLock}; + +static CORRELATED_SUBQUERY_RULE: LazyLock = LazyLock::new(|| Pattern { + predicate: |op| matches!(op, Join(_)), + children: PatternChildrenPredicate::None, +}); + +#[derive(Clone)] +pub struct CorrelatedSubquery; + +macro_rules! trans_references { + ($columns:expr) => {{ + let mut column_references = HashSet::with_capacity($columns.len()); + for column in $columns { + column_references.insert(column); + } + column_references + }}; +} + +impl CorrelatedSubquery { + fn _apply( + column_references: HashSet<&ColumnRef>, + mut used_scan: HashMap, + node_id: HepNodeId, + graph: &mut HepGraph, + ) -> Result, DatabaseError> { + let operator = &graph.operator(node_id).clone(); + + match operator { + Operator::Aggregate(op) => { + let is_distinct = op.is_distinct; + let referenced_columns = operator.referenced_columns(true); + let mut new_column_references = trans_references!(&referenced_columns); + // on distinct + if is_distinct { + for summary in column_references { + new_column_references.insert(summary); + } + } + + Self::recollect_apply(new_column_references, used_scan, node_id, graph) + } + Operator::Project(op) => { + let referenced_columns = operator.referenced_columns(true); + let new_column_references = trans_references!(&referenced_columns); + + Self::recollect_apply(new_column_references, used_scan, node_id, graph) + } + TableScan(op) => { + let table_columns: HashSet<&ColumnRef> = op.columns.values().collect(); + let mut parent_scan_to_added = HashMap::new(); + for col in column_references { + if table_columns.contains(col) { + continue; + } + if let Some(table_name) = col.table_name() { + if !used_scan.contains_key(table_name) { + continue; + } + parent_scan_to_added + .entry(table_name) + .or_insert(HashSet::new()) + .insert(col); + } + } + for (table_name, table_columns) in parent_scan_to_added { + let op = used_scan.get(table_name).unwrap(); + let left_operator = graph.operator(node_id).clone(); + let right_operator = TableScan(TableScanOperator { + table_name: table_name.clone(), + primary_keys: op.primary_keys.clone(), + columns: op + .columns + .iter() + .filter(|(_, column)| table_columns.contains(column)) + .map(|(i, col)| (*i, col.clone())) + .collect(), + limit: (None, None), + index_infos: op.index_infos.clone(), + with_pk: false, + }); + let join_operator = Join(JoinOperator { + on: JoinCondition::None, + join_type: JoinType::Cross, + }); + + match &left_operator { + TableScan(_) => { + graph.replace_node(node_id, join_operator); + graph.add_node(node_id, None, left_operator); + graph.add_node(node_id, None, right_operator); + } + Join(_) => { + let left_id = graph.eldest_child_at(node_id).unwrap(); + let left_id = graph.add_node(node_id, Some(left_id), join_operator); + graph.add_node(left_id, None, right_operator); + } + _ => unreachable!(), + } + } + used_scan.insert(op.table_name.clone(), op.clone()); + Ok(used_scan) + } + Operator::Sort(_) | Operator::Limit(_) | Operator::Filter(_) | Operator::Union(_) => { + let temp_columns = operator.referenced_columns(true); + let mut column_references = column_references; + for column in temp_columns.iter() { + column_references.insert(column); + } + Self::recollect_apply(column_references, used_scan, node_id, graph) + } + Join(_) => { + let used_scan = + Self::recollect_apply(column_references.clone(), used_scan, node_id, graph)?; + let temp_columns = operator.referenced_columns(true); + let mut column_references = column_references; + for column in temp_columns.iter() { + column_references.insert(column); + } + Ok(used_scan) + //todo Supplemental testing is required + } + // Last Operator + Operator::Dummy | Operator::Values(_) | Operator::FunctionScan(_) => Ok(used_scan), + Operator::Explain => { + if let Some(child_id) = graph.eldest_child_at(node_id) { + Self::_apply(column_references, used_scan, child_id, graph) + } else { + unreachable!() + } + } + // DDL Based on Other Plan + Operator::Insert(_) + | Operator::Update(_) + | Operator::Delete(_) + | Operator::Analyze(_) => { + let referenced_columns = operator.referenced_columns(true); + let new_column_references = trans_references!(&referenced_columns); + + if let Some(child_id) = graph.eldest_child_at(node_id) { + Self::recollect_apply(new_column_references, used_scan, child_id, graph) + } else { + unreachable!(); + } + } + // DDL Single Plan + Operator::CreateTable(_) + | Operator::CreateIndex(_) + | Operator::CreateView(_) + | Operator::DropTable(_) + | Operator::DropView(_) + | Operator::DropIndex(_) + | Operator::Truncate(_) + | Operator::ShowTable + | Operator::ShowView + | Operator::CopyFromFile(_) + | Operator::CopyToFile(_) + | Operator::AddColumn(_) + | Operator::DropColumn(_) + | Operator::Describe(_) => Ok(used_scan), + } + } + + fn recollect_apply( + referenced_columns: HashSet<&ColumnRef>, + mut used_scan: HashMap, + node_id: HepNodeId, + graph: &mut HepGraph, + ) -> Result, DatabaseError> { + for child_id in graph.children_at(node_id).collect_vec() { + let copy_references = referenced_columns.clone(); + let copy_scan = used_scan.clone(); + let scan = Self::_apply(copy_references, copy_scan, child_id, graph)?; + used_scan.extend(scan); + } + Ok(used_scan) + } +} + +impl MatchPattern for CorrelatedSubquery { + fn pattern(&self) -> &Pattern { + &CORRELATED_SUBQUERY_RULE + } +} + +impl NormalizationRule for CorrelatedSubquery { + fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { + Self::_apply(HashSet::new(), HashMap::new(), node_id, graph)?; + // mark changed to skip this rule batch + graph.version += 1; + + Ok(()) + } +} diff --git a/src/optimizer/rule/normalization/mod.rs b/src/optimizer/rule/normalization/mod.rs index 8c30fd4a..b7fc6d4b 100644 --- a/src/optimizer/rule/normalization/mod.rs +++ b/src/optimizer/rule/normalization/mod.rs @@ -10,6 +10,7 @@ use crate::optimizer::rule::normalization::combine_operators::{ use crate::optimizer::rule::normalization::compilation_in_advance::{ EvaluatorBind, ExpressionRemapper, }; +use crate::optimizer::rule::normalization::correlated_subquery::CorrelatedSubquery; use crate::optimizer::rule::normalization::pushdown_limit::{ LimitProjectTranspose, PushLimitIntoScan, PushLimitThroughJoin, }; @@ -21,6 +22,7 @@ use crate::optimizer::rule::normalization::simplification::SimplifyFilter; mod column_pruning; mod combine_operators; mod compilation_in_advance; +mod correlated_subquery; mod pushdown_limit; mod pushdown_predicates; mod simplification; @@ -32,6 +34,7 @@ pub enum NormalizationRuleImpl { CollapseProject, CollapseGroupByAgg, CombineFilter, + CorrelateSubquery, // PushDown limit LimitProjectTranspose, PushLimitThroughJoin, @@ -55,6 +58,7 @@ impl MatchPattern for NormalizationRuleImpl { NormalizationRuleImpl::CollapseProject => CollapseProject.pattern(), NormalizationRuleImpl::CollapseGroupByAgg => CollapseGroupByAgg.pattern(), NormalizationRuleImpl::CombineFilter => CombineFilter.pattern(), + NormalizationRuleImpl::CorrelateSubquery => CorrelatedSubquery.pattern(), NormalizationRuleImpl::LimitProjectTranspose => LimitProjectTranspose.pattern(), NormalizationRuleImpl::PushLimitThroughJoin => PushLimitThroughJoin.pattern(), NormalizationRuleImpl::PushLimitIntoTableScan => PushLimitIntoScan.pattern(), @@ -75,6 +79,7 @@ impl NormalizationRule for NormalizationRuleImpl { NormalizationRuleImpl::CollapseProject => CollapseProject.apply(node_id, graph), NormalizationRuleImpl::CollapseGroupByAgg => CollapseGroupByAgg.apply(node_id, graph), NormalizationRuleImpl::CombineFilter => CombineFilter.apply(node_id, graph), + NormalizationRuleImpl::CorrelateSubquery => CorrelatedSubquery.apply(node_id, graph), NormalizationRuleImpl::LimitProjectTranspose => { LimitProjectTranspose.apply(node_id, graph) } diff --git a/src/optimizer/rule/normalization/simplification.rs b/src/optimizer/rule/normalization/simplification.rs index a220c065..02d66ef5 100644 --- a/src/optimizer/rule/normalization/simplification.rs +++ b/src/optimizer/rule/normalization/simplification.rs @@ -298,7 +298,7 @@ mod test { op: BinaryOperator::Plus, left_expr: Box::new(ScalarExpression::ColumnRef(ColumnRef::from( c1_col - ))), + ),)), right_expr: Box::new(ScalarExpression::Constant(DataValue::Int32(1))), evaluator: None, ty: LogicalType::Integer, @@ -306,7 +306,7 @@ mod test { evaluator: None, ty: LogicalType::Integer, }), - right_expr: Box::new(ScalarExpression::ColumnRef(ColumnRef::from(c2_col))), + right_expr: Box::new(ScalarExpression::ColumnRef(ColumnRef::from(c2_col),)), evaluator: None, ty: LogicalType::Boolean, } diff --git a/tests/slt/correlated_subquery.slt b/tests/slt/correlated_subquery.slt new file mode 100644 index 00000000..1a6c0d62 --- /dev/null +++ b/tests/slt/correlated_subquery.slt @@ -0,0 +1,112 @@ +statement ok +CREATE TABLE t1 (id INT PRIMARY KEY, v1 VARCHAR(50), v2 INT); + +statement ok +CREATE TABLE t2 (id INT PRIMARY KEY, v1 VARCHAR(50), v2 INT); + +statement ok +CREATE TABLE t3 (id INT PRIMARY KEY, v1 INT, v2 INT); + +statement ok +insert into t1(id, v1, v2) values (1,'a',9) + +statement ok +insert into t1(id, v1, v2) values (2,'b',6) + +statement ok +insert into t1(id, v1, v2) values (3,'c',11) + +statement ok +insert into t2(id, v1, v2) values (1,'A',10) + +statement ok +insert into t2(id, v1, v2) values (2,'B',11) + +statement ok +insert into t2(id, v1, v2) values (3,'C',9) + +statement ok +insert into t3(id, v1, v2) values (1,6,10) + +statement ok +insert into t3(id, v1, v2) values (2,5,10) + +statement ok +insert into t3(id, v1, v2) values (3,4,10) + +query IT rowsort +SELECT id, v1 FROM t1 WHERE id IN ( SELECT t2.id FROM t2 WHERE t2.v2 < t1.v2 ) +---- +1 a +3 c + +query I rowsort +SELECT v1 FROM t1 WHERE EXISTS ( SELECT 1 FROM t3 WHERE t3.id = t1.id ) +---- +a +b +c + +query TT rowsort +SELECT t1.v1, t2.v1 FROM t1 JOIN t2 ON t1.id = t2.id WHERE t2.v2 > ( SELECT AVG(v2) FROM t1 ) +---- +a A +b B +c C + +query IT rowsort +SELECT id, v1 FROM t1 WHERE NOT EXISTS ( SELECT 1 FROM t2 WHERE t2.id = t1.id AND t2.v2 = t1.v2 ) +---- +1 a +2 b +3 c + +query IT rowsort +SELECT id, v1 FROM t1 WHERE v2 > ( SELECT MIN(v2) FROM t2 WHERE t2.id = t1.id ) +---- +3 c + +query IT rowsort +SELECT id, v1 FROM t1 WHERE EXISTS ( SELECT 1 FROM t2 WHERE t2.id = t1.id AND EXISTS ( SELECT 1 FROM t3 WHERE t3.id = t1.id ) ) +---- +1 a +2 b +3 c + +query IT rowsort +SELECT id, v1 FROM t1 WHERE v2 - 5 > ( SELECT AVG(v1) FROM t3 WHERE t3.id <= t1.id ) +---- +3 c + +query IT rowsort +SELECT id, v1 FROM t1 WHERE id NOT IN ( SELECT t2.id FROM t2 WHERE t2.v2 > t1.v2 ) +---- + + +query IT rowsort +SELECT id, v1 FROM t1 WHERE EXISTS ( SELECT 1 FROM t2 WHERE t2.id = t1.id AND t2.v2 + t1.v2 > 15 ) +---- +1 a +2 b +3 c + +query IT rowsort +SELECT id, v1 FROM t1 WHERE v2 = ( SELECT MAX(v2) FROM t2 WHERE t2.id <= t1.id ) ORDER BY id +---- +3 c + +query IT rowsort +SELECT id, v1 FROM t1 WHERE ( SELECT COUNT(*) FROM t2 WHERE t2.v2 = t1.v2 ) = 2 +---- +1 a +2 b +3 c + +statement ok +DROP TABLE t1; + +statement ok +DROP TABLE t2; + +statement ok +DROP TABLE t3; \ No newline at end of file diff --git a/tests/slt/crdb/update.slt b/tests/slt/crdb/update.slt index e87066a7..de96e271 100644 --- a/tests/slt/crdb/update.slt +++ b/tests/slt/crdb/update.slt @@ -88,8 +88,8 @@ query I select a from t1 order by a; ---- 1 -2 3 +4 8 # sqlparser-rs not support