%load_ext d2lbook.tab
tab.interact_select(['mxnet', 'pytorch', 'tensorflow', 'jax'])
Custom Layers⚓︎
One factor behind deep learning's success is the availability of a wide range of layers that can be composed in creative ways to design architectures suitable for a wide variety of tasks. For instance, researchers have invented layers specifically for handling images, text, looping over sequential data, and performing dynamic programming. Sooner or later, you will need a layer that does not exist yet in the deep learning framework. In these cases, you must build a custom layer. In this section, we show you how.
%%tab mxnet
from d2l import mxnet as d2l
from mxnet import np, npx
from mxnet.gluon import nn
npx.set_np()
%%tab pytorch
from d2l import torch as d2l
import torch
from torch import nn
from torch.nn import functional as F
%%tab tensorflow
from d2l import tensorflow as d2l
import tensorflow as tf
%%tab jax
from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnp
(Layers without Parameters)⚓︎
To start, we construct a custom layer
that does not have any parameters of its own.
This should look familiar if you recall our
introduction to modules in :numref:sec_model_construction
.
The following CenteredLayer
class simply
subtracts the mean from its input.
To build it, we simply need to inherit
from the base layer class and implement the forward propagation function.
%%tab mxnet
class CenteredLayer(nn.Block):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def forward(self, X):
return X - X.mean()
%%tab pytorch
class CenteredLayer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, X):
return X - X.mean()
%%tab tensorflow
class CenteredLayer(tf.keras.Model):
def __init__(self):
super().__init__()
def call(self, X):
return X - tf.reduce_mean(X)
%%tab jax
class CenteredLayer(nn.Module):
def __call__(self, X):
return X - X.mean()
Let's verify that our layer works as intended by feeding some data through it.
%%tab all
layer = CenteredLayer()
layer(d2l.tensor([1.0, 2, 3, 4, 5]))
We can now [incorporate our layer as a component in constructing more complex models.]
%%tab mxnet
net = nn.Sequential()
net.add(nn.Dense(128), CenteredLayer())
net.initialize()
%%tab pytorch
net = nn.Sequential(nn.LazyLinear(128), CenteredLayer())
%%tab tensorflow
net = tf.keras.Sequential([tf.keras.layers.Dense(128), CenteredLayer()])
%%tab jax
net = nn.Sequential([nn.Dense(128), CenteredLayer()])
As an extra sanity check, we can send random data through the network and check that the mean is in fact 0. Because we are dealing with floating point numbers, we may still see a very small nonzero number due to quantization.
:begin_tab:jax
Here we utilize the init_with_output
method which returns both the output of
the network as well as the parameters. In this case we only focus on the
output.
:end_tab:
%%tab pytorch, mxnet
Y = net(d2l.rand(4, 8))
Y.mean()
%%tab tensorflow
Y = net(tf.random.uniform((4, 8)))
tf.reduce_mean(Y)
%%tab jax
Y, _ = net.init_with_output(d2l.get_key(), jax.random.uniform(d2l.get_key(),
(4, 8)))
Y.mean()
[Layers with Parameters]⚓︎
Now that we know how to define simple layers, let's move on to defining layers with parameters that can be adjusted through training. We can use built-in functions to create parameters, which provide some basic housekeeping functionality. In particular, they govern access, initialization, sharing, saving, and loading model parameters. This way, among other benefits, we will not need to write custom serialization routines for every custom layer.
Now let's implement our own version of the fully connected layer.
Recall that this layer requires two parameters,
one to represent the weight and the other for the bias.
In this implementation, we bake in the ReLU activation as a default.
This layer requires two input arguments: in_units
and units
, which
denote the number of inputs and outputs, respectively.
%%tab mxnet
class MyDense(nn.Block):
def __init__(self, units, in_units, **kwargs):
super().__init__(**kwargs)
self.weight = self.params.get('weight', shape=(in_units, units))
self.bias = self.params.get('bias', shape=(units,))
def forward(self, x):
linear = np.dot(x, self.weight.data(ctx=x.ctx)) + self.bias.data(
ctx=x.ctx)
return npx.relu(linear)
%%tab pytorch
class MyLinear(nn.Module):
def __init__(self, in_units, units):
super().__init__()
self.weight = nn.Parameter(torch.randn(in_units, units))
self.bias = nn.Parameter(torch.randn(units,))
def forward(self, X):
linear = torch.matmul(X, self.weight.data) + self.bias.data
return F.relu(linear)
%%tab tensorflow
class MyDense(tf.keras.Model):
def __init__(self, units):
super().__init__()
self.units = units
def build(self, X_shape):
self.weight = self.add_weight(name='weight',
shape=[X_shape[-1], self.units],
initializer=tf.random_normal_initializer())
self.bias = self.add_weight(
name='bias', shape=[self.units],
initializer=tf.zeros_initializer())
def call(self, X):
linear = tf.matmul(X, self.weight) + self.bias
return tf.nn.relu(linear)
%%tab jax
class MyDense(nn.Module):
in_units: int
units: int
def setup(self):
self.weight = self.param('weight', nn.initializers.normal(stddev=1),
(self.in_units, self.units))
self.bias = self.param('bias', nn.initializers.zeros, self.units)
def __call__(self, X):
linear = jnp.matmul(X, self.weight) + self.bias
return nn.relu(linear)
:begin_tab:mxnet, tensorflow, jax
Next, we instantiate the MyDense
class
and access its model parameters.
:end_tab:
:begin_tab:pytorch
Next, we instantiate the MyLinear
class
and access its model parameters.
:end_tab:
%%tab mxnet
dense = MyDense(units=3, in_units=5)
dense.params
%%tab pytorch
linear = MyLinear(5, 3)
linear.weight
%%tab tensorflow
dense = MyDense(3)
dense(tf.random.uniform((2, 5)))
dense.get_weights()
%%tab jax
dense = MyDense(5, 3)
params = dense.init(d2l.get_key(), jnp.zeros((3, 5)))
params
We can [directly carry out forward propagation calculations using custom layers.]
%%tab mxnet
dense.initialize()
dense(np.random.uniform(size=(2, 5)))
%%tab pytorch
linear(torch.rand(2, 5))
%%tab tensorflow
dense(tf.random.uniform((2, 5)))
%%tab jax
dense.apply(params, jax.random.uniform(d2l.get_key(),
(2, 5)))
We can also (construct models using custom layers.) Once we have that we can use it just like the built-in fully connected layer.
%%tab mxnet
net = nn.Sequential()
net.add(MyDense(8, in_units=64),
MyDense(1, in_units=8))
net.initialize()
net(np.random.uniform(size=(2, 64)))
%%tab pytorch
net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1))
net(torch.rand(2, 64))
%%tab tensorflow
net = tf.keras.models.Sequential([MyDense(8), MyDense(1)])
net(tf.random.uniform((2, 64)))
%%tab jax
net = nn.Sequential([MyDense(64, 8), MyDense(8, 1)])
Y, _ = net.init_with_output(d2l.get_key(), jax.random.uniform(d2l.get_key(),
(2, 64)))
Y
Summary⚓︎
We can design custom layers via the basic layer class. This allows us to define flexible new layers that behave differently from any existing layers in the library. Once defined, custom layers can be invoked in arbitrary contexts and architectures. Layers can have local parameters, which can be created through built-in functions.
Exercises⚓︎
- Design a layer that takes an input and computes a tensor reduction, i.e., it returns \(y_k = \sum_{i, j} W_{ijk} x_i x_j\).
- Design a layer that returns the leading half of the Fourier coefficients of the data.
:begin_tab:mxnet
Discussions
:end_tab:
:begin_tab:pytorch
Discussions
:end_tab:
:begin_tab:tensorflow
Discussions
:end_tab:
:begin_tab:jax
Discussions
:end_tab:
创建日期: November 25, 2023