# ref: # - https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L566 # - https://huggingface.co/spaces/baulab/Erasing-Concepts-In-Diffusion/blob/main/train.py from typing import List, Optional import argparse import ast from pathlib import Path import gc import torch from tqdm import tqdm from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV import train_util import model_util import prompt_util from prompt_util import PromptEmbedsCache, PromptEmbedsPair, PromptSettings import debug_util import config_util from config_util import RootConfig import wandb def flush(): torch.cuda.empty_cache() gc.collect() def train( config: RootConfig, prompts: list[PromptSettings], device: int ): metadata = { "prompts": ",".join([prompt.json() for prompt in prompts]), "config": config.json(), } save_path = Path(config.save.path) modules = DEFAULT_TARGET_REPLACE if config.network.type == "c3lier": modules += UNET_TARGET_REPLACE_MODULE_CONV if config.logging.verbose: print(metadata) if config.logging.use_wandb: wandb.init(project=f"LECO_{config.save.name}", config=metadata) weight_dtype = config_util.parse_precision(config.train.precision) save_weight_dtype = config_util.parse_precision(config.train.precision) tokenizer, text_encoder, unet, noise_scheduler = model_util.load_models( config.pretrained_model.name_or_path, scheduler_name=config.train.noise_scheduler, v2=config.pretrained_model.v2, v_pred=config.pretrained_model.v_pred, ) text_encoder.to(device, dtype=weight_dtype) text_encoder.eval() unet.to(device, dtype=weight_dtype) unet.enable_xformers_memory_efficient_attention() unet.requires_grad_(False) unet.eval() network = LoRANetwork( unet, rank=config.network.rank, multiplier=1.0, alpha=config.network.alpha, train_method=config.network.training_method, ).to(device, dtype=weight_dtype) optimizer_module = train_util.get_optimizer(config.train.optimizer) #optimizer_args optimizer_kwargs = {} if config.train.optimizer_args is not None and len(config.train.optimizer_args) > 0: for arg in config.train.optimizer_args.split(" "): key, value = arg.split("=") value = ast.literal_eval(value) optimizer_kwargs[key] = value optimizer = optimizer_module(network.prepare_optimizer_params(), lr=config.train.lr, **optimizer_kwargs) lr_scheduler = train_util.get_lr_scheduler( config.train.lr_scheduler, optimizer, max_iterations=config.train.iterations, lr_min=config.train.lr / 100, ) criteria = torch.nn.MSELoss() print("Prompts") for settings in prompts: print(settings) # debug debug_util.check_requires_grad(network) debug_util.check_training_mode(network) cache = PromptEmbedsCache() prompt_pairs: list[PromptEmbedsPair] = [] with torch.no_grad(): for settings in prompts: print(settings) for prompt in [ settings.target, settings.positive, settings.neutral, settings.unconditional, ]: print(prompt) if isinstance(prompt, list): if prompt == settings.positive: key_setting = 'positive' else: key_setting = 'attributes' if len(prompt) == 0: cache[key_setting] = [] else: if cache[key_setting] is None: cache[key_setting] = train_util.encode_prompts( tokenizer, text_encoder, prompt ) else: if cache[prompt] == None: cache[prompt] = train_util.encode_prompts( tokenizer, text_encoder, [prompt] ) prompt_pairs.append( PromptEmbedsPair( criteria, cache[settings.target], cache[settings.positive], cache[settings.unconditional], cache[settings.neutral], settings, ) ) del tokenizer del text_encoder flush() pbar = tqdm(range(config.train.iterations)) for i in pbar: with torch.no_grad(): noise_scheduler.set_timesteps( config.train.max_denoising_steps, device=device ) optimizer.zero_grad() prompt_pair: PromptEmbedsPair = prompt_pairs[ torch.randint(0, len(prompt_pairs), (1,)).item() ] # 1 ~ 49 からランダム timesteps_to = torch.randint( 1, config.train.max_denoising_steps, (1,) ).item() height, width = ( prompt_pair.resolution, prompt_pair.resolution, ) if prompt_pair.dynamic_resolution: height, width = train_util.get_random_resolution_in_bucket( prompt_pair.resolution ) if config.logging.verbose: print("guidance_scale:", prompt_pair.guidance_scale) print("resolution:", prompt_pair.resolution) print("dynamic_resolution:", prompt_pair.dynamic_resolution) if prompt_pair.dynamic_resolution: print("bucketed resolution:", (height, width)) print("batch_size:", prompt_pair.batch_size) latents = train_util.get_initial_latents( noise_scheduler, prompt_pair.batch_size, height, width, 1 ).to(device, dtype=weight_dtype) with network: # ちょっとデノイズされれたものが返る denoised_latents = train_util.diffusion( unet, noise_scheduler, latents, # 単純なノイズのlatentsを渡す train_util.concat_embeddings( prompt_pair.unconditional, prompt_pair.target, prompt_pair.batch_size, ), start_timesteps=0, total_timesteps=timesteps_to, guidance_scale=3, ) noise_scheduler.set_timesteps(1000) current_timestep = noise_scheduler.timesteps[ int(timesteps_to * 1000 / config.train.max_denoising_steps) ] # with network: の外では空のLoRAのみが有効になる positive_latents = train_util.predict_noise( unet, noise_scheduler, current_timestep, denoised_latents, train_util.concat_embeddings( prompt_pair.unconditional, prompt_pair.positive, prompt_pair.batch_size, ), guidance_scale=1, ).to(device, dtype=weight_dtype) neutral_latents = train_util.predict_noise( unet, noise_scheduler, current_timestep, denoised_latents, train_util.concat_embeddings( prompt_pair.unconditional, prompt_pair.neutral, prompt_pair.batch_size, ), guidance_scale=1, ).to(device, dtype=weight_dtype) unconditional_latents = train_util.predict_noise( unet, noise_scheduler, current_timestep, denoised_latents, train_util.concat_embeddings( prompt_pair.unconditional, prompt_pair.unconditional, prompt_pair.batch_size, ), guidance_scale=1, ).to(device, dtype=weight_dtype) ######################### if config.logging.verbose: print("positive_latents:", positive_latents[0, 0, :5, :5]) print("neutral_latents:", neutral_latents[0, 0, :5, :5]) print("unconditional_latents:", unconditional_latents[0, 0, :5, :5]) with network: target_latents = train_util.predict_noise( unet, noise_scheduler, current_timestep, denoised_latents, train_util.concat_embeddings( prompt_pair.unconditional, prompt_pair.target, prompt_pair.batch_size, ), guidance_scale=1, ).to(device, dtype=weight_dtype) ######################### if config.logging.verbose: print("target_latents:", target_latents[0, 0, :5, :5]) positive_latents.requires_grad = False neutral_latents.requires_grad = False unconditional_latents.requires_grad = False loss = prompt_pair.loss( target_latents=target_latents, positive_latents=positive_latents, neutral_latents=neutral_latents, unconditional_latents=unconditional_latents, ) # 1000倍しないとずっと0.000...になってしまって見た目的に面白くない pbar.set_description(f"Loss*1k: {loss.item()*1000:.4f}") if config.logging.use_wandb: wandb.log( {"loss": loss, "iteration": i, "lr": lr_scheduler.get_last_lr()[0]} ) loss.backward() optimizer.step() lr_scheduler.step() del ( positive_latents, neutral_latents, unconditional_latents, target_latents, latents, ) flush() if ( i % config.save.per_steps == 0 and i != 0 and i != config.train.iterations - 1 ): print("Saving...") save_path.mkdir(parents=True, exist_ok=True) network.save_weights( save_path / f"{config.save.name}_{i}steps.pt", dtype=save_weight_dtype, ) print("Saving...") save_path.mkdir(parents=True, exist_ok=True) network.save_weights( save_path / f"{config.save.name}_last.pt", dtype=save_weight_dtype, ) del ( unet, noise_scheduler, loss, optimizer, network, ) flush() print("Done.") def main(args): config_file = args.config_file config = config_util.load_config_from_yaml(config_file) if args.name is not None: config.save.name = args.name attributes = [] if args.attributes is not None: attributes = args.attributes.split(',') attributes = [a.strip() for a in attributes] config.network.alpha = args.alpha config.network.rank = args.rank config.save.name += f'_alpha{args.alpha}' config.save.name += f'_rank{config.network.rank }' config.save.name += f'_{config.network.training_method}' config.save.path += f'/{config.save.name}' prompts = prompt_util.load_prompts_from_yaml(config.prompts_file, attributes) device = torch.device(f"cuda:{args.device}") train(config=config, prompts=prompts, device=device) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--config_file", required=False, default = 'data/config.yaml', help="Config file for training.", ) # config_file 'data/config.yaml' parser.add_argument( "--alpha", type=float, required=True, help="LoRA weight.", ) # --alpha 1.0 parser.add_argument( "--rank", type=int, required=False, help="Rank of LoRA.", default=4, ) # --rank 4 parser.add_argument( "--device", type=int, required=False, default=0, help="Device to train on.", ) # --device 0 parser.add_argument( "--name", type=str, required=False, default=None, help="Device to train on.", ) # --name 'eyesize_slider' parser.add_argument( "--attributes", type=str, required=False, default=None, help="attritbutes to disentangle (comma seperated string)", ) # --attributes 'male, female' args = parser.parse_args() main(args)