Paint2Pix代码笔记

代码来源:https://github.com/1jsingh/paint2pix

试玩demo网址:http://exposition.cecs.anu.edu.au:6009/

一、demo.py文件主要利用streamlit设计网页:

def display_alongside_batch(img_list, resize_dims=(256,256)):
    res = np.concatenate([np.array(img.resize(resize_dims)) for img in img_list], axis=1)
    return Image.fromarray(res)

 把img_list里的img tensor放在一排转为图片格式,即几张图横着连放在一起。

def main():
    # 初始化session state
    if 'button_id' not in st.session_state:
        st.session_state['button_id'] = ''
    if 'color_to_label' not in st.session_state:
        st.session_state['color_to_label'] = {}
    PAGES = {
        "Real image editing": paint2pix_demo,
        "Progressive image synthesis": paint2pix_demo,
        "Artistic content generation": paint2pix_demo,
    }
    st.sidebar.subheader("Paint2Pix Demos")
    page = st.sidebar.selectbox("Demo:", options=list(PAGES.keys()))
    PAGES[page](page.lower())

import streamlit as st

1、st.session_state:一种在重新运行之间共享变量的方式,用于每个用户会话。除了存储和保持状态,还可以使用回调(Callbacks)来操作状态。读取状态并展示用st.write.

2、st.sidebar:把组件放在侧边栏,组件加入方式用st.sidebar.[element_name]/with st.sidebar:st.[element_name].

3、st.title、st.header、subheader:文本以title/header/subheader三级格式展示,使用方式st.subheader("text input here").

4、st.selectbox(label, options, index=0, format_func=special_internal_function, key=None, help=None, on_change=None, args=None, kwargs=None, *, disabled=False):选择组件,其参数label放选择提示的字符串;options放选项,字符串型;index预选选项;format_func用于修改标签显示的函数...

def paint2pix_demo(demo_type='real image editing'):
    st.sidebar.header("Configuration")
    
    # Specify canvas parameters in application
    experiment_type = 'ffhq'
    
    # demo现在应该只提供ffhq
    if experiment_type == 'ffhq':
        resize_dims = (256,256)
    elif experiment_type == 'cars_encode':
        resize_dims = (192,256)
    
    st.sidebar.markdown('---')
    #用户可调笔刷粗细、透明度、色彩、形状
    st.sidebar.header('User brush stroke parameters:')
    stroke_width = st.sidebar.slider("Stroke width: ", 1, 100, 40)
    stroke_opacity = st.sidebar.slider("Stroke opacity: ", 0, 100, 80)
    stroke_color = st.sidebar.color_picker("Stroke color hex: ", "#E8BEAC")
    
    
    r,g,b = ImageColor.getcolor(stroke_color, "RGB")
    stroke_color_with_opacity = "rgba({},{},{},{})".format(r,g,b,stroke_opacity/100.0)
    
    bg_color = st.sidebar.color_picker("Background color hex: ", "#000000")
    drawing_mode = st.sidebar.selectbox(
        "Drawing tool:", ("freedraw", "line", "rect", "circle", "transform", "polygon")
    )
    realtime_update = st.sidebar.checkbox("Update in realtime", True)

    st.sidebar.subheader('Restyle Config:')
    restyle_iter = st.sidebar.slider("restyle iter: ", 1, 10, 5)
    restyle_select_iter = st.sidebar.slider("restyle select iter: ", 1, restyle_iter, restyle_iter)
    
    multi_modal = False
    num_multi_output = 5
    
    st.markdown(
    """
    Paint2Pix: Interactive Painting based Image Synthesis and Editing
    """
    )
    st.markdown("""<img src="https://1jsingh.github.io/assets/publications/images/paint2pix.png" alt="Streamlit logo" height="280"><br><br>""", unsafe_allow_html=True)
    
    #3个页面的文本
    if demo_type=='real image editing':
        st.header('Real Image Editing')
        st.markdown("Start editing by using a real image input or create your own image using brushstroke inputs")
    elif demo_type=='progressive image synthesis':
        st.header('Progressive Image Synthesis')
        st.markdown("Express your inner ideas ... synthesize your desired output image using just coarse scribbles")
    elif demo_type=='artistic content generation':
        st.header('Artistic Content Generation')
        st.markdown("Unleash your inner artist ... create high artistic paintings using just coarse scribbles")

5、st.markdown:用markdown语法写入需要显示的文本

6、st.slider(label, min_value=None, max_value=None, value=None, step=None, format=None, key=None, help=None, on_change=None, args=None, kwargs=None, *, disabled=False):滑块部件,value是预先选择的值。

