Removing Duplicate Rows in PySpark DataFrames

Duplicate rows in dataframes are very common. Here, I will show you some common methods to remove duplicate rows from dataframes.

What is a Duplicate Row?

A duplicate row is a row that is an exact copy of another row. Let’s look at an example:

data = [
    ("Donald", "Trump", 70),
    ("Jo", "Biden", 99),
    ("Barrack", "Obama", 60),  
    ("Donald", "Trump", 70),  # A real duplicate row. A mirror copy of another row
    ("Donald", "Duck", 70)    # Not a duplicate
]

Common Methods to Remove Duplicates

  1. Using distinct():
    • Using distinct(): The first method that comes to mind. It removes duplicate rows that are 100% the same, meaning all columns are checked.
      df = df.distinct()
      
  2. Using dropDuplicates():
    • Using dropDuplicates(): This method is exactly like distinct().
      df = df.dropDuplicates()
      # Alternative: df = df.drop_duplicates()
      
  3. Removing Duplicates when only few columns are duplicates:
    • Using subset=["col1", "col2"]: If it’s not 100% the same, but some columns are the same, you can still remove duplicates based on those columns using df.dropDuplicates(subset=["col1", "col2"]).
      df = df.dropDuplicates(subset=["name", "age"])
      # Alternative: df = df.drop_duplicates(subset=["name", "age"])
      

Advanced Duplicate Removal

Normally, using one function, df.dropDuplicates(subset=["name", "age"]), is enough to handle all kinds of duplicate removal. But, in some cases, the requirement may be more customized. In such cases, we have to use window functions.

The code below shows how to handle special cases of removing duplicates using window functions in PySpark. Each example shows how window functions can be used for specific needs that dropDuplicates() cannot handle.

from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, sum, desc

# Initialize Spark session
spark = SparkSession.builder.appName("AdvancedDuplicateRemoval").getOrCreate()

# Sample DataFrames for each scenario
data1 = [
    (1, "Alice", "2021-01-01"),
    (1, "Alice", "2021-01-02"),
    (2, "Bob", "2021-01-01"),
    (2, "Bob", "2021-01-03"),
    (3, "Charlie", "2021-01-01")
]
columns1 = ["user_id", "name", "timestamp"]
df1 = spark.createDataFrame(data1, columns1)

data2 = [
    (1, "GroupA", 5),
    (2, "GroupA", 10),
    (3, "GroupB", 15),
    (4, "GroupB", 8),
    (5, "GroupC", 12)
]
columns2 = ["id", "group_id", "priority"]
df2 = spark.createDataFrame(data2, columns2)

data3 = [
    (1, "TXN1", 100),
    (2, "TXN1", 200),
    (3, "TXN2", 300),
    (4, "TXN2", 400),
    (5, "TXN3", 500)
]
columns3 = ["id", "transaction_id", "sales_amount"]
df3 = spark.createDataFrame(data3, columns3)

# Scenario 1: Keep the latest duplicate based on a timestamp column
window_spec1 = Window.partitionBy("user_id").orderBy(desc("timestamp"))
df1_with_row_num = df1.withColumn("row_num", row_number().over(window_spec1))
unique_df1 = df1_with_row_num.filter(df1_with_row_num.row_num == 1).drop("row_num")
print("Scenario 1: Keep the most recent entry")
unique_df1.show()

# Scenario 2: Keep the duplicate with the highest priority value within each group
window_spec2 = Window.partitionBy("group_id").orderBy(desc("priority"))
df2_with_row_num = df2.withColumn("row_num", row_number().over(window_spec2))
unique_df2 = df2_with_row_num.filter(df2_with_row_num.row_num == 1).drop("row_num")
print("Scenario 2: Keep the highest priority entry within each group")
unique_df2.show()

# Scenario 3: Aggregate data from duplicate rows before removing them
window_spec3 = Window.partitionBy("transaction_id")
df3_with_aggregates = df3.withColumn("total_sales", sum("sales_amount").over(window_spec3))
unique_df3 = df3_with_aggregates.dropDuplicates(["transaction_id"])
print("Scenario 3: Aggregate sales amounts and keep one entry per transaction ID")
unique_df3.show()