An Intro to Gradient Descent for Kotlin Programmers

2019-09-01

Introduction

Gradient descent is an algorithm that’s used to solve supervised learning and deep learning problems. Here I’m going to try to give you an idea of why the algorithm works and how you’d implement it in Kotlin. I’ll also show the algorithm working with a simple kaggle dataset involving video game sales and ratings.

Everything I cover here is covered in Andrew Ng’s excellent Coursera machine learning course with the exception of the Kotlin implementation of gradient descent. If you really want a clear and definitive introduction to gradient descent, I recommend that course over this article. Here I’m mostly interested in solidifying what I’ve learned by sharing it.

Our Toy Problem

Here’s a plot of a bunch of video games from 2010 to 2016. On the X axis, we have a meta-critic score. On the y-axis, we see the number of sales in North America in millions. Looking at the data, it looks like there is some sort of relationship between meta-critic score and game sales.1 Our toy problem is this: we want to write a program that can give a prediction of video game sales based on metric critic score.

If we were looking at this in excel or google sheets, we’d just tick the “draw trendline” checkbox on our chart options and be almost done with our machine learning problem. Excel is fine for some supervised learning tasks, but it won’t work for harder problems, which is where gradient descent becomes essential.2 Although gradient descent is probably overkill for this toy problem, we’re going to use it anyway since presumably whatever boring learning algorithm that’s been sitting in excel for decades can’t scale to more interesting contemporary ML problems.

Anyway, if we drew the line of best fit for the above data, it’d look like this:

If you remember your high-school algebra, equations for lines like this look like this:

So, our problem here is really this: How do we figure out what m and b are given the above video game data? Or, how would we implement the following kotlin function:

fun findLearningParameters(videoGameData: Array<Pair<Int, Int>>): Pair<Int, Int> {
  //...
  return Pair(m, b)
}

This is really the crux of the problem because once we this function implemented, it’s trivial to write a program that can predict video game sales based on meta-critic score:

fun predictVideoGameSales(
  metaCriticScore: Int,
  learningParameters: Pair<Int, Int>
) {
  val (m, b) = learningParameters
  return metaCriticScore * m + b
}

What we want is for the learning parameters to be such that the difference between the predicted values (given by the line of best fit) and all the actual values is minimized. This difference depends on the values we choose for x and b; it’s a function of those variables. This function is called the “cost function,” so another way of putting our problem is to say that we’re trying to find values of x and b that minimize the cost function.

Solution

To start understanding how we’d solve this problem, we need to take a closer look at the cost function. What exactly is it? Let’s start with a thought experiment: suppose we choose 0 as the values for x and b. What would the difference be between the actual values from our data set and the predicted values? The following table with the first five data points tries to capture this:

meta-critic score actual sales predicted sales (predicted - actual)^2
61 15 0 225
97 7.02 0 49
97 9.66 0 93.3
88 9.04 0 81.7216
87 9.7 0 94.09

One way to summarize this table would be to take the average of the squared difference between the predicted value and the actual value and divide by 2. That’s our cost function, which we’ll call J. Formally, that’s:

where n is the number of video games in our data set. (If we were to use this equation on the above table, we’d get 54.31.) Let’s ignore b for a second and rewrite our equation:

Remember: x and actual aren’t variables here, we’ll be able to plug in values from our data, so we’re really just looking at a quadratic function and you probably remember from high-school that quadratic functions look like this:

The x-axis in this case is m and the y axis tells us how far off m is from where it needs to be. This graph tells us that 0 is our best choice for m since choosing 0 will actually minimize the cost function. This isn’t true of our actual data set, but the basic idea of gradient descent applies regardless and this graph is (probably) cleaner to look at.

At the highest level, the way gradient descent works is that we take a guess at the optimal value of m, then we modify our guess by a small amount called the “learning rate” and repeat until we aren’t getting any better. This is how the algorithm got its name: we’re slowing descending down the gradient until we get to the minimum.

With that, we can start to fill in more of our findLearningParameters function:

fun findLearningParameters(videoGameData: Array<Pair<Double, Double>>): Pair<Double, Double> {
val learningRate = .0003
var guess = Pair(0.0, 0.0)
for (i in 1..10000000) {
guess = updateGuess(guess, learningRate, videoGameData)
if (i % 1000 == 0) println("Guess: $guess")
        if (i % 1000 == 0) println("Cost: ${cost(videoGameData, guess)}")
}
return guess
}

The highlighted lines point to some functions that need explaining. You should be able to guess at what cost does: it computes how bad our guess at m and b are with the aforementioned formula:

So, the implementation is just this:

fun cost(videoGameData: Array<Pair<Double, Double>>, learningParameters: Pair<Double, Double>): Double =
        videoGameData.fold(0.0) { acc: Double, pair: Pair<Double, Double> ->
            val (m, b) = learningParameters
            val (x, y) = pair
            acc + ((((m * x) + b) - y).pow(2)/2)
        } / videoGameData.size

The reason we are using the cost function is to monitor progress on the gradient descent. As the for-loop executes, cost should be going down.

Understanding the implementation of updateGuess requires a little detour into basic calculus.

If we look at our graph of m and J(m) and remember that the derivative of a function gives us the slope of the tangent line, we have what we need to see the central insight that makes gradient descent possible. So, let’s ask: What happens if we compute the derivative of this function and draw the tangent lines for m = 1 and m=-1 in this graph?

The tangent line for m=1 has a positive slope, which means the derivative of J(m) at that point will be positive. The tangent line for m=-1 has a negative slope, which means the derivative of J(m) at that point will be negative. If m=1, we want our next guess at m to be less and if m=-1, we want our next guess to be more. We can easily achieve this by subtracting the derivative, and if we do this, we wind up with the heart of the gradient descent algorithm:

We’ve got some partial derivatives here because we’re working with multiple variables, and in general, there may be many more learning parameters than m and b if we have more features/variables in our learning problem (e.g., we may think that game studio is a predictor/variable that could tell us something about game sales). This detail doesn’t really matter for getting the gist of gradient descent. Here’s the kotlin:

fun updateGuess(
        guess: Pair<Double, Double>,
        learningRate: Double,
        videoGameData: Array<Pair<Double, Double>>
    ): Pair<Double, Double> {
        val (m, b)  = guess
        val mGradient = videoGameData.fold(0.0) { acc: Double, pair: Pair<Double, Double> ->
            val (x, y) = pair
            acc + ((((m * x) + b) - y) * x)
        } / videoGameData.size
        val bGradient = videoGameData.fold(0.0) { acc: Double, pair: Pair<Double, Double> ->
            val (x, y) = pair
            acc + ((((m * x) + b) - y))
        } / videoGameData.size
        val newM = m - (learningRate * mGradient)
        val newB = b - (learningRate * bGradient)
        return Pair(newM, newB)
    }

Running this code gives us .0281 for m and -.804 for b, which is pretty darn close to that line of best fit we saw earlier.

Notes


  1. The relationship isn’t as strong as I suspected. They say being a game dev is rough. Maybe this is evidence of that. Maybe the data just isn’t very good. ↩︎

  2. I’m betting excel and google sheets use the “normal equation” method of linear regression rather than gradient descent, which is apparently ok if you’re dealing with a linear regression problem with less than 10k features. I also doubt it can handle more than one variable without any sort of extension. ↩︎

machine learningkotlin

Maybe Don't Write That Test

Dagger 2, 2 Years Later