FRESCO 源码解析(AttentionControl+FRESCO_Attention)

本文主要解析FRESCO如何存储前面keyframe的 attention feature,并将其用于后续keyframe采样过程的的。(但不涉及分析FRESCO-based optimization特征优化的源码)

FRESCO Inference 流程

get_models函数中,加载好Stable Diffusion后:

  • 使用apply_freeu为unet设置FreeU。
  • 使用apply_FRESCO_attn创建FRESCOAttnProcessor2_0,将unet的up_blocks.2up_blocks.3的AttnProcessor设置为FRESCOAttnProcessor2_0,其他block的AttnProcessor设置为AttnProcessor2_0。返回的 frescoProc 用于控制unet中FRESCO_Attention的开关。
  • 使用apply_FRESCO_opt设置unet特征优化。
    if config['use_freeu']:
        from src.free_lunch_utils import apply_freeu
        apply_freeu(pipe, b1=1.2, b2=1.5, s1=1.0, s2=1.0)

    frescoProc = apply_FRESCO_attn(pipe)
    frescoProc.controller.disable_controller()
    apply_FRESCO_opt(pipe)
    print('create diffusion model ' + config['sd_path'] + ' successfully!')

将key frames划分为多个batch,每个batch一起去噪:注意第一个batch的帧数=batch_size后面batch的帧数都=batch_size-2,少的那2帧用于 为当前batch 引入 前一个batch的信息! 详见inference代码的record_latents

    sublists = [keys[i:i+config['batch_size']-2] for i in range(2, len(keys), config['batch_size']-2)]
    sublists[0].insert(0, keys[0])
    sublists[0].insert(1, keys[1])
    if len(sublists) > 1 and len(sublists[-1]) < 3:  # 如果最后一个子列表的长度小于3,则从前一个子列表中取出几个元素,让3个元素组成最后一个的子列表
        add_num = 3 - len(sublists[-1])
        sublists[-1] = sublists[-2][-add_num:] + sublists[-1]
        sublists[-2] = sublists[-2][:-add_num]

    if not sublists[-2]:
        del sublists[-2]

在开始推理之前,get_flow_and_interframe_paras计算帧间的光流追踪关系;get_intraframe_paras

        # prepare parameters for inter-frame and intra-frame consistency
        flows, occs, attn_mask, interattn_paras = get_flow_and_interframe_paras(flow_model, imgs)
        correlation_matrix = get_intraframe_paras(pipe, imgs_torch, frescoProc, 
                            prompt_embeds, seed = config['seed'])
        '''
        Flexible settings for attention:
        * Turn off FRESCO-guided attention: frescoProc.controller.disable_controller() 
        Then you can turn on one specific attention submodule
        * Turn on Cross-frame attention: frescoProc.controller.enable_cfattn(attn_mask) 
        * Turn on Spatial-guided attention: frescoProc.controller.enable_intraattn() 
        * Turn on Temporal-guided attention: frescoProc.controller.enable_interattn(interattn_paras)
    
        Flexible settings for optimization:
        * Turn off Spatial-guided optimization: set optimize_temporal = False in apply_FRESCO_opt()
        * Turn off Temporal-guided optimization: set correlation_matrix = [] in apply_FRESCO_opt()
        * Turn off FRESCO-guided optimization: disable_FRESCO_opt(pipe)
    
        Flexible settings for background smoothing:
        * Turn off background smoothing: set saliency = None in apply_FRESCO_opt()
        '''    
        # Turn on all FRESCO support
        frescoProc.controller.enable_controller(interattn_paras=interattn_paras, attn_mask=attn_mask)

inference执行FRESCO的推理:

        # run!
        latents = inference(pipe, controlnet, frescoProc, 
                  imgs_torch, prompt_embeds, edges, timesteps,
                  cond_scale, config['num_inference_steps'], config['num_warmup_steps'], 
                  do_classifier_free_guidance, config['seed'], guidance_scale, config['use_controlnet'],         
                  record_latents, propagation_mode,
                  flows = flows, occs = occs, saliency=saliency, repeat_noise=True)

1. apply_FRESCO_attn

PART I - FRESCO-based attention
* Class AttentionControl: Control the function of FRESCO-based attention
* Class FRESCOAttnProcessor2_0: FRESCO-based attention
* apply_FRESCO_attn(): Apply FRESCO-based attention to a StableDiffusionPipeline

apply_FRESCO_attn函数创建FRESCOAttnProcessor2_0对象和AttnProcessor2_0对象,unet中有3个down_blocks、1个mid_blocks、3个up_blocks,但只将unet的up_blocks.2up_blocks.3的AttnProcessor设置为FRESCOAttnProcessor2_0,其他block的AttnProcessor设置为AttnProcessor2_0。返回的 frescoProc 用于控制unet中FRESCO_Attention的开关。

