第5章:充分利用 UDF 和 UDTF#

在大型数据处理中,通常需要进行定制以扩展 Spark 的原生能力。Python 用户定义函数 (UDF)用户定义表函数 (UDTF) 提供了一种使用 Python 执行复杂转换和计算的方法,并将它们无缝集成到 Spark 的分布式环境中。

在本节中,我们将探讨如何在 Python 中编写和使用 UDF 和 UDTF,利用 PySpark 执行超出 Spark 内置功能的复杂数据转换。

Python UDF#

Python UDF 的类别#

PySpark 支持两种主要的 UDF 类别:标量 Python UDF 和 Pandas UDF。

  • 标量 Python UDF 是用户定义的标量函数,它们接受或返回通过 pickleArrow 序列化/反序列化的 Python 对象,并且一次操作一行。

  • Pandas UDF(又称矢量化 UDF)是接受/返回由 Apache Arrow 序列化/反序列化的 pandas Series 或 DataFrame,并逐块操作的 UDF。Pandas UDF 根据用法分为几种变体,具有特定的输入和输出类型:Series 到 Series、Series 到 Scalar,以及 Iterator 到 Iterator。

基于 Pandas UDF 的实现,还有 Pandas 函数 API:Map(即 mapInPandas)和 (Co)Grouped Map(即 applyInPandas),以及一个 Arrow 函数 API - mapInArrow

创建标量 Python UDF#

在下面的代码中,我们创建了一个简单的标量 Python UDF。

[6]:
from pyspark.sql.functions import udf

@udf(returnType='int')
def slen(s: str):
    return len(s)

Arrow 优化#

标量 Python UDF 依赖 cloudpickle 进行序列化和反序列化,并且会遇到性能瓶颈,尤其是在处理大量数据输入和输出时。我们引入 Arrow 优化的 Python UDF 以显著提高性能。

这种优化的核心在于 Apache Arrow,它是一种标准化的跨语言列式内存数据表示。通过利用 Arrow,这些 UDF 绕过了传统、较慢的数据(反)序列化方法,从而实现了 JVM 和 Python 进程之间快速的数据交换。凭借 Apache Arrow 丰富的类型系统,这些优化后的 UDF 提供了一种更一致、更标准化的方式来处理类型强制转换。

我们可以通过使用 functions.udf 的布尔参数 useArrow 来控制是否为单个 UDF 启用 Arrow 优化。示例如下:

from pyspark.sql.functions import udf

@udf(returnType='int', useArrow=True)  # An Arrow Python UDF
def arrow_slen(s: str):
    ...

此外,我们还可以通过 Spark 配置 spark.sql.execution.pythonUDF.arrow.enabled 为整个 SparkSession 的所有 UDF 启用 Arrow 优化,如下所示:

spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", True)

@udf(returnType='int')  # An Arrow Python UDF
def arrow_slen(s: str):
    ...

使用标量 Python UDF#

在 Python 中,我们可以直接对列调用 UDF,就像调用 Spark 内置函数一样,如下所示。

[7]:
data = [("Alice",), ("Bob",), ("Charlie",)]
df = spark.createDataFrame(data, ["name"])
df.withColumn("name_length", slen(df["name"])).show()
+-------+-----------+
|   name|name_length|
+-------+-----------+
|  Alice|          5|
|    Bob|          3|
|Charlie|          7|
+-------+-----------+

创建 Pandas UDF#

在下面的代码中,我们创建了一个 Pandas UDF,它接受一个 pandas.Series 并输出一个 pandas.Series

[8]:
import pandas as pd
from pyspark.sql.functions import pandas_udf

@pandas_udf("string")
def to_upper(s: pd.Series) -> pd.Series:
    return s.str.upper()

df = spark.createDataFrame([("John Doe",)], ("name",))
df.select(to_upper("name")).show()

+--------------+
|to_upper(name)|
+--------------+
|      JOHN DOE|
+--------------+

使用 Pandas UDF#

与标量 Python UDF 类似,我们也可以直接对列调用 Pandas UDF。

