Understanding Cross Validation for Beginners

ichen
Analytics Vidhya
Published in
4 min readJun 19, 2021

--

Whether you’re building a regression model or a classifier, it is difficult to know how well your model does if you don’t test it. We want a model that has a pattern that is as close as possible to the true relationship of the data without getting influenced by too much noise. Hence, testing our model is important as it could potentially help reveal the appropriateness of our model. Not to mention, not testing our model could lead to overfitting as well. Overfitting is an error that occurs when the model is very good at predicting or modeling the dataset for certain intervals inside the dataset, but terrible for values outside the dataset.

So how do we help combat overfitting? Through cross validation!

What is Cross Validation?

Cross validation (CV) is a technique used in training our model. In CV, we will split our dataset into two: the testing dataset and the training dataset. The idea is that we will train our model using the training set and then test our model on the testing set to see how well it does. “Seeing how well it does” can be done through numerous ways but one method I used when starting out was comparing mean square error (MSE) values. Hence, we will have a training set MSE and a testing set MSE. Generally, we care “more” about the testing set MSE as we expect the training set MSE to be fairly low as the model is trained on that particular set. Note that a very high testing set MSE in comparison to the training set MSE may indicate overfitting.

Now that we answered why we should split our dataset, the question now lies within how we will split our dataset. Ideally, we want a good split where we have enough data to train our model as best as possible, and enough data to properly test it.

The 50–50 Split

Let us consider naively splitting the data into a 50–50 split — that is, 50% of the data is used for training the model (the training dataset) and the other 50% is withheld for testing the model after training (the testing dataset).

While we do have 50% to test our model on, 50% is still quite a bit that is withheld from helping train the model. The testing 50% portion may contain valuable information which our model will miss out on if we use this split. This may lead to higher bias where our predictions/estimates are far from the actual.

Useful documentation: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html

Leave-One-Out-Cross-Validation (LOOCV)

Imagine a situation where we have very little data points, let’s say 10. Then the 50–50 split is even more unfavorable as leaving out 50% is 5 very influential data points. Consider LOOCV as an alternative.

In LOOCV, we train the model on the entire dataset, excluding only one datapoint. We will iterate through the entire dataset such that each time we will be excluding a different datapoint from the training set. As such, the testing set will only contain that sole point which changes in each iteration. We will then fit the model on the training set and repeat. Essentially, if we have a total of n data points, the model will be fitted on n-1 data points each iteration. This is further illustrated below:

Note that if we have a very large dataset, LOOCV is very time consuming as we are going through n iterations.

K-Fold Cross Validation

LOOCV is a special case ( k=n, the size of the dataset ) of K-Fold Cross Validation. In K-Fold Cross Validation, we will split the dataset into k folds (a fold may be thought of as a group/subset and k is some number of groups). 1 out of k folds is withheld as the testing set, while the remaining k -1 folds are used as the training set. The splitting of folds is depicted below with k=2, k=3 and k=6, and a dataset with 6 data points:

We can break down the steps of K-Fold Cross Validation into four steps:

  1. First, we will split the data into k-folds.
  2. Then, we will withhold one fold as the testing set.
  3. We will then train the model on the other k -1 folds.
  4. Repeat, iterating through the folds until all of the folds have been tested on (total k times). This is depicted below with the same dataset of 6 points and k=3)

Before performing K-Fold Cross Validation, you should shuffle the data so that each fold is representative of the dataset. Note the MSEs are averaged over the k iterations.

In practice, we commonly use k=5 or k=10, but you can pick any suitable k value.

Additionally, as we split into k-folds, by using k≠n and not performing LOOCV, K-Folds CV is not as costly in terms of time.

Useful documentation: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html

By performing cross validation, we are attempting to improve our model’s performance. The above methods can be easily implemented by writing your own functions for practice or through Scikit-Learn (Python). In considering the right models to use for certain datasets, we will most likely have to make tradeoffs to get the most appropriate model. Ultimately, cross validation is an excellent tool to use in achieving our goal of selecting most appropriate model.

--

--

ichen
Analytics Vidhya

B.S. Applied Mathematics. Interested in learning more about data and its use in industry.