def apply_FRESCO_attn(pipe):
    """
    Apply FRESCO-guided attention to a StableDiffusionPipeline
    """    
    frescoProc = FRESCOAttnProcessor2_0(2, AttentionControl())
    attnProc = AttnProcessor2_0()
    attn_processor_dict = {
    
    }
    for k in pipe.unet.attn_processors.keys():
        if k.startswith("up_blocks.2") or k.startswith("up_blocks.3"):
            attn_processor_dict[k] = frescoProc
        else:
            attn_processor_dict[k] = attnProc
    pipe.unet.set_attn_processor(attn_processor_dict)
    return frescoProc

AttentionControl

AttentionControl类用于存储Unet的decoder中帧间的self-attention feature。存储的attention feature在stored_attn中。Intra-attention (帧内部注意力/空间注意力)、Inter-attention (帧之间注意力/时间注意力)

class AttentionControl():
    """
    Control FRESCO-based attention
    * enable/diable spatial-guided attention
    * enable/diable temporal-guided attention
    * enable/diable cross-frame attention
    * collect intermediate attention feature (for spatial-guided attention)
    """
    def __init__(self):
        self.stored_attn = self.get_empty_store()  # attention feature存储
        self.store = False  # enable_store和disable_store 是否存储注意力信息
        self.index = 0  # spatial-guided attention(帧内)参数
        self.attn_mask = None  # cross frame attention参数:attn_mask 
        self.interattn_paras = None  # temporal-guided attention(帧间)参数:fwd_mappings, bwd_mappings, interattn_masks
        self.use_interattn = False  # enable_intraattn和disable_intraattn 是否使用spatial-guided attention(帧内)
        self.use_cfattn = False  # enable_cfattn和disable_cfattn 是否使用corss-frame attention
        self.use_intraattn = False  # enable_interattn和disable_interattn 是否使用temporal-guided attention(帧间)
        self.intraattn_bias = 0
        self.intraattn_scale_factor = 0.2
        self.interattn_scale_factor = 0.2
    
    @staticmethod
    def get_empty_store():
        return {
    
    
            'decoder_attn': [],
        }
	def clear_store(self):
        del self.stored_attn
        torch.cuda.empty_cache()
        gc.collect()
        self.stored_attn = self.get_empty_store()
        self.disable_intraattn()
        
    # store attention feature of the input frame for spatial-guided attention
    def enable_store(self):
        self.store = True
        
    def disable_store(self):
        self.store = False  

下面几个方法用于控制3种attention是否执行:

  • spatial-guided attention
  • temporal-guided attention
  • cross-frame attention
    # spatial-guided attention
    def enable_intraattn(self):
        self.index = 0
        self.use_intraattn = True
        self.disable_store()
        if len(self.stored_attn['decoder_attn']) == 0:
            self.use_intraattn = False
        
    def disable_intraattn(self):
        self.index = 0
        self.use_intraattn = False
        self.disable_store()

    def disable_cfattn(self):
        self.use_cfattn = False        

    # cross frame attention
    def enable_cfattn(self, attn_mask=None):
        if attn_mask:
            if self.attn_mask:
                del self.attn_mask
                torch.cuda.empty_cache()
            self.attn_mask = attn_mask
            self.use_cfattn = True  
        else:
            if self.attn_mask:
                self.use_cfattn = True
            else:
                print('Warning: no valid cross-frame attention parameters available!')
                self.disable_cfattn()       
        
    def disable_interattn(self):
        self.use_interattn = False

    # temporal-guided attention
    def enable_interattn(self, interattn_paras=None):
        if interattn_paras:
            if self.interattn_paras:
                del self.interattn_paras
                torch.cuda.empty_cache()
            self.interattn_paras = interattn_paras
            self.use_interattn = True
        else:
            if self.interattn_paras:
                self.use_interattn = True
            else:
                print('Warning: no valid temporal-guided attention parameters available!')
                self.disable_interattn()

enable_controllerdisable_controller用于同时控制3种attention是否使用。

