A simple guide to gradient descent in machine learning

3D gradient descent
Image credit: 123RF

This article is part of Demystifying AI, a series of posts that (try to) disambiguate the jargon and myths surrounding AI.

Deep learning models come in various forms, each tailored to specific tasks. From dense layers to convolutional neural networks (CNNs), recurrent neural networks (RNNs), and the more recent transformers, the landscape of deep learning is diverse and fascinating. However, beneath this diversity lies a common thread that binds these models together—the principle of gradient descent, the technique used for training machine learning models.

The mathematics underpinning deep learning can often seem intimidating. Recently, I had the pleasure to read Math for Deep Learning by Ronald T. Kneusel, which delves into the mathematical intricacies of deep learning and makes it understandable through examples, Python code, and visuals.

In this article, I will try to demystify gradient descent without delving too deeply into the mathematical details. This article draws inspiration from Chapter 11 of Math for Deep Learning, which provides a comprehensive explanation of gradient descent. For a more in-depth understanding, I highly recommend reading the entire book. It is an invaluable resource for anyone interested in the mathematics of deep learning.

What is gradient descent?

The general idea behind gradient descent is to iteratively adjust a machine learning model’s parameters to minimize its errors. Errors are calculated through a loss function, which calculates the difference between the model’s predictions and actual values, also referred to as the ground truth. In other words, the goal of gradient descent is to find the minimum value of the loss function.

To understand how gradient descent works, imagine you’re trying to find the lowest point in a valley but you’re blindfolded. You can feel the slope under your feet and decide which way to go. You take a step, feel the slope again, and adjust your direction.

This is essentially what gradient descent does. It starts with a prediction, measures the error, and then adjusts the model’s parameters based on the slope of the loss function. By repeating this process, the model learns to approximate the distribution of the underlying data.

The term “gradient” in gradient descent refers to the slope of the loss function with respect to the model’s parameters. Since we want to minimize the error, we go in the opposite direction of the slope—hence the term “descent.”

Gradient descent in one dimension

To understand how gradient descent works, let’s consider a simple one-dimensional problem. Imagine we have a machine learning model with a single parameter, x. Our goal is to find the value for x that minimizes the loss function, which calculates the error between the model’s predictions and the actual values.

Suppose the underlying function for our data is f(x) = 6x^2 – 12x + 3. Analytically, we can find the function’s minimum by setting the first derivative to zero and solving for x. However, in real-world applications, problems are often so complex that we don’t know the form of the function in advance and can’t solve it analytically. This is where gradient descent comes into play and helps us find the minimum value step by step.

1D function
A simple one-dimensional function

In gradient descent, we calculate the gradient, which for a one-dimensional problem is essentially its first derivative. In our case, the derivative d/dx = 12x – 12. We then initialize x to a random value and calculate the output of the function. By calculating the gradient for x, we obtain the direction of the slope. 

We adjust x by a small amount in the opposite direction of the gradient. The step size is often referred to as the learning rate. This is to avoid overshooting and missing the minimum. In mathematical literature and programming libraries, you’ll usually see the Greek letter η or eta as the parameter for the learning rate. The gradient descent algorithm is x <- x – η * d/dx.

Suppose we initialize x at -0.9 and eta at 0.03. In the first step, the gradient will be 12 * (-0.9) – 12 = -22.8. The first update to x will be x <- -0.9 – 0.03 * (-22.8), which gives us -0.216. We’re now closer to the minimum, which is 1.

The next updates will be:






gradient descent 1D
Gradient descent on a simple one-dimensional function

As we continue, x will eventually reach 1. As you can see, as we get closer to the minimum, the changes become smaller because the slope softens.

In this simple one-dimensional example, the process seems straightforward. However, in real-world machine learning problems, we often deal with high-dimensional data and complex models with millions of parameters. Despite this complexity, the principle remains the same: iteratively adjust the parameters in the direction of descent until we reach a minimum.

This is the power of gradient descent. It provides a practical, scalable method for training complex machine learning models, even when the underlying mathematical functions are too complex to solve analytically. 

gradient descent 1d overshoot
If the learning rate is too high, gradient descent might overshoot and miss the optimal point of the function

Gradient descent in multiple dimensions

In real-world applications, machine learning models—especially deep learning models—often have many dimensions. Each dimension corresponds to a different parameter or feature that the model uses to make predictions. Despite the added complexity, the principle of gradient descent remains the same as explained above. The main difference is that we now have to compute the partial derivative of the function for each parameter separately.

Consider a two-dimensional function f(x,y) = 6x^2 + 9y^2 – 12x – 14y + 3. This function represents a model with two parameters, x and y

2d function
A simple two-dimensional function with a single minimum

