ChatGLM2-6B源码解析 web_demo.py

from transformers import AutoModel, AutoTokenizer
import gradio as gr
import mdtex2html

tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda()
model = model.eval()

"""Override Chatbot.postprocess"""


def postprocess(self, y):
    if y is None:
        return []
    for i, (message, response) in enumerate(y):
        y[i] = (
            None if message is None else mdtex2html.convert((message)),
            None if response is None else mdtex2html.convert(response),
        )
    return y


gr.Chatbot.postprocess = postprocess


def parse_text(text):
    """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
    lines = text.split("\n")
    lines = [line for line in lines if line != ""]
    count = 0
    for i, line in enumerate(lines):
        if "```" in line:
            count += 1
            items = line.split('`')
            if count % 2 == 1:
                lines[i] = f'<pre><code class="language-{items[-1]}">'
            else:
                lines[i] = f'<br></code></pre>'
        else:
            if i > 0:
                if count % 2 == 1:
                    line = line.replace("`", "\`")
                    line = line.replace("<", "&lt;")
                    line = line.replace(">", "&gt;")
                    line = line.replace(" ", "&nbsp;")
                    line = line.replace("*", "&ast;")
                    line = line.replace("_", "&lowbar;")
                    line = line.replace("-", "&#45;")
                    line = line.replace(".", "&#46;")
                    line = line.replace("!", "&#33;")
                    line = line.replace("(", "&#40;")
                    line = line.replace(")", "&#41;")
                    line = line.replace("$", "&#36;")
                lines[i] = "<br>"+line
    text = "".join(lines)
    return text


def predict(input, chatbot, max_length, top_p, temperature, history, past_key_values):
    chatbot.append((parse_text(input), ""))
    for response, history, past_key_values in model.stream_chat(tokenizer, input, history, past_key_values=past_key_values,
                                                                return_past_key_values=True,
                                                                max_length=max_length, top_p=top_p,
                                                                temperature=temperature):
        chatbot[-1] = (parse_text(input), parse_text(response))

        yield chatbot, history, past_key_values


def reset_user_input():
    return gr.update(value='')


def reset_state():
    return [], [], None


with gr.Blocks() as demo:
    gr.HTML("""<h1 align="center">ChatGLM2-6B</h1>""")

    chatbot = gr.Chatbot()
    with gr.Row():
        with gr.Column(scale=4):
            with gr.Column(scale=12):
                user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
                    container=False)
            with gr.Column(min_width=32, scale=1):
                submitBtn = gr.Button("Submit", variant="primary")
        with gr.Column(scale=1):
            emptyBtn = gr.Button("Clear History")
            max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
            top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
            temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)

    history = gr.State([])
    past_key_values = gr.State(None)

    submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history, past_key_values],
                    [chatbot, history, past_key_values], show_progress=True)
    submitBtn.click(reset_user_input, [], [user_input])

    emptyBtn.click(reset_state, outputs=[chatbot, history, past_key_values], show_progress=True)

demo.queue().launch(share=False, inbrowser=True)

这是一个使用Hugging Face的transformers库和gradio库搭建的机器人聊天程序的Python代码,模型预训练的版本为"THUDM/chatglm2-6b",而且使用了mdtex2html来把markdown格式转化为HTML格式。我会逐行解释它的含义:

1-3: 导入必要的模块。transformers 是Hugging Face发布的处理NLP任务的库,提供了大量的预训练模型和相关工具。gradio 是一个创建快速demo用户界面的库,可以让模型快速地在web页面上进行展示和交互。mdtex2html用来将markdown格式的文本转化为HTML格式。

5-6: 从 Hugging Face 模型库中加载了一个预训练的模型及其相应的词条化器(tokenizer),模型的名字是 "THUDM/chatglm2-6b"。trust_remote_code=True 是指示允许从远程代码中加载特殊的tokenizer或model代码。

7: 把模型放在CUDA上,这会让模型运行在GPU上,加快运算速度。

8: 把模型设定为评估模式(eval mode),在这种模式下,模型中的某些层(如 Dropout 或 BatchNorm)的行为会与训练模式有所不同。

10-19: 定义了一个postprocess函数,它是为了对模型的输出进行后处理。在这个函数中,它会将每一条聊天信息相应的回复转化为HTML格式。它将此函数绑定到gradio的Chatbot类的后处理函数上。

21-42: parse_text函数用于处理用户的输入。它主要是处理markdown的代码部分,并把它们转化为HTML格式。

44-52: predict函数接受用户的输入和一些模型参数,它首先把用户的输入添加到聊天记录中,然后用模型生成回复,最后返回聊天记录和新的历史记录。

54-56: 定义了两个函数来重置用户的输入和模型的状态。

58-89: 在这部分中,它用gradio创建了一个用户界面。这个界面包括一个聊天机器人,一个用户输入框,一个提交按钮,一个清空历史按钮,以及一些滑块用于控制模型的参数。

91: 启动这个用户界面。share=False表示这个界面不会被分享,inbrowser=True表示这个界面会在浏览器中打开

总的来说,这是一个使用预训练的模型来生成对话的聊天机器人的程序,它有一个简单的用户界面,用户可以通过这个界面和机器人进行交流。

我会继续分析此代码中的一些关键部分:

AutoTokenizer.from_pretrainedAutoModel.from_pretrained 方法从预训练的模型库中加载了一个模型及其对应的词条化器。它们能够自动地识别模型的类型,并加载相应的模型和词条化器。

model.cuda() 将模型加载到GPU上进行计算。这个操作是在PyTorch中进行的,其目的是利用GPU进行更快的计算。这只有在你的机器上有可用的Nvidia GPU,并且你的PyTorch版本支持CUDA时才会有效。

gr.Chatbot.postprocess = postprocess 这行代码将定义的postprocess函数绑定到gr.Chatbotpostprocess方法。这使得我们可以修改gr.Chatbot的行为,以便于使用mdtex2html.convert函数处理消息和回应,从而将markdown格式的文本转化为HTML格式。

函数 parse_text 是用于处理输入的文本,主要是处理markdown的代码部分,并把它们转化为HTML格式。

函数 predict 是这个程序的核心部分。它使用了模型进行预测,生成聊天机器人的回应。

最后,这个程序使用gradio库来创建一个用户界面。这个界面包括一个聊天机器人,一个用户输入框,一个提交按钮,一个清空历史按钮,以及一些滑块用于控制模型的参数。用户可以通过这个界面和聊天机器人进行交流。

整体来看,这个程序是一个基于预训练模型和gradio库的聊天机器人。它能够处理用户的输入,生成聊天机器人的回应,并通过用户界面和用户进行交互。

reset_user_inputreset_state函数被设计用来重置用户输入和清空模型的状态。它们主要被绑定到了前端的按钮上,当用户点击这些按钮时,会触发相应的函数。

接下来,这个脚本使用gr.Blocks()来创建一个应用的界面。在这个界面中,gr.Chatbot()实例化了一个聊天机器人,gr.Textbox()实例化了一个输入框,用户可以在里面输入文本,然后点击gr.Button()实例化的提交按钮,这会触发绑定到按钮上的函数,把用户的输入提交到聊天机器人,并接收聊天机器人的回应。另外,gr.Slider()实例化了一些滑块,用于调整模型的参数。

gr.State([])gr.State(None)实例化了两个状态对象,这些对象用于保存聊天的历史记录和模型的内部状态。这些状态在每次预测时都会更新。

然后,submitBtn.click()emptyBtn.click()将前面定义的函数绑定到了按钮的点击事件上。当用户点击这些按钮时,会触发绑定到按钮上的函数。

最后,demo.queue().launch(share=False, inbrowser=True)启动了这个应用。用户可以在浏览器中访问这个应用,并与聊天机器人进行交互。

这个脚本实现了一个完整的基于预训练模型的聊天机器人应用,包括用户界面和模型的预测逻辑。它展示了如何使用Hugging Face的transformers库和gradio库来实现一个可交互的机器学习应用。

猜你喜欢

转载自blog.csdn.net/sinat_37574187/article/details/131404024