__call__调用时执行forward,如果 self.store 为真,将传入的context(attention feature)存入stored_attn['decoder_attn']中,如果启用了 self.use_intraattn,并且 self.stored_attn['decoder_attn'] 中有存储的attention feature,并且已经存储的数量超过了索引 self.index,则返回存储的attention feature。如果没有存储的attention feature可用或者未启用 self.use_intraattn,则直接返回输入的 context。

    def disable_controller(self):
        self.disable_intraattn()
        self.disable_interattn()
        self.disable_cfattn()
    
    def enable_controller(self, interattn_paras=None, attn_mask=None):
        self.enable_intraattn()
        self.enable_interattn(interattn_paras)
        self.enable_cfattn(attn_mask)    
    
    def forward(self, context):
        if self.store:
            self.stored_attn['decoder_attn'].append(context.detach())
        if self.use_intraattn and len(self.stored_attn['decoder_attn']) > 0:
            tmp = self.stored_attn['decoder_attn'][self.index]
            self.index = self.index + 1
            if self.index >= len(self.stored_attn['decoder_attn']):
                self.index = 0
                self.disable_store()
            return tmp
        return context
    
    def __call__(self, context):
        context = self.forward(context)
        return context

FRESCOAttnProcessor2_0

创建FRESCOAttnProcessor2_0类时,需要传入unet_chunk_sizeAttentionControl对象。用于将unet的 self attention替换为FRESCO-based attention(spatial-guided attention、temporal-guided attention、cross-frame attention)

class FRESCOAttnProcessor2_0:
    """
    Hack self attention to FRESCO-based attention
    * adding spatial-guided attention
    * adding temporal-guided attention
    * adding cross-frame attention
    
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
    Usage
    frescoProc = FRESCOAttnProcessor2_0(2, attn_mask)
    attnProc = AttnProcessor2_0()
    
    attn_processor_dict = {}
    for k in pipe.unet.attn_processors.keys():
        if k.startswith("up_blocks.2") or k.startswith("up_blocks.3"):
            attn_processor_dict[k] = frescoProc
        else:
            attn_processor_dict[k] = attnProc
    pipe.unet.set_attn_processor(attn_processor_dict)
    """

    def __init__(self, unet_chunk_size=2, controller=None):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
        self.unet_chunk_size = unet_chunk_size
        self.controller = controller

分别执行spatial-guided attention、temporal-guided attention、cross-frame attention:

  • 在执行当前帧的self-attention之前,先保存当前帧的self-attn feature到controller。query_raw, key_raw保存当前帧的query, key。
    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        crossattn = False  # ckeck is cross-attn
        if encoder_hidden_states is None:  # self-attn
            encoder_hidden_states = hidden_states
            if self.controller and self.controller.store:  # save self-attn-feature to controller
                self.controller(hidden_states.detach().clone())
        else:  # cross-attn
            crossattn = True
            if attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
            
        # BC * HW * 8D
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)
        
        # if self-attn and temporal-guided attention, save query_raw and key_raw
        query_raw, key_raw = None, None
        if self.controller and self.controller.use_interattn and (not crossattn):
            query_raw, key_raw = query.clone(), key.clone()

        inner_dim = key.shape[-1] # 8D
        head_dim = inner_dim // attn.heads # D
  • spatial-guided intra-frame attention :从controller中取出当前window帧的原始图像只经过一次DDPM froward和backward的UNet的Decoder特征ref_hidden_states,作为key和value,计算self-attention。
		'''for spatial-guided intra-frame attention'''
        if self.controller and self.controller.use_intraattn and (not crossattn): 
            ref_hidden_states = self.controller(None)
            assert ref_hidden_states.shape == encoder_hidden_states.shape
            query_ = attn.to_q(ref_hidden_states)
            key_ = attn.to_k(ref_hidden_states) 
            
            ''' 
            # for xformers implementation 
            if importlib.util.find_spec("xformers") is not None:
                # BC * HW * 8D --> BC * HW * 8 * D
                query_ = rearrange(query_, "b d (h c) -> b d h c", h=attn.heads)
                key_ = rearrange(key_, "b d (h c) -> b d h c", h=attn.heads)
                # BC * 8 * HW * D --> 8BC * HW * D
                query = rearrange(query, "b h d c -> b d h c")
                query = xformers.ops.memory_efficient_attention(
                    query_, key_ * self.sattn_scale_factor, query, 
                    attn_bias=torch.eye(query_.size(1), key_.size(1), 
                    dtype=query.dtype, device=query.device) * self.bias_weight, op=None
                )
                query = rearrange(query, "b d h c -> b h d c").detach()
            '''
            # BC * 8 * HW * D
            query_ = query_.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
            key_ = key_.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
            query = F.scaled_dot_product_attention(
                query_, key_ * self.controller.intraattn_scale_factor, query, 
                attn_mask = torch.eye(query_.size(-2), key_.size(-2), 
                                      dtype=query.dtype, device=query.device) * self.controller.intraattn_bias,
            ).detach()
            #print('intra: ', GPU.getGPUs()[1].memoryUsed)
            del query_, key_
            torch.cuda.empty_cache()
  • efficient cross-frame attention:从AttnControler中取出基于光流的attn_mask =[video_len, seq_len],将每帧key的mask的区域取出,并为每帧重复,共video_length个,使得每帧的query都可以和所有帧的mask_key计算attn;value和key操作一样。实现高效的cross frame attention。query来自spatial-guided intra-frame attention 的计算结果。
