前言:IP-Adapter作为Diffusion Models最成功的技术之一,已经在诸多互联网应用中落地。介绍IP-Adapter原理和应用的博客有很多,但是逐行详细解读代码的博客很少。这篇博客从细节出发,结合原理详细解读个性化定制大杀器IP-Adapter代码。
目录
原理概述
一句话概括:原有的Cross Attention计算是用text condition计算的;IP-Adapter在原有Cross Attention计算上加上了image condition。其中Q共用,K V重新计算。
在整个训练过程中,冻结了最初的 UNet 模型,所以在上述解耦交叉注意力中,只有W′k,W′v 两个参数是可训练的。而且W′k,W′v 的权重从原来对应的Cross Attention初始化。
代码详解
冻结模型
这张图中蓝色部分全部冻结:
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
image_encoder.requires_grad_(False)
红色部分可训练:
params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())
Image Projection
对应于下图中的这个部分,由一个Linear层和一个Norm组成:
image_proj_model = ImageProjModel(
cross_attention_dim=unet.config.cross_attention_dim,
clip_embeddings_dim=image_encoder.config.projection_dim,
clip_extra_context_tokens=4,
)
class ImageProjModel(torch.nn.Module):
"""Projection Model"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()
self.generator = None
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
经过这个模型得到Image Feature,它的维度需要与计算Attention的维度对齐。
注意力替换
在原始的Unet模型中,需要把Cross Attention的计算替换成新的,但是self- attention不变。
其中attn1.processor 是 self- attention,这部分不变;attn2.processor代表交叉注意力层,这部分替换!
W′k,W′v 的权重从原来对应的Cross Attention初始化!也就是代码中的:
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor()
else:
layer_name = name.split(".processor")[0]
weights = {
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
}
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
attn_procs[name].load_state_dict(weights)
这部分的完整代码如下:
attn_procs = {}
unet_sd = unet.state_dict()
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor()
else:
layer_name = name.split(".processor")[0]
weights = {
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
}
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
attn_procs[name].load_state_dict(weights)
unet.set_attn_processor(attn_procs)
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
新注意力计算
现在已经在Unet中把原来的Cross注意力计算都替换并且组装好了。
下一步我们看看被替换的注意力是如何计算的,也就是看看IPAttnProcessor 如何实现。
class IPAttnProcessor(nn.Module):
r"""
Attention processor for IP-Adapater.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
The context length of the image features.
"""
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.num_tokens = num_tokens
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
*args,
**kwargs,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
# get encoder_hidden_states, ip_hidden_states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
self.attn_map = ip_attention_probs
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
核心是计算出ip_hidden_states,然后与原来的hidden_state相加:
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
self.attn_map = ip_attention_probs
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
相加的时候有一个scale,这个scale是一个固定的参数1:
hidden_states = hidden_states + self.scale * ip_hidden_states
打包IP-Adapter
因为是用accelerate训练的,所以把可训练的部分都打包成一个单独的类 ip_adapter,用accelerator包装一下,就可以很方便更新梯度和保存权重。
ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
推理
累了,下次再写...
def generate(
self,
pil_image=None,
clip_image_embeds=None,
prompt=None,
negative_prompt=None,
scale=1.0,
num_samples=4,
seed=None,
guidance_scale=7.5,
num_inference_steps=30,
**kwargs,
):
self.set_scale(scale)
if pil_image is not None:
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
else:
num_prompts = clip_image_embeds.size(0)
if prompt is None:
prompt = "best quality, high quality"
if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
pil_image=pil_image, clip_image_embeds=clip_image_embeds
)
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
with torch.inference_mode():
prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
prompt,
device=self.device,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
generator = get_generator(seed, self.device)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
**kwargs,
).images
return images
源码地址
IP-Adapter/tutorial_train.py at main · tencent-ailab/IP-Adapter · GitHub