Python 数据源 API#
概览#
Python 数据源 API 是 Spark 4.0 中引入的一项新功能,它使开发人员能够用 Python 从自定义数据源读取数据并写入自定义数据汇。本指南提供了该 API 的全面概述,以及如何创建、使用和管理 Python 数据源的说明。
简单示例#
这是一个简单的 Python 数据源,它精确地生成两行合成数据。此示例演示了如何在不使用外部库的情况下设置自定义数据源,重点介绍了快速启动和运行所需的要素。
步骤 1:定义数据源
from pyspark.sql.datasource import DataSource, DataSourceReader
from pyspark.sql.types import IntegerType, StringType, StructField, StructType
class SimpleDataSource(DataSource):
"""
A simple data source for PySpark that generates exactly two rows of synthetic data.
"""
@classmethod
def name(cls):
return "simple"
def schema(self):
return StructType([
StructField("name", StringType()),
StructField("age", IntegerType())
])
def reader(self, schema: StructType):
return SimpleDataSourceReader()
class SimpleDataSourceReader(DataSourceReader):
def read(self, partition):
yield ("Alice", 20)
yield ("Bob", 30)
步骤 2:注册数据源
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
spark.dataSource.register(SimpleDataSource)
步骤 3:从数据源读取
spark.read.format("simple").load().show()
# +-----+---+
# | name|age|
# +-----+---+
# |Alice| 20|
# | Bob| 30|
# +-----+---+
创建 Python 数据源#
要创建自定义 Python 数据源,您需要继承 DataSource
基类,并实现读写数据所需的方法。
此示例演示了如何使用 faker 库创建一个简单的数据源来生成合成数据。请确保 faker 库已安装并在您的 Python 环境中可访问。
定义数据源
首先,创建一个新的 DataSource
子类,并指定源名称和 Schema。
为了在批处理或流式查询中用作源或汇,需要实现 DataSource 的相应方法。
需要为某项功能实现的方法
源 |
汇 |
|
---|---|---|
批处理 |
reader() |
writer() |
流式 |
streamReader() 或 simpleStreamReader() |
streamWriter() |
from pyspark.sql.datasource import DataSource, DataSourceReader
from pyspark.sql.types import StructType
class FakeDataSource(DataSource):
"""
A fake data source for PySpark to generate synthetic data using the `faker` library.
Options:
- numRows: specify number of rows to generate. Default value is 3.
"""
@classmethod
def name(cls):
return "fake"
def schema(self):
return "name string, date string, zipcode string, state string"
def reader(self, schema: StructType):
return FakeDataSourceReader(schema, self.options)
def writer(self, schema: StructType, overwrite: bool):
return FakeDataSourceWriter(self.options)
def streamReader(self, schema: StructType):
return FakeStreamReader(schema, self.options)
# Please skip the implementation of this method if streamReader has been implemented.
def simpleStreamReader(self, schema: StructType):
return SimpleStreamReader()
def streamWriter(self, schema: StructType, overwrite: bool):
return FakeStreamWriter(self.options)
为 Python 数据源实现批处理读写器#
实现读取器
定义读取器逻辑以生成合成数据。使用 faker 库填充 Schema 中的每个字段。
class FakeDataSourceReader(DataSourceReader):
def __init__(self, schema, options):
self.schema: StructType = schema
self.options = options
def read(self, partition):
from faker import Faker
fake = Faker()
# Note: every value in this `self.options` dictionary is a string.
num_rows = int(self.options.get("numRows", 3))
for _ in range(num_rows):
row = []
for field in self.schema.fields:
value = getattr(fake, field.name)()
row.append(value)
yield tuple(row)
实现写入器
创建一个模拟数据源写入器,它处理数据的每个分区,统计行数,并在成功写入后打印总行数,或者在写入过程失败时打印失败任务的数量。
from dataclasses import dataclass
from typing import Iterator, List
from pyspark.sql.types import Row
from pyspark.sql.datasource import DataSource, DataSourceWriter, WriterCommitMessage
@dataclass
class SimpleCommitMessage(WriterCommitMessage):
partition_id: int
count: int
class FakeDataSourceWriter(DataSourceWriter):
def write(self, rows: Iterator[Row]) -> SimpleCommitMessage:
from pyspark import TaskContext
context = TaskContext.get()
partition_id = context.partitionId()
cnt = sum(1 for _ in rows)
return SimpleCommitMessage(partition_id=partition_id, count=cnt)
def commit(self, messages: List[SimpleCommitMessage]) -> None:
total_count = sum(message.count for message in messages)
print(f"Total number of rows: {total_count}")
def abort(self, messages: List[SimpleCommitMessage]) -> None:
failed_count = sum(message is None for message in messages)
print(f"Number of failed tasks: {failed_count}")
为 Python 数据源实现流式读写器#
实现流式读取器
这是一个虚拟的流式数据读取器,它在每个微批次中生成 2 行数据。streamReader 实例有一个整数偏移量,在每个微批次中增加 2。
class RangePartition(InputPartition):
def __init__(self, start, end):
self.start = start
self.end = end
class FakeStreamReader(DataSourceStreamReader):
def __init__(self, schema, options):
self.current = 0
def initialOffset(self) -> dict:
"""
Return the initial start offset of the reader.
"""
return {"offset": 0}
def latestOffset(self) -> dict:
"""
Return the current latest offset that the next microbatch will read to.
"""
self.current += 2
return {"offset": self.current}
def partitions(self, start: dict, end: dict):
"""
Plans the partitioning of the current microbatch defined by start and end offset,
it needs to return a sequence of :class:`InputPartition` object.
"""
return [RangePartition(start["offset"], end["offset"])]
def commit(self, end: dict):
"""
This is invoked when the query has finished processing data before end offset, this can be used to clean up resource.
"""
pass
def read(self, partition) -> Iterator[Tuple]:
"""
Takes a partition as an input and read an iterator of tuples from the data source.
"""
start, end = partition.start, partition.end
for i in range(start, end):
yield (i, str(i))
实现简易流式读取器
如果数据源吞吐量较低且不需要分区,您可以实现 SimpleDataSourceStreamReader 而不是 DataSourceStreamReader。
对于可读的流式数据源,必须实现 simpleStreamReader() 和 streamReader() 中的一个。并且只有当 streamReader() 未实现时,才会调用 simpleStreamReader()。
这是使用 SimpleDataSourceStreamReader 接口实现的虚拟流式读取器,它在每个批次中生成 2 行数据。
class SimpleStreamReader(SimpleDataSourceStreamReader):
def initialOffset(self):
"""
Return the initial start offset of the reader.
"""
return {"offset": 0}
def read(self, start: dict) -> (Iterator[Tuple], dict):
"""
Takes start offset as an input, return an iterator of tuples and the start offset of next read.
"""
start_idx = start["offset"]
it = iter([(i,) for i in range(start_idx, start_idx + 2)])
return (it, {"offset": start_idx + 2})
def readBetweenOffsets(self, start: dict, end: dict) -> Iterator[Tuple]:
"""
Takes start and end offset as input and read an iterator of data deterministically.
This is called whe query replay batches during restart or after failure.
"""
start_idx = start["offset"]
end_idx = end["offset"]
return iter([(i,) for i in range(start_idx, end_idx)])
def commit(self, end):
"""
This is invoked when the query has finished processing data before end offset, this can be used to clean up resource.
"""
pass
实现流式写入器
这是一个流式数据写入器,它将每个微批次的元数据信息写入本地路径。
class SimpleCommitMessage(WriterCommitMessage):
partition_id: int
count: int
class FakeStreamWriter(DataSourceStreamWriter):
def __init__(self, options):
self.options = options
self.path = self.options.get("path")
assert self.path is not None
def write(self, iterator):
"""
Write the data and return the commit message of that partition
"""
from pyspark import TaskContext
context = TaskContext.get()
partition_id = context.partitionId()
cnt = 0
for row in iterator:
cnt += 1
return SimpleCommitMessage(partition_id=partition_id, count=cnt)
def commit(self, messages, batchId) -> None:
"""
Receives a sequence of :class:`WriterCommitMessage` when all write tasks succeed and decides what to do with it.
In this FakeStreamWriter, we write the metadata of the microbatch(number of rows and partitions) into a json file inside commit().
"""
status = dict(num_partitions=len(messages), rows=sum(m.count for m in messages))
with open(os.path.join(self.path, f"{batchId}.json"), "a") as file:
file.write(json.dumps(status) + "\n")
def abort(self, messages, batchId) -> None:
"""
Receives a sequence of :class:`WriterCommitMessage` from successful tasks when some tasks fail and decides what to do with it.
In this FakeStreamWriter, we write a failure message into a txt file inside abort().
"""
with open(os.path.join(self.path, f"{batchId}.txt"), "w") as file:
file.write(f"failed in batch {batchId}")
序列化要求#
用户定义的 DataSource、DataSourceReader、DataSourceWriter、DataSourceStreamReader 和 DataSourceStreamWriter 及其方法必须能够通过 pickle 进行序列化。
对于在方法内部使用的库,它必须在方法内部导入。例如,在下面的代码中,TaskContext 必须在 read() 方法内部导入。
def read(self, partition):
from pyspark import TaskContext
context = TaskContext.get()
使用 Python 数据源#
在批处理查询中使用 Python 数据源
定义数据源后,必须先注册才能使用。
spark.dataSource.register(FakeDataSource)
从 Python 数据源读取
使用默认 Schema 和选项从模拟数据源读取
spark.read.format("fake").load().show()
# +-----------+----------+-------+-------+
# | name| date|zipcode| state|
# +-----------+----------+-------+-------+
# |Carlos Cobb|2018-07-15| 73003|Indiana|
# | Eric Scott|1991-08-22| 10085| Idaho|
# | Amy Martin|1988-10-28| 68076| Oregon|
# +-----------+----------+-------+-------+
使用自定义 Schema 从模拟数据源读取
spark.read.format("fake").schema("name string, company string").load().show()
# +---------------------+--------------+
# |name |company |
# +---------------------+--------------+
# |Tanner Brennan |Adams Group |
# |Leslie Maxwell |Santiago Group|
# |Mrs. Jacqueline Brown|Maynard Inc |
# +---------------------+--------------+
从模拟数据源读取不同数量的行
spark.read.format("fake").option("numRows", 5).load().show()
# +--------------+----------+-------+------------+
# | name| date|zipcode| state|
# +--------------+----------+-------+------------+
# | Pam Mitchell|1988-10-20| 23788| Tennessee|
# |Melissa Turner|1996-06-14| 30851| Nevada|
# | Brian Ramsey|2021-08-21| 55277| Washington|
# | Caitlin Reed|1983-06-22| 89813|Pennsylvania|
# | Douglas James|2007-01-18| 46226| Alabama|
# +--------------+----------+-------+------------+
写入 Python 数据源
要将数据写入自定义位置,请确保指定 mode() 子句。支持的模式有 append 和 overwrite。
df = spark.range(0, 10, 1, 5)
df.write.format("fake").mode("append").save()
# You can check the Spark log (standard error) to see the output of the write operation.
# Total number of rows: 10
在流式查询中使用 Python 数据源
一旦我们注册了 Python 数据源,我们还可以通过将短名称或全名称传递给 format(),在流式查询中将其用作 readStream() 的源或 writeStream() 的汇。
启动一个从模拟 Python 数据源读取并写入控制台的查询
query = spark.readStream.format("fake").load().writeStream.format("console").start()
# +---+
# | id|
# +---+
# | 0|
# | 1|
# +---+
# +---+
# | id|
# +---+
# | 2|
# | 3|
# +---+
我们还可以在流式读取器和写入器中使用相同的数据源
query = spark.readStream.format("fake").load().writeStream.format("fake").start("/output_path")
支持直接 Arrow 批处理以提高性能的 Python 数据源读取器#
Python 数据源读取器支持直接生成 Arrow 批处理,这可以显著提高数据处理性能。通过使用高效的 Arrow 格式,此功能避免了传统逐行数据处理的开销,尤其是在处理大型数据集时,性能可提高一个数量级。
启用 Arrow 批处理支持:要启用此功能,请配置您的自定义 DataSource,使其在 DataSourceReader(或 DataSourceStreamReader)实现的 read 方法中返回 pyarrow.RecordBatch 对象以生成 Arrow 批处理。此方法简化了数据处理并减少了 I/O 操作次数,对于大规模数据处理任务尤其有益。
Arrow 批处理示例:以下示例演示了如何使用 Arrow 批处理支持实现一个基本的数据源。
from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
from pyspark.sql import SparkSession
import pyarrow as pa
# Define the ArrowBatchDataSource
class ArrowBatchDataSource(DataSource):
"""
A Data Source for testing Arrow Batch Serialization
"""
@classmethod
def name(cls):
return "arrowbatch"
def schema(self):
return "key int, value string"
def reader(self, schema: str):
return ArrowBatchDataSourceReader(schema, self.options)
# Define the ArrowBatchDataSourceReader
class ArrowBatchDataSourceReader(DataSourceReader):
def __init__(self, schema, options):
self.schema: str = schema
self.options = options
def read(self, partition):
# Create Arrow Record Batch
keys = pa.array([1, 2, 3, 4, 5], type=pa.int32())
values = pa.array(["one", "two", "three", "four", "five"], type=pa.string())
schema = pa.schema([("key", pa.int32()), ("value", pa.string())])
record_batch = pa.RecordBatch.from_arrays([keys, values], schema=schema)
yield record_batch
def partitions(self):
# Define the number of partitions
num_part = 1
return [InputPartition(i) for i in range(num_part)]
# Initialize the Spark Session
spark = SparkSession.builder.appName("ArrowBatchExample").getOrCreate()
# Register the ArrowBatchDataSource
spark.dataSource.register(ArrowBatchDataSource)
# Load data using the custom data source
df = spark.read.format("arrowbatch").load()
df.show()
使用须知#
在数据源解析期间,内置和 Scala/Java 数据源优先于同名的 Python 数据源;要显式使用 Python 数据源,请确保其名称与其他数据源不冲突。