7、st.color_picker(label, value=None, key=None, help=None, on_change=None, args=None, kwargs=None, *, disabled=False):选颜色的部件。

8、st.checkbox勾选框

from PIL import ImageColor

ImageColor.getcolor()把字符串型输入的颜色转为rgb三通道0~255的数值。

    # choose input image from a list of custom image inputs for real image editing
    if demo_type=='real image editing':
        # input-images文件夹下放着5张人脸图
        input_image_list = sorted(list(glob.glob('input-images/*')))
        input_image_options = [Image.open(x).resize((128,128)) for x in input_image_list]
        st.image(input_image_options, caption = ["Input image {}".format(i+1) for i in range(len(input_image_list))])
        selected_image = st.selectbox("Select image input:",["Input image {}".format(i+1) for i in range(len(input_image_list))]+["Custom Input"])
    else:
        selected_image = "Paint from scratch"
    
    # bg_image根据选择,可能是用户上传的照片/input-images下的一张图/None
    bg_image_container = st.empty()
    if selected_image == "Custom Input":
        bg_image = bg_image_container.file_uploader("Input image:", type=["png", "jpg"])
    elif selected_image.startswith('Input image'):
        bg_image = input_image_list[int(selected_image[-1])-1]
    else:
        bg_image = None

    get_value = lambda x: x if x is None or isinstance(x,str) else x.getvalue()
    # using canvas prediction as initialization of canvas
    if 'bg_img' not in st.session_state or get_value(st.session_state.bg_image) != get_value(bg_image):
        use_canvas_pred = False
        if 'bg_img' in st.session_state and get_value(st.session_state.bg_image) != get_value(bg_image):
            # Reset Session state
            for key in st.session_state.keys():
                del st.session_state[key]
            time.sleep(1)
        st.session_state.bg_image = bg_image
        
        # 用canvas_pred结果作bg_image
        if bg_image is None and 'canvas_pred' in st.session_state and use_canvas_pred:
            canvas_pred = cv2.resize(st.session_state['canvas_pred'][canvas_frame_id-1],resize_dims[::-1])
            bg_img = Image.fromarray(canvas_pred)
        elif bg_image is not None:
            # bg_image非None时,bg_img、session_state中的'prediction'存入bg_image
            bg_img = Image.open(bg_image).convert('RGB')
            st.session_state['prediction'] = Image.open(bg_image).convert('RGB')
        else:
            bg_img = None
        
        # 如果bg_img既不用选择的图也不用canvas_pred结果,则选择全0的黑图   
        if bg_img is None:
            bg_img_ = np.zeros((resize_dims[0],resize_dims[1],3)).astype('float32')
            bg_img = Image.new('RGB',(256,256))
        else:
            bg_img_ = np.float32(bg_img)/255.#.convert('RGBA')
            bg_img_ = cv2.resize(bg_img_,resize_dims[::-1])
        st.session_state.bg_img = bg_img
        st.session_state.bg_img_ = bg_img_
        st.session_state.real_image_input = bg_img

9、st.image展示一张图或一个列表的图,传入的图可以是PIL.Image.open打开的格式。

10、st.empty()一个单一元素的容器,可以插入/替换/清除一个元素。

11、st.file_uploader(label, type=None, accept_multiple_files=False, key=None, help=None, on_change=None, args=None, kwargs=None, *, disabled=False)是上传文件的部件,最大200MB,type是字符串/字符串列表型参数,用以注明文件上传的后缀;accept_multiple_files=True则可以一次性上传多个文件。

import glob

glob.glob()返回所有匹配的文件路径列表

lambda相当于表达式,格式为lambda x: 关于x的函数。

col1, col2 = st.columns(2)
    with col1:
        st.subheader('Canvas')
        stylized_output = True if demo_type=='artistic content generation' else False
        use_image_pred_as_input = st.button('Use Image Prediction as Canvas')
        if use_image_pred_as_input and 'prediction' in st.session_state:
            bg_img = st.session_state.id_constrained_pred
            bg_img_ = np.float32(bg_img)/255.
            bg_img_ = cv2.resize(bg_img_,resize_dims[::-1])
            st.session_state.bg_img = bg_img
            st.session_state.bg_img_ = bg_img_
            st.session_state.input_latent = st.session_state.output_latent
            st.session_state.real_image_input = bg_img
            st.session_state.input_latent_ = st.session_state.input_latent

        # Create a canvas component
        canvas_result = st_canvas(
            fill_color="rgba(255, 165, 0, 0.3)",  # Fixed fill color with some opacity
            stroke_width=stroke_width,
            stroke_color=stroke_color_with_opacity,
            background_color=bg_color,
            background_image=st.session_state.bg_img,
            update_streamlit=realtime_update,
            height=resize_dims[0],
            width=resize_dims[1],
            drawing_mode=drawing_mode,
            display_toolbar=st.sidebar.checkbox("Display drawing toolbar", True),
            key="full_app",
        )

