diffusers中的dreambooth的微调和lora微调

train_dreambooth.py

代码:

accelerator = Accelerator()->

# Generate class image if prior oreservation is enabled
if args.with_prior_preservation:
    if cur_class_images<args.num_class_images:
        pipeline = DiffusionPipline.from_pretrained()->
        sample_dataset = PromptDataset(args.class_prompt, num_new_images)
        sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
        for example in sample_dataloader:
            images = pipeline(example['prompt']).images

tokenizer = AutoTokenizer.from_pretrained(,"tokenizer")->
text_encoder_cls = import_model_class_from_model_name_or_path()->
noise_scheduler = DDPMScheduler.from_pretrained(,"scheduler")->
text_encoder = text_encoder_cls.from_pretrained(,"text_encoder")->
vae = AutoencoderKL.from_pretrained(,"vae")->
unet = UNet2DConditionModel.from_pretrained(,"unet")->

accelerator.register_save_state_pre_hook(save_model_hook)->
accelerator.register_load_state_pre_hook(load_model_hook)->

vae.requires_grad_(False)->
text_encoder.requires_grad_(False)->

unet.enable_gradient_checkpointing()->
if args.train_text_encoder:
    text_encoder.gradient_checkpointing_enable()

optimizer_class = torch.optim.AdamW->
params_to_optimize = (itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters())->
optimizer = optimizer_class(params_to_optimize,lr,betas,weight_decay,eps)->

train_dataset = DreamBoothDataset(instance_data_root,instance_prompt,class_data_root,class_prompt,class_num,tokenizer,size,center_crop,encoder_hidden_states,instance_prompt_encoder_hidden_states,tokenizer_max_length)->
train_dataloader = torch.utils.data.DataLoader(train_dataset,batch_size)

num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
lr_scheduler = get_scheduler(lr_scheduler,optimizer,num_warmup_steps,num_training_steps,num_cycles,power)->
unet,text_encoder,optimizer,train_dataloader,lr_scheduler = accelerator.prepare(unet,text_encoder,optimizer,train_dataloader,lr_scheduler)->

for epoch in rang(first_epoch,args.num_train_epochs):
    unet.train()
    if args.train_text_encoder:
        text_encoder.train()
    for step,batch in enumerate(train_dataloader):
        with accelerator.accumulate(unet):
            pixel_values = batch['pixel_values'].to(weight_dtype)
            
            model_input = vae.encode(batch['pixel_values'].to().latent_dist.sample())
            model_input = model_input*vae.config.scaling_factor
            noise = torch.randn_like(model_input)
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device)    
            noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
            encoder_hidden_states = encode_prompt(text_encoder,batch["input_ids"],batch["attention_mask"],                       text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,)
            model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels).sample
            
            target = noise
            model_pred,model_pred_prior = torch.chunk(model_pred,2,dim=0)->
            target,target_prior = torch.chunk(target,2,dim=2)->
            
            loss = F.mse_loss(model_pred.float(),target.float())
            prior_loss = F.mse_loss(model_pred_prior.float(),target_prior.float())
            loss = loss+args.prior_loss_weight*prior_loss
            
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

train_dreambooth_lora.py

accelerate = Accelerator()->

tokenizer = AutoTokenizer.from_pretrained(,"tokenizer")->
text_encoder_cls = import_model_class_from_model_name_or_path()->
noise_scheduler = DDPMScheduler.from_pretrained(,"scheduler")->
text_encoder = text_encoder_cls.from_pretrained(,"text_encoder")->
vae = AutoencoderKL.from_pretrained(,"vae")->
unet = UNet2DConditionModel.from_pretrained(,"unet")->

vae.requires_grad_(False)->
text_encoder.requires_grad_(False)->
unet.requires_grad_(False)->

# now we will add new LoRA weights to the attention layers
# It's important to realize here how many attention weights will be added and of which sizes
# The sizes of the attention layers consist only of two different variables:
# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.

# Let's first see how many attention processors we will have to set.
# For Stable Diffusion, it should be equal to:
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
# - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
# => 32 layers

for name,attn_processor in unet.attn_processor.items():
    if isinstance(attn_processor,(AttnAddedKVProcessor,SlicedAttnAddedKVProcessor,AttnAddedKVProcessor)):
        lora_attn_processor_class = LoRAAttnAddedKVProcessor
    else:
        lora_attn_process_class = (LoRAAttnProcessor2_0 if hasattr(F,'scaled_dot_product_attention') else LoRAAttnProcessor)
    
    module = lora_attn_processor_class(hidden_size,cross_attention_dim,rank)
    unet_lora_attn_procs[name] = module
    unet_lora_parameters.extend(module.parameters())
unet.set_attn_processor(unet_lora_attn_procs)

text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder)
optimizer_class = torch.optim.AdamW->
params_to_optimize = (itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters())->
optimizer = optimizer_class(params_to_optimize,lr,betas,weight_decay,eps)->

train_dataset = DreamBoothDataset(instance_data_root,instance_prompt,class_data_root,class_prompt,class_num,tokenizer,size,center_crop,encoder_hidden_states,instance_prompt_encoder_hidden_states,tokenizer_max_length)->
train_dataloader = torch.utils.data.DataLoader(train_dataset,batch_size)

num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
lr_scheduler = get_scheduler(lr_scheduler,optimizer,num_warmup_steps,num_training_steps,num_cycles,power)->
unet,text_encoder,optimizer,train_dataloader,lr_scheduler = accelerator.prepare(unet,text_encoder,optimizer,train_dataloader,lr_scheduler)->

for epoch in range(first_epoch,args.num_train_epoch):
    unet.train()
    text_encoder.train()
    for step,batch in enumerate(train_dataloader):
        with accelerate.accumulate(unet):
            pixel_values = batch['pixel_values']
            model_input = vae.encode(pixel_values).latent_dist_sample()
            model_input = model_input*vae.config.scaling_factor
            
            noise = torch.randn_like(model_input)
            timesteps = torch.randint(0,noise_scheduler.config.num_train_timesteps)
            noisy_model_input = noise_scheduler.add_noise(model_input,noise,timesteps)            
            encoder_hidden_states = encode_prompt(text_encoder,batch["input_ids"],batch["attention_mask"],                       text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,)
            model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels).sample
            
            target = noise
            loss = F.mse_loss(model_pred.float(),target.float())
            
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

猜你喜欢

转载自blog.csdn.net/u012193416/article/details/132878543