Lawyer LLaMA(中文法律大模型本地部署)
1.模型选择(lawyer-llama-13b-v2
)
2.运行环境
1.建议使用Python 3.8及以上版本。
2.主要依赖库如下:
transformers
>= 4.28.0 注意:检索模块需要使用transformers <= 4.30sentencepiece
>= 0.1.97gradio
3.使用步骤
1.从HuggingFace下载 **Lawyer LLaMA 2 (lawyer-llama-13b-v2
)**模型参数。(需要的torch )
# Use a pipeline as a high-level helper
from transformers import pipeline
pipe = pipeline("text-generation", model="pkupie/lawyer-llama-13b-v2")
2.从HuggingFace下载法条检索模块,并运行其中的python server.py
启动法条检索服务,默认挂在9098端口。(注意事项,拉取的代码有可能少labels2id.pkl,pytorch_model.bin等文件)
1.git lfs install
2.git clone https://huggingface.co/pkupie/marriage_law_retrieval
3.GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/pkupie/marriage_law_retrieval
4.server.py代码这样的,模型路径手动更改
import json
import subprocess
import os
import codecs
import logging
import os
import math
import json
import random
from tqdm import tqdm
from transformers import pipeline
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
from flask import Flask, request, jsonify
import json
import random
from tqdm import tqdm
import os
import pickle as pkl
from argparse import Namespace
from models import Elect
import torch
from transformers import AutoModel, AutoTokenizer
from sklearn.preprocessing import MultiLabelBinarizer
logger = logging.getLogger(__name__)
app = Flask(__name__)
hunyin_classifier = None
fatiao_args = Namespace()
fatiao_tokenizer = None
fatiao_model = None
@app.route('/check_hunyin', methods=['GET', 'POST'])
def check_hunyin():
input_text = request.json['input'].strip()
force_return = request.json['force_return'] if 'force_return' in request.json else False
print("input_text:", input_text)
if len(input_text) == 0:
json_result = {
"output": []
}
return jsonify(json_result)
if not force_return:
classifier_result = hunyin_classifier(input_text[:500])
print(classifier_result)
classifier_result = classifier_result[0]['label']
# 加一条规则,如果输入文本中包含“婚”字,那么直接判定为婚姻相关
if '婚' in input_text:
classifier_result = True
# 如果不是婚姻相关的,直接返回空
if classifier_result == False:
json_result = {
"output": []
}
return jsonify(json_result)
inputs = fatiao_tokenizer(input_text, padding='max_length', truncation=True, max_length=256, return_tensors="pt")
batch = {
'ids': inputs['input_ids'],
'mask': inputs['attention_mask'],
'token_type_ids': inputs["token_type_ids"]
}
model_output = fatiao_model(batch)
pred = torch.sigmoid(model_output).cpu().detach().numpy()[0]
pred_laws = []
for law_id, score in sorted(enumerate(pred), key=lambda x: x[1], reverse=True):
pred_laws.append({
'id': law_id,
'score': float(score),
'text': fatiao_args.mlb.classes_[law_id]
})
json_result = {
"output": pred_laws[:3]
}
print("json_result:", json_result)
return jsonify(json_result)
if __name__ == "__main__":
# 加载咨询分类模型,用于判断是否与婚姻有关
hunyin_classifier_path = "C:/Users/win10/PycharmProjects/lawyer-llama_/marriage_law_retrieval/pretrained_models/roberta_wwm_ext_hunyin_2epoch/"
# 检查模型文件是否存在
model_file = os.path.join(hunyin_classifier_path, "pytorch_model.bin")
# 打印目录内容
print("Files in directory:")
for filename in os.listdir(hunyin_classifier_path):
print(filename)
if not os.path.exists(model_file):
print(f"Model file not found at {
model_file}")
else:
print(f"Model file found at {
model_file}")
hunyin_config = AutoConfig.from_pretrained(
hunyin_classifier_path,
num_labels=2,
)
hunyin_tokenizer = AutoTokenizer.from_pretrained(
hunyin_classifier_path
)
hunyin_model = AutoModelForSequenceClassification.from_pretrained(
hunyin_classifier_path,
config=hunyin_config,
)
hunyin_classifier = pipeline(model=hunyin_model, tokenizer=hunyin_tokenizer, task="text-classification", device=0)
print("Model loaded successfully")
# 加载法条检索模型
fatiao_args.ckpt_dir = r"C:\Users\win10\PycharmProjects\lawyer-llama_\marriage_law_retrieval\pretrained_models\chinese-roberta-wwm-ext"
fatiao_args.device = "cuda:0"
# 确认路径是否正确
labels2id_path = os.path.join("data", "labels2id.pkl")
if not os.path.exists(labels2id_path):
print(f"Labels2id file not found at {
labels2id_path}")
else:
print(f"Labels2id file found at {
labels2id_path}")
with open(labels2id_path, "rb") as f:
laws2id = pkl.load(f)
fatiao_args.labels = list(laws2id.keys())
id2laws = {
}
for k, v in laws2id.items():
id2laws[v] = k
print("法条个数:", len(id2laws))
fatiao_tokenizer = AutoTokenizer.from_pretrained(fatiao_args.ckpt_dir)
fatiao_args.tokenizer = fatiao_tokenizer
fatiao_model = Elect(fatiao_args, "cuda:0").to("cuda:0")
fatiao_model.eval()
mlb = MultiLabelBinarizer()
mlb.fit([fatiao_args.labels])
fatiao_args.mlb = mlb
with torch.no_grad():
for idx, l in enumerate(fatiao_args.labels):
text = ':'.join(l.split(':')[1:]).lower()
la_in = fatiao_tokenizer(text, padding='max_length', truncation=True, max_length=256, return_tensors="pt")
ids = la_in['input_ids'].to(fatiao_args.device)
mask = la_in['attention_mask'].to(fatiao_args.device)
fatiao_model.la[idx] += (fatiao_model.plm(input_ids=ids, attention_mask=mask)[0][:, 0]).squeeze(0)
fatiao_model.load_state_dict(torch.load('./pretrained_models/ELECT', map_location=torch.device(fatiao_args.device)))
fatiao_model.to(fatiao_args.device)
logger.info("model loaded")
app.run(host="0.0.0.0", port=9098, debug=False)
5.如需使用nginx反向代理访问此服务,可参考https://github.com/LeetJoe/lawyer-llama/blob/main/demo/nginx_proxy.md (Credit to @LeetJoe)
1.启动命令 python demo_web.py --port 7863 --checkpoint “C:/Users/win10/.cache/huggingface/hub/models–pkupie–lawyer-llama-13b-v2/snapshots/f61a4a16c97b6bd546790d88eaec7bc7fcd7344b” --classifier_url “http://127.0.0.1:9098/check_hunyin” --offload_folder “C:/path/to/offload/folder”(内存不够时启动的命令在这个命令中,--offload_folder "C:/path/to/offload/folder"
用于指定一个目录,用来存储模型的部分数据,从而减轻内存负担。这通常是在处理大模型时的一种策略,通过将一些不常用的模型部分卸载到磁盘上,可以节省系统内存(RAM)的使用。)
2.python demo_web.py --port 7863 --checkpoint “C:/Users/win10/.cache/huggingface/hub/models–pkupie–lawyer-llama-13b-v2/snapshots/f61a4a16c97b6bd546790d88eaec7bc7fcd7344b” --classifier_url “http://127.0.0.1:9098/check_hunyin”(内存够的时候启动命令)
demo_web.py代码
import gradio as gr
import requests
import json
from transformers import LlamaForCausalLM, LlamaTokenizer, TextIteratorStreamer
import torch
import threading
import argparse
class StoppableThread(threading.Thread):
"""Thread class with a stop() method. The thread itself has to check
regularly for the stopped() condition."""
def __init__(self, *args, **kwargs):
super(StoppableThread, self).__init__(*args, **kwargs)
self._stop_event = threading.Event()
def stop(self):
self._stop_event.set()
def stopped(self):
return self._stop_event.is_set()
def json_send(url, data=None, method="POST"):
headers = {
"Content-type": "application/json", "Accept": "text/plain", "charset": "UTF-8"}
try:
if method == "POST":
if data is not None:
response = requests.post(url=url, headers=headers, data=json.dumps(data))
else:
response = requests.post(url=url, headers=headers)
elif method == "GET":
response = requests.get(url=url, headers=headers)
response.raise_for_status() # Ensure we notice bad responses
return response.json() # Return the response as a JSON object
except requests.exceptions.RequestException as e:
print(f"HTTP Request failed: {
e}")
return {
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--checkpoint", type=str, default="")
parser.add_argument("--classifier_url", type=str, default="")
parser.add_argument("--load_in_8bit", action="store_true")
parser.add_argument("--offload_folder", type=str, default="./offload")
args = parser.parse_args()
checkpoint = args.checkpoint
classifier_url = args.classifier_url
print("Loading model...")
tokenizer = LlamaTokenizer.from_pretrained(checkpoint)
if args.load_in_8bit:
model = LlamaForCausalLM.from_pretrained(checkpoint, device_map="auto", load_in_8bit=True, offload_folder=args.offload_folder)
else:
model = LlamaForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.float16, offload_folder=args.offload_folder)
print("Model loaded.")
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
input_msg = gr.Textbox(label="Input")
with gr.Row():
generate_button = gr.Button('Generate', elem_id='generate', variant='primary')
clear_button = gr.Button('Clear', elem_id='clear', variant='secondary')
def user(user_message, chat_history):
user_message = user_message.strip()
return "", chat_history + [[user_message, None]]
def bot(chat_history):
# extract user inputs from chat history and retrieve law articles
current_user_input = chat_history[-1][0]
if len(current_user_input) == 0:
yield chat_history[:-1]
return
# 检索法条
history_user_input = [x[0] for x in chat_history]
input_to_classifier = " ".join(history_user_input)
data = {
"input": input_to_classifier}
result = json_send(classifier_url, data, method="POST")
retrieve_output = result.get('output', [])
# 构造输入
if len(retrieve_output) == 0:
input_text = "你是人工智能法律助手“Lawyer LLaMA”,能够回答与中国法律相关的问题。\n"
for history_pair in chat_history[:-1]:
input_text += f"### Human: {
history_pair[0]}\n### Assistant: {
history_pair[1]}\n"
input_text += f"### Human: {
current_user_input}\n### Assistant: "
else:
input_text = f"你是人工智能法律助手“Lawyer LLaMA”,能够回答与中国法律相关的问题。请参考给出的\"参考法条\",回复用户的咨询问题。\"参考法条\"中可能存在与咨询无关的法条,请回复时不要引用这些无关的法条。\n"
for history_pair in chat_history[:-1]:
input_text += f"### Human: {
history_pair[0]}\n### Assistant: {
history_pair[1]}\n"
input_text += f"### Human: {
current_user_input}\n### 参考法条: {
retrieve_output[0]['text']}\n{
retrieve_output[1]['text']}\n{
retrieve_output[2]['text']}\n### Assistant: "
print("=== Input ===")
print("input_text: ", input_text)
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=400, do_sample=False, repetition_penalty=1.1)
thread = StoppableThread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# 开始流式生成
chat_history[-1][1] = ""
for new_text in streamer:
chat_history[-1][1] += new_text
yield chat_history
streamer.end()
thread.stop()
print("Output: ", chat_history[-1][1])
input_msg.submit(user, [input_msg, chatbot], [input_msg, chatbot], queue=False).then(
bot, [chatbot], chatbot
)
generate_button.click(user, [input_msg, chatbot], [input_msg, chatbot], queue=False).then(
bot, [chatbot], chatbot
)
demo.queue()
demo.launch(share=False, server_port=args.port, server_name='0.0.0.0')