第 3 章:函数交汇 - 使用 PySpark 进行数据操作#

清洗数据#

在数据科学中,垃圾进,垃圾出(GIGO)是指有缺陷、有偏见或低质量的信息或输入会产生类似质量的结果或输出的概念。为了提高分析质量,我们需要数据清洗,这个过程将垃圾转化为黄金,它包括识别、纠正或删除数据中的错误和不一致性,以提高数据的质量和可用性。

我们从一个包含不良值的 DataFrame 开始。

[1]:
!pip install pyspark==4.0.0.dev2
[2]:
from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("Data Loading and Storage Example") \
    .getOrCreate()
[3]:
from pyspark.sql import Row

df = spark.createDataFrame([
    Row(age=10, height=80.0, NAME="Alice"),
    Row(age=10, height=80.0, NAME="Alice"),
    Row(age=5, height=float("nan"), NAME="BOB"),
    Row(age=None, height=None, NAME="Tom"),
    Row(age=None, height=float("nan"), NAME=None),
    Row(age=9, height=78.9, NAME="josh"),
    Row(age=18, height=1802.3, NAME="bush"),
    Row(age=7, height=75.3, NAME="jerry"),
])

df.show()
+----+------+-----+
| age|height| NAME|
+----+------+-----+
|  10|  80.0|Alice|
|  10|  80.0|Alice|
|   5|   NaN|  BOB|
|NULL|  NULL|  Tom|
|NULL|   NaN| NULL|
|   9|  78.9| josh|
|  18|1802.3| bush|
|   7|  75.3|jerry|
+----+------+-----+

重命名列#

乍一看,我们发现 NAME 列是大写的。为了保持一致性,我们可以使用 DataFrame.withColumnRenamed 来重命名列。

[4]:
df2 = df.withColumnRenamed("NAME", "name")

df2.show()
+----+------+-----+
| age|height| name|
+----+------+-----+
|  10|  80.0|Alice|
|  10|  80.0|Alice|
|   5|   NaN|  BOB|
|NULL|  NULL|  Tom|
|NULL|   NaN| NULL|
|   9|  78.9| josh|
|  18|1802.3| bush|
|   7|  75.3|jerry|
+----+------+-----+

删除空值#

然后我们可以注意到有两种缺失数据:

  • 所有三列中的 NULL 值;

  • 数值列中的 NaN 值,表示“非数字”;

没有有效 name 的记录很可能无用,所以我们先删除它们。DataFrameNaFunctions 中有一组用于处理缺失值的函数,我们可以使用 DataFrame.na.dropDataFrame.dropna 来省略包含 NULLNaN 值的行。

执行 df2.na.drop(subset="name") 这一步后,无效记录 (age=None, height=NaN, name=None) 被丢弃。

[5]:
df3 = df2.na.drop(subset="name")

df3.show()
+----+------+-----+
| age|height| name|
+----+------+-----+
|  10|  80.0|Alice|
|  10|  80.0|Alice|
|   5|   NaN|  BOB|
|NULL|  NULL|  Tom|
|   9|  78.9| josh|
|  18|1802.3| bush|
|   7|  75.3|jerry|
+----+------+-----+

填充值#

对于剩余的缺失值,我们可以使用 DataFrame.na.fillDataFrame.fillna 来填充它们。

通过 Dict 输入 {'age': 10, 'height': 80.1},我们可以同时指定 ageheight 列的值。

[6]:
df4 = df3.na.fill({'age': 10, 'height': 80.1})

df4.show()
+---+------+-----+
|age|height| name|
+---+------+-----+
| 10|  80.0|Alice|
| 10|  80.0|Alice|
|  5|  80.1|  BOB|
| 10|  80.1|  Tom|
|  9|  78.9| josh|
| 18|1802.3| bush|
|  7|  75.3|jerry|
+---+------+-----+

移除异常值#

经过上述步骤,所有缺失值都被删除或填充了。然而,我们发现 height=1802.3 似乎不合理,为了移除这类异常值,我们可以使用一个有效范围,例如 (65, 85) 来过滤 DataFrame。

[7]:
df5 = df4.where(df4.height.between(65, 85))

df5.show()
+---+------+-----+
|age|height| name|
+---+------+-----+
| 10|  80.0|Alice|
| 10|  80.0|Alice|
|  5|  80.1|  BOB|
| 10|  80.1|  Tom|
|  9|  78.9| josh|
|  7|  75.3|jerry|
+---+------+-----+

