How does a neural network learn

栏目: IT技术 · 发布时间: 4年前

内容简介:Each neuron will have both a linear function of the formThis results in a vector with shapeLet’s now walk through the key concepts to understand how the neural network learns through this architecture.

Each neuron will have both a linear function of the form z = wᵢxᵢ + b , with i representing values from 1 to 4096 in our example above. Each of this result is then passed to a non-linear activation function g(z) .

This results in a vector with shape (8, 209) in our example, which will be fed into the output layer. The output layer in our example is a single node that performs a binary classification.

Let’s now walk through the key concepts to understand how the neural network learns through this architecture.

Linear functions and non-linear activation functions

If we look at each of the neuron, we see that there is both a linear function represented by z = wx + b and a non-linear activation function. This simple set-up allows for non-linear fitting, whilst keeping the functions simple.

As you will see later when we discuss back-propagation, we use derivatives to minimize the prediction loss, and adjust the weights and biases accordingly. This simple use of a linear and non-linear function allows us to achieve the desired non-linearity while retaining the ability to use the chain rule to easily compute the derivatives. But we’ll come to that later.

We discussed using the sigmoid function as an activation function in Part 1. This function is defined as:

Sigmoid function

The sigmoid function is not often used as there are more efficient implementations using the ReLU (Rectified Linear Unit) to be covered later. However, due to the nature that its output is constrained between 0 and 1, it is still useful for binary classification tasks.

Functions like this also suffer from what is commonly known as the vanishing gradient problem . The slope of function approaches 0 at extremes of the input values. If you recall, we will be using derivatives during back propagation to minimise the loss parameter. These activation functions present a challenge when the input value is large, resulting in a derivative (slope) close to zero. This results in very small deltas, meaning very slow learning!

This is partly the reason why we normally standardise our inputs to values that are typically between -1 and 1. When I first started learning this topic, I had no idea why this was needed … a number is a number right? Most literature only indicates that the network will perform better when the inputs are scaled to a number either between -1 and 1 or 0 and 1 depending on the scenario.

Understanding the vanishing gradient problem gives some insights into why standardising the inputs is recommended.

A more common activation function used today is the Rectified Linear Unit (ReLU). Mathematically, ReLU can be defined as:

Rectified Linear Unit (ReLU)
Fun fact: Technically, the gradient is non-existent when x is exactly equal to zero. But the probability of this happening is very rare, and you can programatically just set the gradient to 0 if that ever happens. Not a real practical issue.

Programmatically, it can easily be defined as the max of 0 and the input value.

g(z) = max(0, input)

Another side benefit of this activation function during back propagation is that the slope of this function is simply 1 for any non-zero values. Though the vanishing gradient problem is avoided for positive values, it can still result in what is known as a “Dying ReLU” if there is a large negative bias on one of the neurons (effectively, the neuron will be stuck on the negative side and always output a value of 0). This neuron is effectively “dead” as it no longer plays a part in the fitting process.

Some implementations use a hybrid approach known as a Leaky ReLU to solve this issue, where:

Leaky ReLU

For this series, we will only use ReLU for the hidden layers, and the sigmoid function for the output node (again — only because we are doing binary classification). If you are interested to read more on the other hybrid ReLU activation functions, there is an interesting Medium post by Liu Danqing below:

The network’s first attempt

Prior to the first training run, the weights matrix is initialised to a very small random number .

It is important that the weights start off with a small random number. If all the weights were initialized with zeros, all the outputs would be zero because of z = wx + b . Having equal weights will not work as well as the back propagated errors will all be equal, resulting in all the weights being updated by the same amount.

The bias vector can be initialized to 0 in the first run.

Clearly, the network’s first attempt is really nothing more than a random guess!

Forward propagation

Forward propagation is the process of taking the inputs, weights, bias, and activation function to compute the values at each stage. For example, in our single-layer architecture, the computation will happen in this order:

