使用预训练好的 DALLE 模型进行 Text-to-Image 生成图像

使用预训练好的 DALLE 模型进行 Text-to-Image 任务

Hugging Face 文档:https://huggingface.co/kuprel/min-dalle

安装库:

pip install min-dalle

本文使用的库:

import torch
from min_dalle import MinDalle

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

加载模型:

model = MinDalle(
    models_root='./pretrained',  # 预训练模型的保存地址, 运行代码时自动从网上下载到这里, 即使该地址不存在都没事
                                 # 首次运行时需等待较长时间, 因为从网上下载预训练好的模型需要一些时间
    dtype=torch.float32,
    device='cuda',
    is_mega=True,  # True 表示使用dalle-mega, 大模型, 效果更好, 占用显存也多
                   # False表示使用dalle-mini, 小模型
    is_reusable=True
)

生成图像:

images = model.generate_images(
    text='Objects in the photo: Dessert, Fast food, Snack, Drink',  # 文本
    seed=-1,
    grid_size=3,               # 最终生成的图像为 grid_size*grid_size 个
    is_seamless=False,
    temperature=1,
    top_k=256,                 # 从生成的 top-k 个中再选择最贴合文本的 grid_size*grid_size 张图像
    supercondition_factor=16,
    is_verbose=False
)

显示并保存生成的图片:

images = images.to('cpu').numpy()  # images.shape = (grid_size^2, 256, 256, 3)

# 显示图片
for i in range(images.shape[0]):
    image = Image.fromarray(np.uint8(images[i]))
    plt.subplot(3, 3, i+1)  #表示第i张图片,下标只能从1开始,不能从0
    plt.imshow(image)
    plt.axis('off')  # 去掉横纵坐标
plt.show()

# 保存图片
for i in range(images.shape[0]):
    image = Image.fromarray(np.uint8(images[i]))
    image.save('image_{}.png'.format(i))  # 保存地址

使用 dalle-mini 生成图像的示例(因为显存不够所以用的是 dalle-mini,而且没有 fine-tune,所以效果并不是很好):

完整代码:

https://github.com/friedrichor/Text-to-Image-Summary/blob/main/demo/DALLE.ipynb

猜你喜欢

转载自blog.csdn.net/Friedrichor/article/details/128086733