Skip to content

pandas User-Defined Aggregate Functions

pandas User-Defined Aggregate Functions (pandas UDAFs) are PythonUDFs (with optional PandasUDFType.GROUPED_AGG function type) to used as aggregation functions in GroupedData.agg operator.

pandas UDAFs are also known as Group Aggregate pandas UDFs.

Limitations

  1. There is no partial aggregation with group aggregate UDFs (i.e., a full shuffle is required).
  2. All the data of a group will be loaded into memory, so there is a potential OOM risk if data is skewed and certain groups are too large to fit in memory
  3. Group aggregate pandas UDFs and built-in aggregation functions cannot be mixed in a single GroupedData.agg operator. Otherwise, the following AnalysisException is thrown:

    [INVALID_PANDAS_UDF_PLACEMENT] The group aggregate pandas UDF `my_udaf` cannot be invoked together with as other, non-pandas aggregate functions.
    

Demo

import pandas as pd
from pyspark.sql.functions import pandas_udf
@pandas_udf(returnType = "long")
def my_count(s: pd.Series) -> 'long':
    return pd.Series(s.count())
from pyspark.sql.functions import abs
nums = spark.range(5) # FIXME More meaningful dataset
grouped_nums = (nums
    .withColumn("gid", abs((nums.id * 100) % 2))
    .groupBy("gid"))
count_by_gid_agg = my_count("gid").alias("count")
counts_by_gid = grouped_nums.agg(count_by_gid_agg)
counts_by_gid.show()