How to handle data skew in Spark

In spark, we partition(split) data to process it parallely. But, if the splitting ‘unfair’, like some parition has humougus amount of data while others have too little. This is called data skew. This is undesirable. If we process such skewed data, some nodes will be OOM while some others will sit idle(as they alreaady processed the little data they were assigned).

What causes data skew?

Uneven Key Distribution in Join Operations

  • Suppose you have two datasets (key-value pairs) you want to join:
    • Dataset A: [(China, "A"), (Fiji, "B"), (Bhutan, "C")]
    • Dataset B: [(China, "X"), (China, "Y"), (China, "Z"), (Fiji, "P"), (Bhutan, "Q")]
  • When joining on the first column (key), the resulting dataset would look like:
    • Joined Data: [(China, "A", "X"), (China, "A", "Y"), (China, "A", "Z"), (Fiji, "B", "P"), (Bhutan, "C", "Q")]

Now, if you split (partition) the joined data based on the key (China/Fiji/Bhutan), there will be three partitions:

Partition 1: [(China, "A", "X"), (China, "A", "Y"), (China, "A", "Z")]
Partition 2: [(Fiji, "B", "P")]
Partition 3: [(Bhutan, "C", "Q")]

This means the first partition is unfair. The node with Partition 1 will be overloaded, while the other nodes will be sitting idle after some time.

Presence of Null Values or a Dominant Key

  • Suppose you have a dataset with some entries having a null key or the same key:
    • Dataset: [(null, "A"), (null, "B"), (1, "C"), (1, "D"), (2, "E")]
  • When partitioning based on the key:
    • Partition 1: [(null, "A"), (null, "B")]
    • Partition 2: [(1, "C"), (1, "D")]
    • Partition 3: [(2, "E")]
  • The partitions for null and 1 have more entries, causing an imbalance.

Poorly Chosen Partitioning Key

  • Suppose you have a dataset and you partition it using a key that does not distribute data well:
    • Dataset: [(10, "A"), (20, "B"), (30, "C"), (10, "D"), (20, "E")]
  • If you partition by the key (first column) and the data is unevenly distributed:
    • Partition 1: [(10, "A"), (10, "D")]
    • Partition 2: [(20, "B"), (20, "E")]
    • Partition 3: [(30, "C")]
  • Here, Partition 1 and 2 have more entries compared to Partition 3, leading to skew.

Faulty GroupBy

Just like join operations, GroupBy operations can also cause data imbalance when some keys have a lot more values than others.

Example:

If you group sales data by product ID and one product has many more sales records than others, the partition for that product will be much bigger and could slow down processing.

Handling Data Skew

Salting (for join operations)

Add a random number to keys to distribute them more evenly.

Example:

from pyspark.sql.functions import rand, col

# Assume df1 has skewed keys
salted_df1 = df1.withColumn("salt", (rand()*5).cast("int"))
                .withColumn("new_key", col("key") * 10 + col("salt"))

salted_df2 = df2.withColumn("salt", explode(array([lit(i) for i in range(5)])))
                .withColumn("new_key", col("key") * 10 + col("salt"))

result = salted_df1.join(salted_df2, "new_key").drop("salt", "new_key")

Custom Partitioning

Repartition data based on a more evenly distributed column.

Example:

from pyspark.sql.functions import col

# Assume 'id' is more evenly distributed than 'skewed_column'
evenly_distributed_df = df.repartition(col("id"))

Broadcast Join

For scenarios where one DataFrame is small enough to fit in memory.

Example:

from pyspark.sql.functions import broadcast

result = large_df.join(broadcast(small_df), "key")

Separate out the skewed keys

Handle the most skewed keys separately.

Example:

# Identify skewed keys
key_counts = df.groupBy("key").count().orderBy(col("count").desc())
skewed_keys = key_counts.limit(5).select("key").collect()

# Separate skewed and non-skewed data
skewed_data = df.filter(col("key").isin(skewed_keys))
non_skewed_data = df.filter(~col("key").isin(skewed_keys))

# Process separately and union results
skewed_result = process_skewed_data(skewed_data)
non_skewed_result = process_non_skewed_data(non_skewed_data)
final_result = skewed_result.union(non_skewed_result)

Using AQE (Adaptive Query Execution):

Enable AQE in Spark 3.0+ to dynamically optimize query plans.

Example:

spark.conf.set("spark.sql.adaptive.enabled", "true")

Avoid GroupBy for Large Datasets

When one key has a much larger presence in the dataset, try to avoid using GroupBy with that key or avoid GroupBy altogether. Instead, use alternatives like reduceByKey, which first combines data locally on each partition before grouping, making it more efficient.

Example with DataFrames:

Instead of this:

# GroupBy operation
grouped_data = df.groupBy("key").agg(sum("value"))

Use this:

from pyspark.sql import functions as F

# ReduceByKey equivalent
reduced_data = df.groupBy("key").agg(F.sum("value"))

In this example with DataFrames, the data is aggregated locally on each partition before the final grouping, helping to avoid data imbalance and improve efficiency.

Check your knowledge

Question: Which of the following operations is most likely to induce a skew in the size of your data’s partitions?

A. DataFrame.collect()
B. DataFrame.cache()
C. DataFrame.repartition(n)
D. DataFrame.coalesce(n)
E. DataFrame.persist()

Answer: D. DataFrame.coalesce(n)

Explanation:

  • DataFrame.collect(): This operation brings all the data to the driver node, which doesn’t affect the partition sizes.
  • DataFrame.cache(): This operation stores the DataFrame in memory but doesn’t change the partition sizes.
  • DataFrame.repartition(n): This operation reshuffles the data into a specified number of partitions with a relatively even distribution.
  • DataFrame.coalesce(n): This operation reduces the number of partitions by merging them without a shuffle, which can lead to uneven partition sizes if the original data was not evenly distributed.
  • DataFrame.persist(): Similar to cache(), this stores the DataFrame in memory/disk but doesn’t change partition sizes.

Therefore, coalesce(n) is the operation that is most likely to cause skew in the size of your data’s partitions.