Machine learning models represent the learning output from a machine in such a way that, it can be used in the future to predict or understand similar kinds of data by which the model had been trained.
In the process of building a Machine Learning model, there is a trade-off between bias and variance. We all know this.
But what exactly does it mean when we say a model has a high bias or high variance? Can we visualize what is happening?
I have explored these questions using python and some basic visualization libraries. We will be using NumPy for data generation and calculations. Matplotlib and Seaborn are useful for visualizing the generated data points and predicted data points. We will be using Scikit-learn for model building and Pandas for handling the data.
# Importing Libraries
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set()
import pandas as pd
from sklearn.tree import DecisionTreeRegressor
from sklearn.linear_model import LinearRegression
In most real-world problems, we don't really know what is the true relation between independent variables(features) and the dependent variable(target variable). In statistical learning, we try to estimate this true function using various models.
Since it is not possible to estimate the true function accurately, this will induce error in the predictions. This error can be broken down into 3 categories: bias, variance, and irreducible error.
Bias is the error that is caused by the model when we try to approximate a complex real-world problem using a simple model.
Linear regression often has a high bias since we assume a linear relation which is a simple one. Because of this reason, we will use Linear regression as one of our models to visualize.
Variance is the amount by which the model would change if we use different sample data for model training.
A decision tree is a model that has high variance but low bias. Decision trees make almost no assumption about the true function. But if we change the training data, the model will vary too much. The decision tree is the second model that we use for our prediction.
Irreducible error is caused because of some unknown reason.
In terms of features, we can say that there can be some other factors that affect the target variable which is not a part of our available features.
We will use NumPy to create our feature X which lies between -4 and 4. We are generating our target variable Y as a sine function of X. But we know that there can be an irreducible error which we won't be able to learn from our available features. In order to simulate this irreducible error, we randomly generate very small values and add them to Y.
n = 2000
np.random.seed(63)
X = np.linspace(-4,4,n) # generating data with noise
Y = np.sin(X) + 0.4*np.random.randn(n) # for sine function
We can visualize the generated data using a scatter plot.\
plt.figure(figsize=(12,6))
plt.scatter(X,Y)
plt.xlabel('X')
plt.ylabel('Y')
plt.show()
Note that, we are using only one feature variable so that it will be easier to visualize. But in real-world problems, we will have many features and they cannot be visualized. Our objective here is to see what is happening in terms of bias and variance.
Now we will store this data as a Pandas data frame and make 5 subsets of this data by randomly sampling 100 indices of our data without replacement. We are performing random sampling without replacement to make sure that each of these subsets has different data.
One important thing that we should realize is, out of the 2000 data points that we generated, we are only using 100 data points for each subset. We are doing this to make sure that, even though the data in each subset is different, each of these subsets should be a good representation of our population of 2000 data points.
df = pd.DataFrame({'X':X,'Y':Y}) # Dataframe that contains our population
# creating subsets
np.random.seed(63)
X1 = df.iloc[np.random.choice(range(2000),100,replace=False),:]
X2 = df.iloc[np.random.choice(range(2000),100,replace=False),:]
X3 = df.iloc[np.random.choice(range(2000),100,replace=False),:]
X4 = df.iloc[np.random.choice(range(2000),100,replace=False),:]
X5 = df.iloc[np.random.choice(range(2000),100,replace=False),:]
We can build 5 different linear regression models using the 5 subsets that we have created. Similarly, we can also build 5 different decision tree models.
Our objective is to build a five linear regression model and visualize its prediction in the interval of -4 to 4(interval of our population). To make this process easier, we can create a python function that will take the training data as input and return the prediction for 2000 evenly spaced numbers between -4 to 4.
# For Linear regressi/on
def LR_gen(data):
lr1 = LinearRegression()
lr1.fit((data['X'].values).reshape(-1,1),data['Y'])
prediction = lr1.predict(np.linspace(-4,4,2000).reshape(-1,1))
return prediction
# For Decision tree regressor
def DT_gen(data):
dt1 = DecisionTreeRegressor()
dt1.fit((data['X'].values).reshape(-1,1),data['Y'])
prediction = dt1.predict(np.linspace(-4,4,2000).reshape(-1,1))
return prediction
Now we can simply train a model using a subset of data and get the predictions using the function that we defined. We can visualize these 5 predictions along with the ground truth in one plot.
plt.figure(figsize=(12,6))
for i,x in enumerate([X1,X2,X3,X4,X5]):
plt.plot(np.linspace(-4,4,2000),LR_gen(x),label=f'Prediction_{i+1}')
plt.scatter(df['X'],df['Y'],label='Ground Truth',color='b',alpha=0.1)
plt.legend()
plt.show()
What does the plot say?
From the plot, we can make two observations:
The first observation tells us about the high bias of linear regression. That is, in reality, our true function is a non-linear one. Since we can't possibly know this, we assumed that the true function is linear and build linear regression models. This resulted in models that predict very poorly. If we take the average of all 5 predictions from the models and compare it with the true function, it will have a high difference.
The second observation tells us about the low variance of linear regression models. All 5 models are close to each other. This implies that, when we change the training data, the model will get affected very little.
mean_predictions = [] # To store predictions from 5 models
for i in [X1,X2,X3,X4,X5]:
lr1 = LinearRegression()
lr1.fit((i['X'].values).reshape(-1,1),i['Y'])
prediction = lr1.predict((df['X'].values).reshape(-1,1))
mean_predictions.append(prediction)
mean_prediction = np.mean(mean_predictions,axis=0) # average of all 5 models
# Plotting
fig, (ax1,ax2) = plt.subplots(2,1,figsize=(12,8),sharey=True)
# Plotting the average prediction of all the model and the true function
ax1.scatter(df['X'],mean_prediction,label='Mean prediction',color='green')
ax1.scatter(df['X'],df['Y'],label='Ground Truth',color='b',alpha=0.1)
ax1.title.set_text('Average prediction v/s Ground truth')
ax1.legend()
for i,x in enumerate([X1,X2,X3,X4,X5]):
ax2.plot(np.linspace(-4,4,2000),LR_gen(x),label=f'Prediction_{i+1}')
ax2.scatter(df['X'],mean_prediction,label='Mean prediction',color='green')
ax2.title.set_text('Average prediction v/s models with different data')
ax2.legend()
plt.show()
plt.figure(figsize=(12,6))
for i,x in enumerate([X1,X2,X3,X4,X5]):
plt.plot(np.linspace(-4,4,2000),DT_gen(x),label=f'Prediction_{i+1}')
plt.scatter(df['X'],df['Y'],label='Ground Truth',color='b',alpha=0.1)
plt.legend()
plt.show()
What does the plot say?
From the plot, we can make two observations:
The first observation tells us about the low bias of the decision tree. That is, in reality, our true function is a non-linear one. Since we didn't make any assumption about the form of the function and built decision tree models, it was able to learn the underlying form up to an extent. This resulted in models that predict well. If we take the average of all 5 predictions from the models and compare it with the true function, it will be very close to the true function.
The second observation tells us about the high variance of the decision tree model. All 5 models vary very much from each other. This implies that, if we change the training data the model will get affected too much.
mean_predictions = []
for i in [X1,X2,X3,X4,X5]:
dt1 = DecisionTreeRegressor()
dt1.fit((i['X'].values).reshape(-1,1),i['Y'])
prediction = dt1.predict((df['X'].values).reshape(-1,1))
mean_predictions.append(prediction)
mean_prediction = np.mean(mean_predictions,axis=0)
# Plotting
fig, (ax1,ax2) = plt.subplots(2,1,figsize=(12,8))
ax1.scatter(df['X'],mean_prediction,label='Mean prediction',color='green')
ax1.scatter(df['X'],df['Y'],label='Ground Truth',color='b',alpha=0.1)
ax1.title.set_text('Average prediction v/s Ground truth')
ax1.legend()
for i,x in enumerate([X1,X2,X3,X4,X5]):
ax2.plot(np.linspace(-4,4,2000),DT_gen(x),label=f'Prediction_{i+1}')
ax2.scatter(df['X'],mean_prediction,label='Mean prediction',color='green')
ax2.title.set_text('Average prediction v/s models with different data')
ax2.legend()
plt.show()
We are able to understand bias and variance visually. But in practice, we will have many features rather than just one. This makes it impossible to visualize and tune our model in order to attain reduced bias as well as variance. We can use various statistical methods during the model training phase to achieve minimum bias and variance.
[1] Gareth James, Daniela Witten, Trevor Hastie, Robert Tibshirani, An introduction to statistical learning: with applications in R (2013), Springer
[2] Hunter, J. D, Matplotlib: A 2D graphics environment(2007), Computing in Science & Engineering, 9(3), 90–95.
[3] Harris, C. R., Millman, K. J., van der Walt, S. J., Gommers, R., Virtanen, P., Cournapeau, D., … Oliphant, T. E, Array programming with NumPy(2020), Nature, 585, 357–362. https://doi.org/10.1038/s41586-020-2649-2
[4] McKinney, W., & others, Data structures for statistical computing in Python(2010), In Proceedings of the 9th Python in Science Conference (Vol. 445, pp. 51–56). [5] Pedregosa, F., Varoquaux, Ga"el, Gramfort, A., Michel, V., Thirion, B., Grisel, O., … others, Scikit-learn: Machine learning in Python(2011), Journal of Machine Learning Research, 12(Oct), 2825–2830.