import torch
from einops import repeat

batch_size = 2  # unet_chunk_size for CFG
video_length = 8
sequence_length = 1280
hidden_size = 640

key = torch.randn(batch_size, video_length, sequence_length, hidden_size)
# 生成(video_length, sequence_length)大小的随机true false 矩阵, 代表哪些位置需要mask
attn_mask = torch.randint(0, 2, (video_length, sequence_length))

# 将每帧mask的区域的key取出,并为每帧重复,共f个,使得每帧的query都可以和所有帧的mask_key计算attn
# 实现方法1:
masked_key = key[:, attn_mask.bool()].unsqueeze(1).repeat(1, video_length, 1, 1)
print(masked_key.shape)
# 实现方法2:
masked_key = repeat(key[:, attn_mask.bool()], "b d c -> b f d c", f=video_length)  # 这里的d由mask决定保留多少个token
print(masked_key.shape)
		'''for efficient cross-frame attention'''
        if self.controller and self.controller.use_cfattn and (not crossattn):
            video_length = key.size()[0] // self.unet_chunk_size  # BC // C = B
            former_frame_index = [0] * video_length  # pervious frame index
            attn_mask = None  # get attn_mask from controller
            if self.controller.attn_mask is not None:
                for m in self.controller.attn_mask:
                    if m.shape[1] == key.shape[1]:
                        attn_mask = m
            # get key and value in former_frame_index range
            # BC * HW * 8D --> B * C * HW * 8D
            key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
            # B * C * HW * 8D --> B * C * HW * 8D
            if attn_mask is None:
                key = key[:, former_frame_index]
            else:
                key = repeat(key[:, attn_mask], "b d c -> b f d c", f=video_length)
            # B * C * HW * 8D --> BC * HW * 8D 
            key = rearrange(key, "b f d c -> (b f) d c").detach()
            value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
            if attn_mask is None:
                value = value[:, former_frame_index]
            else:
                value = repeat(value[:, attn_mask], "b d c -> b f d c", f=video_length)              
            value = rearrange(value, "b f d c -> (b f) d c").detach()
        
        # BC * HW * 8D --> BC * HW * 8 * D --> BC * 8 * HW * D
        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        # BC * 8 * HW2 * D
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        # BC * 8 * HW2 * D2
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
		# the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        # output: BC * 8 * HW * D2      
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )
  • temporal-guided inter-frame attention (FLATTEN):从controller的interattn_paras中取出 bwd_mappingfwd_mappinginterattn_masks
		'''for temporal-guided inter-frame attention (FLATTEN)'''
        if self.controller and self.controller.use_interattn and (not crossattn):
            del query, key, value
            torch.cuda.empty_cache()
            bwd_mapping, fwd_mapping, flattn_mask = None, None, None
            for i, f in enumerate(self.controller.interattn_paras['fwd_mappings']):
                if f.shape[2] == hidden_states.shape[2]:
                    fwd_mapping = f
                    bwd_mapping = self.controller.interattn_paras['bwd_mappings'][i]
                    interattn_mask = self.controller.interattn_paras['interattn_masks'][i]
            video_length = key_raw.size()[0] // self.unet_chunk_size
            # BC * HW * 8D --> C * 8BD * HW
            key = rearrange(key_raw, "(b f) d c -> f (b c) d", f=video_length)
            query = rearrange(query_raw, "(b f) d c -> f (b c) d", f=video_length)
            # BC * 8 * HW * D --> C * 8BD * HW
            #key = rearrange(hidden_states, "(b f) h d c -> f (b h c) d", f=video_length) ########
            #query = rearrange(hidden_states, "(b f) h d c -> f (b h c) d", f=video_length) #######
            
            value = rearrange(hidden_states, "(b f) h d c -> f (b h c) d", f=video_length)
            key = torch.gather(key, 2, fwd_mapping.expand(-1,key.shape[1],-1))
            query = torch.gather(query, 2, fwd_mapping.expand(-1,query.shape[1],-1))
            value = torch.gather(value, 2, fwd_mapping.expand(-1,value.shape[1],-1))
            # C * 8BD * HW --> BHW, C, 8D
            key = rearrange(key, "f (b c) d -> (b d) f c", b=self.unet_chunk_size)
            query = rearrange(query, "f (b c) d -> (b d) f c", b=self.unet_chunk_size)
            value = rearrange(value, "f (b c) d -> (b d) f c", b=self.unet_chunk_size) 
            # BHW * C * 8D --> BHW * C * 8 * D--> BHW * 8 * C * D
            query = query.view(-1, video_length, attn.heads, head_dim).transpose(1, 2).detach()
            key = key.view(-1, video_length, attn.heads, head_dim).transpose(1, 2).detach()
            value = value.view(-1, video_length, attn.heads, head_dim).transpose(1, 2).detach()
            hidden_states_ = F.scaled_dot_product_attention(
                query, key * self.controller.interattn_scale_factor, value, 
                attn_mask = (interattn_mask.repeat(self.unet_chunk_size,1,1,1))#.to(query.dtype)-1.0) * 1e6 -
                #torch.eye(interattn_mask.shape[2]).to(query.device).to(query.dtype) * 1e4,
            )
                
            # BHW * 8 * C * D --> C * 8BD * HW
            hidden_states_ = rearrange(hidden_states_, "(b d) h f c -> f (b h c) d", b=self.unet_chunk_size)
            hidden_states_ = torch.gather(hidden_states_, 2, bwd_mapping.expand(-1,hidden_states_.shape[1],-1)).detach()
            # C * 8BD * HW --> BC * 8 * HW * D
            hidden_states = rearrange(hidden_states_, "f (b h c) d -> (b f) h d c", b=self.unet_chunk_size, h=attn.heads)
            #print('inter: ', GPU.getGPUs()[1].memoryUsed)
		# BC * 8 * HW * D --> BC * HW * 8D
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