The gradient for this function is a vector of the following partial derivatives:

d/dx = 12x – 12

d/dy = 18y – 14

These partial derivatives tell us the slope of the function with respect to each parameter. At each step of gradient descent, we need to update x and y as follows:

x <- x – η * d/dx

y <- y – η * d/dy

In this way, we adjust each parameter in the direction that reduces the function’s value the most, guided by the corresponding partial derivative.

In the following graph, the function landscape is represented with contours and shaded areas, with the lighter area representing the minimum. We have illustrated the gradient descent trajectories of three different initialization points. As you can see, regardless of the initial values of x and y, they end up reaching the minimum after several iterations. This illustrates the power of gradient descent: even in high-dimensional spaces, it can guide us toward the minimum of a complex function.

2d function gradient descent
Gradient descent on a two-dimensional function with a single minimum.

In deep learning, obtaining the gradient is much more complicated due to the layered and interconnected nature of the parameters and the non-linearity introduced across the network. This is done through a process called backpropagation, which uses the chain rule to calculate partial derivatives across multiple layers of artificial neurons. Backpropagation is a crucial part of training deep learning models, but it’s a complex topic that deserves its own article. For a comprehensive explanation of backpropagation, I recommend Chapter 10 of Math for Deep Learning.

Gradient descent with multiple minima

In most real-world problems, the landscape doesn’t have a single optimal minimum. Instead, it has several minima, each representing a potential solution.

To illustrate this, consider a two-dimensional function that has two minima. One of these minima is more optimal than the other, meaning it represents a lower value of the function. 

2d function multiple minima
A two-dimensional function with multiple minima.

However, the path to this optimal solution is not always straightforward. Depending on the initial positions of the parameters, the model may converge on one of the two minima. This is because gradient descent follows the steepest path downwards from the initial point, which may not necessarily lead to the global minimum.

In an ideal world, we would know the form of the loss function in advance, allowing us to choose a good initialization point that would lead us directly to the global minimum. However, in reality, this is rarely the case. The form of the loss function is typically unknown, making it impossible to predict the best starting point.

gradient descent 2d function multiple minima
Gradient descent on a function with multiple minima.

In very high-dimensional problems with millions of parameters, there are often many local minima that are equally good. This is a consequence of the high-dimensional nature of the problem, which increases the likelihood of finding numerous solutions that are nearly identical in terms of their quality.

In such cases, gradient descent will eventually converge on one of the many equally good local minima.

Stochastic gradient descent

In real-world deep learning applications, we usually start by creating a dataset of annotated training examples, such as a collection of x-ray images with their corresponding labels (benign or malignant). 

The process of training a model involves running each example through the model, making a prediction, and measuring the loss function. After processing the entire training set, we calculate the loss for the entire sample (for instance, using mean squared error), and then adjust the model’s parameters using gradient descent. When you use the entire dataset to calculate the gradient of the loss function, it is called batch gradient descent.

math for deep learning book cover
Math for Deep Learning by Ronald T. Kneusel

As the model and dataset scale up, the memory and computational costs of performing batch gradient descent become prohibitive. This is why machine learning practitioners use minibatch gradient descent. Instead of using the entire dataset at each step of the training process, minibatch gradient descent uses a subset of the examples. 

By choosing a random sample, we obtain an estimate of the true gradient. Minibatch gradient descent is also referred to as stochastic gradient descent (SGD). The term “stochastic” refers to the randomness added by using a part of the dataset.

In practice, SGD often outperforms batch gradient descent. The reason for this is that the randomness introduced by the multiple mini batches enables the model to better generalize to the underlying distribution instead of overfitting the entire dataset. This stochasticity can help the model escape shallow local minima and find a better overall solution.

There is no fixed rule for choosing the number of examples in a minibatch. The optimal minibatch size can depend on a variety of factors, including the nature of the problem, the size of the dataset, and the complexity of the model. It’s often a matter of trial and error, and it’s one of the hyperparameters that machine learning practitioners spend a lot of time tuning to get the best performance from their models.

Further reading

This article has provided a brief overview of gradient descent. It’s important to note that this is just the tip of the iceberg. There’s a wealth of knowledge to be discovered beyond what we’ve covered here.

One topic we omitted is momentum in gradient descent. This technique is inspired by physics and it helps the algorithm navigate the parameter space more efficiently. Beyond basic gradient descent, there are several other variants of the algorithm that are worth exploring. These include RMSprop, Adagrad, and Adam, each with their unique strengths and applications. 

To delve deeper into these topics and more, consider reading Deep Learning for Math by Ronald T. Kneusel. This comprehensive guide will provide you with a solid foundation in these concepts, equipping you with the knowledge to tackle more complex problems in the field of deep learning.

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.