Python使用AI animegan2-pytorch制作属于你的漫画头像/风景图片

Python使用AI animegan2-pytorch制作属于你的漫画头像

git clone https://github.com/bryandlee/animegan2-pytorch
cd ./animegan2-pytorch
python test.py --photo_path images/photo_test.jpg --save_path images/animegan2_result.png

1. 效果图

官方效果图如下:

效果图v2 512模型如下:
在这里插入图片描述

效果图v1 512模型如下:
在这里插入图片描述

效果图v1 效果不太好如下:
在这里插入图片描述

效果图rece如下
人物会有一种病态的美,过于白了,风景上效果更好一些;
人物与photo2cartoon的效果图有点像;
在这里插入图片描述

在这里插入图片描述

效果图paprika 模型如下
人物纹理痕迹太过明显,更适合风景
下一张明兰的效果还不错,不同的模型在不同的图像上也会有些微差别;
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

origin vs v1Res vs v2Res vs paprikaRes vs celedistillResAll 风景效果对比图如下:

在这里插入图片描述
在这里插入图片描述

origin vs v1Res vs v2Res vs paprikaRes vs celedistillResAll 人物效果对比图如下:
在这里插入图片描述
在这里插入图片描述

2. 原理

人像/风景卡通风格渲染的目标是,在保持原图像 ID 信息和纹理细节的同时,将真实照片转换为卡通风格的非真实感图像。

3. 源码

源码及示例文件模型等见资源:https://download.csdn.net/download/qq_40985985/87739198

# animegan2-pytroch 生成漫画头像或者风景图
# python test.py --checkpoint weights/face_paint_512_v2.pt --input_dir samples/faces/ --device cpu --output_dir samples/resv2
# model loaded: weights/face_paint_512_v2.pt

import os
import argparse

from PIL import Image
import numpy as np

import torch
from torchvision.transforms.functional import to_tensor, to_pil_image

from model import Generator


torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


def load_image(image_path, x32=False):
    img = Image.open(image_path).convert("RGB")

    if x32:
        def to_32s(x):
            return 256 if x < 256 else x - x % 32
        w, h = img.size
        img = img.resize((to_32s(w), to_32s(h)))

    return img


def test(args):
    device = args.device
    
    net = Generator()
    net.load_state_dict(torch.load(args.checkpoint, map_location="cpu"))
    net.to(device).eval()
    print(f"model loaded: {
      
      args.checkpoint}")
    
    os.makedirs(args.output_dir, exist_ok=True)

    for image_name in sorted(os.listdir(args.input_dir)):
        if os.path.splitext(image_name)[-1].lower() not in [".jpg", ".png", ".bmp", ".tiff"]:
            continue
            
        image = load_image(os.path.join(args.input_dir, image_name), args.x32)

        with torch.no_grad():
            image = to_tensor(image).unsqueeze(0) * 2 - 1
            out = net(image.to(device), args.upsample_align).cpu()
            out = out.squeeze(0).clip(-1, 1) * 0.5 + 0.5
            out = to_pil_image(out)

        out.save(os.path.join(args.output_dir, image_name))
        print(f"image saved: {
      
      image_name}")


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--checkpoint',
        type=str,
        default='./weights/paprika.pt',
    )
    parser.add_argument(
        '--input_dir', 
        type=str, 
        default='./samples/inputs',
    )
    parser.add_argument(
        '--output_dir', 
        type=str, 
        default='./samples/results',
    )
    parser.add_argument(
        '--device',
        type=str,
        default='cuda:0',
    )
    parser.add_argument(
        '--upsample_align',
        type=bool,
        default=False,
        help="Align corners in decoder upsampling layers"
    )
    parser.add_argument(
        '--x32',
        action="store_true",
        help="Resize images to multiple of 32"
    )
    args = parser.parse_args()
    
    test(args)

# 原图VS效果图绘制
# python plot_sample.py

# 获取输入路径的所有图像
import cv2
import imutils
import numpy as np
from imutils import paths

