vanna_ai私有化部署(mysql数据库+Milvus向量库+Qwen2.5 LLM)

vanna_ai私有化部署(mysql数据库+Milvus向量库+Qwen2.5 LLM)

背景

在上一篇博客https://blog.csdn.net/m0_37659226/article/details/145797076提到vanna_ai本地部署,但是用到的还是百炼平台的大模型,如果要使用自己本地部署的大模型要怎么做呢?本博客将介绍如何使用本地部署的Milvus向量库+Qwen2.5 来部署vanna_ai

部署前准备

在这里插入图片描述
首先我这边已经部署了一个milvus了,但是官方文档没有给出milvus的连接配置,因此我们只能参考其他的向量库去写了

class MyVanna(Milvus_VectorStore, MyCustomLLM):
    def __init__(self, config=None):
        milvus_client = MilvusClient(uri=config['milvus_uri'], db_name=config['milvus_db'])
        Milvus_VectorStore.__init__(self, config={
    
    
            'milvus_client':milvus_client,
            'embedding_function':CustomEmbeddingFunction("bge_small"),
        })
        MyCustomLLM.__init__(self, config=config)

    # 实现抽象方法
    def assistant_message(self, message: str):
        """处理 AI 助理消息"""
        return {
    
    "role": "assistant", "content": message}

    def system_message(self, message: str):
        """处理系统消息"""
        return {
    
    "role": "system", "content": message}

    def user_message(self, message: str):
        """处理用户消息"""
        return {
    
    "role": "user", "content": message}

这里要说明一下因为我这边是出于无法访问外网的条件,没有办法从huggingface下载模型的,所以如果默认使用Milvus_VectorStore的embedding_function的话就无法避免要去huggingface下载模型了,所以这里我们打算使用自己的embedding模型,所以写了一个CustomEmbeddingFunction,以下是相关代码:

# 加载embedding
embedding_model_dict = {
    
    
    "ernie_tiny": "nghuyong/ernie-3.0-nano-zh",
    "ernie_base": "nghuyong/ernie-3.0-base-zh",
    "text2vec": "GanymedeNil/text2vec-large-chinese",
    "text2vec2": "uer/sbert-base-chinese-nli",
    "text2vec3": "shibing424/text2vec-base-chinese",
    "bge_small": "D:/HuggingFaceModel/bge-small-zh-v1.5"
}

class CustomEmbeddingFunction:
    def __init__(self, model_name="ernie-tiny"):
        encode_kwargs = {
    
    "normalize_embeddings": False}
        model_kwargs = {
    
    "device": "cuda:0"}
        self.embedding_model = HuggingFaceEmbeddings(
            model_name=embedding_model_dict[model_name],
            model_kwargs=model_kwargs,
            encode_kwargs=encode_kwargs
        )

    def encode_documents(self, documents: List[str]) -> List[np.ndarray]:
        # 将每个嵌入结果转换为 np.ndarray
        embeddings = self.embedding_model.embed_documents(documents)
        return [np.array(embedding) for embedding in embeddings]

    def encode_queries(self, queries: Union[str, List[str]]) -> List[np.ndarray]:
        # 统一处理输入,确保 queries 是字符串列表
        if isinstance(queries, str):
            queries = [queries]
        elif not isinstance(queries, list) or not all(isinstance(q, str) for q in queries):
            raise TypeError("queries must be a string or a list of strings")

        # 逐个查询进行嵌入
        embeddings = [self.embedding_model.embed_query(query) for query in queries]
        return [np.array(embedding) for embedding in embeddings]

下面是LLM部分

from vanna.base import VannaBase
from vanna.milvus import Milvus_VectorStore
from pymilvus import MilvusClient
from openai import OpenAI
from utils.file_processing import CustomEmbeddingFunction

class MyCustomLLM(VannaBase):
    def __init__(self, config=None):
        api_key = config["api_key"]
        model = config["model"]
        self.client = OpenAI(
            api_key=api_key,  # 获取 API 密钥
            base_url=model  # 设置基础 URL
        )

    def submit_prompt(self, prompt, **kwargs) -> str:
        # print("I want to see the prompt:")
        # print(prompt)
        chat_response = self.client.chat.completions.create(
            model="Qwen2.5-72B-Instruct",
            messages=prompt,
            temperature=0.7,
            top_p=0.8,
            max_tokens=5120,
            extra_body={
    
    
                "repetition_penalty": 1.05,
            },
        )
        return chat_response.choices[0].message.content

最后就是数据库连接和程序启动的代码了

# 下面的配置也要改
vn = MyVanna(
    config={
    
    
        "api_key": "EMPTY",
        "model": "改成自己的模型api",
        "milvus_uri": "http://127.0.0.1:19530", 
        "milvus_db": "vanna_db",
    }
)

# 使用pymysql.connect连接到数据库(下面配置改成自己的)
vn.connect_to_mysql(host="127.0.0.1",
                    port=3306,
                    dbname='uc',
                    user='root',
                    password='123456')


# vn.train(documentation="请注意,在我们公司一般将1作为是,0作为否。")

from vanna.flask import VannaFlaskApp
app = VannaFlaskApp(vn, debug=True, allow_llm_to_see_data=True,
                    title="数据库问答", subtitle="您的私人智能助手",
                    show_training_data=True, suggested_questions=False,
                    sql=True, table=True, csv_download=False, chart=True,
                    redraw_chart=False, auto_fix_sql=False,
                    ask_results_correct=False, followup_questions=False)
app.run()

下面就是主界面了
在这里插入图片描述