12、st.columns(spec, *, gap="small"):插入spec个并排排列的容器,gap是列之间的距离,有"small", "medium", or "large"三种选择。通过with col1:这种方式在容器上加入部件。

13、st.button:一个按键,按下则返回True,否则为False.

14、from streamlit_drawable_canvas import st_canvas,在网页加一个手写输入面板。

with col2:
        # image completion predictions
        st.subheader('Image Prediction')
        id_constrain = False
        if True:
            # 所有图都是全黑初始化
            if 'prediction' not in st.session_state:
                st.session_state['prediction'] = Image.fromarray(np.zeros((resize_dims[0],resize_dims[1],3)).astype('uint8'))
            if 'multi_modal_prediction' not in st.session_state:
                st.session_state['multi_modal_prediction'] = Image.fromarray(np.zeros((resize_dims[0],resize_dims[1]*7,3)).astype('uint8'))
            if 'canvas_pred' not in st.session_state:
                st.session_state['canvas_pred'] = np.zeros((100,resize_dims[0],resize_dims[1],3)).astype('uint8')
            if 'restyle_prediction' not in st.session_state:
                st.session_state['restyle_prediction'] = np.zeros((10,resize_dims[0],resize_dims[1],3)).astype('uint8')
            if 'canvas' not in st.session_state:
                st.session_state['canvas'] = Image.fromarray(np.zeros((resize_dims[0],resize_dims[1],3)).astype('uint8'))
                
            img_pred_button = st.button('Predict')
            if img_pred_button:
                with st.spinner('Computing image completion prediction ...'):
                    # net=canvas_encoder
                    net, transform, opts = load_model(experiment_type,id_constrain,stylized_output=stylized_output)
                    # 没在canvas上画图时input_image=bg_img
                    if canvas_result.image_data is None:
                        input_image = bg_img
                    else:
                        painting = np.float32(canvas_result.image_data)/255.
                        alpha = cv2.cvtColor(painting[:,:,-1],cv2.COLOR_GRAY2RGB)
                        painting_fg = painting[:,:,:3]
                        # 把画图的前景背景分别根据乘以透明度加在一起
                        foreground = cv2.multiply(alpha, painting_fg)
                        background = cv2.multiply(1.0 - alpha, st.session_state.bg_img_)
                        outImage = cv2.add(foreground, background)
                        input_image = Image.fromarray(np.uint8(255*outImage))
                    st.session_state['canvas'] = input_image
                    if multi_modal:
                        latent_mask = list(range(mixing_layers[0],mixing_layers[1]+1))
                        mix_alpha = mix_alpha/100
                    else:
                        latent_mask, mix_alpha  = None, None
                    
                    #关键点
                    result_images, multi_out, latents = predict_image_completion(input_image, net, transform, opts, multi_modal=multi_modal, experiment_type=experiment_type, resize_dims=resize_dims,num_multi_output=num_multi_output, n_iters=restyle_iter,latent_mask=latent_mask,mix_alpha=mix_alpha, id_constrain=id_constrain, target_id_feat=None)
                st.session_state.restyle_prediction = result_images
                st.session_state.latents = latents[0]
                if multi_modal:
                    st.session_state['multi_modal_prediction'] = multi_out
            if multi_modal:
                bbox = (select_prediction*resize_dims[1],0,(select_prediction+1)*resize_dims[1],resize_dims[0])
                st.session_state.prediction = st.session_state['multi_modal_prediction'].crop(bbox)
            else:
                st.session_state.prediction = st.session_state.restyle_prediction[restyle_select_iter-1]
                # session_state中'prediction'存入restyle_prediction结果或multi_modal_prediction结果
            output_image_placeholder = st.empty()

15、st.spinner(text="In progress..."):在执行某段代码时临时显示文本,使用方式是with st.spinner(text="In progress..."):执行代码块。

