用Transformers库实现基础的大模型文本生成以及KV cache注意事项

根据输入的prompt,生成一段指定长度的文字。Llama跑起来太慢了,这里用GPT-2作为列子。

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)

prompt_text = "This is a nice story that makes me"
max_gen_len = 9
input_ids = tokenizer.encode(prompt_text, return_tensors="pt")
prompt_len = input_ids.shape[-1]
print(f'length of prompt: {
     
      
      prompt_len}, length of generation: {
     
      
      max_gen_len}

猜你喜欢

转载自blog.csdn.net/qq_16763983/article/details/138828388
今日推荐