Spark: replace array with IDs with values; or: how to join objects?

This week we've been looking at joining two huge tables in Spark into a single table. It turns out that it is not a straightforward exercise to join data based on an array of IDs.

I've written this blog to show a way of solving the problem. It felt as a "math exercise" in which I not only had to show the answer but also the steps to get there - so bear with me.

  1. Intro
  2. The Set & the Objective
  3. Imports
  4. Step 1: Let's join
  5. Step 2: Left Outer Join
  6. Step 3: Trim the fat
  7. Step 4: Empty array?
  8. Step 5: Get the position
  9. Step 6: Sorting with a Window
  10. Question: can it be functionized?
  11. Final thoughts
  12. Comments

The Set & the Objective

We have two data frames: Item and Resource. An item has multiple resources. The set looks like this:

items_df = spark.createDataFrame([
    (1, "Alpha", [1]),
    (2, "Beta", [2,3]),
    (3, "Gamma", []),
    (4, "Delta", [4,3,2,1]),
    (5, "Epsilon", [5,5]),
    (6, "Zeta", [6])
], ["item_id", "name", "resources"])

resources_df = spark.createDataFrame([
    (1, "Resource I"),
    (2, "Resource II"),
    (3, "Resource III"),
    (4, "Resource IV"),
    (5, "Resource V")
], ["resource_id", "name"])

I've added two special use cases: the 5. Epsilon item has two references to the same resource and the 6. Zeta item has a reference to a resource that does not exist.

The objective is simple:
Create a new data frame in which the items
have resource objects instead of the IDs.

Imports

Let's do some import first:

from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql import Window

Step 1: Let's join

The first thing I usually try, is joining both data frames:

df = (items_df
      .select("item_id", explode("resources").alias("resource_id"))
      .join(resources_df, "resource_id")
      .groupBy("item_id")
      .agg(collect_list(struct("*")).alias("resources"))
      .join(items_df.drop("resources"), "item_id")
      .orderBy("item_id")
     )

The result looks like this:

item_idresourcesname
1[{"resource_id":1,"item_id":1,"name":"Resource I"}]Alpha
2[{"resource_id":3,"item_id":2,"name":"Resource III"},{"resource_id":2,"item_id":2,"name":"Resource II"}]Beta
4[{"resource_id":1,"item_id":4,"name":"Resource I"},{"resource_id":3,"item_id":4,"name":"Resource III"},{"resource_id":2,"item_id":4,"name":"Resource II"},{"resource_id":4,"item_id":4,"name":"Resource IV"}]Delta
5[{"resource_id":5,"item_id":5,"name":"Resource V"},{"resource_id":5,"item_id":5,"name":"Resource V"}]Epsilon

There are many things wrong with this result, but that's fine. We'll address them step by step. The first thing that you'll notice is that we are missing 2 items: 3. Gamma and 6. Zeta. Let's get them back.

Step 2: Left Outer Join

There are many types of joins. Here we want to use a left outer join, as it will replace the resources that could not be matched with null values. Also, we want to do an explode_outer to preserve a record for 3. Gamma.

df = (items_df
      .select('item_id', explode_outer('resources').alias('resource_id'))
      .join(resources_df, "resource_id", "left_outer")
      .groupBy('item_id')
      .agg(collect_list(struct("*")).alias("resources"))
      .join(items_df.drop("resources"), 'item_id')
      .orderBy('item_id')
     )

This results in:

item_idresourcesname
1[{"resource_id":1,"item_id":1,"name":"Resource I"}]Alpha
2[{"resource_id":3,"item_id":2,"name":"Resource III"},{"resource_id":2,"item_id":2,"name":"Resource II"}]Beta
3[{"resource_id":null,"item_id":3,"name":null}]Gamma
4[{"resource_id":1,"item_id":4,"name":"Resource I"},{"resource_id":3,"item_id":4,"name":"Resource III"},{"resource_id":2,"item_id":4,"name":"Resource II"},{"resource_id":4,"item_id":4,"name":"Resource IV"}]Delta
5[{"resource_id":5,"item_id":5,"name":"Resource V"},{"resource_id":5,"item_id":5,"name":"Resource V"}]Epsilon
6[{"resource_id":6,"item_id":6,"name":null}]Zeta

