《模型部署》—— 客户端与服务端之间的交互实现模型的输出结果

客户端

  • 采用PyCharm的requests库中的方法实现客户端向服务端发送请求

  • 完整代码如下:

    import requests
    
    # 127.0.0.1:5012 ——> 127.0.0.1 ——> 自己电脑发送给自己的ip, 5012 ——> 端口号
    # predict 为服务端下返回请求结果的函数名
    flask_url = 'http://127.0.0.1:5012/predict'
    
    # 定义发送请求内容的函数
    def predict_result(image_path):
        image = open(image_path, 'rb').read()
        payload = {
          
          'image': image}
    
        r = requests.post(flask_url, files=payload).json()
        # 向flask_url服务端发送一个POST请求,并将返回的JSON响应解析为一个Python字典
    
        if r['success']:
            # 输出结果
            for (i, result) in enumerate(r['predictions']):
                print('{}.预测类别为{}:的概率:{}'.format(i + 1, result['label'], result['probability']))
        # 失败了就打印
        else:
            print('Request failed')
    
    
    if __name__ == '__main__':
        predict_result('./train/6/image_07162.jpg')
    

服务端

  • 此案例中,服务端的功能是对花朵的识别,当客户端发送一张花朵的图片给服务端,服务端将返回这张花朵图片的类别,以及置信度。

  • 通过PyCharm中的 flask 方法创建一个服务端应用程序

  • 注意:在客户端发送请求之前,服务端需提前启动程序,确保服务程序处在一个等待请求的状态

  • 完整代码如下

    import io
    import flask
    import torch
    import torch.nn.functional as F
    from PIL import Image
    from torch import nn
    from torchvision import transforms, models, datasets
    
    # 初始化 Flask app
    app = flask.Flask(__name__)  # 创建一个新的flask应用程序实例
    # __name__ 参数通常被传递给flask应用程序的根路径,这样Flask就可以知道在哪里找到模板、静态文件等
    # 总体来说 app = flask.Flask(__name__) 是Flask应用程序的起点,它初始化了一个新的Flask应用程序实例。为后续添加路由、配置等奠定基础
    
    model = None
    use_gpu = False
    
    # 部署resnet18网络对花朵识别训练后的最优模型
    def load_model():
        global model
        # 加载 resnet18 网络
        model = models.resnet18()
        num_ftrs = model.fc.in_features
        model.fc = nn.Sequential(nn.Linear(num_ftrs, 102))  # 类别数根据实际要求定义
    
        # 调用最优模型
        checkpoint = torch.load('best.pth')
        model.load_state_dict(checkpoint['state_dict'])
        # 将模型指定为测试模式
        model.eval()
    
        # 是否使用GPU
        if use_gpu:
            model.cuda()
    
    
    # 数据预处理
    def prepare_image(image, target_size):
        # 针对不同模型,image的格式不同,但需要统一为RGB格式
        if image.mode != 'RGB':
            image = image.convert('RGB')
    
        # 按照所使用的模型输入图片的尺寸修改,并转为 tensor 类型
        image = transforms.Resize(target_size)(image)
        image = transforms.ToTensor()(image)
    
        # 数据均衡化,这里的参数和数据中是对应的
        image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)
        # 增加一个维度,用于按 batch 测试  本次这里一次测试一张
        image = image[None]
        if use_gpu:
            image = image.cuda()
    
        return torch.tensor(image)
    
    
    # 是一个装饰器,用于将指定的 URL 路径(在这个例子中是 /predict)与一个函数关联起来, 并指定该函数响应的HTTP方法(在这个例子中是POST方法)
    @app.route("/predict", methods=["POST"])
    def predict():
        # 做一个标志,刚开始无图像传入时为 False 传入图片时为 True
        data = {
          
          "success": False}
    
        if flask.request.method == 'POST':  # 如果收到请求
    
            if flask.request.files.get("image"):  # 判断是否是图像
    
                image = flask.request.files["image"].read()  # 将收到的图像进行读取,内容为二进制
                image = Image.open(io.BytesIO(image))  # io.BytesIO
                # 图像预处理
                image = prepare_image(image, target_size=(224, 224))
    
                preds = F.softmax(model(image), dim=1)  # 得到各个类别的概率
                results = torch.topk(preds.cpu().data, k=3, dim=1)  # 获取概率最大的前三个结果
                # torch.topk 用于返回输入张量中每行最大的k个元素及对应的索引
                results = (results[0].cpu().numpy(), results[1].cpu().numpy())
                # 将data字典增加一个key, value,其中value为list格式
                data['predictions'] = list()
    
                for prob, label in zip(results[0][0], results[1][0]):
                    r = {
          
          "label": str(label), "probability": float(prob)}
                    data['predictions'].append(r)
                data['success'] = True
    
        return flask.jsonify(data)  # 最后结果以json格式文件输出,为了能让客户端以多种不同的编程语言来解析获取
    
    
    if __name__ == '__main__':
        print('Loading PyTorch model and Flask starting server ...')
        print('Please wait until server has fully started')
        # 调用部署模型的函数
        load_model()
        # 当客户端与服务端在同一台设备上时,只需给出端口号
        app.run(port='5012')    # 给出端口号,为了能让客户端从指定的端口发送请求
    
        # 当客户端与服务端不在同一台设备上时,给定服务端的 IPv4 地址,并指定端口号
        # 客户端通过服务端的 IPv4 地址,和端口号来发送请求
        # app.run(host='192.168.24.68', port='5012')
    

猜你喜欢

转载自blog.csdn.net/weixin_73504499/article/details/143275976