with st.expander('Control Edit Strength'):
        st.markdown("The user can use extrapolation of edit strength in order to achieve desired output attributes. This is helpful in achieving semantic image edits *e.g.,* aging which are otherwise difficult to describe using just coarse user scribbles.")
        col1, col2 = st.columns(2)
        with col1:
            threshold_beta = st.number_input('threshold-beta',value=1.0,step=0.5)
        with col2:
            edit_alpha = st.number_input('edit-alpha',value=1.0,step=0.2)
        
        num_style_layers = 18 if experiment_type=='ffhq' else 16

        # 可编辑的StyleGAN层,预先选择的是0~17层,也就是全部
        editable_layers_global = st.slider('Editable StyleGAN layers:',0, num_style_layers-1, (0,num_style_layers-1))
        
        if img_pred_button:
            st.session_state.output_latent = st.session_state.latents[restyle_select_iter-1]
            if 'input_latent' not in st.session_state:
                st.session_state.input_latent = st.session_state.input_latent_ = st.session_state.output_latent
            
            mask = np.arange(editable_layers_global[0],editable_layers_global[1]+1)
            latent_mask = np.zeros(num_style_layers)
            latent_mask[mask] = 1.0
            
            delta_w = st.session_state.output_latent - st.session_state.input_latent_
                    
            delta_w = np.expand_dims(latent_mask,-1) * np.clip(delta_w,-threshold_beta,threshold_beta) 
            w0 = st.session_state.input_latent
            w1 = w0 + delta_w * edit_alpha
            # stylegan生图
            edited_img = decode_latent(w1, net, opts, experiment_type='ffhq', resize_dims=(256,256), truncation=1.)
            st.session_state.edited_img = edited_img
            st.session_state.output_latent = w1
            st.image([st.session_state.prediction.resize((200,200)),st.session_state.edited_img.resize((200,200))],["before","after"])

16、st.expander插入一个可以展开/折叠的多元素容器,折叠时只显示label.

17、st.number_input(label, min_value=None, max_value=None, value=, step=None, format=None, key=None, help=None, on_change=None, args=None, kwargs=None, *, disabled=False):输入数字的部件,参数value是预先显示值,step是增减数字时的步数。

np.expand_dims(a,axis)讲a在axis维上升维。

np.clip用于截取数组中小于或者大于某值的部分,所有比最小值小的数都会强制变为最小值,
所有比最大值大的数都会强制变为最大值。

 with st.expander('Identity Correction Config'):
        num_id_gd_iter = st.slider('Number of iterations:',0,1000,20)
        col1, col2 = st.columns(2)
        with col1:
            lambda_reg = st.number_input('lambda id-reg',value=0.01)
        with col2:
            editable_layers = st.slider('StyleGAN Editable Layers:',0, 17, (0,8))
            mask = np.arange(editable_layers[0],editable_layers[1]+1)
            latent_mask = np.zeros(18)
            latent_mask[mask] = 1.0
            
        col1, col2 = st.columns(2)
        with col1:
            use_encoder = st.checkbox('Use Identity Encoder',True)
        with col2:
            num_id_enc_iter = st.slider('Number of identity-encoding steps:',1,5,1)

        if img_pred_button:
            id_loss_func = load_faceid_model()
            x =  st.session_state.output_latent 
            target_img = st.session_state.real_image_input
            if use_encoder:
                original_img = st.session_state.edited_img
                id_net,_,_ = load_model(experiment_type,id_constrain=True,stylized_output=stylized_output)
                # 使用id_encoder
                id_constrained_pred, delta_w = encoder_based_id_edit(original_img, x, target_img, id_net, transform, opts, latent_mask=latent_mask, num_id_iter=num_id_enc_iter)
            else:
                id_constrained_pred, loss_log, delta_w = identity_constrained_latent_pred(x, target_img, net, transform, opts, id_loss_func, input_code=True, n_iter=num_id_gd_iter, lr=5e-3, lambda_id=1.0, lambda_reg=lambda_reg,latent_mask=latent_mask, lambda_l2=1e1)
                st.session_state.id_constrained_loss_log = loss_log
            st.session_state.id_constrained_pred = id_constrained_pred
            st.session_state.output_latent = st.session_state.output_latent + delta_w[0]
        if 'id_constrained_pred' in st.session_state:
            output_image_placeholder.image(st.session_state.id_constrained_pred)
            st.image([st.session_state.real_image_input.resize((200,200)),st.session_state.edited_img.resize((200,200)),st.session_state.id_constrained_pred.resize((200,200))],["Identity Image","w/o Identity Encoder","with Identity Encoder"])

二、predict.py文件中提供了论文中所用的模型:

用canvas_encoder的函数predict_image_completion:

