Rebuilding our next-gen ML Platform with the best of Spark and Tensorflow

Rue Gilt Groupe is a fashion eCommerce company located in Boston, MA, that has 50M+ members and daily flash sales on millions of products. Our Data Science team is a tight-knit group of Data Scientists and Machine Learning Engineers who work full-stack on cloud-native architectures to deliver DS and ML services, heavily utilizing Apache Spark and AWS.

This post focuses on some recent updates we incorporated into one of our stacks built for big data applications to add support for running the latest and greatest deep learning based algorithms and models. This architecture provides us with the flexibility to pick the right framework at any step of Machine Learning and unlock scalable deep learning pipelines with minimal MLOps code. At the same time, it also provides the flexibility to transition to any MLOps platform without a lot of future ML code changes.

For businesses like ours, fast prototyping and quick experimentations are key to building completely new experiences in an efficient and iterative way. It is always preferred to have tangible results before putting more resources into a certain project. This architecture provides us with that capability and lets us spend more time on research, build models, test quickly, and rapidly iterate.

We hope this article will help folks who are on the fence about moving to Deep Learning based algorithms and tackle some of the common concerns and questions while taking on such a big project.

Why do we need Deep Learning?

Why distributed Deep Learning?

Most clickstream datasets especially for training recommender systems are millions (even billions) of rows of data, this gets difficult to train on a single instance (unless you are ready to wait days for a single epoch to complete). Also in use cases like recommenders, it's not easy to learn useful models with a smaller dataset as getting a proper sample representative of the population is harder.

Even if you don't agree with any of the above, wouldn’t it be nice to have multiple cheap GPUs or even a cluster of nodes each with multiple GPUs doing the training? Granted that this comes with its own issues and complexities, these are some things we want to address in this article and make this transition as easy as possible.

So let’s get started!


Where we are now

current stack and technologies used at different ML stages

With that out of our way, I'll quickly explain my justification on why both Spark and Tensorflow are great ecosystems and how their marriage is going to be a happily ever after one!

Spark Awesomeness

Data Pipelines with an easy SQL interface

In short, this allows engineers to write a UDF in any language they prefer (java/python/scala/R) and have Data Analysts/Data Scientists access it using the SQL API.

# in python
import pyspark.sql.functions as f
def say_hello(name: str) -> str:
return f"Hello {name}"
sqlContext.udf.register("say_hello", say_hello)-- in sql
SELECT say_hello('Bob')

Pandas UDFs

def simple_udf(iterator: Iterator[pd.Series])->Iterator[pd.Series]:
for x in iterator:
yield pd.Series(list(map(lambda r: r + "1", x)))

I point this feature out here because we used this to do batch inference on spark. This enabled us to load our deep learning model into a UDAF and run inference on a batch of the dataset much much faster on a CPU spark cluster.

This is a much bigger topic than this article, read more about pandas UDFs here.

Spark ML

It also has support for integrating with XGBoost, sklearn, and most other popular libraries.

Tensorflow Awesomeness

Now let's quickly talk about, some of the tensorflow tools and libraries that were of interest to us.

TF Data

dataset = ([list of filenames]))

Another benefit of using is while running in distributed mode, you just need to add a one-line code change to enable your tf dataset to work for multiple workers.

# There are two shard policies, DATA and FILE
options =
options.experimental_distribute.auto_shard_policy =

Learn more about here.

TF Distribute

This means that to make your single-node development code to later work with distributed training is as simple as adding these few lines initially.

# we can also add more conditions here, eg check for num workers
if tf.config.list_physical_devices('GPU'):
my_strategy = tf.distribute.MirroredStrategy()
else: # Use the Default Strategy
my_strategy = tf.distribute.get_strategy()
with my_strategy:
# wrap all the training code within this scope

This is just the tip of the tf.distribute iceberg. For a full walkthrough, check this guide out.


TF Transform: tf.transform is a great TFx component that can be used instead of Spark, in case you want to use TFx for feature engineering and manage the workload on kubeflow or similar platforms and/or stay in Tensorflow land without crossing over to Spark.

Tensorflow Model Analysis (TFMA): Another interesting library in TFx that is very useful for model validations both in an offline setting and for monitoring issues and drifts in an online production model. In an offline use case, we can utilize this library’s metrics for model management and automated promotion of a model to Production. We are more interested in using this library in the future for the online use case to monitor model performance and drift once we move to online model serving for real-time use cases.

