diff --git a/1 b/1 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/diesel/src/backend.rs b/diesel/src/backend.rs index 3847453f025b..ee1353098292 100644 --- a/diesel/src/backend.rs +++ b/diesel/src/backend.rs @@ -317,6 +317,20 @@ pub trait SqlDialect: self::private::TrustedBackend { doc = "See [`sql_dialect::alias_syntax`] for provided default implementations" )] type AliasSyntax; + + /// Configures how this backend support the `GROUP` frame unit for window functions + #[cfg_attr( + feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes", + doc = "See [`sql_dialect::window_frame_clause_group_support`] for provided default implementations" + )] + type WindowFrameClauseGroupSupport; + + /// Configures how this backend supports aggregate function expressions + #[cfg_attr( + feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes", + doc = "See [`sql_dialect::window_frame_clause_group_support`] for provided default implementations" + )] + type AggregateFunctionExpressions; } /// This module contains all options provided by diesel to configure the [`SqlDialect`] trait. @@ -539,6 +553,34 @@ pub(crate) mod sql_dialect { #[derive(Debug, Copy, Clone)] pub struct AsAliasSyntax; } + + /// This module contains all reusable options to configure [`SqlDialect::WindowFrameClauseGroupSupport`] + #[diesel_derives::__diesel_public_if( + feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes" + )] + pub mod window_frame_clause_group_support { + /// Indicates that this backend does not support the `GROUPS` frame unit + #[derive(Debug, Copy, Clone)] + pub struct NoGroupWindowFrameUnit; + + /// Indicates that this backend does support the `GROUPS` frame unit as specified by the standard + #[derive(Debug, Copy, Clone)] + pub struct IsoGroupWindowFrameUnit; + } + + /// This module contains all reusable options to configure [`SqlDialect::AggregateFunctionExpressions`] + #[diesel_derives::__diesel_public_if( + feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes" + )] + pub mod aggregate_function_expressions { + /// Indicates that this backend does not support aggregate function expressions + #[derive(Debug, Copy, Clone)] + pub struct NoAggregateFunctionExpressions; + + /// Indicates that this backend supports aggregate function expressions similar to PostgreSQL + #[derive(Debug, Copy, Clone)] + pub struct PostgresLikeAggregateFunctionExpressions; + } } // These traits are not part of the public API diff --git a/diesel/src/connection/instrumentation.rs b/diesel/src/connection/instrumentation.rs index 9643c4d10f80..6bb2c03d837f 100644 --- a/diesel/src/connection/instrumentation.rs +++ b/diesel/src/connection/instrumentation.rs @@ -356,6 +356,12 @@ impl DynInstrumentation { #[diesel_derives::__diesel_public_if( feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes" )] + #[cfg(any( + feature = "postgres", + feature = "sqlite", + feature = "mysql", + feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes" + ))] pub(crate) fn default_instrumentation() -> Self { Self { inner: get_default_instrumentation(), @@ -367,6 +373,12 @@ impl DynInstrumentation { #[diesel_derives::__diesel_public_if( feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes" )] + #[cfg(any( + feature = "postgres", + feature = "sqlite", + feature = "mysql", + feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes" + ))] pub(crate) fn none() -> Self { Self { inner: None, @@ -378,6 +390,12 @@ impl DynInstrumentation { #[diesel_derives::__diesel_public_if( feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes" )] + #[cfg(any( + feature = "postgres", + feature = "sqlite", + feature = "mysql", + feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes" + ))] pub(crate) fn on_connection_event(&mut self, event: InstrumentationEvent<'_>) { // This implementation is not necessary to be able to call this method on this object // because of the already existing Deref impl. diff --git a/diesel/src/expression/count.rs b/diesel/src/expression/count.rs index e3512b22995c..c946e63f4ae2 100644 --- a/diesel/src/expression/count.rs +++ b/diesel/src/expression/count.rs @@ -29,6 +29,7 @@ define_sql_function! { /// # } /// ``` #[aggregate] + #[window] fn count(expr: T) -> BigInt; } diff --git a/diesel/src/expression/functions/aggregate_expressions.rs b/diesel/src/expression/functions/aggregate_expressions.rs new file mode 100644 index 000000000000..f1cb7c093815 --- /dev/null +++ b/diesel/src/expression/functions/aggregate_expressions.rs @@ -0,0 +1,264 @@ +use crate::backend::Backend; +use crate::expression::{AsExpression, ValidGrouping}; +use crate::query_builder::{AstPass, NotSpecialized, QueryFragment, QueryId}; +use crate::sql_types::Bool; +use crate::{AppearsOnTable, Expression, QueryResult, SelectableExpression}; + +macro_rules! empty_clause { + ($name: ident) => { + #[derive(Debug, Clone, Copy, QueryId)] + pub struct $name; + + impl crate::query_builder::QueryFragment for $name + where + DB: crate::backend::Backend + crate::backend::DieselReserveSpecialization, + { + fn walk_ast<'b>( + &'b self, + _pass: crate::query_builder::AstPass<'_, 'b, DB>, + ) -> crate::QueryResult<()> { + Ok(()) + } + } + }; +} + +mod aggregate_filter; +mod aggregate_order; +pub(crate) mod frame_clause; +mod over_clause; +mod partition_by; +mod prefix; +mod within_group; + +use self::aggregate_filter::{FilterDsl, NoFilter}; +use self::aggregate_order::{NoOrder, OrderAggregateDsl, OrderWindowDsl}; +use self::frame_clause::{FrameDsl, NoFrame}; +pub use self::over_clause::OverClause; +use self::over_clause::{NoWindow, OverDsl}; +use self::partition_by::PartitionByDsl; +use self::prefix::{All, AllDsl, DistinctDsl, NoPrefix}; +use self::within_group::{NoWithin, WithinGroupDsl}; + +#[derive(QueryId, Debug)] +pub struct AggregateExpression< + Fn, + Prefix = NoPrefix, + Order = NoOrder, + Filter = NoFilter, + Within = NoWithin, + Window = NoWindow, +> { + prefix: Prefix, + function: Fn, + order: Order, + filter: Filter, + within_group: Within, + window: Window, +} + +impl QueryFragment + for AggregateExpression +where + DB: crate::backend::Backend + crate::backend::DieselReserveSpecialization, + Fn: FunctionFragment, + Prefix: QueryFragment, + Order: QueryFragment, + Filter: QueryFragment, + Within: QueryFragment, + Window: QueryFragment + WindowFunctionFragment, +{ + fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, DB>) -> QueryResult<()> { + pass.push_sql(Fn::FUNCTION_NAME); + pass.push_sql("("); + self.prefix.walk_ast(pass.reborrow())?; + self.function.walk_arguments(pass.reborrow())?; + self.order.walk_ast(pass.reborrow())?; + pass.push_sql(")"); + self.within_group.walk_ast(pass.reborrow())?; + self.filter.walk_ast(pass.reborrow())?; + self.window.walk_ast(pass.reborrow())?; + Ok(()) + } +} + +impl ValidGrouping + for AggregateExpression +where + Fn: ValidGrouping, +{ + type IsAggregate = >::IsAggregate; +} + +impl ValidGrouping + for AggregateExpression< + Fn, + Prefix, + Order, + Filter, + Within, + OverClause, + > +where + Fn: IsWindowFunction, + Fn::ArgTypes: ValidGrouping, +{ + // not sure about that, check this + type IsAggregate = >::IsAggregate; +} + +impl Expression + for AggregateExpression +where + Fn: Expression, +{ + type SqlType = ::SqlType; +} + +impl AppearsOnTable + for AggregateExpression +where + Self: Expression, + Fn: AppearsOnTable, +{ +} + +impl SelectableExpression + for AggregateExpression +where + Self: Expression, + Fn: SelectableExpression, +{ +} + +/// A helper marker trait that this function is a window function +/// This is only used to provide the gate the `WindowExpressionMethods` +/// trait onto, not to check if the construct is valid for a given backend +/// This check is postponed to building the query via `QueryFragment` +/// (We have access to the DB type there) +pub trait IsWindowFunction { + /// A tuple of all arg types + type ArgTypes; +} + +/// A helper marker trait that this function is a valid window function +/// for the given backend +/// this trait is used to transport information that +/// a certain function can be used as window function for a specific +/// backend +/// We allow to specialize this function for different SQL dialects +pub trait WindowFunctionFragment {} + +/// A helper marker trait that this function as a aggregate function +/// This is only used to provide the gate the `AggregateExpressionMethods` +/// trait onto, not to check if the construct is valid for a given backend +/// This check is postponed to building the query via `QueryFragment` +/// (We have access to the DB type there) +pub trait IsAggregateFunction {} + +/// A specialized QueryFragment helper trait that allows us to walk the function name +/// and the function arguments in seperate steps +pub trait FunctionFragment { + /// The name of the sql function + const FUNCTION_NAME: &'static str; + + /// Walk the function argument part (everything between ()) + fn walk_arguments<'b>(&'b self, pass: AstPass<'_, 'b, DB>) -> QueryResult<()>; +} + +// TODO: write helper types for all functions +// TODO: write doc tests for all functions +/// Expression methods to build aggregate function expressions +pub trait AggregateExpressionMethods: Sized { + /// `DISTINCT` modifier + fn distinct(self) -> Self::Output + where + Self: DistinctDsl, + { + ::distinct(self) + } + + /// `ALL` modifier + fn all(self) -> Self::Output + where + Self: AllDsl, + { + ::all(self) + } + + /// Add an aggregate filter + fn filter_aggregate

(self, f: P) -> Self::Output + where + P: AsExpression, + Self: FilterDsl, + { + >::filter(self, f.as_expression()) + } + + /// Add an aggregate order + fn order_aggregate(self, o: O) -> Self::Output + where + Self: OrderAggregateDsl, + { + >::order(self, o) + } + + // todo: restrict this to order set aggregates + // (we don't have any in diesel yet) + #[doc(hidden)] // for now + fn within_group(self, o: O) -> Self::Output + where + Self: WithinGroupDsl, + { + >::within_group(self, o) + } +} + +impl AggregateExpressionMethods for T {} + +/// Methods to construct a window function call +pub trait WindowExpressionMethods: Sized { + /// Turn a function call into a window function call + fn over(self) -> Self::Output + where + Self: OverDsl, + { + ::over(self) + } + + /// Add a filter to the current window function + // todo: do we want `or_filter` as well? + fn filter_window