imagePaths = sorted(list(paths.list_images("samples")))

list = [x for x in imagePaths if x.find('inputs') > 0]
print(list)

resv1 = [x for x in imagePaths if x.find("resv1") > 0]
resv2 = [x for x in imagePaths if x.find("resv2") > 0]
cele = [x for x in imagePaths if x.find("cele") > 0]
pap = [x for x in imagePaths if x.find("paprika") > 0]

img = None
for i in list:
    if (i.find("ml2.jpg") < 0): continue
    img = None
    for j in resv1:
        if (j.split("\\")[2].__eq__(i.split("\\")[2])):
            origin = cv2.imread(i)
            res = cv2.imread(j)
            if (origin.shape[0] != res.shape[0] or origin.shape[1] != res.shape[1]):
                res = cv2.resize(res, (origin.shape[1], origin.shape[0]))
            # print(origin.shape, res.shape)
            # print('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res')
            cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res',
                       imutils.resize(np.hstack([origin, res]), width=300))
            if (img is None):
                img = imutils.resize(np.hstack([origin, res]), width=300)
            else:
                imgA = np.vstack([img, imutils.resize(np.hstack([origin, res]), width=300)])

                img = imgA
            cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'ResAll',
                       img)
            # cv2.waitKey(0)
    for j in resv2:
        if (j.split("\\")[2].__eq__(i.split("\\")[2])):
            origin = cv2.imread(i)
            res = cv2.imread(j)
            if (origin.shape[0] != res.shape[0] or origin.shape[1] != res.shape[1]):
                res = cv2.resize(res, (origin.shape[1], origin.shape[0]))
            # cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res',
            #            imutils.resize(np.hstack([origin, res]), width=300))
            if (img is None):
                img = imutils.resize(np.hstack([origin, res]), width=300)
            else:
                imgA = np.vstack([img, imutils.resize(np.hstack([origin, res]), width=300)])

                img = imgA
            # cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'ResAll',
            #            img)
            # cv2.waitKey(0)
    for j in pap:
        if (j.split("\\")[2].__eq__(i.split("\\")[2])):
            # print('--------------\t', i, j)
            origin = cv2.imread(i)
            res = cv2.imread(j)
            if (origin.shape[0] != res.shape[0] or origin.shape[1] != res.shape[1]):
                res = cv2.resize(res, (origin.shape[1], origin.shape[0]))
            # print(origin.shape, res.shape)
            # print('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res')
            # cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res',
            #            imutils.resize(np.hstack([origin, res]), width=300))
            # list.append(imutils.resize(np.hstack([origin, res]), width=300))
            if (img is None):
                img = imutils.resize(np.hstack([origin, res]), width=300)
            else:
                imgA = np.vstack([img, imutils.resize(np.hstack([origin, res]), width=300)])

                img = imgA
            # cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'ResAll',
            #            img)
            # cv2.waitKey(0)
    for j in cele:
        if (j.split("\\")[2].__eq__(i.split("\\")[2])):
            # print('--------------\t', i, j)
            origin = cv2.imread(i)
            res = cv2.imread(j)
            if (origin.shape[0] != res.shape[0] or origin.shape[1] != res.shape[1]):
                res = cv2.resize(res, (origin.shape[1], origin.shape[0]))
            # print(origin.shape, res.shape)
            # print('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res')
            # cv2.imshow('origin vs ' + j.split("\\")[1].replace("res", "") + 'Res',
            #            imutils.resize(np.hstack([origin, res]), width=300))
            # list.append(imutils.resize(np.hstack([origin, res]), width=300))
            if (img is None):
                img = imutils.resize(np.hstack([origin, res]), width=300)
            else:
                imgA = np.vstack([img, imutils.resize(np.hstack([origin, res]), width=300)])

                img = imgA
            cv2.imshow('origin vs v1Res vs v2Res vs paprikaRes vs celedistillResAll',
                       img)
            cv2.waitKey(0)

参考

猜你喜欢

转载自blog.csdn.net/qq_40985985/article/details/130379775