RohitGandikota commited on
Commit
1f8beea
1 Parent(s): cb9665a

testing layout

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +1 -0
  2. __init__.py +1 -0
  3. app.py +261 -4
  4. models/age.pt +3 -0
  5. models/cartoon_style.pt +3 -0
  6. models/chubby.pt +3 -0
  7. models/clay_style.pt +3 -0
  8. models/cluttered_room.pt +3 -0
  9. models/curlyhair.pt +3 -0
  10. models/dark_weather.pt +3 -0
  11. models/eyebrow.pt +3 -0
  12. models/eyesize.pt +3 -0
  13. models/festive.pt +3 -0
  14. models/fix_hands.pt +3 -0
  15. models/long_hair.pt +3 -0
  16. models/muscular.pt +3 -0
  17. models/pixar_style.pt +3 -0
  18. models/professional.pt +3 -0
  19. models/repair_slider.pt +3 -0
  20. models/sculpture_style.pt +3 -0
  21. models/smiling.pt +3 -0
  22. models/stylegan_latent1.pt +3 -0
  23. models/stylegan_latent2.pt +3 -0
  24. models/suprised_look.pt +3 -0
  25. models/tropical_weather.pt +3 -0
  26. models/winter_weather.pt +3 -0
  27. requirements.txt → reqs.txt +0 -0
  28. trainscripts/__init__.py +1 -0
  29. trainscripts/imagesliders/config_util.py +104 -0
  30. trainscripts/imagesliders/data/config-xl.yaml +28 -0
  31. trainscripts/imagesliders/data/config.yaml +28 -0
  32. trainscripts/imagesliders/data/prompts-xl.yaml +275 -0
  33. trainscripts/imagesliders/data/prompts.yaml +174 -0
  34. trainscripts/imagesliders/debug_util.py +16 -0
  35. trainscripts/imagesliders/lora.py +256 -0
  36. trainscripts/imagesliders/model_util.py +283 -0
  37. trainscripts/imagesliders/prompt_util.py +174 -0
  38. trainscripts/imagesliders/train_lora-scale-xl.py +548 -0
  39. trainscripts/imagesliders/train_lora-scale.py +501 -0
  40. trainscripts/imagesliders/train_util.py +458 -0
  41. trainscripts/textsliders/__init__.py +0 -0
  42. trainscripts/textsliders/config_util.py +104 -0
  43. trainscripts/textsliders/data/config-xl.yaml +28 -0
  44. trainscripts/textsliders/data/config.yaml +28 -0
  45. trainscripts/textsliders/data/prompts-xl.yaml +477 -0
  46. trainscripts/textsliders/data/prompts.yaml +193 -0
  47. trainscripts/textsliders/debug_util.py +16 -0
  48. trainscripts/textsliders/flush.py +5 -0
  49. trainscripts/textsliders/generate_images_xl.py +513 -0
  50. trainscripts/textsliders/lora.py +258 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from trainscripts.textsliders import lora
app.py CHANGED
@@ -1,7 +1,264 @@
1
  import gradio as gr
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import os
4
+ from utils import call
5
+ from diffusers.pipelines import StableDiffusionXLPipeline
6
+ StableDiffusionXLPipeline.__call__ = call
7
 
8
+ model_map = {'Age' : 'models/age.pt',
9
+ 'Chubby': 'models/chubby.pt',
10
+ 'Muscular': 'models/muscular.pt',
11
+ 'Wavy Eyebrows': 'models/eyebrows.pt',
12
+ 'Small Eyes': 'models/eyesize.pt',
13
+ 'Long Hair' : 'models/longhair.pt',
14
+ 'Curly Hair' : 'models/curlyhair.pt',
15
+ 'Smiling' : 'models/smiling.pt',
16
+ 'Pixar Style' : 'models/pixar_style.pt',
17
+ 'Sculpture Style': 'models/sculpture_style.pt',
18
+ 'Repair Images': 'models/repair_slider.pt',
19
+ 'Fix Hands': 'models/fix_hands.pt',
20
+ }
21
 
