AI对话交互场景使用WebSocket建立H5客户端和服务端的信息实时双向通信

WebSocket使得客户端和服务器之间的数据交换变得更加简单,允许服务端主动向客户端推送数据。在WebSocket API中,浏览器和服务器只需要完成一次握手,两者之间就可以创建持久性的连接,并进行双向数据传输。

一、为什么需要 WebSocket?

初次接触 WebSocket 的人,都会问同样的问题:我们已经有了 HTTP 协议,为什么还需要另一个协议?它能带来什么好处?

答案很简单,因为 HTTP 协议有一个缺陷:通信只能由客户端发起。

举例来说,我们想了解今天的天气,只能是客户端向服务器发出请求,服务器返回查询结果。HTTP 协议做不到服务器主动向客户端推送信息。
在这里插入图片描述
这种单向请求的特点,注定了如果服务器有连续的状态变化,客户端要获知就非常麻烦。我们只能使用"轮询":每隔一段时候,就发出一个询问,了解服务器有没有新的信息。最典型的场景就是聊天室。

轮询的效率低,非常浪费资源(因为必须不停连接,或者 HTTP 连接始终打开)。因此,工程师们一直在思考,有没有更好的方法。WebSocket 就是这样发明的。

二、简介

WebSocket 协议在2008年诞生,2011年成为国际标准。所有浏览器都已经支持了。

它的最大特点就是,服务器可以主动向客户端推送信息,客户端也可以主动向服务器发送信息,是真正的双向平等对话,属于服务器推送技术的一种。
在这里插入图片描述
其他特点包括:

(1)建立在 TCP 协议之上,服务器端的实现比较容易。

(2)与 HTTP 协议有着良好的兼容性。默认端口也是80和443,并且握手阶段采用 HTTP 协议,因此握手时不容易屏蔽,能通过各种 HTTP 代理服务器。

(3)数据格式比较轻量,性能开销小,通信高效。

(4)可以发送文本,也可以发送二进制数据。

(5)没有同源限制,客户端可以与任意服务器通信。

(6)协议标识符是ws(如果加密,则为wss),服务器网址就是 URL。

在这里插入图片描述

服务端的实现

依赖spring-boot-starter-websocket模块实现WebSocket实时对话交互。

CustomTextWebSocketHandler,扩展的TextWebSocketHandler


import cn.hutool.json.JSONUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.PongMessage;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;

import java.util.concurrent.CountDownLatch;

/**
 * 文本处理器
 *
 * @see org.springframework.web.socket.handler.TextWebSocketHandler
 */
@Slf4j
public class CustomTextWebSocketHandler extends TextWebSocketHandler {
    
    
    /**
     * 第三方身份,消息身份
     */
    private String thirdPartyId;
    /**
     * 回复消息内容
     */
    private String replyContent;
    private StringBuilder replyContentBuilder;
    /**
     * 完成信号
     */
    private final CountDownLatch doneSignal;

    public CustomTextWebSocketHandler(CountDownLatch doneSignal) {
    
    
        this.doneSignal = doneSignal;
    }

    public String getThirdPartyId() {
    
    
        return thirdPartyId;
    }

    public String getReplyContent() {
    
    
        return replyContent;
    }

    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
    
    
        log.info("connection established, session={}", session);
        replyContentBuilder = new StringBuilder(16);
//        super.afterConnectionEstablished(session);
    }

    @Override
    public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
    
    
        super.handleMessage(session, message);
    }

    /**
     * 消息已接收完毕("stop")
     */
    private static final String MESSAGE_DONE = "[DONE]";

    @Override
    protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
    
    