def predict_image_completion(image, net, transform, opts, experiment_type='ffhq', resize_dims=(256,256), multi_modal=False, num_multi_output=5, n_iters=5, latent_mask=None ,mix_alpha=None, id_constrain=False, target_id_feat=None):
    opts.n_iters_per_batch = n_iters
    opts.resize_outputs = False  # generate outputs at full resolution
    transformed_image = transform(image).to(device)
    with torch.no_grad():
        avg_image = get_avg_image(net,experiment_type)
        images, latents = run_on_batch(transformed_image.unsqueeze(0), net, opts, avg_image)
    result_images, latent = images[0], latents[0]
    result_images = [tensor2im(result_images[iter_idx]).resize(resize_dims[::-1]) for iter_idx in range(opts.n_iters_per_batch)]
    
    if multi_modal:
        # randomly draw the latents to use for style mixing
        vectors_to_inject = np.random.randn(num_multi_output, 512).astype('float32')
        
        with torch.no_grad():
            latent = torch.tensor(latent[-1]).to("cuda").float().unsqueeze(0)
            multi_results = get_multi_modal_outputs(latent, net, vectors_to_inject, latent_mask, mix_alpha, input_code=True)
        img_list = [result_images[-1]] + [tensor2im(x).resize(resize_dims[::-1]) for x in multi_results]
        res = display_alongside_batch(img_list[0:],resize_dims)
        return result_images, res, latents
    else:
        return result_images, None, latents
    

run_on_batch函数:

def run_on_batch(inputs, net, opts, avg_image, target_id_feat=None):
    y_hat, latent = None, None
    results_batch = {idx: [] for idx in range(inputs.shape[0])}
    results_latent = {idx: [] for idx in range(inputs.shape[0])}
    for iter in range(opts.n_iters_per_batch):
        if iter == 0:
            avg_image_for_batch = avg_image.unsqueeze(0).repeat(inputs.shape[0], 1, 1, 1)
            x_input = torch.cat([inputs, avg_image_for_batch, inputs, avg_image_for_batch], dim=1)
        else:
            x_input = torch.cat([inputs, y_hat, inputs, y_hat], dim=1)

        y_hat, latent = net.forward(x_input,
                                    target_id_feat=target_id_feat,
                                    latent=latent,
                                    randomize_noise=False,
                                    return_latents=True,
                                    resize=opts.resize_outputs)

        if opts.dataset_type == "cars_encode":
            if opts.resize_outputs:
                y_hat = y_hat[:, :, 32:224, :]
            else:
                y_hat = y_hat[:, :, 64:448, :]

        # store intermediate outputs
        for idx in range(inputs.shape[0]):
            results_batch[idx].append(y_hat[idx])
            results_latent[idx].append(latent[idx].cpu().numpy())

        # resize input to 256 before feeding into next iteration
        if opts.dataset_type == "cars_encode":
            y_hat = torch.nn.AdaptiveAvgPool2d((192, 256))(y_hat)
        else:
            y_hat = net.face_pool(y_hat)

    return results_batch, results_latent

用id_encoder

def encoder_based_id_edit(original_img, initial_latent, target_img, net, transform, opts, latent_mask=None, num_id_iter=5):
    initial_latent = torch.tensor(initial_latent).to("cuda").float().unsqueeze(0)
    
    if latent_mask is None:
        latent_mask = np.ones(18)
    mask = torch.tensor(latent_mask).float().repeat((512,1)).transpose(1,0).unsqueeze(0).to(device)
    mask.requires_grad = False
    
    with torch.no_grad():
        avg_image_for_batch = transform(original_img).to(device).unsqueeze(0)
        x = transform(target_img).to(device).unsqueeze(0)
        
        latent1 = initial_latent
        for iter in range(num_id_iter):
            target_id_feat = None
            if iter == 0:
                x_input = torch.cat([x, avg_image_for_batch], dim=1)
            else:
                x_input = torch.cat([x, y_hat], dim=1)
                
            y_hat, latent2 = net.forward(x_input, target_id_feat=target_id_feat, latent=latent1, return_latents=True)
            latent2 = latent1 + (latent2 - latent1)*mask
            latent1 = latent2
            
            if opts.dataset_type == "cars_encode":
                y_hat = y_hat[:, :, 32:224, :]
                
    out_img = decode_latent(latent2,net, opts, preprocess=False) 
    
    return out_img, (latent2-initial_latent).detach().cpu().numpy()

猜你喜欢

转载自blog.csdn.net/qq_43522986/article/details/126699430