# Sorting an array of a complex data type in Spark

**Date:** 2019-09-05  
**Author:** Kees C. Bakker  
**Categories:** Databricks / Spark  
**Original:** https://keestalkstech.com/sorting-an-array-of-a-complex-data-type-in-spark/

![Sorting an array of a complex data type in Spark](https://keestalkstech.com/wp-content/uploads/2019/09/thomas-habr-wprOCzLIEYI-unsplash.jpg)

---

Today we'll be looking at sorting and reducing an array of a complex data type. I'm using Databricks to do Spark, but I'm sure the code is compatible. I'll be using Spark SQL to show the steps. I've tried to keep the data as simple as possible. The example should apply to scenarios that are more complex. I'll be using Spark SQL functions to show what happens, [at the end I have a PySpark example of the code](https://keestalkstech.com/2019/09/sorting-an-array-of-a-complex-data-type-in-spark/#pyspark-please).

Big shout-out to [Jesse Bouwman](https://www.linkedin.com/in/jesse-bouwman-610b49a5/) for the collaboration.

## Sample data

I'll be working on the following sample:

```spark_output
+---------+---------+--------+
|sessionId|articleId|sequence|
+---------+---------+--------+
| 39582930|   467323|       2|
| 56438837|   765645|       3|
| 56438837|   484482|       1|
| 39582930|   948521|       4|
| 87466464|   485756|       1|
| 74755638|   984755|       2|
| 56438837|   128842|       4|
| 56438837|   475532|       2|
| 74755638|   354461|       1|
| 39582930|   127333|       3|
| 39582930|   484482|       1|
+---------+---------+--------+
```

You can recreate the data by this Python code:

```py
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql import *
import pyspark

data_df = spark.createDataFrame([
  Row(sessionId=39582930, articleId=467323, sequence=2),
  Row(sessionId=56438837, articleId=765645, sequence=3),
  Row(sessionId=56438837, articleId=484482, sequence=1),
  Row(sessionId=39582930, articleId=948521, sequence=4),
  Row(sessionId=87466464, articleId=485756, sequence=1),
  Row(sessionId=74755638, articleId=984755, sequence=2),
  Row(sessionId=56438837, articleId=128842, sequence=4),
  Row(sessionId=56438837, articleId=475532, sequence=2),
  Row(sessionId=74755638, articleId=354461, sequence=1),
  Row(sessionId=39582930, articleId=127333, sequence=3),
  Row(sessionId=39582930, articleId=484482, sequence=1),
])

# if we don't use this, the data is somehow sorted differently
views_df = data_df.select('sessionId', 'articleId', 'sequence')
views_df.createOrReplaceTempView('views')
```

*Objective: create a list of sessions and articles that were viewed during the session. The articles need to be in the right order.*

## Group by session

Let's group the sessions and articles that have been viewed during that session. We're using `collect_list` to group the articles into a single array.

```sql
SELECT   sessionId, 
         collect_list((articleId, sequence)) AS articles 
FROM     views 
GROUP BY sessionId
```

This gives the following result:

```spark_output
+---------+----------------------------------------------------+
|sessionId|articles                                            |
+---------+----------------------------------------------------+
|74755638 |[[984755, 2], [354461, 1]]                          |
|87466464 |[[485756, 1]]                                       |
|39582930 |[[467323, 2], [948521, 4], [127333, 3], [484482, 1]]|
|56438837 |[[765645, 3], [484482, 1], [128842, 4], [475532, 2]]|
+---------+----------------------------------------------------+
```

## Sort the array

Now we'll sort the array. Let's use the `array_sort`.

```sql
SELECT   sessionId, 
         array_sort(collect_list((articleId, sequence))) AS articles 
FROM     views 
GROUP BY sessionId
```

That will give this result:

```spark_output
+---------+----------------------------------------------------+
|sessionId|articles                                            |
+---------+----------------------------------------------------+
|74755638 |[[354461, 1], [984755, 2]]                          |
|87466464 |[[485756, 1]]                                       |
|39582930 |[[127333, 3], [467323, 2], [484482, 1], [948521, 4]]|
|56438837 |[[128842, 4], [475532, 2], [484482, 1], [765645, 3]]|
+---------+----------------------------------------------------+
```

Not all records are sorted correctly. Why? The `array_sort` sorts on the `articleId`, as it is the first field in the `named_struct`. When we switch the `sequence` and `articleId` fields, the sorting will be okay.

```sql
SELECT   sessionId, 
         array_sort(collect_list((sequence, articleId))) AS articles 
FROM     views 
GROUP BY sessionId
```

This results in:

```spark_output
+---------+----------------------------------------------------+
|sessionId|articles                                            |
+---------+----------------------------------------------------+
|74755638 |[[1, 354461], [2, 984755]]                          |
|87466464 |[[1, 485756]]                                       |
|39582930 |[[1, 484482], [2, 467323], [3, 127333], [4, 948521]]|
|56438837 |[[1, 484482], [2, 475532], [3, 765645], [4, 128842]]|
+---------+----------------------------------------------------+
```

## Transform the articles

The only thing we need to do is ditch the sequence numbers by [transforming](https://spark.apache.org/docs/latest/api/sql/index.html#transform) the array using a Lambda expression:

```sql
SELECT   sessionId, 
         TRANSFORM(
              array_sort(collect_list((sequence, articleId))),
              a -> a.articleId) AS articles 
FROM     views 
GROUP BY sessionId
```

*Note: [higher order functions](https://docs.databricks.com/_static/notebooks/apache-spark-2.4-functions.html) like transform have been introduced by Spark in version [2.4.](https://databricks.com/blog/2018/11/08/introducing-apache-spark-2-4.html)* The query result is the end-result we're looking for:

```spark_output
+---------+--------------------------------+
|sessionId|articles                        |
+---------+--------------------------------+
|74755638 |[354461, 984755]                |
|87466464 |[485756]                        |
|39582930 |[484482, 467323, 127333, 948521]|
|56438837 |[484482, 475532, 765645, 128842]|
+---------+--------------------------------+
```

## PySpark, please?

The Python code almost looks the same, but has some slight differences:

```py
(views_df 
  .groupBy('sessionId')
  .agg(
    array_sort(
      collect_list(struct('sequence', 'articleId'))
    ).alias('articles')
  ).select("sessionId", col("articles.articleId").alias("articles"))
).show(20, False)
```