[9]:
data = [("Alice",), ("Bob",), ("Charlie",)]
df = spark.createDataFrame(data, ["name"])
df.withColumn("name_length", to_upper(df["name"])).show()
+-------+-----------+
|   name|name_length|
+-------+-----------+
|  Alice|      ALICE|
|    Bob|        BOB|
|Charlie|    CHARLIE|
+-------+-----------+

更多示例#

示例1:使用 Python UDF 处理包含字符串和列表列的 DataFrame#

[10]:
from pyspark.sql.types import ArrayType, IntegerType, StringType
from pyspark.sql.functions import udf

data = [
    ("Hello World", [1, 2, 3]),
    ("PySpark is Fun", [4, 5, 6]),
    ("PySpark Rocks", [7, 8, 9])
]
df = spark.createDataFrame(data, ["text_column", "list_column"])

@udf(returnType="string")
def process_row(text: str, numbers):
    vowels_count = sum(1 for char in text if char in "aeiouAEIOU")
    doubled = [x * 2 for x in numbers]
    return f"Vowels: {vowels_count}, Doubled: {doubled}"

df.withColumn("process_row", process_row(df["text_column"], df["list_column"])).show(truncate=False)
+--------------+-----------+--------------------------------+
|text_column   |list_column|process_row                     |
+--------------+-----------+--------------------------------+
|Hello World   |[1, 2, 3]  |Vowels: 3, Doubled: [2, 4, 6]   |
|PySpark is Fun|[4, 5, 6]  |Vowels: 3, Doubled: [8, 10, 12] |
|PySpark Rocks |[7, 8, 9]  |Vowels: 2, Doubled: [14, 16, 18]|
+--------------+-----------+--------------------------------+

示例2:用于统计计算和复杂转换的 Pandas UDF#

[11]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import StructType, StructField, DoubleType, StringType
import pandas as pd

data = [
    (10.0, "Spark"),
    (20.0, "Big Data"),
    (30.0, "AI"),
    (40.0, "Machine Learning"),
    (50.0, "Deep Learning")
]
df = spark.createDataFrame(data, ["numeric_column", "text_column"])

# Schema for the result
schema = StructType([
    StructField("mean_value", DoubleType(), True),
    StructField("sum_value", DoubleType(), True),
    StructField("processed_text", StringType(), True)
])

@pandas_udf(schema)
def compute_stats_and_transform_string(numeric_col: pd.Series, text_col: pd.Series) -> pd.DataFrame:
    mean_value = numeric_col.mean()
    sum_value = numeric_col.sum()

    # Reverse the string if its length is greater than 5, otherwise capitalize it
    processed_text = text_col.apply(lambda x: x[::-1] if len(x) > 5 else x.upper())

    result_df = pd.DataFrame({
        "mean_value": [mean_value] * len(text_col),
        "sum_value": [sum_value] * len(text_col),
        "processed_text": processed_text
    })

    return result_df

df.withColumn("result", compute_stats_and_transform_string(df["numeric_column"], df["text_column"])).show(truncate=False)
+--------------+----------------+------------------------------+
|numeric_column|text_column     |result                        |
+--------------+----------------+------------------------------+
|10.0          |Spark           |{10.0, 10.0, SPARK}           |
|20.0          |Big Data        |{20.0, 20.0, ataD giB}        |
|30.0          |AI              |{30.0, 30.0, AI}              |
|40.0          |Machine Learning|{40.0, 40.0, gninraeL enihcaM}|
|50.0          |Deep Learning   |{50.0, 50.0, gninraeL peeD}   |
+--------------+----------------+------------------------------+

Python UDTF#

Python 用户定义表函数 (UDTF) 是一种新型函数,它返回一个表作为输出,而不是单个标量结果值。一旦注册,它们就可以出现在 SQL 查询的 FROM 子句中。

何时使用 Python UDTF#

简而言之,如果您需要一个生成多行和多列的函数,并且希望利用丰富的 Python 生态系统,那么 Python UDTF 是您的不二选择。

  • Python UDTF 与 Python UDF 比较:Spark 中的 Python UDF 旨在每个接受零个或多个标量值作为输入,并返回单个值作为输出,而 UDTF 则提供了更大的灵活性。它们可以返回多行和多列,扩展了 UDF 的功能。以下是 UDTF 特别有用的几种场景:

    • 展开数组或结构体等嵌套数据类型,将其转换为多行

    • 处理需要拆分为多个部分(每个部分表示为单独的行或多列)的字符串数据

    • 根据输入范围生成行,例如创建数字序列、时间戳或不同日期的记录

  • Python UDTF 与 SQL UDTF 比较:SQL UDTF 效率高且用途广泛,但 Python 提供了更丰富的库和工具集。与 SQL 相比,Python 提供了实现高级转换或计算(例如统计函数或机器学习推理)的工具。

