本文主要解析FRESCO如何存储前面keyframe的 attention feature,并将其用于后续keyframe采样过程的的。(但不涉及分析FRESCO-based optimization特征优化的源码)
FRESCO Inference 流程
在get_models
函数中,加载好Stable Diffusion后:
- 使用
apply_freeu
为unet设置FreeU。 - 使用
apply_FRESCO_attn
创建FRESCOAttnProcessor2_0
,将unet的up_blocks.2
和up_blocks.3
的AttnProcessor设置为FRESCOAttnProcessor2_0
,其他block的AttnProcessor设置为AttnProcessor2_0
。返回的frescoProc
用于控制unet中FRESCO_Attention的开关。 - 使用
apply_FRESCO_opt
设置unet特征优化。
if config['use_freeu']:
from src.free_lunch_utils import apply_freeu
apply_freeu(pipe, b1=1.2, b2=1.5, s1=1.0, s2=1.0)
frescoProc = apply_FRESCO_attn(pipe)
frescoProc.controller.disable_controller()
apply_FRESCO_opt(pipe)
print('create diffusion model ' + config['sd_path'] + ' successfully!')
将key frames划分为多个batch,每个batch一起去噪:注意第一个batch的帧数=batch_size
,后面batch的帧数都=batch_size-2
,少的那2帧用于 为当前batch 引入 前一个batch的信息! 详见inference代码的record_latents
。
sublists = [keys[i:i+config['batch_size']-2] for i in range(2, len(keys), config['batch_size']-2)]
sublists[0].insert(0, keys[0])
sublists[0].insert(1, keys[1])
if len(sublists) > 1 and len(sublists[-1]) < 3: # 如果最后一个子列表的长度小于3,则从前一个子列表中取出几个元素,让3个元素组成最后一个的子列表
add_num = 3 - len(sublists[-1])
sublists[-1] = sublists[-2][-add_num:] + sublists[-1]
sublists[-2] = sublists[-2][:-add_num]
if not sublists[-2]:
del sublists[-2]
在开始推理之前,get_flow_and_interframe_paras
计算帧间的光流追踪关系;get_intraframe_paras
:
# prepare parameters for inter-frame and intra-frame consistency
flows, occs, attn_mask, interattn_paras = get_flow_and_interframe_paras(flow_model, imgs)
correlation_matrix = get_intraframe_paras(pipe, imgs_torch, frescoProc,
prompt_embeds, seed = config['seed'])
'''
Flexible settings for attention:
* Turn off FRESCO-guided attention: frescoProc.controller.disable_controller()
Then you can turn on one specific attention submodule
* Turn on Cross-frame attention: frescoProc.controller.enable_cfattn(attn_mask)
* Turn on Spatial-guided attention: frescoProc.controller.enable_intraattn()
* Turn on Temporal-guided attention: frescoProc.controller.enable_interattn(interattn_paras)
Flexible settings for optimization:
* Turn off Spatial-guided optimization: set optimize_temporal = False in apply_FRESCO_opt()
* Turn off Temporal-guided optimization: set correlation_matrix = [] in apply_FRESCO_opt()
* Turn off FRESCO-guided optimization: disable_FRESCO_opt(pipe)
Flexible settings for background smoothing:
* Turn off background smoothing: set saliency = None in apply_FRESCO_opt()
'''
# Turn on all FRESCO support
frescoProc.controller.enable_controller(interattn_paras=interattn_paras, attn_mask=attn_mask)
inference
执行FRESCO的推理:
# run!
latents = inference(pipe, controlnet, frescoProc,
imgs_torch, prompt_embeds, edges, timesteps,
cond_scale, config['num_inference_steps'], config['num_warmup_steps'],
do_classifier_free_guidance, config['seed'], guidance_scale, config['use_controlnet'],
record_latents, propagation_mode,
flows = flows, occs = occs, saliency=saliency, repeat_noise=True)
1. apply_FRESCO_attn
PART I - FRESCO-based attention
* Class AttentionControl: Control the function of FRESCO-based attention
* Class FRESCOAttnProcessor2_0: FRESCO-based attention
* apply_FRESCO_attn(): Apply FRESCO-based attention to a StableDiffusionPipeline
apply_FRESCO_attn
函数创建FRESCOAttnProcessor2_0
对象和AttnProcessor2_0
对象,unet中有3个down_blocks、1个mid_blocks、3个up_blocks,但只将unet的up_blocks.2
和up_blocks.3
的AttnProcessor设置为FRESCOAttnProcessor2_0
,其他block的AttnProcessor设置为AttnProcessor2_0
。返回的 frescoProc
用于控制unet中FRESCO_Attention的开关。
def apply_FRESCO_attn(pipe):
"""
Apply FRESCO-guided attention to a StableDiffusionPipeline
"""
frescoProc = FRESCOAttnProcessor2_0(2, AttentionControl())
attnProc = AttnProcessor2_0()
attn_processor_dict = {
}
for k in pipe.unet.attn_processors.keys():
if k.startswith("up_blocks.2") or k.startswith("up_blocks.3"):
attn_processor_dict[k] = frescoProc
else:
attn_processor_dict[k] = attnProc
pipe.unet.set_attn_processor(attn_processor_dict)
return frescoProc
AttentionControl
AttentionControl类用于存储Unet的decoder中帧间的self-attention feature。存储的attention feature在stored_attn
中。Intra
-attention (帧内部注意力/空间注意力)、Inter
-attention (帧之间注意力/时间注意力)
class AttentionControl():
"""
Control FRESCO-based attention
* enable/diable spatial-guided attention
* enable/diable temporal-guided attention
* enable/diable cross-frame attention
* collect intermediate attention feature (for spatial-guided attention)
"""
def __init__(self):
self.stored_attn = self.get_empty_store() # attention feature存储
self.store = False # enable_store和disable_store 是否存储注意力信息
self.index = 0 # spatial-guided attention(帧内)参数
self.attn_mask = None # cross frame attention参数:attn_mask
self.interattn_paras = None # temporal-guided attention(帧间)参数:fwd_mappings, bwd_mappings, interattn_masks
self.use_interattn = False # enable_intraattn和disable_intraattn 是否使用spatial-guided attention(帧内)
self.use_cfattn = False # enable_cfattn和disable_cfattn 是否使用corss-frame attention
self.use_intraattn = False # enable_interattn和disable_interattn 是否使用temporal-guided attention(帧间)
self.intraattn_bias = 0
self.intraattn_scale_factor = 0.2
self.interattn_scale_factor = 0.2
@staticmethod
def get_empty_store():
return {
'decoder_attn': [],
}
def clear_store(self):
del self.stored_attn
torch.cuda.empty_cache()
gc.collect()
self.stored_attn = self.get_empty_store()
self.disable_intraattn()
# store attention feature of the input frame for spatial-guided attention
def enable_store(self):
self.store = True
def disable_store(self):
self.store = False
下面几个方法用于控制3种attention是否执行:
- spatial-guided attention
- temporal-guided attention
- cross-frame attention
# spatial-guided attention
def enable_intraattn(self):
self.index = 0
self.use_intraattn = True
self.disable_store()
if len(self.stored_attn['decoder_attn']) == 0:
self.use_intraattn = False
def disable_intraattn(self):
self.index = 0
self.use_intraattn = False
self.disable_store()
def disable_cfattn(self):
self.use_cfattn = False
# cross frame attention
def enable_cfattn(self, attn_mask=None):
if attn_mask:
if self.attn_mask:
del self.attn_mask
torch.cuda.empty_cache()
self.attn_mask = attn_mask
self.use_cfattn = True
else:
if self.attn_mask:
self.use_cfattn = True
else:
print('Warning: no valid cross-frame attention parameters available!')
self.disable_cfattn()
def disable_interattn(self):
self.use_interattn = False
# temporal-guided attention
def enable_interattn(self, interattn_paras=None):
if interattn_paras:
if self.interattn_paras:
del self.interattn_paras
torch.cuda.empty_cache()
self.interattn_paras = interattn_paras
self.use_interattn = True
else:
if self.interattn_paras:
self.use_interattn = True
else:
print('Warning: no valid temporal-guided attention parameters available!')
self.disable_interattn()
enable_controller
和disable_controller
用于同时控制3种attention是否使用。
__call__
调用时执行forward
,如果 self.store
为真,将传入的context(attention feature)
存入stored_attn['decoder_attn']
中,如果启用了 self.use_intraattn
,并且 self.stored_attn['decoder_attn']
中有存储的attention feature,并且已经存储的数量超过了索引 self.index
,则返回存储的attention feature。如果没有存储的attention feature可用或者未启用 self.use_intraattn
,则直接返回输入的 context。
def disable_controller(self):
self.disable_intraattn()
self.disable_interattn()
self.disable_cfattn()
def enable_controller(self, interattn_paras=None, attn_mask=None):
self.enable_intraattn()
self.enable_interattn(interattn_paras)
self.enable_cfattn(attn_mask)
def forward(self, context):
if self.store:
self.stored_attn['decoder_attn'].append(context.detach())
if self.use_intraattn and len(self.stored_attn['decoder_attn']) > 0:
tmp = self.stored_attn['decoder_attn'][self.index]
self.index = self.index + 1
if self.index >= len(self.stored_attn['decoder_attn']):
self.index = 0
self.disable_store()
return tmp
return context
def __call__(self, context):
context = self.forward(context)
return context
FRESCOAttnProcessor2_0
创建FRESCOAttnProcessor2_0类时,需要传入unet_chunk_size
和AttentionControl
对象。用于将unet的 self attention
替换为FRESCO-based attention
(spatial-guided attention、temporal-guided attention、cross-frame attention)
class FRESCOAttnProcessor2_0:
"""
Hack self attention to FRESCO-based attention
* adding spatial-guided attention
* adding temporal-guided attention
* adding cross-frame attention
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
Usage
frescoProc = FRESCOAttnProcessor2_0(2, attn_mask)
attnProc = AttnProcessor2_0()
attn_processor_dict = {}
for k in pipe.unet.attn_processors.keys():
if k.startswith("up_blocks.2") or k.startswith("up_blocks.3"):
attn_processor_dict[k] = frescoProc
else:
attn_processor_dict[k] = attnProc
pipe.unet.set_attn_processor(attn_processor_dict)
"""
def __init__(self, unet_chunk_size=2, controller=None):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.unet_chunk_size = unet_chunk_size
self.controller = controller
分别执行spatial-guided attention、temporal-guided attention、cross-frame attention:
- 在执行当前帧的self-attention之前,先保存当前帧的
self-attn feature
到controller。query_raw
,key_raw
保存当前帧的query, key。
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
crossattn = False # ckeck is cross-attn
if encoder_hidden_states is None: # self-attn
encoder_hidden_states = hidden_states
if self.controller and self.controller.store: # save self-attn-feature to controller
self.controller(hidden_states.detach().clone())
else: # cross-attn
crossattn = True
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# BC * HW * 8D
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# if self-attn and temporal-guided attention, save query_raw and key_raw
query_raw, key_raw = None, None
if self.controller and self.controller.use_interattn and (not crossattn):
query_raw, key_raw = query.clone(), key.clone()
inner_dim = key.shape[-1] # 8D
head_dim = inner_dim // attn.heads # D
- spatial-guided intra-frame attention :从controller中取出当前window帧的原始图像只经过一次DDPM froward和backward的UNet的Decoder特征
ref_hidden_states
,作为key和value,计算self-attention。
'''for spatial-guided intra-frame attention'''
if self.controller and self.controller.use_intraattn and (not crossattn):
ref_hidden_states = self.controller(None)
assert ref_hidden_states.shape == encoder_hidden_states.shape
query_ = attn.to_q(ref_hidden_states)
key_ = attn.to_k(ref_hidden_states)
'''
# for xformers implementation
if importlib.util.find_spec("xformers") is not None:
# BC * HW * 8D --> BC * HW * 8 * D
query_ = rearrange(query_, "b d (h c) -> b d h c", h=attn.heads)
key_ = rearrange(key_, "b d (h c) -> b d h c", h=attn.heads)
# BC * 8 * HW * D --> 8BC * HW * D
query = rearrange(query, "b h d c -> b d h c")
query = xformers.ops.memory_efficient_attention(
query_, key_ * self.sattn_scale_factor, query,
attn_bias=torch.eye(query_.size(1), key_.size(1),
dtype=query.dtype, device=query.device) * self.bias_weight, op=None
)
query = rearrange(query, "b d h c -> b h d c").detach()
'''
# BC * 8 * HW * D
query_ = query_.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key_ = key_.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
query = F.scaled_dot_product_attention(
query_, key_ * self.controller.intraattn_scale_factor, query,
attn_mask = torch.eye(query_.size(-2), key_.size(-2),
dtype=query.dtype, device=query.device) * self.controller.intraattn_bias,
).detach()
#print('intra: ', GPU.getGPUs()[1].memoryUsed)
del query_, key_
torch.cuda.empty_cache()
- efficient cross-frame attention:从AttnControler中取出基于光流的
attn_mask =[video_len, seq_len]
,将每帧key的mask的区域取出,并为每帧重复,共video_length个,使得每帧的query都可以和所有帧的mask_key计算attn;value和key操作一样。实现高效的cross frame attention。query来自spatial-guided intra-frame attention 的计算结果。
import torch
from einops import repeat
batch_size = 2 # unet_chunk_size for CFG
video_length = 8
sequence_length = 1280
hidden_size = 640
key = torch.randn(batch_size, video_length, sequence_length, hidden_size)
# 生成(video_length, sequence_length)大小的随机true false 矩阵, 代表哪些位置需要mask
attn_mask = torch.randint(0, 2, (video_length, sequence_length))
# 将每帧mask的区域的key取出,并为每帧重复,共f个,使得每帧的query都可以和所有帧的mask_key计算attn
# 实现方法1:
masked_key = key[:, attn_mask.bool()].unsqueeze(1).repeat(1, video_length, 1, 1)
print(masked_key.shape)
# 实现方法2:
masked_key = repeat(key[:, attn_mask.bool()], "b d c -> b f d c", f=video_length) # 这里的d由mask决定保留多少个token
print(masked_key.shape)
'''for efficient cross-frame attention'''
if self.controller and self.controller.use_cfattn and (not crossattn):
video_length = key.size()[0] // self.unet_chunk_size # BC // C = B
former_frame_index = [0] * video_length # pervious frame index
attn_mask = None # get attn_mask from controller
if self.controller.attn_mask is not None:
for m in self.controller.attn_mask:
if m.shape[1] == key.shape[1]:
attn_mask = m
# get key and value in former_frame_index range
# BC * HW * 8D --> B * C * HW * 8D
key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
# B * C * HW * 8D --> B * C * HW * 8D
if attn_mask is None:
key = key[:, former_frame_index]
else:
key = repeat(key[:, attn_mask], "b d c -> b f d c", f=video_length)
# B * C * HW * 8D --> BC * HW * 8D
key = rearrange(key, "b f d c -> (b f) d c").detach()
value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
if attn_mask is None:
value = value[:, former_frame_index]
else:
value = repeat(value[:, attn_mask], "b d c -> b f d c", f=video_length)
value = rearrange(value, "b f d c -> (b f) d c").detach()
# BC * HW * 8D --> BC * HW * 8 * D --> BC * 8 * HW * D
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# BC * 8 * HW2 * D
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# BC * 8 * HW2 * D2
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
# output: BC * 8 * HW * D2
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
- temporal-guided inter-frame attention (FLATTEN):从controller的
interattn_paras
中取出bwd_mapping
、fwd_mapping
、interattn_masks
。
'''for temporal-guided inter-frame attention (FLATTEN)'''
if self.controller and self.controller.use_interattn and (not crossattn):
del query, key, value
torch.cuda.empty_cache()
bwd_mapping, fwd_mapping, flattn_mask = None, None, None
for i, f in enumerate(self.controller.interattn_paras['fwd_mappings']):
if f.shape[2] == hidden_states.shape[2]:
fwd_mapping = f
bwd_mapping = self.controller.interattn_paras['bwd_mappings'][i]
interattn_mask = self.controller.interattn_paras['interattn_masks'][i]
video_length = key_raw.size()[0] // self.unet_chunk_size
# BC * HW * 8D --> C * 8BD * HW
key = rearrange(key_raw, "(b f) d c -> f (b c) d", f=video_length)
query = rearrange(query_raw, "(b f) d c -> f (b c) d", f=video_length)
# BC * 8 * HW * D --> C * 8BD * HW
#key = rearrange(hidden_states, "(b f) h d c -> f (b h c) d", f=video_length) ########
#query = rearrange(hidden_states, "(b f) h d c -> f (b h c) d", f=video_length) #######
value = rearrange(hidden_states, "(b f) h d c -> f (b h c) d", f=video_length)
key = torch.gather(key, 2, fwd_mapping.expand(-1,key.shape[1],-1))
query = torch.gather(query, 2, fwd_mapping.expand(-1,query.shape[1],-1))
value = torch.gather(value, 2, fwd_mapping.expand(-1,value.shape[1],-1))
# C * 8BD * HW --> BHW, C, 8D
key = rearrange(key, "f (b c) d -> (b d) f c", b=self.unet_chunk_size)
query = rearrange(query, "f (b c) d -> (b d) f c", b=self.unet_chunk_size)
value = rearrange(value, "f (b c) d -> (b d) f c", b=self.unet_chunk_size)
# BHW * C * 8D --> BHW * C * 8 * D--> BHW * 8 * C * D
query = query.view(-1, video_length, attn.heads, head_dim).transpose(1, 2).detach()
key = key.view(-1, video_length, attn.heads, head_dim).transpose(1, 2).detach()
value = value.view(-1, video_length, attn.heads, head_dim).transpose(1, 2).detach()
hidden_states_ = F.scaled_dot_product_attention(
query, key * self.controller.interattn_scale_factor, value,
attn_mask = (interattn_mask.repeat(self.unet_chunk_size,1,1,1))#.to(query.dtype)-1.0) * 1e6 -
#torch.eye(interattn_mask.shape[2]).to(query.device).to(query.dtype) * 1e4,
)
# BHW * 8 * C * D --> C * 8BD * HW
hidden_states_ = rearrange(hidden_states_, "(b d) h f c -> f (b h c) d", b=self.unet_chunk_size)
hidden_states_ = torch.gather(hidden_states_, 2, bwd_mapping.expand(-1,hidden_states_.shape[1],-1)).detach()
# C * 8BD * HW --> BC * 8 * HW * D
hidden_states = rearrange(hidden_states_, "f (b h c) d -> (b f) h d c", b=self.unet_chunk_size, h=attn.heads)
#print('inter: ', GPU.getGPUs()[1].memoryUsed)
# BC * 8 * HW * D --> BC * HW * 8D
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
2. get_flow_and_interframe_paras
利用光流模型GMFlow,求帧间temporal-guided attention的参数:optical_flow
和 occlusion_mask
;为temporal-guided attention(FLANN)计算pixel index 像素索引对应关系
(同一条光流轨迹上的pixel)
- 数据预处理
将输入的numpy图像数据转换为PyTorch张量,并调整其维度顺序(从HWC变为CHW)。将所有图像堆叠为一个批量(batch),并将整个批量转移到GPU上进行加速计算。
创建另一个张量imgs_torch,通过numpy2tensor函数将原始imgs中的每个图像转换为张量后进行拼接。这一步主要用于后续的可视化。
def get_flow_and_interframe_paras(flow_model, imgs, visualize_pipeline=False):
"""
Get parameters for temporal-guided attention and optimization
* predict optical_flow and occlusion_mask
* compute pixel index correspondence for FLATTEN
"""
images = torch.stack([torch.from_numpy(img).permute(2, 0, 1).float() for img in imgs], dim=0).cuda()
imgs_torch = torch.cat([numpy2tensor(img) for img in imgs], dim=0) # torch.Size([batch=8, c=3, h=512, w=640])
- 光流计算与遮挡掩码生成
生成一个reshuffle_list,将输入图像列表中除第一个外的所有元素按顺序排列,然后将第一个元素添加到列表末尾。这个列表用于计算原图像与重排图像之间的光流。
使用flow_model计算原图像与重排图像之间的光流,并获取预测结果字典results_dict。从中取出最后一个时间步的光流预测flow_pr,将其拆分为前向光流fwd_flows和后向光流bwd_flows。
调用forward_backward_consistency_check函数,根据前向和后向光流计算前向遮挡掩码fwd_occs和后向遮挡掩码bwd_occs。
reshuffle_list = list(range(1,len(images)))+[0]
# compute flow between images and reshuffle_images
results_dict = flow_model(images, images[reshuffle_list], attn_splits_list=[2],
corr_radius_list=[-1], prop_radius_list=[-1], pred_bidir_flow=True)
flow_pr = results_dict['flow_preds'][-1] # [2*B, 2, H, W] torch.Size([16, 2, 512, 640])
fwd_flows, bwd_flows = flow_pr.chunk(2) # [B, 2, H, W] torch.Size([8, 2, 512, 640]), torch.Size([8, 2, 512, 640])
fwd_occs, bwd_occs = forward_backward_consistency_check(fwd_flows, bwd_flows) # [B, H, W] torch.Size([8, 512, 640]), torch.Size([8, 512, 640])
- 遮挡掩码增强
利用后向光流对原图像进行反向扭曲,得到warped_image1。
更新后向遮挡掩码bwd_occs:当原图像与反向扭曲图像之间的像素差异超过阈值(255*0.25)时,相应位置的遮挡掩码值增加。
同理,利用前向光流对重排图像进行前向扭曲,得到warped_image2,并更新前向遮挡掩码fwd_occs。
warped_image1 = flow_warp(images, bwd_flows) # torch.Size([8, 3, 512, 640])
bwd_occs = torch.clamp(bwd_occs + (abs(images[reshuffle_list]-warped_image1).mean(dim=1)>255*0.25).float(), 0 ,1) # torch.Size([8, 512, 640])
warped_image2 = flow_warp(images[reshuffle_list], fwd_flows) # torch.Size([8, 3, 512, 640])
fwd_occs = torch.clamp(fwd_occs + (abs(images-warped_image2).mean(dim=1)>255*0.25).float(), 0 ,1) # torch.Size([8, 512, 640])
- 可视化(可选)
如果visualize_pipeline为True,则打印消息并基于光流计算的遮挡掩码对原始图像进行可视化。使用torchvision.utils.make_grid创建网格状图像,并调用visualize函数显示。
if visualize_pipeline:
print('visualized occlusion masks based on optical flows')
viz = torchvision.utils.make_grid(imgs_torch * (1-fwd_occs.unsqueeze(1)), len(images), 1)
visualize(viz.cpu(), 90)
viz = torchvision.utils.make_grid(imgs_torch[reshuffle_list] * (1-bwd_occs.unsqueeze(1)), len(images), 1)
visualize(viz.cpu(), 90)
- cross-frame attention的掩码计算
对于不同的缩放比例(8.0, 16.0, 32.0),即up_blocks.1, up_blocks.2和up_blocks.3,对后向遮挡掩码进行下采样,并构建注意力掩码列表attn_mask。注意力掩码表示在不同尺度下,哪些像素更可能参与注意力机制。
attn_mask = []
for scale in [8.0, 16.0, 32.0]:
bwd_occs_ = F.interpolate(bwd_occs[:-1].unsqueeze(1), scale_factor=1./scale, mode='bilinear') # torch.Size([7, 1, 64, 80]), torch.Size([7, 1, 32, 40]), torch.Size([7, 1, 16, 20])
attn_mask += [torch.cat((bwd_occs_[0:1].reshape(1,-1)>-1, bwd_occs_.reshape(bwd_occs_.shape[0],-1)>0.5), dim=0)] # torch.Size([8, 5120]), torch.Size([8, 1280]), torch.Size([8, 320])
- 光流轨迹pixel映射关系与temporal-guided attention的掩码计算
对选定的缩放比例(8.0, 16.0),即up_blocks.2和up_blocks.3,调用get_mapping_ind函数计算前向(fwd_mappings)和后向(bwd_mappings)像素映射关系以及交互注意力掩码(interattn_masks)。
fwd_mappings = []
bwd_mappings = []
interattn_masks = []
for scale in [8.0, 16.0]:
fwd_mapping, bwd_mapping, interattn_mask = get_mapping_ind(bwd_flows, bwd_occs, imgs_torch, scale=scale)
fwd_mappings += [fwd_mapping] # torch.Size([8, 1, 5120]), torch.Size([8, 1, 1280])
bwd_mappings += [bwd_mapping] # torch.Size([8, 1, 5120]), torch.Size([8, 1, 1280])
interattn_masks += [interattn_mask] # torch.Size([5120, 1, 8, 8]), torch.Size([1280, 1, 8, 8])
interattn_paras = {
}
interattn_paras['fwd_mappings'] = fwd_mappings
interattn_paras['bwd_mappings'] = bwd_mappings
interattn_paras['interattn_masks'] = interattn_masks
gc.collect()
torch.cuda.empty_cache()
return [fwd_flows, bwd_flows], [fwd_occs, bwd_occs], attn_mask, interattn_paras
3. get_intraframe_paras
这里只看对spatial-guided attentiond 影响:先把所有controler的都清空关闭。
frescoProc.controller.disable_controller()
frescoProc.controller.clear_store()
frescoProc.controller.enable_store()
frescoProc.controller.disable_store()
再在FRESCO Inference 流程中打开3种attention:
# Turn on all FRESCO support
frescoProc.controller.enable_controller(interattn_paras=interattn_paras, attn_mask=attn_mask)
4. inference
一次采样batch_size=8个关键帧:
latents = inference(pipe, controlnet, frescoProc,
imgs_torch, prompt_embeds, edges, timesteps,
cond_scale, config['num_inference_steps'], config['num_warmup_steps'],
do_classifier_free_guidance, config['seed'], guidance_scale, config['use_controlnet'],
record_latents, propagation_mode,
flows = flows, occs = occs, saliency=saliency, repeat_noise=True)
- 初始化垃圾回收、清空CUDA缓存、设置设备和生成器;
- 根据输入参数初始化或准备嵌入(latents);
- Free-Init:如果设置了重复噪声(repeat_noise),则将latents中的第一个元素沿批次维度重复B次;
if repeat_noise: # True
latents = latents[0:1].repeat(B,1,1,1).detach() # 将 latents 中的第一个元素沿批次维度重复 B 次
- SDEidting:如果num_warmup_steps小于0,则将num_warmup_steps设为0,否则使用noisy latent作为输入;
if num_warmup_steps < 0: # num_warmup_steps: SDEdit add noise step = (num_inference_steps-num_warmup_steps)
latents_init = latents.detach()
num_warmup_steps = 0
else:
# SDEdit: use the noisy latent of imges as the input rather than a pure gausssian noise
latent_x0 = pipe.vae.config.scaling_factor * pipe.vae.encode(imgs.to(pipe.unet.dtype)).latent_dist.sample()
latents_init = noise_scheduler.add_noise(latent_x0, latents, timesteps[num_warmup_steps]).detach()
- 使用pipe的进度条,运行num_inference_steps-num_warmup_steps步;
- 在每一步中,根据条件应用空间/时间引导注意力、记录和恢复嵌入、扩展嵌入;
# timesteps=[951, 901, 851, 801, 751, 701, 651, 601, 551, 501, 451, 401, 351, 301, 251, 201, 151, 101, 51, 1]
if i >= num_intraattn_steps: # only use spatial-guided attention in i=num_intraattn_steps=1 (start 1/20 step)
frescoProc.controller.disable_intraattn()
if t < step_interattn_end: # only use temporal-guided attention in t<step_interattn_end=350 (last 7/20 steps)
frescoProc.controller.disable_interattn()
在 inference 函数中对 record_latents 的修改会直接影响到原始列表:对第一个batch的8个frame来说,propagation_mode=False,不断的记录每个timestep的当前batch中第1帧和最后1帧的latents到record_latents中。对后面的batch来说,propagation_mode=True,每个timestep不断取前1个batch的record_latents替换当前batch的前2帧:
if propagation_mode: # restore latent from previous batch and record latent of the current batch
latents[0:2] = record_latents[i].detach().clone()
record_latents[i] = latents[[0,len(latents)-1]].detach().clone()
else: # frist batch, record_latents[0][t] = [x_{1,t}, x_{N,t}] = [first_frame, last_frame]
record_latents += [latents[[0,len(latents)-1]].detach().clone()] # add torch.Size([2, 4, 64, 80])
- 如果使用控制网络,则将控制网络的输出添加到unet的输入中;
- 预测噪声残差;
- 如果设置了分类器自由指导,则对噪声残差进行处理;
- 根据是否在背景平滑步骤中,调用不同的step函数来计算前一步的噪声样本;
- 返回嵌入(latents)。
@torch.no_grad()
def inference(pipe, controlnet, frescoProc,
imgs, prompt_embeds, edges, timesteps,
cond_scale=[0.7]*20, num_inference_steps=20, num_warmup_steps=6,
do_classifier_free_guidance=True, seed=0, guidance_scale=7.5, use_controlnet=True,
record_latents=[], propagation_mode=False, visualize_pipeline=False,
flows = None, occs = None, saliency=None, repeat_noise=False,
num_intraattn_steps = 1, step_interattn_end = 350, bg_smoothing_steps = [16,17]):
"""
video-to-video translation inference pipeline with FRESCO
* add controlnet and SDEdit
* add FRESCO-guided attention
* add FRESCO-guided optimization
* add background smoothing
* add support for inter-batch long video translation
[input of the original pipe]
pipe: base diffusion model
imgs: a batch of the input frames
prompt_embeds: prompts
num_inference_steps: number of DDPM steps
timesteps: generated by pipe.scheduler.set_timesteps(num_inference_steps)
do_classifier_free_guidance: cfg, should be always true
guidance_scale: cfg scale
seed
[input of SDEdit]
num_warmup_steps: skip the first num_warmup_steps DDPM steps
[input of controlnet]
use_controlnet: bool, whether using controlnet
controlnet: controlnet model
edges: input for controlnet (edge/stroke/depth, etc.)
cond_scale: controlnet scale
[input of FRESCO]
frescoProc: FRESCO attention controller
flows: optical flows
occs: occlusion mask
num_intraattn_steps: apply num_interattn_steps steps of spatial-guided attention
step_interattn_end: apply temporal-guided attention in [step_interattn_end, 1000] steps
[input for background smoothing]
saliency: saliency mask
repeat_noise: bool, use the same noise for all frames
bg_smoothing_steps: apply background smoothing in bg_smoothing_steps
[input for long video translation]
record_latents: recorded latents in the last batch
propagation_mode: bool, whether this is not the first batch
[output]
latents: a batch of latents of the translated frames
"""
gc.collect()
torch.cuda.empty_cache()
device = pipe._execution_device
noise_scheduler = pipe.scheduler
generator = torch.Generator(device=device).manual_seed(seed)
B, C, H, W = imgs.shape # torch.Size([8, 3, 512, 640])
latents = pipe.prepare_latents( # torch.Size([8, 4, 64, 80])
B,
pipe.unet.config.in_channels,
H,
W,
prompt_embeds.dtype,
device,
generator,
latents = None,
)
if repeat_noise: # True
latents = latents[0:1].repeat(B,1,1,1).detach() # 将 latents 中的第一个元素沿批次维度重复 B 次
if num_warmup_steps < 0: # num_warmup_steps: SDEdit add noise step = (num_inference_steps-num_warmup_steps)
latents_init = latents.detach()
num_warmup_steps = 0
else:
# SDEdit: use the noisy latent of imges as the input rather than a pure gausssian noise
latent_x0 = pipe.vae.config.scaling_factor * pipe.vae.encode(imgs.to(pipe.unet.dtype)).latent_dist.sample()
latents_init = noise_scheduler.add_noise(latent_x0, latents, timesteps[num_warmup_steps]).detach()
# SDEdit, run num_inference_steps-num_warmup_steps steps
with pipe.progress_bar(total=num_inference_steps-num_warmup_steps) as progress_bar:
latents = latents_init # torch.Size([8, 4, 64, 80])
for i, t in enumerate(timesteps[num_warmup_steps:]):
"""
[HACK] control the steps to apply spatial/temporal-guided attention
[HACK] record and restore latents from previous batch
"""
# timesteps=[951, 901, 851, 801, 751, 701, 651, 601, 551, 501, 451, 401, 351, 301, 251, 201, 151, 101, 51, 1]
if i >= num_intraattn_steps: # only use spatial-guided attention in i=num_intraattn_steps=1 (start 1/20 step)
frescoProc.controller.disable_intraattn()
if t < step_interattn_end: # only use temporal-guided attention in t<step_interattn_end=350 (last 7/20 steps)
frescoProc.controller.disable_interattn()
if propagation_mode: # restore latent from previous batch and record latent of the current batch
latents[0:2] = record_latents[i].detach().clone()
record_latents[i] = latents[[0,len(latents)-1]].detach().clone()
else: # frist batch, record_latents[0][t] = [x_{1,t}, x_{N,t}] = [first_frame, last_frame]
record_latents += [latents[[0,len(latents)-1]].detach().clone()] # add torch.Size([2, 4, 64, 80])
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if use_controlnet:
control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds
down_block_res_samples, mid_block_res_sample = controlnet(
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=edges,
conditioning_scale=cond_scale[i+num_warmup_steps],
guess_mode=False,
return_dict=False,
)
else:
down_block_res_samples, mid_block_res_sample = None, None
# predict the noise residual
noise_pred = pipe.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=None,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
"""
[HACK] background smoothing
Note: bg_smoothing_steps should be rescaled based on num_inference_steps
current [16,17] is based on num_inference_steps=20
"""
if i + num_warmup_steps in bg_smoothing_steps:
latents = step(pipe, noise_pred, t, latents, generator,
visualize_pipeline=visualize_pipeline,
flows = flows, occs = occs, saliency=saliency)[0]
else:
latents = step(pipe, noise_pred, t, latents, generator,
visualize_pipeline=visualize_pipeline)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > 0 and (i + 1) % pipe.scheduler.order == 0):
progress_bar.update()
return latents