Caching resized images on S3 with Databricks

When you are training a machine learning image classification model, you often need to resize the images your dataset into smaller ones. When you retrain your model on new data, you resize the images once more. This slows you down, especially when your source images are big. Now imagine that you have multiple models that are working on the same images, the problem gets even bigger.

In this blog I'll share how S3 can be used to cache the resized images.

Dataset

I'll be working on a set similar to this. It has a label and an S3 key:

+-------------+---------------------+
| label       | imag_key            | 
+-------------+---------------------+
| home-living | 2119/16267364_eb_02 | 
| fashion     | 1682/16199538_mb_02 | 
| fashion     | 3405/16207450_mb_02 | 
| home-living | 1804/741901_eb_01   | 
| other       | 2023/16310771_eb_06 | 
| other       | 1521/16297163_eb_05 | 
| other       | 1598/16299306_eb_06 | 
| other       | 9557/16318314_eb_03 | 
| other       | 1039/16216146_eb_03 | 
| fashion     | 3385/16348790_pb_01 | 
| other       | 1162/16354410_eb_01 | 
| fashion     | 2814/16206209_pb_01 | 
| fashion     | 1162/16328003_eb_04 | 
| other       | 7972/16079055_eb_04 | 
+-------------+---------------------+

Import(ant) things first

Let's first taks care of our imports:

from pyspark.sql.functions import *
from pyspark.sql.types import *
import os, pathlib, pyspark, sys
from PIL import Image

I have mounted two S3 buckets, so I can use the Databricks File System (DBFS) to access my buckets. I'll be using the assets_bucket to read source images from and the ml_bucket to store the resized images on.

assets_bucket='assets-prod'
ml_bucket='dp-visual'

destination_path = '/dbfs/mnt/{}/resized'.format(ml_bucket)
source_path = '/dbfs/mnt/{}'.format(assets_bucket)

Persist

We will use a UDF to resize & cache. The advantage of a UDF is that it can return data for each record and it will be executed in parallel on the Spark cluster.

The algorithm is:

  1. Resolve the destination key name, based on the key, image size and type.
  2. If the destination key does not exist:
    1. Get the original
    2. Resize
    3. Persist
  3. Return the destination key.

Expressed into Python code, it looks like this:

@udf("string")
def persist(
  key,                  # the key of S3, will be part of the destination key
  image_size=224,       # the resized image size
  image_type='PNG',     # the type of image to generate
):
  
    destination_key = "{}x{}/{}.{}".format(image_size, image_size, key, image_type.lower())
    path_destination_file_path = "{}/{}".format(destination_path, destination_key)

    # is converted?
    if os.path.exists(path_destination_file_path):
        return destination_key
      
    source_file = "{}/{}".format(source_path, key)
      
    # is the asset still here?
    if not os.path.exists(source_file):
        return None
    
    # read it
    with open(source_file, "rb") as f:
        image_bytes=bytearray(f.read())

    # resize it
    images = Image.open(io.BytesIO(image_bytes))
    converted_image = images.convert('RGB')
    resized_image = converted_image.resize((image_size, image_size), Image.ANTIALIAS)
    
    # ensure directory & save
    destiantion_directory = os.path.dirname(path_destination_file_path)
    pathlib.Path(destiantion_directory).mkdir(parents=True, exist_ok=True)
    
    resized_image.save(path_destination_file_path, image_type)
    
    return destination_key

Using it

The dataset is stored in images_df. Let's say we want to store multiple resolutions on S3, we need to execute the following code:

persisted_images_df = (
  images_df
    #.sample(fraction=0.1)
    #.limit(20)
    .select("label", "image_key")
    .distinct()
  	.withColumn('size_122', persist(col('image_key'), lit(122)))
  	.withColumn('size_244', persist(col('image_key'), lit(244)))
)

This results in:

+-------------+---------------------+-----------------------+-----------------------+
|       label |           image_key |              size_122 |              size_244 |
+-------------+---------------------+-----------------------+-----------------------+
|       other | 6857/16330849_pb_01 | 122x122/6857/1633 ... | 244x244/6857/1633 ... |
|       other | 4345/16322466_eb_01 | 122x122/4345/1632 ... | 244x244/4345/1632 ... |
| home-living |   2096/308693_eb_06 | 122x122/2096/3086 ... | 244x244/2096/3086 ... |
| home-living | 3075/16127396_pb_01 | 122x122/3075/1612 ... | 244x244/3075/1612 ... |
| home-living | 8092/16028030_eb_02 | 122x122/8092/1602 ... | 244x244/8092/1602 ... |
|     fashion | 2043/16327332_eb_04 | 122x122/2043/1632 ... | 244x244/2043/1632 ... |
|     fashion | 1104/16351166_eb_01 | 122x122/1104/1635 ... | 244x244/1104/1635 ... |
|     fashion | 1273/16271703_eb_04 | 122x122/1273/1627 ... | 244x244/1273/1627 ... |
|     fashion | 1226/16288579_eb_02 | 122x122/1226/1628 ... | 244x244/1226/1628 ... |
|       other |   3644/954572_pb_01 | 122x122/3644/9545 ... | 244x244/3644/9545 ... |
|       other |   2367/311719_pb_01 | 122x122/2367/3117 ... | 244x244/2367/3117 ... |
|     fashion |   6074/357979_eb_04 | 122x122/6074/3579 ... | 244x244/6074/3579 ... |
|     fashion | 2145/16358490_eb_04 | 122x122/2145/1635 ... | 244x244/2145/1635 ... |
|       other |   3519/241094_pb_01 | 122x122/3519/2410 ... | 244x244/3519/2410 ... |
|       other | 1209/16315365_eb_04 | 122x122/1209/1631 ... | 244x244/1209/1631 ... |
|       other | 7313/16173709_eb_01 | 122x122/7313/1617 ... | 244x244/7313/1617 ... |
|       other | 1130/16170786_eb_06 | 122x122/1130/1617 ... | 244x244/1130/1617 ... |
|       other | 1707/16268527_eb_06 | 122x122/1707/1626 ... | 244x244/1707/1626 ... |
|       other | 1670/16275881_pb_01 | 122x122/1670/1627 ... | 244x244/1670/1627 ... |
|       other |   1689/628065_eb_03 | 122x122/1689/6280 ... | 244x244/1689/6280 ... |
+-------------+---------------------+-----------------------+-----------------------+

Final thoughts

Just a simple UDF is enough to persist the resized image on S3. You might want to combine this with converting the image to a byte array for more performance. Swapping from PIL to Pillow-SIMD might also benefit.

Further reading

While working on this topic I found some excellent sources for reading:

expand_less