Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 95 additions & 50 deletions datafusion/sql/src/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use crate::query::to_order_by_exprs_with_select;
use crate::utils::{
check_columns_satisfy_exprs, extract_aliases, rebase_expr, resolve_aliases_to_exprs,
resolve_columns, resolve_positions_to_exprs, rewrite_recursive_unnests_bottom_up,
CheckColumnsSatisfyExprsPurpose,
CheckColumnsMustReferenceAggregatePurpose, CheckColumnsSatisfyExprsPurpose,
};

use datafusion_common::error::DataFusionErrorBuilder;
Expand Down Expand Up @@ -84,6 +84,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
// Handle named windows before processing the projection expression
check_conflicting_windows(&select.named_window)?;
self.match_window_definitions(&mut select.projection, &select.named_window)?;

// Process the SELECT expressions
let select_exprs = self.prepare_select_exprs(
&base_plan,
Expand Down Expand Up @@ -146,39 +147,6 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
})
.transpose()?;

// Optionally the QUALIFY expression.
let qualify_expr_opt = select
.qualify
.map::<Result<Expr>, _>(|qualify_expr| {
let qualify_expr = self.sql_expr_to_logical_expr(
qualify_expr,
&combined_schema,
planner_context,
)?;
// This step "dereferences" any aliases in the QUALIFY clause.
//
// This is how we support queries with QUALIFY expressions that
// refer to aliased columns.
//
// For example:
//
// select row_number() over (PARTITION BY id) as rk from users qualify rk > 1;
//
// are rewritten as, respectively:
//
// select row_number() over (PARTITION BY id) as rk from users qualify row_number() over (PARTITION BY id) > 1;
//
let qualify_expr = resolve_aliases_to_exprs(qualify_expr, &alias_map)?;
normalize_col(qualify_expr, &projected_plan)
})
.transpose()?;

// The outer expressions we will search through for aggregates.
// Aggregates may be sourced from the SELECT list or from the HAVING expression.
let aggr_expr_haystack = select_exprs.iter().chain(having_expr_opt.iter());
// All of the aggregate expressions (deduplicated).
let aggr_exprs = find_aggregate_exprs(aggr_expr_haystack);

// All of the group by expressions
let group_by_exprs = if let GroupByExpr::Expressions(exprs, _) = select.group_by {
exprs
Expand Down Expand Up @@ -223,22 +191,61 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
.collect()
};

// Optionally the QUALIFY expression.
let qualify_expr_opt = select
.qualify
.map::<Result<Expr>, _>(|qualify_expr| {
let qualify_expr = self.sql_expr_to_logical_expr(
qualify_expr,
&combined_schema,
planner_context,
)?;
// This step "dereferences" any aliases in the QUALIFY clause.
//
// This is how we support queries with QUALIFY expressions that
// refer to aliased columns.
//
// For example:
//
// select row_number() over (PARTITION BY id) as rk from users qualify rk > 1;
//
// are rewritten as, respectively:
//
// select row_number() over (PARTITION BY id) as rk from users qualify row_number() over (PARTITION BY id) > 1;
//
let qualify_expr = resolve_aliases_to_exprs(qualify_expr, &alias_map)?;
normalize_col(qualify_expr, &projected_plan)
})
.transpose()?;

// The outer expressions we will search through for aggregates.
// Aggregates may be sourced from the SELECT list or from the HAVING expression.
let aggr_expr_haystack = select_exprs
.iter()
.chain(having_expr_opt.iter())
.chain(qualify_expr_opt.iter());
// All of the aggregate expressions (deduplicated).
let aggr_exprs = find_aggregate_exprs(aggr_expr_haystack);

