When you are training a machine learning image classification model, you often need to resize the images your dataset into smaller ones. When you retrain your model on new data, you resize the images once more. This slows you down, especially when your source images are big. Now imagine that you have multiple models that are working on the same images, the problem gets even bigger.
In this blog I'll share how S3 can be used to cache the resized images.
Dataset
I'll be working on a set similar to this. It has a label and an S3 key:
+-------------+---------------------+
| label | imag_key |
+-------------+---------------------+
| home-living | 2119/16267364_eb_02 |
| fashion | 1682/16199538_mb_02 |
| fashion | 3405/16207450_mb_02 |
| home-living | 1804/741901_eb_01 |
| other | 2023/16310771_eb_06 |
| other | 1521/16297163_eb_05 |
| other | 1598/16299306_eb_06 |
| other | 9557/16318314_eb_03 |
| other | 1039/16216146_eb_03 |
| fashion | 3385/16348790_pb_01 |
| other | 1162/16354410_eb_01 |
| fashion | 2814/16206209_pb_01 |
| fashion | 1162/16328003_eb_04 |
| other | 7972/16079055_eb_04 |
+-------------+---------------------+
Import(ant) things first
Let's first taks care of our imports:
from pyspark.sql.functions import *
from pyspark.sql.types import *
import os, pathlib, pyspark, sys
from PIL import Image
I have mounted two S3 buckets, so I can use the Databricks File System (DBFS) to access my buckets. I'll be using the assets_bucket
to read source images from and the ml_bucket
to store the resized images on.
assets_bucket='assets-prod'
ml_bucket='dp-visual'
destination_path = '/dbfs/mnt/{}/resized'.format(ml_bucket)
source_path = '/dbfs/mnt/{}'.format(assets_bucket)
Persist
We will use a UDF to resize & cache. The advantage of a UDF is that it can return data for each record and it will be executed in parallel on the Spark cluster.
The algorithm is:
- Resolve the destination key name, based on the key, image size and type.
- If the destination key does not exist:
- Get the original
- Resize
- Persist
- Return the destination key.
Expressed into Python code, it looks like this:
@udf("string")
def persist(
key, # the key of S3, will be part of the destination key
image_size=224, # the resized image size
image_type='PNG', # the type of image to generate
):
destination_key = "{}x{}/{}.{}".format(image_size, image_size, key, image_type.lower())
path_destination_file_path = "{}/{}".format(destination_path, destination_key)
# is converted?
if os.path.exists(path_destination_file_path):
return destination_key
source_file = "{}/{}".format(source_path, key)
# is the asset still here?
if not os.path.exists(source_file):
return None
# read it
with open(source_file, "rb") as f:
image_bytes=bytearray(f.read())
# resize it
images = Image.open(io.BytesIO(image_bytes))
converted_image = images.convert('RGB')
resized_image = converted_image.resize((image_size, image_size), Image.ANTIALIAS)
# ensure directory & save
destiantion_directory = os.path.dirname(path_destination_file_path)
pathlib.Path(destiantion_directory).mkdir(parents=True, exist_ok=True)
resized_image.save(path_destination_file_path, image_type)
return destination_key
Using it
The dataset is stored in images_df
. Let's say we want to store multiple resolutions on S3, we need to execute the following code:
persisted_images_df = (
images_df
#.sample(fraction=0.1)
#.limit(20)
.select("label", "image_key")
.distinct()
.withColumn('size_122', persist(col('image_key'), lit(122)))
.withColumn('size_244', persist(col('image_key'), lit(244)))
)
This results in:
+-------------+---------------------+-----------------------+-----------------------+
| label | image_key | size_122 | size_244 |
+-------------+---------------------+-----------------------+-----------------------+
| other | 6857/16330849_pb_01 | 122x122/6857/1633 ... | 244x244/6857/1633 ... |
| other | 4345/16322466_eb_01 | 122x122/4345/1632 ... | 244x244/4345/1632 ... |
| home-living | 2096/308693_eb_06 | 122x122/2096/3086 ... | 244x244/2096/3086 ... |
| home-living | 3075/16127396_pb_01 | 122x122/3075/1612 ... | 244x244/3075/1612 ... |
| home-living | 8092/16028030_eb_02 | 122x122/8092/1602 ... | 244x244/8092/1602 ... |
| fashion | 2043/16327332_eb_04 | 122x122/2043/1632 ... | 244x244/2043/1632 ... |
| fashion | 1104/16351166_eb_01 | 122x122/1104/1635 ... | 244x244/1104/1635 ... |
| fashion | 1273/16271703_eb_04 | 122x122/1273/1627 ... | 244x244/1273/1627 ... |
| fashion | 1226/16288579_eb_02 | 122x122/1226/1628 ... | 244x244/1226/1628 ... |
| other | 3644/954572_pb_01 | 122x122/3644/9545 ... | 244x244/3644/9545 ... |
| other | 2367/311719_pb_01 | 122x122/2367/3117 ... | 244x244/2367/3117 ... |
| fashion | 6074/357979_eb_04 | 122x122/6074/3579 ... | 244x244/6074/3579 ... |
| fashion | 2145/16358490_eb_04 | 122x122/2145/1635 ... | 244x244/2145/1635 ... |
| other | 3519/241094_pb_01 | 122x122/3519/2410 ... | 244x244/3519/2410 ... |
| other | 1209/16315365_eb_04 | 122x122/1209/1631 ... | 244x244/1209/1631 ... |
| other | 7313/16173709_eb_01 | 122x122/7313/1617 ... | 244x244/7313/1617 ... |
| other | 1130/16170786_eb_06 | 122x122/1130/1617 ... | 244x244/1130/1617 ... |
| other | 1707/16268527_eb_06 | 122x122/1707/1626 ... | 244x244/1707/1626 ... |
| other | 1670/16275881_pb_01 | 122x122/1670/1627 ... | 244x244/1670/1627 ... |
| other | 1689/628065_eb_03 | 122x122/1689/6280 ... | 244x244/1689/6280 ... |
+-------------+---------------------+-----------------------+-----------------------+
Final thoughts
Just a simple UDF is enough to persist the resized image on S3. You might want to combine this with converting the image to a byte array for more performance. Swapping from PIL to Pillow-SIMD might also benefit.
Further reading
While working on this topic I found some excellent sources for reading: