Recurrent Neural Networks
What is a Recurrent Neural Network?
An RNN(Recurrent Neural Network) is a type of Neural Network that remembers
the sequence of the inputs that we feed it.
These kind of neural networks are used for NLP tasks such as machine
translation, building chatbots, text summarization and many other scenarios.
RNNs are mostly used while working with sentences or speech.
A simple RNN
How is an RNN different from other neural networks?
Neural networks like ANNs and CNNs are known as feed forward neural networks.
The cells and layers in these networks just take inputs, process them and
return outputs.
In an RNN, the output of each timestep also depends on the output of the
previous timesteps. This implies that an RNN can remember and learn the
sequence of the words in each sentence of the input.
What is teacher forcing?
Teacher forcing is one of the most important aspects of an RNN.
It is the process of discarding the output of
timestepn(tn) and using the expected output of
tn as the input for t(n+1)
Why is teacher forcing so important?
Let's try to understand this with an example
Given the following input sentence:
"The grass is green"
Let's say we want to train our model to predict the next word in the sequence
given the previous sequence of words.
Without teacher forcing:
Timesteps | Input | Output |
---|---|---|
t1 | The | bird |
t2 | bird | was |
t3 | was | flying |
Output Sentence: The bird was flying
Desired output sentence: The grass is green
In the above example, we see that the model is predicting a random word at t1. Due to this, it is predicting the wrong outputs at t2 and t3.
This is why it is predicting the wrong output sequence.
With teacher forcing:
https://towardsdatascience.com/what-is-teacher-forcing-3da6217fed1c
Timesteps | Input | Output |
---|---|---|
t1 | The | bird |
t2 | grass | was |
t3 | is | green |
Output Sentence: The grass is green
Here, the model is predicting the right output sentence as we're
correcting it in each timestep of the process.
Hence, it learns the sequence of the words in the sentence.
This is why teacher forcing is so important and widely used in recurrent
neural networks.
How does an RNN remember the sequence of words in a sentence?
RNNs use return states to do this.
There are 2 types of Return States:
1. Hidden State(hn)
2. Cell State(cn)
Hidden State:
The hidden state is just another name for the output at each timestep.
Cell State:
Each cell in the neural network retains an internal state known as the
cell state. The cell state(cn) gets updated from the hidden
state of the previous timestep(hn-1).
Since the cell state is basically the accumulation of the hidden states
of all the previous timesteps, it helps in remembering the sequence of
words in each sentence.
Thank you for taking your time to read this blog.
Feel free to connect with me using this link: https://linktr.ee/Yashas96
References