2. get_flow_and_interframe_paras

利用光流模型GMFlow,求帧间temporal-guided attention的参数:optical_flowocclusion_mask;为temporal-guided attention(FLANN)计算pixel index 像素索引对应关系(同一条光流轨迹上的pixel)

  1. 数据预处理

将输入的numpy图像数据转换为PyTorch张量,并调整其维度顺序(从HWC变为CHW)。将所有图像堆叠为一个批量(batch),并将整个批量转移到GPU上进行加速计算。
创建另一个张量imgs_torch,通过numpy2tensor函数将原始imgs中的每个图像转换为张量后进行拼接。这一步主要用于后续的可视化。

def get_flow_and_interframe_paras(flow_model, imgs, visualize_pipeline=False):
    """
    Get parameters for temporal-guided attention and optimization
    * predict optical_flow and occlusion_mask
    * compute pixel index correspondence for FLATTEN
    """
    images = torch.stack([torch.from_numpy(img).permute(2, 0, 1).float() for img in imgs], dim=0).cuda()
    imgs_torch = torch.cat([numpy2tensor(img) for img in imgs], dim=0)  # torch.Size([batch=8, c=3, h=512, w=640])
  1. 光流计算与遮挡掩码生成

生成一个reshuffle_list,将输入图像列表中除第一个外的所有元素按顺序排列,然后将第一个元素添加到列表末尾。这个列表用于计算原图像与重排图像之间的光流。
使用flow_model计算原图像与重排图像之间的光流,并获取预测结果字典results_dict。从中取出最后一个时间步的光流预测flow_pr,将其拆分为前向光流fwd_flows和后向光流bwd_flows。
调用forward_backward_consistency_check函数,根据前向和后向光流计算前向遮挡掩码fwd_occs和后向遮挡掩码bwd_occs。

	reshuffle_list = list(range(1,len(images)))+[0]
    # compute flow between images and reshuffle_images
    results_dict = flow_model(images, images[reshuffle_list], attn_splits_list=[2], 
                              corr_radius_list=[-1], prop_radius_list=[-1], pred_bidir_flow=True)
    flow_pr = results_dict['flow_preds'][-1]  # [2*B, 2, H, W]  torch.Size([16, 2, 512, 640])
    fwd_flows, bwd_flows = flow_pr.chunk(2)   # [B, 2, H, W] torch.Size([8, 2, 512, 640]), torch.Size([8, 2, 512, 640])
    fwd_occs, bwd_occs = forward_backward_consistency_check(fwd_flows, bwd_flows) # [B, H, W] torch.Size([8, 512, 640]), torch.Size([8, 512, 640])
  1. 遮挡掩码增强

利用后向光流对原图像进行反向扭曲,得到warped_image1。
更新后向遮挡掩码bwd_occs:当原图像与反向扭曲图像之间的像素差异超过阈值(255*0.25)时,相应位置的遮挡掩码值增加。
同理,利用前向光流对重排图像进行前向扭曲,得到warped_image2,并更新前向遮挡掩码fwd_occs。

    warped_image1 = flow_warp(images, bwd_flows)  # torch.Size([8, 3, 512, 640])
    bwd_occs = torch.clamp(bwd_occs + (abs(images[reshuffle_list]-warped_image1).mean(dim=1)>255*0.25).float(), 0 ,1)  # torch.Size([8, 512, 640])
    warped_image2 = flow_warp(images[reshuffle_list], fwd_flows)  # torch.Size([8, 3, 512, 640])
    fwd_occs = torch.clamp(fwd_occs + (abs(images-warped_image2).mean(dim=1)>255*0.25).float(), 0 ,1)  # torch.Size([8, 512, 640])
  1. 可视化(可选)

