Python 数据源 API#

概览#

Python 数据源 API 是 Spark 4.0 中引入的一项新功能,它使开发人员能够用 Python 从自定义数据源读取数据并写入自定义数据汇。本指南提供了该 API 的全面概述,以及如何创建、使用和管理 Python 数据源的说明。

简单示例#

这是一个简单的 Python 数据源,它精确地生成两行合成数据。此示例演示了如何在不使用外部库的情况下设置自定义数据源,重点介绍了快速启动和运行所需的要素。

步骤 1:定义数据源

from pyspark.sql.datasource import DataSource, DataSourceReader
from pyspark.sql.types import IntegerType, StringType, StructField, StructType

class SimpleDataSource(DataSource):
    """
    A simple data source for PySpark that generates exactly two rows of synthetic data.
    """

    @classmethod
    def name(cls):
        return "simple"

    def schema(self):
        return StructType([
            StructField("name", StringType()),
            StructField("age", IntegerType())
        ])

    def reader(self, schema: StructType):
        return SimpleDataSourceReader()

class SimpleDataSourceReader(DataSourceReader):

    def read(self, partition):
        yield ("Alice", 20)
        yield ("Bob", 30)

步骤 2:注册数据源

from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

spark.dataSource.register(SimpleDataSource)

步骤 3:从数据源读取

spark.read.format("simple").load().show()

# +-----+---+
# | name|age|
# +-----+---+
# |Alice| 20|
# |  Bob| 30|
# +-----+---+

创建 Python 数据源#

要创建自定义 Python 数据源,您需要继承 DataSource 基类,并实现读写数据所需的方法。

此示例演示了如何使用 faker 库创建一个简单的数据源来生成合成数据。请确保 faker 库已安装并在您的 Python 环境中可访问。

定义数据源

首先,创建一个新的 DataSource 子类,并指定源名称和 Schema。

为了在批处理或流式查询中用作源或汇,需要实现 DataSource 的相应方法。

需要为某项功能实现的方法

批处理

reader()

writer()

流式

streamReader() 或 simpleStreamReader()

streamWriter()

from pyspark.sql.datasource import DataSource, DataSourceReader
from pyspark.sql.types import StructType

class FakeDataSource(DataSource):
    """
    A fake data source for PySpark to generate synthetic data using the `faker` library.
    Options:
    - numRows: specify number of rows to generate. Default value is 3.
    """

    @classmethod
    def name(cls):
        return "fake"

    def schema(self):
        return "name string, date string, zipcode string, state string"

    def reader(self, schema: StructType):
        return FakeDataSourceReader(schema, self.options)

    def writer(self, schema: StructType, overwrite: bool):
        return FakeDataSourceWriter(self.options)

    def streamReader(self, schema: StructType):
        return FakeStreamReader(schema, self.options)

    # Please skip the implementation of this method if streamReader has been implemented.
    def simpleStreamReader(self, schema: StructType):
        return SimpleStreamReader()

    def streamWriter(self, schema: StructType, overwrite: bool):
        return FakeStreamWriter(self.options)

为 Python 数据源实现批处理读写器#

实现读取器

定义读取器逻辑以生成合成数据。使用 faker 库填充 Schema 中的每个字段。

class FakeDataSourceReader(DataSourceReader):

    def __init__(self, schema, options):
        self.schema: StructType = schema
        self.options = options

    def read(self, partition):
        from faker import Faker
        fake = Faker()
        # Note: every value in this `self.options` dictionary is a string.
        num_rows = int(self.options.get("numRows", 3))
        for _ in range(num_rows):
            row = []
            for field in self.schema.fields:
                value = getattr(fake, field.name)()
                row.append(value)
            yield tuple(row)

实现写入器

创建一个模拟数据源写入器,它处理数据的每个分区,统计行数,并在成功写入后打印总行数,或者在写入过程失败时打印失败任务的数量。

from dataclasses import dataclass
from typing import Iterator, List

from pyspark.sql.types import Row
from pyspark.sql.datasource import DataSource, DataSourceWriter, WriterCommitMessage

@dataclass
class SimpleCommitMessage(WriterCommitMessage):
    partition_id: int
    count: int

class FakeDataSourceWriter(DataSourceWriter):

    def write(self, rows: Iterator[Row]) -> SimpleCommitMessage:
        from pyspark import TaskContext

        context = TaskContext.get()
        partition_id = context.partitionId()
        cnt = sum(1 for _ in rows)
        return SimpleCommitMessage(partition_id=partition_id, count=cnt)

    def commit(self, messages: List[SimpleCommitMessage]) -> None:
        total_count = sum(message.count for message in messages)
        print(f"Total number of rows: {total_count}")

    def abort(self, messages: List[SimpleCommitMessage]) -> None:
        failed_count = sum(message is None for message in messages)
        print(f"Number of failed tasks: {failed_count}")

为 Python 数据源实现流式读写器#