TF Serving: Serving is now a part of TFx and is the core Tensorflow component responsible for deploying and serving an ML model online. It reads a SavedModel file and serves it at a given high-performance REST endpoint. The whole thing can be dockerized and served as a scalable endpoint using container services, an autoscaler and an API gateway.

Check out the tf serving docs here.

There is also a docker container provided from Tensorflow that we used to prototype our online serving component here. This Docker-based approach is perfect for us as it fits seamlessly into our existing AWS ECS-based architecture for serving offline recommendations.

Learn more about TFx here. As expected, you will find really good tutorials and guides in this section.

TF Recommenders

This library uses google’s scann approximate nearest neighbor engine for multiple purposes. A scann engine is provided as a layer in this library which can then be used to build approximate nearest neighbor lookup indexes saved as a SavedModel. This SavedModel can now be loaded into an online service using TF serving described above for real-time lookups of vectors. This layer is also used within the library to compute top K Accuracy metrics directly from the embeddings by comparing the query embedding with all product embeddings.

This is a developing library so things are still changing a lot. We noticed some issues already, but we really like the simplified usage of the scann indexes and the top K metrics features.

Learn more about TF recommenders here.

TF Agents

Check out more on TF agents here.

TF Hub

# model path can be found in tf hub 
# sometimes we also need to load preprocessing layers
bert_embeddings = hub.KerasLayer(bert_path, trainable=False)

Check out the entire model hub here.

This concludes our section on Tensorflow. You can see the whole plethora of Tensorflow extension libraries here. Every time I look at this page it keeps growing and the Tensorflow community is really active and amazing!


Bridging the gap between worlds

Read and write to TFRecords from Spark

# Write tfRecords as Example from a Spark Dataframe
.option(“recordType”, “Example”)
# You need to define a schema for Example
def read_tfrecord_sample(example):
feature_description = {
'feature0':[], tf.int32),
'feature1':[], tf.int32)
return, feature_description)
# Then we can load data into a TF Dataset using the above
# This method takes in a list of filenames to read from
dataset =
# You can also read the tfRecords back into a Dataframe
read_df =“tfrecords”)
.option(“recordType”, “Example”)

Read more on this package here.

Spark Tensorflow Distributor

Usage of this is pretty straightforward, after installing the package to your cluster, define Tensorflow data loading and model building code in a train method. This train method can be defined outside of the current file and imported as needed allowing us to develop our single node code first and then migrate to a distributed runner quickly. We just need to specify the number of workers in the cluster and the train method to the distributed runner.

num_workers = 10
# add local_mode=True for local testing
# you can also specify another strategy in the train method and turn on custom_strategy=True

The greatness of this library is the amount of work that is automated under the hood including the designation of chief and worker nodes and communication between nodes. It also supports custom_strategy to use other tf.distribute strategies (in case you have multiple GPUs on each node or if you want to use TPUs). To turn on custom_strategy, add the new strategy in the train method and add custom_strategy=True as a parameter to the runner.

how spark tensorflow distributor works
how spark tensorflow distributor works
how Spark tensorflow distributor works under the hood

This library also lets us do distributed training on CPU nodes (obviously it will be slower but you can now use a lot of cheaper nodes and spot instances). In our experiments, performance was 3x faster with a base GPU instance (AWS g4dn.xlarge) than on a similar priced CPU (c4.2xlarge) instance. This would be a place where we need to tune the type and number of nodes to optimize the performance of the cluster depending on budget and other requirements.

Read more about Spark Tensorflow distributor here. This guide is a great resource to learn more about all distribution strategies supported by TensorFlow and how to use them as a custom strategy in the distributed runner (Spoiler Alert: It’s very simple to switch strategies).

tf.distribute works by specifying a tf_config environment variable on each node (read more about it here), this config also includes a node index and node 0 is the chief node. There are cases when we need to figure out what the role of a node is, this will become important when we want to treat callbacks and other operations on chief different than on workers. We don’t want all the workers (over)writing the model checkpoints to the same path! Recommendations from the tensorflow guide is to have the chief write to a persistent store like S3 and all other workers write to some temp path).

Since Spark is handling the setting up of the variable and the node index in the case of the spark distributor, we can figure out which one is the chief using these lines of python code.

if 'TF_CONFIG' in os.environ:    
tf_config = json.loads(os.environ['TF_CONFIG'])
node_index = tf_config['task']['index']
is_chief = node_index == 0
print(f"Node Index: {node_index}, Is Chief: {is_chief}")

