Sorting an array of a complex data type in Spark

Today we'll be looking at sorting and reducing an array of a complex data type. I'm using Databricks to do Spark, but I'm sure the code is compatible. I'll be using Spark SQL to show the steps. I've tried to keep the data as simple as possible. The example should apply to scenarios that are more complex. I'll be using Spark SQL functions to show what happens, at the end I have a PySpark example of the code.

Big shout-out to Jesse Bouwman for the collaboration.

Sample data

I'll be working on the following sample:

+---------+---------+--------+
|sessionId|articleId|sequence|
+---------+---------+--------+
| 39582930|   467323|       2|
| 56438837|   765645|       3|
| 56438837|   484482|       1|
| 39582930|   948521|       4|
| 87466464|   485756|       1|
| 74755638|   984755|       2|
| 56438837|   128842|       4|
| 56438837|   475532|       2|
| 74755638|   354461|       1|
| 39582930|   127333|       3|
| 39582930|   484482|       1|
+---------+---------+--------+

You can recreate the data by this Python code:

from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql import *
import pyspark

data_df = spark.createDataFrame([
  Row(sessionId=39582930, articleId=467323, sequence=2),
  Row(sessionId=56438837, articleId=765645, sequence=3),
  Row(sessionId=56438837, articleId=484482, sequence=1),
  Row(sessionId=39582930, articleId=948521, sequence=4),
  Row(sessionId=87466464, articleId=485756, sequence=1),
  Row(sessionId=74755638, articleId=984755, sequence=2),
  Row(sessionId=56438837, articleId=128842, sequence=4),
  Row(sessionId=56438837, articleId=475532, sequence=2),
  Row(sessionId=74755638, articleId=354461, sequence=1),
  Row(sessionId=39582930, articleId=127333, sequence=3),
  Row(sessionId=39582930, articleId=484482, sequence=1),
])

# if we don't use this, the data is somehow sorted differently
views_df = data_df.select('sessionId', 'articleId', 'sequence')
views_df.createOrReplaceTempView('views')

Objective: create a list of sessions and articles that were viewed during the session. The articles need to be in the right order.

Group by session

Let's group the sessions and articles that have been viewed during that session. We're using collect_list to group the articles into a single array.

SELECT   sessionId, 
         collect_list((articleId, sequence)) AS articles 
FROM     views 
GROUP BY sessionId

This gives the following result:

+---------+----------------------------------------------------+
|sessionId|articles                                            |
+---------+----------------------------------------------------+
|74755638 |[[984755, 2], [354461, 1]]                          |
|87466464 |[[485756, 1]]                                       |
|39582930 |[[467323, 2], [948521, 4], [127333, 3], [484482, 1]]|
|56438837 |[[765645, 3], [484482, 1], [128842, 4], [475532, 2]]|
+---------+----------------------------------------------------+

Sort the array

Now we'll sort the array. Let's use the array_sort.

SELECT   sessionId, 
         array_sort(collect_list((articleId, sequence))) AS articles 
FROM     views 
GROUP BY sessionId

That will give this result:

+---------+----------------------------------------------------+
|sessionId|articles                                            |
+---------+----------------------------------------------------+
|74755638 |[[354461, 1], [984755, 2]]                          |
|87466464 |[[485756, 1]]                                       |
|39582930 |[[127333, 3], [467323, 2], [484482, 1], [948521, 4]]|
|56438837 |[[128842, 4], [475532, 2], [484482, 1], [765645, 3]]|
+---------+----------------------------------------------------+

Not all records are sorted correctly. Why? The array_sort sorts on the articleId, as it is the first field in the named_struct. When we switch the sequence and articleId fields, the sorting will be okay.

SELECT   sessionId, 
         array_sort(collect_list((sequence, articleId))) AS articles 
FROM     views 
GROUP BY sessionId

This results in:

+---------+----------------------------------------------------+
|sessionId|articles                                            |
+---------+----------------------------------------------------+
|74755638 |[[1, 354461], [2, 984755]]                          |
|87466464 |[[1, 485756]]                                       |
|39582930 |[[1, 484482], [2, 467323], [3, 127333], [4, 948521]]|
|56438837 |[[1, 484482], [2, 475532], [3, 765645], [4, 128842]]|
+---------+----------------------------------------------------+

Transform the articles

The only thing we need to do is ditch the sequence numbers by transforming the array using a Lambda expression:

SELECT   sessionId, 
         TRANSFORM(
              array_sort(collect_list((sequence, articleId))),
              a -> a.articleId) AS articles 
FROM     views 
GROUP BY sessionId

Note: higher order functions like transform have been introduced by Spark in version 2.4. The query result is the end-result we're looking for:

+---------+--------------------------------+
|sessionId|articles                        |
+---------+--------------------------------+
|74755638 |[354461, 984755]                |
|87466464 |[485756]                        |
|39582930 |[484482, 467323, 127333, 948521]|
|56438837 |[484482, 475532, 765645, 128842]|
+---------+--------------------------------+

PySpark, please?

The Python code almost looks the same, but has some slight differences:

(views_df 
  .groupBy('sessionId')
  .agg(
    array_sort(
      collect_list(struct('sequence', 'articleId'))
    ).alias('articles')
  ).select("sessionId", col("articles.articleId").alias("articles"))
).show(20, False)

expand_less