paint-brush
Why Machine Learning Sampling is Harder Than You Think (And How to Do it Right)by@vrjdev

Why Machine Learning Sampling is Harder Than You Think (And How to Do it Right)

by Rajesh Vakkalagadda
Rajesh Vakkalagadda HackerNoon profile picture

Rajesh Vakkalagadda

@vrjdev

Expert in Big Data, AI, ML and scaling infrastructure

March 17th, 2025
Read on Terminal Reader
Read this story in a terminal
Print this story
Read this story w/o Javascript
Read this story w/o Javascript
tldt arrow

Too Long; Didn't Read

In this article, I will explain how random sampling can be achieved at scale using Scala Spark. I will also show how central limit theorem can be extended to solve this problem.

Companies Mentioned

Mention Thumbnail
Apache
Mention Thumbnail
Limit
featured image - Why Machine Learning Sampling is Harder Than You Think (And How to Do it Right)
1x
Read by Dr. One voice-avatar

Listen to this story

Rajesh Vakkalagadda HackerNoon profile picture
Rajesh Vakkalagadda

Rajesh Vakkalagadda

@vrjdev

Expert in Big Data, AI, ML and scaling infrastructure

About @vrjdev
LEARN MORE ABOUT @VRJDEV'S
EXPERTISE AND PLACE ON THE INTERNET.
0-item

STORY’S CREDIBILITY

Original Reporting

Original Reporting

This story contains new, firsthand information uncovered by the writer.

One of the most common use cases in machine learning systems is to use sampling. It is about taking random samples from a large set of data points. Examples of sampling can be such as randomly selecting people from various States/Counties/Districts, which set of your users should see the new button on the app/website. A sample size is dependent on the dataset size you have. In some cases a million users can be a decent sample size and in some cases you might get only 2000 data points for sampling. This is dependent on the problem you are trying to solve and we have to choose various approaches in modeling to remove bias.


We do sampling because, we want to remove over fitting of the model, if you take the entire dataset and use it in modeling, then your model will not know how to predict for new users. Choosing the right sample datasets is also important for good ML models.


The most common languages used for sampling is either Scala or PySpark that are maintained by the Apache Foundation. And a common challenge in these languages is that when we do sampling, entire data is stored into one machine, their by leading to memory errors.


In this article, I will explain how random sampling can be achieved at scale using Scala Spark and how central limit theorem can be extended to solve this problem.

Challenges

One of the most common ways users get a sample dataset is by using df.sample on a DataFrame. A DataFrame is an object used by Spark to store the row information of your datasets. This is a distributed object that uses Map Reduce to get some operations done (we will get into Map Reduce in future articles). While sample method works fine, when data scales this becomes difficult to handle.


Spark internally pulls the data locally to a single machine in the cluster and then runs the random method internally. This would mean that if the grouping logic generates a large set of results then the single instance can get OOM errors (out of memory). Only way to resolve the error is to use big machines or increase memory.


Increasing instance size means, we are paying more cloud provider costs. Increasing memory means, clusters can take more time as some jobs don’t need large memory. Spark only takes memory as input once for the entire map reduce job.

Solution

Enter Central Limit theorem, basically what it means is, regardless of the original distribution of a population, the distribution of the sample means (or sum) will tend to be approximately normal (i.e., a bell-shaped curve) as the sample size increases, provided the samples are independent and identically distributed and of sufficient size. This means, if we pick a random number between 0 and 1, the sample is always distributed in a curve.


you might be wondering how this will be useful for random sampling ??

Well, if we iterate over each row and always pick a random number at each row and then filter the records that are below certain number then we get those many samples approximately. This is another way to look at this.


Because of this, all we have to do is , loop through the each row and use a UDF to always pick a random number between 0 and 1. If your sample ratio is 30%, then filter the records that have the sample value less than 0.3 , this will scale now. As every iteration is distributed and its only map operation. Meaning low memory and low cloud computing cost.


This solution works, if you have enough decent samples, say 100k or above. As number of points in the universe goes down, this sampling will also be random or does not work as expected. I mean at this point, you can also use the direct df.sample method itself in my opinion.


To keep things consistent, use a random seed value for each sampling to make sure, all instances are in sync for randomness. Use UDFs (User defined functions) as asNondeterministic for this as, otherwise, they will not work as expected due to the caching nature of Spark.

Summary

Overall, we went through how CLT can help in random sampling at scale in Big data systems. In future articles, I will go over the common ML infrastructures used for maintaining ML models at scale.

L O A D I N G
. . . comments & more!

About Author

Rajesh Vakkalagadda HackerNoon profile picture
Rajesh Vakkalagadda@vrjdev
Expert in Big Data, AI, ML and scaling infrastructure

TOPICS

THIS ARTICLE WAS FEATURED IN...

Arweave
Read on Terminal Reader
Read this story in a terminal
 Terminal
Read this story w/o Javascript
Read this story w/o Javascript
 Lite
Hackernoon
X
Threads
Bsky

Mentioned in this story

X REMOVE AD