Beam Search⚓︎
:label:sec_beam-search
In :numref:sec_seq2seq
,
we introduced the encoder--decoder architecture,
and the standard techniques for training them end-to-end. However, when it came to test-time prediction,
we mentioned only the greedy strategy,
where we select at each time step
the token given the highest
predicted probability of coming next,
until, at some time step,
we find that we have predicted
the special end-of-sequence "<eos>" token.
In this section, we will begin
by formalizing this greedy search strategy
and identifying some problems
that practitioners tend to run into.
Subsequently, we compare this strategy
with two alternatives:
exhaustive search (illustrative but not practical)
and beam search (the standard method in practice).
Let's begin by setting up our mathematical notation,
borrowing conventions from :numref:sec_seq2seq
.
At any time step \(t'\), the decoder outputs
predictions representing the probability
of each token in the vocabulary
coming next in the sequence
(the likely value of \(y_{t'+1}\)),
conditioned on the previous tokens
\(y_1, \ldots, y_{t'}\) and
the context variable \(\mathbf{c}\),
produced by the encoder
to represent the input sequence.
To quantify computational cost,
denote by \(\mathcal{Y}\)
the output vocabulary
(including the special end-of-sequence token "<eos>").
Let's also specify the maximum number of tokens
of an output sequence as \(T'\).
Our goal is to search for an ideal output from all
\(\mathcal{O}(\left|\mathcal{Y}\right|^{T'})\)
possible output sequences.
Note that this slightly overestimates
the number of distinct outputs
because there are no subsequent tokens
once the "<eos>" token occurs.
However, for our purposes,
this number roughly captures
the size of the search space.
Greedy Search⚓︎
Consider the simple greedy search strategy from :numref:sec_seq2seq
.
Here, at any time step \(t'\),
we simply select the token
with the highest conditional probability
from \(\mathcal{Y}\), i.e.,
Once our model outputs "<eos>" (or we reach the maximum length \(T'\)) the output sequence is completed.
This strategy might look reasonable, and in fact it is not so bad! Considering how computationally undemanding it is, you'd be hard pressed to get more bang for your buck. However, if we put aside efficiency for a minute, it might seem more reasonable to search for the most likely sequence, not the sequence of (greedily selected) most likely tokens. It turns out that these two objects can be quite different. The most likely sequence is the one that maximizes the expression \(\prod_{t'=1}^{T'} P(y_{t'} \mid y_1, \ldots, y_{t'-1}, \mathbf{c})\). In our machine translation example, if the decoder truly recovered the probabilities of the underlying generative process, then this would give us the most likely translation. Unfortunately, there is no guarantee that greedy search will give us this sequence.
Let's illustrate it with an example.
Suppose that there are four tokens
"A", "B", "C", and "<eos>" in the output dictionary.
In :numref:fig_s2s-prob1
,
the four numbers under each time step represent
the conditional probabilities of generating "A", "B", "C",
and "<eos>" respectively, at that time step.
:label:fig_s2s-prob1
At each time step, greedy search selects
the token with the highest conditional probability.
Therefore, the output sequence "A", "B", "C", and "<eos>"
will be predicted (:numref:fig_s2s-prob1
).
The conditional probability of this output sequence
is \(0.5\times0.4\times0.4\times0.6 = 0.048\).
Next, let's look at another example in :numref:fig_s2s-prob2
.
Unlike in :numref:fig_s2s-prob1
,
at time step 2 we select the token "C",
which has the second highest conditional probability.
:label:fig_s2s-prob2
Since the output subsequences at time steps 1 and 2,
on which time step 3 is based,
have changed from "A" and "B" in :numref:fig_s2s-prob1
to "A" and "C" in :numref:fig_s2s-prob2
,
the conditional probability of each token
at time step 3 has also changed in :numref:fig_s2s-prob2
.
Suppose that we choose the token "B" at time step 3.
Now time step 4 is conditional on
the output subsequence at the first three time steps
"A", "C", and "B",
which has changed from "A", "B", and "C" in :numref:fig_s2s-prob1
.
Therefore, the conditional probability of generating
each token at time step 4 in :numref:fig_s2s-prob2
is also different from that in :numref:fig_s2s-prob1
.
As a result, the conditional probability of the output sequence
"A", "C", "B", and "<eos>" in :numref:fig_s2s-prob2
is \(0.5\times0.3 \times0.6\times0.6=0.054\),
which is greater than that of greedy search in :numref:fig_s2s-prob1
.
In this example, the output sequence "A", "B", "C", and "<eos>"
obtained by the greedy search is not optimal.
Exhaustive Search⚓︎
If the goal is to obtain the most likely sequence, we may consider using exhaustive search: enumerate all the possible output sequences with their conditional probabilities, and then output the one that scores the highest predicted probability.
While this would certainly give us what we desire, it would come at a prohibitive computational cost of \(\mathcal{O}(\left|\mathcal{Y}\right|^{T'})\), exponential in the sequence length and with an enormous base given by the vocabulary size. For example, when \(|\mathcal{Y}|=10000\) and \(T'=10\), both small numbers when compared with ones in real applications, we will need to evaluate \(10000^{10} = 10^{40}\) sequences, which is already beyond the capabilities of any foreseeable computers. On the other hand, the computational cost of greedy search is \(\mathcal{O}(\left|\mathcal{Y}\right|T')\): miraculously cheap but far from optimal. For example, when \(|\mathcal{Y}|=10000\) and \(T'=10\), we only need to evaluate \(10000\times10=10^5\) sequences.
Beam Search⚓︎
You could view sequence decoding strategies as lying on a spectrum, with beam search striking a compromise between the efficiency of greedy search and the optimality of exhaustive search. The most straightforward version of beam search is characterized by a single hyperparameter, the beam size, \(k\). Let's explain this terminology. At time step 1, we select the \(k\) tokens with the highest predicted probabilities. Each of them will be the first token of \(k\) candidate output sequences, respectively. At each subsequent time step, based on the \(k\) candidate output sequences at the previous time step, we continue to select \(k\) candidate output sequences with the highest predicted probabilities from \(k\left|\mathcal{Y}\right|\) possible choices.
:label:fig_beam-search
:numref:fig_beam-search
demonstrates the
process of beam search with an example.
Suppose that the output vocabulary
contains only five elements:
\(\mathcal{Y} = \{A, B, C, D, E\}\),
where one of them is “<eos>”.
Let the beam size be two and
the maximum length of an output sequence be three.
At time step 1,
suppose that the tokens with the highest conditional probabilities
\(P(y_1 \mid \mathbf{c})\) are \(A\) and \(C\).
At time step 2, for all \(y_2 \in \mathcal{Y},\)
we compute
and pick the largest two among these ten values, say \(P(A, B \mid \mathbf{c})\) and \(P(C, E \mid \mathbf{c})\). Then at time step 3, for all \(y_3 \in \mathcal{Y}\), we compute
and pick the largest two among these ten values, say \(P(A, B, D \mid \mathbf{c})\) and \(P(C, E, D \mid \mathbf{c}).\) As a result, we get six candidates output sequences: (i) \(A\); (ii) \(C\); (iii) \(A\), \(B\); (iv) \(C\), \(E\); (v) \(A\), \(B\), \(D\); and (vi) \(C\), \(E\), \(D\).
In the end, we obtain the set of final candidate output sequences based on these six sequences (e.g., discard portions including and after “<eos>”). Then we choose the output sequence which maximizes the following score:
$$ \frac{1}{L^\alpha} \log P(y_1, \ldots, y_{L}\mid \mathbf{c}) = \frac{1}{L^\alpha} \sum_{t'=1}^L \log P(y_{t'} \mid y_1, \ldots, y_{t'-1}, \mathbf{c});$$
:eqlabel:eq_beam-search-score
here \(L\) is the length of the final candidate sequence
and \(\alpha\) is usually set to 0.75.
Since a longer sequence has more logarithmic terms
in the summation of :eqref:eq_beam-search-score
,
the term \(L^\alpha\) in the denominator penalizes
long sequences.
The computational cost of beam search is \(\mathcal{O}(k\left|\mathcal{Y}\right|T')\). This result is in between that of greedy search and that of exhaustive search. Greedy search can be treated as a special case of beam search arising when the beam size is set to 1.
Summary⚓︎
Sequence searching strategies include greedy search, exhaustive search, and beam search. Beam search provides a trade-off between accuracy and computational cost via the flexible choice of the beam size.
Exercises⚓︎
- Can we treat exhaustive search as a special type of beam search? Why or why not?
- Apply beam search in the machine translation problem in :numref:
sec_seq2seq
. How does the beam size affect the translation results and the prediction speed? - We used language modeling for generating text following user-provided prefixes in :numref:
sec_rnn-scratch
. Which kind of search strategy does it use? Can you improve it?
创建日期: November 25, 2023