Insults, everyone’s been on the receiving end of them, and more likely than not it’s been on the internet. It’s easy to sit behind a name and say awful things to other people. Reading these comments ruins your day and social media has been directly linked to increased depression in the U.S. Using AI, more specifically an LSTM, we can fix that.
What is an LSTM
An LSTM or Long Short Term Memory is an improvement to recurrent neural networks(RNN), which learn by passing a hidden state along with the input through each part of a sequence. LSTM’s are useful when our problem requires us to remember recent events and past events. For example, let’s say we have a dataset about forest organisms with a fox, bear, and an image that could either be a wolf or dog. We want our network to predict that the image is a wolf by getting a hint from the previous images.
We do this by using the output of one RNN as part of the input of the next one. In mathematical terms, we are combining the vectors of the neural networks through a linear function, and then applying an activation function(sigmoid for example) and squeezing them together. By the end, the final RNN should be able to know that the show is about forest organisms and use that information accordingly when making a prediction. But the example above is an ideal scenario, what if there are things in between the forest animals?
Looking at the two images before it, the RNN isn’t going to be able to know if the image is a dog or wolf since trees and squirrels aren’t necessarily forest organisms. And remember, between each RNN the information from the bear gets squeezed together even more by our activation function. Finally, the information from the bear is pretty much lost. That leads to the main drawbacks of RNN’s, their poor long term memory. To recap, in an RNN, memory is inputted, then joined with the event to produce an output which is then part of the input for the next RNN.
At a high level, an LSTM works similarly.
Instead of just one memory as input, an LSTM has both a long term and short term memory as input. Once inputted, just like the RNN, they merge with the event to produce a prediction. Unlike an RNN though, every time the long and short term memory merge, a new long and short term memory is created. This allows us to store more of the events that occurred in the past, fixing the problem with RNN’s.
Basics of LSTM’s
Let’s expand a bit more on how the LSTM works with long and short term memory, let an elephant represent long term memory, a fish the short term memory, and the wolf/dog our event. Inside an LSTM there are four gates, the forget gate, the learn gate, the remember gate, and the use gate.
Firstly, the long term memory is transferred to the forget gate, where any unnecessary information is removed. Next, the short term memory and the event are joined in the learn gate, which also removes any unnecessary information from what we’ve just learned and the short term memory. Then, the new long term memory and all the information from the learn gate are combined in the remember gate, which produces a new long term memory. Finally, the use gate combines all the info from the forget and learn gate to make a prediction and create a new short term memory.
Okay, we know what goes on inside of our LSTM, but how would we actually accomplish these things. More specifically, what types of functions would we use to transform our vectors and data. It looks like this:
It might seem very complicated, but it can be broken down pretty simply.
The architecture of LSTM’s
Note: The next few sections will require some calculus and linear algebra knowledge.
The learn gate
Let’s go back to our example of forest organisms, remember that the learn gate combines the event and short term memory, then forgets some of it. So how does it work mathematically?
(Excuse my poor labeling) we have short term memory, STM t-1, the event E t and they combine by being put through a linear function, then joining the vectors, multiplying by a matrix, adding a bias, and once again squeezing the results with a tanh activation function. But we’re not done yet, the learn gate also ignores part of the information, so we need to multiply by an ignore factor, i t. i t is also a vector, so how do we calculate it?
To calculate i t we can still use the same information about the short-term memory and event, except we now combine them with a sigmoid function and add a new matrix and bias to our previous equation, and squeeze the results with another sigmoid function. So that's it, that’s all the learn gate does.
Remember, the forget gate takes the long term memory and makes a decision on what to keep and what to forget. Mathematically, it works pretty similarly to the learn gate.
The long term memory is multiplied by a forget factor, ft. In the same way we calculated i t before, we use the same equation and the sigmoid function this time in terms of ft. That is the forget gate.
The remember gate
The remember gate takes the information from the learn and forget gate and combines them to output a new long term memory. This is the easiest concept yet mathematically!
The use gate
Once again the use gate combines the long and short-term memory, but instead, it’s producing a new short-term memory and an output. This one might be a bit complicated in terms of equations.
On the forget gate, we apply an equation using the tanh activation function. On the short-term memory, we apply the sigmoid activation function that we’ve seen before. These are put together to produce our output and a new short-term memory.
LSTM’s on Toxic Comments
Note: This will require Python knowledge and some NLP experience would be helpful as I’m going to show screenshots of my work. I encourage you to recreate this code and play around with some parameters.
In my example, I’ve classified comments on a toxicity level from 0 to 1. More insulting words like stupid or idiot would be given a much higher toxicity rating than words like great or amazing. Sentences are analyzed, then given a rating from 0 to 1 in terms of toxicity, if the rating is above a threshold, the comment is labeled as toxic.
Before we can begin to analyze the comments, we have to do some data preprocessing to help make it easier for our LSTM to spot toxic words/comments. We are also making functions that will make up our LSTM later. In this function, we input a bunch of comments and the LSTM builds a vocabulary of words that appear. To make analyzing easier, we convert our vocabulary to a dictionary. Then we take a count of each word that appears in the comments.
Next, we use the dictionary we just created to sort and label the dataset in a table, making note of the words that appear the most in comments, finally we convert back to a dictionary.
First, we use the build vocabulary function discussed before to generate word mappings to use in the next function. Then, we use the word mappings to modify every word and sentence in our input. Finally, we append all of the words and sentences into number mappings.
Now we divide up our LSTM for a train, test, and cv dataset and print out the shape of our LSTM. (The x train, test, and cv functions were done beforehand as well as the preprocessing of the text.)
In the end, our network has a single embedding layer, 2 LSTM layers, and then a sigmoid function to squish our results and give us a number between 0 and 1.
AI has a tendency to overfit or train so much that they can’t handle anything other than their training data, this makes them useless in the real world. To prevent this, we’ll look at when the AI stops improving drastically, this tells us that it may be time to stop training. At around four epochs or training loops, our AI stopped improving. So at around four epochs, our model begins to overfit. To visualize this, let's make a graph!
As expected at around 3.5–4.0 epochs our validation loss on our test data really starts to tick up. So our best possible model is around 3.0 epochs.
LSTMS are great for natural language processing and in the end, we were able to achieve a 98.38% accuracy in classifying toxic and unwanted comments.
While I wrote the code for the LSTM myself, the original project idea was not mine. All credit goes to this GitHub project which gave links to the dataset I used and helped out a lot with data preprocessing. The visuals used when explaining LSTM’s were from this Udacity course.