gradio入门示例

    随着chat-gpt等机器人对话框架的流行,让一个名为gradio的框架也火热起来,这个框架可以开启一个http服务,并且带输入输出界面,可以让对话类的人工智能项目快速运行。

    gradio号称可以快速部署ai可视化项目。

    下面通过两个示例来感受一下,首先我们需要安装gradio库。

pip install gradio

    接着编写如下的代码,用户输入一个字符串xxx,提交之后,输出一个hello,xxx 。

import gradio as gr


def hello(name):
    return "hello," + name + "!"


def launch():
    demo = gr.Interface(fn=hello, inputs='text', outputs='text')
    demo.launch()


if __name__ == '__main__':
    launch()

    运行这段代码,可以开启7860端口监听http服务, 浏览器访问http://localhost:7860,可以打开如下界面:

     再编写一个示例,是关于图像识别的,代码如下:

import torch
from PIL import Image
from torchvision import transforms
import gradio as gr
import json

with open('imagenet-simple-labels.json', 'r') as load_f:
    labels = json.load(load_f)
model = torch.hub.load("pytorch/vision:v0.6.0", "resnet18", pretrained=True).eval()


def predict(inp):
    inp = Image.fromarray(inp.astype("uint8"), "RGB")
    inp = transforms.ToTensor()(inp).unsqueeze(0)
    with torch.no_grad():
        prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
    return {labels[i]: float(prediction[i]) for i in range(1000)}


inputs = gr.Image()
outputs = gr.Label(num_top_classes=3)
demo = gr.Interface(fn=predict, inputs=inputs, outputs=outputs)

if __name__ == '__main__':
    demo.launch()


    运行代码,会下载pytorch/vision:v0.6.0版本,并下载一个resnet18的模型文件:resnet18-f37072fd.pth到用户目录下的.cache\torch\hub\checkpoints\目录下。

    运行打印信息如下:

    我们打开浏览器http://localhost:7860,在界面上选择我们事先准备好的豹子和狗的图片:

    这里识别了豹子,显示cheetah。

    换一只狗的再试一下:

    识别结果为一只拉布拉多。

    代码中设置了三个最可能的结果,outputs = gr.Label(num_top_classes=3),所以这里会列出最有可能的三种情况。

     以上代码运行的时候报了警告:

D:\Program Files\Python\Python310\lib\site-packages\torchvision\models\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
D:\Program Files\Python\Python310\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)

    意思是 torch.hub.load加载模型的时候,pretrained参数过时了,可以使用weights=ResNet18_Weights.DEFAULT替代。

    修改代码之后,就不报警告了。如下所示:

    官网的例子,文中有个文件来自https://git.io/JJkYN,现在已经无法下载了,但是它可以直接在github找到:https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json

    这里就是提前下载,然后通过json读取,内容是1000个目标标签。 

猜你喜欢

转载自blog.csdn.net/feinifi/article/details/130816960
今日推荐