# Spark: replace array with IDs with values; or: how to join objects?

**Date:** 2019-12-11  
**Author:** Kees C. Bakker  
**Categories:** Databricks / Spark  
**Original:** https://keestalkstech.com/spark-replace-array-with-ids-with-values-or-how-to-join-objects/

![Spark: replace array with IDs with values; or: how to join objects?](https://keestalkstech.com/wp-content/uploads/2019/12/perry-grone-lbLgFFlADrY-unsplash-scaled.jpg)

---

This week we've been looking at joining two huge tables in Spark into a single table. It turns out that it is *not a straightforward exercise* to join data based on an array of IDs.

I've written this blog to show *a way* of solving the problem. It felt as a "math exercise" in which I not only had to show the answer but also the steps to get there - so bear with me.

[outline]

## The Set & the Objective

We have two data frames: *Item* and *Resource*. An item has multiple resources. The set looks like this:

```py
items_df = spark.createDataFrame([
    (1, "Alpha", [1]),
    (2, "Beta", [2,3]),
    (3, "Gamma", []),
    (4, "Delta", [4,3,2,1]),
    (5, "Epsilon", [5,5]),
    (6, "Zeta", [6])
], ["item_id", "name", "resources"])

resources_df = spark.createDataFrame([
    (1, "Resource I"),
    (2, "Resource II"),
    (3, "Resource III"),
    (4, "Resource IV"),
    (5, "Resource V")
], ["resource_id", "name"])
```

I've added two special use cases: the *5. Epsilon* item has two references to the same resource and the *6. Zeta* item has a reference to a resource that does not exist.

*The objective is simple:*
*Create a new data frame in which the items* 
*have resource objects instead of the IDs.*

## Imports

Let's do some import first:

```py
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql import Window
```

## Step 1: Let's join

The first thing I usually try, is joining both data frames:

```py
df = (items_df
      .select("item_id", explode("resources").alias("resource_id"))
      .join(resources_df, "resource_id")
      .groupBy("item_id")
      .agg(collect_list(struct("*")).alias("resources"))
      .join(items_df.drop("resources"), "item_id")
      .orderBy("item_id")
     )
```

The result looks like this:

[table id=1 /]

There are many things wrong with this result, but that's fine. We'll address them step by step. The first thing that you'll notice is that we are missing 2 items: *3. Gamma* and *6. Zeta*. Let's get them back.

## Step 2: Left Outer Join

[There are many types of joins](https://www.diffen.com/difference/Inner_Join_vs_Outer_Join). Here we want to use a *left outer join*, as it will replace the *resources* that could not be matched with `null` values. Also, we want to do an `explode_outer` to preserve a record for *3. Gamma*.

```py
df = (items_df
      .select('item_id', explode_outer('resources').alias('resource_id'))
      .join(resources_df, "resource_id", "left_outer")
      .groupBy('item_id')
      .agg(collect_list(struct("*")).alias("resources"))
      .join(items_df.drop("resources"), 'item_id')
      .orderBy('item_id')
     )
```

This results in:

[table id=2 /]

Both records are back, but notice how we got an *item_id* field in our resources, let's get rid of it.

## Step 3: Trim the fat

To remove the field, I use a different way of joining the resources. When I join them with `alias("r")`, I can use that alias to construct a `struct` with only the resource information:

```py
df = (items_df
      .select("item_id", explode_outer("resources").alias("rid"))
      .join(resources_df.alias("r"), col("rid") == col("resource_id"), "left_outer")
      .groupBy('item_id')
      .agg(collect_list(struct("r.*")).alias("resources"))
      .join(items_df.drop("resources"), 'item_id')
      .orderBy('item_id')
     )
```

This will get rid of the *item_id* field:

[table id=3 /]

But... it also got rid of the information we had on *6. Zeta*. Now we have a `null` record, instead of only a *name* that was `null` (not sure if that is a bad thing). If we want to get the information back, we should add it to the structure we create during aggregation:

```py
df = (items_df
      .select("item_id", explode_outer("resources").alias("rid"))
      .join(resources_df.alias("r"), col("rid") == col("resource_id"), "left_outer")
      .groupBy('item_id')
      .agg(collect_list(struct("r.*", col("rid").alias("resource_id"))).alias("resources"))
      .join(items_df.drop("resources"), 'item_id')
      .orderBy('item_id')
     )
```

Which results in *6. Zeta* having the *resource_id* back:

[table id=5 /]

## Step 4: Empty array?

Now, when we look at *3. Gamma*, we see a major problem: the original resource array had no data, and now we have a record. That's not okay. We can fix this by constructing a resource structure before doing the `group by` / `agg`. The structure should not be created if the `rid` is `null`. Remember: we've introduced `null`-values when we added `explode_outer`.

```py
df = (items_df
      .select('item_id', explode_outer('resources').alias('rid'))
      .join(resources_df.alias("r"), col("rid") == col("resource_id"), "left_outer")
      .withColumn("resource", 
                  when(col("rid").isNull(), None).otherwise(
                    struct(
                      "r.*", 
                      col("rid").alias("resource_id")
                    )))
      .groupBy('item_id')
      .agg(collect_list("resource").alias("resources"))
      .join(items_df.drop("resources"), 'item_id')
      .orderBy('item_id')
     )
```

Now *3. Gamma* has an empty array of resources:

[table id=4 /]

So, are we ready? It depends... do you think the *order* of the resources is important? Most records have the wrong order.

## Step 5: Get the position

First, we need to get the position of each resource. By doing a `posexplode_outer` we get a *col* and *pos* column that we can use:

```py
df = (items_df
      .select('item_id', posexplode_outer('resources'))
      .withColumnRenamed("col", "rid")
      .join(resources_df.alias("r"), col("rid") == col("resource_id"), "left_outer")
      .withColumn("resource", 
                  when(col("rid").isNull(), None).otherwise(
                    struct(
                      "r.*", 
                      col("rid").alias("resource_id"),
                      "pos"
                    )))
      .groupBy('item_id')
      .agg(collect_list("resource").alias("resources"))
      .join(items_df.drop("resources"), 'item_id')
      .orderBy('item_id')
     )
```

Here we see the *pos* field for each resource:

[table id=6 /]

Now we only need to sort the fields.

## Step 6: Sorting with a Window

The `collect_list` is not ordered, it just takes the values as they are partitioned by the `group_by`. One way of controlling this, is using a [`Window` ](http://spark.apache.org/docs/latest/api/python/pyspark.sql.html?highlight=window#pyspark.sql.Window)function that takes care of the ordering of each partition.

```py
w = Window.partitionBy('item_id').orderBy('pos')

df = (items_df
      .select('item_id', posexplode_outer('resources'))
      .withColumnRenamed("col", "rid")
      .join(resources_df.alias("r"), col("rid") == col("resource_id"), "left_outer")
      .withColumn("resource", 
                  when(col("rid").isNull(), None).otherwise(
                    struct(
                      "r.*", 
                      col("rid").alias("resource_id")
                    )))
      .withColumn("resources", collect_list('resource').over(w))
      .groupBy('item_id')
      .agg(max('resources').alias('resources'))
      .join(items_df.drop("resources"), 'item_id')
      .orderBy('item_id')
     )
```

This results in the completion of our objective:

[table id=7 /]

## Question: can it be functionized?

[A colleague of mine](https://www.linkedin.com/in/jari-koopman/) was wondering if it could be turned into a function. I'm not a big fan of functions, as they often obscure what's going on. But... it can be turned into a function quite easily:

```py
def replace_column_with_objects(df, key_field, column_field, foreign_df, foreign_key):

  w = Window.partitionBy(key_field).orderBy('pos')

  return (df
        # create a "linking" table 
        .select(key_field, posexplode_outer(column_field))
        # add foreign information
        .join(foreign_df.alias("f"), col("col") == col(foreign_key), "left_outer")
        # create a foreign structure
        .withColumn("x", 
                    when(col("col").isNull(), None).otherwise(
                      struct(
                        "f.*", 
                        col("col").alias(foreign_key)
                      )))
        # make an array of foreign structures
        .withColumn('x', collect_list('x').over(w))
        # group back to a single main record
        .groupBy(key_field)
        .agg(max('x').alias(column_field))
        # add the information of the main record without the replace column
        .join(df.drop(column_field), key_field)
       )
```

For readability you could call it with named parameters:

```py
result_df = replace_column_with_objects(
  df=items_df, 
  key_field="item_id",
  column_field="resources",
  foreign_df=resources_df,
  foreign_key="resource_id").orderBy('item_id')
```

## Final thoughts

Code tends to grow. Readability, especially when you chain steps, will suffer over time. It is not always clear what code does or why it has been constructed in a certain way. Comments may be your only way out.