移除重复项#

现在,所有无效记录都已处理完毕。但我们注意到记录 (age=10, height=80.0, name=Alice) 出现了重复。要移除此类重复项,我们可以直接应用 DataFrame.distinct

[8]:
df6 = df5.distinct()

df6.show()
+---+------+-----+
|age|height| name|
+---+------+-----+
| 10|  80.0|Alice|
|  5|  80.1|  BOB|
| 10|  80.1|  Tom|
|  9|  78.9| josh|
|  7|  75.3|jerry|
+---+------+-----+

字符串操作#

name 列同时包含小写和大写字母。我们可以应用 lower() 函数将所有字母转换为小写。

[9]:
from pyspark.sql import functions as sf

df7 = df6.withColumn("name", sf.lower("name"))
df7.show()
+---+------+-----+
|age|height| name|
+---+------+-----+
| 10|  80.0|alice|
|  5|  80.1|  bob|
| 10|  80.1|  tom|
|  9|  78.9| josh|
|  7|  75.3|jerry|
+---+------+-----+

对于更复杂的字符串操作,我们还可以使用 udf 来利用 Python 的强大函数。

[10]:
from pyspark.sql import functions as sf

capitalize = sf.udf(lambda s: s.capitalize())

df8 = df6.withColumn("name", capitalize("name"))
df8.show()
+---+------+-----+
|age|height| name|
+---+------+-----+
| 10|  80.0|Alice|
|  5|  80.1|  Bob|
| 10|  80.1|  Tom|
|  9|  78.9| Josh|
|  7|  75.3|Jerry|
+---+------+-----+

重新排序列#

经过上述过程,数据已清洗完毕,我们希望在将 DataFrame 保存到存储之前重新排列列。有关更多详细信息,请参阅前一章 加载与存储:数据加载、存储、文件格式

通常,我们为此目的使用 DataFrame.select

[11]:
df9 = df7.select("name", "age", "height")

df9.show()
+-----+---+------+
| name|age|height|
+-----+---+------+
|alice| 10|  80.0|
|  bob|  5|  80.1|
|  tom| 10|  80.1|
| josh|  9|  78.9|
|jerry|  7|  75.3|
+-----+---+------+

转换数据#

数据工程项目的主要部分是数据转换。我们从旧的 DataFrame 创建新的 DataFrame。

使用 select() 选择列#

输入表可能包含数百列,但对于特定项目,我们可能只对其中的一小部分感兴趣。

[12]:
from pyspark.sql import functions as sf
df = spark.range(10)

for i in range(20):
  df = df.withColumn(f"col_{i}", sf.lit(i))

df.show()
+---+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+------+------+------+------+------+------+------+------+------+
| id|col_0|col_1|col_2|col_3|col_4|col_5|col_6|col_7|col_8|col_9|col_10|col_11|col_12|col_13|col_14|col_15|col_16|col_17|col_18|col_19|
+---+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+------+------+------+------+------+------+------+------+------+
|  0|    0|    1|    2|    3|    4|    5|    6|    7|    8|    9|    10|    11|    12|    13|    14|    15|    16|    17|    18|    19|
|  1|    0|    1|    2|    3|    4|    5|    6|    7|    8|    9|    10|    11|    12|    13|    14|    15|    16|    17|    18|    19|
|  2|    0|    1|    2|    3|    4|    5|    6|    7|    8|    9|    10|    11|    12|    13|    14|    15|    16|    17|    18|    19|
|  3|    0|    1|    2|    3|    4|    5|    6|    7|    8|    9|    10|    11|    12|    13|    14|    15|    16|    17|    18|    19|
|  4|    0|    1|    2|    3|    4|    5|    6|    7|    8|    9|    10|    11|    12|    13|    14|    15|    16|    17|    18|    19|
|  5|    0|    1|    2|    3|    4|    5|    6|    7|    8|    9|    10|    11|    12|    13|    14|    15|    16|    17|    18|    19|
|  6|    0|    1|    2|    3|    4|    5|    6|    7|    8|    9|    10|    11|    12|    13|    14|    15|    16|    17|    18|    19|
|  7|    0|    1|    2|    3|    4|    5|    6|    7|    8|    9|    10|    11|    12|    13|    14|    15|    16|    17|    18|    19|
|  8|    0|    1|    2|    3|    4|    5|    6|    7|    8|    9|    10|    11|    12|    13|    14|    15|    16|    17|    18|    19|
|  9|    0|    1|    2|    3|    4|    5|    6|    7|    8|    9|    10|    11|    12|    13|    14|    15|    16|    17|    18|    19|
+---+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+------+------+------+------+------+------+------+------+------+

