pandas User-Defined Aggregate Functions¶
pandas User-Defined Aggregate Functions (pandas UDAFs) are PythonUDFs (with optional PandasUDFType.GROUPED_AGG function type) to used as aggregation functions in GroupedData.agg operator.
pandas UDAFs are also known as Group Aggregate pandas UDFs.
Limitations¶
- There is no partial aggregation with group aggregate UDFs (i.e., a full shuffle is required).
- All the data of a group will be loaded into memory, so there is a potential OOM risk if data is skewed and certain groups are too large to fit in memory
-
Group aggregate pandas UDFs and built-in aggregation functions cannot be mixed in a single GroupedData.agg operator. Otherwise, the following
AnalysisException
is thrown:[INVALID_PANDAS_UDF_PLACEMENT] The group aggregate pandas UDF `my_udaf` cannot be invoked together with as other, non-pandas aggregate functions.
Demo¶
import pandas as pd
from pyspark.sql.functions import pandas_udf
@pandas_udf(returnType = "long")
def my_count(s: pd.Series) -> 'long':
return pd.Series(s.count())
from pyspark.sql.functions import abs
nums = spark.range(5) # FIXME More meaningful dataset
grouped_nums = (nums
.withColumn("gid", abs((nums.id * 100) % 2))
.groupBy("gid"))
count_by_gid_agg = my_count("gid").alias("count")
counts_by_gid = grouped_nums.agg(count_by_gid_agg)
counts_by_gid.show()