Listen to this story
Expert in Big Data, AI, ML and scaling infrastructure
This story contains new, firsthand information uncovered by the writer.
One of the most common use cases in machine learning systems is to use sampling. It is about taking random samples from a large set of data points. Examples of sampling can be such as randomly selecting people from various States/Counties/Districts, which set of your users should see the new button on the app/website. A sample size is dependent on the dataset size you have. In some cases a million users can be a decent sample size and in some cases you might get only 2000 data points for sampling. This is dependent on the problem you are trying to solve and we have to choose various approaches in modeling to remove bias.
We do sampling because, we want to remove over fitting of the model, if you take the entire dataset and use it in modeling, then your model will not know how to predict for new users. Choosing the right sample datasets is also important for good ML models.
The most common languages used for sampling is either Scala or PySpark that are maintained by the Apache Foundation. And a common challenge in these languages is that when we do sampling, entire data is stored into one machine, their by leading to memory errors.
In this article, I will explain how random sampling can be achieved at scale using Scala Spark and how central limit theorem can be extended to solve this problem.
One of the most common ways users get a sample dataset is by using df.sample
on a DataFrame. A DataFrame is an object used by Spark to store the row information of your datasets. This is a distributed object that uses Map Reduce to get some operations done (we will get into Map Reduce in future articles). While sample method works fine, when data scales this becomes difficult to handle.
Spark internally pulls the data locally to a single machine in the cluster and then runs the random method internally. This would mean that if the grouping logic generates a large set of results then the single instance can get OOM errors (out of memory). Only way to resolve the error is to use big machines or increase memory.
Increasing instance size means, we are paying more cloud provider costs. Increasing memory means, clusters can take more time as some jobs don’t need large memory. Spark only takes memory as input once for the entire map reduce job.
Enter Central Limit theorem, basically what it means is, regardless of the original distribution of a population, the distribution of the sample means (or sum) will tend to be approximately normal (i.e., a bell-shaped curve) as the sample size increases, provided the samples are independent and identically distributed and of sufficient size. This means, if we pick a random number between 0 and 1, the sample is always distributed in a curve.
you might be wondering how this will be useful for random sampling ??
Well, if we iterate over each row and always pick a random number at each row and then filter the records that are below certain number then we get those many samples approximately. This is another way to look at this.
Because of this, all we have to do is , loop through the each row and use a UDF to always pick a random number between 0 and 1. If your sample ratio is 30%, then filter the records that have the sample value less than 0.3 , this will scale now. As every iteration is distributed and its only map operation. Meaning low memory and low cloud computing cost.
This solution works, if you have enough decent samples, say 100k or above. As number of points in the universe goes down, this sampling will also be random or does not work as expected. I mean at this point, you can also use the direct df.sample
method itself in my opinion.
To keep things consistent, use a random seed value for each sampling to make sure, all instances are in sync for randomness. Use UDFs (User defined functions) as asNondeterministic
for this as, otherwise, they will not work as expected due to the caching nature of Spark.
Overall, we went through how CLT can help in random sampling at scale in Big data systems. In future articles, I will go over the common ML infrastructures used for maintaining ML models at scale.