引入依赖
<dependencyManagement>
<dependencies>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-bom</artifactId>
<version>1.0.0-SNAPSHOT</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-redis-store</artifactId>
</dependency>
<dependency>
<groupId>redis.clients</groupId>
<artifactId>jedis</artifactId>
<version>5.1.0</version>
</dependency>
<dependency>
<groupId>com.baomidou</groupId>
<artifactId>mybatis-plus-spring-boot3-starter</artifactId>
<version>3.5.7</version>
</dependency>
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>8.0.32</version>
</dependency>
<dependency>
<groupId>org.jfree</groupId>
<artifactId>jfreechart</artifactId>
<version>1.5.3</version>
</dependency>
</dependencies>
代码
package com.qjc.demo.controller;
import jakarta.annotation.Resource;
import jakarta.servlet.http.HttpServletResponse;
import org.apache.commons.lang3.StringUtils;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartUtils;
import org.jfree.chart.JFreeChart;
import org.jfree.data.general.DefaultPieDataset;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestParam;
import java.io.IOException;
import java.io.OutputStream;
import java.sql.DatabaseMetaData;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Controller
public class SqlController {
@Resource
private ChatModel chatModel;
@Resource
private JdbcTemplate jdbcTemplate;
private final String FILTER_INSTRUCTION = """
你需要根据指定的Input从Instruction中筛选出最相关的表信息(可能是单个表或多个表),
首先,我将给你展示一个示例,Instruction后面跟着Input和对应的Response,
然后,我会给你一个新的Instruction和新的Input,你需要生成一个新的Response来完成任务。
### Example1 Instruction:
job(id, name, age), user(id, name, age), student(id, name, age, info)
### Example1 Input:
Find the age of student table
### Example1 Response:
student(id, name, age, info)
###New Instruction:
{instruction}
###New Input:
{input}
###New Response:
""";
private final String GENERATE_INSTRUCTION = """
你扮演一个SQL终端,您只需要返回SQL命令给我,而不需要返回其他任何字符。下面是一个描述任务的Instruction,返回适当的结果完成Input对应的请求.
###Instruction:
{instruction}
###Input:
{input}
###Response:
""";
@GetMapping("/chat")
public void chat(@RequestParam("query") String query, HttpServletResponse response) throws SQLException, IOException {
Map<String, List<String>> tableInfo = getTableInfo();
List<String> tableInfoList = tableInfo.entrySet().stream()
.map(entry -> String.format("%s(%s)", entry.getKey(), StringUtils.join(entry.getValue(), ",")))
.toList();
String tableInfoPrompt = StringUtils.join(tableInfoList, ",");
PromptTemplate filtePromptTemplate = new PromptTemplate(FILTER_INSTRUCTION);
filtePromptTemplate.add("instruction", tableInfoPrompt);
filtePromptTemplate.add("input", query);
String filterPrompt = filtePromptTemplate.render();
String filterResult = chatModel.call(filterPrompt);
PromptTemplate generatePromptTemplate = new PromptTemplate(GENERATE_INSTRUCTION);
generatePromptTemplate.add("instruction", filterResult);
generatePromptTemplate.add("input", query);
String generatePrompt = generatePromptTemplate.render();
String sql = chatModel.call(generatePrompt);
sql = sql.replace("```sql", "");
sql = sql.replace("```", "");
System.out.println(sql);
List<Map<String, Object>> maps = jdbcTemplate.queryForList(sql);
DefaultPieDataset dataset = new DefaultPieDataset();
for (Map<String, Object> map : maps) {
Object[] values = map.values().toArray();
dataset.setValue(values[0].toString(), Integer.valueOf(values[1].toString()));
}
JFreeChart chart = ChartFactory.createPieChart(
"统计结果",
dataset,
false,
true,
true);
response.setContentType("image/png");
OutputStream out = response.getOutputStream();
ChartUtils.writeChartAsPNG(out, chart, 800, 600);
out.flush();
}
public Map<String, List<String>> getTableInfo() throws SQLException {
DatabaseMetaData metaData = jdbcTemplate.getDataSource().getConnection().getMetaData();
ResultSet tables = metaData.getTables(null, null, "%", new String[]{
"TABLE"});
Map<String, List<String>> result = new HashMap<>();
while (tables.next()) {
String tableName = tables.getString("TABLE_NAME");
ResultSet columns = metaData.getColumns(null, null, tableName, null);
ArrayList<String> columnNames = new ArrayList<>();
while (columns.next()) {
String columnName = columns.getString("COLUMN_NAME");
String remarks = columns.getString("REMARKS");
columnNames.add(String.format("%s(%s)", columnName, remarks));
}
result.put(tableName, columnNames);
}
return result;
}
}