//        super.handleTextMessage(session, message);
        String payload = message.getPayload();
        log.info("payload={}", payload);
        OpenAiReplyResponse replyResponse = Jsons.fromJson(payload, OpenAiReplyResponse.class);
        if (replyResponse != null && replyResponse.isSuccess()) {
    
    
            String msg = replyResponse.getMsg();
            if (Strings.isEmpty(msg)) {
    
    
                return;
            } else if (msg.startsWith("【超出最大单次回复字数】")) {
    
    
                // {"msg":"【超出最大单次回复字数】该提示由GPT官方返回,非我司限制,请缩减回复字数","code":1,
                // "extParam":"{\"chatId\":\"10056:8889007174\",\"requestId\":\"b6af5830a5a64fa8a4ca9451d7cb5f6f\",\"bizId\":\"\"}",
                // "id":"chatcmpl-7LThw6J9KmBUOcwK1SSOvdBP2vK9w"}
                return;
            } else if (msg.startsWith("发送内容包含敏感词")) {
    
    
                // {"msg":"发送内容包含敏感词,请修改后重试。不合规汇如下:炸弹","code":1,
                // "extParam":"{\"chatId\":\"10024:8889006970\",\"requestId\":\"828068d945c8415d8f32598ef6ef4ad6\",\"bizId\":\"430\"}",
                // "id":"4d4106c3-f7d4-4393-8cce-a32766d43f8b"}
                matchSensitiveWords = msg;
                // 请求完成
                doneSignal.countDown();
                return;
            } else if (MESSAGE_DONE.equals(msg)) {
    
    
                // 消息已接收完毕
                replyContent = replyContentBuilder.toString();
                thirdPartyId = replyResponse.getId();
                // 请求完成
                doneSignal.countDown();
                log.info("replyContent={}", replyContent);
                return;
            }
            replyContentBuilder.append(msg);
        }
    }

    @Override
    protected void handlePongMessage(WebSocketSession session, PongMessage message) throws Exception {
    
    
        super.handlePongMessage(session, message);
    }

    @Override
    public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
    
    
        replyContentBuilder = null;
        log.info("handle transport error, session={}", session, exception);
        doneSignal.countDown();
//        super.handleTransportError(session, exception);
    }

    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
    
    
        replyContentBuilder = null;
        log.info("connection closed, session={}, status={}", session, status);
        if (status == CloseStatus.NORMAL) {
    
    
            log.error("connection closed fail, session={}, status={}", session, status);
        }
        doneSignal.countDown();
//        super.afterConnectionClosed(session, status);
    }
}

OpenAiHandler


/**
 * OpenAI处理器
 */
public interface OpenAiHandler<Req, Rsp> {
    
    
    /**
     * 请求前置处理
     *
     * @param req 入参
     */
    default void beforeRequest(Req req) {
    
    
        //
    }

    /**
     * 响应后置处理
     *
     * @param req 入参
     * @param rsp 出参
     */
    default void afterResponse(Req req, Rsp rsp) {
    
    
        //
    }
}

OpenAiService


/**
 * OpenAI服务
 * <pre>
 * API reference introduction
 * https://platform.openai.com/docs/api-reference/introduction
 * </pre>
 */
public interface OpenAiService<Req, Rsp> extends OpenAiHandler<Req, Rsp> {
    
    
    /**
     * 补全指令
     *
     * @param req 入参
     * @return 出参
     */
    default Rsp completions(Req req) {
    
    
        beforeRequest(req);
        Rsp rsp = doCompletions(req);
        afterResponse(req, rsp);
        return rsp;
    }

    /**
     * 操作补全指令
     *
     * @param req 入参
     * @return 出参
     */
    Rsp doCompletions(Req req);
}

OpenAiServiceImpl


import cn.hutool.core.util.StrUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Configuration;
import org.springframework.stereotype.Service;
import org.springframework.util.StopWatch;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.client.WebSocketClient;

import javax.annotation.Nullable;
import java.io.IOException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

/**
 * OpenAI服务实现
 */
@Slf4j
@Configuration(proxyBeanMethods = false)
@EnableConfigurationProperties(OpenAiProperties.class)
@Service("openAiService")
public class OpenAiServiceImpl implements OpenAiService<CompletionReq, CompletionRsp> {
    
    

    private final OpenAiProperties properties;
    /**
     * 套接字客户端
     */
    private final WebSocketClient webSocketClient;
    /**
     * 模型请求记录服务
     */
    private final ModelRequestRecordService modelRequestRecordService;

    private static final String THREAD_NAME_PREFIX = "gpt.openai";

    public OpenAiServiceImpl(
            OpenAiProperties properties,
            ModelRequestRecordService modelRequestRecordService
    ) {
    
    
        this.properties = properties;
        this.modelRequestRecordService = modelRequestRecordService;
        webSocketClient = WebSocketUtil.applyWebSocketClient(THREAD_NAME_PREFIX);
        log.info("create OpenAiServiceImpl instance");
    }