(self, f: P) -> Self::Output + where + P: AsExpression, + Self: FilterDsl, + { + >::filter(self, f.as_expression()) + } + + /// Add a partition clause to the current window function + fn partition_by(self, expr: E) -> Self::Output + where + Self: PartitionByDsl, + { + >::partition_by(self, expr) + } + + /// Add a order clause to the current window function + fn window_order(self, expr: E) -> Self::Output + where + Self: OrderWindowDsl, + { + >::order(self, expr) + } + + /// Add a frame clause to the current window function + fn frame_by(self, expr: E) -> Self::Output + where + Self: FrameDsl, + { + >::frame(self, expr) + } +} + +impl WindowExpressionMethods for T {} diff --git a/diesel/src/expression/functions/aggregate_expressions/aggregate_filter.rs b/diesel/src/expression/functions/aggregate_expressions/aggregate_filter.rs new file mode 100644 index 000000000000..6363c3293e43 --- /dev/null +++ b/diesel/src/expression/functions/aggregate_expressions/aggregate_filter.rs @@ -0,0 +1,123 @@ +use super::aggregate_order::NoOrder; +use super::prefix::NoPrefix; +use super::AggregateExpression; +use super::IsAggregateFunction; +use super::NoWindow; +use super::NoWithin; +use crate::backend::{sql_dialect, Backend, SqlDialect}; +use crate::query_builder::where_clause::NoWhereClause; +use crate::query_builder::where_clause::WhereAnd; +use crate::query_builder::QueryFragment; +use crate::query_builder::{AstPass, QueryId}; +use crate::sql_types::BoolOrNullableBool; +use crate::Expression; +use crate::QueryResult; + +empty_clause!(NoFilter); + +#[derive(QueryId, Copy, Clone, Debug)] +pub struct Filter

(P); + +impl QueryFragment for Filter

+where + Self: QueryFragment, + DB: Backend, +{ + fn walk_ast<'b>(&'b self, pass: AstPass<'_, 'b, DB>) -> QueryResult<()> { + >::walk_ast(self, pass) + } +} + +impl + QueryFragment< + DB, + sql_dialect::aggregate_function_expressions::PostgresLikeAggregateFunctionExpressions, + > for Filter

+where + P: QueryFragment, + DB: Backend + SqlDialect, +{ + fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, DB>) -> QueryResult<()> { + pass.push_sql(" FILTER ("); + self.0.walk_ast(pass.reborrow())?; + pass.push_sql(")"); + Ok(()) + } +} + +pub trait FilterDsl

{ + type Output; + + fn filter(self, f: P) -> Self::Output; +} + +impl FilterDsl

for T +where + T: IsAggregateFunction, + P: Expression, + ST: BoolOrNullableBool, +{ + type Output = + AggregateExpression>::Output>>; + + fn filter(self, f: P) -> Self::Output { + AggregateExpression { + prefix: NoPrefix, + function: self, + order: NoOrder, + filter: Filter(NoWhereClause.and(f)), + within_group: NoWithin, + window: NoWindow, + } + } +} + +impl FilterDsl

+ for AggregateExpression, Within, Window> +where + P: Expression, + ST: BoolOrNullableBool, + F: WhereAnd

, +{ + type Output = + AggregateExpression>::Output>, Within, Window>; + + fn filter(self, f: P) -> Self::Output { + AggregateExpression { + prefix: self.prefix, + function: self.function, + order: self.order, + filter: Filter(WhereAnd::

::and(self.filter.0, f)), + within_group: self.within_group, + window: self.window, + } + } +} + +impl FilterDsl

+ for AggregateExpression +where + P: Expression, + ST: BoolOrNullableBool, + NoWhereClause: WhereAnd