Both records are back, but notice how we got an item_id field in our resources, let's get rid of it.

Step 3: Trim the fat

To remove the field, I use a different way of joining the resources. When I join them with alias("r"), I can use that alias to construct a struct with only the resource information:

df = (items_df
      .select("item_id", explode_outer("resources").alias("rid"))
      .join(resources_df.alias("r"), col("rid") == col("resource_id"), "left_outer")
      .groupBy('item_id')
      .agg(collect_list(struct("r.*")).alias("resources"))
      .join(items_df.drop("resources"), 'item_id')
      .orderBy('item_id')
     )

This will get rid of the item_id field:

item_idresourcesname
1[{"resource_id":1,"name":"Resource I"}]Alpha
2[{"resource_id":3,"name":"Resource III"},{"resource_id":2,"name":"Resource II"}]Beta
3[{"resource_id":null,"name":null}]Gamma
4[{"resource_id":1,"name":"Resource I"},{"resource_id":3,"name":"Resource III"},{"resource_id":2,"name":"Resource II"},{"resource_id":4,"name":"Resource IV"}]Delta
5[{"resource_id":5,"name":"Resource V"},{"resource_id":5,"name":"Resource V"}]Epsilon
6[{"resource_id":null,"name":null}]Zeta

But... it also got rid of the information we had on 6. Zeta. Now we have a null record, instead of only a name that was null (not sure if that is a bad thing). If we want to get the information back, we should add it to the structure we create during aggregation:

df = (items_df
      .select("item_id", explode_outer("resources").alias("rid"))
      .join(resources_df.alias("r"), col("rid") == col("resource_id"), "left_outer")
      .groupBy('item_id')
      .agg(collect_list(struct("r.*", col("rid").alias("resource_id"))).alias("resources"))
      .join(items_df.drop("resources"), 'item_id')
      .orderBy('item_id')
     )

Which results in 6. Zeta having the resource_id back:

item_idresourcesname
1[{"resource_id":1,"name":"Resource I"}]Alpha
2[{"resource_id":3,"name":"Resource III"},{"resource_id":2,"name":"Resource II"}]Beta
3[{"resource_id":null,"name":null}]Gamma
4[{"resource_id":1,"name":"Resource I"},{"resource_id":3,"name":"Resource III"},{"resource_id":2,"name":"Resource II"},{"resource_id":4,"name":"Resource IV"}]Delta
5[{"resource_id":5,"name":"Resource V"},{"resource_id":5,"name":"Resource V"}]Epsilon
6[{"resource_id":6,"name":null}]Zeta

Step 4: Empty array?

Now, when we look at 3. Gamma, we see a major problem: the original resource array had no data, and now we have a record. That's not okay. We can fix this by constructing a resource structure before doing the group by / agg. The structure should not be created if the rid is null. Remember: we've introduced null-values when we added explode_outer.

df = (items_df
      .select('item_id', explode_outer('resources').alias('rid'))
      .join(resources_df.alias("r"), col("rid") == col("resource_id"), "left_outer")
      .withColumn("resource", 
                  when(col("rid").isNull(), None).otherwise(
                    struct(
                      "r.*", 
                      col("rid").alias("resource_id")
                    )))
      .groupBy('item_id')
      .agg(collect_list("resource").alias("resources"))
      .join(items_df.drop("resources"), 'item_id')
      .orderBy('item_id')
     )

Now 3. Gamma has an empty array of resources:

item_idresourcesname
1[{"resource_id":1,"name":"Resource I"}]Alpha
2[{"resource_id":3,"name":"Resource III"},{"resource_id":2,"name":"Resource II"}]Beta
3[]Gamma
4[{"resource_id":1,"name":"Resource I"},{"resource_id":3,"name":"Resource III"},{"resource_id":2,"name":"Resource II"},{"resource_id":4,"name":"Resource IV"}]Delta
5[{"resource_id":5,"name":"Resource V"},{"resource_id":5,"name":"Resource V"}]Epsilon
6[{"resource_id":6,"name":null}]Zeta

So, are we ready? It depends... do you think the order of the resources is important? Most records have the wrong order.

Step 5: Get the position

First, we need to get the position of each resource. By doing a posexplode_outer we get a col and pos column that we can use:

