Machine Learning Series: How Does the Machine "Learn"?

by Melike Vurucu
Tags:artifical-intelligencemachine-learningmachine-learning-series

Well, you may heard of some machine learning memes such as:

Machine is still learning

Well, I am not actually a fan of this meme, I saw on LinkedIn (then found on Reddit)... Here is the Reddit post (original one seems deleted).

You may be alien to what does it mean, why there is a warning sign on the keyboard etc.

In this post, we will learn what the learning is.

What actually learning is?!

Well, when we look at the anatomy of the machine learning methods, we see that:

There is a set of parameters that should be selected in a way that the model can make the best predictions.

Some of the parameters are learned from the data, and some of them are predefined.

Well, there is a predefined set of parameters, why do we need to learn from the data?

Because they are parameters predefined to adjust the model's behavior, not predicting the output.

Some of the parameters that are predefined are:

These parameters are set before the training and not learned from the data, are known as hyperparameters.

Hyperparameters are like features your classroom—size of the class, equipments available. They are predetermined by the school administration.

But, what about the parameters that are learned from the data?

These parameters are learned from the data and usually referred as weights.

Some of the parameters that are learned from the data are:

These parameters are learned from the data and are used to make the predictions.

Conclusively:

We learned what is "learned", but how does the machine learn?

Well, by using the magical power of the optimization algorithms!

These optimization algorithms are used to minimize the error between the predicted output and the actual output.

What is the metric of error?

It is how our output is different from the actual output!

This metric of error is calculated by using a loss function.

Well, this is another topic to discuss, but in short, loss function is a function that calculates the difference between the predicted output and the actual output.

Our goal is to minimize the loss function to make the model learn better.

Let's make this goal less abstract. Why?

Imagine we are walking through a way to go somewhere. We use GPS, but it is broken. It only shows how far we are from the destination, not the way.

We have to figure out to what could we do to get closer to our destination with our instincts or insights, step by step. The surroundings change.

How do we minimize the loss function?

By using gradient descent! (or its derivatives)

Gradient descent is an optimization algorithm that is used to minimize the loss function.

Well, this is also another topic to discuss, but in short, it is an algorithm that iteratively updates the weights to minimize the loss function.

We can rewrap the previous visualization like:

Types of Feeding Data to the Model

We feed data to make the model learn. And, it doesn't have a single way!

There are mainly two types of feeding data to the model:

  1. Batch Learning: Whole dataset is fed to the model at once.
    • Advantages: Model is stable and as the model sees the whole dataset, it can make better predictions.
    • Disadvantages: It is computationally expensive and time-consuming. Also, the model needs to be retrained from scratch if new data is added.
  2. Online Learning: Data is fed in small batches (in mini-batches, a subset of dataset or one by one) to the model.
    • Advantages: It is computationally less expensive and time-efficient. Also, the model can be updated with new data.
    • Disadvantages: Model can be less stable and can forget the previous data.