Skip to content

Commit b177b65

Browse files
ueshinHyukjinKwon
authored andcommitted
[SPARK-53431][PYTHON] Fix Python UDTF with named table arguments in DataFrame API
### What changes were proposed in this pull request? Fixes Python UDTF with named table arguments in DataFrame API. ### Why are the changes needed? Named table arguments fails with the following error: ```py >>> from pyspark.sql.functions import * >>> >>> udtf(returnType="x string") ... class TestUDTF: ... def eval(self, x): ... yield str(x), ... >>> TestUDTF(x=spark.range(10).asTable()).show() Traceback (most recent call last): ... py4j.Py4JException: Method namedArgumentExpression([class java.lang.String, class org.apache.spark.sql.TableArg]) does not exist ... ``` Also, Spark Connect doesn't recognize table arguments in `analyze`. ### Does this PR introduce _any_ user-facing change? Yes, named table arguments will be available in DataFrame API. ### How was this patch tested? Added the related tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #52171 from ueshin/issues/SPARK-53431/named_table_arguments. Authored-by: Takuya Ueshin <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent 5bf4a29 commit b177b65

File tree

3 files changed

+97
-8
lines changed

3 files changed

+97
-8
lines changed

python/pyspark/sql/tests/test_udtf.py

Lines changed: 90 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,26 @@ def eval(self, a: int, b: int) -> Iterator:
170170

171171
self.spark.udtf.register("testUDTF", TestUDTF)
172172

173-
assertDataFrameEqual(
174-
self.spark.sql("values (0, 1), (1, 2) t(a, b)").lateralJoin(
175-
TestUDTF(col("a").outer(), col("b").outer())
176-
),
177-
self.spark.sql("SELECT * FROM values (0, 1), (1, 2) t(a, b), LATERAL testUDTF(a, b)"),
178-
)
173+
for i, df in enumerate(
174+
[
175+
self.spark.sql("values (0, 1), (1, 2) t(a, b)").lateralJoin(
176+
TestUDTF(col("a").outer(), col("b").outer())
177+
),
178+
self.spark.sql("values (0, 1), (1, 2) t(a, b)").lateralJoin(
179+
TestUDTF(a=col("a").outer(), b=col("b").outer())
180+
),
181+
self.spark.sql("values (0, 1), (1, 2) t(a, b)").lateralJoin(
182+
TestUDTF(b=col("b").outer(), a=col("a").outer())
183+
),
184+
]
185+
):
186+
with self.subTest(query_no=i):
187+
assertDataFrameEqual(
188+
df,
189+
self.spark.sql(
190+
"SELECT * FROM values (0, 1), (1, 2) t(a, b), LATERAL testUDTF(a, b)"
191+
),
192+
)
179193

180194
@udtf(returnType="a: int")
181195
class TestUDTF:
@@ -2118,6 +2132,25 @@ def eval(self, a, b):
21182132
with self.subTest(query_no=i):
21192133
assertDataFrameEqual(df, [Row(a=10)])
21202134

2135+
def test_udtf_with_named_table_arguments(self):
2136+
@udtf(returnType="a: int")
2137+
class TestUDTF:
2138+
def eval(self, a, b):
2139+
yield a.id,
2140+
2141+
self.spark.udtf.register("test_udtf", TestUDTF)
2142+
2143+
for i, df in enumerate(
2144+
[
2145+
self.spark.sql("SELECT * FROM test_udtf(a => TABLE(FROM range(3)), b => 'x')"),
2146+
self.spark.sql("SELECT * FROM test_udtf(b => 'x', a => TABLE(FROM range(3)))"),
2147+
TestUDTF(a=self.spark.range(3).asTable(), b=lit("x")),
2148+
TestUDTF(b=lit("x"), a=self.spark.range(3).asTable()),
2149+
]
2150+
):
2151+
with self.subTest(query_no=i):
2152+
assertDataFrameEqual(df, [Row(a=i) for i in range(3)])
2153+
21212154
def test_udtf_with_named_arguments_negative(self):
21222155
@udtf(returnType="a: int")
21232156
class TestUDTF:
@@ -2170,6 +2203,25 @@ def eval(self, **kwargs):
21702203
with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"):
21712204
self.spark.sql("SELECT * FROM test_udtf(a => 10, 'x')").show()
21722205

2206+
def test_udtf_with_table_argument_and_kwargs(self):
2207+
@udtf(returnType="a: int, b: string")
2208+
class TestUDTF:
2209+
def eval(self, **kwargs):
2210+
yield kwargs["a"].id, kwargs["b"]
2211+
2212+
self.spark.udtf.register("test_udtf", TestUDTF)
2213+
2214+
for i, df in enumerate(
2215+
[
2216+
self.spark.sql("SELECT * FROM test_udtf(a => TABLE(FROM range(3)), b => 'x')"),
2217+
self.spark.sql("SELECT * FROM test_udtf(b => 'x', a => TABLE(FROM range(3)))"),
2218+
TestUDTF(a=self.spark.range(3).asTable(), b=lit("x")),
2219+
TestUDTF(b=lit("x"), a=self.spark.range(3).asTable()),
2220+
]
2221+
):
2222+
with self.subTest(query_no=i):
2223+
assertDataFrameEqual(df, [Row(a=i, b="x") for i in range(3)])
2224+
21732225
def test_udtf_with_analyze_kwargs(self):
21742226
@udtf
21752227
class TestUDTF:
@@ -2204,6 +2256,38 @@ def eval(self, **kwargs):
22042256
with self.subTest(query_no=i):
22052257
assertDataFrameEqual(df, [Row(a=10, b="x")])
22062258

