Python 用户定义表函数 (UDTF)#
Spark 3.5 引入了 Python 用户定义表函数 (UDTF),这是一种新型的用户定义函数。与每次调用返回单个结果值的标量函数不同,每个 UDTF 在查询的 FROM
子句中被调用,并返回整个表作为输出。每个 UDTF 调用可以接受零个或多个参数。这些参数可以是标量表达式,也可以是表示整个输入表的表参数。
实现 Python UDTF#
要实现 Python UDTF,您首先需要定义一个实现以下方法的类
class PythonUDTF:
def __init__(self) -> None:
"""
Initializes the user-defined table function (UDTF). This is optional.
This method serves as the default constructor and is called once when the
UDTF is instantiated on the executor side.
Any class fields assigned in this method will be available for subsequent
calls to the `eval` and `terminate` methods. This class instance will remain
alive until all rows in the current partition have been consumed by the `eval`
method.
Notes
-----
- You cannot create or reference the Spark session within the UDTF. Any
attempt to do so will result in a serialization error.
- If the below `analyze` method is implemented, it is also possible to define this
method as: `__init__(self, analyze_result: AnalyzeResult)`. In this case, the result
of the `analyze` method is passed into all future instantiations of this UDTF class.
In this way, the UDTF may inspect the schema and metadata of the output table as
needed during execution of other methods in this class. Note that it is possible to
create a subclass of the `AnalyzeResult` class if desired for purposes of passing
custom information generated just once during UDTF analysis to other method calls;
this can be especially useful if this initialization is expensive.
"""
...
@staticmethod
def analyze(self, *args: AnalyzeArgument) -> AnalyzeResult:
"""
Static method to compute the output schema of a particular call to this function in
response to the arguments provided.
This method is optional and only needed if the registration of the UDTF did not provide
a static output schema to be use for all calls to the function. In this context,
`output schema` refers to the ordered list of the names and types of the columns in the
function's result table.
This method accepts zero or more parameters mapping 1:1 with the arguments provided to
the particular UDTF call under consideration. Each parameter is an instance of the
`AnalyzeArgument` class.
`AnalyzeArgument` fields
------------------------
dataType: DataType
Indicates the type of the provided input argument to this particular UDTF call.
For input table arguments, this is a StructType representing the table's columns.
value: Optional[Any]
The value of the provided input argument to this particular UDTF call. This is
`None` for table arguments, or for literal scalar arguments that are not constant.
isTable: bool
This is true if the provided input argument to this particular UDTF call is a
table argument.
isConstantExpression: bool
This is true if the provided input argument to this particular UDTF call is either a
literal or other constant-foldable scalar expression.
This method returns an instance of the `AnalyzeResult` class which includes the result
table's schema as a StructType. If the UDTF accepts an input table argument, then the
`AnalyzeResult` can also include a requested way to partition and order the rows of
the input table across several UDTF calls. See below for more information about UDTF
table arguments and how to call them in SQL queries, including the WITH SINGLE
PARTITION clause (corresponding to the `withSinglePartition` field here), PARTITION BY
clause (corresponding to the `partitionBy` field here), ORDER BY clause (corresponding
to the `orderBy` field here), and passing table subqueries as arguments (corresponding
to the `select` field here).
`AnalyzeResult` fields
----------------------
schema: StructType
The schema of the result table.
withSinglePartition: bool = False
If True, the query planner will arrange a repartitioning operation from the previous
execution stage such that all rows of the input table are consumed by the `eval`
method from exactly one instance of the UDTF class.
partitionBy: Sequence[PartitioningColumn] = field(default_factory=tuple)
If non-empty, the query planner will arrange a repartitioning such that all rows
with each unique combination of values of the partitioning expressions are consumed
by a separate unique instance of the UDTF class.
orderBy: Sequence[OrderingColumn] = field(default_factory=tuple)
If non-empty, this specifies the requested ordering of rows within each partition.
select: Sequence[SelectedColumn] = field(default_factory=tuple)
If non-empty, this is a sequence of expressions that the UDTF is specifying for
Catalyst to evaluate against the columns in the input TABLE argument. The UDTF then
receives one input attribute for each name in the list, in the order they are
listed.
Notes
-----
- It is possible for the `analyze` method to accept the exact arguments expected,
mapping 1:1 with the arguments provided to the UDTF call.
- The `analyze` method can instead choose to accept positional arguments if desired
(using `*args`) or keyword arguments (using `**kwargs`).
Examples
--------
This is an `analyze` implementation that returns one output column for each word in the
input string argument.
>>> @staticmethod
... def analyze(text: str) -> AnalyzeResult:
... schema = StructType()
... for index, word in enumerate(text.split(" ")):
... schema = schema.add(f"word_{index}")
... return AnalyzeResult(schema=schema)
Same as above, but using *args to accept the arguments.
>>> @staticmethod
... def analyze(*args) -> AnalyzeResult:
... assert len(args) == 1, "This function accepts one argument only"
... assert args[0].dataType == StringType(), "Only string arguments are supported"
... text = args[0]
... schema = StructType()
... for index, word in enumerate(text.split(" ")):
... schema = schema.add(f"word_{index}")
... return AnalyzeResult(schema=schema)
Same as above, but using **kwargs to accept the arguments.
>>> @staticmethod
... def analyze(**kwargs) -> AnalyzeResult:
... assert len(kwargs) == 1, "This function accepts one argument only"
... assert "text" in kwargs, "An argument named 'text' is required"
... assert kwargs["text"].dataType == StringType(), "Only strings are supported"
... text = args["text"]
... schema = StructType()
... for index, word in enumerate(text.split(" ")):
... schema = schema.add(f"word_{index}")
... return AnalyzeResult(schema=schema)
This is an `analyze` implementation that returns a constant output schema, but add
custom information in the result metadata to be consumed by future __init__ method
calls:
>>> @staticmethod
... def analyze(text: str) -> AnalyzeResult:
... @dataclass
... class AnalyzeResultWithOtherMetadata(AnalyzeResult):
... num_words: int
... num_articles: int
... words = text.split(" ")
... return AnalyzeResultWithOtherMetadata(
... schema=StructType()
... .add("word", StringType())
... .add('total", IntegerType()),
... num_words=len(words),
... num_articles=len((
... word for word in words
... if word == 'a' or word == 'an' or word == 'the')))
This is an `analyze` implementation that returns a constant output schema, and also
requests to select a subset of columns from the input table and for the input table to
be partitioned across several UDTF calls based on the values of the `date` column.
A SQL query may this UDTF passing a table argument like "SELECT * FROM udtf(TABLE(t))".
Then this `analyze` method specifies additional constraints on the input table:
(1) The input table must be partitioned across several UDTF calls based on the values of
the month value of each `date` column.
(2) The rows within each partition will arrive ordered by the `date` column.
(3) The UDTF will only receive the `date` and `word` columns from the input table.
>>> @staticmethod
... def analyze(*args) -> AnalyzeResult:
... assert len(args) == 1, "This function accepts one argument only"
... assert args[0].isTable, "Only table arguments are supported"
... return AnalyzeResult(
... schema=StructType()
... .add("month", DateType())
... .add('longest_word", IntegerType()),
... partitionBy=[
... PartitioningColumn("extract(month from date)")],
... orderBy=[
... OrderingColumn("date")],
... select=[
... SelectedColumn("date"),
... SelectedColumn(
... name="length(word),
... alias="length_word")])
"""
...
def eval(self, *args: Any) -> Iterator[Any]:
"""
Evaluates the function using the given input arguments.
This method is required and must be implemented.
Argument Mapping:
- Each provided scalar expression maps to exactly one value in the
`*args` list.
- Each provided table argument maps to a pyspark.sql.Row object containing
the columns in the order they appear in the provided input table,
and with the names computed by the query analyzer.
This method is called on every input row, and can produce zero or more
output rows. Each element in the output tuple corresponds to one column
specified in the return type of the UDTF.
Parameters
----------
*args : Any
Arbitrary positional arguments representing the input to the UDTF.
Yields
------
tuple
A tuple, list, or pyspark.sql.Row object representing a single row in the UDTF
result table. Yield as many times as needed to produce multiple rows.
Notes
-----
- It is also possible for UDTFs to accept the exact arguments expected, along with
their types.
- UDTFs can instead accept keyword arguments during the function call if needed.
- The `eval` method can raise a `SkipRestOfInputTableException` to indicate that the
UDTF wants to skip consuming all remaining rows from the current partition of the
input table. This will cause the UDTF to proceed directly to the `terminate` method.
- The `eval` method can raise any other exception to indicate that the UDTF should be
aborted entirely. This will cause the UDTF to skip the `terminate` method and proceed
directly to the `cleanup` method, and then the exception will be propagated to the
query processor causing the invoking query to fail.
Examples
--------
This `eval` method returns one row and one column for each input.
>>> def eval(self, x: int):
... yield (x, )
This `eval` method returns two rows and two columns for each input.
>>> def eval(self, x: int, y: int):
... yield (x + y, x - y)
... yield (y + x, y - x)
Same as above, but using *args to accept the arguments:
>>> def eval(self, *args):
... assert len(args) == 2, "This function accepts two integer arguments only"
... x = args[0]
... y = args[1]
... yield (x + y, x - y)
... yield (y + x, y - x)
Same as above, but using **kwargs to accept the arguments:
>>> def eval(self, **kwargs):
... assert len(kwargs) == 2, "This function accepts two integer arguments only"
... x = kwargs["x"]
... y = kwargs["y"]
... yield (x + y, x - y)
... yield (y + x, y - x)
"""
...
def terminate(self) -> Iterator[Any]:
"""
Called when the UDTF has successfully processed all input rows.
This method is optional to implement and is useful for performing any
finalization operations after the UDTF has finished processing
all rows. It can also be used to yield additional rows if needed.
Table functions that consume all rows in the entire input partition
and then compute and return the entire output table can do so from
this method as well (please be mindful of memory usage when doing
this).
If any exceptions occur during input row processing, this method
won't be called.
Yields
------
tuple
A tuple representing a single row in the UDTF result table.
Yield this if you want to return additional rows during termination.
Examples
--------
>>> def terminate(self) -> Iterator[Any]:
>>> yield "done", None
"""
...
def cleanup(self) -> None:
"""
Invoked after the UDTF completes processing input rows.
This method is optional to implement and is useful for final cleanup
regardless of whether the UDTF processed all input rows successfully
or was aborted due to exceptions.
Examples
--------
>>> def cleanup(self) -> None:
>>> self.conn.close()
"""
...
定义输出 Schema#
UDTF 的返回类型定义了其输出表的 schema。
您可以在 @udtf
装饰器之后指定它,或者作为 analyze
方法的结果指定它。
它必须是 StructType
StructType().add("c1", StringType())
或表示结构类型的 DDL 字符串
c1: string
发出输出行#
eval 和 terminate 方法通过生成元组、列表或 pyspark.sql.Row
对象来发出零个或多个符合此 schema 的输出行。
例如,这里我们通过提供一个包含三个元素的元组来返回一行
def eval(self, x, y, z):
yield (x, y, z)
也可以省略括号
def eval(self, x, y, z):
yield x, y, z
如果只返回一列,请记住添加一个尾随逗号!
def eval(self, x, y, z):
yield x,
也可以生成一个 pyspark.sql.Row
对象。
def eval(self, x, y, z)
from pyspark.sql.types import Row
yield Row(x, y, z)
这是一个使用 Python 列表从 terminate 方法生成输出行的示例。通常,为了此目的,在 UDTF 评估的早期步骤中将状态存储在类内部是有意义的。
def terminate(self):
yield [self.x, self.y, self.z]
在 SQL 中注册和使用 Python UDTF#
Python UDTF 可以在 SQL 查询中注册和使用。
from pyspark.sql.functions import udtf
@udtf(returnType="word: string")
class WordSplitter:
def eval(self, text: str):
for word in text.split(" "):
yield (word.strip(),)
# Register the UDTF for use in Spark SQL.
spark.udtf.register("split_words", WordSplitter)
# Example: Using the UDTF in SQL.
spark.sql("SELECT * FROM split_words('hello world')").show()
# +-----+
# | word|
# +-----+
# |hello|
# |world|
# +-----+
# Example: Using the UDTF with a lateral join in SQL.
# The lateral join allows us to reference the columns and aliases
# in the previous FROM clause items as inputs to the UDTF.
spark.sql(
"SELECT * FROM VALUES ('Hello World'), ('Apache Spark') t(text), "
"LATERAL split_words(text)"
).show()
# +------------+------+
# | text| word|
# +------------+------+
# | Hello World| Hello|
# | Hello World| World|
# |Apache Spark|Apache|
# |Apache Spark| Spark|
# +------------+------+
Arrow 优化#
Apache Arrow 是一种内存中的列式数据格式,在 Spark 中用于在 Java 和 Python 进程之间高效传输数据。Apache Arrow 默认情况下对 Python UDTF 禁用。
当每个输入行从 UDTF 生成大型结果表时,Arrow 可以提高性能。
要启用 Arrow 优化,请将 spark.sql.execution.pythonUDTF.arrow.enabled
配置设置为 true
。您也可以在声明 UDTF 时指定 useArrow
参数来启用它。
from pyspark.sql.functions import udtf
@udtf(returnType="c1: int, c2: int", useArrow=True)
class PlusOne:
def eval(self, x: int):
yield x, x + 1
有关更多详细信息,请参阅 PySpark 中的 Apache Arrow。
带标量参数的 UDTF 示例#
下面是一个 UDTF 类实现的简单示例
# Define the UDTF class and implement the required `eval` method.
class SquareNumbers:
def eval(self, start: int, end: int):
for num in range(start, end + 1):
yield (num, num * num)
要使用 UDTF,您首先需要使用 @udtf
装饰器对其进行实例化
from pyspark.sql.functions import lit, udtf
# Create a UDTF using the class definition and the `udtf` function.
square_num = udtf(SquareNumbers, returnType="num: int, squared: int")
# Invoke the UDTF in PySpark.
square_num(lit(1), lit(3)).show()
# +---+-------+
# |num|squared|
# +---+-------+
# | 1| 1|
# | 2| 4|
# | 3| 9|
# +---+-------+
创建 UDTF 的另一种方法是使用 udtf()
函数
from pyspark.sql.functions import lit, udtf
# Define a UDTF using the `udtf` decorator directly on the class.
@udtf(returnType="num: int, squared: int")
class SquareNumbers:
def eval(self, start: int, end: int):
for num in range(start, end + 1):
yield (num, num * num)
# Invoke the UDTF in PySpark using the SquareNumbers class directly.
SquareNumbers(lit(1), lit(3)).show()
# +---+-------+
# |num|squared|
# +---+-------+
# | 1| 1|
# | 2| 4|
# | 3| 9|
# +---+-------+
这是一个将日期范围扩展为单个日期的 Python UDTF
from datetime import datetime, timedelta
from pyspark.sql.functions import lit, udtf
@udtf(returnType="date: string")
class DateExpander:
def eval(self, start_date: str, end_date: str):
current = datetime.strptime(start_date, '%Y-%m-%d')
end = datetime.strptime(end_date, '%Y-%m-%d')
while current <= end:
yield (current.strftime('%Y-%m-%d'),)
current += timedelta(days=1)
DateExpander(lit("2023-02-25"), lit("2023-03-01")).show()
# +----------+
# | date|
# +----------+
# |2023-02-25|
# |2023-02-26|
# |2023-02-27|
# |2023-02-28|
# |2023-03-01|
# +----------+
这是一个带有 __init__
和 terminate
的 Python UDTF
from pyspark.sql.functions import udtf
@udtf(returnType="cnt: int")
class CountUDTF:
def __init__(self):
# Initialize the counter to 0 when an instance of the class is created.
self.count = 0
def eval(self, x: int):
# Increment the counter by 1 for each input value received.
self.count += 1
def terminate(self):
# Yield the final count when the UDTF is done processing.
yield self.count,
spark.udtf.register("count_udtf", CountUDTF)
spark.sql("SELECT * FROM range(0, 10, 1, 1), LATERAL count_udtf(id)").show()
# +---+---+
# | id|cnt|
# +---+---+
# | 9| 10|
# +---+---+
spark.sql("SELECT * FROM range(0, 10, 1, 2), LATERAL count_udtf(id)").show()
# +---+---+
# | id|cnt|
# +---+---+
# | 4| 5|
# | 9| 5|
# +---+---+
接受输入表参数#
上面的 UDTF 示例展示了接受标量输入参数的函数,例如整数或字符串。
然而,任何 Python UDTF 也可以接受输入表作为参数,并且这可以与相同函数定义的标量输入参数结合使用。您只允许将一个此类表参数作为输入。
然后,任何 SQL 查询都可以使用 TABLE
关键字后跟括号(包含适当的表标识符)来提供输入表,例如 TABLE(t)
。或者,您可以传递一个表子查询,例如 TABLE(SELECT a, b, c FROM t)
或 TABLE(SELECT t1.a, t2.b FROM t1 INNER JOIN t2 USING (key))
。
然后,输入表参数作为 pyspark.sql.Row
参数表示给 eval
方法,输入表中的每一行都会对 eval
方法进行一次调用。
例如
from pyspark.sql.functions import udtf
from pyspark.sql.types import Row
@udtf(returnType="id: int")
class FilterUDTF:
def eval(self, row: Row):
if row["id"] > 5:
yield row["id"],
spark.udtf.register("filter_udtf", FilterUDTF)
spark.sql("SELECT * FROM filter_udtf(TABLE(SELECT * FROM range(10)))").show()
# +---+
# | id|
# +---+
# | 6|
# | 7|
# | 8|
# | 9|
# +---+
当使用表参数调用 UDTF 时,任何 SQL 查询都可以请求根据输入表的一个或多个列的值将输入表分区到多个 UDTF 调用中。为此,在函数调用中 TABLE
参数之后指定 PARTITION BY
子句。这保证了具有分区列值的每个唯一组合的所有输入行都将由 UDTF 类的恰好一个实例消耗。
请注意,除了简单的列引用之外,PARTITION BY
子句还接受基于输入表列的任意表达式。例如,您可以指定字符串的 LENGTH
,从日期中提取月份,或连接两个值。
也可以指定 WITH SINGLE PARTITION
而不是 PARTITION BY
,以请求仅一个分区,其中所有输入行都必须由 UDTF 类的恰好一个实例消耗。
在每个分区内,您可以选择性地指定输入行的所需排序,因为 UDTF 的 eval
方法会消耗它们。为此,请在上述 PARTITION BY
或 WITH SINGLE PARTITION
子句之后提供一个 ORDER BY
子句。
例如
from pyspark.sql.functions import udtf
from pyspark.sql.types import Row
# Define and register a UDTF.
@udtf(returnType="a: string, b: int")
class FilterUDTF:
def __init__(self):
self.key = ""
self.max = 0
def eval(self, row: Row):
self.key = row["a"]
self.max = max(self.max, row["b"])
def terminate(self):
yield self.key, self.max
spark.udtf.register("filter_udtf", FilterUDTF)
# Create an input table with some example values.
spark.sql("DROP TABLE IF EXISTS values_table")
spark.sql("CREATE TABLE values_table (a STRING, b INT)")
spark.sql("INSERT INTO values_table VALUES ('abc', 2), ('abc', 4), ('def', 6), ('def', 8)")
spark.table("values_table").show()
# +-------+----+
# | a | b |
# +-------+----+
# | "abc" | 2 |
# | "abc" | 4 |
# | "def" | 6 |
# | "def" | 8 |
# +-------+----+
# Query the UDTF with the input table as an argument, and a directive to partition the input
# rows such that all rows with each unique value of the `a` column are processed by the same
# instance of the UDTF class. Within each partition, the rows are ordered by the `b` column.
spark.sql("""
SELECT * FROM filter_udtf(TABLE(values_table) PARTITION BY a ORDER BY b) ORDER BY 1
""").show()
# +-------+----+
# | a | b |
# +-------+----+
# | "abc" | 4 |
# | "def" | 8 |
# +-------+----+
# Query the UDTF with the input table as an argument, and a directive to partition the input
# rows such that all rows with each unique result of evaluating the "LENGTH(a)" expression are
# processed by the same instance of the UDTF class. Within each partition, the rows are ordered
# by the `b` column.
spark.sql("""
SELECT * FROM filter_udtf(TABLE(values_table) PARTITION BY LENGTH(a) ORDER BY b) ORDER BY 1
""").show()
# +-------+---+
# | a | b |
# +-------+---+
# | "def" | 8 |
# +-------+---+
# Query the UDTF with the input table as an argument, and a directive to consider all the input
# rows in one single partition such that exactly once instance of the UDTF class consumes all of
# the input rows. Within each partition, the rows are ordered by the `b` column.
spark.sql("""
SELECT * FROM filter_udtf(TABLE(values_table) WITH SINGLE PARTITION ORDER BY b) ORDER BY 1
""").show()
# +-------+----+
# | a | b |
# +-------+----+
# | "def" | 8 |
# +-------+----+
# Clean up.
spark.sql("DROP TABLE values_table")
请注意,在 SQL 查询中调用 UDTF 时,对于每种对输入表进行分区的方法,UDTF 的 analyze
方法都有相应的方式来自动指定相同的分区方法。
例如,您无需将 UDTF 调用为 SELECT * FROM udtf(TABLE(t) PARTITION BY a)
,而是可以更新 analyze
方法以设置字段 partitionBy=[PartitioningColumn("a")]
,然后只需像 SELECT * FROM udtf(TABLE(t))
这样调用该函数。
同理,您无需在 SQL 查询中指定 TABLE(t) WITH SINGLE PARTITION
,而是让 analyze
设置字段 withSinglePartition=true
,然后只需传递 TABLE(t)
。
您无需在 SQL 查询中传递 TABLE(t) ORDER BY b
,而是可以使 analyze
设置 orderBy=[OrderingColumn("b")]
,然后只需传递 TABLE(t)
。
您无需在 SQL 查询中传递 TABLE(SELECT a FROM t)
,而是可以使 analyze
设置 select=[SelectedColumn("a")]
,然后只需传递 TABLE(t)
。