我们通过 for 循环创建一个包含 21 列的 DataFrame,然后只通过 select 选择 4 列。idcol_2col_3 列直接从之前的 DataFrame 中选择,而 sqrt_col_4_plus_5 列则由数学函数生成。

pyspark.sql.functionpyspark.sql.Column 中,我们有数百个用于列操作的函数。

[13]:

df2 = df.select("id", "col_2", "col_3", sf.sqrt(sf.col("col_4") + sf.col("col_5")).alias("sqrt_col_4_plus_5")) df2.show()
+---+-----+-----+-----------------+
| id|col_2|col_3|sqrt_col_4_plus_5|
+---+-----+-----+-----------------+
|  0|    2|    3|              3.0|
|  1|    2|    3|              3.0|
|  2|    2|    3|              3.0|
|  3|    2|    3|              3.0|
|  4|    2|    3|              3.0|
|  5|    2|    3|              3.0|
|  6|    2|    3|              3.0|
|  7|    2|    3|              3.0|
|  8|    2|    3|              3.0|
|  9|    2|    3|              3.0|
+---+-----+-----+-----------------+

使用 where() 过滤行#

输入表可能非常庞大,包含数十亿行,我们可能也只对其中的一小部分感兴趣。

我们可以使用 wherefilter 以及指定的条件来过滤行。

例如,我们可以选择 id 值为奇数的行。

[14]:
df3 = df2.where(sf.col("id") % 2 == 1)

df3.show()
+---+-----+-----+-----------------+
| id|col_2|col_3|sqrt_col_4_plus_5|
+---+-----+-----+-----------------+
|  1|    2|    3|              3.0|
|  3|    2|    3|              3.0|
|  5|    2|    3|              3.0|
|  7|    2|    3|              3.0|
|  9|    2|    3|              3.0|
+---+-----+-----+-----------------+

汇总数据#

在数据分析中,我们通常会汇总数据以生成图表或表格。

[15]:
from pyspark.sql import Row

df = spark.createDataFrame([
    Row(incomes=[123.0, 456.0, 789.0], NAME="Alice"),
    Row(incomes=[234.0, 567.0], NAME="BOB"),
    Row(incomes=[100.0, 200.0, 100.0], NAME="Tom"),
    Row(incomes=[79.0, 128.0], NAME="josh"),
    Row(incomes=[123.0, 145.0, 178.0], NAME="bush"),
    Row(incomes=[111.0, 187.0, 451.0, 188.0, 199.0], NAME="jerry"),
])

