使用预训练好的 DALLE 模型进行 Text-to-Image 任务
Hugging Face 文档:https://huggingface.co/kuprel/min-dalle
安装库:
pip install min-dalle
本文使用的库:
import torch
from min_dalle import MinDalle
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
加载模型:
model = MinDalle(
models_root='./pretrained', # 预训练模型的保存地址, 运行代码时自动从网上下载到这里, 即使该地址不存在都没事
# 首次运行时需等待较长时间, 因为从网上下载预训练好的模型需要一些时间
dtype=torch.float32,
device='cuda',
is_mega=True, # True 表示使用dalle-mega, 大模型, 效果更好, 占用显存也多
# False表示使用dalle-mini, 小模型
is_reusable=True
)
生成图像:
images = model.generate_images(
text='Objects in the photo: Dessert, Fast food, Snack, Drink', # 文本
seed=-1,
grid_size=3, # 最终生成的图像为 grid_size*grid_size 个
is_seamless=False,
temperature=1,
top_k=256, # 从生成的 top-k 个中再选择最贴合文本的 grid_size*grid_size 张图像
supercondition_factor=16,
is_verbose=False
)
显示并保存生成的图片:
images = images.to('cpu').numpy() # images.shape = (grid_size^2, 256, 256, 3)
# 显示图片
for i in range(images.shape[0]):
image = Image.fromarray(np.uint8(images[i]))
plt.subplot(3, 3, i+1) #表示第i张图片,下标只能从1开始,不能从0
plt.imshow(image)
plt.axis('off') # 去掉横纵坐标
plt.show()
# 保存图片
for i in range(images.shape[0]):
image = Image.fromarray(np.uint8(images[i]))
image.save('image_{}.png'.format(i)) # 保存地址
使用 dalle-mini 生成图像的示例(因为显存不够所以用的是 dalle-mini,而且没有 fine-tune,所以效果并不是很好):
完整代码:
https://github.com/friedrichor/Text-to-Image-Summary/blob/main/demo/DALLE.ipynb