Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/binder/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use sqlparser::ast::{
BinaryOperator, CharLengthUnits, DataType, Expr, Function, FunctionArg, FunctionArgExpr, Ident,
Query, UnaryOperator, Value,
};
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::slice;
use std::sync::Arc;

Expand Down Expand Up @@ -293,6 +293,7 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T
self.args,
Some(self),
);

let mut sub_query = binder.bind_query(subquery)?;
let sub_query_schema = sub_query.output_schema();

Expand Down Expand Up @@ -368,7 +369,8 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T
try_default!(&full_name.0, full_name.1);
}
if let Some(table) = full_name.0.or(bind_table_name) {
let source = self.context.bind_source(&table)?;
let source = self.context.bind_source::<A>(self.parent, &table)?;

let schema_buf = self.table_schema_buf.entry(Arc::new(table)).or_default();

Ok(ScalarExpression::ColumnRef(
Expand Down
10 changes: 8 additions & 2 deletions src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

use crate::catalog::view::View;
use crate::catalog::{ColumnRef, TableCatalog, TableName};
use crate::catalog::{ColumnCatalog, ColumnRef, TableCatalog, TableName};
use crate::db::{ScalaFunctions, TableFunctions};
use crate::errors::DatabaseError;
use crate::expression::ScalarExpression;
Expand Down Expand Up @@ -276,12 +276,18 @@ impl<'a, T: Transaction> BinderContext<'a, T> {
Ok(source)
}

pub fn bind_source<'b: 'a>(&self, table_name: &str) -> Result<&Source, DatabaseError> {
pub fn bind_source<'b: 'a, A: AsRef<[(&'static str, DataValue)]>>(
&self,
parent: Option<&'a Binder<'a, 'b, T, A>>,
table_name: &str,
) -> Result<&'b Source, DatabaseError> {
if let Some(source) = self.bind_table.iter().find(|((t, alias, _), _)| {
t.as_str() == table_name
|| matches!(alias.as_ref().map(|a| a.as_str() == table_name), Some(true))
}) {
Ok(source.1)
} else if let Some(binder) = parent {
binder.context.bind_source(binder.parent, table_name)
} else {
Err(DatabaseError::InvalidTable(table_name.into()))
}
Expand Down
5 changes: 5 additions & 0 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ impl<S: Storage> State<S> {

pub(crate) fn default_optimizer(source_plan: LogicalPlan) -> HepOptimizer {
HepOptimizer::new(source_plan)
.batch(
"Correlated Subquery".to_string(),
HepBatchStrategy::once_topdown(),
vec![NormalizationRuleImpl::CorrelateSubquery],
)
.batch(
"Column Pruning".to_string(),
HepBatchStrategy::once_topdown(),
Expand Down
3 changes: 2 additions & 1 deletion src/optimizer/heuristic/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ impl HepGraph {
source_id: HepNodeId,
children_option: Option<HepNodeId>,
new_node: Operator,
) {
) -> HepNodeId {
let new_index = self.graph.add_node(new_node);
let mut order = self.graph.edges(source_id).count();

Expand All @@ -95,6 +95,7 @@ impl HepGraph {

self.graph.add_edge(source_id, new_index, order);
self.version += 1;
new_index
}

pub fn replace_node(&mut self, source_id: HepNodeId, new_node: Operator) {
Expand Down
211 changes: 211 additions & 0 deletions src/optimizer/rule/normalization/correlated_subquery.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
use crate::catalog::{ColumnRef, TableName};
use crate::errors::DatabaseError;
use crate::expression::visitor::Visitor;
use crate::expression::HasCountStar;
use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate};
use crate::optimizer::core::rule::{MatchPattern, NormalizationRule};
use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId};
use crate::planner::operator::join::{JoinCondition, JoinOperator, JoinType};
use crate::planner::operator::table_scan::TableScanOperator;
use crate::planner::operator::Operator;
use crate::planner::operator::Operator::{Join, TableScan};
use crate::types::index::IndexInfo;
use crate::types::ColumnId;
use itertools::Itertools;
use std::collections::BTreeMap;
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, LazyLock};

static CORRELATED_SUBQUERY_RULE: LazyLock<Pattern> = LazyLock::new(|| Pattern {
predicate: |op| matches!(op, Join(_)),
children: PatternChildrenPredicate::None,
});

#[derive(Clone)]
pub struct CorrelatedSubquery;

macro_rules! trans_references {
($columns:expr) => {{
let mut column_references = HashSet::with_capacity($columns.len());
for column in $columns {
column_references.insert(column);
}
column_references
}};
}

impl CorrelatedSubquery {
fn _apply(
column_references: HashSet<&ColumnRef>,
mut used_scan: HashMap<TableName, TableScanOperator>,
node_id: HepNodeId,
graph: &mut HepGraph,
) -> Result<HashMap<TableName, TableScanOperator>, DatabaseError> {
let operator = &graph.operator(node_id).clone();

match operator {
Operator::Aggregate(op) => {
let is_distinct = op.is_distinct;
let referenced_columns = operator.referenced_columns(true);
let mut new_column_references = trans_references!(&referenced_columns);
// on distinct
if is_distinct {
for summary in column_references {
new_column_references.insert(summary);
}
}

Self::recollect_apply(new_column_references, used_scan, node_id, graph)
}
Operator::Project(op) => {
let referenced_columns = operator.referenced_columns(true);
let new_column_references = trans_references!(&referenced_columns);

Self::recollect_apply(new_column_references, used_scan, node_id, graph)
}
TableScan(op) => {
let table_columns: HashSet<&ColumnRef> = op.columns.values().collect();
let mut parent_scan_to_added = HashMap::new();
for col in column_references {
if table_columns.contains(col) {
continue;
}
if let Some(table_name) = col.table_name() {
if !used_scan.contains_key(table_name) {
continue;
}
parent_scan_to_added
.entry(table_name)
.or_insert(HashSet::new())
.insert(col);
}
}
for (table_name, table_columns) in parent_scan_to_added {
let op = used_scan.get(table_name).unwrap();
let left_operator = graph.operator(node_id).clone();
let right_operator = TableScan(TableScanOperator {
table_name: table_name.clone(),
primary_keys: op.primary_keys.clone(),
columns: op
.columns
.iter()
.filter(|(_, column)| table_columns.contains(column))
.map(|(i, col)| (*i, col.clone()))
.collect(),
limit: (None, None),
index_infos: op.index_infos.clone(),
with_pk: false,
});
let join_operator = Join(JoinOperator {
on: JoinCondition::None,
join_type: JoinType::Cross,
});

match &left_operator {
TableScan(_) => {
graph.replace_node(node_id, join_operator);
graph.add_node(node_id, None, left_operator);
graph.add_node(node_id, None, right_operator);
}
Join(_) => {
let left_id = graph.eldest_child_at(node_id).unwrap();
let left_id = graph.add_node(node_id, Some(left_id), join_operator);
graph.add_node(left_id, None, right_operator);
}
_ => unreachable!(),
}
}
used_scan.insert(op.table_name.clone(), op.clone());
Ok(used_scan)
}
Operator::Sort(_) | Operator::Limit(_) | Operator::Filter(_) | Operator::Union(_) => {
let temp_columns = operator.referenced_columns(true);
let mut column_references = column_references;
for column in temp_columns.iter() {
column_references.insert(column);
}
Self::recollect_apply(column_references, used_scan, node_id, graph)
}
Join(_) => {
let used_scan =
Self::recollect_apply(column_references.clone(), used_scan, node_id, graph)?;
let temp_columns = operator.referenced_columns(true);
let mut column_references = column_references;
for column in temp_columns.iter() {
column_references.insert(column);
}
Ok(used_scan)
//todo Supplemental testing is required
}
// Last Operator
Operator::Dummy | Operator::Values(_) | Operator::FunctionScan(_) => Ok(used_scan),
Operator::Explain => {
if let Some(child_id) = graph.eldest_child_at(node_id) {
Self::_apply(column_references, used_scan, child_id, graph)
} else {
unreachable!()
}
}
// DDL Based on Other Plan
Operator::Insert(_)
| Operator::Update(_)
| Operator::Delete(_)
| Operator::Analyze(_) => {
let referenced_columns = operator.referenced_columns(true);
let new_column_references = trans_references!(&referenced_columns);

if let Some(child_id) = graph.eldest_child_at(node_id) {
Self::recollect_apply(new_column_references, used_scan, child_id, graph)
} else {
unreachable!();
}
}
// DDL Single Plan
Operator::CreateTable(_)
| Operator::CreateIndex(_)
| Operator::CreateView(_)
| Operator::DropTable(_)
| Operator::DropView(_)
| Operator::DropIndex(_)
| Operator::Truncate(_)
| Operator::ShowTable
| Operator::ShowView
| Operator::CopyFromFile(_)
| Operator::CopyToFile(_)
| Operator::AddColumn(_)
| Operator::DropColumn(_)
| Operator::Describe(_) => Ok(used_scan),
}
}

fn recollect_apply(
referenced_columns: HashSet<&ColumnRef>,
mut used_scan: HashMap<TableName, TableScanOperator>,
node_id: HepNodeId,
graph: &mut HepGraph,
) -> Result<HashMap<TableName, TableScanOperator>, DatabaseError> {
for child_id in graph.children_at(node_id).collect_vec() {
let copy_references = referenced_columns.clone();
let copy_scan = used_scan.clone();
let scan = Self::_apply(copy_references, copy_scan, child_id, graph)?;
used_scan.extend(scan);
}
Ok(used_scan)
}
}

impl MatchPattern for CorrelatedSubquery {
fn pattern(&self) -> &Pattern {
&CORRELATED_SUBQUERY_RULE
}
}

impl NormalizationRule for CorrelatedSubquery {
fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> {
Self::_apply(HashSet::new(), HashMap::new(), node_id, graph)?;
// mark changed to skip this rule batch
graph.version += 1;

Ok(())
}
}
5 changes: 5 additions & 0 deletions src/optimizer/rule/normalization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::optimizer::rule::normalization::combine_operators::{
use crate::optimizer::rule::normalization::compilation_in_advance::{
EvaluatorBind, ExpressionRemapper,
};
use crate::optimizer::rule::normalization::correlated_subquery::CorrelatedSubquery;
use crate::optimizer::rule::normalization::pushdown_limit::{
LimitProjectTranspose, PushLimitIntoScan, PushLimitThroughJoin,
};
Expand All @@ -21,6 +22,7 @@ use crate::optimizer::rule::normalization::simplification::SimplifyFilter;
mod column_pruning;
mod combine_operators;
mod compilation_in_advance;
mod correlated_subquery;
mod pushdown_limit;
mod pushdown_predicates;
mod simplification;
Expand All @@ -32,6 +34,7 @@ pub enum NormalizationRuleImpl {
CollapseProject,
CollapseGroupByAgg,
CombineFilter,
CorrelateSubquery,
// PushDown limit
LimitProjectTranspose,
PushLimitThroughJoin,
Expand All @@ -55,6 +58,7 @@ impl MatchPattern for NormalizationRuleImpl {
NormalizationRuleImpl::CollapseProject => CollapseProject.pattern(),
NormalizationRuleImpl::CollapseGroupByAgg => CollapseGroupByAgg.pattern(),
NormalizationRuleImpl::CombineFilter => CombineFilter.pattern(),
NormalizationRuleImpl::CorrelateSubquery => CorrelatedSubquery.pattern(),
NormalizationRuleImpl::LimitProjectTranspose => LimitProjectTranspose.pattern(),
NormalizationRuleImpl::PushLimitThroughJoin => PushLimitThroughJoin.pattern(),
NormalizationRuleImpl::PushLimitIntoTableScan => PushLimitIntoScan.pattern(),
Expand All @@ -75,6 +79,7 @@ impl NormalizationRule for NormalizationRuleImpl {
NormalizationRuleImpl::CollapseProject => CollapseProject.apply(node_id, graph),
NormalizationRuleImpl::CollapseGroupByAgg => CollapseGroupByAgg.apply(node_id, graph),
NormalizationRuleImpl::CombineFilter => CombineFilter.apply(node_id, graph),
NormalizationRuleImpl::CorrelateSubquery => CorrelatedSubquery.apply(node_id, graph),
NormalizationRuleImpl::LimitProjectTranspose => {
LimitProjectTranspose.apply(node_id, graph)
}
Expand Down
4 changes: 2 additions & 2 deletions src/optimizer/rule/normalization/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,15 +298,15 @@ mod test {
op: BinaryOperator::Plus,
left_expr: Box::new(ScalarExpression::ColumnRef(ColumnRef::from(
c1_col
))),
),)),
right_expr: Box::new(ScalarExpression::Constant(DataValue::Int32(1))),
evaluator: None,
ty: LogicalType::Integer,
}),
evaluator: None,
ty: LogicalType::Integer,
}),
right_expr: Box::new(ScalarExpression::ColumnRef(ColumnRef::from(c2_col))),
right_expr: Box::new(ScalarExpression::ColumnRef(ColumnRef::from(c2_col),)),
evaluator: None,
ty: LogicalType::Boolean,
}
Expand Down
Loading
Loading