实现流式读取器

这是一个虚拟的流式数据读取器,它在每个微批次中生成 2 行数据。streamReader 实例有一个整数偏移量,在每个微批次中增加 2。

class RangePartition(InputPartition):
    def __init__(self, start, end):
        self.start = start
        self.end = end

class FakeStreamReader(DataSourceStreamReader):
    def __init__(self, schema, options):
        self.current = 0

    def initialOffset(self) -> dict:
        """
        Return the initial start offset of the reader.
        """
        return {"offset": 0}

    def latestOffset(self) -> dict:
        """
        Return the current latest offset that the next microbatch will read to.
        """
        self.current += 2
        return {"offset": self.current}

    def partitions(self, start: dict, end: dict):
        """
        Plans the partitioning of the current microbatch defined by start and end offset,
        it needs to return a sequence of :class:`InputPartition` object.
        """
        return [RangePartition(start["offset"], end["offset"])]

    def commit(self, end: dict):
        """
        This is invoked when the query has finished processing data before end offset, this can be used to clean up resource.
        """
        pass

    def read(self, partition) -> Iterator[Tuple]:
        """
        Takes a partition as an input and read an iterator of tuples from the data source.
        """
        start, end = partition.start, partition.end
        for i in range(start, end):
            yield (i, str(i))

实现简易流式读取器

如果数据源吞吐量较低且不需要分区,您可以实现 SimpleDataSourceStreamReader 而不是 DataSourceStreamReader。

对于可读的流式数据源,必须实现 simpleStreamReader() 和 streamReader() 中的一个。并且只有当 streamReader() 未实现时,才会调用 simpleStreamReader()。

这是使用 SimpleDataSourceStreamReader 接口实现的虚拟流式读取器,它在每个批次中生成 2 行数据。

class SimpleStreamReader(SimpleDataSourceStreamReader):
    def initialOffset(self):
        """
        Return the initial start offset of the reader.
        """
        return {"offset": 0}

    def read(self, start: dict) -> (Iterator[Tuple], dict):
        """
        Takes start offset as an input, return an iterator of tuples and the start offset of next read.
        """
        start_idx = start["offset"]
        it = iter([(i,) for i in range(start_idx, start_idx + 2)])
        return (it, {"offset": start_idx + 2})

    def readBetweenOffsets(self, start: dict, end: dict) -> Iterator[Tuple]:
        """
        Takes start and end offset as input and read an iterator of data deterministically.
        This is called whe query replay batches during restart or after failure.
        """
        start_idx = start["offset"]
        end_idx = end["offset"]
        return iter([(i,) for i in range(start_idx, end_idx)])

    def commit(self, end):
        """
        This is invoked when the query has finished processing data before end offset, this can be used to clean up resource.
        """
        pass

实现流式写入器

这是一个流式数据写入器,它将每个微批次的元数据信息写入本地路径。

class SimpleCommitMessage(WriterCommitMessage):
   partition_id: int
   count: int

class FakeStreamWriter(DataSourceStreamWriter):
   def __init__(self, options):
       self.options = options
       self.path = self.options.get("path")
       assert self.path is not None

   def write(self, iterator):
       """
       Write the data and return the commit message of that partition
       """
       from pyspark import TaskContext
       context = TaskContext.get()
       partition_id = context.partitionId()
       cnt = 0
       for row in iterator:
           cnt += 1
       return SimpleCommitMessage(partition_id=partition_id, count=cnt)

   def commit(self, messages, batchId) -> None:
       """
       Receives a sequence of :class:`WriterCommitMessage` when all write tasks succeed and decides what to do with it.
       In this FakeStreamWriter, we write the metadata of the microbatch(number of rows and partitions) into a json file inside commit().
       """
       status = dict(num_partitions=len(messages), rows=sum(m.count for m in messages))
       with open(os.path.join(self.path, f"{batchId}.json"), "a") as file:
           file.write(json.dumps(status) + "\n")

   def abort(self, messages, batchId) -> None:
       """
       Receives a sequence of :class:`WriterCommitMessage` from successful tasks when some tasks fail and decides what to do with it.
       In this FakeStreamWriter, we write a failure message into a txt file inside abort().
       """
       with open(os.path.join(self.path, f"{batchId}.txt"), "w") as file:
           file.write(f"failed in batch {batchId}")

序列化要求#

用户定义的 DataSource、DataSourceReader、DataSourceWriter、DataSourceStreamReader 和 DataSourceStreamWriter 及其方法必须能够通过 pickle 进行序列化。

对于在方法内部使用的库,它必须在方法内部导入。例如,在下面的代码中,TaskContext 必须在 read() 方法内部导入。

def read(self, partition):
    from pyspark import TaskContext
    context = TaskContext.get()

使用 Python 数据源#

在批处理查询中使用 Python 数据源

定义数据源后,必须先注册才能使用。

spark.dataSource.register(FakeDataSource)

从 Python 数据源读取