2259+
def test_udtf_with_table_argument_and_analyze_kwargs(self):
2260+
@udtf
2261+
class TestUDTF:
2262+
@staticmethod
2263+
def analyze(**kwargs: AnalyzeArgument) -> AnalyzeResult:
2264+
assert isinstance(kwargs["a"].dataType, StructType)
2265+
assert kwargs["a"].isTable is True
2266+
assert isinstance(kwargs["b"].dataType, StringType)
2267+
assert kwargs["b"].value == "x"
2268+
assert not kwargs["b"].isTable
2269+
return AnalyzeResult(
2270+
StructType(
2271+
[StructField(key, arg.dataType) for key, arg in sorted(kwargs.items())]
2272+
)
2273+
)
2274+
2275+
def eval(self, **kwargs):
2276+
yield tuple(value for _, value in sorted(kwargs.items()))
2277+
2278+
self.spark.udtf.register("test_udtf", TestUDTF)
2279+
2280+
for i, df in enumerate(
2281+
[
2282+
self.spark.sql("SELECT * FROM test_udtf(a => TABLE(FROM range(3)), b => 'x')"),
2283+
self.spark.sql("SELECT * FROM test_udtf(b => 'x', a => TABLE(FROM range(3)))"),
2284+
TestUDTF(a=self.spark.range(3).asTable(), b=lit("x")),
2285+
TestUDTF(b=lit("x"), a=self.spark.range(3).asTable()),
2286+
]
2287+
):
2288+
with self.subTest(query_no=i):
2289+
assertDataFrameEqual(df, [Row(a=Row(id=i), b="x") for i in range(3)])
2290+
22072291
def test_udtf_with_named_arguments_lateral_join(self):
22082292
@udtf
22092293
class TestUDTF:

sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.api.python.DechunkedInputStream
2626
import org.apache.spark.internal.Logging
2727
import org.apache.spark.internal.LogKeys.CLASS_LOADER
2828
import org.apache.spark.security.SocketAuthServer
29-
import org.apache.spark.sql.{internal, Column, DataFrame, Row, SparkSession}
29+
import org.apache.spark.sql.{internal, Column, DataFrame, Row, SparkSession, TableArg}
3030
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
3131
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TableFunctionRegistry}
3232
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
@@ -182,6 +182,9 @@ private[sql] object PythonSQLUtils extends Logging {
182182
def namedArgumentExpression(name: String, e: Column): Column =
183183
Column(NamedArgumentExpression(name, expression(e)))
184184

185+
def namedArgumentExpression(name: String, e: TableArg): Column =
186+
Column(NamedArgumentExpression(name, e.expression))
187+
185188
@scala.annotation.varargs
186189
def fn(name: String, arguments: Column*): Column = Column.fn(name, arguments: _*)
187190

sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import net.razorvine.pickle.Pickler
2525

2626
import org.apache.spark.api.python.{PythonEvalType, PythonFunction, PythonWorkerUtils, SpecialLengths}
2727
import org.apache.spark.sql.{Column, TableArg}
28-
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, Expression, FunctionTableSubqueryArgumentExpression, NamedArgumentExpression, NullsFirst, NullsLast, PythonUDAF, PythonUDF, PythonUDTF, PythonUDTFAnalyzeResult, PythonUDTFSelectedExpression, SortOrder, UnresolvedPolymorphicPythonUDTF}
28+
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, Expression, FunctionTableSubqueryArgumentExpression, NamedArgumentExpression, NullsFirst, NullsLast, PythonUDAF, PythonUDF, PythonUDTF, PythonUDTFAnalyzeResult, PythonUDTFSelectedExpression, SortOrder, UnresolvedPolymorphicPythonUDTF, UnresolvedTableArgPlanId}
2929
import org.apache.spark.sql.catalyst.parser.ParserInterface
3030
import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, NamedParametersSupport, OneRowRelation}
3131
import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession}
@@ -127,7 +127,9 @@ case class UserDefinedPythonTableFunction(
127127
// `UnresolvedAttribute` to construct lateral join.
128128
val tableArgs = exprs.map {
129129
case _: FunctionTableSubqueryArgumentExpression => true
130+
case _: UnresolvedTableArgPlanId => true
130131
case NamedArgumentExpression(_, _: FunctionTableSubqueryArgumentExpression) => true
132+
case NamedArgumentExpression(_, _: UnresolvedTableArgPlanId) => true
131133
case _ => false
132134
}
133135

0 commit comments

Comments
 (0)