From b4b267ae4b2ad326d207609538add9f0f9ead506 Mon Sep 17 00:00:00 2001 From: UBarney Date: Tue, 24 Dec 2024 10:13:40 +0800 Subject: [PATCH] Support 1 or 3 arg in generate_series() UDTF (#13856) * Support 1 or 3 args in generate_series() UDTF * address comment --- .../functions-table/src/generate_series.rs | 168 +++++++++++------- .../test_files/table_functions.slt | 63 ++++++- 2 files changed, 154 insertions(+), 77 deletions(-) diff --git a/datafusion/functions-table/src/generate_series.rs b/datafusion/functions-table/src/generate_series.rs index ced43ea8f00c..887daa71ec55 100644 --- a/datafusion/functions-table/src/generate_series.rs +++ b/datafusion/functions-table/src/generate_series.rs @@ -22,7 +22,7 @@ use async_trait::async_trait; use datafusion_catalog::Session; use datafusion_catalog::TableFunctionImpl; use datafusion_catalog::TableProvider; -use datafusion_common::{not_impl_err, plan_err, Result, ScalarValue}; +use datafusion_common::{plan_err, Result, ScalarValue}; use datafusion_expr::{Expr, TableType}; use datafusion_physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec}; use datafusion_physical_plan::ExecutionPlan; @@ -30,28 +30,45 @@ use parking_lot::RwLock; use std::fmt; use std::sync::Arc; -/// Table that generates a series of integers from `start`(inclusive) to `end`(inclusive) +/// Indicates the arguments used for generating a series. +#[derive(Debug, Clone)] +enum GenSeriesArgs { + /// ContainsNull signifies that at least one argument(start, end, step) was null, thus no series will be generated. + ContainsNull, + /// AllNotNullArgs holds the start, end, and step values for generating the series when all arguments are not null. + AllNotNullArgs { start: i64, end: i64, step: i64 }, +} + +/// Table that generates a series of integers from `start`(inclusive) to `end`(inclusive), incrementing by step #[derive(Debug, Clone)] struct GenerateSeriesTable { schema: SchemaRef, - // None if input is Null - start: Option, - // None if input is Null - end: Option, + args: GenSeriesArgs, } -/// Table state that generates a series of integers from `start`(inclusive) to `end`(inclusive) +/// Table state that generates a series of integers from `start`(inclusive) to `end`(inclusive), incrementing by step #[derive(Debug, Clone)] struct GenerateSeriesState { schema: SchemaRef, start: i64, // Kept for display end: i64, + step: i64, batch_size: usize, /// Tracks current position when generating table current: i64, } +impl GenerateSeriesState { + fn reach_end(&self, val: i64) -> bool { + if self.step > 0 { + return val > self.end; + } + + val < self.end + } +} + /// Detail to display for 'Explain' plan impl fmt::Display for GenerateSeriesState { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -65,19 +82,19 @@ impl fmt::Display for GenerateSeriesState { impl LazyBatchGenerator for GenerateSeriesState { fn generate_next_batch(&mut self) -> Result> { - // Check if we've reached the end - if self.current > self.end { + let mut buf = Vec::with_capacity(self.batch_size); + while buf.len() < self.batch_size && !self.reach_end(self.current) { + buf.push(self.current); + self.current += self.step; + } + let array = Int64Array::from(buf); + + if array.is_empty() { return Ok(None); } - // Construct batch - let batch_end = (self.current + self.batch_size as i64 - 1).min(self.end); - let array = Int64Array::from_iter_values(self.current..=batch_end); let batch = RecordBatch::try_new(self.schema.clone(), vec![Arc::new(array)])?; - // Update current position for next batch - self.current = batch_end + 1; - Ok(Some(batch)) } } @@ -104,39 +121,31 @@ impl TableProvider for GenerateSeriesTable { _limit: Option, ) -> Result> { let batch_size = state.config_options().execution.batch_size; - match (self.start, self.end) { - (Some(start), Some(end)) => { - if start > end { - return plan_err!( - "End value must be greater than or equal to start value" - ); - } - - Ok(Arc::new(LazyMemoryExec::try_new( - self.schema.clone(), - vec![Arc::new(RwLock::new(GenerateSeriesState { - schema: self.schema.clone(), - start, - end, - current: start, - batch_size, - }))], - )?)) - } - _ => { - // Either start or end is None, return a generator that outputs 0 rows - Ok(Arc::new(LazyMemoryExec::try_new( - self.schema.clone(), - vec![Arc::new(RwLock::new(GenerateSeriesState { - schema: self.schema.clone(), - start: 0, - end: 0, - current: 1, - batch_size, - }))], - )?)) - } - } + + let state = match self.args { + // if args have null, then return 0 row + GenSeriesArgs::ContainsNull => GenerateSeriesState { + schema: self.schema.clone(), + start: 0, + end: 0, + step: 1, + current: 1, + batch_size, + }, + GenSeriesArgs::AllNotNullArgs { start, end, step } => GenerateSeriesState { + schema: self.schema.clone(), + start, + end, + step, + current: start, + batch_size, + }, + }; + + Ok(Arc::new(LazyMemoryExec::try_new( + self.schema.clone(), + vec![Arc::new(RwLock::new(state))], + )?)) } } @@ -144,37 +153,58 @@ impl TableProvider for GenerateSeriesTable { pub struct GenerateSeriesFunc {} impl TableFunctionImpl for GenerateSeriesFunc { - // Check input `exprs` type and number. Input validity check (e.g. start <= end) - // will be performed in `TableProvider::scan` fn call(&self, exprs: &[Expr]) -> Result> { - // TODO: support 1 or 3 arguments following DuckDB: - // - if exprs.len() == 3 || exprs.len() == 1 { - return not_impl_err!("generate_series does not support 1 or 3 arguments"); + if exprs.is_empty() || exprs.len() > 3 { + return plan_err!("generate_series function requires 1 to 3 arguments"); } - if exprs.len() != 2 { - return plan_err!("generate_series expects 2 arguments"); + let mut normalize_args = Vec::new(); + for expr in exprs { + match expr { + Expr::Literal(ScalarValue::Null) => {} + Expr::Literal(ScalarValue::Int64(Some(n))) => normalize_args.push(*n), + _ => return plan_err!("First argument must be an integer literal"), + }; } - let start = match &exprs[0] { - Expr::Literal(ScalarValue::Null) => None, - Expr::Literal(ScalarValue::Int64(Some(n))) => Some(*n), - _ => return plan_err!("First argument must be an integer literal"), - }; - - let end = match &exprs[1] { - Expr::Literal(ScalarValue::Null) => None, - Expr::Literal(ScalarValue::Int64(Some(n))) => Some(*n), - _ => return plan_err!("Second argument must be an integer literal"), - }; - let schema = Arc::new(Schema::new(vec![Field::new( "value", DataType::Int64, false, )])); - Ok(Arc::new(GenerateSeriesTable { schema, start, end })) + if normalize_args.len() != exprs.len() { + // contain null + return Ok(Arc::new(GenerateSeriesTable { + schema, + args: GenSeriesArgs::ContainsNull, + })); + } + + let (start, end, step) = match &normalize_args[..] { + [end] => (0, *end, 1), + [start, end] => (*start, *end, 1), + [start, end, step] => (*start, *end, *step), + _ => { + return plan_err!("generate_series function requires 1 to 3 arguments"); + } + }; + + if start > end && step > 0 { + return plan_err!("start is bigger than end, but increment is positive: cannot generate infinite series"); + } + + if start < end && step < 0 { + return plan_err!("start is smaller than end, but increment is negative: cannot generate infinite series"); + } + + if step == 0 { + return plan_err!("step cannot be zero"); + } + + Ok(Arc::new(GenerateSeriesTable { + schema, + args: GenSeriesArgs::AllNotNullArgs { start, end, step }, + })) } } diff --git a/datafusion/sqllogictest/test_files/table_functions.slt b/datafusion/sqllogictest/test_files/table_functions.slt index 79294993dded..2769da03b8bb 100644 --- a/datafusion/sqllogictest/test_files/table_functions.slt +++ b/datafusion/sqllogictest/test_files/table_functions.slt @@ -16,6 +16,18 @@ # under the License. # Test generate_series table function +query I +SELECT * FROM generate_series(6) +---- +0 +1 +2 +3 +4 +5 +6 + + query I rowsort SELECT * FROM generate_series(1, 5) @@ -39,11 +51,35 @@ SELECT * FROM generate_series(3, 6) 5 6 +# #generated_data > batch_size +query I +SELECT count(v1) FROM generate_series(-66666,66666) t1(v1) +---- +133333 + + + + query I rowsort SELECT SUM(v1) FROM generate_series(1, 5) t1(v1) ---- 15 +query I +SELECT * FROM generate_series(6, -1, -2) +---- +6 +4 +2 +0 + +query I +SELECT * FROM generate_series(6, 66, 666) +---- +6 + + + # Test generate_series with WHERE clause query I rowsort SELECT * FROM generate_series(1, 10) t1(v1) WHERE v1 % 2 = 0 @@ -93,6 +129,10 @@ ON a.v1 = b.v1 - 1 2 3 3 4 +# +# Test generate_series with null arguments +# + query I SELECT * FROM generate_series(NULL, 5) ---- @@ -105,6 +145,11 @@ query I SELECT * FROM generate_series(NULL, NULL) ---- +query I +SELECT * FROM generate_series(1, 5, NULL) +---- + + query TT EXPLAIN SELECT * FROM generate_series(1, 5) ---- @@ -115,20 +160,22 @@ physical_plan LazyMemoryExec: partitions=1, batch_generators=[generate_series: s # Test generate_series with invalid arguments # -query error DataFusion error: Error during planning: End value must be greater than or equal to start value +query error DataFusion error: Error during planning: start is bigger than end, but increment is positive: cannot generate infinite series SELECT * FROM generate_series(5, 1) -statement error DataFusion error: This feature is not implemented: generate_series does not support 1 or 3 arguments -SELECT * FROM generate_series(1, 5, NULL) +query error DataFusion error: Error during planning: start is smaller than end, but increment is negative: cannot generate infinite series +SELECT * FROM generate_series(-6, 6, -1) + +query error DataFusion error: Error during planning: step cannot be zero +SELECT * FROM generate_series(-6, 6, 0) + +query error DataFusion error: Error during planning: start is bigger than end, but increment is positive: cannot generate infinite series +SELECT * FROM generate_series(6, -6, 1) -statement error DataFusion error: This feature is not implemented: generate_series does not support 1 or 3 arguments -SELECT * FROM generate_series(1) -statement error DataFusion error: Error during planning: generate_series expects 2 arguments +statement error DataFusion error: Error during planning: generate_series function requires 1 to 3 arguments SELECT * FROM generate_series(1, 2, 3, 4) -statement error DataFusion error: Error during planning: Second argument must be an integer literal -SELECT * FROM generate_series(1, '2') statement error DataFusion error: Error during planning: First argument must be an integer literal SELECT * FROM generate_series('foo', 'bar')