第5章:充分利用 UDF 和 UDTF#
在大型数据处理中,通常需要进行定制以扩展 Spark 的原生能力。Python 用户定义函数 (UDF) 和 用户定义表函数 (UDTF) 提供了一种使用 Python 执行复杂转换和计算的方法,并将它们无缝集成到 Spark 的分布式环境中。
在本节中,我们将探讨如何在 Python 中编写和使用 UDF 和 UDTF,利用 PySpark 执行超出 Spark 内置功能的复杂数据转换。
Python UDF#
Python UDF 的类别#
PySpark 支持两种主要的 UDF 类别:标量 Python UDF 和 Pandas UDF。
标量 Python UDF 是用户定义的标量函数,它们接受或返回通过 pickle 或 Arrow 序列化/反序列化的 Python 对象,并且一次操作一行。
Pandas UDF(又称矢量化 UDF)是接受/返回由 Apache Arrow 序列化/反序列化的 pandas Series 或 DataFrame,并逐块操作的 UDF。Pandas UDF 根据用法分为几种变体,具有特定的输入和输出类型:Series 到 Series、Series 到 Scalar,以及 Iterator 到 Iterator。
基于 Pandas UDF 的实现,还有 Pandas 函数 API:Map(即 mapInPandas
)和 (Co)Grouped Map(即 applyInPandas
),以及一个 Arrow 函数 API - mapInArrow
。
创建标量 Python UDF#
在下面的代码中,我们创建了一个简单的标量 Python UDF。
[6]:
from pyspark.sql.functions import udf
@udf(returnType='int')
def slen(s: str):
return len(s)
Arrow 优化#
标量 Python UDF 依赖 cloudpickle 进行序列化和反序列化,并且会遇到性能瓶颈,尤其是在处理大量数据输入和输出时。我们引入 Arrow 优化的 Python UDF 以显著提高性能。
这种优化的核心在于 Apache Arrow,它是一种标准化的跨语言列式内存数据表示。通过利用 Arrow,这些 UDF 绕过了传统、较慢的数据(反)序列化方法,从而实现了 JVM 和 Python 进程之间快速的数据交换。凭借 Apache Arrow 丰富的类型系统,这些优化后的 UDF 提供了一种更一致、更标准化的方式来处理类型强制转换。
我们可以通过使用 functions.udf
的布尔参数 useArrow
来控制是否为单个 UDF 启用 Arrow 优化。示例如下:
from pyspark.sql.functions import udf
@udf(returnType='int', useArrow=True) # An Arrow Python UDF
def arrow_slen(s: str):
...
此外,我们还可以通过 Spark 配置 spark.sql.execution.pythonUDF.arrow.enabled
为整个 SparkSession 的所有 UDF 启用 Arrow 优化,如下所示:
spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", True)
@udf(returnType='int') # An Arrow Python UDF
def arrow_slen(s: str):
...
使用标量 Python UDF#
在 Python 中,我们可以直接对列调用 UDF,就像调用 Spark 内置函数一样,如下所示。
[7]:
data = [("Alice",), ("Bob",), ("Charlie",)]
df = spark.createDataFrame(data, ["name"])
df.withColumn("name_length", slen(df["name"])).show()
+-------+-----------+
| name|name_length|
+-------+-----------+
| Alice| 5|
| Bob| 3|
|Charlie| 7|
+-------+-----------+
创建 Pandas UDF#
在下面的代码中,我们创建了一个 Pandas UDF,它接受一个 pandas.Series
并输出一个 pandas.Series
。
[8]:
import pandas as pd
from pyspark.sql.functions import pandas_udf
@pandas_udf("string")
def to_upper(s: pd.Series) -> pd.Series:
return s.str.upper()
df = spark.createDataFrame([("John Doe",)], ("name",))
df.select(to_upper("name")).show()
+--------------+
|to_upper(name)|
+--------------+
| JOHN DOE|
+--------------+
使用 Pandas UDF#
与标量 Python UDF 类似,我们也可以直接对列调用 Pandas UDF。
[9]:
data = [("Alice",), ("Bob",), ("Charlie",)]
df = spark.createDataFrame(data, ["name"])
df.withColumn("name_length", to_upper(df["name"])).show()
+-------+-----------+
| name|name_length|
+-------+-----------+
| Alice| ALICE|
| Bob| BOB|
|Charlie| CHARLIE|
+-------+-----------+
更多示例#
示例1:使用 Python UDF 处理包含字符串和列表列的 DataFrame#
[10]:
from pyspark.sql.types import ArrayType, IntegerType, StringType
from pyspark.sql.functions import udf
data = [
("Hello World", [1, 2, 3]),
("PySpark is Fun", [4, 5, 6]),
("PySpark Rocks", [7, 8, 9])
]
df = spark.createDataFrame(data, ["text_column", "list_column"])
@udf(returnType="string")
def process_row(text: str, numbers):
vowels_count = sum(1 for char in text if char in "aeiouAEIOU")
doubled = [x * 2 for x in numbers]
return f"Vowels: {vowels_count}, Doubled: {doubled}"
df.withColumn("process_row", process_row(df["text_column"], df["list_column"])).show(truncate=False)
+--------------+-----------+--------------------------------+
|text_column |list_column|process_row |
+--------------+-----------+--------------------------------+
|Hello World |[1, 2, 3] |Vowels: 3, Doubled: [2, 4, 6] |
|PySpark is Fun|[4, 5, 6] |Vowels: 3, Doubled: [8, 10, 12] |
|PySpark Rocks |[7, 8, 9] |Vowels: 2, Doubled: [14, 16, 18]|
+--------------+-----------+--------------------------------+
示例2:用于统计计算和复杂转换的 Pandas UDF#
[11]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import StructType, StructField, DoubleType, StringType
import pandas as pd
data = [
(10.0, "Spark"),
(20.0, "Big Data"),
(30.0, "AI"),
(40.0, "Machine Learning"),
(50.0, "Deep Learning")
]
df = spark.createDataFrame(data, ["numeric_column", "text_column"])
# Schema for the result
schema = StructType([
StructField("mean_value", DoubleType(), True),
StructField("sum_value", DoubleType(), True),
StructField("processed_text", StringType(), True)
])
@pandas_udf(schema)
def compute_stats_and_transform_string(numeric_col: pd.Series, text_col: pd.Series) -> pd.DataFrame:
mean_value = numeric_col.mean()
sum_value = numeric_col.sum()
# Reverse the string if its length is greater than 5, otherwise capitalize it
processed_text = text_col.apply(lambda x: x[::-1] if len(x) > 5 else x.upper())
result_df = pd.DataFrame({
"mean_value": [mean_value] * len(text_col),
"sum_value": [sum_value] * len(text_col),
"processed_text": processed_text
})
return result_df
df.withColumn("result", compute_stats_and_transform_string(df["numeric_column"], df["text_column"])).show(truncate=False)
+--------------+----------------+------------------------------+
|numeric_column|text_column |result |
+--------------+----------------+------------------------------+
|10.0 |Spark |{10.0, 10.0, SPARK} |
|20.0 |Big Data |{20.0, 20.0, ataD giB} |
|30.0 |AI |{30.0, 30.0, AI} |
|40.0 |Machine Learning|{40.0, 40.0, gninraeL enihcaM}|
|50.0 |Deep Learning |{50.0, 50.0, gninraeL peeD} |
+--------------+----------------+------------------------------+
Python UDTF#
Python 用户定义表函数 (UDTF) 是一种新型函数,它返回一个表作为输出,而不是单个标量结果值。一旦注册,它们就可以出现在 SQL 查询的 FROM 子句中。
何时使用 Python UDTF#
简而言之,如果您需要一个生成多行和多列的函数,并且希望利用丰富的 Python 生态系统,那么 Python UDTF 是您的不二选择。
Python UDTF 与 Python UDF 比较:Spark 中的 Python UDF 旨在每个接受零个或多个标量值作为输入,并返回单个值作为输出,而 UDTF 则提供了更大的灵活性。它们可以返回多行和多列,扩展了 UDF 的功能。以下是 UDTF 特别有用的几种场景:
展开数组或结构体等嵌套数据类型,将其转换为多行
处理需要拆分为多个部分(每个部分表示为单独的行或多列)的字符串数据
根据输入范围生成行,例如创建数字序列、时间戳或不同日期的记录
Python UDTF 与 SQL UDTF 比较:SQL UDTF 效率高且用途广泛,但 Python 提供了更丰富的库和工具集。与 SQL 相比,Python 提供了实现高级转换或计算(例如统计函数或机器学习推理)的工具。
创建 Python UDTF#
在下面的代码中,我们创建了一个简单的 UDTF,它接受两个整数作为输入,并生成两列作为输出:原始数字及其平方。
请注意 yield
语句的使用;Python UDTF 要求返回类型为元组或 Row 对象,以便正确处理结果。
另请注意,返回类型必须是 StructType
,带有块格式,或者是代表 Spark 中带有块格式的 StructType
的 DDL 字符串。
[12]:
from pyspark.sql.functions import udtf
@udtf(returnType="num: int, squared: int")
class SquareNumbers:
def eval(self, start: int, end: int):
for num in range(start, end + 1):
yield (num, num * num)
Arrow 优化#
Apache Arrow 是一种内存列式数据格式,允许在 Java 和 Python 进程之间高效传输数据。当 UDTF 输出多行时,它可以显著提升性能。Arrow 优化可以通过使用 useArrow=True
启用,例如:
from pyspark.sql.functions import udtf
@udtf(returnType="num: int, squared: int", useArrow=True)
class SquareNumbers:
...
使用 Python UDTF#
在 Python 中,我们可以直接使用类名调用 UDTF,如下所示。
[13]:
from pyspark.sql.functions import lit
SquareNumbers(lit(1), lit(3)).show()
+---+-------+
|num|squared|
+---+-------+
| 1| 1|
| 2| 4|
| 3| 9|
+---+-------+
在 SQL 中,我们可以注册 Python UDTF,然后将其作为表值函数在 SQL 查询的 FROM 子句中使用。
spark.sql("SELECT * FROM square_numbers(1, 3)").show()
更多示例#
示例1:为给定范围生成数字、它们的平方、立方和阶乘#
[14]:
from pyspark.sql.functions import lit, udtf
import math
@udtf(returnType="num: int, square: int, cube: int, factorial: int")
class GenerateComplexNumbers:
def eval(self, start: int, end: int):
for num in range(start, end + 1):
yield (num, num ** 2, num ** 3, math.factorial(num))
GenerateComplexNumbers(lit(1), lit(5)).show()
+---+------+----+---------+
|num|square|cube|factorial|
+---+------+----+---------+
| 1| 1| 1| 1|
| 2| 4| 8| 2|
| 3| 9| 27| 6|
| 4| 16| 64| 24|
| 5| 25| 125| 120|
+---+------+----+---------+
示例2:将句子拆分为单词并执行多项操作#
[15]:
from pyspark.sql.functions import lit, udtf
@udtf(returnType="word: string, length: int, is_palindrome: boolean")
class ProcessWords:
def eval(self, sentence: str):
words = sentence.split() # Split sentence into words
for word in words:
is_palindrome = word == word[::-1] # Check if the word is a palindrome
yield (word, len(word), is_palindrome)
ProcessWords(lit("hello world")).show()
+-----+------+-------------+
| word|length|is_palindrome|
+-----+------+-------------+
|hello| 5| false|
|world| 5| false|
+-----+------+-------------+
示例3:将 JSON 字符串解析为带有数据类型的键值对#
[16]:
import json
from pyspark.sql.functions import lit, udtf
@udtf(returnType="key: string, value: string, value_type: string")
class ParseJSON:
def eval(self, json_str: str):
try:
json_data = json.loads(json_str)
for key, value in json_data.items():
value_type = type(value).__name__
yield (key, str(value), value_type)
except json.JSONDecodeError:
yield ("Invalid JSON", "", "")
ParseJSON(lit('{"name": "Alice", "age": 25, "is_student": false}')).show()
+----------+-----+----------+
| key|value|value_type|
+----------+-----+----------+
| name|Alice| str|
| age| 25| int|
|is_student|False| bool|
+----------+-----+----------+