    @Override
    public void beforeRequest(CompletionReq req) {
    
    
        // 请求身份
        if (Strings.isEmpty(req.getRequestId())) {
    
    
            req.setRequestId(UuidUtil.getUuid());
        }
    }

    @Override
    public void afterResponse(CompletionReq req, CompletionRsp rsp) {
    
    
        if (rsp == null || Strings.isEmpty(rsp.getReplyContent())) {
    
    
            return;
        }
        // 三方敏感词检测
        String matchSensitiveWords = rsp.getMatchSensitiveWords();
        if (Strings.isNotEmpty(matchSensitiveWords)) {
    
    
            // 敏感词命中
            rsp.setMatchSensitiveWords(matchSensitiveWords);
            return;
        }
        // 阶段任务耗时统计
        StopWatch stopWatch = new StopWatch(req.getRequestId());
        try {
    
    
            // 敏感词检测
            stopWatch.start("checkSensitiveWord");
            String replyContent = rsp.getReplyContent();
//            ApiResult<String> apiResult = checkMsg(replyContent, false);
//            stopWatch.stop();
//            if (!apiResult.isSuccess() && Strings.isNotEmpty(apiResult.getData())) {
    
    
//                // 敏感词命中
//                rsp.setMatchSensitiveWords(apiResult.getData());
//                return;
//            }
            // 记录落库
            stopWatch.start("saveModelRequestRecord");
            ModelRequestRecord entity = applyModelRequestRecord(req, rsp);
            modelRequestRecordService.save(entity);
        } finally {
    
    
            if (stopWatch.isRunning()) {
    
    
                stopWatch.stop();
            }
            log.info("afterResponse execute time, {}", stopWatch);
        }
    }

    private static ModelRequestRecord applyModelRequestRecord(
            CompletionReq req, CompletionRsp rsp) {
    
    
        Long orgId = req.getOrgId();
        Long userId = req.getUserId();
        String chatId = applyChatId(orgId, userId);
        return new ModelRequestRecord()
                .setOrgId(orgId)
                .setUserId(userId)
                .setModelType(req.getModelType())
                .setRequestId(req.getRequestId())
                .setBizId(req.getBizId())
                .setChatId(chatId)
                .setThirdPartyId(rsp.getThirdPartyId())
                .setInputMessage(req.getMessage())
                .setReplyContent(rsp.getReplyContent());
    }

    private static String applyChatId(Long orgId, Long userId) {
    
    
        return orgId + ":" + userId;
    }

    private static String applySessionId(String appId, String chatId) {
    
    
        return appId + '_' + chatId;
    }

    private static final String URI_TEMPLATE = "wss://socket.******.com/websocket/{sessionId}";

    @Nullable
    @Override
    public CompletionRsp doCompletions(CompletionReq req) {
    
    
        // 阶段任务耗时统计
        StopWatch stopWatch = new StopWatch(req.getRequestId());
        stopWatch.start("doHandshake");

        // 闭锁,相当于一扇门(同步工具类)
        CountDownLatch doneSignal = new CountDownLatch(1);
        CustomTextWebSocketHandler webSocketHandler = new CustomTextWebSocketHandler(doneSignal);
        String chatId = applyChatId(req.getOrgId(), req.getUserId());
        String sessionId = applySessionId(properties.getAppId(), chatId);
        ListenableFuture<WebSocketSession> listenableFuture = webSocketClient
                .doHandshake(webSocketHandler, URI_TEMPLATE, sessionId);
        stopWatch.stop();
        stopWatch.start("getWebSocketSession");
        long connectionTimeout = properties.getConnectionTimeout().getSeconds();
        try (WebSocketSession webSocketSession = listenableFuture.get(connectionTimeout, TimeUnit.SECONDS)) {
    
    
            stopWatch.stop();
            stopWatch.start("sendMessage");
            OpenAiParam param = applyParam(chatId, req);
            webSocketSession.sendMessage(new TextMessage(Jsons.toJson(param)));
            long requestTimeout = properties.getRequestTimeout().getSeconds();
            // wait for all to finish
            boolean await = doneSignal.await(requestTimeout, TimeUnit.SECONDS);
            if (!await) {
    
    
                log.error("await doneSignal fail, req={}", req);
            }
            String replyContent = webSocketHandler.getReplyContent();
            String matchSensitiveWords = webSocketHandler.getMatchSensitiveWords();
            if (Strings.isEmpty(replyContent) && Strings.isEmpty(matchSensitiveWords)) {
    
    
                // 消息回复异常
                return null;
            }
            String delimiters = properties.getDelimiters();
            replyContent = StrUtil.replaceFirst(replyContent, delimiters, "");
            replyContent = StrUtil.replaceLast(replyContent, delimiters, "");
            String thirdPartyId = webSocketHandler.getThirdPartyId();
            return new CompletionRsp()
                    .setThirdPartyId(thirdPartyId)
                    .setReplyContent(replyContent)
                    .setMatchSensitiveWords(matchSensitiveWords);
        } catch (InterruptedException | ExecutionException | TimeoutException e) {
    
    
            log.error("get WebSocketSession fail, req={}", req, e);
        } catch (IOException e) {
    
    
            log.error("sendMessage fail, req={}", req, e);
        } finally {
    
    
            if (stopWatch.isRunning()) {
    
    
                stopWatch.stop();
            }
            log.info("doCompletions execute time, {}", stopWatch);
        }
        return null;
    }