如果visualize_pipeline为True,则打印消息并基于光流计算的遮挡掩码对原始图像进行可视化。使用torchvision.utils.make_grid创建网格状图像,并调用visualize函数显示。

	if visualize_pipeline:
        print('visualized occlusion masks based on optical flows')
        viz = torchvision.utils.make_grid(imgs_torch * (1-fwd_occs.unsqueeze(1)), len(images), 1)
        visualize(viz.cpu(), 90)
        viz = torchvision.utils.make_grid(imgs_torch[reshuffle_list] * (1-bwd_occs.unsqueeze(1)), len(images), 1)
        visualize(viz.cpu(), 90) 
  1. cross-frame attention的掩码计算

对于不同的缩放比例(8.0, 16.0, 32.0),即up_blocks.1, up_blocks.2和up_blocks.3,对后向遮挡掩码进行下采样,并构建注意力掩码列表attn_mask。注意力掩码表示在不同尺度下,哪些像素更可能参与注意力机制。

	attn_mask = []
    for scale in [8.0, 16.0, 32.0]:
        bwd_occs_ = F.interpolate(bwd_occs[:-1].unsqueeze(1), scale_factor=1./scale, mode='bilinear')  # torch.Size([7, 1, 64, 80]), torch.Size([7, 1, 32, 40]), torch.Size([7, 1, 16, 20])
        attn_mask += [torch.cat((bwd_occs_[0:1].reshape(1,-1)>-1, bwd_occs_.reshape(bwd_occs_.shape[0],-1)>0.5), dim=0)]   # torch.Size([8, 5120]), torch.Size([8, 1280]), torch.Size([8, 320])
  1. 光流轨迹pixel映射关系temporal-guided attention的掩码计算

对选定的缩放比例(8.0, 16.0),即up_blocks.2和up_blocks.3,调用get_mapping_ind函数计算前向(fwd_mappings)和后向(bwd_mappings)像素映射关系以及交互注意力掩码(interattn_masks)。

	fwd_mappings = []
    bwd_mappings = []
    interattn_masks = []
    for scale in [8.0, 16.0]:
        fwd_mapping, bwd_mapping, interattn_mask = get_mapping_ind(bwd_flows, bwd_occs, imgs_torch, scale=scale)
        fwd_mappings += [fwd_mapping]  # torch.Size([8, 1, 5120]), torch.Size([8, 1, 1280])
        bwd_mappings += [bwd_mapping]  # torch.Size([8, 1, 5120]), torch.Size([8, 1, 1280])
        interattn_masks += [interattn_mask]  # torch.Size([5120, 1, 8, 8]), torch.Size([1280, 1, 8, 8])
	interattn_paras = {
    
    }
    interattn_paras['fwd_mappings'] = fwd_mappings
    interattn_paras['bwd_mappings'] = bwd_mappings
    interattn_paras['interattn_masks'] = interattn_masks    

    gc.collect()
    torch.cuda.empty_cache()
    
    return [fwd_flows, bwd_flows], [fwd_occs, bwd_occs], attn_mask, interattn_paras

3. get_intraframe_paras

这里只看对spatial-guided attentiond 影响:先把所有controler的都清空关闭。

	frescoProc.controller.disable_controller()
    frescoProc.controller.clear_store()
    frescoProc.controller.enable_store()
    frescoProc.controller.disable_store()

再在FRESCO Inference 流程中打开3种attention:

		# Turn on all FRESCO support
        frescoProc.controller.enable_controller(interattn_paras=interattn_paras, attn_mask=attn_mask)

4. inference

