Skip to content

The Complete PySpark Gotchas Guide

About This Guide

This comprehensive guide covers the most expensive mistakes in PySpark development, with practical solutions and performance optimizations. Each gotcha includes real-world examples and measurable improvements.

Overview

PySpark offers incredible power for big data processing, but with great power comes great responsibility. This guide helps you avoid the most common and costly mistakes that can turn your lightning-fast distributed processing into a crawling disaster.

Performance Impact

The gotchas in this guide can cause:

  • 10-100x slower job execution
  • OutOfMemoryError crashes
  • Wasted cluster resources costing thousands of dollars
  • Failed production jobs affecting business operations

Categories

  • Small files problem
  • Schema inference overhead
  • Wrong file formats
  • Over/under-caching
  • Wrong storage levels
  • Memory leaks
  • Data skew handling
  • Broadcasting decisions
  • Window function optimization
  • Default settings trap
  • Resource allocation
  • Dynamic scaling
  • UDF optimization
  • Streaming pitfalls
  • Monitoring blind spots

1. Data Loading & I/O Gotchas

Gotcha #1: The Small Files Performance Killer

Performance Impact: 10-50x slower

Reading thousands of small files creates excessive task overhead and kills performance.

The Problem:

# BAD: Reading 10,000 small JSON files
df = spark.read.json("s3://bucket/small-files/*.json")
# Creates 10,000 tasks with massive overhead

Why This Happens
  • Each file becomes a separate task
  • Task scheduling overhead dominates actual processing
  • Executors spend more time on coordination than computation
  • Network latency multiplied by number of files

The Solution:

