代码来源:https://github.com/1jsingh/paint2pix
试玩demo网址:http://exposition.cecs.anu.edu.au:6009/
一、demo.py文件主要利用streamlit设计网页:
def display_alongside_batch(img_list, resize_dims=(256,256)):
res = np.concatenate([np.array(img.resize(resize_dims)) for img in img_list], axis=1)
return Image.fromarray(res)
把img_list里的img tensor放在一排转为图片格式,即几张图横着连放在一起。
def main():
# 初始化session state
if 'button_id' not in st.session_state:
st.session_state['button_id'] = ''
if 'color_to_label' not in st.session_state:
st.session_state['color_to_label'] = {}
PAGES = {
"Real image editing": paint2pix_demo,
"Progressive image synthesis": paint2pix_demo,
"Artistic content generation": paint2pix_demo,
}
st.sidebar.subheader("Paint2Pix Demos")
page = st.sidebar.selectbox("Demo:", options=list(PAGES.keys()))
PAGES[page](page.lower())
import streamlit as st
1、st.session_state:一种在重新运行之间共享变量的方式,用于每个用户会话。除了存储和保持状态,还可以使用回调(Callbacks)来操作状态。读取状态并展示用st.write.
2、st.sidebar:把组件放在侧边栏,组件加入方式用st.sidebar.[element_name]/with st.sidebar:st.[element_name].
3、st.title、st.header、subheader:文本以title/header/subheader三级格式展示,使用方式st.subheader("text input here").
4、st.selectbox(label, options, index=0, format_func=special_internal_function, key=None, help=None, on_change=None, args=None, kwargs=None, *, disabled=False):选择组件,其参数label放选择提示的字符串;options放选项,字符串型;index预选选项;format_func用于修改标签显示的函数...
def paint2pix_demo(demo_type='real image editing'):
st.sidebar.header("Configuration")
# Specify canvas parameters in application
experiment_type = 'ffhq'
# demo现在应该只提供ffhq
if experiment_type == 'ffhq':
resize_dims = (256,256)
elif experiment_type == 'cars_encode':
resize_dims = (192,256)
st.sidebar.markdown('---')
#用户可调笔刷粗细、透明度、色彩、形状
st.sidebar.header('User brush stroke parameters:')
stroke_width = st.sidebar.slider("Stroke width: ", 1, 100, 40)
stroke_opacity = st.sidebar.slider("Stroke opacity: ", 0, 100, 80)
stroke_color = st.sidebar.color_picker("Stroke color hex: ", "#E8BEAC")
r,g,b = ImageColor.getcolor(stroke_color, "RGB")
stroke_color_with_opacity = "rgba({},{},{},{})".format(r,g,b,stroke_opacity/100.0)
bg_color = st.sidebar.color_picker("Background color hex: ", "#000000")
drawing_mode = st.sidebar.selectbox(
"Drawing tool:", ("freedraw", "line", "rect", "circle", "transform", "polygon")
)
realtime_update = st.sidebar.checkbox("Update in realtime", True)
st.sidebar.subheader('Restyle Config:')
restyle_iter = st.sidebar.slider("restyle iter: ", 1, 10, 5)
restyle_select_iter = st.sidebar.slider("restyle select iter: ", 1, restyle_iter, restyle_iter)
multi_modal = False
num_multi_output = 5
st.markdown(
"""
Paint2Pix: Interactive Painting based Image Synthesis and Editing
"""
)
st.markdown("""<img src="https://1jsingh.github.io/assets/publications/images/paint2pix.png" alt="Streamlit logo" height="280"><br><br>""", unsafe_allow_html=True)
#3个页面的文本
if demo_type=='real image editing':
st.header('Real Image Editing')
st.markdown("Start editing by using a real image input or create your own image using brushstroke inputs")
elif demo_type=='progressive image synthesis':
st.header('Progressive Image Synthesis')
st.markdown("Express your inner ideas ... synthesize your desired output image using just coarse scribbles")
elif demo_type=='artistic content generation':
st.header('Artistic Content Generation')
st.markdown("Unleash your inner artist ... create high artistic paintings using just coarse scribbles")
5、st.markdown:用markdown语法写入需要显示的文本
6、st.slider(label, min_value=None, max_value=None, value=None, step=None, format=None, key=None, help=None, on_change=None, args=None, kwargs=None, *, disabled=False):滑块部件,value是预先选择的值。
7、st.color_picker(label, value=None, key=None, help=None, on_change=None, args=None, kwargs=None, *, disabled=False):选颜色的部件。
8、st.checkbox勾选框
from PIL import ImageColor
ImageColor.getcolor()把字符串型输入的颜色转为rgb三通道0~255的数值。
# choose input image from a list of custom image inputs for real image editing
if demo_type=='real image editing':
# input-images文件夹下放着5张人脸图
input_image_list = sorted(list(glob.glob('input-images/*')))
input_image_options = [Image.open(x).resize((128,128)) for x in input_image_list]
st.image(input_image_options, caption = ["Input image {}".format(i+1) for i in range(len(input_image_list))])
selected_image = st.selectbox("Select image input:",["Input image {}".format(i+1) for i in range(len(input_image_list))]+["Custom Input"])
else:
selected_image = "Paint from scratch"
# bg_image根据选择,可能是用户上传的照片/input-images下的一张图/None
bg_image_container = st.empty()
if selected_image == "Custom Input":
bg_image = bg_image_container.file_uploader("Input image:", type=["png", "jpg"])
elif selected_image.startswith('Input image'):
bg_image = input_image_list[int(selected_image[-1])-1]
else:
bg_image = None
get_value = lambda x: x if x is None or isinstance(x,str) else x.getvalue()
# using canvas prediction as initialization of canvas
if 'bg_img' not in st.session_state or get_value(st.session_state.bg_image) != get_value(bg_image):
use_canvas_pred = False
if 'bg_img' in st.session_state and get_value(st.session_state.bg_image) != get_value(bg_image):
# Reset Session state
for key in st.session_state.keys():
del st.session_state[key]
time.sleep(1)
st.session_state.bg_image = bg_image
# 用canvas_pred结果作bg_image
if bg_image is None and 'canvas_pred' in st.session_state and use_canvas_pred:
canvas_pred = cv2.resize(st.session_state['canvas_pred'][canvas_frame_id-1],resize_dims[::-1])
bg_img = Image.fromarray(canvas_pred)
elif bg_image is not None:
# bg_image非None时,bg_img、session_state中的'prediction'存入bg_image
bg_img = Image.open(bg_image).convert('RGB')
st.session_state['prediction'] = Image.open(bg_image).convert('RGB')
else:
bg_img = None
# 如果bg_img既不用选择的图也不用canvas_pred结果,则选择全0的黑图
if bg_img is None:
bg_img_ = np.zeros((resize_dims[0],resize_dims[1],3)).astype('float32')
bg_img = Image.new('RGB',(256,256))
else:
bg_img_ = np.float32(bg_img)/255.#.convert('RGBA')
bg_img_ = cv2.resize(bg_img_,resize_dims[::-1])
st.session_state.bg_img = bg_img
st.session_state.bg_img_ = bg_img_
st.session_state.real_image_input = bg_img
9、st.image展示一张图或一个列表的图,传入的图可以是PIL.Image.open打开的格式。
10、st.empty()一个单一元素的容器,可以插入/替换/清除一个元素。
11、st.file_uploader(label, type=None, accept_multiple_files=False, key=None, help=None, on_change=None, args=None, kwargs=None, *, disabled=False)是上传文件的部件,最大200MB,type是字符串/字符串列表型参数,用以注明文件上传的后缀;accept_multiple_files=True则可以一次性上传多个文件。
import glob
glob.glob()返回所有匹配的文件路径列表
lambda相当于表达式,格式为lambda x: 关于x的函数。
col1, col2 = st.columns(2)
with col1:
st.subheader('Canvas')
stylized_output = True if demo_type=='artistic content generation' else False
use_image_pred_as_input = st.button('Use Image Prediction as Canvas')
if use_image_pred_as_input and 'prediction' in st.session_state:
bg_img = st.session_state.id_constrained_pred
bg_img_ = np.float32(bg_img)/255.
bg_img_ = cv2.resize(bg_img_,resize_dims[::-1])
st.session_state.bg_img = bg_img
st.session_state.bg_img_ = bg_img_
st.session_state.input_latent = st.session_state.output_latent
st.session_state.real_image_input = bg_img
st.session_state.input_latent_ = st.session_state.input_latent
# Create a canvas component
canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
stroke_width=stroke_width,
stroke_color=stroke_color_with_opacity,
background_color=bg_color,
background_image=st.session_state.bg_img,
update_streamlit=realtime_update,
height=resize_dims[0],
width=resize_dims[1],
drawing_mode=drawing_mode,
display_toolbar=st.sidebar.checkbox("Display drawing toolbar", True),
key="full_app",
)
12、st.columns(spec, *, gap="small"):插入spec个并排排列的容器,gap是列之间的距离,有"small", "medium", or "large"三种选择。通过with col1:这种方式在容器上加入部件。
13、st.button:一个按键,按下则返回True,否则为False.
14、from streamlit_drawable_canvas import st_canvas,在网页加一个手写输入面板。
with col2:
# image completion predictions
st.subheader('Image Prediction')
id_constrain = False
if True:
# 所有图都是全黑初始化
if 'prediction' not in st.session_state:
st.session_state['prediction'] = Image.fromarray(np.zeros((resize_dims[0],resize_dims[1],3)).astype('uint8'))
if 'multi_modal_prediction' not in st.session_state:
st.session_state['multi_modal_prediction'] = Image.fromarray(np.zeros((resize_dims[0],resize_dims[1]*7,3)).astype('uint8'))
if 'canvas_pred' not in st.session_state:
st.session_state['canvas_pred'] = np.zeros((100,resize_dims[0],resize_dims[1],3)).astype('uint8')
if 'restyle_prediction' not in st.session_state:
st.session_state['restyle_prediction'] = np.zeros((10,resize_dims[0],resize_dims[1],3)).astype('uint8')
if 'canvas' not in st.session_state:
st.session_state['canvas'] = Image.fromarray(np.zeros((resize_dims[0],resize_dims[1],3)).astype('uint8'))
img_pred_button = st.button('Predict')
if img_pred_button:
with st.spinner('Computing image completion prediction ...'):
# net=canvas_encoder
net, transform, opts = load_model(experiment_type,id_constrain,stylized_output=stylized_output)
# 没在canvas上画图时input_image=bg_img
if canvas_result.image_data is None:
input_image = bg_img
else:
painting = np.float32(canvas_result.image_data)/255.
alpha = cv2.cvtColor(painting[:,:,-1],cv2.COLOR_GRAY2RGB)
painting_fg = painting[:,:,:3]
# 把画图的前景背景分别根据乘以透明度加在一起
foreground = cv2.multiply(alpha, painting_fg)
background = cv2.multiply(1.0 - alpha, st.session_state.bg_img_)
outImage = cv2.add(foreground, background)
input_image = Image.fromarray(np.uint8(255*outImage))
st.session_state['canvas'] = input_image
if multi_modal:
latent_mask = list(range(mixing_layers[0],mixing_layers[1]+1))
mix_alpha = mix_alpha/100
else:
latent_mask, mix_alpha = None, None
#关键点
result_images, multi_out, latents = predict_image_completion(input_image, net, transform, opts, multi_modal=multi_modal, experiment_type=experiment_type, resize_dims=resize_dims,num_multi_output=num_multi_output, n_iters=restyle_iter,latent_mask=latent_mask,mix_alpha=mix_alpha, id_constrain=id_constrain, target_id_feat=None)
st.session_state.restyle_prediction = result_images
st.session_state.latents = latents[0]
if multi_modal:
st.session_state['multi_modal_prediction'] = multi_out
if multi_modal:
bbox = (select_prediction*resize_dims[1],0,(select_prediction+1)*resize_dims[1],resize_dims[0])
st.session_state.prediction = st.session_state['multi_modal_prediction'].crop(bbox)
else:
st.session_state.prediction = st.session_state.restyle_prediction[restyle_select_iter-1]
# session_state中'prediction'存入restyle_prediction结果或multi_modal_prediction结果
output_image_placeholder = st.empty()
15、st.spinner(text="In progress..."):在执行某段代码时临时显示文本,使用方式是with st.spinner(text="In progress..."):执行代码块。
with st.expander('Control Edit Strength'):
st.markdown("The user can use extrapolation of edit strength in order to achieve desired output attributes. This is helpful in achieving semantic image edits *e.g.,* aging which are otherwise difficult to describe using just coarse user scribbles.")
col1, col2 = st.columns(2)
with col1:
threshold_beta = st.number_input('threshold-beta',value=1.0,step=0.5)
with col2:
edit_alpha = st.number_input('edit-alpha',value=1.0,step=0.2)
num_style_layers = 18 if experiment_type=='ffhq' else 16
# 可编辑的StyleGAN层,预先选择的是0~17层,也就是全部
editable_layers_global = st.slider('Editable StyleGAN layers:',0, num_style_layers-1, (0,num_style_layers-1))
if img_pred_button:
st.session_state.output_latent = st.session_state.latents[restyle_select_iter-1]
if 'input_latent' not in st.session_state:
st.session_state.input_latent = st.session_state.input_latent_ = st.session_state.output_latent
mask = np.arange(editable_layers_global[0],editable_layers_global[1]+1)
latent_mask = np.zeros(num_style_layers)
latent_mask[mask] = 1.0
delta_w = st.session_state.output_latent - st.session_state.input_latent_
delta_w = np.expand_dims(latent_mask,-1) * np.clip(delta_w,-threshold_beta,threshold_beta)
w0 = st.session_state.input_latent
w1 = w0 + delta_w * edit_alpha
# stylegan生图
edited_img = decode_latent(w1, net, opts, experiment_type='ffhq', resize_dims=(256,256), truncation=1.)
st.session_state.edited_img = edited_img
st.session_state.output_latent = w1
st.image([st.session_state.prediction.resize((200,200)),st.session_state.edited_img.resize((200,200))],["before","after"])
16、st.expander插入一个可以展开/折叠的多元素容器,折叠时只显示label.
17、st.number_input(label, min_value=None, max_value=None, value=, step=None, format=None, key=None, help=None, on_change=None, args=None, kwargs=None, *, disabled=False):输入数字的部件,参数value是预先显示值,step是增减数字时的步数。
np.expand_dims(a,axis)讲a在axis维上升维。
np.clip用于截取数组中小于或者大于某值的部分,所有比最小值小的数都会强制变为最小值,
所有比最大值大的数都会强制变为最大值。
with st.expander('Identity Correction Config'):
num_id_gd_iter = st.slider('Number of iterations:',0,1000,20)
col1, col2 = st.columns(2)
with col1:
lambda_reg = st.number_input('lambda id-reg',value=0.01)
with col2:
editable_layers = st.slider('StyleGAN Editable Layers:',0, 17, (0,8))
mask = np.arange(editable_layers[0],editable_layers[1]+1)
latent_mask = np.zeros(18)
latent_mask[mask] = 1.0
col1, col2 = st.columns(2)
with col1:
use_encoder = st.checkbox('Use Identity Encoder',True)
with col2:
num_id_enc_iter = st.slider('Number of identity-encoding steps:',1,5,1)
if img_pred_button:
id_loss_func = load_faceid_model()
x = st.session_state.output_latent
target_img = st.session_state.real_image_input
if use_encoder:
original_img = st.session_state.edited_img
id_net,_,_ = load_model(experiment_type,id_constrain=True,stylized_output=stylized_output)
# 使用id_encoder
id_constrained_pred, delta_w = encoder_based_id_edit(original_img, x, target_img, id_net, transform, opts, latent_mask=latent_mask, num_id_iter=num_id_enc_iter)
else:
id_constrained_pred, loss_log, delta_w = identity_constrained_latent_pred(x, target_img, net, transform, opts, id_loss_func, input_code=True, n_iter=num_id_gd_iter, lr=5e-3, lambda_id=1.0, lambda_reg=lambda_reg,latent_mask=latent_mask, lambda_l2=1e1)
st.session_state.id_constrained_loss_log = loss_log
st.session_state.id_constrained_pred = id_constrained_pred
st.session_state.output_latent = st.session_state.output_latent + delta_w[0]
if 'id_constrained_pred' in st.session_state:
output_image_placeholder.image(st.session_state.id_constrained_pred)
st.image([st.session_state.real_image_input.resize((200,200)),st.session_state.edited_img.resize((200,200)),st.session_state.id_constrained_pred.resize((200,200))],["Identity Image","w/o Identity Encoder","with Identity Encoder"])
二、predict.py文件中提供了论文中所用的模型:
用canvas_encoder的函数predict_image_completion:
def predict_image_completion(image, net, transform, opts, experiment_type='ffhq', resize_dims=(256,256), multi_modal=False, num_multi_output=5, n_iters=5, latent_mask=None ,mix_alpha=None, id_constrain=False, target_id_feat=None):
opts.n_iters_per_batch = n_iters
opts.resize_outputs = False # generate outputs at full resolution
transformed_image = transform(image).to(device)
with torch.no_grad():
avg_image = get_avg_image(net,experiment_type)
images, latents = run_on_batch(transformed_image.unsqueeze(0), net, opts, avg_image)
result_images, latent = images[0], latents[0]
result_images = [tensor2im(result_images[iter_idx]).resize(resize_dims[::-1]) for iter_idx in range(opts.n_iters_per_batch)]
if multi_modal:
# randomly draw the latents to use for style mixing
vectors_to_inject = np.random.randn(num_multi_output, 512).astype('float32')
with torch.no_grad():
latent = torch.tensor(latent[-1]).to("cuda").float().unsqueeze(0)
multi_results = get_multi_modal_outputs(latent, net, vectors_to_inject, latent_mask, mix_alpha, input_code=True)
img_list = [result_images[-1]] + [tensor2im(x).resize(resize_dims[::-1]) for x in multi_results]
res = display_alongside_batch(img_list[0:],resize_dims)
return result_images, res, latents
else:
return result_images, None, latents
run_on_batch函数:
def run_on_batch(inputs, net, opts, avg_image, target_id_feat=None):
y_hat, latent = None, None
results_batch = {idx: [] for idx in range(inputs.shape[0])}
results_latent = {idx: [] for idx in range(inputs.shape[0])}
for iter in range(opts.n_iters_per_batch):
if iter == 0:
avg_image_for_batch = avg_image.unsqueeze(0).repeat(inputs.shape[0], 1, 1, 1)
x_input = torch.cat([inputs, avg_image_for_batch, inputs, avg_image_for_batch], dim=1)
else:
x_input = torch.cat([inputs, y_hat, inputs, y_hat], dim=1)
y_hat, latent = net.forward(x_input,
target_id_feat=target_id_feat,
latent=latent,
randomize_noise=False,
return_latents=True,
resize=opts.resize_outputs)
if opts.dataset_type == "cars_encode":
if opts.resize_outputs:
y_hat = y_hat[:, :, 32:224, :]
else:
y_hat = y_hat[:, :, 64:448, :]
# store intermediate outputs
for idx in range(inputs.shape[0]):
results_batch[idx].append(y_hat[idx])
results_latent[idx].append(latent[idx].cpu().numpy())
# resize input to 256 before feeding into next iteration
if opts.dataset_type == "cars_encode":
y_hat = torch.nn.AdaptiveAvgPool2d((192, 256))(y_hat)
else:
y_hat = net.face_pool(y_hat)
return results_batch, results_latent
用id_encoder
def encoder_based_id_edit(original_img, initial_latent, target_img, net, transform, opts, latent_mask=None, num_id_iter=5):
initial_latent = torch.tensor(initial_latent).to("cuda").float().unsqueeze(0)
if latent_mask is None:
latent_mask = np.ones(18)
mask = torch.tensor(latent_mask).float().repeat((512,1)).transpose(1,0).unsqueeze(0).to(device)
mask.requires_grad = False
with torch.no_grad():
avg_image_for_batch = transform(original_img).to(device).unsqueeze(0)
x = transform(target_img).to(device).unsqueeze(0)
latent1 = initial_latent
for iter in range(num_id_iter):
target_id_feat = None
if iter == 0:
x_input = torch.cat([x, avg_image_for_batch], dim=1)
else:
x_input = torch.cat([x, y_hat], dim=1)
y_hat, latent2 = net.forward(x_input, target_id_feat=target_id_feat, latent=latent1, return_latents=True)
latent2 = latent1 + (latent2 - latent1)*mask
latent1 = latent2
if opts.dataset_type == "cars_encode":
y_hat = y_hat[:, :, 32:224, :]
out_img = decode_latent(latent2,net, opts, preprocess=False)
return out_img, (latent2-initial_latent).detach().cpu().numpy()