, +{ + type Output = AggregateExpression< + Fn, + Prefix, + Order, + Filter<>::Output>, + Within, + Window, + >; + + fn filter(self, f: P) -> Self::Output { + AggregateExpression { + prefix: self.prefix, + function: self.function, + order: self.order, + filter: Filter(WhereAnd::

::and(NoWhereClause, f)), + within_group: self.within_group, + window: self.window, + } + } +} diff --git a/diesel/src/expression/functions/aggregate_expressions/aggregate_order.rs b/diesel/src/expression/functions/aggregate_expressions/aggregate_order.rs new file mode 100644 index 000000000000..cc7f9b1603a5 --- /dev/null +++ b/diesel/src/expression/functions/aggregate_expressions/aggregate_order.rs @@ -0,0 +1,124 @@ +use super::IsAggregateFunction; +use super::NoFilter; +use super::NoPrefix; +use super::NoWindow; +use super::NoWithin; +use super::{over_clause::OverClause, AggregateExpression}; +use crate::backend::{sql_dialect, Backend, SqlDialect}; +use crate::query_builder::order_clause::OrderClause; +use crate::query_builder::QueryFragment; +use crate::query_builder::{AstPass, QueryId}; +use crate::{Expression, QueryResult}; + +empty_clause!(NoOrder); + +#[derive(QueryId, Copy, Clone, Debug)] +pub struct Order(OrderClause); + +impl QueryFragment for Order +where + Self: QueryFragment, + DB: Backend, +{ + fn walk_ast<'b>(&'b self, pass: AstPass<'_, 'b, DB>) -> QueryResult<()> { + >::walk_ast(self, pass) + } +} + +impl + QueryFragment< + DB, + sql_dialect::aggregate_function_expressions::PostgresLikeAggregateFunctionExpressions, + > for Order +where + OrderClause: QueryFragment, + DB: Backend +SqlDialect, +{ + fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, DB>) -> QueryResult<()> { + self.0.walk_ast(pass.reborrow())?; + Ok(()) + } +} + +pub trait OrderAggregateDsl { + type Output; + + fn order(self, expr: E) -> Self::Output; +} + +impl OrderAggregateDsl for T +where + T: IsAggregateFunction, + E: Expression, +{ + type Output = AggregateExpression>; + + fn order(self, expr: E) -> Self::Output { + AggregateExpression { + prefix: NoPrefix, + function: self, + order: Order(OrderClause(expr)), + filter: NoFilter, + within_group: NoWithin, + window: NoWindow, + } + } +} + +impl OrderAggregateDsl + for AggregateExpression +{ + type Output = AggregateExpression, Filter>; + + fn order(self, expr: O) -> Self::Output { + AggregateExpression { + prefix: self.prefix, + function: self.function, + order: Order(OrderClause(expr)), + filter: self.filter, + within_group: self.within_group, + window: NoWindow, + } + } +} + +pub trait OrderWindowDsl { + type Output; + + fn order(self, expr: O) -> Self::Output; +} + +impl OrderWindowDsl + for AggregateExpression< + Fn, + NoPrefix, + NoOrder, + Filter, + NoWithin, + OverClause, + > +{ + type Output = AggregateExpression< + Fn, + NoPrefix, + NoOrder, + Filter, + NoWithin, + OverClause, Frame>, + >; + + fn order(self, expr: E) -> Self::Output { + AggregateExpression { + prefix: NoPrefix, + function: self.function, + order: NoOrder, + filter: self.filter, + within_group: NoWithin, + window: OverClause { + partition_by: self.window.partition_by, + order: Order(OrderClause(expr)), + frame_clause: self.window.frame_clause, + }, + } + } +} diff --git a/diesel/src/expression/functions/aggregate_expressions/frame_clause.rs b/diesel/src/expression/functions/aggregate_expressions/frame_clause.rs new file mode 100644 index 000000000000..26907ef6ac19 --- /dev/null +++ b/diesel/src/expression/functions/aggregate_expressions/frame_clause.rs @@ -0,0 +1,349 @@ +use crate::backend::sql_dialect; +use crate::query_builder::{QueryFragment, QueryId}; +use crate::serialize::ToSql; +use crate::sql_types::BigInt; + +use super::aggregate_order::NoOrder; +use super::over_clause::OverClause; +use super::prefix::NoPrefix; +use super::within_group::NoWithin; +use super::AggregateExpression; + +empty_clause!(NoFrame); + +#[derive(QueryId, Copy, Clone, Debug)] +pub struct FrameClause(F); + +impl QueryFragment for FrameClause +where + F: QueryFragment, + DB: crate::backend::Backend, +{ + fn walk_ast<'b>( + &'b self, + pass: crate::query_builder::AstPass<'_, 'b, DB>, + ) -> crate::QueryResult<()> { + self.0.walk_ast(pass)?; + Ok(()) + } +} + +macro_rules! simple_frame_expr { + ($name: ident, $kind: expr) => { + #[derive(QueryId, Clone, Copy, Debug)] + #[doc(hidden)] + pub struct $name; + + impl QueryFragment for $name + where + DB: crate::backend::Backend, + { + fn walk_ast<'b>( + &'b self, + mut pass: crate::query_builder::AstPass<'_, 'b, DB>, + ) -> crate::QueryResult<()> { + pass.push_sql($kind); + Ok(()) + } + } + }; +} + +// kinds +simple_frame_expr!(Range, " RANGE "); +simple_frame_expr!(Rows, " ROWS "); + +#[derive(QueryId, Clone, Copy, Debug)] +#[doc(hidden)] +pub struct Groups; + +impl QueryFragment for Groups +where + DB: crate::backend::Backend, + Self: QueryFragment, +{ + fn walk_ast<'b>( + &'b self, + pass: crate::query_builder::AstPass<'_, 'b, DB>, + ) -> crate::QueryResult<()> { + >::walk_ast(self, pass) + } +} +impl QueryFragment + for Groups +where + DB: crate::backend::Backend, +{ + fn walk_ast<'b>( + &'b self, + mut pass: crate::query_builder::AstPass<'_, 'b, DB>, + ) -> crate::QueryResult<()> { + pass.push_sql(" GROUPS "); + Ok(()) + } +} + +// start & end +simple_frame_expr!(UnboundedPreceding, "UNBOUNDED PRECEDING "); +simple_frame_expr!(CurrentRow, "CURRENT ROW "); +simple_frame_expr!(UnboundedFollowing, "UNBOUNDED FOLLOWING "); + +// exclusion +simple_frame_expr!(ExcludeCurrentRow, "EXCLUDE CURRENT ROW "); +simple_frame_expr!(ExcludeGroup, "EXCLUDE GROUP "); +simple_frame_expr!(ExcludeTies, "EXCLUDE TIES "); +simple_frame_expr!(ExcludeNoOthers, "EXCLUDE NO OTHERS "); + +#[derive(QueryId, Clone, Copy, Debug)] +pub struct OffsetPreceding(i64); + +impl QueryFragment for OffsetPreceding +where + DB: crate::backend::Backend, + i64: ToSql, +{ + fn walk_ast<'b>( + &'b self, + mut pass: crate::query_builder::AstPass<'_, 'b, DB>, + ) -> crate::QueryResult<()> { + pass.push_bind_param::(&self.0)?; + pass.push_sql(" PRECEDING "); + Ok(()) + } +} + +#[derive(QueryId, Clone, Copy, Debug)] +pub struct OffsetFollowing(i64); + +impl QueryFragment for OffsetFollowing +where + DB: crate::backend::Backend, + i64: ToSql, +{ + fn walk_ast<'b>( + &'b self, + mut pass: crate::query_builder::AstPass<'_, 'b, DB>, + ) -> crate::QueryResult<()> { + pass.push_bind_param::(&self.0)?; + pass.push_sql(" FOLLOWING "); + Ok(()) + } +} + +pub trait FrameDsl { + type Output; + + fn frame(self, expr: F) -> Self::Output; +} + +impl FrameDsl + for AggregateExpression< + Fn, + NoPrefix, + NoOrder, + Filter, + NoWithin, + OverClause, + > +where + E: FrameClauseExpression, +{ + type Output = AggregateExpression< + Fn, + NoPrefix, + NoOrder, + Filter, + NoWithin, + OverClause>, + >; + + fn frame(self, expr: E) -> Self::Output { + AggregateExpression { + prefix: self.prefix, + function: self.function, + order: self.order, + filter: self.filter, + within_group: self.within_group, + window: OverClause { + partition_by: self.window.partition_by, + order: self.window.order, + frame_clause: FrameClause(expr), + }, + } + } +} + +pub trait FrameClauseExpression {} + +pub trait FrameClauseStartBound {} +pub trait FrameClauseEndBound {} + +impl FrameClauseEndBound for UnboundedFollowing {} +impl FrameClauseStartBound for UnboundedPreceding {} +impl FrameClauseEndBound for CurrentRow {} +impl FrameClauseStartBound for CurrentRow {} +impl FrameClauseStartBound for OffsetFollowing {} +impl FrameClauseEndBound for OffsetFollowing {} +impl FrameClauseStartBound for OffsetPreceding {} +impl FrameClauseEndBound for OffsetPreceding {} + +pub trait FrameCauseExclusion {} + +impl FrameCauseExclusion for ExcludeGroup {} +impl FrameCauseExclusion for ExcludeNoOthers {} +impl FrameCauseExclusion for ExcludeTies {} +impl FrameCauseExclusion for ExcludeCurrentRow {} + +/// Construct a frame clause for window functions from an integer +pub trait FrameBoundDsl { + /// Use the preceding frame clause specification + fn preceding(self) -> OffsetPreceding; + + /// Use the following frame clause specification + fn following(self) -> OffsetFollowing; +} + +impl FrameBoundDsl for i64 { + fn preceding(self) -> OffsetPreceding { + OffsetPreceding(self) + } + + fn following(self) -> OffsetFollowing { + OffsetFollowing(self) + } +} + +empty_clause!(NoExclusion); + +#[derive(QueryId, Copy, Clone, Debug)] +pub struct StartFrame { + kind: Kind, + start: Start, + exclusion: Exclusion, +} + +impl QueryFragment for StartFrame +where + Kind: QueryFragment, + Start: QueryFragment, + Exclusion: QueryFragment, + DB: crate::backend::Backend, +{ + fn walk_ast<'b>( + &'b self, + mut pass: crate::query_builder::AstPass<'_, 'b, DB>, + ) -> crate::QueryResult<()> { + self.kind.walk_ast(pass.reborrow())?; + self.start.walk_ast(pass.reborrow())?; + self.exclusion.walk_ast(pass.reborrow())?; + Ok(()) + } +} + +impl FrameClauseExpression for StartFrame {} + +#[derive(QueryId, Copy, Clone, Debug)] +pub struct BetweenFrame { + kind: Kind, + start: Start, + end: End, + exclusion: Exclusion, +} + +impl QueryFragment + for BetweenFrame +where + Kind: QueryFragment, + Start: QueryFragment, + End: QueryFragment, + Exclusion: QueryFragment, + DB: crate::backend::Backend, +{ + fn walk_ast<'b>( + &'b self, + mut pass: crate::query_builder::AstPass<'_, 'b, DB>, + ) -> crate::QueryResult<()> { + self.kind.walk_ast(pass.reborrow())?; + pass.push_sql(" BETWEEN "); + self.start.walk_ast(pass.reborrow())?; + pass.push_sql(" AND "); + self.end.walk_ast(pass.reborrow())?; + self.exclusion.walk_ast(pass.reborrow())?; + Ok(()) + } +} + +impl FrameClauseExpression + for BetweenFrame +{ +} + +pub trait FrameClauseDslHelper: Sized {} + +/// Construct a frame clause for window functions +pub trait FrameClauseDsl: FrameClauseDslHelper { + /// Construct a frame clause with a starting bound + fn start_with(self, start: E) -> StartFrame + where + E: FrameClauseStartBound, + { + StartFrame { + kind: self, + start, + exclusion: NoExclusion, + } + } + + /// Construct a frame clause with a starting bound and an exclusion condition + fn start_with_exclusion(self, start: E1, exclusion: E2) -> StartFrame + where + E1: FrameClauseStartBound, + E2: FrameCauseExclusion, + { + StartFrame { + kind: self, + start, + exclusion, + } + } + + /// Construct a between frame clause with a starting and end bound + fn between(self, start: E1, end: E2) -> BetweenFrame + where + E1: FrameClauseStartBound, + E2: FrameClauseEndBound, + { + BetweenFrame { + kind: self, + start, + end, + exclusion: NoExclusion, + } + } + + /// Construct a between frame clause with a starting and end bound with an exclusion condition + fn between_with_exclusion( + self, + start: E1, + end: E2, + exclusion: E3, + ) -> BetweenFrame + where + E1: FrameClauseStartBound, + E2: FrameClauseEndBound, + E3: FrameCauseExclusion, + { + BetweenFrame { + kind: self, + start, + end, + exclusion, + } + } +} + +impl FrameClauseDsl for T where T: FrameClauseDslHelper {} + +impl FrameClauseDslHelper for Range {} +impl FrameClauseDslHelper for Rows {} +impl FrameClauseDslHelper for Groups {} diff --git a/diesel/src/expression/functions/aggregate_expressions/over_clause.rs b/diesel/src/expression/functions/aggregate_expressions/over_clause.rs new file mode 100644 index 000000000000..a8eab14d569c --- /dev/null +++ b/diesel/src/expression/functions/aggregate_expressions/over_clause.rs @@ -0,0 +1,90 @@ +use super::aggregate_filter::NoFilter; +use super::aggregate_order::NoOrder; +use super::partition_by::NoPartition; +use super::prefix::NoPrefix; +use super::within_group::NoWithin; +use super::AggregateExpression; +use super::IsWindowFunction; +use super::NoFrame; +use super::WindowFunctionFragment; +use crate::query_builder::QueryFragment; +use crate::query_builder::{AstPass, QueryId}; +use crate::QueryResult; + +empty_clause!(NoWindow); + +impl WindowFunctionFragment for NoWindow where DB: crate::backend::Backend {} + +/// TODO +#[derive(Clone, Copy, QueryId, Debug)] +pub struct OverClause { + pub(crate) partition_by: Partition, + pub(crate) order: Order, + pub(crate) frame_clause: Frame, +} + +impl QueryFragment for OverClause +where + Partition: QueryFragment, + Order: QueryFragment, + Frame: QueryFragment, + DB: crate::backend::Backend, +{ + fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, DB>) -> QueryResult<()> { + pass.push_sql(" OVER ("); + self.partition_by.walk_ast(pass.reborrow())?; + self.order.walk_ast(pass.reborrow())?; + self.frame_clause.walk_ast(pass.reborrow())?; + pass.push_sql(")"); + Ok(()) + } +} + +pub trait OverDsl { + type Output; + + fn over(self) -> Self::Output; +} + +impl OverDsl for F +where + F: IsWindowFunction, +{ + type Output = AggregateExpression; + + fn over(self) -> Self::Output { + AggregateExpression { + prefix: NoPrefix, + function: self, + order: NoOrder, + filter: NoFilter, + within_group: NoWithin, + window: OverClause { + partition_by: NoPartition, + order: NoOrder, + frame_clause: NoFrame, + }, + } + } +} + +impl OverDsl + for AggregateExpression +{ + type Output = AggregateExpression; + + fn over(self) -> Self::Output { + AggregateExpression { + prefix: NoPrefix, + function: self.function, + order: NoOrder, + filter: self.filter, + within_group: NoWithin, + window: OverClause { + partition_by: NoPartition, + order: NoOrder, + frame_clause: NoFrame, + }, + } + } +} diff --git a/diesel/src/expression/functions/aggregate_expressions/partition_by.rs b/diesel/src/expression/functions/aggregate_expressions/partition_by.rs new file mode 100644 index 000000000000..471759e0c8be --- /dev/null +++ b/diesel/src/expression/functions/aggregate_expressions/partition_by.rs @@ -0,0 +1,66 @@ +use super::aggregate_order::NoOrder; +use super::over_clause::OverClause; +use super::prefix::NoPrefix; +use super::within_group::NoWithin; +use super::AggregateExpression; +use crate::query_builder::QueryFragment; +use crate::query_builder::{AstPass, QueryId}; +use crate::QueryResult; + +empty_clause!(NoPartition); + +#[derive(QueryId, Clone, Copy, Debug)] +pub struct PartitionBy(T); + +impl QueryFragment for PartitionBy +where + T: QueryFragment, + DB: crate::backend::Backend, +{ + fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, DB>) -> QueryResult<()> { + pass.push_sql(" PARTITION BY "); + self.0.walk_ast(pass.reborrow())?; + Ok(()) + } +} + +pub trait PartitionByDsl { + type Output; + + fn partition_by(self, expr: E) -> Self::Output; +} + +impl PartitionByDsl + for AggregateExpression< + Fn, + NoPrefix, + NoOrder, + Filter, + NoWithin, + OverClause, + > +{ + type Output = AggregateExpression< + Fn, + NoPrefix, + NoOrder, + Filter, + NoWithin, + OverClause, Order, Frame>, + >; + + fn partition_by(self, expr: E) -> Self::Output { + AggregateExpression { + prefix: NoPrefix, + function: self.function, + order: NoOrder, + filter: self.filter, + within_group: NoWithin, + window: OverClause { + partition_by: PartitionBy(expr), + order: self.window.order, + frame_clause: self.window.frame_clause, + }, + } + } +} diff --git a/diesel/src/expression/functions/aggregate_expressions/prefix.rs b/diesel/src/expression/functions/aggregate_expressions/prefix.rs new file mode 100644 index 000000000000..61ba4891587f --- /dev/null +++ b/diesel/src/expression/functions/aggregate_expressions/prefix.rs @@ -0,0 +1,122 @@ +use super::AggregateExpression; +use super::IsAggregateFunction; +use super::NoFilter; +use super::NoOrder; +use super::NoWindow; +use super::NoWithin; +use crate::query_builder::{AstPass, QueryFragment, QueryId}; +use crate::QueryResult; + +empty_clause!(NoPrefix); + +#[derive(Debug, Clone, Copy, QueryId)] +pub struct All; + +impl QueryFragment for All +where + DB: crate::backend::Backend, +{ + fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, DB>) -> QueryResult<()> { + pass.push_sql(" ALL "); + Ok(()) + } +} + +#[derive(Debug, Clone, Copy, QueryId)] +pub struct Distinct; + +impl QueryFragment for Distinct +where + DB: crate::backend::Backend, +{ + fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, DB>) -> QueryResult<()> { + pass.push_sql(" DISTINCT "); + Ok(()) + } +} + +pub trait DistinctDsl { + type Output; + + fn distinct(self) -> Self::Output; +} + +impl DistinctDsl for T +where + T: IsAggregateFunction, +{ + type Output = AggregateExpression; + + fn distinct(self) -> Self::Output { + AggregateExpression { + prefix: Distinct, + function: self, + order: NoOrder, + filter: NoFilter, + within_group: NoWithin, + window: NoWindow, + } + } +} + +impl DistinctDsl + for AggregateExpression +where + T: IsAggregateFunction, +{ + type Output = AggregateExpression; + + fn distinct(self) -> Self::Output { + AggregateExpression { + prefix: Distinct, + function: self.function, + order: self.order, + filter: self.filter, + within_group: self.within_group, + window: self.window, + } + } +} + +pub trait AllDsl { + type Output; + + fn all(self) -> Self::Output; +} + +impl AllDsl for T +where + T: IsAggregateFunction, +{ + type Output = AggregateExpression; + + fn all(self) -> Self::Output { + AggregateExpression { + prefix: All, + function: self, + order: NoOrder, + filter: NoFilter, + within_group: NoWithin, + window: NoWindow, + } + } +} + +impl AllDsl + for AggregateExpression +where + T: IsAggregateFunction, +{ + type Output = AggregateExpression; + + fn all(self) -> Self::Output { + AggregateExpression { + prefix: All, + function: self.function, + order: self.order, + filter: self.filter, + within_group: self.within_group, + window: self.window, + } + } +} diff --git a/diesel/src/expression/functions/aggregate_expressions/within_group.rs b/diesel/src/expression/functions/aggregate_expressions/within_group.rs new file mode 100644 index 000000000000..3644d0264abd --- /dev/null +++ b/diesel/src/expression/functions/aggregate_expressions/within_group.rs @@ -0,0 +1,90 @@ +use super::AggregateExpression; +use super::All; +use super::IsAggregateFunction; +use super::NoFilter; +use super::NoOrder; +use super::NoPrefix; +use super::NoWindow; +use crate::query_builder::order_clause::OrderClause; +use crate::query_builder::QueryFragment; +use crate::query_builder::{AstPass, QueryId}; +use crate::Expression; +use crate::QueryResult; + +empty_clause!(NoWithin); + +#[derive(QueryId, Copy, Clone, Debug)] +pub struct WithinGroup(OrderClause); + +// this clause is only postgres specific +#[cfg(feature = "postgres_backend")] +impl QueryFragment for WithinGroup +where + OrderClause: QueryFragment, +{ + fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, diesel::pg::Pg>) -> QueryResult<()> { + pass.push_sql(" WITHIN GROUP ("); + self.0.walk_ast(pass.reborrow())?; + pass.push_sql(")"); + Ok(()) + } +} + +pub trait WithinGroupDsl { + type Output; + + fn within_group(self, expr: E) -> Self::Output; +} + +impl WithinGroupDsl for T +where + T: IsAggregateFunction, + E: Expression, +{ + type Output = AggregateExpression>; + + fn within_group(self, expr: E) -> Self::Output { + AggregateExpression { + prefix: NoPrefix, + function: self, + order: NoOrder, + filter: NoFilter, + within_group: WithinGroup(OrderClause(expr)), + window: NoWindow, + } + } +} + +impl WithinGroupDsl + for AggregateExpression +{ + type Output = AggregateExpression>; + + fn within_group(self, expr: O) -> Self::Output { + AggregateExpression { + prefix: self.prefix, + function: self.function, + order: self.order, + filter: self.filter, + within_group: WithinGroup(OrderClause(expr)), + window: NoWindow, + } + } +} + +impl WithinGroupDsl + for AggregateExpression +{ + type Output = AggregateExpression>; + + fn within_group(self, expr: O) -> Self::Output { + AggregateExpression { + prefix: self.prefix, + function: self.function, + order: self.order, + filter: self.filter, + within_group: WithinGroup(OrderClause(expr)), + window: NoWindow, + } + } +} diff --git a/diesel/src/expression/functions/mod.rs b/diesel/src/expression/functions/mod.rs index db8f79e7a730..c0fb882c6d1b 100644 --- a/diesel/src/expression/functions/mod.rs +++ b/diesel/src/expression/functions/mod.rs @@ -94,6 +94,7 @@ macro_rules! no_arg_sql_function { }; } +pub(crate) mod aggregate_expressions; pub(crate) mod aggregate_folding; pub(crate) mod aggregate_ordering; pub(crate) mod date_and_time; diff --git a/diesel/src/expression/mod.rs b/diesel/src/expression/mod.rs index 9d0d127eb4b7..a313cc893176 100644 --- a/diesel/src/expression/mod.rs +++ b/diesel/src/expression/mod.rs @@ -88,6 +88,21 @@ pub(crate) mod dsl { #[cfg(feature = "mysql_backend")] pub use crate::mysql::query_builder::DuplicatedKeys; + + pub use super::functions::aggregate_expressions::AggregateExpressionMethods; + pub use super::functions::aggregate_expressions::WindowExpressionMethods; + + pub use super::functions::aggregate_expressions::frame_clause::{ + FrameBoundDsl, FrameClauseDsl, + }; + + /// Different frame clause specifications for window functions + pub mod frame { + pub use super::super::functions::aggregate_expressions::frame_clause::{ + CurrentRow, ExcludeCurrentRow, ExcludeGroup, ExcludeNoOthers, ExcludeTies, Groups, + Range, Rows, UnboundedFollowing, UnboundedPreceding, + }; + } } #[doc(inline)] diff --git a/diesel/src/internal/mod.rs b/diesel/src/internal/mod.rs index 5ba2e38905ff..e0451a3c135b 100644 --- a/diesel/src/internal/mod.rs +++ b/diesel/src/internal/mod.rs @@ -6,4 +6,5 @@ pub mod alias_macro; pub mod derives; pub mod operators_macro; +pub mod sql_functions; pub mod table_macro; diff --git a/diesel/src/internal/sql_functions.rs b/diesel/src/internal/sql_functions.rs new file mode 100644 index 000000000000..441035b04014 --- /dev/null +++ b/diesel/src/internal/sql_functions.rs @@ -0,0 +1,4 @@ +#[doc(hidden)] +pub use crate::expression::functions::aggregate_expressions::{ + FunctionFragment, IsAggregateFunction, IsWindowFunction, OverClause, WindowFunctionFragment, +}; diff --git a/diesel/src/mysql/backend.rs b/diesel/src/mysql/backend.rs index 290935d8eb38..69f0345e75cb 100644 --- a/diesel/src/mysql/backend.rs +++ b/diesel/src/mysql/backend.rs @@ -4,6 +4,7 @@ use super::query_builder::MysqlQueryBuilder; use super::MysqlValue; use crate::backend::sql_dialect::on_conflict_clause::SupportsOnConflictClause; use crate::backend::*; +use crate::internal::derives::multiconnection::sql_dialect; use crate::query_builder::bind_collector::RawBytesBindCollector; use crate::sql_types::TypeMetadata; @@ -89,6 +90,12 @@ impl SqlDialect for Mysql { type ConcatClause = MysqlConcatClause; type AliasSyntax = sql_dialect::alias_syntax::AsAliasSyntax; + + type WindowFrameClauseGroupSupport = + sql_dialect::window_frame_clause_group_support::NoGroupWindowFrameUnit; + + type AggregateFunctionExpressions = + sql_dialect::aggregate_function_expressions::NoAggregateFunctionExpressions; } impl DieselReserveSpecialization for Mysql {} diff --git a/diesel/src/pg/backend.rs b/diesel/src/pg/backend.rs index db18d681c8d0..f2d94946894e 100644 --- a/diesel/src/pg/backend.rs +++ b/diesel/src/pg/backend.rs @@ -141,6 +141,10 @@ impl SqlDialect for Pg { type ExistsSyntax = sql_dialect::exists_syntax::AnsiSqlExistsSyntax; type ArrayComparison = PgStyleArrayComparison; type AliasSyntax = sql_dialect::alias_syntax::AsAliasSyntax; + type WindowFrameClauseGroupSupport = + sql_dialect::window_frame_clause_group_support::IsoGroupWindowFrameUnit; + type AggregateFunctionExpressions = + sql_dialect::aggregate_function_expressions::PostgresLikeAggregateFunctionExpressions; } impl DieselReserveSpecialization for Pg {} diff --git a/diesel/src/pg/expression/functions.rs b/diesel/src/pg/expression/functions.rs index c8a22d2c15e1..5cec034fb791 100644 --- a/diesel/src/pg/expression/functions.rs +++ b/diesel/src/pg/expression/functions.rs @@ -2126,7 +2126,6 @@ define_sql_function! { /// # Ok(()) /// # } /// ``` - fn jsonb_array_length>(jsonb: E) -> E::Out; } diff --git a/diesel/src/query_builder/mod.rs b/diesel/src/query_builder/mod.rs index 64cf8f851cdb..cd7865494805 100644 --- a/diesel/src/query_builder/mod.rs +++ b/diesel/src/query_builder/mod.rs @@ -136,6 +136,8 @@ use crate::backend::Backend; use crate::result::QueryResult; use std::error::Error; +pub(crate) use self::private::NotSpecialized; + #[doc(hidden)] pub type Binds = Vec>>; /// A specialized Result type used with the query builder. diff --git a/diesel/src/sqlite/backend.rs b/diesel/src/sqlite/backend.rs index 80ca6fc0b348..8e1ede6a0601 100644 --- a/diesel/src/sqlite/backend.rs +++ b/diesel/src/sqlite/backend.rs @@ -67,6 +67,11 @@ impl SqlDialect for Sqlite { type ExistsSyntax = sql_dialect::exists_syntax::AnsiSqlExistsSyntax; type ArrayComparison = sql_dialect::array_comparison::AnsiSqlArrayComparison; type AliasSyntax = sql_dialect::alias_syntax::AsAliasSyntax; + + type WindowFrameClauseGroupSupport = + sql_dialect::window_frame_clause_group_support::IsoGroupWindowFrameUnit; + type AggregateFunctionExpressions = + sql_dialect::aggregate_function_expressions::PostgresLikeAggregateFunctionExpressions; } impl DieselReserveSpecialization for Sqlite {} diff --git a/diesel_derives/src/attrs.rs b/diesel_derives/src/attrs.rs index e196ba6f74ed..99e29a02d8a7 100644 --- a/diesel_derives/src/attrs.rs +++ b/diesel_derives/src/attrs.rs @@ -24,6 +24,7 @@ pub trait MySpanned { fn span(&self) -> Span; } +#[derive(Clone)] pub struct AttributeSpanWrapper { pub item: T, pub attribute_span: Span, diff --git a/diesel_derives/src/lib.rs b/diesel_derives/src/lib.rs index 3b53aca5527b..e57f6f3733c9 100644 --- a/diesel_derives/src/lib.rs +++ b/diesel_derives/src/lib.rs @@ -1325,7 +1325,10 @@ pub fn derive_valid_grouping(input: TokenStream) -> TokenStream { /// ``` #[proc_macro] pub fn define_sql_function(input: TokenStream) -> TokenStream { - sql_function::expand(parse_macro_input!(input), false).into() + match sql_function::expand(parse_macro_input!(input), false) { + Ok(o) => o.into(), + Err(e) => e.into_compile_error().into(), + } } /// A legacy version of [`define_sql_function!`]. @@ -1358,7 +1361,10 @@ pub fn define_sql_function(input: TokenStream) -> TokenStream { #[proc_macro] #[cfg(all(feature = "with-deprecated", not(feature = "without-deprecated")))] pub fn sql_function_proc(input: TokenStream) -> TokenStream { - sql_function::expand(parse_macro_input!(input), true).into() + match sql_function::expand(parse_macro_input!(input), true) { + Ok(o) => o.into(), + Err(e) => e.into_compile_error().into(), + } } /// This is an internal diesel macro that diff --git a/diesel_derives/src/sql_function.rs b/diesel_derives/src/sql_function.rs index c35a3fe95930..2b7dd5c2b185 100644 --- a/diesel_derives/src/sql_function.rs +++ b/diesel_derives/src/sql_function.rs @@ -3,50 +3,63 @@ use quote::quote; use quote::ToTokens; use syn::parse::{Parse, ParseStream, Result}; use syn::punctuated::Punctuated; +use syn::spanned::Spanned; use syn::{ - parenthesized, parse_quote, Attribute, GenericArgument, Generics, Ident, Meta, MetaNameValue, - PathArguments, Token, Type, + parenthesized, parse_quote, Attribute, GenericArgument, Generics, Ident, ImplGenerics, LitStr, + PathArguments, Token, Type, TypeGenerics, }; -pub(crate) fn expand(input: SqlFunctionDecl, legacy_helper_type_and_module: bool) -> TokenStream { +use crate::attrs::{AttributeSpanWrapper, MySpanned}; +use crate::util::parse_eq; + +pub(crate) fn expand( + input: SqlFunctionDecl, + legacy_helper_type_and_module: bool, +) -> Result { let SqlFunctionDecl { - mut attributes, + attributes, fn_token, fn_name, mut generics, - args, + ref args, return_type, } = input; let sql_name = attributes .iter() - .find(|attr| attr.meta.path().is_ident("sql_name")) - .and_then(|attr| { - if let Meta::NameValue(MetaNameValue { - value: - syn::Expr::Lit(syn::ExprLit { - lit: syn::Lit::Str(ref lit), - .. - }), - .. - }) = attr.meta - { - Some(lit.value()) - } else { - None - } + .find_map(|attr| match attr.item { + SqlFunctionAttribute::SqlName(_, ref value) => Some(value.value()), + _ => None, }) .unwrap_or_else(|| fn_name.to_string()); let is_aggregate = attributes .iter() - .any(|attr| attr.meta.path().is_ident("aggregate")); + .any(|attr| matches!(attr.item, SqlFunctionAttribute::Aggregate(..))); - attributes.retain(|attr| { - !attr.meta.path().is_ident("sql_name") && !attr.meta.path().is_ident("aggregate") - }); + let can_be_called_directly = !function_cannot_be_called_directly(&attributes)?; + + let window = attributes + .iter() + .find(|a| matches!(a.item, SqlFunctionAttribute::Window(..))) + .cloned(); + + let restrictions = attributes + .iter() + .find_map(|a| match a.item { + SqlFunctionAttribute::Restriction(ref r) => Some(r.clone()), + _ => None, + }) + .unwrap_or_default(); + + let attributes = attributes + .into_iter() + .filter_map(|a| match a.item { + SqlFunctionAttribute::Other(a) => Some(a), + _ => None, + }) + .collect::>(); - let args = &args; let (ref arg_name, ref arg_type): (Vec<_>, Vec<_>) = args .iter() .map(|StrictFnArg { name, ty, .. }| (name, ty)) @@ -95,6 +108,13 @@ pub(crate) fn expand(input: SqlFunctionDecl, legacy_helper_type_and_module: bool } let helper_type_doc = format!("The return type of [`{fn_name}()`](super::fn_name)"); + let query_fragment_impl = + can_be_called_directly.then_some(restrictions.generate_all_queryfragment_impls( + generics.clone(), + &ty_generics, + arg_name, + &fn_name, + )); let args_iter = args.iter(); let mut tokens = quote! { @@ -102,6 +122,7 @@ pub(crate) fn expand(input: SqlFunctionDecl, legacy_helper_type_and_module: bool use diesel::expression::{AsExpression, Expression, SelectableExpression, AppearsOnTable, ValidGrouping}; use diesel::query_builder::{QueryFragment, AstPass}; use diesel::sql_types::*; + use diesel::internal::sql_functions::*; use super::*; #[derive(Debug, Clone, Copy, diesel::query_builder::QueryId)] @@ -142,16 +163,16 @@ pub(crate) fn expand(input: SqlFunctionDecl, legacy_helper_type_and_module: bool { } - // __DieselInternal is what we call DB normally - impl #impl_generics_internal QueryFragment<__DieselInternal> + impl #impl_generics_internal FunctionFragment<__DieselInternal> for #fn_name #ty_generics where __DieselInternal: diesel::backend::Backend, #(#arg_name: QueryFragment<__DieselInternal>,)* { + const FUNCTION_NAME: &'static str = #sql_name; + #[allow(unused_assignments)] - fn walk_ast<'__b>(&'__b self, mut out: AstPass<'_, '__b, __DieselInternal>) -> QueryResult<()>{ - out.push_sql(concat!(#sql_name, "(")); + fn walk_arguments<'__b>(&'__b self, mut out: AstPass<'_, '__b, __DieselInternal>) -> QueryResult<()> { // we unroll the arguments manually here, to prevent borrow check issues let mut needs_comma = false; #( @@ -163,10 +184,11 @@ pub(crate) fn expand(input: SqlFunctionDecl, legacy_helper_type_and_module: bool needs_comma = true; } )* - out.push_sql(")"); Ok(()) } } + + #query_fragment_impl }; let is_supported_on_sqlite = cfg!(feature = "sqlite") @@ -174,222 +196,44 @@ pub(crate) fn expand(input: SqlFunctionDecl, legacy_helper_type_and_module: bool && is_sqlite_type(&return_type) && arg_type.iter().all(|a| is_sqlite_type(a)); - if is_aggregate { - tokens = quote! { - #tokens - - impl #impl_generics_internal ValidGrouping<__DieselInternal> - for #fn_name #ty_generics - { - type IsAggregate = diesel::expression::is_aggregate::Yes; - } - }; - if is_supported_on_sqlite { - tokens = quote! { - #tokens - - use diesel::sqlite::{Sqlite, SqliteConnection}; - use diesel::serialize::ToSql; - use diesel::deserialize::{FromSqlRow, StaticallySizedRow}; - use diesel::sqlite::SqliteAggregateFunction; - use diesel::sql_types::IntoNullable; - }; - - match arg_name.len() { - x if x > 1 => { - tokens = quote! { - #tokens - - #[allow(dead_code)] - /// Registers an implementation for this aggregate function on the given connection - /// - /// This function must be called for every `SqliteConnection` before - /// this SQL function can be used on SQLite. The implementation must be - /// deterministic (returns the same result given the same arguments). - pub fn register_impl( - conn: &mut SqliteConnection - ) -> QueryResult<()> - where - A: SqliteAggregateFunction<(#(#arg_name,)*)> - + Send - + 'static - + ::std::panic::UnwindSafe - + ::std::panic::RefUnwindSafe, - A::Output: ToSql<#return_type, Sqlite>, - (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> + - StaticallySizedRow<(#(#arg_type,)*), Sqlite> + - ::std::panic::UnwindSafe, - { - conn.register_aggregate_function::<(#(#arg_type,)*), #return_type, _, _, A>(#sql_name) - } - }; - } - 1 => { - let arg_name = arg_name[0]; - let arg_type = arg_type[0]; - - tokens = quote! { - #tokens - - #[allow(dead_code)] - /// Registers an implementation for this aggregate function on the given connection - /// - /// This function must be called for every `SqliteConnection` before - /// this SQL function can be used on SQLite. The implementation must be - /// deterministic (returns the same result given the same arguments). - pub fn register_impl( - conn: &mut SqliteConnection - ) -> QueryResult<()> - where - A: SqliteAggregateFunction<#arg_name> - + Send - + 'static - + std::panic::UnwindSafe - + std::panic::RefUnwindSafe, - A::Output: ToSql<#return_type, Sqlite>, - #arg_name: FromSqlRow<#arg_type, Sqlite> + - StaticallySizedRow<#arg_type, Sqlite> + - ::std::panic::UnwindSafe, - { - conn.register_aggregate_function::<#arg_type, #return_type, _, _, A>(#sql_name) - } - }; - } - _ => (), - } - } - } else { - tokens = quote! { - #tokens - - #[derive(ValidGrouping)] - pub struct __Derived<#(#arg_name,)*>(#(#arg_name,)*); - - impl #impl_generics_internal ValidGrouping<__DieselInternal> - for #fn_name #ty_generics - where - __Derived<#(#arg_name,)*>: ValidGrouping<__DieselInternal>, - { - type IsAggregate = <__Derived<#(#arg_name,)*> as ValidGrouping<__DieselInternal>>::IsAggregate; - } - }; - - if is_supported_on_sqlite && !arg_name.is_empty() { - tokens = quote! { - #tokens - - use diesel::sqlite::{Sqlite, SqliteConnection}; - use diesel::serialize::ToSql; - use diesel::deserialize::{FromSqlRow, StaticallySizedRow}; - - #[allow(dead_code)] - /// Registers an implementation for this function on the given connection - /// - /// This function must be called for every `SqliteConnection` before - /// this SQL function can be used on SQLite. The implementation must be - /// deterministic (returns the same result given the same arguments). If - /// the function is nondeterministic, call - /// `register_nondeterministic_impl` instead. - pub fn register_impl( - conn: &mut SqliteConnection, - f: F, - ) -> QueryResult<()> - where - F: Fn(#(#arg_name,)*) -> Ret + std::panic::UnwindSafe + Send + 'static, - (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> + - StaticallySizedRow<(#(#arg_type,)*), Sqlite>, - Ret: ToSql<#return_type, Sqlite>, - { - conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>( - #sql_name, - true, - move |(#(#arg_name,)*)| f(#(#arg_name,)*), - ) - } - - #[allow(dead_code)] - /// Registers an implementation for this function on the given connection - /// - /// This function must be called for every `SqliteConnection` before - /// this SQL function can be used on SQLite. - /// `register_nondeterministic_impl` should only be used if your - /// function can return different results with the same arguments (e.g. - /// `random`). If your function is deterministic, you should call - /// `register_impl` instead. - pub fn register_nondeterministic_impl( - conn: &mut SqliteConnection, - mut f: F, - ) -> QueryResult<()> - where - F: FnMut(#(#arg_name,)*) -> Ret + std::panic::UnwindSafe + Send + 'static, - (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> + - StaticallySizedRow<(#(#arg_type,)*), Sqlite>, - Ret: ToSql<#return_type, Sqlite>, - { - conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>( - #sql_name, - false, - move |(#(#arg_name,)*)| f(#(#arg_name,)*), - ) - } - }; - } - - if is_supported_on_sqlite && arg_name.is_empty() { - tokens = quote! { - #tokens - - use diesel::sqlite::{Sqlite, SqliteConnection}; - use diesel::serialize::ToSql; - - #[allow(dead_code)] - /// Registers an implementation for this function on the given connection - /// - /// This function must be called for every `SqliteConnection` before - /// this SQL function can be used on SQLite. The implementation must be - /// deterministic (returns the same result given the same arguments). If - /// the function is nondeterministic, call - /// `register_nondeterministic_impl` instead. - pub fn register_impl( - conn: &SqliteConnection, - f: F, - ) -> QueryResult<()> - where - F: Fn() -> Ret + std::panic::UnwindSafe + Send + 'static, - Ret: ToSql<#return_type, Sqlite>, - { - conn.register_noarg_sql_function::<#return_type, _, _>( - #sql_name, - true, - f, - ) - } + if let Some(ref window) = window { + tokens = generate_window_function_tokens( + window, + &impl_generics, + generics.clone(), + &ty_generics, + &fn_name, + arg_name, + tokens, + ); + } - #[allow(dead_code)] - /// Registers an implementation for this function on the given connection - /// - /// This function must be called for every `SqliteConnection` before - /// this SQL function can be used on SQLite. - /// `register_nondeterministic_impl` should only be used if your - /// function can return different results with the same arguments (e.g. - /// `random`). If your function is deterministic, you should call - /// `register_impl` instead. - pub fn register_nondeterministic_impl( - conn: &SqliteConnection, - mut f: F, - ) -> QueryResult<()> - where - F: FnMut() -> Ret + std::panic::UnwindSafe + Send + 'static, - Ret: ToSql<#return_type, Sqlite>, - { - conn.register_noarg_sql_function::<#return_type, _, _>( - #sql_name, - false, - f, - ) - } - }; - } + if is_aggregate { + tokens = generate_tokens_for_aggregate_functions( + tokens, + &impl_generics_internal, + &impl_generics, + &fn_name, + &ty_generics, + arg_name, + arg_type, + is_supported_on_sqlite, + window.as_ref(), + &return_type, + &sql_name, + ); + } else if window.is_none() { + tokens = generate_tokens_for_non_aggregate_functions( + tokens, + &impl_generics_internal, + &fn_name, + &ty_generics, + arg_name, + arg_type, + is_supported_on_sqlite, + &return_type, + &sql_name, + ); } let args_iter = args.iter(); @@ -413,7 +257,7 @@ pub(crate) fn expand(input: SqlFunctionDecl, legacy_helper_type_and_module: bool ) }; - quote! { + Ok(quote! { #(#attributes)* #[allow(non_camel_case_types)] pub #fn_token #fn_name #impl_generics (#(#args_iter,)*) @@ -434,11 +278,301 @@ pub(crate) fn expand(input: SqlFunctionDecl, legacy_helper_type_and_module: bool pub(crate) mod #internals_module_name { #tokens } + }) +} + +fn generate_window_function_tokens( + window: &AttributeSpanWrapper, + impl_generics: &syn::ImplGenerics<'_>, + generics: Generics, + ty_generics: &TypeGenerics<'_>, + fn_name: &Ident, + arg_name: &[&syn::Ident], + tokens: TokenStream, +) -> TokenStream { + let SqlFunctionAttribute::Window(_, ref restrictions) = window.item else { + unreachable!("We filtered for window attributes above") + }; + let window_function_impl = + restrictions.generate_all_window_fragment_impls(generics, ty_generics, fn_name); + quote::quote! { + #tokens + #window_function_impl + impl #impl_generics IsWindowFunction for #fn_name #ty_generics { + type ArgTypes = (#(#arg_name,)*); + } } } +#[allow(clippy::too_many_arguments)] +fn generate_tokens_for_non_aggregate_functions( + mut tokens: TokenStream, + impl_generics_internal: &syn::ImplGenerics<'_>, + fn_name: &syn::Ident, + ty_generics: &syn::TypeGenerics<'_>, + arg_name: &[&syn::Ident], + arg_type: &[&syn::Type], + is_supported_on_sqlite: bool, + return_type: &syn::Type, + sql_name: &str, +) -> TokenStream { + tokens = quote! { + #tokens + + #[derive(ValidGrouping)] + pub struct __Derived<#(#arg_name,)*>(#(#arg_name,)*); + + impl #impl_generics_internal ValidGrouping<__DieselInternal> + for #fn_name #ty_generics + where + __Derived<#(#arg_name,)*>: ValidGrouping<__DieselInternal>, + { + type IsAggregate = <__Derived<#(#arg_name,)*> as ValidGrouping<__DieselInternal>>::IsAggregate; + } + }; + + if is_supported_on_sqlite && !arg_name.is_empty() { + tokens = quote! { + #tokens + + use diesel::sqlite::{Sqlite, SqliteConnection}; + use diesel::serialize::ToSql; + use diesel::deserialize::{FromSqlRow, StaticallySizedRow}; + + #[allow(dead_code)] + /// Registers an implementation for this function on the given connection + /// + /// This function must be called for every `SqliteConnection` before + /// this SQL function can be used on SQLite. The implementation must be + /// deterministic (returns the same result given the same arguments). If + /// the function is nondeterministic, call + /// `register_nondeterministic_impl` instead. + pub fn register_impl( + conn: &mut SqliteConnection, + f: F, + ) -> QueryResult<()> + where + F: Fn(#(#arg_name,)*) -> Ret + std::panic::UnwindSafe + Send + 'static, + (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> + + StaticallySizedRow<(#(#arg_type,)*), Sqlite>, + Ret: ToSql<#return_type, Sqlite>, + { + conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>( + #sql_name, + true, + move |(#(#arg_name,)*)| f(#(#arg_name,)*), + ) + } + + #[allow(dead_code)] + /// Registers an implementation for this function on the given connection + /// + /// This function must be called for every `SqliteConnection` before + /// this SQL function can be used on SQLite. + /// `register_nondeterministic_impl` should only be used if your + /// function can return different results with the same arguments (e.g. + /// `random`). If your function is deterministic, you should call + /// `register_impl` instead. + pub fn register_nondeterministic_impl( + conn: &mut SqliteConnection, + mut f: F, + ) -> QueryResult<()> + where + F: FnMut(#(#arg_name,)*) -> Ret + std::panic::UnwindSafe + Send + 'static, + (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> + + StaticallySizedRow<(#(#arg_type,)*), Sqlite>, + Ret: ToSql<#return_type, Sqlite>, + { + conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>( + #sql_name, + false, + move |(#(#arg_name,)*)| f(#(#arg_name,)*), + ) + } + }; + } + + if is_supported_on_sqlite && arg_name.is_empty() { + tokens = quote! { + #tokens + + use diesel::sqlite::{Sqlite, SqliteConnection}; + use diesel::serialize::ToSql; + + #[allow(dead_code)] + /// Registers an implementation for this function on the given connection + /// + /// This function must be called for every `SqliteConnection` before + /// this SQL function can be used on SQLite. The implementation must be + /// deterministic (returns the same result given the same arguments). If + /// the function is nondeterministic, call + /// `register_nondeterministic_impl` instead. + pub fn register_impl( + conn: &SqliteConnection, + f: F, + ) -> QueryResult<()> + where + F: Fn() -> Ret + std::panic::UnwindSafe + Send + 'static, + Ret: ToSql<#return_type, Sqlite>, + { + conn.register_noarg_sql_function::<#return_type, _, _>( + #sql_name, + true, + f, + ) + } + + #[allow(dead_code)] + /// Registers an implementation for this function on the given connection + /// + /// This function must be called for every `SqliteConnection` before + /// this SQL function can be used on SQLite. + /// `register_nondeterministic_impl` should only be used if your + /// function can return different results with the same arguments (e.g. + /// `random`). If your function is deterministic, you should call + /// `register_impl` instead. + pub fn register_nondeterministic_impl( + conn: &SqliteConnection, + mut f: F, + ) -> QueryResult<()> + where + F: FnMut() -> Ret + std::panic::UnwindSafe + Send + 'static, + Ret: ToSql<#return_type, Sqlite>, + { + conn.register_noarg_sql_function::<#return_type, _, _>( + #sql_name, + false, + f, + ) + } + }; + } + tokens +} + +#[allow(clippy::too_many_arguments)] +fn generate_tokens_for_aggregate_functions( + mut tokens: TokenStream, + impl_generics_internal: &syn::ImplGenerics<'_>, + impl_generics: &syn::ImplGenerics<'_>, + fn_name: &syn::Ident, + ty_generics: &syn::TypeGenerics<'_>, + arg_name: &[&syn::Ident], + arg_type: &[&syn::Type], + is_supported_on_sqlite: bool, + window: Option<&AttributeSpanWrapper>, + return_type: &syn::Type, + sql_name: &str, +) -> TokenStream { + tokens = quote! { + #tokens + + impl #impl_generics_internal ValidGrouping<__DieselInternal> + for #fn_name #ty_generics + { + type IsAggregate = diesel::expression::is_aggregate::Yes; + } + + impl #impl_generics IsAggregateFunction for #fn_name #ty_generics {} + }; + // we do not support custom window functions for sqlite yet + if is_supported_on_sqlite && window.is_none() { + tokens = quote! { + #tokens + + use diesel::sqlite::{Sqlite, SqliteConnection}; + use diesel::serialize::ToSql; + use diesel::deserialize::{FromSqlRow, StaticallySizedRow}; + use diesel::sqlite::SqliteAggregateFunction; + use diesel::sql_types::IntoNullable; + }; + + match arg_name.len() { + x if x > 1 => { + tokens = quote! { + #tokens + + #[allow(dead_code)] + /// Registers an implementation for this aggregate function on the given connection + /// + /// This function must be called for every `SqliteConnection` before + /// this SQL function can be used on SQLite. The implementation must be + /// deterministic (returns the same result given the same arguments). + pub fn register_impl( + conn: &mut SqliteConnection + ) -> QueryResult<()> + where + A: SqliteAggregateFunction<(#(#arg_name,)*)> + + Send + + 'static + + ::std::panic::UnwindSafe + + ::std::panic::RefUnwindSafe, + A::Output: ToSql<#return_type, Sqlite>, + (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> + + StaticallySizedRow<(#(#arg_type,)*), Sqlite> + + ::std::panic::UnwindSafe, + { + conn.register_aggregate_function::<(#(#arg_type,)*), #return_type, _, _, A>(#sql_name) + } + }; + } + 1 => { + let arg_name = arg_name[0]; + let arg_type = arg_type[0]; + + tokens = quote! { + #tokens + + #[allow(dead_code)] + /// Registers an implementation for this aggregate function on the given connection + /// + /// This function must be called for every `SqliteConnection` before + /// this SQL function can be used on SQLite. The implementation must be + /// deterministic (returns the same result given the same arguments). + pub fn register_impl( + conn: &mut SqliteConnection + ) -> QueryResult<()> + where + A: SqliteAggregateFunction<#arg_name> + + Send + + 'static + + std::panic::UnwindSafe + + std::panic::RefUnwindSafe, + A::Output: ToSql<#return_type, Sqlite>, + #arg_name: FromSqlRow<#arg_type, Sqlite> + + StaticallySizedRow<#arg_type, Sqlite> + + ::std::panic::UnwindSafe, + { + conn.register_aggregate_function::<#arg_type, #return_type, _, _, A>(#sql_name) + } + }; + } + _ => (), + } + } + tokens +} + +fn function_cannot_be_called_directly( + attributes: &[AttributeSpanWrapper], +) -> Result { + let mut has_aggregate = false; + let mut has_window = false; + let mut has_require_within = false; + for attr in attributes { + has_aggregate = has_aggregate || matches!(attr.item, SqlFunctionAttribute::Aggregate(..)); + has_window = has_window || matches!(attr.item, SqlFunctionAttribute::Window(..)); + has_require_within = + has_require_within || matches!(attr.item, SqlFunctionAttribute::RequireWithin(..)); + if has_require_within && (has_aggregate || has_window) { + return Err(syn::Error::new(attr.ident_span, "cannot have `#[require_within]` and `#[aggregate]` or `#[window]` on the same function")); + } + } + Ok(has_require_within || (has_window && !has_aggregate)) +} + pub(crate) struct SqlFunctionDecl { - attributes: Vec, + attributes: Vec>, fn_token: Token![fn], fn_name: Ident, generics: Generics, @@ -449,6 +583,67 @@ pub(crate) struct SqlFunctionDecl { impl Parse for SqlFunctionDecl { fn parse(input: ParseStream) -> Result { let attributes = Attribute::parse_outer(input)?; + + let attributes = attributes + .into_iter() + .map(|attr| match &attr.meta { + syn::Meta::NameValue(syn::MetaNameValue { + path, + value: + syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(sql_name), + .. + }), + .. + }) if path.is_ident("sql_name") => Ok(AttributeSpanWrapper { + attribute_span: attr.span(), + ident_span: sql_name.span(), + item: SqlFunctionAttribute::SqlName( + path.require_ident()?.clone(), + sql_name.clone(), + ), + }), + syn::Meta::Path(path) if path.is_ident("aggregate") => Ok(AttributeSpanWrapper { + attribute_span: attr.span(), + ident_span: path.span(), + item: SqlFunctionAttribute::Aggregate(path.require_ident()?.clone()), + }), + syn::Meta::Path(path) if path.is_ident("window") => Ok(AttributeSpanWrapper { + attribute_span: attr.span(), + ident_span: path.span(), + item: SqlFunctionAttribute::Window( + path.require_ident()?.clone(), + BackendRestriction::None, + ), + }), + syn::Meta::Path(path) if path.is_ident("require_within") => { + Ok(AttributeSpanWrapper { + attribute_span: attr.span(), + ident_span: path.span(), + item: SqlFunctionAttribute::RequireWithin(path.require_ident()?.clone()), + }) + } + syn::Meta::NameValue(_) | syn::Meta::Path(_) => Ok(AttributeSpanWrapper { + attribute_span: attr.span(), + ident_span: attr.span(), + item: SqlFunctionAttribute::Other(attr), + }), + syn::Meta::List(_) => { + let name = attr.meta.path().require_ident()?; + let attribute_span = attr.meta.span(); + attr.clone() + .parse_args_with(|input: &syn::parse::ParseBuffer| { + SqlFunctionAttribute::parse_attr( + name.clone(), + input, + attr.clone(), + attribute_span, + ) + }) + } + }) + .collect::>>()?; + let fn_token: Token![fn] = input.parse()?; let fn_name = Ident::parse(input)?; let generics = Generics::parse(input)?; @@ -538,3 +733,342 @@ fn is_sqlite_type(ty: &Type) -> bool { ] .contains(&ident.as_str()) } + +#[derive(Default, Clone, Debug)] +enum BackendRestriction { + #[default] + None, + SqlDialect(syn::Ident, syn::Ident, syn::Path), + BackendBound( + syn::Ident, + syn::punctuated::Punctuated, + ), + Backends( + syn::Ident, + syn::punctuated::Punctuated, + ), +} + +impl BackendRestriction { + fn parse_from(input: &syn::parse::ParseBuffer<'_>) -> Result { + if input.is_empty() { + return Ok(Self::None); + } + Self::parse(input) + } + + fn parse_backends( + input: &syn::parse::ParseBuffer<'_>, + name: Ident, + ) -> Result { + let backends = Punctuated::parse_terminated(input)?; + Ok(Self::Backends(name, backends)) + } + + fn parse_sql_dialect( + content: &syn::parse::ParseBuffer<'_>, + name: Ident, + ) -> Result { + let dialect = content.parse()?; + let _del: syn::Token![,] = content.parse()?; + let dialect_variant = content.parse()?; + + Ok(Self::SqlDialect(name, dialect, dialect_variant)) + } + + fn parse_backend_bounds( + input: &syn::parse::ParseBuffer<'_>, + name: Ident, + ) -> Result { + let restrictions = Punctuated::parse_terminated(input)?; + Ok(Self::BackendBound(name, restrictions)) + } + + fn generate_all_window_fragment_impls( + &self, + mut generics: Generics, + ty_generics: &TypeGenerics<'_>, + fn_name: &syn::Ident, + ) -> TokenStream { + generics.params.push(parse_quote!(__P)); + generics.params.push(parse_quote!(__O)); + generics.params.push(parse_quote!(__F)); + match *self { + BackendRestriction::None => { + generics.params.push(parse_quote!(__DieselInternal)); + let (impl_generics, _, _) = generics.split_for_impl(); + Self::generate_window_fragment_impl( + parse_quote!(__DieselInternal), + Some(parse_quote!(__DieselInternal: diesel::backend::Backend,)), + &impl_generics, + ty_generics, + fn_name, + None, + ) + } + BackendRestriction::SqlDialect(_, ref dialect, ref dialect_type) => { + generics.params.push(parse_quote!(__DieselInternal)); + let (impl_generics, _, _) = generics.split_for_impl(); + let specific_impl = Self::generate_window_fragment_impl( + parse_quote!(__DieselInternal), + Some( + parse_quote!(__DieselInternal: diesel::backend::Backend + diesel::backend::SqlDialect<#dialect = #dialect_type>,), + ), + &impl_generics, + ty_generics, + fn_name, + Some(dialect_type), + ); + quote::quote! { + impl #impl_generics WindowFunctionFragment<__DieselInternal> + for #fn_name #ty_generics + where + Self: WindowFunctionFragment<__DieselInternal, <__DieselInternal as diesel::backend::SqlDialect>::#dialect>, + __DieselInternal: diesel::backend::Backend, + { + } + + #specific_impl + } + } + BackendRestriction::BackendBound(_, ref restriction) => { + generics.params.push(parse_quote!(__DieselInternal)); + let (impl_generics, _, _) = generics.split_for_impl(); + Self::generate_window_fragment_impl( + parse_quote!(__DieselInternal), + Some(parse_quote!(__DieselInternal: diesel::backend::Backend + #restriction,)), + &impl_generics, + ty_generics, + fn_name, + None, + ) + } + BackendRestriction::Backends(_, ref backends) => { + let (impl_generics, _, _) = generics.split_for_impl(); + let backends = backends.iter().map(|b| { + Self::generate_window_fragment_impl( + quote! {#b}, + None, + &impl_generics, + ty_generics, + fn_name, + None, + ) + }); + + parse_quote!(#(#backends)*) + } + } + } + + fn generate_window_fragment_impl( + backend: TokenStream, + backend_bound: Option, + impl_generics: &ImplGenerics<'_>, + ty_generics: &TypeGenerics<'_>, + fn_name: &syn::Ident, + dialect: Option<&syn::Path>, + ) -> TokenStream { + quote::quote! { + impl #impl_generics WindowFunctionFragment<#fn_name #ty_generics, #backend, #dialect> for OverClause<__P, __O, __F> + where #backend_bound + { + + } + } + } + + fn generate_all_queryfragment_impls( + &self, + mut generics: Generics, + ty_generics: &TypeGenerics<'_>, + arg_name: &[&syn::Ident], + fn_name: &syn::Ident, + ) -> proc_macro2::TokenStream { + match *self { + BackendRestriction::None => { + generics.params.push(parse_quote!(__DieselInternal)); + let (impl_generics, _, _) = generics.split_for_impl(); + Self::generate_queryfragment_impl( + parse_quote!(__DieselInternal), + Some(parse_quote!(__DieselInternal: diesel::backend::Backend,)), + &impl_generics, + ty_generics, + arg_name, + fn_name, + None, + ) + } + BackendRestriction::BackendBound(_, ref restriction) => { + generics.params.push(parse_quote!(__DieselInternal)); + let (impl_generics, _, _) = generics.split_for_impl(); + Self::generate_queryfragment_impl( + parse_quote!(__DieselInternal), + Some(parse_quote!(__DieselInternal: diesel::backend::Backend + #restriction,)), + &impl_generics, + ty_generics, + arg_name, + fn_name, + None, + ) + } + BackendRestriction::SqlDialect(_, ref dialect, ref dialect_type) => { + generics.params.push(parse_quote!(__DieselInternal)); + let (impl_generics, _, _) = generics.split_for_impl(); + let specific_impl = Self::generate_queryfragment_impl( + parse_quote!(__DieselInternal), + Some( + parse_quote!(__DieselInternal: diesel::backend::Backend + diesel::backend::SqlDialect<#dialect = #dialect_type>,), + ), + &impl_generics, + ty_generics, + arg_name, + fn_name, + Some(dialect_type), + ); + quote::quote! { + impl #impl_generics QueryFragment<__DieselInternal> + for #fn_name #ty_generics + where + Self: QueryFragment<__DieselInternal, <__DieselInternal as diesel::backend::SqlDialect>::#dialect>, + __DieselInternal: diesel::backend::Backend, + { + fn walk_ast<'__b>(&'__b self, mut out: AstPass<'_, '__b, __DieselInternal>) -> QueryResult<()> { + ::#dialect>>::walk_ast(self, out) + } + + } + + #specific_impl + } + } + BackendRestriction::Backends(_, ref backends) => { + let (impl_generics, _, _) = generics.split_for_impl(); + let backends = backends.iter().map(|b| { + Self::generate_queryfragment_impl( + quote! {#b}, + None, + &impl_generics, + ty_generics, + arg_name, + fn_name, + None, + ) + }); + + parse_quote!(#(#backends)*) + } + } + } + + fn generate_queryfragment_impl( + backend: proc_macro2::TokenStream, + backend_bound: Option, + impl_generics: &ImplGenerics<'_>, + ty_generics: &TypeGenerics<'_>, + arg_name: &[&syn::Ident], + fn_name: &syn::Ident, + dialect: Option<&syn::Path>, + ) -> proc_macro2::TokenStream { + quote::quote! { + impl #impl_generics QueryFragment<#backend, #dialect> + for #fn_name #ty_generics + where + #backend_bound + #(#arg_name: QueryFragment<#backend>,)* + { + fn walk_ast<'__b>(&'__b self, mut out: AstPass<'_, '__b, #backend>) -> QueryResult<()>{ + out.push_sql(>::FUNCTION_NAME); + out.push_sql("("); + self.walk_arguments(out.reborrow())?; + out.push_sql(")"); + Ok(()) + } + } + } + } +} + +impl Parse for BackendRestriction { + fn parse(input: ParseStream) -> Result { + let name: syn::Ident = input.parse()?; + let name_str = name.to_string(); + let content; + parenthesized!(content in input); + match &*name_str { + "backends" => Self::parse_backends(&content, name), + "dialect" => Self::parse_sql_dialect(&content, name), + "backend_bounds" => Self::parse_backend_bounds(&content, name), + _ => Err(syn::Error::new( + name.span(), + format!("unexpected option `{name_str}`"), + )), + } + } +} + +#[derive(Debug, Clone)] +enum SqlFunctionAttribute { + Aggregate(Ident), + Window(Ident, BackendRestriction), + SqlName(Ident, LitStr), + Restriction(BackendRestriction), + RequireWithin(Ident), + Other(Attribute), +} + +impl MySpanned for SqlFunctionAttribute { + fn span(&self) -> proc_macro2::Span { + match self { + SqlFunctionAttribute::Restriction(BackendRestriction::Backends(ref ident, ..)) + | SqlFunctionAttribute::Restriction(BackendRestriction::SqlDialect(ref ident, ..)) + | SqlFunctionAttribute::Restriction(BackendRestriction::BackendBound(ref ident, ..)) + | SqlFunctionAttribute::Aggregate(ref ident, ..) + | SqlFunctionAttribute::Window(ref ident, ..) + | SqlFunctionAttribute::RequireWithin(ref ident) + | SqlFunctionAttribute::SqlName(ref ident, ..) => ident.span(), + SqlFunctionAttribute::Restriction(BackendRestriction::None) => { + unreachable!("We do not construct that") + } + SqlFunctionAttribute::Other(ref attribute) => attribute.span(), + } + } +} + +impl SqlFunctionAttribute { + fn parse_attr( + name: Ident, + input: &syn::parse::ParseBuffer<'_>, + attr: Attribute, + attribute_span: proc_macro2::Span, + ) -> Result> { + let name_str = name.to_string(); + let parsed_attr = match &*name_str { + "window" => BackendRestriction::parse_from(input).map(|r| Self::Window(name, r))?, + "sql_name" => parse_eq(input, "sql_name = \"SUM\"").map(|v| Self::SqlName(name, v))?, + "backends" => BackendRestriction::parse_backends(input, name).map(Self::Restriction)?, + "dialect" => { + BackendRestriction::parse_sql_dialect(input, name).map(Self::Restriction)? + } + "backend_bounds" => { + BackendRestriction::parse_backend_bounds(input, name).map(Self::Restriction)? + } + _ => { + // empty the parse buffer otherwise syn will return an error + let _ = input.step(|cursor| { + let mut rest = *cursor; + while let Some((_, next)) = rest.token_tree() { + rest = next; + } + Ok(((), rest)) + }); + SqlFunctionAttribute::Other(attr) + } + }; + Ok(AttributeSpanWrapper { + ident_span: parsed_attr.span(), + item: parsed_attr, + attribute_span, + }) + } +} diff --git a/diesel_tests/tests/aggregate_expressions.rs b/diesel_tests/tests/aggregate_expressions.rs new file mode 100644 index 000000000000..2e590ae059d1 --- /dev/null +++ b/diesel_tests/tests/aggregate_expressions.rs @@ -0,0 +1,74 @@ +#![cfg(feature = "postgres")] // todo +use crate::schema::connection_with_sean_and_tess_in_users_table; +use crate::schema::users; +use diesel::dsl::{ + self, frame, AggregateExpressionMethods, FrameBoundDsl, FrameClauseDsl, WindowExpressionMethods, +}; +use diesel::prelude::*; + +#[test] +fn test1() { + let mut conn = connection_with_sean_and_tess_in_users_table(); + + let query = users::table.select(dsl::count(users::id).filter_aggregate(users::name.eq("Sean"))); + dbg!(diesel::debug_query::(&query)); + + let res = query.get_result::(&mut conn).unwrap(); + assert_eq!(res, 1); + + let query2 = users::table.select( + dsl::count(users::id) + .distinct() + .filter_aggregate(users::name.eq("Sean")), + ); + dbg!(diesel::debug_query::(&query2)); + let res = query2.get_result::(&mut conn).unwrap(); + dbg!(res); + + let query3 = users::table.select( + dsl::count(users::id) + .distinct() + .filter_aggregate(users::name.eq("Sean")) + .order_aggregate(users::id), + ); + dbg!(diesel::debug_query::(&query3)); + let res = query3.get_result::(&mut conn).unwrap(); + dbg!(res); + todo!() +} + +#[test] +fn test2() { + let mut conn = connection_with_sean_and_tess_in_users_table(); + + let query = users::table.select(dsl::count(users::id).over()); + dbg!(diesel::debug_query::(&query)); + + let res = query.get_result::(&mut conn).unwrap(); + assert_eq!(res, 2); + + let query = users::table.select(dsl::count(users::id).over().partition_by(users::name)); + dbg!(diesel::debug_query::(&query)); + + let res = query.get_result::(&mut conn).unwrap(); + assert_eq!(res, 1); + + let query = users::table.select(dsl::count(users::id).over().window_order(users::name)); + dbg!(diesel::debug_query::(&query)); + + let res = query.get_result::(&mut conn).unwrap(); + assert_eq!(res, 1); + + let query = users::table.select( + dsl::count(users::id) + .over() + .window_order(users::name) + .partition_by(users::name) + .frame_by(frame::Rows.start_with(2.preceding())), + ); + dbg!(diesel::debug_query::(&query)); + + let res = query.get_result::(&mut conn).unwrap(); + assert_eq!(res, 1); + todo!() +} diff --git a/diesel_tests/tests/lib.rs b/diesel_tests/tests/lib.rs index 07743347201f..1b7d6341ddf1 100644 --- a/diesel_tests/tests/lib.rs +++ b/diesel_tests/tests/lib.rs @@ -6,6 +6,7 @@ extern crate assert_matches; #[macro_use] extern crate diesel; +mod aggregate_expressions; mod alias; #[cfg(not(feature = "sqlite"))] mod annotations;