Block diagram showing forward and backward propagation
  • Compute Z1

    Remember that z = wx + b , we will essentially do the following matrix multiplication. The equation below shows the matrix multiplication for one training example, but Z will be a matrix with shape (8, 209) after applying the entire training set.

  • Compute A1 (activation function)
    We are using ReLU on the hidden layer, so A1 will simply be computed as the max of 0 and the current value of Z1 (i.e. any negative numbers just become 0).
  • Feed A1 to the output layer
    The matrix A1(also with shape of (8, 209) ) will now be fed into the output layer as inputs.
  • Compute Z[L]
    In this final output layer, the weights matrix has a shape of (1, 8) . Another matrix multiplication will result in a final output with a shape of (1, 209) , which is then passed into a sigmoid function.
  • Compute A[L] (final prediction)
    We pass the computed value to the sigmoid function, which will result in a value between 0 and 1 that is returned.

Compute the loss function

The prediction from this run is based on the model’s weights and biases, and the output either a 1 or 0 for this binary classification example. This is compared with the ground truth which is found in the labels that we extracted as part of the training set (read into the train_set_y variable).

For binary classification, the loss function is used for each training example is as follows.

Loss function — Binary Classification

Intuitively, you can see how this achieves what we need for cases where the ground truth, y = 0 or 1. It helps to remember that the log curves cut the x-axis at x = 1 .

Log Function
Intuitive explanation for the Loss Function

The cost function is the average of all the losses over the training set, and this is the equation that we want to minimize (i.e. take the derivative of this with respect to the other parameters).

Cost Function — Binary Classification

Using back propagation to adjust the weights and biases

The network will use a concept of back propagation to adjust the weights and biases. To understand this intuitively, you need to consider this.

  • We want to minimise the cost function with respect to the trainable parameters, i.e. what should the W and b be adjusted to? This diagram shows the functions used in the forward propagation, and the corresponding derivatives.
Block diagram showing forward and backward propagation
  • To minimise a parameter, we take the derivative, which in this case, means we want to obtain the following:
Taking the derivative of the cost function wrt a

The proof of this derivative is outside the scope of this article, though some other writers have done a full article on it. I enjoyed reading the details of the derivation provided by Patrick David in this article:

  • We use the chain-rule of derivatives to compute the following:
Chain rule to compute the derivative of cost function wrt weights
  • We then update each of these weights and biases:

Where α is the hyperparameter for the learning rate. We will cover the details separately, but it defines how much to nudge the adjustments. Too small, and the learning process will be slow. Too large, and you might overshoot the local minimum. There’s an art and science to tuning this.

Repeat

The process above then repeats itself for the number of iterations that you have defined. This variable is known as a hyperparameter of the model, similar to the learning rate that you encountered above. These are not trainable parameters, and it is both an art and science to tune these hyperparameters. We will cover more of this in a later article.

Conclusion

This is a very heavy topic, but I hope that it highlights at a high-level how a neural network learns through forward propagation, and subsequently nudges the weights and biases by taking the derivatives of the equations with respect to the cost function in the back propagation.

It is mainly through this back propagation that the weights and biases are slowly nudged along to a value that will minimize the cost, which effectively maximizes the accuracy. That is why you always hear people say that the training neural networks is about minimizing the cost, rather than aiming to maximize the accuracy.

If you do not understand the concepts the first time round, do not despair. It took me a while to understand back propagation, and I must admit that I am still discovering new things while writing this article.

Writing these articles help me firm up my understanding, and I hope that you have learnt something from it.

Again, I have made every attempt to check the work. If there are errors in my understanding, please let me know!

The next article will get into some codes to implement the concepts that we have discussed.


以上所述就是小编给大家介绍的《How does a neural network learn》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Visual C# 2008入门经典

Visual C# 2008入门经典

James Foxall / 张劼 / 人民邮电出版社 / 2009-6 / 39.00元

《Visual C#2008入门经典》分为五部分,共24章。第一部分介绍了Visual C# 2008速成版开发环境,引导读者熟练使用该IDE;第二部分探讨如何创建应用程序界面,包含窗体和各种控件的用法;第三部分介绍了编程技术,包括编写和调用方法、处理数值、字符串和日期、决策和循环结构、代码调试、类和对象的创建以及图形绘制等;第四部分阐述了文件和注册表的处理、数据库的使用和自动化其他应用程序等;第......一起来看看 《Visual C# 2008入门经典》 这本书的介绍吧!

JSON 在线解析
JSON 在线解析

在线 JSON 格式化工具

XML、JSON 在线转换
XML、JSON 在线转换

在线XML、JSON转换工具

HEX HSV 转换工具
HEX HSV 转换工具

HEX HSV 互换工具