Module 12 of 26 · Deep Learning with PyTorch · Intermediate

long-short-term-memory-networks

Duration: 8 min

This module delves into Long Short-Term Memory (LSTM) networks, a type of recurrent neural network (RNN) that is particularly effective at learning from sequences of data. LSTMs are crucial for tasks such as time series prediction, natural language processing, and speech recognition due to their ability to remember long-term dependencies.

Understanding LSTM Architecture

LSTM networks are composed of LSTM cells, which include input, output, and forget gates that regulate the flow of information. These gates determine what information is kept, discarded, or updated, allowing the network to effectively manage long-term dependencies and mitigate issues like vanishing gradients.

import torch
import torch.nn as nn

# Define an LSTM network
class LSTMNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(LSTMNetwork, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        out, _ = self.lstm(x, (h0, c0))
        out = self.fc(out[:, -1, :])
        return out

# Initialize the network
input_size = 10
hidden_size = 20
num_layers = 2
output_size = 1

model = LSTMNetwork(input_size, hidden_size, num_layers, output_size)
print(model)

Try it in Google Colab: Open in Colab

LSTMNetwork(
  (lstm): LSTM(10, 20, 2)
  (fc): Linear(in_features=20, out_features=1, bias=True)
)

Training an LSTM Network

Training an LSTM network involves defining a loss function and an optimizer, then iteratively passing data through the network and updating the weights based on the loss. This process helps the network learn the underlying patterns in the data, enabling it to make accurate predictions.

import torch.optim as optim

# Sample data
input_sequence = torch.randn(5, 10)
target = torch.tensor([1.0])

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Forward pass
output = model(input_sequence)

# Compute loss
loss = criterion(output, target)
print(f'Loss: {loss.item()}')

# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()

💡 Tip: When training LSTM networks, it's important to monitor the loss to ensure it is decreasing over time. If the loss plateaus or increases, consider adjusting the learning rate or the network architecture.

❓ What is the primary function of the forget gate in an LSTM cell?

❓ Which of the following is a common issue that LSTMs help mitigate?

← Previous Continue interactively → Next →

Related Courses