22
+ ORIGINAL_SPACE_ID = 'baulab/ConceptSliders'
23
+ SPACE_ID = os.getenv('SPACE_ID')
24
+
25
+ SHARED_UI_WARNING = f'''## Attention - Training does not work in this shared UI. You can either duplicate and use it with a gpu with at least 40GB, or clone this repository to run on your own machine.
26
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
27
+ '''
28
+
29
+
30
+ class Demo:
31
+
32
+ def __init__(self) -> None:
33
+
34
+ self.training = False
35
+ self.generating = False
36
+ self.device = 'cuda'
37
+ self.weight_dtype = torch.float16
38
+ self.pipe = StableDiffusionXLPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=weight_dtype)
39
+
40
+ with gr.Blocks() as demo:
41
+ self.layout()
42
+ demo.queue(concurrency_count=5).launch()
43
+
44
+
45
+ def layout(self):
46
+
47
+ with gr.Row():
48
+
49
+ if SPACE_ID == ORIGINAL_SPACE_ID:
50
+
51
+ self.warning = gr.Markdown(SHARED_UI_WARNING)
52
+
53
+ with gr.Row():
54
+
55
+ with gr.Tab("Test") as inference_column:
56
+
57
+ with gr.Row():
58
+
59
+ self.explain_infr = gr.Markdown(interactive=False,
60
+ value='This is a demo of [Concept Sliders: LoRA Adaptors for Precise Control in Diffusion Models](https://sliders.baulab.info/). To try out a model that can control a particular concept, select a model and enter any prompt. For example, if you select the model "Surprised Look" you can generate images for the prompt "A picture of a person, realistic, 8k" and compare the slider effect to the image generated by original model. We have also provided several other pre-fine-tuned models like "repair" sliders to repair flaws in SDXL generated images (Check out the "Pretrained Sliders" drop-down). You can also train and run your own custom sliders. Check out the "train" section for custom concept slider training.')
61
+
62
+ with gr.Row():
63
+
64
+ with gr.Column(scale=1):
65
+
66
+ self.prompt_input_infr = gr.Text(
67
+ placeholder="Enter prompt...",
68
+ label="Prompt",
69
+ info="Prompt to generate"
70
+ )
71
+
72
+ with gr.Row():
73
+
74
+ self.model_dropdown = gr.Dropdown(
75
+ label="Pretrained Sliders",
76
+ choices= list(model_map.keys()),
77
+ value='Age',
78
+ interactive=True
79
+ )
80
+
81
+ self.seed_infr = gr.Number(
82
+ label="Seed",
83
+ value=12345
84
+ )
85
+
86
+ with gr.Column(scale=2):
87
+
88
+ self.infr_button = gr.Button(
89
+ value="Generate",
90
+ interactive=True
91
+ )
92
+
93
+ with gr.Row():
94
+
95
+ self.image_new = gr.Image(
96
+ label="Slider",
97
+ interactive=False
98
+ )
99
+ self.image_orig = gr.Image(
100
+ label="Original SD",
101
+ interactive=False
102
+ )
103
+
104
+ with gr.Tab("Train") as training_column:
105
+
106
+ with gr.Row():
107
+
108
+ self.explain_train= gr.Markdown(interactive=False,
109
+ value='In this part you can train a concept slider for Stable Diffusion XL. Enter a target concept you wish to make an edit on. Next, enter a enhance prompt of the attribute you wish to edit (for controlling age of a person, enter "person, old"). Then, type the supress prompt of the attribute (for our example, enter "person, young"). Then press "train" button. With default settings, it takes about 15 minutes to train a slider; then you can try inference above or download the weights. Code and details are at [github link](https://github.com/rohitgandikota/sliders).')
110
+
111
+ with gr.Row():
112
+
113
+ with gr.Column(scale=3):
114
+
115
+ self.target_concept = gr.Text(
116
+ placeholder="Enter target concept to make edit on ...",
117
+ label="Prompt of concept on which edit is made",
118
+ info="Prompt corresponding to concept to edit"
119
+ )
120
+
121
+ self.positive_prompt = gr.Text(
122
+ placeholder="Enter the enhance prompt for the edit...",
123
+ label="Prompt to enhance",
124
+ info="Prompt corresponding to concept to enhance"
125
+ )
126
+
127
+ self.negative_prompt = gr.Text(
128
+ placeholder="Enter the suppress prompt for the edit...",
129
+ label="Prompt to suppress",
130
+ info="Prompt corresponding to concept to supress"
131
+ )
132
+
133
+
134
+ self.rank = gr.Number(
135
+ value=4,
136
+ label="Rank of the Slider",
137
+ info='Slider Rank to train'
138
+ )
139
+
140
+ self.iterations_input = gr.Number(
141
+ value=1000,
142
+ precision=0,
143
+ label="Iterations",
144
+ info='iterations used to train'
145
+ )
146
+
147
+ self.lr_input = gr.Number(
148
+ value=2e-4,
149
+ label="Learning Rate",
150
+ info='Learning rate used to train'
151
+ )
152
+
153
+ with gr.Column(scale=1):
154
+
155
+ self.train_status = gr.Button(value='', variant='primary', label='Status', interactive=False)
156
+
157
+ self.train_button = gr.Button(
158
+ value="Train",
159
+ )
160
+
161
+ self.download = gr.Files()
162
+
163
+ self.infr_button.click(self.inference, inputs = [
164
+ self.prompt_input_infr,
165
+ self.seed_infr,
166
+ self.model_dropdown
167
+ ],
168
+ outputs=[
169
+ self.image_new,
170
+ self.image_orig
171
+ ]
172
+ )
173
+ self.train_button.click(self.train, inputs = [
174
+ self.target_concept,
175
+ self.positive_prompt,
176
+ slef.negative_prompt,
177
+ self.rank,
178
+ self.iterations_input,
179
+ self.lr_input
180
+ ],
181
+ outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
182
+ )
183
+
184
+ def train(self, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
185
+
186
+ if self.training:
187
+ return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
188
+
189
+ if train_method == 'ESD-x':
190
+
191
+ modules = ".*attn2$"
192
+ frozen = []
193
+
194
+ elif train_method == 'ESD-u':
195
+
196
+ modules = "unet$"
197
+ frozen = [".*attn2$", "unet.time_embedding$", "unet.conv_out$"]
198
+
199
+ elif train_method == 'ESD-self':
200
+
201
+ modules = ".*attn1$"
202
+ frozen = []
203
+
204
+ randn = torch.randint(1, 10000000, (1,)).item()
205
+
206
+ save_path = f"models/{randn}_{prompt.lower().replace(' ', '')}.pt"
207
+
208
+ self.training = True
209
+
210
+ train(prompt, modules, frozen, iterations, neg_guidance, lr, save_path)
211
+
212
+ self.training = False
213
+
214
+ torch.cuda.empty_cache()
215
+
216
+ model_map['Custom'] = save_path
217
+
218
+ return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom model in the "Test" tab'), save_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
219
+
220
+
221
+ def inference(self, prompt, seed, model_name, pbar = gr.Progress(track_tqdm=True)):
222
+
223
+ seed = seed or 12345
224
+
225
+ generator = torch.manual_seed(seed)
226
+
227
+ model_path = model_map[model_name]
228
+
229
+ checkpoint = torch.load(model_path)
230
+
231
+ finetuner = FineTunedModel.from_checkpoint(self.diffuser, checkpoint).eval().half()
232
+
233
+ torch.cuda.empty_cache()
234
+
235
+ images = self.diffuser(
236
+ prompt,
237
+ n_steps=50,
238
+ generator=generator
239
+ )
240
+
241
+
242
+ orig_image = images[0][0]
243
+
244
+ torch.cuda.empty_cache()
245
+
246
+ generator = torch.manual_seed(seed)
247
+
248
+ with finetuner:
249
+
250
+ images = self.diffuser(
251
+ prompt,
252
+ n_steps=50,
253
+ generator=generator
254
+ )
255
+
256
+ edited_image = images[0][0]
257
+
258
+ del finetuner
259
+ torch.cuda.empty_cache()
260
+
261
+ return edited_image, orig_image
262
+
263
+
264
+ demo = Demo()
models/age.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c1c096f7cc1109b4072cbc604c811a5f0ff034fc0f6dc7cf66a558550aa4890
3
+ size 9142347
models/cartoon_style.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e07c30e4f82f709a474ae11dc5108ac48f81b6996b937757c8dd198920ea9b4d
3
+ size 9146507
models/chubby.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a70fb34187821a06a39bf36baa400090a32758d56771c3f54fcc4d9089f0d88
3
+ size 9144427
models/clay_style.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b0deeb787248811fb8e54498768e303cffaeb3125d00c5fd303294af59a9380
3
+ size 9143387
models/cluttered_room.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee409a45bfaa7ca01fbffe63ec185c0f5ccf0e7b0fa67070a9e0b41886b7ea66
3
+ size 9140267
models/curlyhair.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d9b8d7d44da256291e3710f74954d352160ade5cbe291bce16c8f4951db31e7b
3
+ size 9136043
models/dark_weather.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eecd2ae8b35022cbfb9c32637d9fa8c3c0ca3aa5ea189369c027f938064ada3c
3
+ size 9135003
models/eyebrow.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:442770d2c30de92e30a1c2fcf9aab6b6cf5a3786eff84d513b7455345c79b57d
3
+ size 9135003
models/eyesize.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8fdffa3e7788f4bd6be9a2fe3b91957b4f35999fc9fa19eabfb49f92fbf6650b
3
+ size 9139227
models/festive.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70d6c5d5be5f001510988852c2d233a916d23766675d9a000c6f785cd7e9127c
3
+ size 9133963
models/fix_hands.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d98c4828468c8d5831c439f49914672710f63219a561b191670fa54d542fa57b
3
+ size 9131883
models/long_hair.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e93dba27fa012bba0ea468eb2f9877ec0934424a9474e30ac9e94ea0517822ca
3
+ size 9147547
models/muscular.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b46b8eeac992f2d0e76ce887ea45ec1ce70bfbae053876de26d1f33f986eb37
3
+ size 9135003
models/pixar_style.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e07c30e4f82f709a474ae11dc5108ac48f81b6996b937757c8dd198920ea9b4d
3
+ size 9146507
models/professional.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d4289f4c60dd008fe487369ddccf3492bd678cc1e6b30de2c17f9ce802b12ac
3
+ size 9151707
models/repair_slider.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6e589e7d3b2174bb1d5d861a7218c4c26a94425b6dcdce0085b57f87ab841c5
3
+ size 9133963
models/sculpture_style.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2779746c08062ccb128fdaa6cb66f061070ac8f19386701a99fb9291392d5985
3
+ size 9148587
models/smiling.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6430ab47393ba15222ea0988c3479f547c8b59f93a41024bcddd7121ef7147d1
3
+ size 9146507
models/stylegan_latent1.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dca6cda8028af4587968cfed07c3bc6a2e79e5f8d01dad9351877f9de28f232d
3
+ size 9142347
models/stylegan_latent2.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4bbe239c399a4fc7b73a034b643c406106cd4c8392ad806ee3fd8dd8c80ba5fc
3
+ size 9142347
models/suprised_look.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36806271ca61dced2a506430c6c0b53ace09c68f65a90e09778c2bb5bcad31d4
3
+ size 9148587
models/tropical_weather.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:215e5445bbb7288ebea2e523181ca6db991417deca2736de29f0c3a76eb69ac0
3
+ size 9135003
models/winter_weather.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38f0bc81bc3cdef0c1c47895df6c9f0a9b98507f48928ef971f341e02c76bb4c
3
+ size 9132923
requirements.txt → reqs.txt RENAMED
File without changes
trainscripts/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # from textsliders import lora
trainscripts/imagesliders/config_util.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional
2
+
3
+ import yaml
4
+
5
+ from pydantic import BaseModel
6
+ import torch
7
+
8
+ from lora import TRAINING_METHODS
9
+
10
+ PRECISION_TYPES = Literal["fp32", "fp16", "bf16", "float32", "float16", "bfloat16"]
11
+ NETWORK_TYPES = Literal["lierla", "c3lier"]
12
+
13
+
14
+ class PretrainedModelConfig(BaseModel):
15
+ name_or_path: str
16
+ v2: bool = False
17
+ v_pred: bool = False
18
+
19
+ clip_skip: Optional[int] = None
20
+
21
+
22
+ class NetworkConfig(BaseModel):
23
+ type: NETWORK_TYPES = "lierla"
24
+ rank: int = 4
25
+ alpha: float = 1.0
26
+
27
+ training_method: TRAINING_METHODS = "full"
28
+
29
+
30
+ class TrainConfig(BaseModel):
31
+ precision: PRECISION_TYPES = "bfloat16"
32
+ noise_scheduler: Literal["ddim", "ddpm", "lms", "euler_a"] = "ddim"
33
+
34
+ iterations: int = 500
35
+ lr: float = 1e-4
36
+ optimizer: str = "adamw"
37
+ optimizer_args: str = ""
38
+ lr_scheduler: str = "constant"
39
+
40
+ max_denoising_steps: int = 50
41
+
42
+
43
+ class SaveConfig(BaseModel):
44
+ name: str = "untitled"
45
+ path: str = "./output"
46
+ per_steps: int = 200
47
+ precision: PRECISION_TYPES = "float32"
48
+
49
+
50
+ class LoggingConfig(BaseModel):
51
+ use_wandb: bool = False
52
+
53
+ verbose: bool = False
54
+
55
+
56
+ class OtherConfig(BaseModel):
57
+ use_xformers: bool = False
58
+
59
+
60
+ class RootConfig(BaseModel):
61
+ prompts_file: str
62
+ pretrained_model: PretrainedModelConfig
63
+
64
+ network: NetworkConfig
65
+
66
+ train: Optional[TrainConfig]
67
+
68
+ save: Optional[SaveConfig]
69
+
70
+ logging: Optional[LoggingConfig]
71
+
72
+ other: Optional[OtherConfig]
73
+
74
+
75
+ def parse_precision(precision: str) -> torch.dtype:
76
+ if precision == "fp32" or precision == "float32":
77
+ return torch.float32
78
+ elif precision == "fp16" or precision == "float16":
79
+ return torch.float16
80
+ elif precision == "bf16" or precision == "bfloat16":
81
+ return torch.bfloat16
82
+
83
+ raise ValueError(f"Invalid precision type: {precision}")
84
+
85
+
86
+ def load_config_from_yaml(config_path: str) -> RootConfig:
87
+ with open(config_path, "r") as f:
88
+ config = yaml.load(f, Loader=yaml.FullLoader)
89
+
90
+ root = RootConfig(**config)
91
+
92
+ if root.train is None:
93
+ root.train = TrainConfig()
94
+
95
+ if root.save is None:
96
+ root.save = SaveConfig()
97
+
98
+ if root.logging is None:
99
+ root.logging = LoggingConfig()
100
+
101
+ if root.other is None:
102
+ root.other = OtherConfig()
103
+
104
+ return root
trainscripts/imagesliders/data/config-xl.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prompts_file: "trainscripts/imagesliders/data/prompts-xl.yaml"
2
+ pretrained_model:
3
+ name_or_path: "stabilityai/stable-diffusion-xl-base-1.0" # you can also use .ckpt or .safetensors models
4
+ v2: false # true if model is v2.x
5
+ v_pred: false # true if model uses v-prediction
6
+ network:
7
+ type: "c3lier" # or "c3lier" or "lierla"
8
+ rank: 4
9
+ alpha: 1.0
10
+ training_method: "noxattn"
11
+ train:
12
+ precision: "bfloat16"
13
+ noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a"
14
+ iterations: 1000
15
+ lr: 0.0002
16
+ optimizer: "AdamW"
17
+ lr_scheduler: "constant"
18
+ max_denoising_steps: 50
19
+ save:
20
+ name: "temp"
21
+ path: "./models"
22
+ per_steps: 500
23
+ precision: "bfloat16"
24
+ logging:
25
+ use_wandb: false
26
+ verbose: false
27
+ other:
28
+ use_xformers: true
trainscripts/imagesliders/data/config.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prompts_file: "trainscripts/imagesliders/data/prompts.yaml"
2
+ pretrained_model:
3
+ name_or_path: "CompVis/stable-diffusion-v1-4" # you can also use .ckpt or .safetensors models
4
+ v2: false # true if model is v2.x
5
+ v_pred: false # true if model uses v-prediction
6
+ network:
7
+ type: "c3lier" # or "c3lier" or "lierla"
8
+ rank: 4
9
+ alpha: 1.0
10
+ training_method: "noxattn"
11
+ train:
12
+ precision: "bfloat16"
13
+ noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a"
14
+ iterations: 1000
15
+ lr: 0.0002
16
+ optimizer: "AdamW"
17
+ lr_scheduler: "constant"
18
+ max_denoising_steps: 50
19
+ save:
20
+ name: "temp"
21
+ path: "./models"
22
+ per_steps: 500
23
+ precision: "bfloat16"
24
+ logging:
25
+ use_wandb: false
26
+ verbose: false
27
+ other:
28
+ use_xformers: true
trainscripts/imagesliders/data/prompts-xl.yaml ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ####################################################################################################### AGE SLIDER
2
+ # - target: "male person" # what word for erasing the positive concept from
3
+ # positive: "male person, very old" # concept to erase
4
+ # unconditional: "male person, very young" # word to take the difference from the positive concept
5
+ # neutral: "male person" # starting point for conditioning the target
6
+ # action: "enhance" # erase or enhance
7
+ # guidance_scale: 4
8
+ # resolution: 512
9
+ # dynamic_resolution: false
10
+ # batch_size: 1
11
+ # - target: "female person" # what word for erasing the positive concept from
12
+ # positive: "female person, very old" # concept to erase
13
+ # unconditional: "female person, very young" # word to take the difference from the positive concept
14
+ # neutral: "female person" # starting point for conditioning the target
15
+ # action: "enhance" # erase or enhance
16
+ # guidance_scale: 4
17
+ # resolution: 512
18
+ # dynamic_resolution: false
19
+ # batch_size: 1
20
+ ####################################################################################################### GLASSES SLIDER
21
+ # - target: "male person" # what word for erasing the positive concept from
22
+ # positive: "male person, wearing glasses" # concept to erase
23
+ # unconditional: "male person" # word to take the difference from the positive concept
24
+ # neutral: "male person" # starting point for conditioning the target
25
+ # action: "enhance" # erase or enhance
26
+ # guidance_scale: 4
27
+ # resolution: 512
28
+ # dynamic_resolution: false
29
+ # batch_size: 1
30
+ # - target: "female person" # what word for erasing the positive concept from
31
+ # positive: "female person, wearing glasses" # concept to erase
32
+ # unconditional: "female person" # word to take the difference from the positive concept
33
+ # neutral: "female person" # starting point for conditioning the target
34
+ # action: "enhance" # erase or enhance
35
+ # guidance_scale: 4
36
+ # resolution: 512
37
+ # dynamic_resolution: false
38
+ # batch_size: 1
39
+ ####################################################################################################### ASTRONAUGHT SLIDER
40
+ # - target: "astronaught" # what word for erasing the positive concept from
41
+ # positive: "astronaught, with orange colored spacesuit" # concept to erase
42
+ # unconditional: "astronaught" # word to take the difference from the positive concept
43
+ # neutral: "astronaught" # starting point for conditioning the target
44
+ # action: "enhance" # erase or enhance
45
+ # guidance_scale: 4
46
+ # resolution: 512
47
+ # dynamic_resolution: false
48
+ # batch_size: 1
49
+ ####################################################################################################### SMILING SLIDER
50
+ # - target: "male person" # what word for erasing the positive concept from
51
+ # positive: "male person, smiling" # concept to erase
52
+ # unconditional: "male person, frowning" # word to take the difference from the positive concept
53
+ # neutral: "male person" # starting point for conditioning the target
54
+ # action: "enhance" # erase or enhance
55
+ # guidance_scale: 4
56
+ # resolution: 512
57
+ # dynamic_resolution: false
58
+ # batch_size: 1
59
+ # - target: "female person" # what word for erasing the positive concept from
60
+ # positive: "female person, smiling" # concept to erase
61
+ # unconditional: "female person, frowning" # word to take the difference from the positive concept
62
+ # neutral: "female person" # starting point for conditioning the target
63
+ # action: "enhance" # erase or enhance
64
+ # guidance_scale: 4
65
+ # resolution: 512
66
+ # dynamic_resolution: false
67
+ # batch_size: 1
68
+ ####################################################################################################### CAR COLOR SLIDER
69
+ # - target: "car" # what word for erasing the positive concept from
70
+ # positive: "car, white color" # concept to erase
71
+ # unconditional: "car, black color" # word to take the difference from the positive concept
72
+ # neutral: "car" # starting point for conditioning the target
73
+ # action: "enhance" # erase or enhance
74
+ # guidance_scale: 4
75
+ # resolution: 512
76
+ # dynamic_resolution: false
77
+ # batch_size: 1
78
+ ####################################################################################################### DETAILS SLIDER
79
+ # - target: "" # what word for erasing the positive concept from
80
+ # positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality, hyper realistic" # concept to erase
81
+ # unconditional: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality" # word to take the difference from the positive concept
82
+ # neutral: "" # starting point for conditioning the target
83
+ # action: "enhance" # erase or enhance
84
+ # guidance_scale: 4
85
+ # resolution: 512
86
+ # dynamic_resolution: false
87
+ # batch_size: 1
88
+ ####################################################################################################### BOKEH SLIDER
89
+ # - target: "" # what word for erasing the positive concept from
90
+ # positive: "blurred background, narrow DOF, bokeh effect" # concept to erase
91
+ # # unconditional: "high detail background, 8k, intricate, detailed, high resolution background, high res, high quality background" # word to take the difference from the positive concept
92
+ # unconditional: ""
93
+ # neutral: "" # starting point for conditioning the target
94
+ # action: "enhance" # erase or enhance
95
+ # guidance_scale: 4
96
+ # resolution: 512
97
+ # dynamic_resolution: false
98
+ # batch_size: 1
99
+ ####################################################################################################### LONG HAIR SLIDER
100
+ # - target: "male person" # what word for erasing the positive concept from
101
+ # positive: "male person, with long hair" # concept to erase
102
+ # unconditional: "male person, with short hair" # word to take the difference from the positive concept
103
+ # neutral: "male person" # starting point for conditioning the target
104
+ # action: "enhance" # erase or enhance
105
+ # guidance_scale: 4
106
+ # resolution: 512
107
+ # dynamic_resolution: false
108
+ # batch_size: 1
109
+ # - target: "female person" # what word for erasing the positive concept from
110
+ # positive: "female person, with long hair" # concept to erase
111
+ # unconditional: "female person, with short hair" # word to take the difference from the positive concept
112
+ # neutral: "female person" # starting point for conditioning the target
113
+ # action: "enhance" # erase or enhance
114
+ # guidance_scale: 4
115
+ # resolution: 512
116
+ # dynamic_resolution: false
117
+ # batch_size: 1
118
+ ####################################################################################################### IMAGE SLIDER
119
+ - target: "" # what word for erasing the positive concept from
120
+ positive: "" # concept to erase
121
+ unconditional: "" # word to take the difference from the positive concept
122
+ neutral: "" # starting point for conditioning the target
123
+ action: "enhance" # erase or enhance
124
+ guidance_scale: 4
125
+ resolution: 512
126
+ dynamic_resolution: false
127
+ batch_size: 1
128
+ ####################################################################################################### IMAGE SLIDER
129
+ # - target: "food" # what word for erasing the positive concept from
130
+ # positive: "food, expensive and fine dining" # concept to erase
131
+ # unconditional: "food, cheap and low quality" # word to take the difference from the positive concept
132
+ # neutral: "food" # starting point for conditioning the target
133
+ # action: "enhance" # erase or enhance
134
+ # guidance_scale: 4
135
+ # resolution: 512
136
+ # dynamic_resolution: false
137
+ # batch_size: 1
138
+ # - target: "room" # what word for erasing the positive concept from
139
+ # positive: "room, dirty disorganised and cluttered" # concept to erase
140
+ # unconditional: "room, neat organised and clean" # word to take the difference from the positive concept
141
+ # neutral: "room" # starting point for conditioning the target
142
+ # action: "enhance" # erase or enhance
143
+ # guidance_scale: 4
144
+ # resolution: 512
145
+ # dynamic_resolution: false
146
+ # batch_size: 1
147
+ # - target: "male person" # what word for erasing the positive concept from
148
+ # positive: "male person, with a surprised look" # concept to erase
149
+ # unconditional: "male person, with a disinterested look" # word to take the difference from the positive concept
150
+ # neutral: "male person" # starting point for conditioning the target
151
+ # action: "enhance" # erase or enhance
152
+ # guidance_scale: 4
153
+ # resolution: 512
154
+ # dynamic_resolution: false
155
+ # batch_size: 1
156
+ # - target: "female person" # what word for erasing the positive concept from
157
+ # positive: "female person, with a surprised look" # concept to erase
158
+ # unconditional: "female person, with a disinterested look" # word to take the difference from the positive concept
159
+ # neutral: "female person" # starting point for conditioning the target
160
+ # action: "enhance" # erase or enhance
161
+ # guidance_scale: 4
162
+ # resolution: 512
163
+ # dynamic_resolution: false
164
+ # batch_size: 1
165
+ # - target: "sky" # what word for erasing the positive concept from
166
+ # positive: "peaceful sky" # concept to erase
167
+ # unconditional: "sky" # word to take the difference from the positive concept
168
+ # neutral: "sky" # starting point for conditioning the target
169
+ # action: "enhance" # erase or enhance
170
+ # guidance_scale: 4
171
+ # resolution: 512
172
+ # dynamic_resolution: false
173
+ # batch_size: 1
174
+ # - target: "sky" # what word for erasing the positive concept from
175
+ # positive: "chaotic dark sky" # concept to erase
176
+ # unconditional: "sky" # word to take the difference from the positive concept
177
+ # neutral: "sky" # starting point for conditioning the target
178
+ # action: "erase" # erase or enhance
179
+ # guidance_scale: 4
180
+ # resolution: 512
181
+ # dynamic_resolution: false
182
+ # batch_size: 1
183
+ # - target: "person" # what word for erasing the positive concept from
184
+ # positive: "person, very young" # concept to erase
185
+ # unconditional: "person" # word to take the difference from the positive concept
186
+ # neutral: "person" # starting point for conditioning the target
187
+ # action: "erase" # erase or enhance
188
+ # guidance_scale: 4
189
+ # resolution: 512
190
+ # dynamic_resolution: false
191
+ # batch_size: 1
192
+ # overweight
193
+ # - target: "art" # what word for erasing the positive concept from
194
+ # positive: "realistic art" # concept to erase
195
+ # unconditional: "art" # word to take the difference from the positive concept
196
+ # neutral: "art" # starting point for conditioning the target
197
+ # action: "enhance" # erase or enhance
198
+ # guidance_scale: 4
199
+ # resolution: 512
200
+ # dynamic_resolution: false
201
+ # batch_size: 1
202
+ # - target: "art" # what word for erasing the positive concept from
203
+ # positive: "abstract art" # concept to erase
204
+ # unconditional: "art" # word to take the difference from the positive concept
205
+ # neutral: "art" # starting point for conditioning the target
206
+ # action: "erase" # erase or enhance
207
+ # guidance_scale: 4
208
+ # resolution: 512
209
+ # dynamic_resolution: false
210
+ # batch_size: 1
211
+ # sky
212
+ # - target: "weather" # what word for erasing the positive concept from
213
+ # positive: "bright pleasant weather" # concept to erase
214
+ # unconditional: "weather" # word to take the difference from the positive concept
215
+ # neutral: "weather" # starting point for conditioning the target
216
+ # action: "enhance" # erase or enhance
217
+ # guidance_scale: 4
218
+ # resolution: 512
219
+ # dynamic_resolution: false
220
+ # batch_size: 1
221
+ # - target: "weather" # what word for erasing the positive concept from
222
+ # positive: "dark gloomy weather" # concept to erase
223
+ # unconditional: "weather" # word to take the difference from the positive concept
224
+ # neutral: "weather" # starting point for conditioning the target
225
+ # action: "erase" # erase or enhance
226
+ # guidance_scale: 4
227
+ # resolution: 512
228
+ # dynamic_resolution: false
229
+ # batch_size: 1
230
+ # hair
231
+ # - target: "person" # what word for erasing the positive concept from
232
+ # positive: "person with long hair" # concept to erase
233
+ # unconditional: "person" # word to take the difference from the positive concept
234
+ # neutral: "person" # starting point for conditioning the target
235
+ # action: "enhance" # erase or enhance
236
+ # guidance_scale: 4
237
+ # resolution: 512
238
+ # dynamic_resolution: false
239
+ # batch_size: 1
240
+ # - target: "person" # what word for erasing the positive concept from
241
+ # positive: "person with short hair" # concept to erase
242
+ # unconditional: "person" # word to take the difference from the positive concept
243
+ # neutral: "person" # starting point for conditioning the target
244
+ # action: "erase" # erase or enhance
245
+ # guidance_scale: 4
246
+ # resolution: 512
247
+ # dynamic_resolution: false
248
+ # batch_size: 1
249
+ # - target: "girl" # what word for erasing the positive concept from
250
+ # positive: "baby girl" # concept to erase
251
+ # unconditional: "girl" # word to take the difference from the positive concept
252
+ # neutral: "girl" # starting point for conditioning the target
253
+ # action: "enhance" # erase or enhance
254
+ # guidance_scale: -4
255
+ # resolution: 512
256
+ # dynamic_resolution: false
257
+ # batch_size: 1
258
+ # - target: "boy" # what word for erasing the positive concept from
259
+ # positive: "old man" # concept to erase
260
+ # unconditional: "boy" # word to take the difference from the positive concept
261
+ # neutral: "boy" # starting point for conditioning the target
262
+ # action: "enhance" # erase or enhance
263
+ # guidance_scale: 4
264
+ # resolution: 512
265
+ # dynamic_resolution: false
266
+ # batch_size: 1
267
+ # - target: "boy" # what word for erasing the positive concept from
268
+ # positive: "baby boy" # concept to erase
269
+ # unconditional: "boy" # word to take the difference from the positive concept
270
+ # neutral: "boy" # starting point for conditioning the target
271
+ # action: "enhance" # erase or enhance
272
+ # guidance_scale: -4
273
+ # resolution: 512
274
+ # dynamic_resolution: false
275
+ # batch_size: 1
trainscripts/imagesliders/data/prompts.yaml ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - target: "person" # what word for erasing the positive concept from
2
+ # positive: "person, very old" # concept to erase
3
+ # unconditional: "person" # word to take the difference from the positive concept
4
+ # neutral: "person" # starting point for conditioning the target
5
+ # action: "enhance" # erase or enhance
6
+ # guidance_scale: 4
7
+ # resolution: 512
8
+ # dynamic_resolution: false
9
+ # batch_size: 1
10
+ - target: "" # what word for erasing the positive concept from
11
+ positive: "" # concept to erase
12
+ unconditional: "" # word to take the difference from the positive concept
13
+ neutral: "" # starting point for conditioning the target
14
+ action: "enhance" # erase or enhance
15
+ guidance_scale: 1
16
+ resolution: 512
17
+ dynamic_resolution: false
18
+ batch_size: 1
19
+ # - target: "" # what word for erasing the positive concept from
20
+ # positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality" # concept to erase
21
+ # unconditional: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality" # word to take the difference from the positive concept
22
+ # neutral: "" # starting point for conditioning the target
23
+ # action: "enhance" # erase or enhance
24
+ # guidance_scale: 4
25
+ # resolution: 512
26
+ # dynamic_resolution: false
27
+ # batch_size: 1
28
+ # - target: "food" # what word for erasing the positive concept from
29
+ # positive: "food, expensive and fine dining" # concept to erase
30
+ # unconditional: "food, cheap and low quality" # word to take the difference from the positive concept
31
+ # neutral: "food" # starting point for conditioning the target
32
+ # action: "enhance" # erase or enhance
33
+ # guidance_scale: 4
34
+ # resolution: 512
35
+ # dynamic_resolution: false
36
+ # batch_size: 1
37
+ # - target: "room" # what word for erasing the positive concept from
38
+ # positive: "room, dirty disorganised and cluttered" # concept to erase
39
+ # unconditional: "room, neat organised and clean" # word to take the difference from the positive concept
40
+ # neutral: "room" # starting point for conditioning the target
41
+ # action: "enhance" # erase or enhance
42
+ # guidance_scale: 4
43
+ # resolution: 512
44
+ # dynamic_resolution: false
45
+ # batch_size: 1
46
+ # - target: "male person" # what word for erasing the positive concept from
47
+ # positive: "male person, with a surprised look" # concept to erase
48
+ # unconditional: "male person, with a disinterested look" # word to take the difference from the positive concept
49
+ # neutral: "male person" # starting point for conditioning the target
50
+ # action: "enhance" # erase or enhance
51
+ # guidance_scale: 4
52
+ # resolution: 512
53
+ # dynamic_resolution: false
54
+ # batch_size: 1
55
+ # - target: "female person" # what word for erasing the positive concept from
56
+ # positive: "female person, with a surprised look" # concept to erase
57
+ # unconditional: "female person, with a disinterested look" # word to take the difference from the positive concept
58
+ # neutral: "female person" # starting point for conditioning the target
59
+ # action: "enhance" # erase or enhance
60
+ # guidance_scale: 4
61
+ # resolution: 512
62
+ # dynamic_resolution: false
63
+ # batch_size: 1
64
+ # - target: "sky" # what word for erasing the positive concept from
65
+ # positive: "peaceful sky" # concept to erase
66
+ # unconditional: "sky" # word to take the difference from the positive concept
67
+ # neutral: "sky" # starting point for conditioning the target
68
+ # action: "enhance" # erase or enhance
69
+ # guidance_scale: 4
70
+ # resolution: 512
71
+ # dynamic_resolution: false
72
+ # batch_size: 1
73
+ # - target: "sky" # what word for erasing the positive concept from
74
+ # positive: "chaotic dark sky" # concept to erase
75
+ # unconditional: "sky" # word to take the difference from the positive concept
76
+ # neutral: "sky" # starting point for conditioning the target
77
+ # action: "erase" # erase or enhance
78
+ # guidance_scale: 4
79
+ # resolution: 512
80
+ # dynamic_resolution: false
81
+ # batch_size: 1
82
+ # - target: "person" # what word for erasing the positive concept from
83
+ # positive: "person, very young" # concept to erase
84
+ # unconditional: "person" # word to take the difference from the positive concept
85
+ # neutral: "person" # starting point for conditioning the target
86
+ # action: "erase" # erase or enhance
87
+ # guidance_scale: 4
88
+ # resolution: 512
89
+ # dynamic_resolution: false
90
+ # batch_size: 1
91
+ # overweight
92
+ # - target: "art" # what word for erasing the positive concept from
93
+ # positive: "realistic art" # concept to erase
94
+ # unconditional: "art" # word to take the difference from the positive concept
95
+ # neutral: "art" # starting point for conditioning the target
96
+ # action: "enhance" # erase or enhance
97
+ # guidance_scale: 4
98
+ # resolution: 512
99
+ # dynamic_resolution: false
100
+ # batch_size: 1
101
+ # - target: "art" # what word for erasing the positive concept from
102
+ # positive: "abstract art" # concept to erase
103
+ # unconditional: "art" # word to take the difference from the positive concept
104
+ # neutral: "art" # starting point for conditioning the target
105
+ # action: "erase" # erase or enhance
106
+ # guidance_scale: 4
107
+ # resolution: 512
108
+ # dynamic_resolution: false
109
+ # batch_size: 1
110
+ # sky
111
+ # - target: "weather" # what word for erasing the positive concept from
112
+ # positive: "bright pleasant weather" # concept to erase
113
+ # unconditional: "weather" # word to take the difference from the positive concept
114
+ # neutral: "weather" # starting point for conditioning the target
115
+ # action: "enhance" # erase or enhance
116
+ # guidance_scale: 4
117
+ # resolution: 512
118
+ # dynamic_resolution: false
119
+ # batch_size: 1
120
+ # - target: "weather" # what word for erasing the positive concept from
121
+ # positive: "dark gloomy weather" # concept to erase
122
+ # unconditional: "weather" # word to take the difference from the positive concept
123
+ # neutral: "weather" # starting point for conditioning the target
124
+ # action: "erase" # erase or enhance
125
+ # guidance_scale: 4
126
+ # resolution: 512
127
+ # dynamic_resolution: false
128
+ # batch_size: 1
129
+ # hair
130
+ # - target: "person" # what word for erasing the positive concept from
131
+ # positive: "person with long hair" # concept to erase
132
+ # unconditional: "person" # word to take the difference from the positive concept
133
+ # neutral: "person" # starting point for conditioning the target
134
+ # action: "enhance" # erase or enhance
135
+ # guidance_scale: 4
136
+ # resolution: 512
137
+ # dynamic_resolution: false
138
+ # batch_size: 1
139
+ # - target: "person" # what word for erasing the positive concept from
140
+ # positive: "person with short hair" # concept to erase
141
+ # unconditional: "person" # word to take the difference from the positive concept
142
+ # neutral: "person" # starting point for conditioning the target
143
+ # action: "erase" # erase or enhance
144
+ # guidance_scale: 4
145
+ # resolution: 512
146
+ # dynamic_resolution: false
147
+ # batch_size: 1
148
+ # - target: "girl" # what word for erasing the positive concept from
149
+ # positive: "baby girl" # concept to erase
150
+ # unconditional: "girl" # word to take the difference from the positive concept
151
+ # neutral: "girl" # starting point for conditioning the target
152
+ # action: "enhance" # erase or enhance
153
+ # guidance_scale: -4
154
+ # resolution: 512
155
+ # dynamic_resolution: false
156
+ # batch_size: 1
157
+ # - target: "boy" # what word for erasing the positive concept from
158
+ # positive: "old man" # concept to erase
159
+ # unconditional: "boy" # word to take the difference from the positive concept
160
+ # neutral: "boy" # starting point for conditioning the target
161
+ # action: "enhance" # erase or enhance
162
+ # guidance_scale: 4
163
+ # resolution: 512
164
+ # dynamic_resolution: false
165
+ # batch_size: 1
166
+ # - target: "boy" # what word for erasing the positive concept from
167
+ # positive: "baby boy" # concept to erase
168
+ # unconditional: "boy" # word to take the difference from the positive concept
169
+ # neutral: "boy" # starting point for conditioning the target
170
+ # action: "enhance" # erase or enhance
171
+ # guidance_scale: -4
172
+ # resolution: 512
173
+ # dynamic_resolution: false
174
+ # batch_size: 1
trainscripts/imagesliders/debug_util.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # デバッグ用...
2
+
3
+ import torch
4
+
5
+
6
+ def check_requires_grad(model: torch.nn.Module):
7
+ for name, module in list(model.named_modules())[:5]:
8
+ if len(list(module.parameters())) > 0:
9
+ print(f"Module: {name}")
10
+ for name, param in list(module.named_parameters())[:2]:
11
+ print(f" Parameter: {name}, Requires Grad: {param.requires_grad}")
12
+
13
+
14
+ def check_training_mode(model: torch.nn.Module):
15
+ for name, module in list(model.named_modules())[:5]:
16
+ print(f"Module: {name}, Training Mode: {module.training}")
trainscripts/imagesliders/lora.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ref:
2
+ # - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
3
+ # - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
4
+
5
+ import os
6
+ import math
7
+ from typing import Optional, List, Type, Set, Literal
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from diffusers import UNet2DConditionModel
12
+ from safetensors.torch import save_file
13
+
14
+
15
+ UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [
16
+ # "Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2
17
+ "Attention"
18
+ ]
19
+ UNET_TARGET_REPLACE_MODULE_CONV = [
20
+ "ResnetBlock2D",
21
+ "Downsample2D",
22
+ "Upsample2D",
23
+ # "DownBlock2D",
24
+ # "UpBlock2D"
25
+ ] # locon, 3clier
26
+
27
+ LORA_PREFIX_UNET = "lora_unet"
28
+
29
+ DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER
30
+
31
+ TRAINING_METHODS = Literal[
32
+ "noxattn", # train all layers except x-attns and time_embed layers
33
+ "innoxattn", # train all layers except self attention layers
34
+ "selfattn", # ESD-u, train only self attention layers
35
+ "xattn", # ESD-x, train only x attention layers
36
+ "full", # train all layers
37
+ "xattn-strict", # q and k values
38
+ "noxattn-hspace",
39
+ "noxattn-hspace-last",
40
+ # "xlayer",
41
+ # "outxattn",
42
+ # "outsattn",
43
+ # "inxattn",
44
+ # "inmidsattn",
45
+ # "selflayer",
46
+ ]
47
+
48
+
49
+ class LoRAModule(nn.Module):
50
+ """
51
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ lora_name,
57
+ org_module: nn.Module,
58
+ multiplier=1.0,
59
+ lora_dim=4,
60
+ alpha=1,
61
+ ):
62
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
63
+ super().__init__()
64
+ self.lora_name = lora_name
65
+ self.lora_dim = lora_dim
66
+
67
+ if "Linear" in org_module.__class__.__name__:
68
+ in_dim = org_module.in_features
69
+ out_dim = org_module.out_features
70
+ self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
71
+ self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
72
+
73
+ elif "Conv" in org_module.__class__.__name__: # 一応
74
+ in_dim = org_module.in_channels
75
+ out_dim = org_module.out_channels
76
+
77
+ self.lora_dim = min(self.lora_dim, in_dim, out_dim)
78
+ if self.lora_dim != lora_dim:
79
+ print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
80
+
81
+ kernel_size = org_module.kernel_size
82
+ stride = org_module.stride
83
+ padding = org_module.padding
84
+ self.lora_down = nn.Conv2d(
85
+ in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
86
+ )
87
+ self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
88
+
89
+ if type(alpha) == torch.Tensor:
90
+ alpha = alpha.detach().numpy()
91
+ alpha = lora_dim if alpha is None or alpha == 0 else alpha
92
+ self.scale = alpha / self.lora_dim
93
+ self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
94
+
95
+ # same as microsoft's
96
+ nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
97
+ nn.init.zeros_(self.lora_up.weight)
98
+
99
+ self.multiplier = multiplier
100
+ self.org_module = org_module # remove in applying
101
+
102
+ def apply_to(self):
103
+ self.org_forward = self.org_module.forward
104
+ self.org_module.forward = self.forward
105
+ del self.org_module
106
+
107
+ def forward(self, x):
108
+ return (
109
+ self.org_forward(x)
110
+ + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
111
+ )
112
+
113
+
114
+ class LoRANetwork(nn.Module):
115
+ def __init__(
116
+ self,
117
+ unet: UNet2DConditionModel,
118
+ rank: int = 4,
119
+ multiplier: float = 1.0,
120
+ alpha: float = 1.0,
121
+ train_method: TRAINING_METHODS = "full",
122
+ ) -> None:
123
+ super().__init__()
124
+ self.lora_scale = 1
125
+ self.multiplier = multiplier
126
+ self.lora_dim = rank
127
+ self.alpha = alpha
128
+
129
+ # LoRAのみ
130
+ self.module = LoRAModule
131
+
132
+ # unetのloraを作る
133
+ self.unet_loras = self.create_modules(
134
+ LORA_PREFIX_UNET,
135
+ unet,
136
+ DEFAULT_TARGET_REPLACE,
137
+ self.lora_dim,
138
+ self.multiplier,
139
+ train_method=train_method,
140
+ )
141
+ print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
142
+
143
+ # assertion 名前の被りがないか確認しているようだ
144
+ lora_names = set()
145
+ for lora in self.unet_loras:
146
+ assert (
147
+ lora.lora_name not in lora_names
148
+ ), f"duplicated lora name: {lora.lora_name}. {lora_names}"
149
+ lora_names.add(lora.lora_name)
150
+
151
+ # 適用する
152
+ for lora in self.unet_loras:
153
+ lora.apply_to()
154
+ self.add_module(
155
+ lora.lora_name,
156
+ lora,
157
+ )
158
+
159
+ del unet
160
+
161
+ torch.cuda.empty_cache()
162
+
163
+ def create_modules(
164
+ self,
165
+ prefix: str,
166
+ root_module: nn.Module,
167
+ target_replace_modules: List[str],
168
+ rank: int,
169
+ multiplier: float,
170
+ train_method: TRAINING_METHODS,
171
+ ) -> list:
172
+ loras = []
173
+ names = []
174
+ for name, module in root_module.named_modules():
175
+ if train_method == "noxattn" or train_method == "noxattn-hspace" or train_method == "noxattn-hspace-last": # Cross Attention と Time Embed 以外学習
176
+ if "attn2" in name or "time_embed" in name:
177
+ continue
178
+ elif train_method == "innoxattn": # Cross Attention 以外学習
179
+ if "attn2" in name:
180
+ continue
181
+ elif train_method == "selfattn": # Self Attention のみ学習
182
+ if "attn1" not in name:
183
+ continue
184
+ elif train_method == "xattn" or train_method == "xattn-strict": # Cross Attention のみ学習
185
+ if "attn2" not in name:
186
+ continue
187
+ elif train_method == "full": # 全部学習
188
+ pass
189
+ else:
190
+ raise NotImplementedError(
191
+ f"train_method: {train_method} is not implemented."
192
+ )
193
+ if module.__class__.__name__ in target_replace_modules:
194
+ for child_name, child_module in module.named_modules():
195
+ if child_module.__class__.__name__ in ["Linear", "Conv2d", "LoRACompatibleLinear", "LoRACompatibleConv"]:
196
+ if train_method == 'xattn-strict':
197
+ if 'out' in child_name:
198
+ continue
199
+ if train_method == 'noxattn-hspace':
200
+ if 'mid_block' not in name:
201
+ continue
202
+ if train_method == 'noxattn-hspace-last':
203
+ if 'mid_block' not in name or '.1' not in name or 'conv2' not in child_name:
204
+ continue
205
+ lora_name = prefix + "." + name + "." + child_name
206
+ lora_name = lora_name.replace(".", "_")
207
+ # print(f"{lora_name}")
208
+ lora = self.module(
209
+ lora_name, child_module, multiplier, rank, self.alpha
210
+ )
211
+ # print(name, child_name)
212
+ # print(child_module.weight.shape)
213
+ loras.append(lora)
214
+ names.append(lora_name)
215
+ # print(f'@@@@@@@@@@@@@@@@@@@@@@@@@@@@ \n {names}')
216
+ return loras
217
+
218
+ def prepare_optimizer_params(self):
219
+ all_params = []
220
+
221
+ if self.unet_loras: # 実質これしかない
222
+ params = []
223
+ [params.extend(lora.parameters()) for lora in self.unet_loras]
224
+ param_data = {"params": params}
225
+ all_params.append(param_data)
226
+
227
+ return all_params
228
+
229
+ def save_weights(self, file, dtype=None, metadata: Optional[dict] = None):
230
+ state_dict = self.state_dict()
231
+
232
+ if dtype is not None:
233
+ for key in list(state_dict.keys()):
234
+ v = state_dict[key]
235
+ v = v.detach().clone().to("cpu").to(dtype)
236
+ state_dict[key] = v
237
+
238
+ # for key in list(state_dict.keys()):
239
+ # if not key.startswith("lora"):
240
+ # # lora以外除外
241
+ # del state_dict[key]
242
+
243
+ if os.path.splitext(file)[1] == ".safetensors":
244
+ save_file(state_dict, file, metadata)
245
+ else:
246
+ torch.save(state_dict, file)
247
+ def set_lora_slider(self, scale):
248
+ self.lora_scale = scale
249
+
250
+ def __enter__(self):
251
+ for lora in self.unet_loras:
252
+ lora.multiplier = 1.0 * self.lora_scale
253
+
254
+ def __exit__(self, exc_type, exc_value, tb):
255
+ for lora in self.unet_loras:
256
+ lora.multiplier = 0
trainscripts/imagesliders/model_util.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Union, Optional
2
+
3
+ import torch
4
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
5
+ from diffusers import (
6
+ UNet2DConditionModel,
7
+ SchedulerMixin,
8
+ StableDiffusionPipeline,
9
+ StableDiffusionXLPipeline,
10
+ AutoencoderKL,
11
+ )
12
+ from diffusers.schedulers import (
13
+ DDIMScheduler,
14
+ DDPMScheduler,
15
+ LMSDiscreteScheduler,
16
+ EulerAncestralDiscreteScheduler,
17
+ )
18
+
19
+
20
+ TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4"
21
+ TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1"
22
+
23
+ AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a"]
24
+
25
+ SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection]
26
+
27
+ DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this
28
+
29
+
30
+ def load_diffusers_model(
31
+ pretrained_model_name_or_path: str,
32
+ v2: bool = False,
33
+ clip_skip: Optional[int] = None,
34
+ weight_dtype: torch.dtype = torch.float32,
35
+ ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
36
+ # VAE はいらない
37
+
38
+ if v2:
39
+ tokenizer = CLIPTokenizer.from_pretrained(
40
+ TOKENIZER_V2_MODEL_NAME,
41
+ subfolder="tokenizer",
42
+ torch_dtype=weight_dtype,
43
+ cache_dir=DIFFUSERS_CACHE_DIR,
44
+ )
45
+ text_encoder = CLIPTextModel.from_pretrained(
46
+ pretrained_model_name_or_path,
47
+ subfolder="text_encoder",
48
+ # default is clip skip 2
49
+ num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23,
50
+ torch_dtype=weight_dtype,
51
+ cache_dir=DIFFUSERS_CACHE_DIR,
52
+ )
53
+ else:
54
+ tokenizer = CLIPTokenizer.from_pretrained(
55
+ TOKENIZER_V1_MODEL_NAME,
56
+ subfolder="tokenizer",
57
+ torch_dtype=weight_dtype,
58
+ cache_dir=DIFFUSERS_CACHE_DIR,
59
+ )
60
+ text_encoder = CLIPTextModel.from_pretrained(
61
+ pretrained_model_name_or_path,
62
+ subfolder="text_encoder",
63
+ num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12,
64
+ torch_dtype=weight_dtype,
65
+ cache_dir=DIFFUSERS_CACHE_DIR,
66
+ )
67
+
68
+ unet = UNet2DConditionModel.from_pretrained(
69
+ pretrained_model_name_or_path,
70
+ subfolder="unet",
71
+ torch_dtype=weight_dtype,
72
+ cache_dir=DIFFUSERS_CACHE_DIR,
73
+ )
74
+
75
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
76
+
77
+ return tokenizer, text_encoder, unet, vae
78
+
79
+
80
+ def load_checkpoint_model(
81
+ checkpoint_path: str,
82
+ v2: bool = False,
83
+ clip_skip: Optional[int] = None,
84
+ weight_dtype: torch.dtype = torch.float32,
85
+ ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
86
+ pipe = StableDiffusionPipeline.from_ckpt(
87
+ checkpoint_path,
88
+ upcast_attention=True if v2 else False,
89
+ torch_dtype=weight_dtype,
90
+ cache_dir=DIFFUSERS_CACHE_DIR,
91
+ )
92
+
93
+ unet = pipe.unet
94
+ tokenizer = pipe.tokenizer
95
+ text_encoder = pipe.text_encoder
96
+ vae = pipe.vae
97
+ if clip_skip is not None:
98
+ if v2:
99
+ text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1)
100
+ else:
101
+ text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1)
102
+
103
+ del pipe
104
+
105
+ return tokenizer, text_encoder, unet, vae
106
+
107
+
108
+ def load_models(
109
+ pretrained_model_name_or_path: str,
110
+ scheduler_name: AVAILABLE_SCHEDULERS,
111
+ v2: bool = False,
112
+ v_pred: bool = False,
113
+ weight_dtype: torch.dtype = torch.float32,
114
+ ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]:
115
+ if pretrained_model_name_or_path.endswith(
116
+ ".ckpt"
117
+ ) or pretrained_model_name_or_path.endswith(".safetensors"):
118
+ tokenizer, text_encoder, unet, vae = load_checkpoint_model(
119
+ pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
120
+ )
121
+ else: # diffusers
122
+ tokenizer, text_encoder, unet, vae = load_diffusers_model(
123
+ pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
124
+ )
125
+
126
+ # VAE はいらない
127
+
128
+ scheduler = create_noise_scheduler(
129
+ scheduler_name,
130
+ prediction_type="v_prediction" if v_pred else "epsilon",
131
+ )
132
+
133
+ return tokenizer, text_encoder, unet, scheduler, vae
134
+
135
+
136
+ def load_diffusers_model_xl(
137
+ pretrained_model_name_or_path: str,
138
+ weight_dtype: torch.dtype = torch.float32,
139
+ ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
140
+ # returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet
141
+
142
+ tokenizers = [
143
+ CLIPTokenizer.from_pretrained(
144
+ pretrained_model_name_or_path,
145
+ subfolder="tokenizer",
146
+ torch_dtype=weight_dtype,
147
+ cache_dir=DIFFUSERS_CACHE_DIR,
148
+ ),
149
+ CLIPTokenizer.from_pretrained(
150
+ pretrained_model_name_or_path,
151
+ subfolder="tokenizer_2",
152
+ torch_dtype=weight_dtype,
153
+ cache_dir=DIFFUSERS_CACHE_DIR,
154
+ pad_token_id=0, # same as open clip
155
+ ),
156
+ ]
157
+
158
+ text_encoders = [
159
+ CLIPTextModel.from_pretrained(
160
+ pretrained_model_name_or_path,
161
+ subfolder="text_encoder",
162
+ torch_dtype=weight_dtype,
163
+ cache_dir=DIFFUSERS_CACHE_DIR,
164
+ ),
165
+ CLIPTextModelWithProjection.from_pretrained(
166
+ pretrained_model_name_or_path,
167
+ subfolder="text_encoder_2",
168
+ torch_dtype=weight_dtype,
169
+ cache_dir=DIFFUSERS_CACHE_DIR,
170
+ ),
171
+ ]
172
+
173
+ unet = UNet2DConditionModel.from_pretrained(
174
+ pretrained_model_name_or_path,
175
+ subfolder="unet",
176
+ torch_dtype=weight_dtype,
177
+ cache_dir=DIFFUSERS_CACHE_DIR,
178
+ )
179
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
180
+ return tokenizers, text_encoders, unet, vae
181
+
182
+
183
+ def load_checkpoint_model_xl(
184
+ checkpoint_path: str,
185
+ weight_dtype: torch.dtype = torch.float32,
186
+ ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
187
+ pipe = StableDiffusionXLPipeline.from_single_file(
188
+ checkpoint_path,
189
+ torch_dtype=weight_dtype,
190
+ cache_dir=DIFFUSERS_CACHE_DIR,
191
+ )
192
+
193
+ unet = pipe.unet
194
+ tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
195
+ text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
196
+ if len(text_encoders) == 2:
197
+ text_encoders[1].pad_token_id = 0
198
+
199
+ del pipe
200
+
201
+ return tokenizers, text_encoders, unet
202
+
203
+
204
+ def load_models_xl(
205
+ pretrained_model_name_or_path: str,
206
+ scheduler_name: AVAILABLE_SCHEDULERS,
207
+ weight_dtype: torch.dtype = torch.float32,
208
+ ) -> tuple[
209
+ list[CLIPTokenizer],
210
+ list[SDXL_TEXT_ENCODER_TYPE],
211
+ UNet2DConditionModel,
212
+ SchedulerMixin,
213
+ ]:
214
+ if pretrained_model_name_or_path.endswith(
215
+ ".ckpt"
216
+ ) or pretrained_model_name_or_path.endswith(".safetensors"):
217
+ (
218
+ tokenizers,
219
+ text_encoders,
220
+ unet,
221
+ ) = load_checkpoint_model_xl(pretrained_model_name_or_path, weight_dtype)
222
+ else: # diffusers
223
+ (
224
+ tokenizers,
225
+ text_encoders,
226
+ unet,
227
+ vae
228
+ ) = load_diffusers_model_xl(pretrained_model_name_or_path, weight_dtype)
229
+
230
+ scheduler = create_noise_scheduler(scheduler_name)
231
+
232
+ return tokenizers, text_encoders, unet, scheduler, vae
233
+
234
+
235
+ def create_noise_scheduler(
236
+ scheduler_name: AVAILABLE_SCHEDULERS = "ddpm",
237
+ prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
238
+ ) -> SchedulerMixin:
239
+ # 正直、どれがいいのかわからない。元の実装だとDDIMとDDPMとLMSを選べたのだけど、どれがいいのかわからぬ。
240
+
241
+ name = scheduler_name.lower().replace(" ", "_")
242
+ if name == "ddim":
243
+ # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim
244
+ scheduler = DDIMScheduler(
245
+ beta_start=0.00085,
246
+ beta_end=0.012,
247
+ beta_schedule="scaled_linear",
248
+ num_train_timesteps=1000,
249
+ clip_sample=False,
250
+ prediction_type=prediction_type, # これでいいの?
251
+ )
252
+ elif name == "ddpm":
253
+ # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm
254
+ scheduler = DDPMScheduler(
255
+ beta_start=0.00085,
256
+ beta_end=0.012,
257
+ beta_schedule="scaled_linear",
258
+ num_train_timesteps=1000,
259
+ clip_sample=False,
260
+ prediction_type=prediction_type,
261
+ )
262
+ elif name == "lms":
263
+ # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete
264
+ scheduler = LMSDiscreteScheduler(
265
+ beta_start=0.00085,
266
+ beta_end=0.012,
267
+ beta_schedule="scaled_linear",
268
+ num_train_timesteps=1000,
269
+ prediction_type=prediction_type,
270
+ )
271
+ elif name == "euler_a":
272
+ # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral
273
+ scheduler = EulerAncestralDiscreteScheduler(
274
+ beta_start=0.00085,
275
+ beta_end=0.012,
276
+ beta_schedule="scaled_linear",
277
+ num_train_timesteps=1000,
278
+ prediction_type=prediction_type,
279
+ )
280
+ else:
281
+ raise ValueError(f"Unknown scheduler name: {name}")
282
+
283
+ return scheduler
trainscripts/imagesliders/prompt_util.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional, Union, List
2
+
3
+ import yaml
4
+ from pathlib import Path
5
+
6
+
7
+ from pydantic import BaseModel, root_validator
8
+ import torch
9
+ import copy
10
+
11
+ ACTION_TYPES = Literal[
12
+ "erase",
13
+ "enhance",
14
+ ]
15
+
16
+
17
+ # XL は二種類必要なので
18
+ class PromptEmbedsXL:
19
+ text_embeds: torch.FloatTensor
20
+ pooled_embeds: torch.FloatTensor
21
+
22
+ def __init__(self, *args) -> None:
23
+ self.text_embeds = args[0]
24
+ self.pooled_embeds = args[1]
25
+
26
+
27
+ # SDv1.x, SDv2.x は FloatTensor、XL は PromptEmbedsXL
28
+ PROMPT_EMBEDDING = Union[torch.FloatTensor, PromptEmbedsXL]
29
+
30
+
31
+ class PromptEmbedsCache: # 使いまわしたいので
32
+ prompts: dict[str, PROMPT_EMBEDDING] = {}
33
+
34
+ def __setitem__(self, __name: str, __value: PROMPT_EMBEDDING) -> None:
35
+ self.prompts[__name] = __value
36
+
37
+ def __getitem__(self, __name: str) -> Optional[PROMPT_EMBEDDING]:
38
+ if __name in self.prompts:
39
+ return self.prompts[__name]
40
+ else:
41
+ return None
42
+
43
+
44
+ class PromptSettings(BaseModel): # yaml のやつ
45
+ target: str
46
+ positive: str = None # if None, target will be used
47
+ unconditional: str = "" # default is ""
48
+ neutral: str = None # if None, unconditional will be used
49
+ action: ACTION_TYPES = "erase" # default is "erase"
50
+ guidance_scale: float = 1.0 # default is 1.0
51
+ resolution: int = 512 # default is 512
52
+ dynamic_resolution: bool = False # default is False
53
+ batch_size: int = 1 # default is 1
54
+ dynamic_crops: bool = False # default is False. only used when model is XL
55
+
56
+ @root_validator(pre=True)
57
+ def fill_prompts(cls, values):
58
+ keys = values.keys()
59
+ if "target" not in keys:
60
+ raise ValueError("target must be specified")
61
+ if "positive" not in keys:
62
+ values["positive"] = values["target"]
63
+ if "unconditional" not in keys:
64
+ values["unconditional"] = ""
65
+ if "neutral" not in keys:
66
+ values["neutral"] = values["unconditional"]
67
+
68
+ return values
69
+
70
+
71
+ class PromptEmbedsPair:
72
+ target: PROMPT_EMBEDDING # not want to generate the concept
73
+ positive: PROMPT_EMBEDDING # generate the concept
74
+ unconditional: PROMPT_EMBEDDING # uncondition (default should be empty)
75
+ neutral: PROMPT_EMBEDDING # base condition (default should be empty)
76
+
77
+ guidance_scale: float
78
+ resolution: int
79
+ dynamic_resolution: bool
80
+ batch_size: int
81
+ dynamic_crops: bool
82
+
83
+ loss_fn: torch.nn.Module
84
+ action: ACTION_TYPES
85
+
86
+ def __init__(
87
+ self,
88
+ loss_fn: torch.nn.Module,
89
+ target: PROMPT_EMBEDDING,
90
+ positive: PROMPT_EMBEDDING,
91
+ unconditional: PROMPT_EMBEDDING,
92
+ neutral: PROMPT_EMBEDDING,
93
+ settings: PromptSettings,
94
+ ) -> None:
95
+ self.loss_fn = loss_fn
96
+ self.target = target
97
+ self.positive = positive
98
+ self.unconditional = unconditional
99
+ self.neutral = neutral
100
+
101
+ self.guidance_scale = settings.guidance_scale
102
+ self.resolution = settings.resolution
103
+ self.dynamic_resolution = settings.dynamic_resolution
104
+ self.batch_size = settings.batch_size
105
+ self.dynamic_crops = settings.dynamic_crops
106
+ self.action = settings.action
107
+
108
+ def _erase(
109
+ self,
110
+ target_latents: torch.FloatTensor, # "van gogh"
111
+ positive_latents: torch.FloatTensor, # "van gogh"
112
+ unconditional_latents: torch.FloatTensor, # ""
113
+ neutral_latents: torch.FloatTensor, # ""
114
+ ) -> torch.FloatTensor:
115
+ """Target latents are going not to have the positive concept."""
116
+ return self.loss_fn(
117
+ target_latents,
118
+ neutral_latents
119
+ - self.guidance_scale * (positive_latents - unconditional_latents)
120
+ )
121
+
122
+
123
+ def _enhance(
124
+ self,
125
+ target_latents: torch.FloatTensor, # "van gogh"
126
+ positive_latents: torch.FloatTensor, # "van gogh"
127
+ unconditional_latents: torch.FloatTensor, # ""
128
+ neutral_latents: torch.FloatTensor, # ""
129
+ ):
130
+ """Target latents are going to have the positive concept."""
131
+ return self.loss_fn(
132
+ target_latents,
133
+ neutral_latents
134
+ + self.guidance_scale * (positive_latents - unconditional_latents)
135
+ )
136
+
137
+ def loss(
138
+ self,
139
+ **kwargs,
140
+ ):
141
+ if self.action == "erase":
142
+ return self._erase(**kwargs)
143
+
144
+ elif self.action == "enhance":
145
+ return self._enhance(**kwargs)
146
+
147
+ else:
148
+ raise ValueError("action must be erase or enhance")
149
+
150
+
151
+ def load_prompts_from_yaml(path, attributes = []):
152
+ with open(path, "r") as f:
153
+ prompts = yaml.safe_load(f)
154
+ print(prompts)
155
+ if len(prompts) == 0:
156
+ raise ValueError("prompts file is empty")
157
+ if len(attributes)!=0:
158
+ newprompts = []
159
+ for i in range(len(prompts)):
160
+ for att in attributes:
161
+ copy_ = copy.deepcopy(prompts[i])
162
+ copy_['target'] = att + ' ' + copy_['target']
163
+ copy_['positive'] = att + ' ' + copy_['positive']
164
+ copy_['neutral'] = att + ' ' + copy_['neutral']
165
+ copy_['unconditional'] = att + ' ' + copy_['unconditional']
166
+ newprompts.append(copy_)
167
+ else:
168
+ newprompts = copy.deepcopy(prompts)
169
+
170
+ print(newprompts)
171
+ print(len(prompts), len(newprompts))
172
+ prompt_settings = [PromptSettings(**prompt) for prompt in newprompts]
173
+
174
+ return prompt_settings
trainscripts/imagesliders/train_lora-scale-xl.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ref:
2
+ # - https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L566
3
+ # - https://huggingface.co/spaces/baulab/Erasing-Concepts-In-Diffusion/blob/main/train.py
4
+
5
+ from typing import List, Optional
6
+ import argparse
7
+ import ast
8
+ from pathlib import Path
9
+ import gc, os
10
+ import numpy as np
11
+
12
+ import torch
13
+ from tqdm import tqdm
14
+ from PIL import Image
15
+
16
+
17
+
18
+ import train_util
19
+ import random
20
+ import model_util
21
+ import prompt_util
22
+ from prompt_util import (
23
+ PromptEmbedsCache,
24
+ PromptEmbedsPair,
25
+ PromptSettings,
26
+ PromptEmbedsXL,
27
+ )
28
+ import debug_util
29
+ import config_util
30
+ from config_util import RootConfig
31
+
32
+ import wandb
33
+
34
+ NUM_IMAGES_PER_PROMPT = 1
35
+ from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
36
+
37
+ def flush():
38
+ torch.cuda.empty_cache()
39
+ gc.collect()
40
+
41
+
42
+ def train(
43
+ config: RootConfig,
44
+ prompts: list[PromptSettings],
45
+ device,
46
+ folder_main: str,
47
+ folders,
48
+ scales,
49
+
50
+ ):
51
+ scales = np.array(scales)
52
+ folders = np.array(folders)
53
+ scales_unique = list(scales)
54
+
55
+ metadata = {
56
+ "prompts": ",".join([prompt.json() for prompt in prompts]),
57
+ "config": config.json(),
58
+ }
59
+ save_path = Path(config.save.path)
60
+
61
+ modules = DEFAULT_TARGET_REPLACE
62
+ if config.network.type == "c3lier":
63
+ modules += UNET_TARGET_REPLACE_MODULE_CONV
64
+
65
+ if config.logging.verbose:
66
+ print(metadata)
67
+
68
+ if config.logging.use_wandb:
69
+ wandb.init(project=f"LECO_{config.save.name}", config=metadata)
70
+
71
+ weight_dtype = config_util.parse_precision(config.train.precision)
72
+ save_weight_dtype = config_util.parse_precision(config.train.precision)
73
+
74
+ (
75
+ tokenizers,
76
+ text_encoders,
77
+ unet,
78
+ noise_scheduler,
79
+ vae
80
+ ) = model_util.load_models_xl(
81
+ config.pretrained_model.name_or_path,
82
+ scheduler_name=config.train.noise_scheduler,
83
+ )
84
+
85
+ for text_encoder in text_encoders:
86
+ text_encoder.to(device, dtype=weight_dtype)
87
+ text_encoder.requires_grad_(False)
88
+ text_encoder.eval()
89
+
90
+ unet.to(device, dtype=weight_dtype)
91
+ if config.other.use_xformers:
92
+ unet.enable_xformers_memory_efficient_attention()
93
+ unet.requires_grad_(False)
94
+ unet.eval()
95
+
96
+ vae.to(device)
97
+ vae.requires_grad_(False)
98
+ vae.eval()
99
+
100
+ network = LoRANetwork(
101
+ unet,
102
+ rank=config.network.rank,
103
+ multiplier=1.0,
104
+ alpha=config.network.alpha,
105
+ train_method=config.network.training_method,
106
+ ).to(device, dtype=weight_dtype)
107
+
108
+ optimizer_module = train_util.get_optimizer(config.train.optimizer)
109
+ #optimizer_args
110
+ optimizer_kwargs = {}
111
+ if config.train.optimizer_args is not None and len(config.train.optimizer_args) > 0:
112
+ for arg in config.train.optimizer_args.split(" "):
113
+ key, value = arg.split("=")
114
+ value = ast.literal_eval(value)
115
+ optimizer_kwargs[key] = value
116
+
117
+ optimizer = optimizer_module(network.prepare_optimizer_params(), lr=config.train.lr, **optimizer_kwargs)
118
+ lr_scheduler = train_util.get_lr_scheduler(
119
+ config.train.lr_scheduler,
120
+ optimizer,
121
+ max_iterations=config.train.iterations,
122
+ lr_min=config.train.lr / 100,
123
+ )
124
+ criteria = torch.nn.MSELoss()
125
+
126
+ print("Prompts")
127
+ for settings in prompts:
128
+ print(settings)
129
+
130
+ # debug
131
+ debug_util.check_requires_grad(network)
132
+ debug_util.check_training_mode(network)
133
+
134
+ cache = PromptEmbedsCache()
135
+ prompt_pairs: list[PromptEmbedsPair] = []
136
+
137
+ with torch.no_grad():
138
+ for settings in prompts:
139
+ print(settings)
140
+ for prompt in [
141
+ settings.target,
142
+ settings.positive,
143
+ settings.neutral,
144
+ settings.unconditional,
145
+ ]:
146
+ if cache[prompt] == None:
147
+ tex_embs, pool_embs = train_util.encode_prompts_xl(
148
+ tokenizers,
149
+ text_encoders,
150
+ [prompt],
151
+ num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
152
+ )
153
+ cache[prompt] = PromptEmbedsXL(
154
+ tex_embs,
155
+ pool_embs
156
+ )
157
+
158
+ prompt_pairs.append(
159
+ PromptEmbedsPair(
160
+ criteria,
161
+ cache[settings.target],
162
+ cache[settings.positive],
163
+ cache[settings.unconditional],
164
+ cache[settings.neutral],
165
+ settings,
166
+ )
167
+ )
168
+
169
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
170
+ del tokenizer, text_encoder
171
+
172
+ flush()
173
+
174
+ pbar = tqdm(range(config.train.iterations))
175
+
176
+ loss = None
177
+
178
+ for i in pbar:
179
+ with torch.no_grad():
180
+ noise_scheduler.set_timesteps(
181
+ config.train.max_denoising_steps, device=device
182
+ )
183
+
184
+ optimizer.zero_grad()
185
+
186
+ prompt_pair: PromptEmbedsPair = prompt_pairs[
187
+ torch.randint(0, len(prompt_pairs), (1,)).item()
188
+ ]
189
+
190
+ # 1 ~ 49 からランダム
191
+ timesteps_to = torch.randint(
192
+ 1, config.train.max_denoising_steps, (1,)
193
+ ).item()
194
+
195
+ height, width = prompt_pair.resolution, prompt_pair.resolution
196
+ if prompt_pair.dynamic_resolution:
197
+ height, width = train_util.get_random_resolution_in_bucket(
198
+ prompt_pair.resolution
199
+ )
200
+
201
+ if config.logging.verbose:
202
+ print("guidance_scale:", prompt_pair.guidance_scale)
203
+ print("resolution:", prompt_pair.resolution)
204
+ print("dynamic_resolution:", prompt_pair.dynamic_resolution)
205
+ if prompt_pair.dynamic_resolution:
206
+ print("bucketed resolution:", (height, width))
207
+ print("batch_size:", prompt_pair.batch_size)
208
+ print("dynamic_crops:", prompt_pair.dynamic_crops)
209
+
210
+
211
+
212
+ scale_to_look = abs(random.choice(list(scales_unique)))
213
+ folder1 = folders[scales==-scale_to_look][0]
214
+ folder2 = folders[scales==scale_to_look][0]
215
+
216
+ ims = os.listdir(f'{folder_main}/{folder1}/')
217
+ ims = [im_ for im_ in ims if '.png' in im_ or '.jpg' in im_ or '.jpeg' in im_ or '.webp' in im_]
218
+ random_sampler = random.randint(0, len(ims)-1)
219
+
220
+ img1 = Image.open(f'{folder_main}/{folder1}/{ims[random_sampler]}').resize((512,512))
221
+ img2 = Image.open(f'{folder_main}/{folder2}/{ims[random_sampler]}').resize((512,512))
222
+
223
+ seed = random.randint(0,2*15)
224
+
225
+ generator = torch.manual_seed(seed)
226
+ denoised_latents_low, low_noise = train_util.get_noisy_image(
227
+ img1,
228
+ vae,
229
+ generator,
230
+ unet,
231
+ noise_scheduler,
232
+ start_timesteps=0,
233
+ total_timesteps=timesteps_to)
234
+ denoised_latents_low = denoised_latents_low.to(device, dtype=weight_dtype)
235
+ low_noise = low_noise.to(device, dtype=weight_dtype)
236
+
237
+ generator = torch.manual_seed(seed)
238
+ denoised_latents_high, high_noise = train_util.get_noisy_image(
239
+ img2,
240
+ vae,
241
+ generator,
242
+ unet,
243
+ noise_scheduler,
244
+ start_timesteps=0,
245
+ total_timesteps=timesteps_to)
246
+ denoised_latents_high = denoised_latents_high.to(device, dtype=weight_dtype)
247
+ high_noise = high_noise.to(device, dtype=weight_dtype)
248
+ noise_scheduler.set_timesteps(1000)
249
+
250
+ add_time_ids = train_util.get_add_time_ids(
251
+ height,
252
+ width,
253
+ dynamic_crops=prompt_pair.dynamic_crops,
254
+ dtype=weight_dtype,
255
+ ).to(device, dtype=weight_dtype)
256
+
257
+
258
+ current_timestep = noise_scheduler.timesteps[
259
+ int(timesteps_to * 1000 / config.train.max_denoising_steps)
260
+ ]
261
+ try:
262
+ # with network: の外では空のLoRAのみが有効になる
263
+ high_latents = train_util.predict_noise_xl(
264
+ unet,
265
+ noise_scheduler,
266
+ current_timestep,
267
+ denoised_latents_high,
268
+ text_embeddings=train_util.concat_embeddings(
269
+ prompt_pair.unconditional.text_embeds,
270
+ prompt_pair.positive.text_embeds,
271
+ prompt_pair.batch_size,
272
+ ),
273
+ add_text_embeddings=train_util.concat_embeddings(
274
+ prompt_pair.unconditional.pooled_embeds,
275
+ prompt_pair.positive.pooled_embeds,
276
+ prompt_pair.batch_size,
277
+ ),
278
+ add_time_ids=train_util.concat_embeddings(
279
+ add_time_ids, add_time_ids, prompt_pair.batch_size
280
+ ),
281
+ guidance_scale=1,
282
+ ).to(device, dtype=torch.float32)
283
+ except:
284
+ flush()
285
+ print(f'Error Occured!: {np.array(img1).shape} {np.array(img2).shape}')
286
+ continue
287
+ # with network: の外では空のLoRAのみが有効になる
288
+
289
+ low_latents = train_util.predict_noise_xl(
290
+ unet,
291
+ noise_scheduler,
292
+ current_timestep,
293
+ denoised_latents_low,
294
+ text_embeddings=train_util.concat_embeddings(
295
+ prompt_pair.unconditional.text_embeds,
296
+ prompt_pair.neutral.text_embeds,
297
+ prompt_pair.batch_size,
298
+ ),
299
+ add_text_embeddings=train_util.concat_embeddings(
300
+ prompt_pair.unconditional.pooled_embeds,
301
+ prompt_pair.neutral.pooled_embeds,
302
+ prompt_pair.batch_size,
303
+ ),
304
+ add_time_ids=train_util.concat_embeddings(
305
+ add_time_ids, add_time_ids, prompt_pair.batch_size
306
+ ),
307
+ guidance_scale=1,
308
+ ).to(device, dtype=torch.float32)
309
+
310
+
311
+
312
+ if config.logging.verbose:
313
+ print("positive_latents:", positive_latents[0, 0, :5, :5])
314
+ print("neutral_latents:", neutral_latents[0, 0, :5, :5])
315
+ print("unconditional_latents:", unconditional_latents[0, 0, :5, :5])
316
+
317
+ network.set_lora_slider(scale=scale_to_look)
318
+ with network:
319
+ target_latents_high = train_util.predict_noise_xl(
320
+ unet,
321
+ noise_scheduler,
322
+ current_timestep,
323
+ denoised_latents_high,
324
+ text_embeddings=train_util.concat_embeddings(
325
+ prompt_pair.unconditional.text_embeds,
326
+ prompt_pair.positive.text_embeds,
327
+ prompt_pair.batch_size,
328
+ ),
329
+ add_text_embeddings=train_util.concat_embeddings(
330
+ prompt_pair.unconditional.pooled_embeds,
331
+ prompt_pair.positive.pooled_embeds,
332
+ prompt_pair.batch_size,
333
+ ),
334
+ add_time_ids=train_util.concat_embeddings(
335
+ add_time_ids, add_time_ids, prompt_pair.batch_size
336
+ ),
337
+ guidance_scale=1,
338
+ ).to(device, dtype=torch.float32)
339
+
340
+ high_latents.requires_grad = False
341
+ low_latents.requires_grad = False
342
+
343
+ loss_high = criteria(target_latents_high, high_noise.to(torch.float32))
344
+ pbar.set_description(f"Loss*1k: {loss_high.item()*1000:.4f}")
345
+ loss_high.backward()
346
+
347
+ # opposite
348
+ network.set_lora_slider(scale=-scale_to_look)
349
+ with network:
350
+ target_latents_low = train_util.predict_noise_xl(
351
+ unet,
352
+ noise_scheduler,
353
+ current_timestep,
354
+ denoised_latents_low,
355
+ text_embeddings=train_util.concat_embeddings(
356
+ prompt_pair.unconditional.text_embeds,
357
+ prompt_pair.neutral.text_embeds,
358
+ prompt_pair.batch_size,
359
+ ),
360
+ add_text_embeddings=train_util.concat_embeddings(
361
+ prompt_pair.unconditional.pooled_embeds,
362
+ prompt_pair.neutral.pooled_embeds,
363
+ prompt_pair.batch_size,
364
+ ),
365
+ add_time_ids=train_util.concat_embeddings(
366
+ add_time_ids, add_time_ids, prompt_pair.batch_size
367
+ ),
368
+ guidance_scale=1,
369
+ ).to(device, dtype=torch.float32)
370
+
371
+
372
+ high_latents.requires_grad = False
373
+ low_latents.requires_grad = False
374
+
375
+ loss_low = criteria(target_latents_low, low_noise.to(torch.float32))
376
+ pbar.set_description(f"Loss*1k: {loss_low.item()*1000:.4f}")
377
+ loss_low.backward()
378
+
379
+
380
+ optimizer.step()
381
+ lr_scheduler.step()
382
+
383
+ del (
384
+ high_latents,
385
+ low_latents,
386
+ target_latents_low,
387
+ target_latents_high,
388
+ )
389
+ flush()
390
+
391
+ if (
392
+ i % config.save.per_steps == 0
393
+ and i != 0
394
+ and i != config.train.iterations - 1
395
+ ):
396
+ print("Saving...")
397
+ save_path.mkdir(parents=True, exist_ok=True)
398
+ network.save_weights(
399
+ save_path / f"{config.save.name}_{i}steps.pt",
400
+ dtype=save_weight_dtype,
401
+ )
402
+
403
+ print("Saving...")
404
+ save_path.mkdir(parents=True, exist_ok=True)
405
+ network.save_weights(
406
+ save_path / f"{config.save.name}_last.pt",
407
+ dtype=save_weight_dtype,
408
+ )
409
+
410
+ del (
411
+ unet,
412
+ noise_scheduler,
413
+ loss,
414
+ optimizer,
415
+ network,
416
+ )
417
+
418
+ flush()
419
+
420
+ print("Done.")
421
+
422
+
423
+ def main(args):
424
+ config_file = args.config_file
425
+
426
+ config = config_util.load_config_from_yaml(config_file)
427
+ if args.name is not None:
428
+ config.save.name = args.name
429
+ attributes = []
430
+ if args.attributes is not None:
431
+ attributes = args.attributes.split(',')
432
+ attributes = [a.strip() for a in attributes]
433
+
434
+ config.network.alpha = args.alpha
435
+ config.network.rank = args.rank
436
+ config.save.name += f'_alpha{args.alpha}'
437
+ config.save.name += f'_rank{config.network.rank }'
438
+ config.save.name += f'_{config.network.training_method}'
439
+ config.save.path += f'/{config.save.name}'
440
+
441
+ prompts = prompt_util.load_prompts_from_yaml(config.prompts_file, attributes)
442
+
443
+ device = torch.device(f"cuda:{args.device}")
444
+
445
+ folders = args.folders.split(',')
446
+ folders = [f.strip() for f in folders]
447
+ scales = args.scales.split(',')
448
+ scales = [f.strip() for f in scales]
449
+ scales = [int(s) for s in scales]
450
+
451
+ print(folders, scales)
452
+ if len(scales) != len(folders):
453
+ raise Exception('the number of folders need to match the number of scales')
454
+
455
+ if args.stylecheck is not None:
456
+ check = args.stylecheck.split('-')
457
+
458
+ for i in range(int(check[0]), int(check[1])):
459
+ folder_main = args.folder_main+ f'{i}'
460
+ config.save.name = f'{os.path.basename(folder_main)}'
461
+ config.save.name += f'_alpha{args.alpha}'
462
+ config.save.name += f'_rank{config.network.rank }'
463
+ config.save.path = f'models/{config.save.name}'
464
+ train(config=config, prompts=prompts, device=device, folder_main = folder_main, folders = folders, scales = scales)
465
+ else:
466
+ train(config=config, prompts=prompts, device=device, folder_main = args.folder_main, folders = folders, scales = scales)
467
+
468
+
469
+ if __name__ == "__main__":
470
+ parser = argparse.ArgumentParser()
471
+ parser.add_argument(
472
+ "--config_file",
473
+ required=True,
474
+ help="Config file for training.",
475
+ )
476
+ # config_file 'data/config.yaml'
477
+ parser.add_argument(
478
+ "--alpha",
479
+ type=float,
480
+ required=True,
481
+ help="LoRA weight.",
482
+ )
483
+ # --alpha 1.0
484
+ parser.add_argument(
485
+ "--rank",
486
+ type=int,
487
+ required=False,
488
+ help="Rank of LoRA.",
489
+ default=4,
490
+ )
491
+ # --rank 4
492
+ parser.add_argument(
493
+ "--device",
494
+ type=int,
495
+ required=False,
496
+ default=0,
497
+ help="Device to train on.",
498
+ )
499
+ # --device 0
500
+ parser.add_argument(
501
+ "--name",
502
+ type=str,
503
+ required=False,
504
+ default=None,
505
+ help="Device to train on.",
506
+ )
507
+ # --name 'eyesize_slider'
508
+ parser.add_argument(
509
+ "--attributes",
510
+ type=str,
511
+ required=False,
512
+ default=None,
513
+ help="attritbutes to disentangle (comma seperated string)",
514
+ )
515
+ parser.add_argument(
516
+ "--folder_main",
517
+ type=str,
518
+ required=True,
519
+ help="The folder to check",
520
+ )
521
+
522
+ parser.add_argument(
523
+ "--stylecheck",
524
+ type=str,
525
+ required=False,
526
+ default = None,
527
+ help="The folder to check",
528
+ )
529
+
530
+ parser.add_argument(
531
+ "--folders",
532
+ type=str,
533
+ required=False,
534
+ default = 'verylow, low, high, veryhigh',
535
+ help="folders with different attribute-scaled images",
536
+ )
537
+ parser.add_argument(
538
+ "--scales",
539
+ type=str,
540
+ required=False,
541
+ default = '-2, -1, 1, 2',
542
+ help="scales for different attribute-scaled images",
543
+ )
544
+
545
+
546
+ args = parser.parse_args()
547
+
548
+ main(args)
trainscripts/imagesliders/train_lora-scale.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ref:
2
+ # - https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L566
3
+ # - https://huggingface.co/spaces/baulab/Erasing-Concepts-In-Diffusion/blob/main/train.py
4
+
5
+ from typing import List, Optional
6
+ import argparse
7
+ import ast
8
+ from pathlib import Path
9
+ import gc
10
+
11
+ import torch
12
+ from tqdm import tqdm
13
+ import os, glob
14
+
15
+ from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
16
+ import train_util
17
+ import model_util
18
+ import prompt_util
19
+ from prompt_util import PromptEmbedsCache, PromptEmbedsPair, PromptSettings
20
+ import debug_util
21
+ import config_util
22
+ from config_util import RootConfig
23
+ import random
24
+ import numpy as np
25
+ import wandb
26
+ from PIL import Image
27
+
28
+ def flush():
29
+ torch.cuda.empty_cache()
30
+ gc.collect()
31
+ def prev_step(model_output, timestep, scheduler, sample):
32
+ prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
33
+ alpha_prod_t =scheduler.alphas_cumprod[timestep]
34
+ alpha_prod_t_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
35
+ beta_prod_t = 1 - alpha_prod_t
36
+ pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
37
+ pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output
38
+ prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
39
+ return prev_sample
40
+
41
+ def train(
42
+ config: RootConfig,
43
+ prompts: list[PromptSettings],
44
+ device: int,
45
+ folder_main: str,
46
+ folders,
47
+ scales,
48
+ ):
49
+ scales = np.array(scales)
50
+ folders = np.array(folders)
51
+ scales_unique = list(scales)
52
+
53
+ metadata = {
54
+ "prompts": ",".join([prompt.json() for prompt in prompts]),
55
+ "config": config.json(),
56
+ }
57
+ save_path = Path(config.save.path)
58
+
59
+ modules = DEFAULT_TARGET_REPLACE
60
+ if config.network.type == "c3lier":
61
+ modules += UNET_TARGET_REPLACE_MODULE_CONV
62
+
63
+ if config.logging.verbose:
64
+ print(metadata)
65
+
66
+ if config.logging.use_wandb:
67
+ wandb.init(project=f"LECO_{config.save.name}", config=metadata)
68
+
69
+ weight_dtype = config_util.parse_precision(config.train.precision)
70
+ save_weight_dtype = config_util.parse_precision(config.train.precision)
71
+
72
+ tokenizer, text_encoder, unet, noise_scheduler, vae = model_util.load_models(
73
+ config.pretrained_model.name_or_path,
74
+ scheduler_name=config.train.noise_scheduler,
75
+ v2=config.pretrained_model.v2,
76
+ v_pred=config.pretrained_model.v_pred,
77
+ )
78
+
79
+ text_encoder.to(device, dtype=weight_dtype)
80
+ text_encoder.eval()
81
+
82
+ unet.to(device, dtype=weight_dtype)
83
+ unet.enable_xformers_memory_efficient_attention()
84
+ unet.requires_grad_(False)
85
+ unet.eval()
86
+
87
+ vae.to(device)
88
+ vae.requires_grad_(False)
89
+ vae.eval()
90
+
91
+ network = LoRANetwork(
92
+ unet,
93
+ rank=config.network.rank,
94
+ multiplier=1.0,
95
+ alpha=config.network.alpha,
96
+ train_method=config.network.training_method,
97
+ ).to(device, dtype=weight_dtype)
98
+
99
+ optimizer_module = train_util.get_optimizer(config.train.optimizer)
100
+ #optimizer_args
101
+ optimizer_kwargs = {}
102
+ if config.train.optimizer_args is not None and len(config.train.optimizer_args) > 0:
103
+ for arg in config.train.optimizer_args.split(" "):
104
+ key, value = arg.split("=")
105
+ value = ast.literal_eval(value)
106
+ optimizer_kwargs[key] = value
107
+
108
+ optimizer = optimizer_module(network.prepare_optimizer_params(), lr=config.train.lr, **optimizer_kwargs)
109
+ lr_scheduler = train_util.get_lr_scheduler(
110
+ config.train.lr_scheduler,
111
+ optimizer,
112
+ max_iterations=config.train.iterations,
113
+ lr_min=config.train.lr / 100,
114
+ )
115
+ criteria = torch.nn.MSELoss()
116
+
117
+ print("Prompts")
118
+ for settings in prompts:
119
+ print(settings)
120
+
121
+ # debug
122
+ debug_util.check_requires_grad(network)
123
+ debug_util.check_training_mode(network)
124
+
125
+ cache = PromptEmbedsCache()
126
+ prompt_pairs: list[PromptEmbedsPair] = []
127
+
128
+ with torch.no_grad():
129
+ for settings in prompts:
130
+ print(settings)
131
+ for prompt in [
132
+ settings.target,
133
+ settings.positive,
134
+ settings.neutral,
135
+ settings.unconditional,
136
+ ]:
137
+ print(prompt)
138
+ if isinstance(prompt, list):
139
+ if prompt == settings.positive:
140
+ key_setting = 'positive'
141
+ else:
142
+ key_setting = 'attributes'
143
+ if len(prompt) == 0:
144
+ cache[key_setting] = []
145
+ else:
146
+ if cache[key_setting] is None:
147
+ cache[key_setting] = train_util.encode_prompts(
148
+ tokenizer, text_encoder, prompt
149
+ )
150
+ else:
151
+ if cache[prompt] == None:
152
+ cache[prompt] = train_util.encode_prompts(
153
+ tokenizer, text_encoder, [prompt]
154
+ )
155
+
156
+ prompt_pairs.append(
157
+ PromptEmbedsPair(
158
+ criteria,
159
+ cache[settings.target],
160
+ cache[settings.positive],
161
+ cache[settings.unconditional],
162
+ cache[settings.neutral],
163
+ settings,
164
+ )
165
+ )
166
+
167
+ del tokenizer
168
+ del text_encoder
169
+
170
+ flush()
171
+
172
+ pbar = tqdm(range(config.train.iterations))
173
+ for i in pbar:
174
+ with torch.no_grad():
175
+ noise_scheduler.set_timesteps(
176
+ config.train.max_denoising_steps, device=device
177
+ )
178
+
179
+ optimizer.zero_grad()
180
+
181
+ prompt_pair: PromptEmbedsPair = prompt_pairs[
182
+ torch.randint(0, len(prompt_pairs), (1,)).item()
183
+ ]
184
+
185
+ # 1 ~ 49 からランダム
186
+ timesteps_to = torch.randint(
187
+ 1, config.train.max_denoising_steps-1, (1,)
188
+ # 1, 25, (1,)
189
+ ).item()
190
+
191
+ height, width = (
192
+ prompt_pair.resolution,
193
+ prompt_pair.resolution,
194
+ )
195
+ if prompt_pair.dynamic_resolution:
196
+ height, width = train_util.get_random_resolution_in_bucket(
197
+ prompt_pair.resolution
198
+ )
199
+
200
+ if config.logging.verbose:
201
+ print("guidance_scale:", prompt_pair.guidance_scale)
202
+ print("resolution:", prompt_pair.resolution)
203
+ print("dynamic_resolution:", prompt_pair.dynamic_resolution)
204
+ if prompt_pair.dynamic_resolution:
205
+ print("bucketed resolution:", (height, width))
206
+ print("batch_size:", prompt_pair.batch_size)
207
+
208
+
209
+
210
+
211
+ scale_to_look = abs(random.choice(list(scales_unique)))
212
+ folder1 = folders[scales==-scale_to_look][0]
213
+ folder2 = folders[scales==scale_to_look][0]
214
+
215
+ ims = os.listdir(f'{folder_main}/{folder1}/')
216
+ ims = [im_ for im_ in ims if '.png' in im_ or '.jpg' in im_ or '.jpeg' in im_ or '.webp' in im_]
217
+ random_sampler = random.randint(0, len(ims)-1)
218
+
219
+ img1 = Image.open(f'{folder_main}/{folder1}/{ims[random_sampler]}').resize((256,256))
220
+ img2 = Image.open(f'{folder_main}/{folder2}/{ims[random_sampler]}').resize((256,256))
221
+
222
+ seed = random.randint(0,2*15)
223
+
224
+ generator = torch.manual_seed(seed)
225
+ denoised_latents_low, low_noise = train_util.get_noisy_image(
226
+ img1,
227
+ vae,
228
+ generator,
229
+ unet,
230
+ noise_scheduler,
231
+ start_timesteps=0,
232
+ total_timesteps=timesteps_to)
233
+ denoised_latents_low = denoised_latents_low.to(device, dtype=weight_dtype)
234
+ low_noise = low_noise.to(device, dtype=weight_dtype)
235
+
236
+ generator = torch.manual_seed(seed)
237
+ denoised_latents_high, high_noise = train_util.get_noisy_image(
238
+ img2,
239
+ vae,
240
+ generator,
241
+ unet,
242
+ noise_scheduler,
243
+ start_timesteps=0,
244
+ total_timesteps=timesteps_to)
245
+ denoised_latents_high = denoised_latents_high.to(device, dtype=weight_dtype)
246
+ high_noise = high_noise.to(device, dtype=weight_dtype)
247
+ noise_scheduler.set_timesteps(1000)
248
+
249
+ current_timestep = noise_scheduler.timesteps[
250
+ int(timesteps_to * 1000 / config.train.max_denoising_steps)
251
+ ]
252
+
253
+ # with network: の外では空のLoRAのみが有効になる
254
+ high_latents = train_util.predict_noise(
255
+ unet,
256
+ noise_scheduler,
257
+ current_timestep,
258
+ denoised_latents_high,
259
+ train_util.concat_embeddings(
260
+ prompt_pair.unconditional,
261
+ prompt_pair.positive,
262
+ prompt_pair.batch_size,
263
+ ),
264
+ guidance_scale=1,
265
+ ).to("cpu", dtype=torch.float32)
266
+ # with network: の外では空のLoRAのみが有効になる
267
+ low_latents = train_util.predict_noise(
268
+ unet,
269
+ noise_scheduler,
270
+ current_timestep,
271
+ denoised_latents_low,
272
+ train_util.concat_embeddings(
273
+ prompt_pair.unconditional,
274
+ prompt_pair.unconditional,
275
+ prompt_pair.batch_size,
276
+ ),
277
+ guidance_scale=1,
278
+ ).to("cpu", dtype=torch.float32)
279
+ if config.logging.verbose:
280
+ print("positive_latents:", positive_latents[0, 0, :5, :5])
281
+ print("neutral_latents:", neutral_latents[0, 0, :5, :5])
282
+ print("unconditional_latents:", unconditional_latents[0, 0, :5, :5])
283
+
284
+ network.set_lora_slider(scale=scale_to_look)
285
+ with network:
286
+ target_latents_high = train_util.predict_noise(
287
+ unet,
288
+ noise_scheduler,
289
+ current_timestep,
290
+ denoised_latents_high,
291
+ train_util.concat_embeddings(
292
+ prompt_pair.unconditional,
293
+ prompt_pair.positive,
294
+ prompt_pair.batch_size,
295
+ ),
296
+ guidance_scale=1,
297
+ ).to("cpu", dtype=torch.float32)
298
+
299
+
300
+ high_latents.requires_grad = False
301
+ low_latents.requires_grad = False
302
+
303
+ loss_high = criteria(target_latents_high, high_noise.cpu().to(torch.float32))
304
+ pbar.set_description(f"Loss*1k: {loss_high.item()*1000:.4f}")
305
+ loss_high.backward()
306
+
307
+
308
+ network.set_lora_slider(scale=-scale_to_look)
309
+ with network:
310
+ target_latents_low = train_util.predict_noise(
311
+ unet,
312
+ noise_scheduler,
313
+ current_timestep,
314
+ denoised_latents_low,
315
+ train_util.concat_embeddings(
316
+ prompt_pair.unconditional,
317
+ prompt_pair.neutral,
318
+ prompt_pair.batch_size,
319
+ ),
320
+ guidance_scale=1,
321
+ ).to("cpu", dtype=torch.float32)
322
+
323
+
324
+ high_latents.requires_grad = False
325
+ low_latents.requires_grad = False
326
+
327
+ loss_low = criteria(target_latents_low, low_noise.cpu().to(torch.float32))
328
+ pbar.set_description(f"Loss*1k: {loss_low.item()*1000:.4f}")
329
+ loss_low.backward()
330
+
331
+ ## NOTICE NO zero_grad between these steps (accumulating gradients)
332
+ #following guidelines from Ostris (https://github.com/ostris/ai-toolkit)
333
+
334
+ optimizer.step()
335
+ lr_scheduler.step()
336
+
337
+ del (
338
+ high_latents,
339
+ low_latents,
340
+ target_latents_low,
341
+ target_latents_high,
342
+ )
343
+ flush()
344
+
345
+ if (
346
+ i % config.save.per_steps == 0
347
+ and i != 0
348
+ and i != config.train.iterations - 1
349
+ ):
350
+ print("Saving...")
351
+ save_path.mkdir(parents=True, exist_ok=True)
352
+ network.save_weights(
353
+ save_path / f"{config.save.name}_{i}steps.pt",
354
+ dtype=save_weight_dtype,
355
+ )
356
+
357
+ print("Saving...")
358
+ save_path.mkdir(parents=True, exist_ok=True)
359
+ network.save_weights(
360
+ save_path / f"{config.save.name}_last.pt",
361
+ dtype=save_weight_dtype,
362
+ )
363
+
364
+ del (
365
+ unet,
366
+ noise_scheduler,
367
+ optimizer,
368
+ network,
369
+ )
370
+
371
+ flush()
372
+
373
+ print("Done.")
374
+
375
+
376
+ def main(args):
377
+ config_file = args.config_file
378
+
379
+ config = config_util.load_config_from_yaml(config_file)
380
+ if args.name is not None:
381
+ config.save.name = args.name
382
+ attributes = []
383
+ if args.attributes is not None:
384
+ attributes = args.attributes.split(',')
385
+ attributes = [a.strip() for a in attributes]
386
+
387
+ config.network.alpha = args.alpha
388
+ config.network.rank = args.rank
389
+ config.save.name += f'_alpha{args.alpha}'
390
+ config.save.name += f'_rank{config.network.rank }'
391
+ config.save.name += f'_{config.network.training_method}'
392
+ config.save.path += f'/{config.save.name}'
393
+
394
+ prompts = prompt_util.load_prompts_from_yaml(config.prompts_file, attributes)
395
+ device = torch.device(f"cuda:{args.device}")
396
+
397
+
398
+ folders = args.folders.split(',')
399
+ folders = [f.strip() for f in folders]
400
+ scales = args.scales.split(',')
401
+ scales = [f.strip() for f in scales]
402
+ scales = [int(s) for s in scales]
403
+
404
+ print(folders, scales)
405
+ if len(scales) != len(folders):
406
+ raise Exception('the number of folders need to match the number of scales')
407
+
408
+ if args.stylecheck is not None:
409
+ check = args.stylecheck.split('-')
410
+
411
+ for i in range(int(check[0]), int(check[1])):
412
+ folder_main = args.folder_main+ f'{i}'
413
+ config.save.name = f'{os.path.basename(folder_main)}'
414
+ config.save.name += f'_alpha{args.alpha}'
415
+ config.save.name += f'_rank{config.network.rank }'
416
+ config.save.path = f'models/{config.save.name}'
417
+ train(config=config, prompts=prompts, device=device, folder_main = folder_main)
418
+ else:
419
+ train(config=config, prompts=prompts, device=device, folder_main = args.folder_main, folders = folders, scales = scales)
420
+
421
+ if __name__ == "__main__":
422
+ parser = argparse.ArgumentParser()
423
+ parser.add_argument(
424
+ "--config_file",
425
+ required=False,
426
+ default = 'data/config.yaml',
427
+ help="Config file for training.",
428
+ )
429
+ parser.add_argument(
430
+ "--alpha",
431
+ type=float,
432
+ required=True,
433
+ help="LoRA weight.",
434
+ )
435
+
436
+ parser.add_argument(
437
+ "--rank",
438
+ type=int,
439
+ required=False,
440
+ help="Rank of LoRA.",
441
+ default=4,
442
+ )
443
+
444
+ parser.add_argument(
445
+ "--device",
446
+ type=int,
447
+ required=False,
448
+ default=0,
449
+ help="Device to train on.",
450
+ )
451
+
452
+ parser.add_argument(
453
+ "--name",
454
+ type=str,
455
+ required=False,
456
+ default=None,
457
+ help="Device to train on.",
458
+ )
459
+
460
+ parser.add_argument(
461
+ "--attributes",
462
+ type=str,
463
+ required=False,
464
+ default=None,
465
+ help="attritbutes to disentangle",
466
+ )
467
+
468
+ parser.add_argument(
469
+ "--folder_main",
470
+ type=str,
471
+ required=True,
472
+ help="The folder to check",
473
+ )
474
+
475
+ parser.add_argument(
476
+ "--stylecheck",
477
+ type=str,
478
+ required=False,
479
+ default = None,
480
+ help="The folder to check",
481
+ )
482
+
483
+ parser.add_argument(
484
+ "--folders",
485
+ type=str,
486
+ required=False,
487
+ default = 'verylow, low, high, veryhigh',
488
+ help="folders with different attribute-scaled images",
489
+ )
490
+ parser.add_argument(
491
+ "--scales",
492
+ type=str,
493
+ required=False,
494
+ default = '-2, -1,1, 2',
495
+ help="scales for different attribute-scaled images",
496
+ )
497
+
498
+
499
+ args = parser.parse_args()
500
+
501
+ main(args)
trainscripts/imagesliders/train_util.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+
5
+ from transformers import CLIPTextModel, CLIPTokenizer
6
+ from diffusers import UNet2DConditionModel, SchedulerMixin
7
+ from diffusers.image_processor import VaeImageProcessor
8
+ from model_util import SDXL_TEXT_ENCODER_TYPE
9
+ from diffusers.utils import randn_tensor
10
+
11
+ from tqdm import tqdm
12
+
13
+ UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
14
+ VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
15
+
16
+ UNET_ATTENTION_TIME_EMBED_DIM = 256 # XL
17
+ TEXT_ENCODER_2_PROJECTION_DIM = 1280
18
+ UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM = 2816
19
+
20
+
21
+ def get_random_noise(
22
+ batch_size: int, height: int, width: int, generator: torch.Generator = None
23
+ ) -> torch.Tensor:
24
+ return torch.randn(
25
+ (
26
+ batch_size,
27
+ UNET_IN_CHANNELS,
28
+ height // VAE_SCALE_FACTOR, # 縦と横これであってるのかわからないけど、どっちにしろ大きな問題は発生しないのでこれでいいや
29
+ width // VAE_SCALE_FACTOR,
30
+ ),
31
+ generator=generator,
32
+ device="cpu",
33
+ )
34
+
35
+
36
+ # https://www.crosslabs.org/blog/diffusion-with-offset-noise
37
+ def apply_noise_offset(latents: torch.FloatTensor, noise_offset: float):
38
+ latents = latents + noise_offset * torch.randn(
39
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
40
+ )
41
+ return latents
42
+
43
+
44
+ def get_initial_latents(
45
+ scheduler: SchedulerMixin,
46
+ n_imgs: int,
47
+ height: int,
48
+ width: int,
49
+ n_prompts: int,
50
+ generator=None,
51
+ ) -> torch.Tensor:
52
+ noise = get_random_noise(n_imgs, height, width, generator=generator).repeat(
53
+ n_prompts, 1, 1, 1
54
+ )
55
+
56
+ latents = noise * scheduler.init_noise_sigma
57
+
58
+ return latents
59
+
60
+
61
+ def text_tokenize(
62
+ tokenizer: CLIPTokenizer, # 普通ならひとつ、XLならふたつ!
63
+ prompts: list[str],
64
+ ):
65
+ return tokenizer(
66
+ prompts,
67
+ padding="max_length",
68
+ max_length=tokenizer.model_max_length,
69
+ truncation=True,
70
+ return_tensors="pt",
71
+ ).input_ids
72
+
73
+
74
+ def text_encode(text_encoder: CLIPTextModel, tokens):
75
+ return text_encoder(tokens.to(text_encoder.device))[0]
76
+
77
+
78
+ def encode_prompts(
79
+ tokenizer: CLIPTokenizer,
80
+ text_encoder: CLIPTokenizer,
81
+ prompts: list[str],
82
+ ):
83
+
84
+ text_tokens = text_tokenize(tokenizer, prompts)
85
+ text_embeddings = text_encode(text_encoder, text_tokens)
86
+
87
+
88
+
89
+ return text_embeddings
90
+
91
+
92
+ # https://github.com/huggingface/diffusers/blob/78922ed7c7e66c20aa95159c7b7a6057ba7d590d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L334-L348
93
+ def text_encode_xl(
94
+ text_encoder: SDXL_TEXT_ENCODER_TYPE,
95
+ tokens: torch.FloatTensor,
96
+ num_images_per_prompt: int = 1,
97
+ ):
98
+ prompt_embeds = text_encoder(
99
+ tokens.to(text_encoder.device), output_hidden_states=True
100
+ )
101
+ pooled_prompt_embeds = prompt_embeds[0]
102
+ prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer
103
+
104
+ bs_embed, seq_len, _ = prompt_embeds.shape
105
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
106
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
107
+
108
+ return prompt_embeds, pooled_prompt_embeds
109
+
110
+
111
+ def encode_prompts_xl(
112
+ tokenizers: list[CLIPTokenizer],
113
+ text_encoders: list[SDXL_TEXT_ENCODER_TYPE],
114
+ prompts: list[str],
115
+ num_images_per_prompt: int = 1,
116
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
117
+ # text_encoder and text_encoder_2's penuultimate layer's output
118
+ text_embeds_list = []
119
+ pooled_text_embeds = None # always text_encoder_2's pool
120
+
121
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
122
+ text_tokens_input_ids = text_tokenize(tokenizer, prompts)
123
+ text_embeds, pooled_text_embeds = text_encode_xl(
124
+ text_encoder, text_tokens_input_ids, num_images_per_prompt
125
+ )
126
+
127
+ text_embeds_list.append(text_embeds)
128
+
129
+ bs_embed = pooled_text_embeds.shape[0]
130
+ pooled_text_embeds = pooled_text_embeds.repeat(1, num_images_per_prompt).view(
131
+ bs_embed * num_images_per_prompt, -1
132
+ )
133
+
134
+ return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds
135
+
136
+
137
+ def concat_embeddings(
138
+ unconditional: torch.FloatTensor,
139
+ conditional: torch.FloatTensor,
140
+ n_imgs: int,
141
+ ):
142
+ return torch.cat([unconditional, conditional]).repeat_interleave(n_imgs, dim=0)
143
+
144
+
145
+ # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L721
146
+ def predict_noise(
147
+ unet: UNet2DConditionModel,
148
+ scheduler: SchedulerMixin,
149
+ timestep: int, # 現在のタイムステップ
150
+ latents: torch.FloatTensor,
151
+ text_embeddings: torch.FloatTensor, # uncond な text embed と cond な text embed を結合したもの
152
+ guidance_scale=7.5,
153
+ ) -> torch.FloatTensor:
154
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
155
+ latent_model_input = torch.cat([latents] * 2)
156
+
157
+ latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
158
+
159
+ # predict the noise residual
160
+ noise_pred = unet(
161
+ latent_model_input,
162
+ timestep,
163
+ encoder_hidden_states=text_embeddings,
164
+ ).sample
165
+
166
+ # perform guidance
167
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
168
+ guided_target = noise_pred_uncond + guidance_scale * (
169
+ noise_pred_text - noise_pred_uncond
170
+ )
171
+
172
+ return guided_target
173
+
174
+
175
+
176
+ # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
177
+ @torch.no_grad()
178
+ def diffusion(
179
+ unet: UNet2DConditionModel,
180
+ scheduler: SchedulerMixin,
181
+ latents: torch.FloatTensor, # ただのノイズだけのlatents
182
+ text_embeddings: torch.FloatTensor,
183
+ total_timesteps: int = 1000,
184
+ start_timesteps=0,
185
+ **kwargs,
186
+ ):
187
+ # latents_steps = []
188
+
189
+ for timestep in tqdm(scheduler.timesteps[start_timesteps:total_timesteps]):
190
+ noise_pred = predict_noise(
191
+ unet, scheduler, timestep, latents, text_embeddings, **kwargs
192
+ )
193
+
194
+ # compute the previous noisy sample x_t -> x_t-1
195
+ latents = scheduler.step(noise_pred, timestep, latents).prev_sample
196
+
197
+ # return latents_steps
198
+ return latents
199
+
200
+ @torch.no_grad()
201
+ def get_noisy_image(
202
+ img,
203
+ vae,
204
+ generator,
205
+ unet: UNet2DConditionModel,
206
+ scheduler: SchedulerMixin,
207
+ total_timesteps: int = 1000,
208
+ start_timesteps=0,
209
+
210
+ **kwargs,
211
+ ):
212
+ # latents_steps = []
213
+ vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
214
+ image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
215
+
216
+ image = img
217
+ im_orig = image
218
+ device = vae.device
219
+ image = image_processor.preprocess(image).to(device)
220
+
221
+ init_latents = vae.encode(image).latent_dist.sample(None)
222
+ init_latents = vae.config.scaling_factor * init_latents
223
+
224
+ init_latents = torch.cat([init_latents], dim=0)
225
+
226
+ shape = init_latents.shape
227
+
228
+ noise = randn_tensor(shape, generator=generator, device=device)
229
+
230
+ time_ = total_timesteps
231
+ timestep = scheduler.timesteps[time_:time_+1]
232
+ # get latents
233
+ init_latents = scheduler.add_noise(init_latents, noise, timestep)
234
+
235
+ return init_latents, noise
236
+
237
+
238
+ def rescale_noise_cfg(
239
+ noise_cfg: torch.FloatTensor, noise_pred_text, guidance_rescale=0.0
240
+ ):
241
+ """
242
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
243
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
244
+ """
245
+ std_text = noise_pred_text.std(
246
+ dim=list(range(1, noise_pred_text.ndim)), keepdim=True
247
+ )
248
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
249
+ # rescale the results from guidance (fixes overexposure)
250
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
251
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
252
+ noise_cfg = (
253
+ guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
254
+ )
255
+
256
+ return noise_cfg
257
+
258
+
259
+ def predict_noise_xl(
260
+ unet: UNet2DConditionModel,
261
+ scheduler: SchedulerMixin,
262
+ timestep: int, # 現在のタイムステップ
263
+ latents: torch.FloatTensor,
264
+ text_embeddings: torch.FloatTensor, # uncond な text embed と cond な text embed を結合したもの
265
+ add_text_embeddings: torch.FloatTensor, # pooled なやつ
266
+ add_time_ids: torch.FloatTensor,
267
+ guidance_scale=7.5,
268
+ guidance_rescale=0.7,
269
+ ) -> torch.FloatTensor:
270
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
271
+ latent_model_input = torch.cat([latents] * 2)
272
+
273
+ latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
274
+
275
+ added_cond_kwargs = {
276
+ "text_embeds": add_text_embeddings,
277
+ "time_ids": add_time_ids,
278
+ }
279
+
280
+ # predict the noise residual
281
+ noise_pred = unet(
282
+ latent_model_input,
283
+ timestep,
284
+ encoder_hidden_states=text_embeddings,
285
+ added_cond_kwargs=added_cond_kwargs,
286
+ ).sample
287
+
288
+ # perform guidance
289
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
290
+ guided_target = noise_pred_uncond + guidance_scale * (
291
+ noise_pred_text - noise_pred_uncond
292
+ )
293
+
294
+ # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
295
+ noise_pred = rescale_noise_cfg(
296
+ noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
297
+ )
298
+
299
+ return guided_target
300
+
301
+
302
+ @torch.no_grad()
303
+ def diffusion_xl(
304
+ unet: UNet2DConditionModel,
305
+ scheduler: SchedulerMixin,
306
+ latents: torch.FloatTensor, # ただのノイズだけのlatents
307
+ text_embeddings: tuple[torch.FloatTensor, torch.FloatTensor],
308
+ add_text_embeddings: torch.FloatTensor, # pooled なやつ
309
+ add_time_ids: torch.FloatTensor,
310
+ guidance_scale: float = 1.0,
311
+ total_timesteps: int = 1000,
312
+ start_timesteps=0,
313
+ ):
314
+ # latents_steps = []
315
+
316
+ for timestep in tqdm(scheduler.timesteps[start_timesteps:total_timesteps]):
317
+ noise_pred = predict_noise_xl(
318
+ unet,
319
+ scheduler,
320
+ timestep,
321
+ latents,
322
+ text_embeddings,
323
+ add_text_embeddings,
324
+ add_time_ids,
325
+ guidance_scale=guidance_scale,
326
+ guidance_rescale=0.7,
327
+ )
328
+
329
+ # compute the previous noisy sample x_t -> x_t-1
330
+ latents = scheduler.step(noise_pred, timestep, latents).prev_sample
331
+
332
+ # return latents_steps
333
+ return latents
334
+
335
+
336
+ # for XL
337
+ def get_add_time_ids(
338
+ height: int,
339
+ width: int,
340
+ dynamic_crops: bool = False,
341
+ dtype: torch.dtype = torch.float32,
342
+ ):
343
+ if dynamic_crops:
344
+ # random float scale between 1 and 3
345
+ random_scale = torch.rand(1).item() * 2 + 1
346
+ original_size = (int(height * random_scale), int(width * random_scale))
347
+ # random position
348
+ crops_coords_top_left = (
349
+ torch.randint(0, original_size[0] - height, (1,)).item(),
350
+ torch.randint(0, original_size[1] - width, (1,)).item(),
351
+ )
352
+ target_size = (height, width)
353
+ else:
354
+ original_size = (height, width)
355
+ crops_coords_top_left = (0, 0)
356
+ target_size = (height, width)
357
+
358
+ # this is expected as 6
359
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
360
+
361
+ # this is expected as 2816
362
+ passed_add_embed_dim = (
363
+ UNET_ATTENTION_TIME_EMBED_DIM * len(add_time_ids) # 256 * 6
364
+ + TEXT_ENCODER_2_PROJECTION_DIM # + 1280
365
+ )
366
+ if passed_add_embed_dim != UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM:
367
+ raise ValueError(
368
+ f"Model expects an added time embedding vector of length {UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
369
+ )
370
+
371
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
372
+ return add_time_ids
373
+
374
+
375
+ def get_optimizer(name: str):
376
+ name = name.lower()
377
+
378
+ if name.startswith("dadapt"):
379
+ import dadaptation
380
+
381
+ if name == "dadaptadam":
382
+ return dadaptation.DAdaptAdam
383
+ elif name == "dadaptlion":
384
+ return dadaptation.DAdaptLion
385
+ else:
386
+ raise ValueError("DAdapt optimizer must be dadaptadam or dadaptlion")
387
+
388
+ elif name.endswith("8bit"): # 検証してない
389
+ import bitsandbytes as bnb
390
+
391
+ if name == "adam8bit":
392
+ return bnb.optim.Adam8bit
393
+ elif name == "lion8bit":
394
+ return bnb.optim.Lion8bit
395
+ else:
396
+ raise ValueError("8bit optimizer must be adam8bit or lion8bit")
397
+
398
+ else:
399
+ if name == "adam":
400
+ return torch.optim.Adam
401
+ elif name == "adamw":
402
+ return torch.optim.AdamW
403
+ elif name == "lion":
404
+ from lion_pytorch import Lion
405
+
406
+ return Lion
407
+ elif name == "prodigy":
408
+ import prodigyopt
409
+
410
+ return prodigyopt.Prodigy
411
+ else:
412
+ raise ValueError("Optimizer must be adam, adamw, lion or Prodigy")
413
+
414
+
415
+ def get_lr_scheduler(
416
+ name: Optional[str],
417
+ optimizer: torch.optim.Optimizer,
418
+ max_iterations: Optional[int],
419
+ lr_min: Optional[float],
420
+ **kwargs,
421
+ ):
422
+ if name == "cosine":
423
+ return torch.optim.lr_scheduler.CosineAnnealingLR(
424
+ optimizer, T_max=max_iterations, eta_min=lr_min, **kwargs
425
+ )
426
+ elif name == "cosine_with_restarts":
427
+ return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
428
+ optimizer, T_0=max_iterations // 10, T_mult=2, eta_min=lr_min, **kwargs
429
+ )
430
+ elif name == "step":
431
+ return torch.optim.lr_scheduler.StepLR(
432
+ optimizer, step_size=max_iterations // 100, gamma=0.999, **kwargs
433
+ )
434
+ elif name == "constant":
435
+ return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, **kwargs)
436
+ elif name == "linear":
437
+ return torch.optim.lr_scheduler.LinearLR(
438
+ optimizer, factor=0.5, total_iters=max_iterations // 100, **kwargs
439
+ )
440
+ else:
441
+ raise ValueError(
442
+ "Scheduler must be cosine, cosine_with_restarts, step, linear or constant"
443
+ )
444
+
445
+
446
+ def get_random_resolution_in_bucket(bucket_resolution: int = 512) -> tuple[int, int]:
447
+ max_resolution = bucket_resolution
448
+ min_resolution = bucket_resolution // 2
449
+
450
+ step = 64
451
+
452
+ min_step = min_resolution // step
453
+ max_step = max_resolution // step
454
+
455
+ height = torch.randint(min_step, max_step, (1,)).item() * step
456
+ width = torch.randint(min_step, max_step, (1,)).item() * step
457
+
458
+ return height, width
trainscripts/textsliders/__init__.py ADDED
File without changes
trainscripts/textsliders/config_util.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional
2
+
3
+ import yaml
4
+
5
+ from pydantic import BaseModel
6
+ import torch
7
+
8
+ from lora import TRAINING_METHODS
9
+
10
+ PRECISION_TYPES = Literal["fp32", "fp16", "bf16", "float32", "float16", "bfloat16"]
11
+ NETWORK_TYPES = Literal["lierla", "c3lier"]
12
+
13
+
14
+ class PretrainedModelConfig(BaseModel):
15
+ name_or_path: str
16
+ v2: bool = False
17
+ v_pred: bool = False
18
+
19
+ clip_skip: Optional[int] = None
20
+
21
+
22
+ class NetworkConfig(BaseModel):
23
+ type: NETWORK_TYPES = "lierla"
24
+ rank: int = 4
25
+ alpha: float = 1.0
26
+
27
+ training_method: TRAINING_METHODS = "full"
28
+
29
+
30
+ class TrainConfig(BaseModel):
31
+ precision: PRECISION_TYPES = "bfloat16"
32
+ noise_scheduler: Literal["ddim", "ddpm", "lms", "euler_a"] = "ddim"
33
+
34
+ iterations: int = 500
35
+ lr: float = 1e-4
36
+ optimizer: str = "adamw"
37
+ optimizer_args: str = ""
38
+ lr_scheduler: str = "constant"
39
+
40
+ max_denoising_steps: int = 50
41
+
42
+
43
+ class SaveConfig(BaseModel):
44
+ name: str = "untitled"
45
+ path: str = "./output"
46
+ per_steps: int = 200
47
+ precision: PRECISION_TYPES = "float32"
48
+
49
+
50
+ class LoggingConfig(BaseModel):
51
+ use_wandb: bool = False
52
+
53
+ verbose: bool = False
54
+
55
+
56
+ class OtherConfig(BaseModel):
57
+ use_xformers: bool = False
58
+
59
+
60
+ class RootConfig(BaseModel):
61
+ prompts_file: str
62
+ pretrained_model: PretrainedModelConfig
63
+
64
+ network: NetworkConfig
65
+
66
+ train: Optional[TrainConfig]
67
+
68
+ save: Optional[SaveConfig]
69
+
70
+ logging: Optional[LoggingConfig]
71
+
72
+ other: Optional[OtherConfig]
73
+
74
+
75
+ def parse_precision(precision: str) -> torch.dtype:
76
+ if precision == "fp32" or precision == "float32":
77
+ return torch.float32
78
+ elif precision == "fp16" or precision == "float16":
79
+ return torch.float16
80
+ elif precision == "bf16" or precision == "bfloat16":
81
+ return torch.bfloat16
82
+
83
+ raise ValueError(f"Invalid precision type: {precision}")
84
+
85
+
86
+ def load_config_from_yaml(config_path: str) -> RootConfig:
87
+ with open(config_path, "r") as f:
88
+ config = yaml.load(f, Loader=yaml.FullLoader)
89
+
90
+ root = RootConfig(**config)
91
+
92
+ if root.train is None:
93
+ root.train = TrainConfig()
94
+
95
+ if root.save is None:
96
+ root.save = SaveConfig()
97
+
98
+ if root.logging is None:
99
+ root.logging = LoggingConfig()
100
+
101
+ if root.other is None:
102
+ root.other = OtherConfig()
103
+
104
+ return root
trainscripts/textsliders/data/config-xl.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prompts_file: "trainscripts/textsliders/data/prompts-xl.yaml"
2
+ pretrained_model:
3
+ name_or_path: "stabilityai/stable-diffusion-xl-base-1.0" # you can also use .ckpt or .safetensors models
4
+ v2: false # true if model is v2.x
5
+ v_pred: false # true if model uses v-prediction
6
+ network:
7
+ type: "c3lier" # or "c3lier" or "lierla"
8
+ rank: 4
9
+ alpha: 1.0
10
+ training_method: "noxattn"
11
+ train:
12
+ precision: "bfloat16"
13
+ noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a"
14
+ iterations: 1000
15
+ lr: 0.0002
16
+ optimizer: "AdamW"
17
+ lr_scheduler: "constant"
18
+ max_denoising_steps: 50
19
+ save:
20
+ name: "temp"
21
+ path: "./models"
22
+ per_steps: 500
23
+ precision: "bfloat16"
24
+ logging:
25
+ use_wandb: false
26
+ verbose: false
27
+ other:
28
+ use_xformers: true
trainscripts/textsliders/data/config.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prompts_file: "trainscripts/textsliders/data/prompts.yaml"
2
+ pretrained_model:
3
+ name_or_path: "CompVis/stable-diffusion-v1-4" # you can also use .ckpt or .safetensors models
4
+ v2: false # true if model is v2.x
5
+ v_pred: false # true if model uses v-prediction
6
+ network:
7
+ type: "c3lier" # or "c3lier" or "lierla"
8
+ rank: 4
9
+ alpha: 1.0
10
+ training_method: "noxattn"
11
+ train:
12
+ precision: "bfloat16"
13
+ noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a"
14
+ iterations: 1000
15
+ lr: 0.0002
16
+ optimizer: "AdamW"
17
+ lr_scheduler: "constant"
18
+ max_denoising_steps: 50
19
+ save:
20
+ name: "temp"
21
+ path: "./models"
22
+ per_steps: 500
23
+ precision: "bfloat16"
24
+ logging:
25
+ use_wandb: false
26
+ verbose: false
27
+ other:
28
+ use_xformers: true
trainscripts/textsliders/data/prompts-xl.yaml ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ####################################################################################################### AGE SLIDER
2
+ # - target: "male person" # what word for erasing the positive concept from
3
+ # positive: "male person, very old" # concept to erase
4
+ # unconditional: "male person, very young" # word to take the difference from the positive concept
5
+ # neutral: "male person" # starting point for conditioning the target
6
+ # action: "enhance" # erase or enhance
7
+ # guidance_scale: 4
8
+ # resolution: 512
9
+ # dynamic_resolution: false
10
+ # batch_size: 1
11
+ # - target: "female person" # what word for erasing the positive concept from
12
+ # positive: "female person, very old" # concept to erase
13
+ # unconditional: "female person, very young" # word to take the difference from the positive concept
14
+ # neutral: "female person" # starting point for conditioning the target
15
+ # action: "enhance" # erase or enhance
16
+ # guidance_scale: 4
17
+ # resolution: 512
18
+ # dynamic_resolution: false
19
+ # batch_size: 1
20
+ ####################################################################################################### MUSCULAR SLIDER
21
+ # - target: "male person" # what word for erasing the positive concept from
22
+ # positive: "male person, muscular, strong, biceps, greek god physique, body builder" # concept to erase
23
+ # unconditional: "male person, lean, thin, weak, slender, skinny, scrawny" # word to take the difference from the positive concept
24
+ # neutral: "male person" # starting point for conditioning the target
25
+ # action: "enhance" # erase or enhance
26
+ # guidance_scale: 4
27
+ # resolution: 512
28
+ # dynamic_resolution: false
29
+ # batch_size: 1
30
+ # - target: "female person" # what word for erasing the positive concept from
31
+ # positive: "female person, muscular, strong, biceps, greek god physique, body builder" # concept to erase
32
+ # unconditional: "female person, lean, thin, weak, slender, skinny, scrawny" # word to take the difference from the positive concept
33
+ # neutral: "female person" # starting point for conditioning the target
34
+ # action: "enhance" # erase or enhance
35
+ # guidance_scale: 4
36
+ # resolution: 512
37
+ # dynamic_resolution: false
38
+ # batch_size: 1
39
+ ####################################################################################################### CURLY HAIR SLIDER
40
+ # - target: "male person" # what word for erasing the positive concept from
41
+ # positive: "male person, curly hair, wavy hair" # concept to erase
42
+ # unconditional: "male person, straight hair" # word to take the difference from the positive concept
43
+ # neutral: "male person" # starting point for conditioning the target
44
+ # action: "enhance" # erase or enhance
45
+ # guidance_scale: 4
46
+ # resolution: 512
47
+ # dynamic_resolution: false
48
+ # batch_size: 1
49
+ # - target: "female person" # what word for erasing the positive concept from
50
+ # positive: "female person, curly hair, wavy hair" # concept to erase
51
+ # unconditional: "female person, straight hair" # word to take the difference from the positive concept
52
+ # neutral: "female person" # starting point for conditioning the target
53
+ # action: "enhance" # erase or enhance
54
+ # guidance_scale: 4
55
+ # resolution: 512
56
+ # dynamic_resolution: false
57
+ # batch_size: 1
58
+ ####################################################################################################### BEARD SLIDER
59
+ # - target: "male person" # what word for erasing the positive concept from
60
+ # positive: "male person, with beard" # concept to erase
61
+ # unconditional: "male person, clean shaven" # word to take the difference from the positive concept
62
+ # neutral: "male person" # starting point for conditioning the target
63
+ # action: "enhance" # erase or enhance
64
+ # guidance_scale: 4
65
+ # resolution: 512
66
+ # dynamic_resolution: false
67
+ # batch_size: 1
68
+ # - target: "female person" # what word for erasing the positive concept from
69
+ # positive: "female person, with beard, lipstick and feminine" # concept to erase
70
+ # unconditional: "female person, clean shaven" # word to take the difference from the positive concept
71
+ # neutral: "female person" # starting point for conditioning the target
72
+ # action: "enhance" # erase or enhance
73
+ # guidance_scale: 4
74
+ # resolution: 512
75
+ # dynamic_resolution: false
76
+ # batch_size: 1
77
+ ####################################################################################################### MAKEUP SLIDER
78
+ # - target: "male person" # what word for erasing the positive concept from
79
+ # positive: "male person, with makeup, cosmetic, concealer, mascara" # concept to erase
80
+ # unconditional: "male person, barefaced, ugly" # word to take the difference from the positive concept
81
+ # neutral: "male person" # starting point for conditioning the target
82
+ # action: "enhance" # erase or enhance
83
+ # guidance_scale: 4
84
+ # resolution: 512
85
+ # dynamic_resolution: false
86
+ # batch_size: 1
87
+ # - target: "female person" # what word for erasing the positive concept from
88
+ # positive: "female person, with makeup, cosmetic, concealer, mascara, lipstick" # concept to erase
89
+ # unconditional: "female person, barefaced, ugly" # word to take the difference from the positive concept
90
+ # neutral: "female person" # starting point for conditioning the target
91
+ # action: "enhance" # erase or enhance
92
+ # guidance_scale: 4
93
+ # resolution: 512
94
+ # dynamic_resolution: false
95
+ # batch_size: 1
96
+ ####################################################################################################### SURPRISED SLIDER
97
+ # - target: "male person" # what word for erasing the positive concept from
98
+ # positive: "male person, with shocked look, surprised, stunned, amazed" # concept to erase
99
+ # unconditional: "male person, dull, uninterested, bored, incurious" # word to take the difference from the positive concept
100
+ # neutral: "male person" # starting point for conditioning the target
101
+ # action: "enhance" # erase or enhance
102
+ # guidance_scale: 4
103
+ # resolution: 512
104
+ # dynamic_resolution: false
105
+ # batch_size: 1
106
+ # - target: "female person" # what word for erasing the positive concept from
107
+ # positive: "female person, with shocked look, surprised, stunned, amazed" # concept to erase
108
+ # unconditional: "female person, dull, uninterested, bored, incurious" # word to take the difference from the positive concept
109
+ # neutral: "female person" # starting point for conditioning the target
110
+ # action: "enhance" # erase or enhance
111
+ # guidance_scale: 4
112
+ # resolution: 512
113
+ # dynamic_resolution: false
114
+ # batch_size: 1
115
+ ####################################################################################################### OBESE SLIDER
116
+ # - target: "male person" # what word for erasing the positive concept from
117
+ # positive: "male person, fat, chubby, overweight, obese" # concept to erase
118
+ # unconditional: "male person, lean, fit, slim, slender" # word to take the difference from the positive concept
119
+ # neutral: "male person" # starting point for conditioning the target
120
+ # action: "enhance" # erase or enhance
121
+ # guidance_scale: 4
122
+ # resolution: 512
123
+ # dynamic_resolution: false
124
+ # batch_size: 1
125
+ # - target: "female person" # what word for erasing the positive concept from
126
+ # positive: "female person, fat, chubby, overweight, obese" # concept to erase
127
+ # unconditional: "female person, lean, fit, slim, slender" # word to take the difference from the positive concept
128
+ # neutral: "female person" # starting point for conditioning the target
129
+ # action: "enhance" # erase or enhance
130
+ # guidance_scale: 4
131
+ # resolution: 512
132
+ # dynamic_resolution: false
133
+ # batch_size: 1
134
+ ####################################################################################################### PROFESSIONAL SLIDER
135
+ # - target: "male person" # what word for erasing the positive concept from
136
+ # positive: "male person, professionally dressed, stylised hair, clean face" # concept to erase
137
+ # unconditional: "male person, casually dressed, messy hair, unkempt face" # word to take the difference from the positive concept
138
+ # neutral: "male person" # starting point for conditioning the target
139
+ # action: "enhance" # erase or enhance
140
+ # guidance_scale: 4
141
+ # resolution: 512
142
+ # dynamic_resolution: false
143
+ # batch_size: 1
144
+ # - target: "female person" # what word for erasing the positive concept from
145
+ # positive: "female person, professionally dressed, stylised hair, clean face" # concept to erase
146
+ # unconditional: "female person, casually dressed, messy hair, unkempt face" # word to take the difference from the positive concept
147
+ # neutral: "female person" # starting point for conditioning the target
148
+ # action: "enhance" # erase or enhance
149
+ # guidance_scale: 4
150
+ # resolution: 512
151
+ # dynamic_resolution: false
152
+ # batch_size: 1
153
+ ####################################################################################################### GLASSES SLIDER
154
+ # - target: "male person" # what word for erasing the positive concept from
155
+ # positive: "male person, wearing glasses" # concept to erase
156
+ # unconditional: "male person" # word to take the difference from the positive concept
157
+ # neutral: "male person" # starting point for conditioning the target
158
+ # action: "enhance" # erase or enhance
159
+ # guidance_scale: 4
160
+ # resolution: 512
161
+ # dynamic_resolution: false
162
+ # batch_size: 1
163
+ # - target: "female person" # what word for erasing the positive concept from
164
+ # positive: "female person, wearing glasses" # concept to erase
165
+ # unconditional: "female person" # word to take the difference from the positive concept
166
+ # neutral: "female person" # starting point for conditioning the target
167
+ # action: "enhance" # erase or enhance
168
+ # guidance_scale: 4
169
+ # resolution: 512
170
+ # dynamic_resolution: false
171
+ # batch_size: 1
172
+ ####################################################################################################### ASTRONAUGHT SLIDER
173
+ # - target: "astronaught" # what word for erasing the positive concept from
174
+ # positive: "astronaught, with orange colored spacesuit" # concept to erase
175
+ # unconditional: "astronaught" # word to take the difference from the positive concept
176
+ # neutral: "astronaught" # starting point for conditioning the target
177
+ # action: "enhance" # erase or enhance
178
+ # guidance_scale: 4
179
+ # resolution: 512
180
+ # dynamic_resolution: false
181
+ # batch_size: 1
182
+ ####################################################################################################### SMILING SLIDER
183
+ # - target: "male person" # what word for erasing the positive concept from
184
+ # positive: "male person, smiling" # concept to erase
185
+ # unconditional: "male person, frowning" # word to take the difference from the positive concept
186
+ # neutral: "male person" # starting point for conditioning the target
187
+ # action: "enhance" # erase or enhance
188
+ # guidance_scale: 4
189
+ # resolution: 512
190
+ # dynamic_resolution: false
191
+ # batch_size: 1
192
+ # - target: "female person" # what word for erasing the positive concept from
193
+ # positive: "female person, smiling" # concept to erase
194
+ # unconditional: "female person, frowning" # word to take the difference from the positive concept
195
+ # neutral: "female person" # starting point for conditioning the target
196
+ # action: "enhance" # erase or enhance
197
+ # guidance_scale: 4
198
+ # resolution: 512
199
+ # dynamic_resolution: false
200
+ # batch_size: 1
201
+ ####################################################################################################### CAR COLOR SLIDER
202
+ # - target: "car" # what word for erasing the positive concept from
203
+ # positive: "car, white color" # concept to erase
204
+ # unconditional: "car, black color" # word to take the difference from the positive concept
205
+ # neutral: "car" # starting point for conditioning the target
206
+ # action: "enhance" # erase or enhance
207
+ # guidance_scale: 4
208
+ # resolution: 512
209
+ # dynamic_resolution: false
210
+ # batch_size: 1
211
+ ####################################################################################################### DETAILS SLIDER
212
+ # - target: "" # what word for erasing the positive concept from
213
+ # positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality, hyper realistic" # concept to erase
214
+ # unconditional: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality" # word to take the difference from the positive concept
215
+ # neutral: "" # starting point for conditioning the target
216
+ # action: "enhance" # erase or enhance
217
+ # guidance_scale: 4
218
+ # resolution: 512
219
+ # dynamic_resolution: false
220
+ # batch_size: 1
221
+ ####################################################################################################### CARTOON SLIDER
222
+ # - target: "male person" # what word for erasing the positive concept from
223
+ # positive: "male person, cartoon style, pixar style, animated style" # concept to erase
224
+ # unconditional: "male person, realistic, hyper realistic" # word to take the difference from the positive concept
225
+ # neutral: "male person" # starting point for conditioning the target
226
+ # action: "enhance" # erase or enhance
227
+ # guidance_scale: 4
228
+ # resolution: 512
229
+ # dynamic_resolution: false
230
+ # batch_size: 1
231
+ # - target: "female person" # what word for erasing the positive concept from
232
+ # positive: "female person, cartoon style, pixar style, animated style" # concept to erase
233
+ # unconditional: "female person, realistic, hyper realistic" # word to take the difference from the positive concept
234
+ # neutral: "female person" # starting point for conditioning the target
235
+ # action: "enhance" # erase or enhance
236
+ # guidance_scale: 4
237
+ # resolution: 512
238
+ # dynamic_resolution: false
239
+ # batch_size: 1
240
+ ####################################################################################################### CLAY SLIDER
241
+ # - target: "male person" # what word for erasing the positive concept from
242
+ # positive: "male person, clay style, made out of clay, clay sculpture" # concept to erase
243
+ # unconditional: "male person, realistic, hyper realistic" # word to take the difference from the positive concept
244
+ # neutral: "male person" # starting point for conditioning the target
245
+ # action: "enhance" # erase or enhance
246
+ # guidance_scale: 4
247
+ # resolution: 512
248
+ # dynamic_resolution: false
249
+ # batch_size: 1
250
+ # - target: "female person" # what word for erasing the positive concept from
251
+ # positive: "female person, clay style, made out of clay, clay sculpture" # concept to erase
252
+ # unconditional: "female person, realistic, hyper realistic" # word to take the difference from the positive concept
253
+ # neutral: "female person" # starting point for conditioning the target
254
+ # action: "enhance" # erase or enhance
255
+ # guidance_scale: 4
256
+ # resolution: 512
257
+ # dynamic_resolution: false
258
+ # batch_size: 1
259
+ ####################################################################################################### SCULPTURE SLIDER
260
+ - target: "male person" # what word for erasing the positive concept from
261
+ positive: "male person, cement sculpture, cement greek statue style" # concept to erase
262
+ unconditional: "male person, realistic, hyper realistic" # word to take the difference from the positive concept
263
+ neutral: "male person" # starting point for conditioning the target
264
+ action: "enhance" # erase or enhance
265
+ guidance_scale: 4
266
+ resolution: 512
267
+ dynamic_resolution: false
268
+ batch_size: 1
269
+ - target: "female person" # what word for erasing the positive concept from
270
+ positive: "female person, cement sculpture, cement greek statue style" # concept to erase
271
+ unconditional: "female person, realistic, hyper realistic" # word to take the difference from the positive concept
272
+ neutral: "female person" # starting point for conditioning the target
273
+ action: "enhance" # erase or enhance
274
+ guidance_scale: 4
275
+ resolution: 512
276
+ dynamic_resolution: false
277
+ batch_size: 1
278
+ ####################################################################################################### METAL SLIDER
279
+ # - target: "" # what word for erasing the positive concept from
280
+ # positive: "made out of metal, metallic style, iron, copper, platinum metal," # concept to erase
281
+ # unconditional: "wooden style, made out of wood" # word to take the difference from the positive concept
282
+ # neutral: "" # starting point for conditioning the target
283
+ # action: "enhance" # erase or enhance
284
+ # guidance_scale: 4
285
+ # resolution: 512
286
+ # dynamic_resolution: false
287
+ # batch_size: 1
288
+ ####################################################################################################### FESTIVE SLIDER
289
+ # - target: "" # what word for erasing the positive concept from
290
+ # positive: "festive, colorful banners, confetti, indian festival decorations, chinese festival decorations, fireworks, parade, cherry, gala, happy, celebrations" # concept to erase
291
+ # unconditional: "dull, dark, sad, desserted, empty, alone" # word to take the difference from the positive concept
292
+ # neutral: "" # starting point for conditioning the target
293
+ # action: "enhance" # erase or enhance
294
+ # guidance_scale: 4
295
+ # resolution: 512
296
+ # dynamic_resolution: false
297
+ # batch_size: 1
298
+ ####################################################################################################### TROPICAL SLIDER
299
+ # - target: "" # what word for erasing the positive concept from
300
+ # positive: "tropical, beach, sunny, hot" # concept to erase
301
+ # unconditional: "arctic, winter, snow, ice, iceburg, snowfall" # word to take the difference from the positive concept
302
+ # neutral: "" # starting point for conditioning the target
303
+ # action: "enhance" # erase or enhance
304
+ # guidance_scale: 4
305
+ # resolution: 512
306
+ # dynamic_resolution: false
307
+ # batch_size: 1
308
+ ####################################################################################################### MODERN SLIDER
309
+ # - target: "" # what word for erasing the positive concept from
310
+ # positive: "modern, futuristic style, trendy, stylish, swank" # concept to erase
311
+ # unconditional: "ancient, classic style, regal, vintage" # word to take the difference from the positive concept
312
+ # neutral: "" # starting point for conditioning the target
313
+ # action: "enhance" # erase or enhance
314
+ # guidance_scale: 4
315
+ # resolution: 512
316
+ # dynamic_resolution: false
317
+ # batch_size: 1
318
+ ####################################################################################################### BOKEH SLIDER
319
+ # - target: "" # what word for erasing the positive concept from
320
+ # positive: "blurred background, narrow DOF, bokeh effect" # concept to erase
321
+ # # unconditional: "high detail background, 8k, intricate, detailed, high resolution background, high res, high quality background" # word to take the difference from the positive concept
322
+ # unconditional: ""
323
+ # neutral: "" # starting point for conditioning the target
324
+ # action: "enhance" # erase or enhance
325
+ # guidance_scale: 4
326
+ # resolution: 512
327
+ # dynamic_resolution: false
328
+ # batch_size: 1
329
+ ####################################################################################################### LONG HAIR SLIDER
330
+ # - target: "male person" # what word for erasing the positive concept from
331
+ # positive: "male person, with long hair" # concept to erase
332
+ # unconditional: "male person, with short hair" # word to take the difference from the positive concept
333
+ # neutral: "male person" # starting point for conditioning the target
334
+ # action: "enhance" # erase or enhance
335
+ # guidance_scale: 4
336
+ # resolution: 512
337
+ # dynamic_resolution: false
338
+ # batch_size: 1
339
+ # - target: "female person" # what word for erasing the positive concept from
340
+ # positive: "female person, with long hair" # concept to erase
341
+ # unconditional: "female person, with short hair" # word to take the difference from the positive concept
342
+ # neutral: "female person" # starting point for conditioning the target
343
+ # action: "enhance" # erase or enhance
344
+ # guidance_scale: 4
345
+ # resolution: 512
346
+ # dynamic_resolution: false
347
+ # batch_size: 1
348
+ ####################################################################################################### NEGPROMPT SLIDER
349
+ # - target: "" # what word for erasing the positive concept from
350
+ # positive: "cartoon, cgi, render, illustration, painting, drawing, bad quality, grainy, low resolution" # concept to erase
351
+ # unconditional: ""
352
+ # neutral: "" # starting point for conditioning the target
353
+ # action: "erase" # erase or enhance
354
+ # guidance_scale: 4
355
+ # resolution: 512
356
+ # dynamic_resolution: false
357
+ # batch_size: 1
358
+ ####################################################################################################### EXPENSIVE FOOD SLIDER
359
+ # - target: "food" # what word for erasing the positive concept from
360
+ # positive: "food, expensive and fine dining" # concept to erase
361
+ # unconditional: "food, cheap and low quality" # word to take the difference from the positive concept
362
+ # neutral: "food" # starting point for conditioning the target
363
+ # action: "enhance" # erase or enhance
364
+ # guidance_scale: 4
365
+ # resolution: 512
366
+ # dynamic_resolution: false
367
+ # batch_size: 1
368
+ ####################################################################################################### COOKED FOOD SLIDER
369
+ # - target: "food" # what word for erasing the positive concept from
370
+ # positive: "food, cooked, baked, roasted, fried" # concept to erase
371
+ # unconditional: "food, raw, uncooked, fresh, undone" # word to take the difference from the positive concept
372
+ # neutral: "food" # starting point for conditioning the target
373
+ # action: "enhance" # erase or enhance
374
+ # guidance_scale: 4
375
+ # resolution: 512
376
+ # dynamic_resolution: false
377
+ # batch_size: 1
378
+ ####################################################################################################### MEAT FOOD SLIDER
379
+ # - target: "food" # what word for erasing the positive concept from
380
+ # positive: "food, meat, steak, fish, non-vegetrian, beef, lamb, pork, chicken, salmon" # concept to erase
381
+ # unconditional: "food, vegetables, fruits, leafy-vegetables, greens, vegetarian, vegan, tomatoes, onions, carrots" # word to take the difference from the positive concept
382
+ # neutral: "food" # starting point for conditioning the target
383
+ # action: "enhance" # erase or enhance
384
+ # guidance_scale: 4
385
+ # resolution: 512
386
+ # dynamic_resolution: false
387
+ # batch_size: 1
388
+ ####################################################################################################### WEATHER SLIDER
389
+ # - target: "" # what word for erasing the positive concept from
390
+ # positive: "snowy, winter, cold, ice, snowfall, white" # concept to erase
391
+ # unconditional: "hot, summer, bright, sunny" # word to take the difference from the positive concept
392
+ # neutral: "" # starting point for conditioning the target
393
+ # action: "enhance" # erase or enhance
394
+ # guidance_scale: 4
395
+ # resolution: 512
396
+ # dynamic_resolution: false
397
+ # batch_size: 1
398
+ ####################################################################################################### NIGHT/DAY SLIDER
399
+ # - target: "" # what word for erasing the positive concept from
400
+ # positive: "night time, dark, darkness, pitch black, nighttime" # concept to erase
401
+ # unconditional: "day time, bright, sunny, daytime, sunlight" # word to take the difference from the positive concept
402
+ # neutral: "" # starting point for conditioning the target
403
+ # action: "enhance" # erase or enhance
404
+ # guidance_scale: 4
405
+ # resolution: 512
406
+ # dynamic_resolution: false
407
+ # batch_size: 1
408
+ ####################################################################################################### INDOOR/OUTDOOR SLIDER
409
+ # - target: "" # what word for erasing the positive concept from
410
+ # positive: "indoor, inside a room, inside, interior" # concept to erase
411
+ # unconditional: "outdoor, outside, open air, exterior" # word to take the difference from the positive concept
412
+ # neutral: "" # starting point for conditioning the target
413
+ # action: "enhance" # erase or enhance
414
+ # guidance_scale: 4
415
+ # resolution: 512
416
+ # dynamic_resolution: false
417
+ # batch_size: 1
418
+ ####################################################################################################### GOODHANDS SLIDER
419
+ # - target: "" # what word for erasing the positive concept from
420
+ # positive: "realistic hands, realistic limbs, perfect limbs, perfect hands, 5 fingers, five fingers, hyper realisitc hands" # concept to erase
421
+ # unconditional: "poorly drawn limbs, distorted limbs, poorly rendered hands,bad anatomy, disfigured, mutated body parts, bad composition" # word to take the difference from the positive concept
422
+ # neutral: "" # starting point for conditioning the target
423
+ # action: "enhance" # erase or enhance
424
+ # guidance_scale: 4
425
+ # resolution: 512
426
+ # dynamic_resolution: false
427
+ # batch_size: 1
428
+ ####################################################################################################### RUSTY CAR SLIDER
429
+ # - target: "car" # what word for erasing the positive concept from
430
+ # positive: "car, rusty conditioned" # concept to erase
431
+ # unconditional: "car, mint condition, brand new, shiny" # word to take the difference from the positive concept
432
+ # neutral: "car" # starting point for conditioning the target
433
+ # action: "enhance" # erase or enhance
434
+ # guidance_scale: 4
435
+ # resolution: 512
436
+ # dynamic_resolution: false
437
+ # batch_size: 1
438
+ ####################################################################################################### RUSTY CAR SLIDER
439
+ # - target: "car" # what word for erasing the positive concept from
440
+ # positive: "car, damaged, broken headlights, dented car, with scrapped paintwork" # concept to erase
441
+ # unconditional: "car, mint condition, brand new, shiny" # word to take the difference from the positive concept
442
+ # neutral: "car" # starting point for conditioning the target
443
+ # action: "enhance" # erase or enhance
444
+ # guidance_scale: 4
445
+ # resolution: 512
446
+ # dynamic_resolution: false
447
+ # batch_size: 1
448
+ ####################################################################################################### CLUTTERED ROOM SLIDER
449
+ # - target: "room" # what word for erasing the positive concept from
450
+ # positive: "room, cluttered, disorganized, dirty, jumbled, scattered" # concept to erase
451
+ # unconditional: "room, super organized, clean, ordered, neat, tidy" # word to take the difference from the positive concept
452
+ # neutral: "room" # starting point for conditioning the target
453
+ # action: "enhance" # erase or enhance
454
+ # guidance_scale: 4
455
+ # resolution: 512
456
+ # dynamic_resolution: false
457
+ # batch_size: 1
458
+ ####################################################################################################### HANDS SLIDER
459
+ # - target: "hands" # what word for erasing the positive concept from
460
+ # positive: "realistic hands, five fingers, 8k hyper realistic hands" # concept to erase
461
+ # unconditional: "poorly drawn hands, distorted hands, amputed fingers" # word to take the difference from the positive concept
462
+ # neutral: "hands" # starting point for conditioning the target
463
+ # action: "enhance" # erase or enhance
464
+ # guidance_scale: 4
465
+ # resolution: 512
466
+ # dynamic_resolution: false
467
+ # batch_size: 1
468
+ ####################################################################################################### HANDS SLIDER
469
+ # - target: "female person" # what word for erasing the positive concept from
470
+ # positive: "female person, with a surprised look" # concept to erase
471
+ # unconditional: "female person, with a disinterested look" # word to take the difference from the positive concept
472
+ # neutral: "female person" # starting point for conditioning the target
473
+ # action: "enhance" # erase or enhance
474
+ # guidance_scale: 4
475
+ # resolution: 512
476
+ # dynamic_resolution: false
477
+ # batch_size: 1
trainscripts/textsliders/data/prompts.yaml ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - target: "male person" # what word for erasing the positive concept from
2
+ positive: "male person, very old" # concept to erase
3
+ unconditional: "male person, very young" # word to take the difference from the positive concept
4
+ neutral: "male person" # starting point for conditioning the target
5
+ action: "enhance" # erase or enhance
6
+ guidance_scale: 4
7
+ resolution: 512
8
+ dynamic_resolution: false
9
+ batch_size: 1
10
+ - target: "female person" # what word for erasing the positive concept from
11
+ positive: "female person, very old" # concept to erase
12
+ unconditional: "female person, very young" # word to take the difference from the positive concept
13
+ neutral: "female person" # starting point for conditioning the target
14
+ action: "enhance" # erase or enhance
15
+ guidance_scale: 4
16
+ resolution: 512
17
+ dynamic_resolution: false
18
+ batch_size: 1
19
+ # - target: "" # what word for erasing the positive concept from
20
+ # positive: "a group of people" # concept to erase
21
+ # unconditional: "a person" # word to take the difference from the positive concept
22
+ # neutral: "" # starting point for conditioning the target
23
+ # action: "enhance" # erase or enhance
24
+ # guidance_scale: 4
25
+ # resolution: 512
26
+ # dynamic_resolution: false
27
+ # batch_size: 1
28
+ # - target: "" # what word for erasing the positive concept from
29
+ # positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality" # concept to erase
30
+ # unconditional: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality" # word to take the difference from the positive concept
31
+ # neutral: "" # starting point for conditioning the target
32
+ # action: "enhance" # erase or enhance
33
+ # guidance_scale: 4
34
+ # resolution: 512
35
+ # dynamic_resolution: false
36
+ # batch_size: 1
37
+ # - target: "" # what word for erasing the positive concept from
38
+ # positive: "blurred background, narrow DOF, bokeh effect" # concept to erase
39
+ # # unconditional: "high detail background, 8k, intricate, detailed, high resolution background, high res, high quality background" # word to take the difference from the positive concept
40
+ # unconditional: ""
41
+ # neutral: "" # starting point for conditioning the target
42
+ # action: "enhance" # erase or enhance
43
+ # guidance_scale: 4
44
+ # resolution: 512
45
+ # dynamic_resolution: false
46
+ # batch_size: 1
47
+ # - target: "food" # what word for erasing the positive concept from
48
+ # positive: "food, expensive and fine dining" # concept to erase
49
+ # unconditional: "food, cheap and low quality" # word to take the difference from the positive concept
50
+ # neutral: "food" # starting point for conditioning the target
51
+ # action: "enhance" # erase or enhance
52
+ # guidance_scale: 4
53
+ # resolution: 512
54
+ # dynamic_resolution: false
55
+ # batch_size: 1
56
+ # - target: "room" # what word for erasing the positive concept from
57
+ # positive: "room, dirty disorganised and cluttered" # concept to erase
58
+ # unconditional: "room, neat organised and clean" # word to take the difference from the positive concept
59
+ # neutral: "room" # starting point for conditioning the target
60
+ # action: "enhance" # erase or enhance
61
+ # guidance_scale: 4
62
+ # resolution: 512
63
+ # dynamic_resolution: false
64
+ # batch_size: 1
65
+ # - target: "male person" # what word for erasing the positive concept from
66
+ # positive: "male person, with a surprised look" # concept to erase
67
+ # unconditional: "male person, with a disinterested look" # word to take the difference from the positive concept
68
+ # neutral: "male person" # starting point for conditioning the target
69
+ # action: "enhance" # erase or enhance
70
+ # guidance_scale: 4
71
+ # resolution: 512
72
+ # dynamic_resolution: false
73
+ # batch_size: 1
74
+ # - target: "female person" # what word for erasing the positive concept from
75
+ # positive: "female person, with a surprised look" # concept to erase
76
+ # unconditional: "female person, with a disinterested look" # word to take the difference from the positive concept
77
+ # neutral: "female person" # starting point for conditioning the target
78
+ # action: "enhance" # erase or enhance
79
+ # guidance_scale: 4
80
+ # resolution: 512
81
+ # dynamic_resolution: false
82
+ # batch_size: 1
83
+ # - target: "sky" # what word for erasing the positive concept from
84
+ # positive: "peaceful sky" # concept to erase
85
+ # unconditional: "sky" # word to take the difference from the positive concept
86
+ # neutral: "sky" # starting point for conditioning the target
87
+ # action: "enhance" # erase or enhance
88
+ # guidance_scale: 4
89
+ # resolution: 512
90
+ # dynamic_resolution: false
91
+ # batch_size: 1
92
+ # - target: "sky" # what word for erasing the positive concept from
93
+ # positive: "chaotic dark sky" # concept to erase
94
+ # unconditional: "sky" # word to take the difference from the positive concept
95
+ # neutral: "sky" # starting point for conditioning the target
96
+ # action: "erase" # erase or enhance
97
+ # guidance_scale: 4
98
+ # resolution: 512
99
+ # dynamic_resolution: false
100
+ # batch_size: 1
101
+ # - target: "person" # what word for erasing the positive concept from
102
+ # positive: "person, very young" # concept to erase
103
+ # unconditional: "person" # word to take the difference from the positive concept
104
+ # neutral: "person" # starting point for conditioning the target
105
+ # action: "erase" # erase or enhance
106
+ # guidance_scale: 4
107
+ # resolution: 512
108
+ # dynamic_resolution: false
109
+ # batch_size: 1
110
+ # overweight
111
+ # - target: "art" # what word for erasing the positive concept from
112
+ # positive: "realistic art" # concept to erase
113
+ # unconditional: "art" # word to take the difference from the positive concept
114
+ # neutral: "art" # starting point for conditioning the target
115
+ # action: "enhance" # erase or enhance
116
+ # guidance_scale: 4
117
+ # resolution: 512
118
+ # dynamic_resolution: false
119
+ # batch_size: 1
120
+ # - target: "art" # what word for erasing the positive concept from
121
+ # positive: "abstract art" # concept to erase
122
+ # unconditional: "art" # word to take the difference from the positive concept
123
+ # neutral: "art" # starting point for conditioning the target
124
+ # action: "erase" # erase or enhance
125
+ # guidance_scale: 4
126
+ # resolution: 512
127
+ # dynamic_resolution: false
128
+ # batch_size: 1
129
+ # sky
130
+ # - target: "weather" # what word for erasing the positive concept from
131
+ # positive: "bright pleasant weather" # concept to erase
132
+ # unconditional: "weather" # word to take the difference from the positive concept
133
+ # neutral: "weather" # starting point for conditioning the target
134
+ # action: "enhance" # erase or enhance
135
+ # guidance_scale: 4
136
+ # resolution: 512
137
+ # dynamic_resolution: false
138
+ # batch_size: 1
139
+ # - target: "weather" # what word for erasing the positive concept from
140
+ # positive: "dark gloomy weather" # concept to erase
141
+ # unconditional: "weather" # word to take the difference from the positive concept
142
+ # neutral: "weather" # starting point for conditioning the target
143
+ # action: "erase" # erase or enhance
144
+ # guidance_scale: 4
145
+ # resolution: 512
146
+ # dynamic_resolution: false
147
+ # batch_size: 1
148
+ # hair
149
+ # - target: "person" # what word for erasing the positive concept from
150
+ # positive: "person with long hair" # concept to erase
151
+ # unconditional: "person" # word to take the difference from the positive concept
152
+ # neutral: "person" # starting point for conditioning the target
153
+ # action: "enhance" # erase or enhance
154
+ # guidance_scale: 4
155
+ # resolution: 512
156
+ # dynamic_resolution: false
157
+ # batch_size: 1
158
+ # - target: "person" # what word for erasing the positive concept from
159
+ # positive: "person with short hair" # concept to erase
160
+ # unconditional: "person" # word to take the difference from the positive concept
161
+ # neutral: "person" # starting point for conditioning the target
162
+ # action: "erase" # erase or enhance
163
+ # guidance_scale: 4
164
+ # resolution: 512
165
+ # dynamic_resolution: false
166
+ # batch_size: 1
167
+ # - target: "girl" # what word for erasing the positive concept from
168
+ # positive: "baby girl" # concept to erase
169
+ # unconditional: "girl" # word to take the difference from the positive concept
170
+ # neutral: "girl" # starting point for conditioning the target
171
+ # action: "enhance" # erase or enhance
172
+ # guidance_scale: -4
173
+ # resolution: 512
174
+ # dynamic_resolution: false
175
+ # batch_size: 1
176
+ # - target: "boy" # what word for erasing the positive concept from
177
+ # positive: "old man" # concept to erase
178
+ # unconditional: "boy" # word to take the difference from the positive concept
179
+ # neutral: "boy" # starting point for conditioning the target
180
+ # action: "enhance" # erase or enhance
181
+ # guidance_scale: 4
182
+ # resolution: 512
183
+ # dynamic_resolution: false
184
+ # batch_size: 1
185
+ # - target: "boy" # what word for erasing the positive concept from
186
+ # positive: "baby boy" # concept to erase
187
+ # unconditional: "boy" # word to take the difference from the positive concept
188
+ # neutral: "boy" # starting point for conditioning the target
189
+ # action: "enhance" # erase or enhance
190
+ # guidance_scale: -4
191
+ # resolution: 512
192
+ # dynamic_resolution: false
193
+ # batch_size: 1
trainscripts/textsliders/debug_util.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # デバッグ用...
2
+
3
+ import torch
4
+
5
+
6
+ def check_requires_grad(model: torch.nn.Module):
7
+ for name, module in list(model.named_modules())[:5]:
8
+ if len(list(module.parameters())) > 0:
9
+ print(f"Module: {name}")
10
+ for name, param in list(module.named_parameters())[:2]:
11
+ print(f" Parameter: {name}, Requires Grad: {param.requires_grad}")
12
+
13
+
14
+ def check_training_mode(model: torch.nn.Module):
15
+ for name, module in list(model.named_modules())[:5]:
16
+ print(f"Module: {name}, Training Mode: {module.training}")
trainscripts/textsliders/flush.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import torch
2
+ import gc
3
+
4
+ torch.cuda.empty_cache()
5
+ gc.collect()
trainscripts/textsliders/generate_images_xl.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import argparse
4
+ import os, json, random
5
+ import pandas as pd
6
+ import matplotlib.pyplot as plt
7
+ import glob, re
8
+
9
+ from safetensors.torch import load_file
10
+ import matplotlib.image as mpimg
11
+ import copy
12
+ import gc
13
+ from transformers import CLIPTextModel, CLIPTokenizer
14
+
15
+ import diffusers
16
+ from diffusers import DiffusionPipeline
17
+ from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler
18
+ from diffusers.loaders import AttnProcsLayers
19
+ from diffusers.models.attention_processor import LoRAAttnProcessor, AttentionProcessor
20
+ from typing import Any, Dict, List, Optional, Tuple, Union
21
+ from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
22
+ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
23
+ import inspect
24
+ import os
25
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
26
+ from diffusers.pipelines import StableDiffusionXLPipeline
27
+ import random
28
+
29
+ import torch
30
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
31
+ import re
32
+ import argparse
33
+
34
+ def flush():
35
+ torch.cuda.empty_cache()
36
+ gc.collect()
37
+
38
+ @torch.no_grad()
39
+ def call(
40
+ self,
41
+ prompt: Union[str, List[str]] = None,
42
+ prompt_2: Optional[Union[str, List[str]]] = None,
43
+ height: Optional[int] = None,
44
+ width: Optional[int] = None,
45
+ num_inference_steps: int = 50,
46
+ denoising_end: Optional[float] = None,
47
+ guidance_scale: float = 5.0,
48
+ negative_prompt: Optional[Union[str, List[str]]] = None,
49
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
50
+ num_images_per_prompt: Optional[int] = 1,
51
+ eta: float = 0.0,
52
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
53
+ latents: Optional[torch.FloatTensor] = None,
54
+ prompt_embeds: Optional[torch.FloatTensor] = None,
55
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
56
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
57
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
58
+ output_type: Optional[str] = "pil",
59
+ return_dict: bool = True,
60
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
61
+ callback_steps: int = 1,
62
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
63
+ guidance_rescale: float = 0.0,
64
+ original_size: Optional[Tuple[int, int]] = None,
65
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
66
+ target_size: Optional[Tuple[int, int]] = None,
67
+ negative_original_size: Optional[Tuple[int, int]] = None,
68
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
69
+ negative_target_size: Optional[Tuple[int, int]] = None,
70
+
71
+ network=None,
72
+ start_noise=None,
73
+ scale=None,
74
+ unet=None,
75
+ ):
76
+ r"""
77
+ Function invoked when calling the pipeline for generation.
78
+
79
+ Args:
80
+ prompt (`str` or `List[str]`, *optional*):
81
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
82
+ instead.
83
+ prompt_2 (`str` or `List[str]`, *optional*):
84
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
85
+ used in both text-encoders
86
+ height (`int`, *optional*, defaults to unet.config.sample_size * self.vae_scale_factor):
87
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
88
+ Anything below 512 pixels won't work well for
89
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
90
+ and checkpoints that are not specifically fine-tuned on low resolutions.
91
+ width (`int`, *optional*, defaults to unet.config.sample_size * self.vae_scale_factor):
92
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
93
+ Anything below 512 pixels won't work well for
94
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
95
+ and checkpoints that are not specifically fine-tuned on low resolutions.
96
+ num_inference_steps (`int`, *optional*, defaults to 50):
97
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
98
+ expense of slower inference.
99
+ denoising_end (`float`, *optional*):
100
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
101
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
102
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
103
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
104
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
105
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
106
+ guidance_scale (`float`, *optional*, defaults to 5.0):
107
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
108
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
109
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
110
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
111
+ usually at the expense of lower image quality.
112
+ negative_prompt (`str` or `List[str]`, *optional*):
113
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
114
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
115
+ less than `1`).
116
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
117
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
118
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
119
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
120
+ The number of images to generate per prompt.
121
+ eta (`float`, *optional*, defaults to 0.0):
122
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
123
+ [`schedulers.DDIMScheduler`], will be ignored for others.
124
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
125
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
126
+ to make generation deterministic.
127
+ latents (`torch.FloatTensor`, *optional*):
128
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
129
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
130
+ tensor will ge generated by sampling using the supplied random `generator`.
131
+ prompt_embeds (`torch.FloatTensor`, *optional*):
132
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
133
+ provided, text embeddings will be generated from `prompt` input argument.
134
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
135
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
136
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
137
+ argument.
138
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
139
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
140
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
141
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
142
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
143
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
144
+ input argument.
145
+ output_type (`str`, *optional*, defaults to `"pil"`):
146
+ The output format of the generate image. Choose between
147
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
148
+ return_dict (`bool`, *optional*, defaults to `True`):
149
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
150
+ of a plain tuple.
151
+ callback (`Callable`, *optional*):
152
+ A function that will be called every `callback_steps` steps during inference. The function will be
153
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
154
+ callback_steps (`int`, *optional*, defaults to 1):
155
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
156
+ called at every step.
157
+ cross_attention_kwargs (`dict`, *optional*):
158
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
159
+ `self.processor` in
160
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
161
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
162
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
163
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
164
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
165
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
166
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
167
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
168
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
169
+ explained in section 2.2 of
170
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
171
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
172
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
173
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
174
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
175
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
176
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
177
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
178
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
179
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
180
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
181
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
182
+ micro-conditioning as explained in section 2.2 of
183
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
184
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
185
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
186
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
187
+ micro-conditioning as explained in section 2.2 of
188
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
189
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
190
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
191
+ To negatively condition the generation process based on a target image resolution. It should be as same
192
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
193
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
194
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
195
+
196
+ Examples:
197
+
198
+ Returns:
199
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
200
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
201
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
202
+ """
203
+ # 0. Default height and width to unet
204
+ height = height or self.default_sample_size * self.vae_scale_factor
205
+ width = width or self.default_sample_size * self.vae_scale_factor
206
+
207
+ original_size = original_size or (height, width)
208
+ target_size = target_size or (height, width)
209
+
210
+ # 1. Check inputs. Raise error if not correct
211
+ self.check_inputs(
212
+ prompt,
213
+ prompt_2,
214
+ height,
215
+ width,
216
+ callback_steps,
217
+ negative_prompt,
218
+ negative_prompt_2,
219
+ prompt_embeds,
220
+ negative_prompt_embeds,
221
+ pooled_prompt_embeds,
222
+ negative_pooled_prompt_embeds,
223
+ )
224
+
225
+ # 2. Define call parameters
226
+ if prompt is not None and isinstance(prompt, str):
227
+ batch_size = 1
228
+ elif prompt is not None and isinstance(prompt, list):
229
+ batch_size = len(prompt)
230
+ else:
231
+ batch_size = prompt_embeds.shape[0]
232
+
233
+ device = self._execution_device
234
+
235
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
236
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
237
+ # corresponds to doing no classifier free guidance.
238
+ do_classifier_free_guidance = guidance_scale > 1.0
239
+
240
+ # 3. Encode input prompt
241
+ text_encoder_lora_scale = (
242
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
243
+ )
244
+ (
245
+ prompt_embeds,
246
+ negative_prompt_embeds,
247
+ pooled_prompt_embeds,
248
+ negative_pooled_prompt_embeds,
249
+ ) = self.encode_prompt(
250
+ prompt=prompt,
251
+ prompt_2=prompt_2,
252
+ device=device,
253
+ num_images_per_prompt=num_images_per_prompt,
254
+ do_classifier_free_guidance=do_classifier_free_guidance,
255
+ negative_prompt=negative_prompt,
256
+ negative_prompt_2=negative_prompt_2,
257
+ prompt_embeds=prompt_embeds,
258
+ negative_prompt_embeds=negative_prompt_embeds,
259
+ pooled_prompt_embeds=pooled_prompt_embeds,
260
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
261
+ lora_scale=text_encoder_lora_scale,
262
+ )
263
+
264
+ # 4. Prepare timesteps
265
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
266
+
267
+ timesteps = self.scheduler.timesteps
268
+
269
+ # 5. Prepare latent variables
270
+ num_channels_latents = unet.config.in_channels
271
+ latents = self.prepare_latents(
272
+ batch_size * num_images_per_prompt,
273
+ num_channels_latents,
274
+ height,
275
+ width,
276
+ prompt_embeds.dtype,
277
+ device,
278
+ generator,
279
+ latents,
280
+ )
281
+
282
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
283
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
284
+
285
+ # 7. Prepare added time ids & embeddings
286
+ add_text_embeds = pooled_prompt_embeds
287
+ add_time_ids = self._get_add_time_ids(
288
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
289
+ )
290
+ if negative_original_size is not None and negative_target_size is not None:
291
+ negative_add_time_ids = self._get_add_time_ids(
292
+ negative_original_size,
293
+ negative_crops_coords_top_left,
294
+ negative_target_size,
295
+ dtype=prompt_embeds.dtype,
296
+ )
297
+ else:
298
+ negative_add_time_ids = add_time_ids
299
+
300
+ if do_classifier_free_guidance:
301
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
302
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
303
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
304
+
305
+ prompt_embeds = prompt_embeds.to(device)
306
+ add_text_embeds = add_text_embeds.to(device)
307
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
308
+
309
+ # 8. Denoising loop
310
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
311
+
312
+ # 7.1 Apply denoising_end
313
+ if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
314
+ discrete_timestep_cutoff = int(
315
+ round(
316
+ self.scheduler.config.num_train_timesteps
317
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
318
+ )
319
+ )
320
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
321
+ timesteps = timesteps[:num_inference_steps]
322
+
323
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
324
+ for i, t in enumerate(timesteps):
325
+ if t>start_noise:
326
+ network.set_lora_slider(scale=0)
327
+ else:
328
+ network.set_lora_slider(scale=scale)
329
+ # expand the latents if we are doing classifier free guidance
330
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
331
+
332
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
333
+
334
+ # predict the noise residual
335
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
336
+ with network:
337
+ noise_pred = unet(
338
+ latent_model_input,
339
+ t,
340
+ encoder_hidden_states=prompt_embeds,
341
+ cross_attention_kwargs=cross_attention_kwargs,
342
+ added_cond_kwargs=added_cond_kwargs,
343
+ return_dict=False,
344
+ )[0]
345
+
346
+ # perform guidance
347
+ if do_classifier_free_guidance:
348
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
349
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
350
+
351
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
352
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
353
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
354
+
355
+ # compute the previous noisy sample x_t -> x_t-1
356
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
357
+
358
+ # call the callback, if provided
359
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
360
+ progress_bar.update()
361
+ if callback is not None and i % callback_steps == 0:
362
+ callback(i, t, latents)
363
+
364
+ if not output_type == "latent":
365
+ # make sure the VAE is in float32 mode, as it overflows in float16
366
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
367
+
368
+ if needs_upcasting:
369
+ self.upcast_vae()
370
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
371
+
372
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
373
+
374
+ # cast back to fp16 if needed
375
+ if needs_upcasting:
376
+ self.vae.to(dtype=torch.float16)
377
+ else:
378
+ image = latents
379
+
380
+ if not output_type == "latent":
381
+ # apply watermark if available
382
+ if self.watermark is not None:
383
+ image = self.watermark.apply_watermark(image)
384
+
385
+ image = self.image_processor.postprocess(image, output_type=output_type)
386
+
387
+ # Offload all models
388
+ # self.maybe_free_model_hooks()
389
+
390
+ if not return_dict:
391
+ return (image,)
392
+
393
+ return StableDiffusionXLPipelineOutput(images=image)
394
+
395
+
396
+ def sorted_nicely( l ):
397
+ convert = lambda text: float(text) if text.replace('-','').replace('.','').isdigit() else text
398
+ alphanum_key = lambda key: [convert(c) for c in re.split('(-?[0-9]+.?[0-9]+?)', key) ]
399
+ return sorted(l, key = alphanum_key)
400
+
401
+ def flush():
402
+ torch.cuda.empty_cache()
403
+ gc.collect()
404
+
405
+
406
+ if __name__=='__main__':
407
+
408
+ device = 'cuda:0'
409
+ StableDiffusionXLPipeline.__call__ = call
410
+ pipe = StableDiffusionXLPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0')
411
+
412
+ # pipe.__call__ = call
413
+ pipe = pipe.to(device)
414
+
415
+
416
+ parser = argparse.ArgumentParser(
417
+ prog = 'generateImages',
418
+ description = 'Generate Images using Diffusers Code')
419
+ parser.add_argument('--model_name', help='name of model', type=str, required=True)
420
+ parser.add_argument('--prompts_path', help='path to csv file with prompts', type=str, required=True)
421
+ parser.add_argument('--negative_prompts', help='negative prompt', type=str, required=False, default=None)
422
+ parser.add_argument('--save_path', help='folder where to save images', type=str, required=True)
423
+ parser.add_argument('--base', help='version of stable diffusion to use', type=str, required=False, default='1.4')
424
+ parser.add_argument('--guidance_scale', help='guidance to run eval', type=float, required=False, default=7.5)
425
+ parser.add_argument('--image_size', help='image size used to train', type=int, required=False, default=512)
426
+ parser.add_argument('--till_case', help='continue generating from case_number', type=int, required=False, default=1000000)
427
+ parser.add_argument('--from_case', help='continue generating from case_number', type=int, required=False, default=0)
428
+ parser.add_argument('--num_samples', help='number of samples per prompt', type=int, required=False, default=5)
429
+ parser.add_argument('--ddim_steps', help='ddim steps of inference used to train', type=int, required=False, default=50)
430
+ parser.add_argument('--rank', help='rank of the LoRA', type=int, required=False, default=4)
431
+ parser.add_argument('--start_noise', help='what time stamp to flip to edited model', type=int, required=False, default=750)
432
+
433
+ args = parser.parse_args()
434
+ lora_weight = args.model_name
435
+ csv_path = args.prompts_path
436
+ save_path = args.save_path
437
+ start_noise = args.start_noise
438
+ from_case = args.from_case
439
+ till_case = args.till_case
440
+
441
+ weight_dtype = torch.float16
442
+ num_images_per_prompt = 1
443
+ scales = [-2, -1, 0, 1, 2]
444
+ scales = [-1, -.5, 0, .5, 1]
445
+ scales = [-2]
446
+ df = pd.read_csv(csv_path)
447
+
448
+ for scale in scales:
449
+ os.makedirs(f'{save_path}/{os.path.basename(lora_weight)}/{scale}', exist_ok=True)
450
+
451
+ prompts = list(df['prompt'])
452
+ seeds = list(df['evaluation_seed'])
453
+ case_numbers = list(df['case_number'])
454
+ pipe = StableDiffusionXLPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0',torch_dtype=torch.float16,)
455
+
456
+ # pipe.__call__ = call
457
+ pipe = pipe.to(device)
458
+ unet = pipe.unet
459
+ if 'full' in lora_weight:
460
+ train_method = 'full'
461
+ elif 'noxattn' in lora_weight:
462
+ train_method = 'noxattn'
463
+ else:
464
+ train_method = 'noxattn'
465
+
466
+ network_type = "c3lier"
467
+ if train_method == 'xattn':
468
+ network_type = 'lierla'
469
+
470
+ modules = DEFAULT_TARGET_REPLACE
471
+ if network_type == "c3lier":
472
+ modules += UNET_TARGET_REPLACE_MODULE_CONV
473
+ import os
474
+ model_name = lora_weight
475
+
476
+ name = os.path.basename(model_name)
477
+ rank = 1
478
+ alpha = 4
479
+ if 'rank4' in lora_weight:
480
+ rank = 4
481
+ if 'rank8' in lora_weight:
482
+ rank = 8
483
+ if 'alpha1' in lora_weight:
484
+ alpha = 1.0
485
+ network = LoRANetwork(
486
+ unet,
487
+ rank=rank,
488
+ multiplier=1.0,
489
+ alpha=alpha,
490
+ train_method=train_method,
491
+ ).to(device, dtype=weight_dtype)
492
+ network.load_state_dict(torch.load(lora_weight))
493
+
494
+ for idx, prompt in enumerate(prompts):
495
+ seed = seeds[idx]
496
+ case_number = case_numbers[idx]
497
+
498
+ if not (case_number>=from_case and case_number<=till_case):
499
+ continue
500
+ if os.path.exists(f'{save_path}/{os.path.basename(lora_weight)}/{scale}/{case_number}_{idx}.png'):
501
+ continue
502
+ print(prompt, seed)
503
+ for scale in scales:
504
+ generator = torch.manual_seed(seed)
505
+ images = pipe(prompt, num_images_per_prompt=args.num_samples, num_inference_steps=50, generator=generator, network=network, start_noise=start_noise, scale=scale, unet=unet).images
506
+ for idx, im in enumerate(images):
507
+ im.save(f'{save_path}/{os.path.basename(lora_weight)}/{scale}/{case_number}_{idx}.png')
508
+ del unet, network, pipe
509
+ unet = None
510
+ network = None
511
+ pipe = None
512
+ torch.cuda.empty_cache()
513
+ flush()
trainscripts/textsliders/lora.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ref:
2
+ # - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
3
+ # - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
4
+
5
+ import os
6
+ import math
7
+ from typing import Optional, List, Type, Set, Literal
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from diffusers import UNet2DConditionModel
12
+ from safetensors.torch import save_file
13
+
14
+
15
+ UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [
16
+ # "Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2
17
+ "Attention"
18
+ ]
19
+ UNET_TARGET_REPLACE_MODULE_CONV = [
20
+ "ResnetBlock2D",
21
+ "Downsample2D",
22
+ "Upsample2D",
23
+ "DownBlock2D",
24
+ "UpBlock2D",
25
+
26
+ ] # locon, 3clier
27
+
28
+ LORA_PREFIX_UNET = "lora_unet"
29
+
30
+ DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER
31
+
32
+ TRAINING_METHODS = Literal[
33
+ "noxattn", # train all layers except x-attns and time_embed layers
34
+ "innoxattn", # train all layers except self attention layers
35
+ "selfattn", # ESD-u, train only self attention layers
36
+ "xattn", # ESD-x, train only x attention layers
37
+ "full", # train all layers
38
+ "xattn-strict", # q and k values
39
+ "noxattn-hspace",
40
+ "noxattn-hspace-last",
41
+ # "xlayer",
42
+ # "outxattn",
43
+ # "outsattn",
44
+ # "inxattn",
45
+ # "inmidsattn",
46
+ # "selflayer",
47
+ ]
48
+
49
+
50
+ class LoRAModule(nn.Module):
51
+ """
52
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ lora_name,
58
+ org_module: nn.Module,
59
+ multiplier=1.0,
60
+ lora_dim=4,
61
+ alpha=1,
62
+ ):
63
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
64
+ super().__init__()
65
+ self.lora_name = lora_name
66
+ self.lora_dim = lora_dim
67
+
68
+ if "Linear" in org_module.__class__.__name__:
69
+ in_dim = org_module.in_features
70
+ out_dim = org_module.out_features
71
+ self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
72
+ self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
73
+
74
+ elif "Conv" in org_module.__class__.__name__: # 一応
75
+ in_dim = org_module.in_channels
76
+ out_dim = org_module.out_channels
77
+
78
+ self.lora_dim = min(self.lora_dim, in_dim, out_dim)
79
+ if self.lora_dim != lora_dim:
80
+ print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
81
+
82
+ kernel_size = org_module.kernel_size
83
+ stride = org_module.stride
84
+ padding = org_module.padding
85
+ self.lora_down = nn.Conv2d(
86
+ in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
87
+ )
88
+ self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
89
+
90
+ if type(alpha) == torch.Tensor:
91
+ alpha = alpha.detach().numpy()
92
+ alpha = lora_dim if alpha is None or alpha == 0 else alpha
93
+ self.scale = alpha / self.lora_dim
94
+ self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
95
+
96
+ # same as microsoft's
97
+ nn.init.kaiming_uniform_(self.lora_down.weight, a=1)
98
+ nn.init.zeros_(self.lora_up.weight)
99
+
100
+ self.multiplier = multiplier
101
+ self.org_module = org_module # remove in applying
102
+
103
+ def apply_to(self):
104
+ self.org_forward = self.org_module.forward
105
+ self.org_module.forward = self.forward
106
+ del self.org_module
107
+
108
+ def forward(self, x):
109
+ return (
110
+ self.org_forward(x)
111
+ + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
112
+ )
113
+
114
+
115
+ class LoRANetwork(nn.Module):
116
+ def __init__(
117
+ self,
118
+ unet: UNet2DConditionModel,
119
+ rank: int = 4,
120
+ multiplier: float = 1.0,
121
+ alpha: float = 1.0,
122
+ train_method: TRAINING_METHODS = "full",
123
+ ) -> None:
124
+ super().__init__()
125
+ self.lora_scale = 1
126
+ self.multiplier = multiplier
127
+ self.lora_dim = rank
128
+ self.alpha = alpha
129
+
130
+ # LoRAのみ
131
+ self.module = LoRAModule
132
+
133
+ # unetのloraを作る
134
+ self.unet_loras = self.create_modules(
135
+ LORA_PREFIX_UNET,
136
+ unet,
137
+ DEFAULT_TARGET_REPLACE,
138
+ self.lora_dim,
139
+ self.multiplier,
140
+ train_method=train_method,
141
+ )
142
+ print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
143
+
144
+ # assertion 名前の被りがないか確認しているようだ
145
+ lora_names = set()
146
+ for lora in self.unet_loras:
147
+ assert (
148
+ lora.lora_name not in lora_names
149
+ ), f"duplicated lora name: {lora.lora_name}. {lora_names}"
150
+ lora_names.add(lora.lora_name)
151
+
152
+ # 適用する
153
+ for lora in self.unet_loras:
154
+ lora.apply_to()
155
+ self.add_module(
156
+ lora.lora_name,
157
+ lora,
158
+ )
159
+
160
+ del unet
161
+
162
+ torch.cuda.empty_cache()
163
+
164
+ def create_modules(
165
+ self,
166
+ prefix: str,
167
+ root_module: nn.Module,
168
+ target_replace_modules: List[str],
169
+ rank: int,
170
+ multiplier: float,
171
+ train_method: TRAINING_METHODS,
172
+ ) -> list:
173
+ loras = []
174
+ names = []
175
+ for name, module in root_module.named_modules():
176
+ if train_method == "noxattn" or train_method == "noxattn-hspace" or train_method == "noxattn-hspace-last": # Cross Attention と Time Embed 以外学習
177
+ if "attn2" in name or "time_embed" in name:
178
+ continue
179
+ elif train_method == "innoxattn": # Cross Attention 以外学習
180
+ if "attn2" in name:
181
+ continue
182
+ elif train_method == "selfattn": # Self Attention のみ学習
183
+ if "attn1" not in name:
184
+ continue
185
+ elif train_method == "xattn" or train_method == "xattn-strict": # Cross Attention のみ学習
186
+ if "attn2" not in name:
187
+ continue
188
+ elif train_method == "full": # 全部学習
189
+ pass
190
+ else:
191
+ raise NotImplementedError(
192
+ f"train_method: {train_method} is not implemented."
193
+ )
194
+ if module.__class__.__name__ in target_replace_modules:
195
+ for child_name, child_module in module.named_modules():
196
+ if child_module.__class__.__name__ in ["Linear", "Conv2d", "LoRACompatibleLinear", "LoRACompatibleConv"]:
197
+ if train_method == 'xattn-strict':
198
+ if 'out' in child_name:
199
+ continue
200
+ if train_method == 'noxattn-hspace':
201
+ if 'mid_block' not in name:
202
+ continue
203
+ if train_method == 'noxattn-hspace-last':
204
+ if 'mid_block' not in name or '.1' not in name or 'conv2' not in child_name:
205
+ continue
206
+ lora_name = prefix + "." + name + "." + child_name
207
+ lora_name = lora_name.replace(".", "_")
208
+ # print(f"{lora_name}")
209
+ lora = self.module(
210
+ lora_name, child_module, multiplier, rank, self.alpha
211
+ )
212
+ # print(name, child_name)
213
+ # print(child_module.weight.shape)
214
+ if lora_name not in names:
215
+ loras.append(lora)
216
+ names.append(lora_name)
217
+ # print(f'@@@@@@@@@@@@@@@@@@@@@@@@@@@@ \n {names}')
218
+ return loras
219
+
220
+ def prepare_optimizer_params(self):
221
+ all_params = []
222
+
223
+ if self.unet_loras: # 実質これしかない
224
+ params = []
225
+ [params.extend(lora.parameters()) for lora in self.unet_loras]
226
+ param_data = {"params": params}
227
+ all_params.append(param_data)
228
+
229
+ return all_params
230
+
231
+ def save_weights(self, file, dtype=None, metadata: Optional[dict] = None):
232
+ state_dict = self.state_dict()
233
+
234
+ if dtype is not None:
235
+ for key in list(state_dict.keys()):
236
+ v = state_dict[key]
237
+ v = v.detach().clone().to("cpu").to(dtype)
238
+ state_dict[key] = v
239
+
240
+ # for key in list(state_dict.keys()):
241
+ # if not key.startswith("lora"):
242
+ # # lora以外除外
243
+ # del state_dict[key]
244
+
245
+ if os.path.splitext(file)[1] == ".safetensors":
246
+ save_file(state_dict, file, metadata)
247
+ else:
248
+ torch.save(state_dict, file)
249
+ def set_lora_slider(self, scale):
250
+ self.lora_scale = scale
251
+
252
+ def __enter__(self):
253
+ for lora in self.unet_loras:
254
+ lora.multiplier = 1.0 * self.lora_scale
255
+
256
+ def __exit__(self, exc_type, exc_value, tb):
257
+ for lora in self.unet_loras:
258
+ lora.multiplier = 0