df.show()
+--------------------+-----+
|             incomes| NAME|
+--------------------+-----+
|[123.0, 456.0, 78...|Alice|
|      [234.0, 567.0]|  BOB|
|[100.0, 200.0, 10...|  Tom|
|       [79.0, 128.0]| josh|
|[123.0, 145.0, 17...| bush|
|[111.0, 187.0, 45...|jerry|
+--------------------+-----+

例如,给定每月收入,我们想找出每个名称的平均收入。

[16]:
from pyspark.sql import functions as sf

df2 = df.select(sf.lower("NAME").alias("name"), "incomes")

df2.show(truncate=False)
+-----+-----------------------------------+
|name |incomes                            |
+-----+-----------------------------------+
|alice|[123.0, 456.0, 789.0]              |
|bob  |[234.0, 567.0]                     |
|tom  |[100.0, 200.0, 100.0]              |
|josh |[79.0, 128.0]                      |
|bush |[123.0, 145.0, 178.0]              |
|jerry|[111.0, 187.0, 451.0, 188.0, 199.0]|
+-----+-----------------------------------+

使用 explode() 重塑数据#

为了方便数据聚合,我们可以使用 explode() 函数来重塑数据

[17]:
df3 = df2.select("name", sf.explode("incomes").alias("income"))

df3.show()
+-----+------+
| name|income|
+-----+------+
|alice| 123.0|
|alice| 456.0|
|alice| 789.0|
|  bob| 234.0|
|  bob| 567.0|
|  tom| 100.0|
|  tom| 200.0|
|  tom| 100.0|
| josh|  79.0|
| josh| 128.0|
| bush| 123.0|
| bush| 145.0|
| bush| 178.0|
|jerry| 111.0|
|jerry| 187.0|
|jerry| 451.0|
|jerry| 188.0|
|jerry| 199.0|
+-----+------+

通过 groupBy() 和 agg() 汇总数据#

然后我们通常使用 DataFrame.groupBy(...).agg(...) 来聚合数据。要计算平均收入,我们可以应用聚合函数 avg

[18]:
df4 = df3.groupBy("name").agg(sf.avg("income").alias("avg_income"))

df4.show()
+-----+------------------+
| name|        avg_income|
+-----+------------------+
|alice|             456.0|
|  bob|             400.5|
|  tom|133.33333333333334|
| josh|             103.5|
| bush|148.66666666666666|
|jerry|             227.2|
+-----+------------------+

排序#

对于最终分析,我们通常希望对数据进行排序。在这种情况下,我们可以按 name 排序数据。

[19]:
df5 = df4.orderBy("name")

df5.show()
+-----+------------------+
| name|        avg_income|
+-----+------------------+
|alice|             456.0|
|  bob|             400.5|
| bush|148.66666666666666|
|jerry|             227.2|
| josh|             103.5|
|  tom|133.33333333333334|
+-----+------------------+

当 DataFrame 碰撞时:连接的艺术#

当处理多个 DataFrame 时,我们可能需要以某种方式将它们组合在一起。最常用的方法是连接(joining)。

例如,给定 incomes 数据和 height 数据,我们可以使用 DataFrame.join 通过 name 将它们连接起来。

我们可以看到最终结果中只有 alicejoshbush,因为它们同时出现在两个 DataFrame 中。

[20]:
from pyspark.sql import Row

df1 = spark.createDataFrame([
    Row(age=10, height=80.0, name="alice"),
    Row(age=9, height=78.9, name="josh"),
    Row(age=18, height=82.3, name="bush"),
    Row(age=7, height=75.3, name="tom"),
])

df2 = spark.createDataFrame([
    Row(incomes=[123.0, 456.0, 789.0], name="alice"),
    Row(incomes=[234.0, 567.0], name="bob"),
    Row(incomes=[79.0, 128.0], name="josh"),
    Row(incomes=[123.0, 145.0, 178.0], name="bush"),
    Row(incomes=[111.0, 187.0, 451.0, 188.0, 199.0], name="jerry"),
])
[21]:
df3 = df1.join(df2, on="name")

df3.show(truncate=False)
+-----+---+------+---------------------+
|name |age|height|incomes              |
+-----+---+------+---------------------+
|alice|10 |80.0  |[123.0, 456.0, 789.0]|
|bush |18 |82.3  |[123.0, 145.0, 178.0]|
|josh |9  |78.9  |[79.0, 128.0]        |
+-----+---+------+---------------------+

有七种连接方法: - INNER - LEFT - RIGHT - FULL - CROSS - LEFTSEMI - LEFTANTI

默认方法是 INNER

我们以 LEFT 连接为例。左连接(left join)包含来自两个表中的第一个(左侧)表的所有记录,即使第二个(右侧)表中没有匹配的记录值。

[22]:
df4 = df1.join(df2, on="name", how="left")

df4.show(truncate=False)
+-----+---+------+---------------------+
|name |age|height|incomes              |
+-----+---+------+---------------------+
|alice|10 |80.0  |[123.0, 456.0, 789.0]|
|josh |9  |78.9  |[79.0, 128.0]        |
|bush |18 |82.3  |[123.0, 145.0, 178.0]|
|tom  |7  |75.3  |NULL                 |
+-----+---+------+---------------------+

RIGHT 连接则保留右侧表的所有记录。

[23]:
df5 = df1.join(df2, on="name", how="right")

df5.show(truncate=False)
+-----+----+------+-----------------------------------+
|name |age |height|incomes                            |
+-----+----+------+-----------------------------------+
|alice|10  |80.0  |[123.0, 456.0, 789.0]              |
|bob  |NULL|NULL  |[234.0, 567.0]                     |
|josh |9   |78.9  |[79.0, 128.0]                      |
|bush |18  |82.3  |[123.0, 145.0, 178.0]              |
|jerry|NULL|NULL  |[111.0, 187.0, 451.0, 188.0, 199.0]|
+-----+----+------+-----------------------------------+