df = (items_df
      .select('item_id', posexplode_outer('resources'))
      .withColumnRenamed("col", "rid")
      .join(resources_df.alias("r"), col("rid") == col("resource_id"), "left_outer")
      .withColumn("resource", 
                  when(col("rid").isNull(), None).otherwise(
                    struct(
                      "r.*", 
                      col("rid").alias("resource_id"),
                      "pos"
                    )))
      .groupBy('item_id')
      .agg(collect_list("resource").alias("resources"))
      .join(items_df.drop("resources"), 'item_id')
      .orderBy('item_id')
     )

Here we see the pos field for each resource:

item_idresourcesname
1[{"resource_id":1,"name":"Resource I","pos":0}]Alpha
2[{"resource_id":3,"name":"Resource III","pos":1},{"resource_id":2,"name":"Resource II","pos":0}]Beta
3[]Gamma
4[{"resource_id":1,"name":"Resource I","pos":3},{"resource_id":3,"name":"Resource III","pos":1},{"resource_id":2,"name":"Resource II","pos":2},{"resource_id":4,"name":"Resource IV","pos":0}]Delta
5[{"resource_id":5,"name":"Resource V","pos":0},{"resource_id":5,"name":"Resource V","pos":1}]Epsilon
6[{"resource_id":6,"name":null,"pos":0}]Zeta

Now we only need to sort the fields.

Step 6: Sorting with a Window

The collect_list is not ordered, it just takes the values as they are partitioned by the group_by. One way of controlling this, is using a Window function that takes care of the ordering of each partition.

w = Window.partitionBy('item_id').orderBy('pos')

df = (items_df
      .select('item_id', posexplode_outer('resources'))
      .withColumnRenamed("col", "rid")
      .join(resources_df.alias("r"), col("rid") == col("resource_id"), "left_outer")
      .withColumn("resource", 
                  when(col("rid").isNull(), None).otherwise(
                    struct(
                      "r.*", 
                      col("rid").alias("resource_id")
                    )))
      .withColumn("resources", collect_list('resource').over(w))
      .groupBy('item_id')
      .agg(max('resources').alias('resources'))
      .join(items_df.drop("resources"), 'item_id')
      .orderBy('item_id')
     )

This results in the completion of our objective:

item_idresourcesname
1[{"resource_id":1,"name":"Resource I"}]Alpha
2[{"resource_id":2,"name":"Resource II"},{"resource_id":3,"name":"Resource III"}]Beta
3[]Gamma
4[{"resource_id":4,"name":"Resource IV"},{"resource_id":3,"name":"Resource III"},{"resource_id":2,"name":"Resource II"},{"resource_id":1,"name":"Resource I"}]Delta
5[{"resource_id":5,"name":"Resource V"},{"resource_id":5,"name":"Resource V"}]Epsilon
6[{"resource_id":6,"name":null}]Zeta

Question: can it be functionized?

A colleague of mine was wondering if it could be turned into a function. I'm not a big fan of functions, as they often obscure what's going on. But... it can be turned into a function quite easily:

def replace_column_with_objects(df, key_field, column_field, foreign_df, foreign_key):

  w = Window.partitionBy(key_field).orderBy('pos')

  return (df
        # create a "linking" table 
        .select(key_field, posexplode_outer(column_field))
        # add foreign information
        .join(foreign_df.alias("f"), col("col") == col(foreign_key), "left_outer")
        # create a foreign structure
        .withColumn("x", 
                    when(col("col").isNull(), None).otherwise(
                      struct(
                        "f.*", 
                        col("col").alias(foreign_key)
                      )))
        # make an array of foreign structures
        .withColumn('x', collect_list('x').over(w))
        # group back to a single main record
        .groupBy(key_field)
        .agg(max('x').alias(column_field))
        # add the information of the main record without the replace column
        .join(df.drop(column_field), key_field)
       )

For readability you could call it with named parameters:

result_df = replace_column_with_objects(
  df=items_df, 
  key_field="item_id",
  column_field="resources",
  foreign_df=resources_df,
  foreign_key="resource_id").orderBy('item_id')

Final thoughts

Code tends to grow. Readability, especially when you chain steps, will suffer over time. It is not always clear what code does or why it has been constructed in a certain way. Comments may be your only way out.

expand_less