# GOOD: Coalesce after reading
df = spark.read.json("s3://bucket/small-files/*.json").coalesce(100)
# BETTER: Use wholeTextFiles for very small files
rdd = spark.sparkContext.wholeTextFiles("s3://bucket/small-files/*.json")
df = spark.read.json(rdd.values())
# BEST: Combine files during ingestion
def optimize_small_files(input_path, output_path, target_size_mb=128):
    df = spark.read.json(input_path)

    # Calculate optimal partitions
    total_size_mb = estimate_dataframe_size_mb(df)
    optimal_partitions = max(1, total_size_mb // target_size_mb)

    df.coalesce(optimal_partitions) \
      .write \
      .mode("overwrite") \
      .parquet(output_path)

Performance Gain

Before: 10,000 tasks, 45 minutes
After: 100 tasks, 3 minutes
Improvement: 15x faster ⚡


Gotcha #2: Schema Inference Double-Read Penalty

Performance Impact: 2x slower, 2x I/O cost

Schema inference requires reading the entire dataset twice.

The Problem:

# BAD: Schema inference on large datasets
df = spark.read.csv("huge_dataset.csv", header=True, inferSchema=True)
# Spark reads the entire file twice!

What Happens Under the Hood
  1. First read: Spark scans entire dataset to infer schema
  2. Second read: Spark reads dataset again with inferred schema
  3. Cost: Double I/O, double time, double cloud storage charges

** The Solution:**

from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType

# GOOD: Define schema upfront
schema = StructType([
    StructField("user_id", StringType(), True),
    StructField("event_time", TimestampType(), True),
    StructField("event_count", IntegerType(), True)
])

df = spark.read.csv("huge_dataset.csv", header=True, schema=schema)
def generate_schema_from_sample(file_path, sample_size=1000):
    """Generate schema from a small sample"""
    sample_df = spark.read.csv(file_path, header=True, inferSchema=True).limit(sample_size)

    print("Generated Schema:")
    print("schema = StructType([")
    for field in sample_df.schema.fields:
        print(f'    StructField("{field.name}", {field.dataType}, {field.nullable}),')
    print("])")

    return sample_df.schema

# Use for large datasets
schema = generate_schema_from_sample("huge_dataset.csv")
df = spark.read.csv("huge_dataset.csv", header=True, schema=schema)

Performance Gain

Before: 2 full dataset scans
After: 1 dataset scan
Improvement: 50% faster, 50% less I/O cost 💰


Gotcha #3: Suboptimal File Format Choices

Performance Impact: 5-20x slower queries

Using row-based formats (CSV/JSON) for analytical workloads instead of columnar formats.

The Problem:

# BAD: Large analytical datasets in row-based formats
df = spark.read.csv("10TB_dataset.csv")  
# No compression, no predicate pushdown, reads entire rows

Format Comparison
Format Compression Predicate Pushdown Schema Evolution ACID
CSV
JSON
Parquet
Delta

** The Solution:**

# GOOD: Columnar format with compression
df = spark.read.parquet("10TB_dataset.parquet")

# Benefits:
# - 70-90% compression
# - Column pruning
# - Predicate pushdown
# - Fast aggregations
# BETTER: Delta Lake for ACID transactions
df = spark.read.format("delta").load("delta-table/")

# Additional benefits:
# - Time travel
# - ACID transactions
# - Schema enforcement
# - Automatic optimization
def migrate_to_optimized_format(source_path, target_path, format_type="delta"):
    """Migrate data to optimized format with partitioning"""

    # Read source data
    df = spark.read.csv(source_path, header=True, inferSchema=True)

    # Optimize partitioning
    if "date" in df.columns:
        df = df.withColumn("year", year(col("date"))) \
               .withColumn("month", month(col("date")))
        partition_cols = ["year", "month"]
    else:
        partition_cols = None

    # Write in optimized format
    writer = df.write.mode("overwrite")

    if partition_cols:
        writer = writer.partitionBy(*partition_cols)

    if format_type == "delta":
        writer.format("delta").save(target_path)
    else:
        writer.parquet(target_path)

    print(f"Migration complete. Data saved to {target_path}")

Storage & Performance Comparison

CSV (1TB)Parquet (200GB)Delta (180GB)
Query Speed: CSV baseline → Parquet 10x faster → Delta 12x faster


2. Partitioning Nightmares

Gotcha #4: The Goldilocks Partition Problem

Performance Impact: 5-100x slower

Partitions that are too small create overhead; too large cause memory issues.

The Problem:

# BAD: Ignoring partition sizes
df = spark.read.parquet("data/")
print(f"Partitions: {df.rdd.getNumPartitions()}")  # Could be 1 or 10,000!

Partition Size Impact

Too Small (< 10MB): - High task scheduling overhead - Underutilized executors - Inefficient network usage

Too Large (> 1GB): - Memory pressure - GC overhead - Risk of OOM errors

Just Right (100-200MB): - Optimal resource utilization - Balanced parallelism - Efficient processing

** The Solution:**

def analyze_partitions(df, df_name="DataFrame"):
    """Comprehensive partition analysis"""
    print(f"\n=== {df_name} Partition Analysis ===")

    num_partitions = df.rdd.getNumPartitions()
    print(f"Number of partitions: {num_partitions}")

    # Sample partition sizes
    partition_counts = df.rdd.mapPartitions(lambda x: [sum(1 for _ in x)]).collect()

    if partition_counts:
        min_size = min(partition_counts)
        max_size = max(partition_counts)
        avg_size = sum(partition_counts) / len(partition_counts)

        print(f"Partition sizes:")
        print(f"  Min: {min_size:,} rows")
        print(f"  Max: {max_size:,} rows") 
        print(f"  Avg: {avg_size:,.0f} rows")

        # Skew detection
        skew_ratio = max_size / avg_size if avg_size > 0 else 0
        if skew_ratio > 3:
            print(f"⚠️  WARNING: High skew detected! Max is {skew_ratio:.1f}x larger than average")

        # Size recommendation
        estimated_mb_per_partition = avg_size * 0.001  # Rough estimate
        print(f"Estimated avg partition size: ~{estimated_mb_per_partition:.1f} MB")

        if estimated_mb_per_partition < 50:
            print("💡 Consider reducing partitions (coalesce)")
        elif estimated_mb_per_partition > 300:
            print("💡 Consider increasing partitions (repartition)")
        else:
            print("✅ Partition sizes look good!")

    return df

# Usage
df = analyze_partitions(df, "Raw Data")
def optimize_partitions(df, target_partition_size_mb=128):
    """Calculate and apply optimal partitioning"""

    # Estimate total size
    sample_count = df.sample(0.01).count()
    if sample_count == 0:
        return df.coalesce(1)

    total_count = df.count()
    sample_data = df.sample(0.01).take(min(100, sample_count))

    if sample_data:
        # Estimate row size (rough approximation)
        avg_row_size_bytes = sum(len(str(row)) for row in sample_data) / len(sample_data) * 2
        total_size_mb = (total_count * avg_row_size_bytes) / (1024 * 1024)

        optimal_partitions = max(1, int(total_size_mb / target_partition_size_mb))
        optimal_partitions = min(optimal_partitions, 4000)  # Cap at reasonable max

        print(f"Estimated dataset size: {total_size_mb:.1f} MB")
        print(f"Target partition size: {target_partition_size_mb} MB")
        print(f"Optimal partitions: {optimal_partitions}")

        current_partitions = df.rdd.getNumPartitions()

        if optimal_partitions < current_partitions:
            print("Applying coalesce...")
            return df.coalesce(optimal_partitions)
        elif optimal_partitions > current_partitions * 1.5:
            print("Applying repartition...")
            return df.repartition(optimal_partitions)
        else:
            print("Current partitioning is acceptable")
            return df

    return df

# Apply optimization
df_optimized = optimize_partitions(df)

Optimization Results

Before: 10,000 partitions (1MB each)
After: 100 partitions (100MB each)
Improvement: 20x fewer tasks, 80% less overhead


Gotcha #5: High-Cardinality Partitioning Disaster

Performance Impact: Creates millions of tiny files

Partitioning by high-cardinality columns creates too many small partitions.

The Problem:

# BAD: Partitioning by high-cardinality column
df.write.partitionBy("user_id").parquet("output/")  
# Creates millions of tiny partitions (one per user)

Small Files Problem

Impact of millions of small files: - Metadata overhead in storage systems - Slow listing operations - Inefficient subsequent reads - Increased storage costs (minimum block sizes)

** The Solution:**

# GOOD: Partition by low-cardinality columns
df.write.partitionBy("year", "month").parquet("output/")

# For time-based partitioning
from pyspark.sql.functions import date_format

df_with_partition = df.withColumn(
    "year_month", 
    date_format(col("timestamp"), "yyyy-MM")
)
df_with_partition.write.partitionBy("year_month").parquet("output/")
# BETTER: Use hash-based partitioning for high-cardinality
from pyspark.sql.functions import hash, col

# Create buckets for high-cardinality column
num_buckets = 100  # Adjust based on data size

df_bucketed = df.withColumn(
    "user_bucket", 
    hash(col("user_id")) % num_buckets
)

df_bucketed.write.partitionBy("user_bucket").parquet("output/")
def analyze_cardinality(df, columns, sample_fraction=0.1):
    """Analyze cardinality of potential partition columns"""

    print("=== Cardinality Analysis ===")
    sample_df = df.sample(sample_fraction)
    total_rows = df.count()
    sample_rows = sample_df.count()

    results = {}

    for column in columns:
        if column in df.columns:
            distinct_count = sample_df.select(column).distinct().count()

            # Estimate total distinct values
            estimated_distinct = distinct_count * (total_rows / sample_rows)

            results[column] = {
                'estimated_distinct': int(estimated_distinct),
                'cardinality_ratio': estimated_distinct / total_rows
            }

            # Partitioning recommendation
            if estimated_distinct < 100:
                recommendation = "✅ Good for partitioning"
            elif estimated_distinct < 1000:
                recommendation = "⚠️  Consider hash bucketing"
            else:
                recommendation = "❌ Too high cardinality"

            print(f"{column}:")
            print(f"  Estimated distinct values: {int(estimated_distinct):,}")
            print(f"  Cardinality ratio: {estimated_distinct/total_rows:.4f}")
            print(f"  Recommendation: {recommendation}")
            print()

    return results

# Usage
partition_analysis = analyze_cardinality(
    df, 
    columns=["user_id", "category", "date", "region"]
)

Partitioning Guidelines

Ideal partition column characteristics:

  • ✅ Low cardinality (< 1000 distinct values)
  • ✅ Evenly distributed data
  • ✅ Frequently used in WHERE clauses
  • ✅ Stable over time

3. Caching & Persistence Pitfalls

Gotcha #6: The Over-Caching Memory Waste

Performance Impact: Memory exhaustion, slower jobs

Caching DataFrames that are used only once wastes precious executor memory.

The Problem:

# BAD: Cache everything approach
df1 = spark.read.parquet("data1.parquet").cache()  # Used once
df2 = spark.read.parquet("data2.parquet").cache()  # Used once  
df3 = spark.read.parquet("data3.parquet").cache()  # Used once

result = df1.join(df2, "key").join(df3, "key")  # Memory wasted!

Memory Impact

Over-caching consequences: - Executor memory exhaustion - Increased GC pressure - Spilling to disk (defeating cache purpose) - Reduced performance for actually reused data

** The Solution:**

# GOOD: Cache only reused DataFrames
expensive_df = df.groupBy("category").agg(
    count("*").alias("count"),
    avg("price").alias("avg_price"),
    sum("revenue").alias("total_revenue")
)

# This will be reused multiple times
expensive_df.cache()

# Multiple operations using cached data
high_volume = expensive_df.filter(col("count") > 1000)
low_volume = expensive_df.filter(col("count") < 100)
mid_range = expensive_df.filter(
    (col("count") >= 100) & (col("count") <= 1000)
)

# Clean up when done
expensive_df.unpersist()
def should_cache(df, usage_count, computation_cost="medium"):
    """Intelligent caching decision based on usage patterns"""

    cost_weights = {
        "low": 1,      # Simple transformations
        "medium": 3,   # Joins, groupBy
        "high": 10     # Complex aggregations, multiple joins
    }

    weight = cost_weights.get(computation_cost, 3)
    cache_benefit_score = usage_count * weight

    # Memory consideration
    partition_count = df.rdd.getNumPartitions()
    memory_concern = partition_count > 1000  # Many partitions = more memory

    recommendation = {
        "should_cache": cache_benefit_score >= 6 and not memory_concern,
        "score": cache_benefit_score,
        "memory_concern": memory_concern
    }

    return recommendation

# Usage example
expensive_computation = df.groupBy("category", "region").agg(
    countDistinct("user_id"),
    percentile_approx("amount", 0.5),
    collect_list("product_id")
)

cache_decision = should_cache(
    expensive_computation, 
    usage_count=3, 
    computation_cost="high"
)

if cache_decision["should_cache"]:
    expensive_computation.cache()
    print(f"✅ Caching recommended (score: {cache_decision['score']})")
else:
    print(f"❌ Caching not recommended (score: {cache_decision['score']})")

Caching Best Practices

Cache when: - DataFrame is used 2+ times ✅ - Computation is expensive ✅
- Memory is available ✅

Don't cache when: - One-time use ❌ - Simple transformations ❌ - Memory is constrained ❌


Gotcha #7: Wrong Storage Level Choices

Performance Impact: Cache eviction, memory pressure

Using inappropriate storage levels can cause cache thrashing and poor performance.

The Problem:

# BAD: Default MEMORY_ONLY when data doesn't fit
large_df.cache()  # Uses MEMORY_ONLY, causes eviction cascades

Storage Level Comparison
Storage Level Memory Disk Serialized Replicated
MEMORY_ONLY
MEMORY_AND_DISK
MEMORY_ONLY_SER
MEMORY_AND_DISK_SER
DISK_ONLY

** The Solution:**

from pyspark import StorageLevel

# For large datasets that might not fit in memory
large_df.persist(StorageLevel.MEMORY_AND_DISK_SER)

# For critical data that needs high availability
critical_df.persist(StorageLevel.MEMORY_ONLY_2)  # Replicated

# For infrequently accessed but expensive to compute
archive_df.persist(StorageLevel.DISK_ONLY)

# For iterative algorithms with memory constraints
ml_features.persist(StorageLevel.MEMORY_AND_DISK_SER_2)
def select_storage_level(df, access_pattern="frequent", memory_available_gb=8):
    """Select optimal storage level based on usage pattern"""

    # Estimate DataFrame size
    estimated_size_gb = estimate_dataframe_size_gb(df)

    # Storage level decision matrix
    if access_pattern == "frequent":
        if estimated_size_gb < memory_available_gb * 0.3:
            return StorageLevel.MEMORY_ONLY
        elif estimated_size_gb < memory_available_gb * 0.6:
            return StorageLevel.MEMORY_ONLY_SER
        else:
            return StorageLevel.MEMORY_AND_DISK_SER

    elif access_pattern == "occasional":
        if estimated_size_gb < memory_available_gb * 0.2:
            return StorageLevel.MEMORY_ONLY
        else:
            return StorageLevel.MEMORY_AND_DISK_SER

    elif access_pattern == "rare":
        return StorageLevel.DISK_ONLY

    elif access_pattern == "critical":
        if estimated_size_gb < memory_available_gb * 0.4:
            return StorageLevel.MEMORY_ONLY_2  # Replicated
        else:
            return StorageLevel.MEMORY_AND_DISK_SER_2

    return StorageLevel.MEMORY_AND_DISK_SER  # Safe default

# Usage
storage_level = select_storage_level(
    df=expensive_computation,
    access_pattern="frequent", 
    memory_available_gb=16
)

expensive_computation.persist(storage_level)
print(f"Using storage level: {storage_level}")
def monitor_cache_usage():
    """Monitor cache usage across the cluster"""
    print("=== Cache Usage Report ===")

    storage_infos = spark.sparkContext._jsc.sc().getRDDStorageInfo()

    total_memory_used = 0
    total_disk_used = 0

    for storage_info in storage_infos:
        rdd_id = storage_info.id()
        memory_size_mb = storage_info.memSize() / (1024 * 1024)
        disk_size_mb = storage_info.diskSize() / (1024 * 1024)
        storage_level = storage_info.storageLevel()

        total_memory_used += memory_size_mb
        total_disk_used += disk_size_mb

        print(f"RDD {rdd_id}:")
        print(f"  Memory: {memory_size_mb:.1f} MB")
        print(f"  Disk: {disk_size_mb:.1f} MB")
        print(f"  Storage Level: {storage_level}")
        print()

    print(f"Total Cache Usage:")
    print(f"  Memory: {total_memory_used:.1f} MB")
    print(f"  Disk: {total_disk_used:.1f} MB")

    # Get executor memory info
    executors = spark.sparkContext.statusTracker().getExecutorInfos()
    total_executor_memory = sum(exec.maxMemory for exec in executors)
    cache_memory_ratio = (total_memory_used * 1024 * 1024) / total_executor_memory

    print(f"  Cache Memory Ratio: {cache_memory_ratio:.1%}")

    if cache_memory_ratio > 0.8:
        print("⚠️  WARNING: High cache memory usage!")

# Monitor periodically
monitor_cache_usage()

Storage Level Guidelines

Choose based on: - Data size vs available memory - Access frequency (frequent = memory priority) - Fault tolerance needs (critical = replication) - Cost sensitivity (disk cheaper than memory)


Gotcha #8: Lazy Cache Evaluation Trap

Performance Impact: Cache never populated

Cache is lazy - without triggering an action, the cache remains empty.

The Problem:

# BAD: Cache without triggering action
expensive_df = df.groupBy("category").agg(count("*"))
expensive_df.cache()  # Nothing cached yet!

# Later operations don't benefit from cache
result1 = expensive_df.filter(col("count") > 100).collect()  # Computed
result2 = expensive_df.filter(col("count") < 50).collect()   # Recomputed!

** The Solution:**

# GOOD: Force cache with action
expensive_df = df.groupBy("category").agg(count("*"))
expensive_df.cache()

# Trigger cache population
cache_trigger = expensive_df.count()  # Forces computation and caching
print(f"Cached {cache_trigger} rows")

# Now subsequent operations use cache
result1 = expensive_df.filter(col("count") > 100).collect()  # Uses cache
result2 = expensive_df.filter(col("count") < 50).collect()   # Uses cache
from contextlib import contextmanager

@contextmanager
def cached_dataframe(df, storage_level=None, trigger_action=True):
    """Context manager for automatic cache management"""

    if storage_level:
        df.persist(storage_level)
    else:
        df.cache()

    try:
        if trigger_action:
            # Trigger cache population
            row_count = df.count()
            print(f"✅ Cached DataFrame with {row_count:,} rows")

        yield df

    finally:
        df.unpersist()
        print("🧹 Cache cleaned up")

# Usage
expensive_computation = df.groupBy("category").agg(
    count("*").alias("count"),
    avg("price").alias("avg_price")
)

with cached_dataframe(expensive_computation) as cached_df:
    # All operations within this block use cache
    high_count = cached_df.filter(col("count") > 1000).collect()
    low_count = cached_df.filter(col("count") < 100).collect()
    stats = cached_df.describe().collect()

# Cache automatically cleaned up here
def verify_cache_usage(df, operation_name="operation"):
    """Verify that cache is actually being used"""

    # Check if DataFrame is cached
    if not df.is_cached:
        print(f"⚠️  WARNING: {operation_name} - DataFrame not cached!")
        return False

    # Get RDD storage info
    rdd_id = df.rdd.id()
    storage_infos = spark.sparkContext._jsc.sc().getRDDStorageInfo()

    for storage_info in storage_infos:
        if storage_info.id() == rdd_id:
            memory_size = storage_info.memSize()
            disk_size = storage_info.diskSize()

            if memory_size > 0 or disk_size > 0:
                print(f"✅ {operation_name} - Cache verified: "
                      f"{memory_size/(1024**2):.1f}MB memory, "
                      f"{disk_size/(1024**2):.1f}MB disk")
                return True
            else:
                print(f"⚠️  {operation_name} - Cache empty, triggering population...")
                df.count()  # Trigger cache
                return True

    print(f"❌ {operation_name} - Cache not found!")
    return False

# Usage
expensive_df.cache()
verify_cache_usage(expensive_df, "Expensive Computation")

Cache Checklist

Before relying on cache:

  • [ ] DataFrame is marked as cached (.cache() or .persist())
  • [ ] Action has been triggered (.count(), .collect(), etc.)
  • [ ] Verify cache population with monitoring
  • [ ] Plan cache cleanup (.unpersist())

4. Join Operation Hell

Gotcha #9: Data Skew - The Silent Performance Killer

Performance Impact: Some tasks take 100x longer

Uneven key distribution causes massive partitions while others remain tiny, creating severe bottlenecks.

The Problem:

# BAD: Join with severely skewed data
user_events = spark.read.table("user_events")    # Some users: millions of events
user_profiles = spark.read.table("user_profiles") # Even distribution

# Hot keys create massive partitions
result = user_events.join(user_profiles, "user_id")
# 99% of tasks finish in 30 seconds, 1% take 2 hours!

Why Skew Kills Performance

The anatomy of skew: - 95% of partitions: 1,000 records each (finish quickly) - 5% of partitions: 1,000,000 records each (become stragglers) - Result: Entire job waits for slowest partition

Real-world impact: - Job that should take 10 minutes takes 3 hours - Cluster sits 95% idle waiting for stragglers - Potential executor OOM on large partitions

** The Solution:**

def detect_join_skew(df, join_column, sample_fraction=0.1, skew_threshold=1000):
    """Detect data skew in join keys"""

    print(f"=== Skew Analysis for '{join_column}' ===")

    # Sample for performance on large datasets
    sample_df = df.sample(sample_fraction)

    # Get key distribution
    key_counts = sample_df.groupBy(join_column).count() \
                          .orderBy(col("count").desc())

    stats = key_counts.agg(
        min("count").alias("min_count"),
        max("count").alias("max_count"), 
        avg("count").alias("avg_count"),
        expr("percentile_approx(count, 0.95)").alias("p95_count")
    ).collect()[0]

    # Scale up from sample
    scale_factor = 1 / sample_fraction
    scaled_max = stats["max_count"] * scale_factor
    scaled_avg = stats["avg_count"] * scale_factor

    skew_ratio = scaled_max / scaled_avg if scaled_avg > 0 else 0

    print(f"Key distribution (scaled from {sample_fraction*100}% sample):")
    print(f"  Average count per key: {scaled_avg:,.0f}")
    print(f"  Maximum count per key: {scaled_max:,.0f}")
    print(f"  Skew ratio (max/avg): {skew_ratio:.1f}")

    # Show top skewed keys
    print(f"\nTop 10 most frequent keys:")
    top_keys = key_counts.limit(10).collect()
    for row in top_keys:
        scaled_count = row["count"] * scale_factor
        print(f"  {row[join_column]}: {scaled_count:,.0f} records")

    # Skew assessment
    if skew_ratio > 10:
        print(f"\n🚨 SEVERE SKEW DETECTED! Ratio: {skew_ratio:.1f}")
        return "severe"
    elif skew_ratio > 3:
        print(f"\n⚠️  Moderate skew detected. Ratio: {skew_ratio:.1f}")
        return "moderate"
    else:
        print(f"\n✅ No significant skew. Ratio: {skew_ratio:.1f}")
        return "none"

# Usage
skew_level = detect_join_skew(user_events, "user_id")
# GOOD: Broadcast small table to avoid shuffle
def smart_broadcast_join(large_df, small_df, join_keys):
    """Intelligently decide on broadcast join"""

    # Estimate small table size
    small_sample = small_df.sample(0.1)
    if small_sample.count() > 0:
        sample_rows = small_sample.count()
        total_rows = small_df.count()
        sample_size_mb = len(str(small_sample.take(100))) * sample_rows / (1024 * 1024)
        estimated_size_mb = sample_size_mb * (total_rows / sample_rows)

        print(f"Small table estimated size: {estimated_size_mb:.1f} MB")

        if estimated_size_mb < 200:  # Safe broadcast threshold
            print("✅ Using broadcast join")
            return large_df.join(broadcast(small_df), join_keys)
        else:
            print("❌ Table too large for broadcast, using regular join")
            return large_df.join(small_df, join_keys)

    return large_df.join(small_df, join_keys)

# Apply smart broadcast
result = smart_broadcast_join(user_events, user_profiles, "user_id")
# BETTER: Salting technique for severe skew
def salted_join(large_df, small_df, join_key, salt_buckets=100):
    """Handle severe skew using salting technique"""

    print(f"Applying salting with {salt_buckets} buckets...")

    # Add salt to large table
    large_salted = large_df.withColumn(
        "salt", 
        (rand() * salt_buckets).cast("int")
    ).withColumn(
        "salted_key",
        concat(col(join_key).cast("string"), lit("_"), col("salt").cast("string"))
    )

    # Explode small table across all salt values
    salt_range = spark.range(salt_buckets).select(col("id").alias("salt"))
    small_exploded = small_df.crossJoin(salt_range).withColumn(
        "salted_key",
        concat(col(join_key).cast("string"), lit("_"), col("salt").cast("string"))
    )

    # Join on salted keys
    result = large_salted.join(small_exploded, "salted_key") \
                        .drop("salt", "salted_key")  # Clean up helper columns

    print("✅ Salted join completed")
    return result

# Apply when severe skew detected
if skew_level == "severe":
    result = salted_join(user_events, user_profiles, "user_id", salt_buckets=200)
else:
    result = user_events.join(broadcast(user_profiles), "user_id")
# BEST: Pre-bucketing for repeated skewed joins
def create_bucketed_tables(df, table_name, bucket_column, num_buckets=200):
    """Create bucketed table for optimal joins"""

    print(f"Creating bucketed table: {table_name}")

    df.write \
      .mode("overwrite") \
      .option("path", f"/bucketed_tables/{table_name}") \
      .bucketBy(num_buckets, bucket_column) \
      .sortBy(bucket_column) \
      .saveAsTable(table_name)

    print(f"✅ Bucketed table created with {num_buckets} buckets")

# Create bucketed tables (one-time setup)
create_bucketed_tables(user_events, "user_events_bucketed", "user_id", 200)
create_bucketed_tables(user_profiles, "user_profiles_bucketed", "user_id", 200)

# Fast joins on bucketed tables (no shuffle needed!)
bucketed_events = spark.table("user_events_bucketed")
bucketed_profiles = spark.table("user_profiles_bucketed")

# This join will be much faster - no shuffle required
result = bucketed_events.join(bucketed_profiles, "user_id")

Skew Handling Results

Before (skewed join): 3 hours, 95% cluster idle
After (salted join): 25 minutes, even distribution
Improvement: 7x faster, better resource utilization


Gotcha #10: Broadcasting Memory Bombs

Performance Impact: OutOfMemoryError, cluster crashes

Broadcasting tables larger than executor memory causes catastrophic failures.

The Problem:

# BAD: Broadcasting without size validation
large_lookup = spark.read.table("product_catalog")  # 5GB table!
orders = spark.read.table("orders")

# This will crash executors
result = orders.join(broadcast(large_lookup), "product_id")  # OOM!

Broadcast Memory Requirements

Memory needed for broadcast: - Table size × Number of executor cores - Example: 1GB table × 100 cores = 100GB total memory needed - Each executor must hold entire broadcasted table in memory

Failure cascade: - Executors run out of memory - Tasks fail and retry - Driver struggles with retries - Eventually entire application crashes

** The Solution:**

def safe_broadcast_join(left_df, right_df, join_keys, max_broadcast_mb=200):
    """Safely determine optimal join strategy"""

    def estimate_dataframe_size_mb(df, sample_fraction=0.01):
        """Estimate DataFrame size in MB"""
        try:
            sample = df.sample(sample_fraction)
            sample_count = sample.count()

            if sample_count == 0:
                return float('inf')  # Cannot estimate

            # Get sample data for size estimation
            sample_data = sample.take(min(100, sample_count))
            if not sample_data:
                return float('inf')

            # Estimate average row size
            avg_row_size = sum(len(str(row)) for row in sample_data) / len(sample_data)
            total_rows = df.count()

            # Conservative size estimate (includes serialization overhead)
            estimated_size_mb = (total_rows * avg_row_size * 2) / (1024 * 1024)
            return estimated_size_mb

        except Exception as e:
            print(f"Size estimation failed: {e}")
            return float('inf')

    # Estimate sizes
    left_size = estimate_dataframe_size_mb(left_df)
    right_size = estimate_dataframe_size_mb(right_df)

    print(f"Join size analysis:")
    print(f"  Left table:  {left_size:.1f} MB")
    print(f"  Right table: {right_size:.1f} MB")
    print(f"  Broadcast threshold: {max_broadcast_mb} MB")

    # Determine join strategy
    if right_size <= max_broadcast_mb:
        print("✅ Broadcasting right table")
        return left_df.join(broadcast(right_df), join_keys)
    elif left_size <= max_broadcast_mb:
        print("✅ Broadcasting left table")
        return broadcast(left_df).join(right_df, join_keys)
    else:
        print("⚠️  Both tables too large for broadcast, using shuffle join")

        # Optimize shuffle join
        optimized_left = left_df.repartition(col(join_keys[0]) if isinstance(join_keys, list) else col(join_keys))
        optimized_right = right_df.repartition(col(join_keys[0]) if isinstance(join_keys, list) else col(join_keys))

        return optimized_left.join(optimized_right, join_keys)

# Usage
result = safe_broadcast_join(orders, product_catalog, "product_id", max_broadcast_mb=150)
def calculate_safe_broadcast_threshold():
    """Calculate safe broadcast threshold based on cluster resources"""

    # Get executor information
    executors = spark.sparkContext.statusTracker().getExecutorInfos()

    if not executors:
        return 50  # Conservative default

    # Calculate available memory per executor
    min_executor_memory = min(exec.maxMemory for exec in executors)
    executor_count = len(executors)

    # Reserve memory for other operations (conservative 30%)
    available_memory_per_executor = min_executor_memory * 0.3

    # Convert to MB
    safe_threshold_mb = available_memory_per_executor / (1024 * 1024)

    print(f"Cluster broadcast analysis:")
    print(f"  Executors: {executor_count}")
    print(f"  Min executor memory: {min_executor_memory/(1024**3):.1f} GB")
    print(f"  Safe broadcast threshold: {safe_threshold_mb:.0f} MB")

    return min(safe_threshold_mb, 500)  # Cap at 500MB for safety

# Auto-calculate threshold
safe_threshold = calculate_safe_broadcast_threshold()
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", f"{int(safe_threshold)}MB")
def monitor_broadcast_usage():
    """Monitor broadcast variable usage across cluster"""

    print("=== Broadcast Usage Report ===")

    # Get broadcast information
    broadcast_vars = spark.sparkContext._jsc.sc().getBroadcastInfos()

    total_broadcast_size = 0

    for broadcast_info in broadcast_vars:
        broadcast_id = broadcast_info.id()
        broadcast_size_mb = broadcast_info.size() / (1024 * 1024)
        total_broadcast_size += broadcast_size_mb

        print(f"Broadcast {broadcast_id}: {broadcast_size_mb:.1f} MB")

    print(f"Total broadcast size: {total_broadcast_size:.1f} MB")

    # Check against executor memory
    executors = spark.sparkContext.statusTracker().getExecutorInfos()
    if executors:
        total_executor_memory_gb = sum(exec.maxMemory for exec in executors) / (1024**3)
        broadcast_ratio = (total_broadcast_size / 1024) / total_executor_memory_gb

        print(f"Broadcast memory ratio: {broadcast_ratio:.1%}")

        if broadcast_ratio > 0.1:  # More than 10% of cluster memory
            print("⚠️  WARNING: High broadcast memory usage!")

    return total_broadcast_size

# Monitor after joins
monitor_broadcast_usage()

Broadcasting Guidelines

Safe broadcast sizes:

  • Small cluster: < 50MB
  • Medium cluster: < 200MB
  • Large cluster: < 500MB
  • Rule of thumb: < 10% of executor memory

Gotcha #11: Inefficient Join Ordering

Performance Impact: 10x larger intermediate results

Joining large tables first instead of filtering with smaller ones creates massive intermediate datasets.

The Problem:

# BAD: Join large tables first
events = spark.read.table("events")           # 1TB
sessions = spark.read.table("sessions")       # 500GB  
active_users = spark.read.table("active_users")  # 10MB

# Creates massive intermediate result
result = events.join(sessions, "session_id") \    # 1.5TB intermediate!
               .join(active_users, "user_id")     # Then filter to 50GB

** The Solution:**

# GOOD: Filter early, join strategically

# Step 1: Start with most selective filters
active_events = events.join(broadcast(active_users), "user_id")  # Filter first

# Step 2: Apply additional filters before expensive joins
recent_events = active_events.filter(col("event_date") >= "2023-01-01")

# Step 3: Join with larger tables only on filtered data
result = recent_events.join(sessions, "session_id")

# Result: 50GB intermediate instead of 1.5TB!
class JoinOptimizer:
    def __init__(self, spark_session):
        self.spark = spark_session

    def estimate_join_selectivity(self, left_df, right_df, join_key):
        """Estimate how much a join will reduce data size"""

        # Sample to estimate selectivity
        left_sample = left_df.sample(0.01)
        right_sample = right_df.sample(0.01)

        left_keys = set(row[join_key] for row in left_sample.select(join_key).collect())
        right_keys = set(row[join_key] for row in right_sample.select(join_key).collect())

        # Estimate join selectivity
        intersection_ratio = len(left_keys.intersection(right_keys)) / max(len(left_keys), 1)

        return intersection_ratio

    def optimize_join_order(self, tables_with_info):
        """
        Optimize join order based on table sizes and selectivity
        tables_with_info: [(df, name, estimated_size_mb, join_key)]
        """

        print("=== Join Order Optimization ===")

        # Sort by size (smallest first for broadcast candidates)
        sorted_tables = sorted(tables_with_info, key=lambda x: x[2])

        optimized_plan = []

        for i, (df, name, size_mb, join_key) in enumerate(sorted_tables):
            if size_mb < 200:  # Broadcast candidate
                optimized_plan.append({
                    'table': df,
                    'name': name,
                    'size_mb': size_mb,
                    'strategy': 'broadcast',
                    'order': i
                })
            else:
                optimized_plan.append({
                    'table': df,
                    'name': name, 
                    'size_mb': size_mb,
                    'strategy': 'shuffle',
                    'order': i
                })

        print("Optimized join plan:")
        for plan in optimized_plan:
            print(f"  {plan['order']}: {plan['name']} ({plan['size_mb']:.1f}MB) - {plan['strategy']}")

        return optimized_plan

    def execute_optimized_joins(self, join_plan, base_df=None):
        """Execute joins in optimized order"""

        if not join_plan:
            return base_df

        result_df = base_df or join_plan[0]['table']

        for i, plan in enumerate(join_plan[1:], 1):
            join_df = plan['table']

            if plan['strategy'] == 'broadcast':
                print(f"Step {i}: Broadcasting {plan['name']}")
                result_df = result_df.join(broadcast(join_df), plan.get('join_key', 'id'))
            else:
                print(f"Step {i}: Shuffle joining {plan['name']}")
                result_df = result_df.join(join_df, plan.get('join_key', 'id'))

        return result_df

# Usage
optimizer = JoinOptimizer(spark)

tables_info = [
    (events, "events", 1000000, "user_id"),      # 1TB
    (sessions, "sessions", 500000, "session_id"), # 500GB
    (active_users, "active_users", 10, "user_id") # 10MB
]

join_plan = optimizer.optimize_join_order(tables_info)
optimized_result = optimizer.execute_optimized_joins(join_plan)
def monitor_join_performance(df, operation_name):
    """Monitor join operation performance"""

    print(f"\n=== {operation_name} Performance ===")

    start_time = time.time()

    # Get initial metrics
    initial_partitions = df.rdd.getNumPartitions()

    # Trigger computation
    result_count = df.count()

    end_time = time.time()
    duration = end_time - start_time

    print(f"Operation: {operation_name}")
    print(f"Duration: {duration:.2f} seconds")
    print(f"Result rows: {result_count:,}")
    print(f"Partitions: {initial_partitions}")
    print(f"Throughput: {result_count/duration:,.0f} rows/second")

    # Check for skew in result
    partition_counts = df.rdd.mapPartitions(lambda x: [sum(1 for _ in x)]).collect()
    if partition_counts:
        max_partition = max(partition_counts)
        avg_partition = sum(partition_counts) / len(partition_counts)
        skew_ratio = max_partition / avg_partition if avg_partition > 0 else 0

        print(f"Result skew ratio: {skew_ratio:.1f}")
        if skew_ratio > 3:
            print("⚠️  High skew in join result!")

    return duration, result_count

# Monitor join performance
with monitor_join_performance(result, "Optimized Multi-Join"):
    final_result = result.collect()

Join Ordering Benefits

Unoptimized: 1TB → 1.5TB → 50GB (45 min)
Optimized: 1TB → 100GB → 50GB (8 min)
Improvement: 5.6x faster, 93% less shuffle data


5. Aggregation & GroupBy Traps

Gotcha #12: Multiple-Pass Aggregation Waste

Performance Impact: 3-5x unnecessary data scans

Calling separate aggregation functions triggers multiple passes over the same dataset.

The Problem:

# BAD: Multiple scans of the same data
total_sales = df.agg(sum("sales")).collect()[0][0]      # Scan 1
avg_sales = df.agg(avg("sales")).collect()[0][0]        # Scan 2  
max_sales = df.agg(max("sales")).collect()[0][0]        # Scan 3
count_sales = df.agg(count("sales")).collect()[0][0]    # Scan 4
# Four full dataset scans!

** The Solution:**

# GOOD: Single pass for multiple aggregations
from pyspark.sql.functions import sum, avg, max, min, count, stddev, expr

# Compute all statistics in one pass
stats = df.agg(
    sum("sales").alias("total_sales"),
    avg("sales").alias("avg_sales"), 
    max("sales").alias("max_sales"),
    min("sales").alias("min_sales"),
    count("sales").alias("count_sales"),
    stddev("sales").alias("stddev_sales"),
    expr("percentile_approx(sales, 0.5)").alias("median_sales"),
    expr("percentile_approx(sales, array(0.25, 0.75))").alias("quartiles")
).collect()[0]

# Extract results
total_sales = stats["total_sales"]
avg_sales = stats["avg_sales"]
max_sales = stats["max_sales"]

print(f"Sales Statistics (single pass):")
print(f"  Total: ${total_sales:,.2f}")
print(f"  Average: ${avg_sales:,.2f}")
print(f"  Range: ${stats['min_sales']:,.2f} - ${max_sales:,.2f}")
def comprehensive_stats(df, numeric_columns, categorical_columns=None):
    """Generate comprehensive statistics in a single pass"""

    print("=== Comprehensive Statistics Analysis ===")

    # Build aggregation expressions
    agg_exprs = []

    # Numeric column statistics
    for col_name in numeric_columns:
        if col_name in df.columns:
            agg_exprs.extend([
                count(col_name).alias(f"{col_name}_count"),
                sum(col_name).alias(f"{col_name}_sum"),
                avg(col_name).alias(f"{col_name}_avg"),
                min(col_name).alias(f"{col_name}_min"),
                max(col_name).alias(f"{col_name}_max"),
                stddev(col_name).alias(f"{col_name}_stddev"),
                expr(f"percentile_approx({col_name}, 0.5)").alias(f"{col_name}_median")
            ])

    # Categorical column statistics  
    if categorical_columns:
        for col_name in categorical_columns:
            if col_name in df.columns:
                agg_exprs.extend([
                    countDistinct(col_name).alias(f"{col_name}_distinct_count"),
                    count(col_name).alias(f"{col_name}_non_null_count")
                ])

    # Execute single aggregation
    if agg_exprs:
        stats_result = df.agg(*agg_exprs).collect()[0]

        # Format and display results
        for col_name in numeric_columns:
            if f"{col_name}_count" in stats_result.asDict():
                print(f"\n{col_name.upper()} Statistics:")
                print(f"  Count: {stats_result[f'{col_name}_count']:,}")
                print(f"  Sum: {stats_result[f'{col_name}_sum']:,.2f}")
                print(f"  Average: {stats_result[f'{col_name}_avg']:,.2f}")
                print(f"  Range: {stats_result[f'{col_name}_min']:,.2f} - {stats_result[f'{col_name}_max']:,.2f}")
                print(f"  Std Dev: {stats_result[f'{col_name}_stddev']:,.2f}")
                print(f"  Median: {stats_result[f'{col_name}_median']:,.2f}")

        if categorical_columns:
            print(f"\nCATEGORICAL COLUMNS:")
            for col_name in categorical_columns:
                if f"{col_name}_distinct_count" in stats_result.asDict():
                    distinct_count = stats_result[f"{col_name}_distinct_count"]
                    non_null_count = stats_result[f"{col_name}_non_null_count"]
                    print(f"  {col_name}: {distinct_count:,} distinct values, {non_null_count:,} non-null")

        return stats_result.asDict()

    return {}

# Usage
stats = comprehensive_stats(
    df,
    numeric_columns=["sales", "quantity", "profit"],
    categorical_columns=["category", "region", "customer_segment"]
)

Performance Improvement

Before: 4 separate scans (20 minutes)
After: 1 comprehensive scan (5 minutes)
Improvement: 4x faster data processing


Gotcha #13: High-Cardinality GroupBy Memory Explosion

Performance Impact: OutOfMemoryError from massive state

GroupBy operations on high-cardinality columns create enormous intermediate state.

The Problem:

# BAD: GroupBy on millions of unique values
user_stats = df.groupBy("user_id").agg(        # 10M unique users
    count("*").alias("event_count"),
    sum("amount").alias("total_spent")
)
# Creates 10M groups in memory - potential OOM!

** The Solution:**

# Strategy 1: Binning/Bucketing
from pyspark.sql.functions import floor, col, when

# GOOD: Reduce cardinality with intelligent binning
user_binned = df.withColumn(
    "user_bucket", 
    floor(col("user_id") / 1000)  # Group users into buckets of 1000
)

bucket_stats = user_binned.groupBy("user_bucket").agg(
    count("*").alias("total_events"),
    countDistinct("user_id").alias("unique_users"),
    avg("amount").alias("avg_amount")
)

# Strategy 2: Sampling for exploration
sample_stats = df.sample(0.1).groupBy("user_id").agg(
    count("*").alias("sample_event_count"),
    sum("amount").alias("sample_total")
)

# Strategy 3: Top-N analysis instead of full groupby
top_users = df.groupBy("user_id") \
              .agg(sum("amount").alias("total_spent")) \
              .orderBy(col("total_spent").desc()) \
              .limit(1000)  # Only top 1000 users
# BETTER: Use approximate functions for large-scale analytics
from pyspark.sql.functions import approx_count_distinct, expr

# Approximate statistics (much faster, less memory)
approx_stats = df.agg(
    approx_count_distinct("user_id", 0.05).alias("approx_unique_users"),  # 5% error
    expr("percentile_approx(amount, 0.5)").alias("median_amount"),
    expr("percentile_approx(amount, array(0.25, 0.75))").alias("quartiles")
)

# Compare exact vs approximate
print("Approximate vs Exact Comparison:")

# Exact (expensive)
exact_unique = df.select("user_id").distinct().count()

# Approximate (fast)
approx_unique = approx_stats.collect()[0]["approx_unique_users"]

error_rate = abs(exact_unique - approx_unique) / exact_unique
print(f"Exact unique users: {exact_unique:,}")
print(f"Approximate unique users: {approx_unique:,}")  
print(f"Error rate: {error_rate:.2%}")

```python def memory_aware_groupby(df, group_columns, agg_exprs, max_groups=1000000): """Perform GroupBy with memory awareness"""