Baidu engineers analyze the decoding strategy

Author | Jane

guide

There are two main types of decoding methods for generative models: deterministic methods (such as greedy search and beam search) and stochastic methods. The text generated by deterministic methods is often unnatural and may contain repetitive or oversimplified expressions. Whereas, stochastic methods introduce randomness into the decoding process in order to generate more diverse and natural text. Two common methods of randomization are:

1. Top-k sampling : At each decoding step, the model selects the top k words with the highest likelihood, and then randomly selects one of these words as the next generated word. This adds variety to the text, but still maintains some control.

2. Kernel Sampling (Top-p Sampling) : In this method, the model selects the next word from the vocabulary based on the cumulative probability. Cumulative probability refers to the sum of the probabilities of words arranged in order of probability from high to low. This reduces repetitiveness, and it can adaptively select fewer or more candidate words relative to a fixed value of k.

Although kernel sampling can generate (relieve) text without repetition, the semantic consistency of the generated text is not very good, and this problem of semantic inconsistency can be partially resolved by lowering the temperature. Lowering temperature is a parameter that can affect randomness. Higher temperatures lead to a more even distribution, making the generated text more diverse, while lower temperatures make the distribution more focused and closer to deterministic. This introduces a trade-off, as higher temperatures may result in inconsistent text semantics, while lower temperatures may lose some diversity.

In practical applications, the appropriate decoding method, randomness parameter and temperature value should be selected according to the task and desired text output characteristics. Different methods and parameter combinations may be suitable for different situations to balance the diversity, accuracy and consistency of the generated text.

The full text is 3646 words, and the estimated reading time is 10 minutes.

01 Contrastive search (contrastive_search)

Compare and search for a given prefix text \(x_{< t}\) , and output token \(x_{t}\) according to the following formula :

picture

The probability distribution \(p_{\theta}(v|x_{< t})\) output by the language model represents the predicted probability of the next possible token, and \(V^{(k)}\) in the above formula represents the probability A collection of the k most probable candidate tokens in the distribution.

  • The first term, model confidence, is the probability of each candidate token v predicted by the language model.

  • The second item, degeneration penalty, is used to measure the similarity between the candidate token v and each token in the above \(x{< t}\) , the vector representation of v \(h_{v}\) and The vector representation of each token in the above \(x {< t}\) calculates the cosine similarity, and the maximum value of the similarity is used as a degradation penalty. Intuitively, a larger degradation penalty for v means that it is more similar to the above (in representation space), and thus more likely to cause model degradation problems. The hyperparameter  \(\alpha\) is used to compromise between these two items. At that time \(\alpha=0\) , the comparison search degenerates into a pure greedy search.

In summary, a comparison search considers both when generating output:

  • The probabilities predicted by the language model to maintain semantic coherence between the generated text and the prefixed text.

  • Similarity to above to avoid model degradation.

# generate the result with contrastive search
output = model.generate(
    input_ids, 
    penalty_alpha=0.6,  # 对比搜索中的超参 $\alpha$
    top_k=4,  # 对比搜索中的超参 $k$。
    max_length=512
 )

02 Greedy search (greedy_search)**

Greedy search simply selects the word with the highest probability as the current output word at each time step: \(w_t = argmax_{w}P(w | w_{1:t-1})\)

picture

△ greedy search

question:

  • prone to outputting repetitive text, which is a very common problem in language generation and seems to be even more so in greedy search and beam search

  • The main disadvantage is that it misses high probability words hidden behind low probability words: The -> dog -> has (0.4*0.9=0.36), The -> nice -> wman (0.5*0.4=0.20), beam search can alleviate such problems

03 beam search (beam_search)

The whole process of beam search can be summarized as: bifurcation, sorting, pruning, and so on. Beam search reduces the risk of missing potentially high-probability sequences by keeping the most likely num_beams words at each time step, from which the sequence with the highest probability is eventually selected.

The example below shows num_beams=2:

picture

△ beam search num_beams=2

Beam search will generally find an output sequence with higher probability than greedy search, but it is still not guaranteed to find the global optimal solution.

While the results are smoother than greedy search, the output still contains repetitions. A simple remedy is to introduce n-grams (i.e. word sequences of consecutive n words) penalties : the most common n-grams penalty is to ensure that each n-gram occurs exactly once by If the n-gram formed above has already appeared, the probability of the candidate word is set to 0. Try it out by setting no_repeat_ngram_size=2 so that no arbitrary 2-gram appears twice:

beam_output = model.generate(
    input_ids, 
    max_length=50, 
    num_beams=5, 
    no_repeat_ngram_size=2, # n-grams
    early_stopping=True
)

