pandas User-Defined Aggregate Functions¶
pandas User-Defined Aggregate Functions (pandas UDAFs) are PythonUDFs (with optional PandasUDFType.GROUPED_AGG function type) to used as aggregation functions in GroupedData.agg operator.
pandas UDAFs are also known as Group Aggregate pandas UDFs.
- There is no partial aggregation with group aggregate UDFs (i.e., a full shuffle is required).
- All the data of a group will be loaded into memory, so there is a potential OOM risk if data is skewed and certain groups are too large to fit in memory
Group aggregate pandas UDFs and built-in aggregation functions cannot be mixed in a single GroupedData.agg operator. Otherwise, the following
[INVALID_PANDAS_UDF_PLACEMENT] The group aggregate pandas UDF `my_udaf` cannot be invoked together with as other, non-pandas aggregate functions.
import pandas as pd from pyspark.sql.functions import pandas_udf
@pandas_udf(returnType = "long") def my_count(s: pd.Series) -> 'long': return pd.Series(s.count())
from pyspark.sql.functions import abs nums = spark.range(5) # FIXME More meaningful dataset grouped_nums = (nums .withColumn("gid", abs((nums.id * 100) % 2)) .groupBy("gid")) count_by_gid_agg = my_count("gid").alias("count") counts_by_gid = grouped_nums.agg(count_by_gid_agg)