    private static final int MIN_TOKENS = 11;

    /**
     * 限制单次最大回复单词数(tokens)
     */
    private static int applyMaxTokens(int reqMaxTokens, int maxTokensConfig) {
    
    
        if (reqMaxTokens < MIN_TOKENS || maxTokensConfig < reqMaxTokens) {
    
    
            return maxTokensConfig;
        }
        return reqMaxTokens;
    }

    private OpenAiParam applyParam(String chatId, CompletionReq req) {
    
    
        OpenAiDataExtParam extParam = new OpenAiDataExtParam()
                .setChatId(chatId)
                .setRequestId(req.getRequestId())
                .setBizId(req.getBizId());
        // 提示
        String prompt = req.getPrompt();
        // 分隔符
        String delimiters = properties.getDelimiters();
        String message = prompt + delimiters + req.getMessage() + delimiters;
        int maxTokens = applyMaxTokens(req.getMaxTokens(), properties.getMaxTokens());
        OpenAiData data = new OpenAiData()
                .setMsg(message)
                .setContext(properties.getContext())
                .setLimitTokens(maxTokens)
                .setExtParam(extParam);
        String sign = OpenAiUtil.applySign(message, properties.getSecret());
        return new OpenAiParam()
                .setData(data)
                .setSign(sign);
    }
}

WebSocketUtil


import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.client.WebSocketClient;
import org.springframework.web.socket.client.standard.StandardWebSocketClient;

/**
 * WebSocket辅助方法
 */
public final class WebSocketUtil {
    
    
    /**
     * 创建一个新的WebSocket客户端
     */
    public static WebSocketClient applyWebSocketClient(String threadNamePrefix) {
    
    
        StandardWebSocketClient webSocketClient = new StandardWebSocketClient();
        int cpuNum = Runtime.getRuntime().availableProcessors();
        ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
        taskExecutor.setCorePoolSize(cpuNum);
        taskExecutor.setMaxPoolSize(200);
        taskExecutor.setDaemon(true);
        if (StringUtils.hasText(threadNamePrefix)) {
    
    
            taskExecutor.setThreadNamePrefix(threadNamePrefix);
        } else {
    
    
            taskExecutor.setThreadNamePrefix("gpt.web.socket");
        }
        taskExecutor.initialize();
        webSocketClient.setTaskExecutor(taskExecutor);
        return webSocketClient;
    }
}

OpenAiUtil


import org.springframework.util.DigestUtils;

import java.nio.charset.StandardCharsets;

/**
 * OpenAi辅助方法
 */
public final class OpenAiUtil {
    
    
    /**
     * 对消息内容进行md5加密
     *
     * @param message 消息内容
     * @param secret 加签密钥
     * @return 十六进制加密后的消息内容
     */
    public static String applySign(String message, String secret) {
    
    
        String data = message + secret;
        byte[] dataBytes = data.getBytes(StandardCharsets.UTF_8);
        return DigestUtils.md5DigestAsHex(dataBytes);
    }
}

参考资料

猜你喜欢

转载自blog.csdn.net/shupili141005/article/details/130811504