‍However, the n-gram penalty must be used carefully, such as an article about the city New York should not use 2-gram penalty, otherwise, the city name will appear only once in the entire text!

Beam search has been shown to still suffer from repeated generation problems . In a scenario like "story generation", it is difficult to control with n-grams or other penalties, because finding a good compromise between "no repetition" and maximum repeatable n-grams requires a lot of fine-tuning. As demonstrated by Ari Holtzman et al. (2019) ( https://arxiv.org/abs/1904.09751 ), high-quality human language does not follow the law of maximum probability . This is because human language is creative and surprising, not just simply predictive.

Therefore, introducing random and creative elements is the key to generate more interesting and diverse texts.

04 Sampling

4.1 Sampling

Text generation itself is no longer deterministic when using sampling methods (do_sample=True).

# activate sampling and deactivate top_k by setting top_k sampling to 0
sample_output = model.generate(
    input_ids, 
    do_sample=True, 
    max_length=50, 
    top_k=0
)

Problems when sampling word sequences: models often produce incoherent garbled characters. One trick to alleviate this problem is to make the distribution \(P(w|w_{1:t-1} ) \) is steeper. And lowering the "temperature" essentially increases the likelihood of high-probability words and decreases the likelihood of low-probability words.

sample_output = model.generate(
    input_ids, 
    do_sample=True, 
    max_length=50, 
    top_k=0, 
    temperature=0.7
)

Although the temperature can reduce the randomness of the distribution, under extreme conditions, when the "temperature" is set to 0, the temperature scaling sampling degenerates into greedy decoding , so it will encounter the same problem as greedy decoding.

4.2 Top-k Sampling

In Top-K sampling, the K words with the highest probability will be selected, then the probabilities of these K words will be renormalized , and finally the K words will be sampled after the renormalized probability. GPT2 employs this sampling scheme, which is one of the reasons for its success on tasks like story generation.

picture

Assuming p=0.92, Top-p sampling sorts the word probabilities in descending order and accumulates them, and then selects the word set whose probability exceeds p=92% for the first time as the sampling pool, defined as\(V_{\text{top-p}} \ ) . At t=1 \(V_{\text{top-p}}\) has 9 words, and at t=2 it only needs to select the first 3 words to exceed 92%.

It can be seen that when the word is less predictable (such as the flatter left image), it retains more candidates, such as \(P(w | \text{“The”})\) , while when the word seems When it is easier to predict (such as the sharper right image), only a few candidate words are kept, such as \(P(w | \text{“The”}, \text{“car”})\) .

# deactivate top_k sampling and sample only from 92% most likely words
sample_output = model.generate(
    input_ids, 
    do_sample=True, 
    max_length=50, 
    top_p=0.92, 
    top_k=0
)

Although in theory, Top-p seems more elegant than Top-K, both approaches work well in practice. Top-p can also be combined with Top-K, which avoids very low-ranked words while allowing some dynamic selection. If both k and p are enabled, then p takes effect after k.

# 配置 top_k = 50 、 top_p = 0.95 、 num_return_sequences = 3
sample_outputs = model.generate(
    input_ids,
    do_sample=True, 
    max_length=50, 
    top_k=50, 
    top_p=0.95, 
    num_return_sequences=3
)

——END——

References:

[1] A simple and effective decoding strategy: Contrastive Search

[2]HF: How to generate text: Generate text with different decoding methods through Transformers

[3]https://docs.cohere.ai/docs/controlling-generation-with-top-k-top-p

[4]https://docs.cohere.ai/docs/temperature

Recommended reading:

Baidu engineers analyze reinforcement learning

Talking about the Design and Development of Unified Authority Management Service

Baidu APP iOS terminal package size 50M optimization practice (5) HEIC picture and useless class optimization practice

Baidu Knows Cloud and Architecture Evolution

Baidu APP iOS terminal package size 50M optimization practice (4) code optimization

Baidu App Startup Performance Optimization Practice

Redis 7.2.0 was released, the most far-reaching version Chinese programmers refused to write gambling programs, 14 teeth were pulled out, and 88% of the whole body was damaged. Flutter 3.13 was released. System Initiative announced that all its software would be open source. The first large-scale independent App appeared , Grace changed its name to "Doubao" Spring 6.1 is compatible with virtual threads and JDK 21 Linux tablet StarLite 5: default Ubuntu, 12.5-inch Chrome 116 officially released Red Hat redeployed desktop Linux development, the main developer was transferred away Kubernetes 1.28 officially released
{{o.name}}
{{m.name}}

Guess you like

Origin my.oschina.net/u/4939618/blog/10100869