一次采样batch_size=8个关键帧:

        latents = inference(pipe, controlnet, frescoProc, 
                  imgs_torch, prompt_embeds, edges, timesteps,
                  cond_scale, config['num_inference_steps'], config['num_warmup_steps'], 
                  do_classifier_free_guidance, config['seed'], guidance_scale, config['use_controlnet'],         
                  record_latents, propagation_mode,
                  flows = flows, occs = occs, saliency=saliency, repeat_noise=True)
  • 初始化垃圾回收、清空CUDA缓存、设置设备和生成器;
  • 根据输入参数初始化或准备嵌入(latents);
  • Free-Init:如果设置了重复噪声(repeat_noise),则将latents中的第一个元素沿批次维度重复B次;
    if repeat_noise:  # True  
        latents = latents[0:1].repeat(B,1,1,1).detach()  # 将 latents 中的第一个元素沿批次维度重复 B 次
        
  • SDEidting:如果num_warmup_steps小于0,则将num_warmup_steps设为0,否则使用noisy latent作为输入;
    if num_warmup_steps < 0:  # num_warmup_steps: SDEdit add noise step = (num_inference_steps-num_warmup_steps)
        latents_init = latents.detach()
        num_warmup_steps = 0
    else:
        # SDEdit: use the noisy latent of imges as the input rather than a pure gausssian noise
        latent_x0 = pipe.vae.config.scaling_factor * pipe.vae.encode(imgs.to(pipe.unet.dtype)).latent_dist.sample()
        latents_init = noise_scheduler.add_noise(latent_x0, latents, timesteps[num_warmup_steps]).detach()
  • 使用pipe的进度条,运行num_inference_steps-num_warmup_steps步;
  • 在每一步中,根据条件应用空间/时间引导注意力、记录和恢复嵌入、扩展嵌入;
# timesteps=[951, 901, 851, 801, 751, 701, 651, 601, 551, 501, 451, 401, 351, 301, 251, 201, 151, 101, 51, 1]
            if i >= num_intraattn_steps:  # only use spatial-guided attention in i=num_intraattn_steps=1 (start 1/20 step)
                frescoProc.controller.disable_intraattn()
            if t < step_interattn_end:  # only use temporal-guided attention in t<step_interattn_end=350 (last 7/20 steps)
                frescoProc.controller.disable_interattn()

在 inference 函数中对 record_latents 的修改会直接影响到原始列表:对第一个batch的8个frame来说,propagation_mode=False,不断的记录每个timestep的当前batch中第1帧和最后1帧的latents到record_latents中。对后面的batch来说,propagation_mode=True,每个timestep不断取前1个batch的record_latents替换当前batch的前2帧:

			if propagation_mode: # restore latent from previous batch and record latent of the current batch
                latents[0:2] = record_latents[i].detach().clone()
                record_latents[i] = latents[[0,len(latents)-1]].detach().clone()
            else: # frist batch, record_latents[0][t] = [x_{1,t}, x_{N,t}] = [first_frame, last_frame]
                record_latents += [latents[[0,len(latents)-1]].detach().clone()]  # add torch.Size([2, 4, 64, 80])
  • 如果使用控制网络,则将控制网络的输出添加到unet的输入中;
  • 预测噪声残差;
  • 如果设置了分类器自由指导,则对噪声残差进行处理;
  • 根据是否在背景平滑步骤中,调用不同的step函数来计算前一步的噪声样本;
  • 返回嵌入(latents)。