// Process group by, aggregation or having
let (plan, mut select_exprs_post_aggr, having_expr_post_aggr) = if !group_by_exprs
.is_empty()
|| !aggr_exprs.is_empty()
{
let (
plan,
mut select_exprs_post_aggr,
having_expr_post_aggr,
qualify_expr_post_aggr,
) = if !group_by_exprs.is_empty() || !aggr_exprs.is_empty() {
self.aggregate(
&base_plan,
&select_exprs,
having_expr_opt.as_ref(),
qualify_expr_opt.as_ref(),
&group_by_exprs,
&aggr_exprs,
)?
} else {
match having_expr_opt {
Some(having_expr) => return plan_err!("HAVING clause references: {having_expr} must appear in the GROUP BY clause or be used in an aggregate function"),
None => (base_plan.clone(), select_exprs.clone(), having_expr_opt)
None => (base_plan.clone(), select_exprs.clone(), having_expr_opt, qualify_expr_opt)
}
};

Expand All @@ -252,11 +259,15 @@ impl<S: ContextProvider> SqlToRel<'_, S> {

// The outer expressions we will search through for window functions.
// Window functions may be sourced from the SELECT list or from the QUALIFY expression.
let windows_expr_haystack =
select_exprs_post_aggr.iter().chain(qualify_expr_opt.iter());
// All of the window expressions (deduplicated).
let windows_expr_haystack = select_exprs_post_aggr
.iter()
.chain(qualify_expr_post_aggr.iter());
// All of the window expressions (deduplicated and rewritten to reference aggregates as
// columns from input).
let window_func_exprs = find_window_exprs(windows_expr_haystack);

// Process window functions after aggregation as they can reference
// aggregate functions in their body
let plan = if window_func_exprs.is_empty() {
plan
} else {
Expand All @@ -273,7 +284,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {

// Process QUALIFY clause after window functions
// QUALIFY filters the results of window functions, similar to how HAVING filters aggregates
let plan = if let Some(qualify_expr) = qualify_expr_opt {
let plan = if let Some(qualify_expr) = qualify_expr_post_aggr {
// Validate that QUALIFY is used with window functions
if window_func_exprs.is_empty() {
return plan_err!(
Expand Down Expand Up @@ -839,36 +850,42 @@ impl<S: ContextProvider> SqlToRel<'_, S> {

/// Create an aggregate plan.
///
/// An aggregate plan consists of grouping expressions, aggregate expressions, and an
/// optional HAVING expression (which is a filter on the output of the aggregate).
/// An aggregate plan consists of grouping expressions, aggregate expressions, an
/// optional HAVING expression (which is a filter on the output of the aggregate),
/// and an optional QUALIFY clause which may reference aggregates.
///
/// # Arguments
///
/// * `input` - The input plan that will be aggregated. The grouping, aggregate, and
/// "having" expressions must all be resolvable from this plan.
/// * `select_exprs` - The projection expressions from the SELECT clause.
/// * `having_expr_opt` - Optional HAVING clause.
/// * `qualify_expr_opt` - Optional QUALIFY clause.
/// * `group_by_exprs` - Grouping expressions from the GROUP BY clause. These can be column
/// references or more complex expressions.
/// * `aggr_exprs` - Aggregate expressions, such as `SUM(a)` or `COUNT(1)`.
///
/// # Return
///
/// The return value is a triplet of the following items:
/// The return value is a quadruplet of the following items:
///
/// * `plan` - A [LogicalPlan::Aggregate] plan for the newly created aggregate.
/// * `select_exprs_post_aggr` - The projection expressions rewritten to reference columns from
/// the aggregate
/// * `having_expr_post_aggr` - The "having" expression rewritten to reference a column from
/// the aggregate
/// * `qualify_expr_post_aggr` - The "qualify" expression rewritten to reference a column from
/// the aggregate
#[allow(clippy::type_complexity)]
fn aggregate(
&self,
input: &LogicalPlan,
select_exprs: &[Expr],
having_expr_opt: Option<&Expr>,
qualify_expr_opt: Option<&Expr>,
group_by_exprs: &[Expr],
aggr_exprs: &[Expr],
) -> Result<(LogicalPlan, Vec<Expr>, Option<Expr>)> {
) -> Result<(LogicalPlan, Vec<Expr>, Option<Expr>, Option<Expr>)> {
// create the aggregate plan
let options =
LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true);
Expand Down Expand Up @@ -932,7 +949,9 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
check_columns_satisfy_exprs(
&column_exprs_post_aggr,
&select_exprs_post_aggr,
CheckColumnsSatisfyExprsPurpose::ProjectionMustReferenceAggregate,
CheckColumnsSatisfyExprsPurpose::Aggregate(
CheckColumnsMustReferenceAggregatePurpose::Projection,
),
)?;

// Rewrite the HAVING expression to use the columns produced by the
Expand All @@ -944,15 +963,41 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
check_columns_satisfy_exprs(
&column_exprs_post_aggr,
std::slice::from_ref(&having_expr_post_aggr),
CheckColumnsSatisfyExprsPurpose::HavingMustReferenceAggregate,
CheckColumnsSatisfyExprsPurpose::Aggregate(
CheckColumnsMustReferenceAggregatePurpose::Having,
),
)?;

Some(having_expr_post_aggr)
} else {
None
};

Ok((plan, select_exprs_post_aggr, having_expr_post_aggr))
// Rewrite the QUALIFY expression to use the columns produced by the
// aggregation.
let qualify_expr_post_aggr = if let Some(qualify_expr) = qualify_expr_opt {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this logic is pretty similar to what is used for HAVING. What about creating a helper function for this helper logic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely similar, although currently we still require multiple distinct function calls to check_column_satisfy_expr because for each expression (SELECT, HAVING, QUALIFY), we pass in diagnostic information that indicates which clause the error occurs in (CheckColumnsSatisfyExprsPurpose).

Given this, I'm not quite sure a helper function would help readability/redundancy too much

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can try it in a follow on PR

let qualify_expr_post_aggr =
rebase_expr(qualify_expr, &aggr_projection_exprs, input)?;

check_columns_satisfy_exprs(
&column_exprs_post_aggr,
std::slice::from_ref(&qualify_expr_post_aggr),
CheckColumnsSatisfyExprsPurpose::Aggregate(
CheckColumnsMustReferenceAggregatePurpose::Qualify,
),
)?;

Some(qualify_expr_post_aggr)
} else {
None
};

Ok((
plan,
select_exprs_post_aggr,
having_expr_post_aggr,
qualify_expr_post_aggr,
))
}

// If the projection is done over a named window, that window
Expand Down
19 changes: 14 additions & 5 deletions datafusion/sql/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,21 +92,30 @@ pub(crate) fn rebase_expr(
.data()
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum CheckColumnsMustReferenceAggregatePurpose {
Projection,
Having,
Qualify,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum CheckColumnsSatisfyExprsPurpose {
ProjectionMustReferenceAggregate,
HavingMustReferenceAggregate,
Aggregate(CheckColumnsMustReferenceAggregatePurpose),
}

impl CheckColumnsSatisfyExprsPurpose {
fn message_prefix(&self) -> &'static str {
match self {
CheckColumnsSatisfyExprsPurpose::ProjectionMustReferenceAggregate => {
Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Projection) => {
"Column in SELECT must be in GROUP BY or an aggregate function"
}
CheckColumnsSatisfyExprsPurpose::HavingMustReferenceAggregate => {
Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Having) => {
"Column in HAVING must be in GROUP BY or an aggregate function"
}
Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Qualify) => {
"Column in QUALIFY must be in GROUP BY or an aggregate function"
}
}
}

Expand Down Expand Up @@ -162,7 +171,7 @@ fn check_column_satisfies_expr(
purpose.diagnostic_message(expr),
expr.spans().and_then(|spans| spans.first()),
)
.with_help(format!("Either add '{expr}' to GROUP BY clause, or use an aggregare function like ANY_VALUE({expr})"), None);
.with_help(format!("Either add '{expr}' to GROUP BY clause, or use an aggregate function like ANY_VALUE({expr})"), None);

return plan_err!(
"{}: While expanding wildcard, column \"{}\" must appear in the GROUP BY clause or must be part of an aggregate function, currently only \"{}\" appears in the SELECT clause satisfies this requirement",
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/tests/cases/diagnostic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ fn test_missing_non_aggregate_in_group_by() -> Result<()> {
let diag = do_query(query);
assert_snapshot!(diag.message, @"'person.first_name' must appear in GROUP BY clause because it's not an aggregate expression");
assert_eq!(diag.span, Some(spans["a"]));
assert_snapshot!(diag.helps[0].message, @"Either add 'person.first_name' to GROUP BY clause, or use an aggregare function like ANY_VALUE(person.first_name)");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

assert_snapshot!(diag.helps[0].message, @"Either add 'person.first_name' to GROUP BY clause, or use an aggregate function like ANY_VALUE(person.first_name)");
Ok(())
}

Expand Down
61 changes: 61 additions & 0 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4202,6 +4202,67 @@ Projection: person.id, row_number() PARTITION BY [person.age] ORDER BY [person.i
);
}

#[test]
fn test_select_qualify_aggregate_reference() {
let sql = "
SELECT
person.id,
ROW_NUMBER() OVER (PARTITION BY person.id ORDER BY person.id) as rn
FROM person
GROUP BY
person.id
QUALIFY rn = 1 AND SUM(person.age) > 0";
let plan = logical_plan(sql).unwrap();
assert_snapshot!(
plan,
@r"
Projection: person.id, row_number() PARTITION BY [person.id] ORDER BY [person.id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn
Filter: row_number() PARTITION BY [person.id] ORDER BY [person.id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW = Int64(1) AND sum(person.age) > Int64(0)
WindowAggr: windowExpr=[[row_number() PARTITION BY [person.id] ORDER BY [person.id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
Aggregate: groupBy=[[person.id]], aggr=[[sum(person.age)]]
TableScan: person
"
);
}

#[test]
fn test_select_qualify_aggregate_reference_within_window_function() {
let sql = "
SELECT
person.id
FROM person
GROUP BY
person.id
QUALIFY ROW_NUMBER() OVER (PARTITION BY person.id ORDER BY SUM(person.age) DESC) = 1";
let plan = logical_plan(sql).unwrap();
assert_snapshot!(
plan,
@r"
Projection: person.id
Filter: row_number() PARTITION BY [person.id] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW = Int64(1)
WindowAggr: windowExpr=[[row_number() PARTITION BY [person.id] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
Aggregate: groupBy=[[person.id]], aggr=[[sum(person.age)]]
TableScan: person
"
);
}

#[test]
fn test_select_qualify_aggregate_invalid_column_reference() {
let sql = "
SELECT
person.id
FROM person
GROUP BY
person.id
QUALIFY ROW_NUMBER() OVER (PARTITION BY person.id ORDER BY person.age DESC) = 1";
let err = logical_plan(sql).unwrap_err();
assert_snapshot!(
err.strip_backtrace(),
@r#"Error during planning: Column in QUALIFY must be in GROUP BY or an aggregate function: While expanding wildcard, column "person.age" must appear in the GROUP BY clause or must be part of an aggregate function, currently only "person.id" appears in the SELECT clause satisfies this requirement"#
);
}

#[test]
fn test_select_qualify_without_window_function() {
let sql = "SELECT person.id FROM person QUALIFY person.id > 1";
Expand Down
Loading