创建 Python UDTF#

在下面的代码中,我们创建了一个简单的 UDTF,它接受两个整数作为输入,并生成两列作为输出:原始数字及其平方。

请注意 yield 语句的使用;Python UDTF 要求返回类型为元组或 Row 对象,以便正确处理结果。

另请注意,返回类型必须是 StructType,带有块格式,或者是代表 Spark 中带有块格式的 StructType 的 DDL 字符串。

[12]:
from pyspark.sql.functions import udtf

@udtf(returnType="num: int, squared: int")
class SquareNumbers:
    def eval(self, start: int, end: int):
        for num in range(start, end + 1):
            yield (num, num * num)

Arrow 优化#

Apache Arrow 是一种内存列式数据格式,允许在 Java 和 Python 进程之间高效传输数据。当 UDTF 输出多行时,它可以显著提升性能。Arrow 优化可以通过使用 useArrow=True 启用,例如:

from pyspark.sql.functions import udtf

@udtf(returnType="num: int, squared: int", useArrow=True)
class SquareNumbers:
    ...

使用 Python UDTF#

在 Python 中,我们可以直接使用类名调用 UDTF,如下所示。

[13]:
from pyspark.sql.functions import lit

SquareNumbers(lit(1), lit(3)).show()
+---+-------+
|num|squared|
+---+-------+
|  1|      1|
|  2|      4|
|  3|      9|
+---+-------+

在 SQL 中,我们可以注册 Python UDTF,然后将其作为表值函数在 SQL 查询的 FROM 子句中使用。

spark.sql("SELECT * FROM square_numbers(1, 3)").show()

更多示例#

示例1:为给定范围生成数字、它们的平方、立方和阶乘#

[14]:
from pyspark.sql.functions import lit, udtf
import math

@udtf(returnType="num: int, square: int, cube: int, factorial: int")
class GenerateComplexNumbers:
    def eval(self, start: int, end: int):
        for num in range(start, end + 1):
            yield (num, num ** 2, num ** 3, math.factorial(num))

GenerateComplexNumbers(lit(1), lit(5)).show()
+---+------+----+---------+
|num|square|cube|factorial|
+---+------+----+---------+
|  1|     1|   1|        1|
|  2|     4|   8|        2|
|  3|     9|  27|        6|
|  4|    16|  64|       24|
|  5|    25| 125|      120|
+---+------+----+---------+

示例2:将句子拆分为单词并执行多项操作#

[15]:
from pyspark.sql.functions import lit, udtf

@udtf(returnType="word: string, length: int, is_palindrome: boolean")
class ProcessWords:
    def eval(self, sentence: str):
        words = sentence.split()  # Split sentence into words
        for word in words:
            is_palindrome = word == word[::-1]  # Check if the word is a palindrome
            yield (word, len(word), is_palindrome)

ProcessWords(lit("hello world")).show()
+-----+------+-------------+
| word|length|is_palindrome|
+-----+------+-------------+
|hello|     5|        false|
|world|     5|        false|
+-----+------+-------------+

示例3:将 JSON 字符串解析为带有数据类型的键值对#

[16]:
import json
from pyspark.sql.functions import lit, udtf

@udtf(returnType="key: string, value: string, value_type: string")
class ParseJSON:
    def eval(self, json_str: str):
        try:
            json_data = json.loads(json_str)
            for key, value in json_data.items():
                value_type = type(value).__name__
                yield (key, str(value), value_type)
        except json.JSONDecodeError:
            yield ("Invalid JSON", "", "")

ParseJSON(lit('{"name": "Alice", "age": 25, "is_student": false}')).show()
+----------+-----+----------+
|       key|value|value_type|
+----------+-----+----------+
|      name|Alice|       str|
|       age|   25|       int|
|is_student|False|      bool|
+----------+-----+----------+