This is still not perfect and I want to point out some issues we tackled

Since the training happens on the cluster and the code is completely wrapped inside the runner method, we don’t have access to the final model at the end of training. This means that if you don't explicitly handle this, you might be running training for days but there would be no way to retrieve your model (yikes!). Tensorflow recommends using the ModelCheckpoint or BackupAndRestore callbacks for this. We used ModelCheckpoint callback,

best_model_save_callback = ModelCheckpoint(      filepath=best_filepath, # Somewhere on s3 or persistent storage

Another thing we noticed while saving the model was that if we added the callback to just the chief node, distributed training crashes as the workers move on from the callback stage and start training for the next epoch while the chief node is still saving the model and the workers think it’s dead and crashes the whole training.

callbacks need to be similar for chief and workers

A workaround we found for this issue was to add a dummy callback to all workers which will save the model to some temp path while the chief saves to a persistent store (like S3), this makes sure all nodes take a similar time to complete an epoch.

best_model_save_callback_dummy = ModelCheckpoint(    filepath="some_random_path",    

callbacks = [best_model_save_callback if is_chief \
else best_model_save_callback_dummy]

The training would also crash in the event that a worker dies or the cluster loses a worker as there is no built-in recovery mechanism in this distribution strategy. Tensorflow recommends saving checkpoints using the same approach as above at each epoch to make sure we can restart training if this happens. Read more about it here.

In addition to this approach using tf distributor (works only for TensorFlow), we can also run Distributed deep learning training on spark using horovod which supports both TensorFlow and PyTorch. If you want to learn more about it, these are good starting points horovod and petastorm.

MLFlow Callback

class MlflowLogging(tf.keras.callbacks.Callback):  
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def on_train_end(self, logs=None):
def on_epoch_end(self, epoch, logs=None):
super().on_epoch_end(epoch, logs)
keys = list(logs.keys())
for key in keys:
mlflow.log_metric(key, float(logs[key].numpy()), epoch)

Completing the picture

updated stack now extended with Tensorflow

In our initial use case, we decided to keep using our existing Spark-based Data Engineering pipelines with the key difference being the training and inference datasets written now as tfrecords (vs parquet) files, which would then allow to pick these up and load into the model using a TFRecordDataset.

I'll go over some pros and cons of this approach to wrap up,


  • Leverage existing read/write connectors in Spark (Most data warehouses, MySQL, S3, Kinesis, etc)
  • Parallelized batch inference in Spark with CPU clusters
  • Combine the best of both ecosystems with the flexibility to choose one over the other at any stage of the ML lifecycle
  • Run distributed deep learning loads on Spark clusters with fault tolerance
  • Ease to move from single-node development to distributed training
  • Use of any extension libraries from the TF ecosystem
  • Central storage, versioning, and management of models using MLflow
  • Any downstream application can now fetch the latest model from MLFlow
  • Great for smaller teams who want to focus on ML and less infrastructure


  • Distributed training with MirroredStrategy is not perfect especially with recovery (in the case of a dead worker). Parameter server strategy could be a better alternative in this case
  • Cryptic and messy Tensorflow error messages
  • Tensorflow is still a work in progress and a lot of things mentioned in this article are still experimental (especially with distribution strategies)
  • APIs between Keras and Tensorflow are in a state of flux right now

Conclusion and Next Steps

  • Personalization using Deep learning based Recommenders
  • Better text representation with BERT (and other transformers)
  • Product Tagging and Catalog Management with NER and text
  • Image embeddings with Autoencoders and CNN based CV models
  • Bandits and Online recommendations with Reinforcement Learning

We are primarily a Tensorflow shop at the moment and like how tf.distribute handles distribution strategies and (almost) how well it integrates with Spark.

We are currently exploring more additional pieces of technology that we believe would be great additions to our stack including Online inference using scann indexes deployed on ECS and streaming pipelines which would unlock real-time recommendations.

I will follow up with another article with a sneak peek at the sequential (LSTM based) model we recently built and deployed using this stack and share some of our learnings from this experience. This model has already won the hearts of our members (>30% lift in click-through rate) proving that deep learning based solutions are indeed the way to go and should be factored into the next generation of our products.

The future is exciting and with endless possibilities!

We’re also hiring, come help build some of these projects with our team.

PS: This is my first ever medium post, feel free to reach out if you have any questions or any comments and suggestions on this article!

AI/ML at Rue Gilt Groupe