使用默认 Schema 和选项从模拟数据源读取

spark.read.format("fake").load().show()

# +-----------+----------+-------+-------+
# |       name|      date|zipcode|  state|
# +-----------+----------+-------+-------+
# |Carlos Cobb|2018-07-15|  73003|Indiana|
# | Eric Scott|1991-08-22|  10085|  Idaho|
# | Amy Martin|1988-10-28|  68076| Oregon|
# +-----------+----------+-------+-------+

使用自定义 Schema 从模拟数据源读取

spark.read.format("fake").schema("name string, company string").load().show()

# +---------------------+--------------+
# |name                 |company       |
# +---------------------+--------------+
# |Tanner Brennan       |Adams Group   |
# |Leslie Maxwell       |Santiago Group|
# |Mrs. Jacqueline Brown|Maynard Inc   |
# +---------------------+--------------+

从模拟数据源读取不同数量的行

spark.read.format("fake").option("numRows", 5).load().show()

# +--------------+----------+-------+------------+
# |          name|      date|zipcode|       state|
# +--------------+----------+-------+------------+
# |  Pam Mitchell|1988-10-20|  23788|   Tennessee|
# |Melissa Turner|1996-06-14|  30851|      Nevada|
# |  Brian Ramsey|2021-08-21|  55277|  Washington|
# |  Caitlin Reed|1983-06-22|  89813|Pennsylvania|
# | Douglas James|2007-01-18|  46226|     Alabama|
# +--------------+----------+-------+------------+

写入 Python 数据源

要将数据写入自定义位置,请确保指定 mode() 子句。支持的模式有 appendoverwrite

df = spark.range(0, 10, 1, 5)
df.write.format("fake").mode("append").save()

# You can check the Spark log (standard error) to see the output of the write operation.
# Total number of rows: 10

在流式查询中使用 Python 数据源

一旦我们注册了 Python 数据源,我们还可以通过将短名称或全名称传递给 format(),在流式查询中将其用作 readStream() 的源或 writeStream() 的汇。

启动一个从模拟 Python 数据源读取并写入控制台的查询

query = spark.readStream.format("fake").load().writeStream.format("console").start()

# +---+
# | id|
# +---+
# |  0|
# |  1|
# +---+
# +---+
# | id|
# +---+
# |  2|
# |  3|
# +---+

我们还可以在流式读取器和写入器中使用相同的数据源

query = spark.readStream.format("fake").load().writeStream.format("fake").start("/output_path")

支持直接 Arrow 批处理以提高性能的 Python 数据源读取器#

Python 数据源读取器支持直接生成 Arrow 批处理,这可以显著提高数据处理性能。通过使用高效的 Arrow 格式,此功能避免了传统逐行数据处理的开销,尤其是在处理大型数据集时,性能可提高一个数量级。

启用 Arrow 批处理支持:要启用此功能,请配置您的自定义 DataSource,使其在 DataSourceReader(或 DataSourceStreamReader)实现的 read 方法中返回 pyarrow.RecordBatch 对象以生成 Arrow 批处理。此方法简化了数据处理并减少了 I/O 操作次数,对于大规模数据处理任务尤其有益。

Arrow 批处理示例:以下示例演示了如何使用 Arrow 批处理支持实现一个基本的数据源。

from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
from pyspark.sql import SparkSession
import pyarrow as pa

# Define the ArrowBatchDataSource
class ArrowBatchDataSource(DataSource):
    """
    A Data Source for testing Arrow Batch Serialization
    """

    @classmethod
    def name(cls):
        return "arrowbatch"

    def schema(self):
        return "key int, value string"

    def reader(self, schema: str):
        return ArrowBatchDataSourceReader(schema, self.options)

# Define the ArrowBatchDataSourceReader
class ArrowBatchDataSourceReader(DataSourceReader):
    def __init__(self, schema, options):
        self.schema: str = schema
        self.options = options

    def read(self, partition):
        # Create Arrow Record Batch
        keys = pa.array([1, 2, 3, 4, 5], type=pa.int32())
        values = pa.array(["one", "two", "three", "four", "five"], type=pa.string())
        schema = pa.schema([("key", pa.int32()), ("value", pa.string())])
        record_batch = pa.RecordBatch.from_arrays([keys, values], schema=schema)
        yield record_batch

    def partitions(self):
        # Define the number of partitions
        num_part = 1
        return [InputPartition(i) for i in range(num_part)]

# Initialize the Spark Session
spark = SparkSession.builder.appName("ArrowBatchExample").getOrCreate()

# Register the ArrowBatchDataSource
spark.dataSource.register(ArrowBatchDataSource)

# Load data using the custom data source
df = spark.read.format("arrowbatch").load()

df.show()

使用须知#

  • 在数据源解析期间,内置和 Scala/Java 数据源优先于同名的 Python 数据源;要显式使用 Python 数据源,请确保其名称与其他数据源不冲突。