@torch.no_grad()
def inference(pipe, controlnet, frescoProc, 
              imgs, prompt_embeds, edges, timesteps,
              cond_scale=[0.7]*20, num_inference_steps=20, num_warmup_steps=6, 
              do_classifier_free_guidance=True, seed=0, guidance_scale=7.5, use_controlnet=True,         
              record_latents=[], propagation_mode=False, visualize_pipeline=False, 
              flows = None, occs = None, saliency=None, repeat_noise=False,
              num_intraattn_steps = 1, step_interattn_end = 350, bg_smoothing_steps = [16,17]):
    """
    video-to-video translation inference pipeline with FRESCO
    * add controlnet and SDEdit
    * add FRESCO-guided attention
    * add FRESCO-guided optimization
    * add background smoothing
    * add support for inter-batch long video translation
    
    [input of the original pipe]
    pipe: base diffusion model
    imgs: a batch of the input frames
    prompt_embeds: prompts
    num_inference_steps: number of DDPM steps 
    timesteps: generated by pipe.scheduler.set_timesteps(num_inference_steps)
    do_classifier_free_guidance: cfg, should be always true
    guidance_scale: cfg scale
    seed

    [input of SDEdit]
    num_warmup_steps: skip the first num_warmup_steps DDPM steps

    [input of controlnet]
    use_controlnet: bool, whether using controlnet
    controlnet: controlnet model
    edges: input for controlnet (edge/stroke/depth, etc.)
    cond_scale: controlnet scale

    [input of FRESCO]
    frescoProc: FRESCO attention controller 
    flows: optical flows 
    occs: occlusion mask
    num_intraattn_steps: apply num_interattn_steps steps of spatial-guided attention
    step_interattn_end: apply temporal-guided attention in [step_interattn_end, 1000] steps

    [input for background smoothing]
    saliency: saliency mask
    repeat_noise: bool, use the same noise for all frames
    bg_smoothing_steps: apply background smoothing in bg_smoothing_steps

    [input for long video translation]
    record_latents: recorded latents in the last batch
    propagation_mode: bool, whether this is not the first batch
    
    [output]
    latents: a batch of latents of the translated frames 
    """
    gc.collect()
    torch.cuda.empty_cache()

    device = pipe._execution_device
    noise_scheduler = pipe.scheduler 
    generator = torch.Generator(device=device).manual_seed(seed)
    B, C, H, W = imgs.shape  # torch.Size([8, 3, 512, 640])
    latents = pipe.prepare_latents(  # torch.Size([8, 4, 64, 80])
        B,
        pipe.unet.config.in_channels,
        H,
        W,
        prompt_embeds.dtype,
        device,
        generator,
        latents = None,
    )   
    
    if repeat_noise:  # True  
        latents = latents[0:1].repeat(B,1,1,1).detach()  # 将 latents 中的第一个元素沿批次维度重复 B 次
        
    if num_warmup_steps < 0:  # num_warmup_steps: SDEdit add noise step = (num_inference_steps-num_warmup_steps)
        latents_init = latents.detach()
        num_warmup_steps = 0
    else:
        # SDEdit: use the noisy latent of imges as the input rather than a pure gausssian noise
        latent_x0 = pipe.vae.config.scaling_factor * pipe.vae.encode(imgs.to(pipe.unet.dtype)).latent_dist.sample()
        latents_init = noise_scheduler.add_noise(latent_x0, latents, timesteps[num_warmup_steps]).detach()

    # SDEdit, run num_inference_steps-num_warmup_steps steps
    with pipe.progress_bar(total=num_inference_steps-num_warmup_steps) as progress_bar:
        latents = latents_init  # torch.Size([8, 4, 64, 80])
        for i, t in enumerate(timesteps[num_warmup_steps:]):
            """
            [HACK] control the steps to apply spatial/temporal-guided attention
            [HACK] record and restore latents from previous batch
            """
            # timesteps=[951, 901, 851, 801, 751, 701, 651, 601, 551, 501, 451, 401, 351, 301, 251, 201, 151, 101, 51, 1]
            if i >= num_intraattn_steps:  # only use spatial-guided attention in i=num_intraattn_steps=1 (start 1/20 step)
                frescoProc.controller.disable_intraattn()
            if t < step_interattn_end:  # only use temporal-guided attention in t<step_interattn_end=350 (last 7/20 steps)
                frescoProc.controller.disable_interattn()

            if propagation_mode: # restore latent from previous batch and record latent of the current batch
                latents[0:2] = record_latents[i].detach().clone()
                record_latents[i] = latents[[0,len(latents)-1]].detach().clone()
            else: # frist batch, record_latents[0][t] = [x_{1,t}, x_{N,t}] = [first_frame, last_frame]
                record_latents += [latents[[0,len(latents)-1]].detach().clone()]  # add torch.Size([2, 4, 64, 80])
            
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
            
            if use_controlnet:
                control_model_input = latent_model_input
                controlnet_prompt_embeds = prompt_embeds

                down_block_res_samples, mid_block_res_sample = controlnet(
                    control_model_input,
                    t,
                    encoder_hidden_states=controlnet_prompt_embeds,
                    controlnet_cond=edges,
                    conditioning_scale=cond_scale[i+num_warmup_steps],
                    guess_mode=False,
                    return_dict=False,
                )
            else:
                down_block_res_samples, mid_block_res_sample = None, None 
            
            # predict the noise residual
            noise_pred = pipe.unet(
                latent_model_input,
                t,
                encoder_hidden_states=prompt_embeds,
                cross_attention_kwargs=None,
                down_block_additional_residuals=down_block_res_samples,
                mid_block_additional_residual=mid_block_res_sample,
                return_dict=False,
            )[0]
            
            # perform guidance
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            """
            [HACK] background smoothing
            Note: bg_smoothing_steps should be rescaled based on num_inference_steps
            current [16,17] is based on num_inference_steps=20
            """
            if i + num_warmup_steps in bg_smoothing_steps:
                latents = step(pipe, noise_pred, t, latents, generator, 
                               visualize_pipeline=visualize_pipeline, 
                               flows = flows, occs = occs, saliency=saliency)[0]  
            else:
                latents = step(pipe, noise_pred, t, latents, generator, 
                           visualize_pipeline=visualize_pipeline)[0]                            

            # call the callback, if provided
            if i == len(timesteps) - 1 or ((i + 1) > 0 and (i + 1) % pipe.scheduler.order == 0):
                progress_bar.update()
                
    return latents

猜你喜欢

转载自blog.csdn.net/weixin_54338498/article/details/137429124