Spring AI实现自然语言生成SQL和报表

Spring AI实现自然语言生成SQL和报表

引入依赖

<dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>org.springframework.ai</groupId>
                <artifactId>spring-ai-bom</artifactId>
                <version>1.0.0-SNAPSHOT</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
        </dependencies>
    </dependencyManagement>

    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>

        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-openai-spring-boot-starter</artifactId>
        </dependency>

        <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-redis-store</artifactId>
        </dependency>

        <dependency>
            <groupId>redis.clients</groupId>
            <artifactId>jedis</artifactId>
            <version>5.1.0</version>
        </dependency>


        <dependency>
            <groupId>com.baomidou</groupId>
            <artifactId>mybatis-plus-spring-boot3-starter</artifactId>
            <version>3.5.7</version>
        </dependency>

        <dependency>
            <groupId>mysql</groupId>
            <artifactId>mysql-connector-java</artifactId>
            <version>8.0.32</version>
        </dependency>

        <dependency>
            <groupId>org.jfree</groupId>
            <artifactId>jfreechart</artifactId>
            <version>1.5.3</version>
        </dependency>
    </dependencies>

代码

package com.qjc.demo.controller;

import jakarta.annotation.Resource;
import jakarta.servlet.http.HttpServletResponse;
import org.apache.commons.lang3.StringUtils;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartUtils;
import org.jfree.chart.JFreeChart;
import org.jfree.data.general.DefaultPieDataset;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestParam;

import java.io.IOException;
import java.io.OutputStream;
import java.sql.DatabaseMetaData;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/***
 * @projectName spring-ollama-demo
 * @packageName com.qjc.demo.controller
 * @author qjc
 * @description TODO
 * @Email [email protected]
 * @date 2024-10-24 17:56
 **/
@Controller
public class SqlController {
    
    

    @Resource
    private ChatModel chatModel;

    @Resource
    private JdbcTemplate jdbcTemplate;

    private final String FILTER_INSTRUCTION = """
    你需要根据指定的Input从Instruction中筛选出最相关的表信息(可能是单个表或多个表),
    首先,我将给你展示一个示例,Instruction后面跟着Input和对应的Response,
    然后,我会给你一个新的Instruction和新的Input,你需要生成一个新的Response来完成任务。

    ### Example1 Instruction:
    job(id, name, age), user(id, name, age), student(id, name, age, info)
    ### Example1 Input:
    Find the age of student table
    ### Example1 Response:
    student(id, name, age, info)
    ###New Instruction:
    {instruction}
    ###New Input:
    {input}
    ###New Response:
    """;

    private final String GENERATE_INSTRUCTION = """
    你扮演一个SQL终端,您只需要返回SQL命令给我,而不需要返回其他任何字符。下面是一个描述任务的Instruction,返回适当的结果完成Input对应的请求.
    ###Instruction:
    {instruction}
    ###Input:
    {input}
    ###Response:
    """;

    @GetMapping("/chat")
    public void chat(@RequestParam("query") String query, HttpServletResponse response) throws SQLException, IOException {
    
    

        Map<String, List<String>> tableInfo = getTableInfo();

        List<String> tableInfoList = tableInfo.entrySet().stream()
        .map(entry -> String.format("%s(%s)", entry.getKey(), StringUtils.join(entry.getValue(), ",")))
        .toList();

        String tableInfoPrompt = StringUtils.join(tableInfoList, ",");

        PromptTemplate filtePromptTemplate = new PromptTemplate(FILTER_INSTRUCTION);
        filtePromptTemplate.add("instruction", tableInfoPrompt);
        filtePromptTemplate.add("input", query);

        String filterPrompt = filtePromptTemplate.render();
        String filterResult = chatModel.call(filterPrompt);

        PromptTemplate generatePromptTemplate = new PromptTemplate(GENERATE_INSTRUCTION);
        generatePromptTemplate.add("instruction", filterResult);
        generatePromptTemplate.add("input", query);

        String generatePrompt = generatePromptTemplate.render();
        String sql = chatModel.call(generatePrompt);
        sql = sql.replace("```sql", "");
        sql = sql.replace("```", "");
        System.out.println(sql);

        List<Map<String, Object>> maps = jdbcTemplate.queryForList(sql);

        DefaultPieDataset dataset = new DefaultPieDataset();
        for (Map<String, Object> map : maps) {
    
    
            Object[] values = map.values().toArray();
            dataset.setValue(values[0].toString(), Integer.valueOf(values[1].toString()));
        }

        // 创建JFreeChart对象
        JFreeChart chart = ChartFactory.createPieChart(
                "统计结果", // 图标题
                dataset, // 数据集
                false,
                true,
                true);

        // 设置响应类型为图片
        response.setContentType("image/png");
        OutputStream out = response.getOutputStream();

        // 将图表输出到响应的输出流中
        ChartUtils.writeChartAsPNG(out, chart, 800, 600);
        out.flush();
    }

    public Map<String, List<String>> getTableInfo() throws SQLException {
    
    
        // 获取数据库的元数据信息
        DatabaseMetaData metaData = jdbcTemplate.getDataSource().getConnection().getMetaData();
        ResultSet tables = metaData.getTables(null, null, "%", new String[]{
    
    "TABLE"});

        Map<String, List<String>> result = new HashMap<>();
        while (tables.next()) {
    
    
            String tableName = tables.getString("TABLE_NAME");

            ResultSet columns = metaData.getColumns(null, null, tableName, null);
            ArrayList<String> columnNames = new ArrayList<>();
            while (columns.next()) {
    
    
                String columnName = columns.getString("COLUMN_NAME");
                String remarks = columns.getString("REMARKS");
                columnNames.add(String.format("%s(%s)", columnName, remarks));
            }

            result.put(tableName, columnNames);
        }

        return result;
    }
}