Concise Implementation of Recurrent Neural Networks⚓︎
:label:sec_rnn-concise
Like most of our from-scratch implementations,
:numref:sec_rnn-scratch
was designed
to provide insight into how each component works.
But when you are using RNNs every day
or writing production code,
you will want to rely more on libraries
that cut down on both implementation time
(by supplying library code for common models and functions)
and computation time
(by optimizing the heck out of these library implementations).
This section will show you how to implement
the same language model more efficiently
using the high-level API provided
by your deep learning framework.
We begin, as before, by loading
The Time Machine dataset.
%load_ext d2lbook.tab
tab.interact_select('mxnet', 'pytorch', 'tensorflow', 'jax')
%%tab mxnet
from d2l import mxnet as d2l
from mxnet import np, npx
from mxnet.gluon import nn, rnn
npx.set_np()
%%tab pytorch
from d2l import torch as d2l
import torch
from torch import nn
from torch.nn import functional as F
%%tab tensorflow
from d2l import tensorflow as d2l
import tensorflow as tf
%%tab jax
from d2l import jax as d2l
from flax import linen as nn
from jax import numpy as jnp
[Defining the Model]⚓︎
We define the following class using the RNN implemented by high-level APIs.
:begin_tab:mxnet
Specifically, to initialize the hidden state,
we invoke the member method begin_state
.
This returns a list that contains
an initial hidden state
for each example in the minibatch,
whose shape is
(number of hidden layers, batch size, number of hidden units).
For some models to be introduced later
(e.g., long short-term memory),
this list will also contain other information.
:end_tab:
:begin_tab:jax
Flax does not provide an RNNCell for concise implementation of Vanilla RNNs
as of today. There are more advanced variants of RNNs like LSTMs and GRUs
which are available in the Flax linen
API.
:end_tab:
%%tab mxnet
class RNN(d2l.Module): #@save
"""The RNN model implemented with high-level APIs."""
def __init__(self, num_hiddens):
super().__init__()
self.save_hyperparameters()
self.rnn = rnn.RNN(num_hiddens)
def forward(self, inputs, H=None):
if H is None:
H, = self.rnn.begin_state(inputs.shape[1], ctx=inputs.ctx)
outputs, (H, ) = self.rnn(inputs, (H, ))
return outputs, H
%%tab pytorch
class RNN(d2l.Module): #@save
"""The RNN model implemented with high-level APIs."""
def __init__(self, num_inputs, num_hiddens):
super().__init__()
self.save_hyperparameters()
self.rnn = nn.RNN(num_inputs, num_hiddens)
def forward(self, inputs, H=None):
return self.rnn(inputs, H)
%%tab tensorflow
class RNN(d2l.Module): #@save
"""The RNN model implemented with high-level APIs."""
def __init__(self, num_hiddens):
super().__init__()
self.save_hyperparameters()
self.rnn = tf.keras.layers.SimpleRNN(
num_hiddens, return_sequences=True, return_state=True,
time_major=True)
def forward(self, inputs, H=None):
outputs, H = self.rnn(inputs, H)
return outputs, H
%%tab jax
class RNN(nn.Module): #@save
"""The RNN model implemented with high-level APIs."""
num_hiddens: int
@nn.compact
def __call__(self, inputs, H=None):
raise NotImplementedError
Inheriting from the RNNLMScratch
class in :numref:sec_rnn-scratch
,
the following RNNLM
class defines a complete RNN-based language model.
Note that we need to create a separate fully connected output layer.
%%tab pytorch
class RNNLM(d2l.RNNLMScratch): #@save
"""The RNN-based language model implemented with high-level APIs."""
def init_params(self):
self.linear = nn.LazyLinear(self.vocab_size)
def output_layer(self, hiddens):
return d2l.swapaxes(self.linear(hiddens), 0, 1)
%%tab mxnet, tensorflow
class RNNLM(d2l.RNNLMScratch): #@save
"""The RNN-based language model implemented with high-level APIs."""
def init_params(self):
if tab.selected('mxnet'):
self.linear = nn.Dense(self.vocab_size, flatten=False)
self.initialize()
if tab.selected('tensorflow'):
self.linear = tf.keras.layers.Dense(self.vocab_size)
def output_layer(self, hiddens):
if tab.selected('mxnet'):
return d2l.swapaxes(self.linear(hiddens), 0, 1)
if tab.selected('tensorflow'):
return d2l.transpose(self.linear(hiddens), (1, 0, 2))
%%tab jax
class RNNLM(d2l.RNNLMScratch): #@save
"""The RNN-based language model implemented with high-level APIs."""
training: bool = True
def setup(self):
self.linear = nn.Dense(self.vocab_size)
def output_layer(self, hiddens):
return d2l.swapaxes(self.linear(hiddens), 0, 1)
def forward(self, X, state=None):
embs = self.one_hot(X)
rnn_outputs, _ = self.rnn(embs, state, self.training)
return self.output_layer(rnn_outputs)
Training and Predicting⚓︎
Before training the model, let's [make a prediction with a model initialized with random weights.] Given that we have not trained the network, it will generate nonsensical predictions.
%%tab pytorch, mxnet, tensorflow
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
if tab.selected('mxnet', 'tensorflow'):
rnn = RNN(num_hiddens=32)
if tab.selected('pytorch'):
rnn = RNN(num_inputs=len(data.vocab), num_hiddens=32)
model = RNNLM(rnn, vocab_size=len(data.vocab), lr=1)
model.predict('it has', 20, data.vocab)
Next, we [train our model, leveraging the high-level API].
%%tab pytorch, mxnet, tensorflow
if tab.selected('mxnet', 'pytorch'):
trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1, num_gpus=1)
if tab.selected('tensorflow'):
with d2l.try_gpu():
trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1)
trainer.fit(model, data)
Compared with :numref:sec_rnn-scratch
,
this model achieves comparable perplexity,
but runs faster due to the optimized implementations.
As before, we can generate predicted tokens
following the specified prefix string.
%%tab mxnet, pytorch
model.predict('it has', 20, data.vocab, d2l.try_gpu())
%%tab tensorflow
model.predict('it has', 20, data.vocab)
Summary⚓︎
High-level APIs in deep learning frameworks provide implementations of standard RNNs. These libraries help you to avoid wasting time reimplementing standard models. Moreover, framework implementations are often highly optimized, leading to significant (computational) performance gains when compared with implementations from scratch.
Exercises⚓︎
- Can you make the RNN model overfit using the high-level APIs?
- Implement the autoregressive model of :numref:
sec_sequence
using an RNN.
:begin_tab:mxnet
Discussions
:end_tab:
:begin_tab:pytorch
Discussions
:end_tab:
:begin_tab:tensorflow
Discussions
:end_tab:
:begin_tab:jax
Discussions
:end_tab:
创建日期: November 25, 2023