%load_ext d2lbook.tab
tab.interact_select(['mxnet', 'pytorch', 'tensorflow', 'jax'])
Concise Implementation of Softmax Regression⚓︎
:label:sec_softmax_concise
Just as high-level deep learning frameworks
made it easier to implement linear regression
(see :numref:sec_linear_concise
),
they are similarly convenient here.
%%tab mxnet
from d2l import mxnet as d2l
from mxnet import gluon, init, npx
from mxnet.gluon import nn
npx.set_np()
%%tab pytorch
from d2l import torch as d2l
import torch
from torch import nn
from torch.nn import functional as F
%%tab tensorflow
from d2l import tensorflow as d2l
import tensorflow as tf
%%tab jax
from d2l import jax as d2l
from flax import linen as nn
from functools import partial
import jax
from jax import numpy as jnp
import optax
Defining the Model⚓︎
As in :numref:sec_linear_concise
,
we construct our fully connected layer
using the built-in layer.
The built-in __call__
method then invokes forward
whenever we need to apply the network to some input.
:begin_tab:mxnet
Even though the input X
is a fourth-order tensor,
the built-in Dense
layer
will automatically convert X
into a second-order tensor
by keeping the dimensionality along the first axis unchanged.
:end_tab:
:begin_tab:pytorch
We use a Flatten
layer to convert the fourth-order tensor X
to second order
by keeping the dimensionality along the first axis unchanged.
:end_tab:
:begin_tab:tensorflow
We use a Flatten
layer to convert the fourth-order tensor X
by keeping the dimension along the first axis unchanged.
:end_tab:
:begin_tab:jax
Flax allows users to write the network class in a more compact way
using @nn.compact
dectorator. With @nn.compact
, one
can simply write all network logic inside a single “forward pass”
method, without needing to define the standard setup
method in
the dataclass.
:end_tab:
%%tab pytorch
class SoftmaxRegression(d2l.Classifier): #@save
"""The softmax regression model."""
def __init__(self, num_outputs, lr):
super().__init__()
self.save_hyperparameters()
self.net = nn.Sequential(nn.Flatten(),
nn.LazyLinear(num_outputs))
def forward(self, X):
return self.net(X)
%%tab mxnet, tensorflow
class SoftmaxRegression(d2l.Classifier): #@save
"""The softmax regression model."""
def __init__(self, num_outputs, lr):
super().__init__()
self.save_hyperparameters()
if tab.selected('mxnet'):
self.net = nn.Dense(num_outputs)
self.net.initialize()
if tab.selected('tensorflow'):
self.net = tf.keras.models.Sequential()
self.net.add(tf.keras.layers.Flatten())
self.net.add(tf.keras.layers.Dense(num_outputs))
def forward(self, X):
return self.net(X)
%%tab jax
class SoftmaxRegression(d2l.Classifier): #@save
num_outputs: int
lr: float
@nn.compact
def __call__(self, X):
X = X.reshape((X.shape[0], -1)) # Flatten
X = nn.Dense(self.num_outputs)(X)
return X
Softmax Revisited⚓︎
:label:subsec_softmax-implementation-revisited
In :numref:sec_softmax_scratch
we calculated our model's output
and applied the cross-entropy loss. While this is perfectly
reasonable mathematically, it is risky computationally, because of
numerical underflow and overflow in the exponentiation.
Recall that the softmax function computes probabilities via \(\hat y_j = \frac{\exp(o_j)}{\sum_k \exp(o_k)}\). If some of the \(o_k\) are very large, i.e., very positive, then \(\exp(o_k)\) might be larger than the largest number we can have for certain data types. This is called overflow. Likewise, if every argument is a very large negative number, we will get underflow. For instance, single precision floating point numbers approximately cover the range of \(10^{-38}\) to \(10^{38}\). As such, if the largest term in \(\mathbf{o}\) lies outside the interval \([-90, 90]\), the result will not be stable. A way round this problem is to subtract \(\bar{o} \stackrel{\textrm{def}}{=} \max_k o_k\) from all entries:
By construction we know that \(o_j - \bar{o} \leq 0\) for all \(j\). As such, for a \(q\)-class
classification problem, the denominator is contained in the interval \([1, q]\). Moreover, the
numerator never exceeds \(1\), thus preventing numerical overflow. Numerical underflow only
occurs when \(\exp(o_j - \bar{o})\) numerically evaluates as \(0\). Nonetheless, a few steps down
the road we might find ourselves in trouble when we want to compute \(\log \hat{y}_j\) as \(\log 0\).
In particular, in backpropagation,
we might find ourselves faced with a screenful
of the dreaded NaN
(Not a Number) results.
Fortunately, we are saved by the fact that even though we are computing exponential functions, we ultimately intend to take their log (when calculating the cross-entropy loss). By combining softmax and cross-entropy, we can escape the numerical stability issues altogether. We have:
This avoids both overflow and underflow. We will want to keep the conventional softmax function handy in case we ever want to evaluate the output probabilities by our model. But instead of passing softmax probabilities into our new loss function, we just [pass the logits and compute the softmax and its log all at once inside the cross-entropy loss function,] which does smart things like the "LogSumExp trick".
%%tab pytorch, mxnet, tensorflow
@d2l.add_to_class(d2l.Classifier) #@save
def loss(self, Y_hat, Y, averaged=True):
Y_hat = d2l.reshape(Y_hat, (-1, Y_hat.shape[-1]))
Y = d2l.reshape(Y, (-1,))
if tab.selected('mxnet'):
fn = gluon.loss.SoftmaxCrossEntropyLoss()
l = fn(Y_hat, Y)
return l.mean() if averaged else l
if tab.selected('pytorch'):
return F.cross_entropy(
Y_hat, Y, reduction='mean' if averaged else 'none')
if tab.selected('tensorflow'):
fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
return fn(Y, Y_hat)
%%tab jax
@d2l.add_to_class(d2l.Classifier) #@save
@partial(jax.jit, static_argnums=(0, 5))
def loss(self, params, X, Y, state, averaged=True):
# To be used later (e.g., for batch norm)
Y_hat = state.apply_fn({'params': params}, *X,
mutable=False, rngs=None)
Y_hat = d2l.reshape(Y_hat, (-1, Y_hat.shape[-1]))
Y = d2l.reshape(Y, (-1,))
fn = optax.softmax_cross_entropy_with_integer_labels
# The returned empty dictionary is a placeholder for auxiliary data,
# which will be used later (e.g., for batch norm)
return (fn(Y_hat, Y).mean(), {}) if averaged else (fn(Y_hat, Y), {})
Training⚓︎
Next we train our model. We use Fashion-MNIST images, flattened to 784-dimensional feature vectors.
%%tab all
data = d2l.FashionMNIST(batch_size=256)
model = SoftmaxRegression(num_outputs=10, lr=0.1)
trainer = d2l.Trainer(max_epochs=10)
trainer.fit(model, data)
As before, this algorithm converges to a solution that is reasonably accurate, albeit this time with fewer lines of code than before.
Summary⚓︎
High-level APIs are very convenient at hiding from their user potentially dangerous aspects, such as numerical stability. Moreover, they allow users to design models concisely with very few lines of code. This is both a blessing and a curse. The obvious benefit is that it makes things highly accessible, even to engineers who never took a single class of statistics in their life (in fact, they are part of the target audience of the book). But hiding the sharp edges also comes with a price: a disincentive to add new and different components on your own, since there is little muscle memory for doing it. Moreover, it makes it more difficult to fix things whenever the protective padding of a framework fails to cover all the corner cases entirely. Again, this is due to lack of familiarity.
As such, we strongly urge you to review both the bare bones and the elegant versions of many of the implementations that follow. While we emphasize ease of understanding, the implementations are nonetheless usually quite performant (convolutions are the big exception here). It is our intention to allow you to build on these when you invent something new that no framework can give you.
Exercises⚓︎
- Deep learning uses many different number formats, including FP64 double precision (used extremely rarely), FP32 single precision, BFLOAT16 (good for compressed representations), FP16 (very unstable), TF32 (a new format from NVIDIA), and INT8. Compute the smallest and largest argument of the exponential function for which the result does not lead to numerical underflow or overflow.
- INT8 is a very limited format consisting of nonzero numbers from \(1\) to \(255\). How could you extend its dynamic range without using more bits? Do standard multiplication and addition still work?
- Increase the number of epochs for training. Why might the validation accuracy decrease after a while? How could we fix this?
- What happens as you increase the learning rate? Compare the loss curves for several learning rates. Which one works better? When?
:begin_tab:mxnet
Discussions
:end_tab:
:begin_tab:pytorch
Discussions
:end_tab:
:begin_tab:tensorflow
Discussions
:end_tab:
:begin_tab:jax
Discussions
:end_tab:
创建日期: November 25, 2023