SunilGopal commited on
Commit
1e37f65
1 Parent(s): 9657edf

Upload 4 files

Browse files
coarse_transformer.ipynb ADDED
@@ -0,0 +1,632 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Coarse Transformer"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "### Libraries:"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 1,
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "import torch\n",
24
+ "from audiolm_pytorch import HubertWithKmeans\n",
25
+ "from audiolm_pytorch import SemanticTransformer, SemanticTransformerTrainer\n",
26
+ "from audiolm_pytorch import CoarseTransformer, CoarseTransformerTrainer\n",
27
+ "from audiolm_pytorch import SoundStream, FineTransformer, FineTransformerTrainer\n",
28
+ "from audiolm_pytorch import AudioLMSoundStream, AudioLM, MusicLMSoundStream\n",
29
+ "import gc\n",
30
+ "from musiclm_pytorch import MuLaNEmbedQuantizer\n",
31
+ "from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": 2,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "checkpoint_path = './models/hubert/hubert_base_ls960.pt'\n",
41
+ "kmeans_path = './models/hubert/hubert_base_ls960_L9_km500.bin'\n",
42
+ "\n",
43
+ "audio_output_dir = './audio'\n",
44
+ "batch_size = 1\n",
45
+ "data_max_length = 320 * 32\n",
46
+ "num_train_steps = 1000"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": 3,
52
+ "metadata": {},
53
+ "outputs": [
54
+ {
55
+ "name": "stdout",
56
+ "output_type": "stream",
57
+ "text": [
58
+ "spectrogram yielded shape of (65, 86), but had to be cropped to (64, 80) to be patchified for transformer\n"
59
+ ]
60
+ }
61
+ ],
62
+ "source": [
63
+ "audio_transformer = AudioSpectrogramTransformer(\n",
64
+ " dim = 512,\n",
65
+ " depth = 6,\n",
66
+ " heads = 8,\n",
67
+ " dim_head = 64,\n",
68
+ " spec_n_fft = 128,\n",
69
+ " spec_win_length = 24,\n",
70
+ " spec_aug_stretch_factor = 0.8\n",
71
+ ")\n",
72
+ "\n",
73
+ "text_transformer = TextTransformer(\n",
74
+ " dim = 512,\n",
75
+ " depth = 6,\n",
76
+ " heads = 8,\n",
77
+ " dim_head = 64\n",
78
+ ")\n",
79
+ "\n",
80
+ "mulan = MuLaN(\n",
81
+ " audio_transformer = audio_transformer,\n",
82
+ " text_transformer = text_transformer\n",
83
+ ")\n",
84
+ "\n",
85
+ "quantizer = MuLaNEmbedQuantizer(\n",
86
+ " mulan = mulan, \n",
87
+ " conditioning_dims = (1024, 1024, 1024), \n",
88
+ " namespaces = ('semantic', 'coarse', 'fine')\n",
89
+ ")\n",
90
+ "wavs = torch.randn(2, 1024)\n",
91
+ "conds = quantizer(wavs = wavs, namespace = 'semantic')"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": 4,
97
+ "metadata": {},
98
+ "outputs": [
99
+ {
100
+ "name": "stdout",
101
+ "output_type": "stream",
102
+ "text": [
103
+ "ANTLR runtime and generated code versions disagree: 4.9.3!=4.8\n",
104
+ "ANTLR runtime and generated code versions disagree: 4.9.3!=4.8\n",
105
+ "training with dataset of 4806 samples and validating with randomly splitted 253 samples\n",
106
+ "0: loss: 90.55248260498047\n",
107
+ "0: valid loss 28.765926361083984\n",
108
+ "0: saving model to results\n",
109
+ "1: loss: 39.71841812133789\n",
110
+ "2: loss: 89.22168731689453\n",
111
+ "3: loss: 64.72769927978516\n",
112
+ "4: loss: 46.61131286621094\n",
113
+ "5: loss: 71.61656951904297\n",
114
+ "6: loss: 51.03081130981445\n",
115
+ "7: loss: 41.790443420410156\n",
116
+ "8: loss: 53.92983627319336\n",
117
+ "9: loss: 34.468536376953125\n",
118
+ "10: loss: 33.230533599853516\n",
119
+ "11: loss: 39.82740020751953\n",
120
+ "12: loss: 25.284324645996094\n",
121
+ "13: loss: 28.97213363647461\n",
122
+ "14: loss: 30.330350875854492\n",
123
+ "15: loss: 29.048341751098633\n",
124
+ "16: loss: 22.92132568359375\n",
125
+ "17: loss: 19.784038543701172\n",
126
+ "18: loss: 24.917173385620117\n",
127
+ "19: loss: 21.861900329589844\n",
128
+ "20: loss: 21.64893913269043\n",
129
+ "21: loss: 19.426795959472656\n",
130
+ "22: loss: 16.47875213623047\n",
131
+ "23: loss: 14.150989532470703\n",
132
+ "24: loss: 16.4312686920166\n",
133
+ "25: loss: 10.732200622558594\n",
134
+ "26: loss: 9.64625358581543\n",
135
+ "27: loss: 13.40906047821045\n",
136
+ "28: loss: 8.942117691040039\n",
137
+ "29: loss: 14.944022178649902\n",
138
+ "30: loss: 17.149667739868164\n",
139
+ "31: loss: 8.965814590454102\n",
140
+ "32: loss: 10.492903709411621\n",
141
+ "33: loss: 11.236382484436035\n",
142
+ "34: loss: 10.356119155883789\n",
143
+ "35: loss: 9.816141128540039\n",
144
+ "36: loss: 11.789191246032715\n",
145
+ "37: loss: 10.450325012207031\n",
146
+ "38: loss: 18.911396026611328\n",
147
+ "39: loss: 8.278931617736816\n",
148
+ "40: loss: 10.884782791137695\n",
149
+ "41: loss: 8.885784149169922\n",
150
+ "42: loss: 9.226049423217773\n",
151
+ "43: loss: 10.362125396728516\n",
152
+ "44: loss: 4.0845770835876465\n",
153
+ "45: loss: 9.664544105529785\n",
154
+ "46: loss: 9.46312427520752\n",
155
+ "47: loss: 9.138323783874512\n",
156
+ "48: loss: 7.396448135375977\n",
157
+ "49: loss: 7.293612480163574\n",
158
+ "50: loss: 10.331693649291992\n",
159
+ "51: loss: 7.775559425354004\n",
160
+ "52: loss: 7.011277198791504\n",
161
+ "53: loss: 6.324047565460205\n",
162
+ "54: loss: 5.501199245452881\n",
163
+ "55: loss: 4.69442081451416\n",
164
+ "56: loss: 4.073971748352051\n",
165
+ "57: loss: 4.142904758453369\n",
166
+ "58: loss: 4.585968017578125\n",
167
+ "59: loss: 4.700481414794922\n",
168
+ "60: loss: 5.152374267578125\n",
169
+ "61: loss: 8.181085586547852\n",
170
+ "62: loss: 6.7371416091918945\n",
171
+ "63: loss: 10.67423152923584\n",
172
+ "64: loss: 5.926950454711914\n",
173
+ "65: loss: 5.470860004425049\n",
174
+ "66: loss: 4.630016803741455\n",
175
+ "67: loss: 5.366561412811279\n",
176
+ "68: loss: 11.271105766296387\n",
177
+ "69: loss: 6.516841411590576\n",
178
+ "70: loss: 7.9438066482543945\n",
179
+ "71: loss: 5.358776092529297\n",
180
+ "72: loss: 5.713461875915527\n",
181
+ "73: loss: 7.075550556182861\n",
182
+ "74: loss: 5.229584217071533\n",
183
+ "75: loss: 5.103419303894043\n",
184
+ "76: loss: 4.516308307647705\n",
185
+ "77: loss: 7.4682488441467285\n",
186
+ "78: loss: 7.275866508483887\n",
187
+ "79: loss: 5.846785545349121\n",
188
+ "80: loss: 5.688624382019043\n",
189
+ "81: loss: 5.150119781494141\n",
190
+ "82: loss: 4.671944618225098\n",
191
+ "83: loss: 8.293455123901367\n",
192
+ "84: loss: 7.202897071838379\n",
193
+ "85: loss: 4.38778018951416\n",
194
+ "86: loss: 4.410329818725586\n",
195
+ "87: loss: 4.341781139373779\n",
196
+ "88: loss: 4.000961780548096\n",
197
+ "89: loss: 4.009156703948975\n",
198
+ "90: loss: 3.562082052230835\n",
199
+ "91: loss: 3.641108989715576\n",
200
+ "92: loss: 5.916473388671875\n",
201
+ "93: loss: 4.046755790710449\n",
202
+ "94: loss: 6.699942111968994\n",
203
+ "95: loss: 6.139719009399414\n",
204
+ "96: loss: 10.71791934967041\n",
205
+ "97: loss: 4.094853401184082\n",
206
+ "98: loss: 6.08973503112793\n",
207
+ "99: loss: 9.11803150177002\n",
208
+ "100: loss: 8.486052513122559\n",
209
+ "100: valid loss 4.0021281242370605\n",
210
+ "101: loss: 4.0021281242370605\n",
211
+ "102: loss: 3.346961736679077\n",
212
+ "103: loss: 3.15854549407959\n",
213
+ "104: loss: 2.5357956886291504\n",
214
+ "105: loss: 5.492861270904541\n",
215
+ "106: loss: 2.7623958587646484\n",
216
+ "107: loss: 2.9482226371765137\n",
217
+ "108: loss: 6.3801493644714355\n",
218
+ "109: loss: 4.1293463706970215\n",
219
+ "110: loss: 3.566096067428589\n",
220
+ "111: loss: 3.569946527481079\n",
221
+ "112: loss: 3.762925624847412\n",
222
+ "113: loss: 6.147146701812744\n",
223
+ "114: loss: 5.933719635009766\n",
224
+ "115: loss: 6.800720691680908\n",
225
+ "116: loss: 2.86614990234375\n",
226
+ "117: loss: 3.0812878608703613\n",
227
+ "118: loss: 3.110222101211548\n",
228
+ "119: loss: 4.000320911407471\n",
229
+ "120: loss: 3.2422871589660645\n",
230
+ "121: loss: 3.7775020599365234\n",
231
+ "122: loss: 3.595900774002075\n",
232
+ "123: loss: 2.73819637298584\n",
233
+ "124: loss: 3.4981672763824463\n",
234
+ "125: loss: 5.3726325035095215\n",
235
+ "126: loss: 3.0014798641204834\n",
236
+ "127: loss: 3.5963802337646484\n",
237
+ "128: loss: 2.8306686878204346\n",
238
+ "129: loss: 2.5162878036499023\n",
239
+ "130: loss: 2.685560941696167\n",
240
+ "131: loss: 6.374442100524902\n",
241
+ "132: loss: 7.788975715637207\n",
242
+ "133: loss: 2.897576332092285\n",
243
+ "134: loss: 3.333127737045288\n",
244
+ "135: loss: 3.436774253845215\n",
245
+ "136: loss: 4.979071617126465\n",
246
+ "137: loss: 4.120012283325195\n",
247
+ "138: loss: 3.7855355739593506\n",
248
+ "139: loss: 4.324587345123291\n",
249
+ "140: loss: 3.4336843490600586\n",
250
+ "141: loss: 2.6801435947418213\n",
251
+ "142: loss: 3.359581470489502\n",
252
+ "143: loss: 5.4692182540893555\n",
253
+ "144: loss: 5.773078918457031\n",
254
+ "145: loss: 4.27987813949585\n",
255
+ "146: loss: 7.247451305389404\n",
256
+ "147: loss: 6.170166492462158\n",
257
+ "148: loss: 4.961609840393066\n",
258
+ "149: loss: 4.028770923614502\n",
259
+ "150: loss: 2.90120005607605\n",
260
+ "151: loss: 1.9893661737442017\n",
261
+ "152: loss: 1.652574062347412\n",
262
+ "153: loss: 2.374600887298584\n",
263
+ "154: loss: 2.1045265197753906\n",
264
+ "155: loss: 6.417508125305176\n",
265
+ "156: loss: 5.273669719696045\n",
266
+ "157: loss: 6.238985538482666\n",
267
+ "158: loss: 3.8025736808776855\n",
268
+ "159: loss: 6.6854705810546875\n",
269
+ "160: loss: 2.5476467609405518\n",
270
+ "161: loss: 6.810393810272217\n",
271
+ "162: loss: 2.2033159732818604\n",
272
+ "163: loss: 1.9863100051879883\n",
273
+ "164: loss: 4.976431369781494\n",
274
+ "165: loss: 3.899188756942749\n",
275
+ "166: loss: 4.68454647064209\n",
276
+ "167: loss: 2.4539690017700195\n",
277
+ "168: loss: 6.830282688140869\n",
278
+ "169: loss: 1.7942843437194824\n",
279
+ "170: loss: 1.242318868637085\n",
280
+ "171: loss: 5.012855052947998\n",
281
+ "172: loss: 1.6154134273529053\n",
282
+ "173: loss: 1.5895756483078003\n",
283
+ "174: loss: 5.240614891052246\n",
284
+ "175: loss: 1.8958660364151\n",
285
+ "176: loss: 2.1411402225494385\n",
286
+ "177: loss: 5.932228088378906\n",
287
+ "178: loss: 2.7539122104644775\n",
288
+ "179: loss: 6.218499660491943\n",
289
+ "180: loss: 2.991704225540161\n",
290
+ "181: loss: 3.378645896911621\n",
291
+ "182: loss: 2.719741106033325\n",
292
+ "183: loss: 2.5844321250915527\n",
293
+ "184: loss: 5.851257801055908\n",
294
+ "185: loss: 2.239989995956421\n",
295
+ "186: loss: 5.5589141845703125\n",
296
+ "187: loss: 3.11521053314209\n",
297
+ "188: loss: 2.5269265174865723\n",
298
+ "189: loss: 2.181260824203491\n",
299
+ "190: loss: 1.8941911458969116\n",
300
+ "191: loss: 5.106175422668457\n",
301
+ "192: loss: 3.5514838695526123\n",
302
+ "193: loss: 3.233003854751587\n",
303
+ "194: loss: 2.55694317817688\n",
304
+ "195: loss: 6.5134053230285645\n",
305
+ "196: loss: 6.311967372894287\n",
306
+ "197: loss: 2.3541362285614014\n",
307
+ "198: loss: 6.195401668548584\n",
308
+ "199: loss: 3.013007879257202\n",
309
+ "200: loss: 2.53104567527771\n",
310
+ "200: valid loss 1.895339846611023\n",
311
+ "201: loss: 7.572109699249268\n",
312
+ "202: loss: 1.946860909461975\n",
313
+ "203: loss: 1.6077873706817627\n",
314
+ "204: loss: 1.5050052404403687\n",
315
+ "205: loss: 1.1216596364974976\n",
316
+ "206: loss: 1.017206072807312\n",
317
+ "207: loss: 7.081823825836182\n",
318
+ "208: loss: 1.1608872413635254\n",
319
+ "209: loss: 0.728882908821106\n",
320
+ "210: loss: 0.514722466468811\n",
321
+ "211: loss: 0.6075964570045471\n",
322
+ "212: loss: 0.7593868970870972\n",
323
+ "213: loss: 0.6465023159980774\n",
324
+ "214: loss: 8.1160888671875\n",
325
+ "215: loss: 0.8256340622901917\n",
326
+ "216: loss: 0.5982277393341064\n",
327
+ "217: loss: 7.202335834503174\n",
328
+ "218: loss: 4.8967790603637695\n",
329
+ "219: loss: 2.037604331970215\n",
330
+ "220: loss: 1.7443571090698242\n",
331
+ "221: loss: 0.8838777542114258\n",
332
+ "222: loss: 0.7871264219284058\n",
333
+ "223: loss: 5.985363483428955\n",
334
+ "224: loss: 3.6808922290802\n",
335
+ "225: loss: 4.453125476837158\n",
336
+ "226: loss: 4.137350559234619\n",
337
+ "227: loss: 1.5606231689453125\n",
338
+ "228: loss: 5.764791488647461\n",
339
+ "229: loss: 1.2394036054611206\n",
340
+ "230: loss: 1.1438194513320923\n",
341
+ "231: loss: 0.5560073852539062\n",
342
+ "232: loss: 5.746810436248779\n",
343
+ "233: loss: 4.34252405166626\n",
344
+ "234: loss: 6.079676628112793\n",
345
+ "235: loss: 4.213600158691406\n",
346
+ "236: loss: 1.1661522388458252\n",
347
+ "237: loss: 7.770791053771973\n",
348
+ "238: loss: 3.6331183910369873\n",
349
+ "239: loss: 6.657710552215576\n",
350
+ "240: loss: 4.314018249511719\n",
351
+ "241: loss: 3.964081048965454\n",
352
+ "242: loss: 3.4643802642822266\n",
353
+ "243: loss: 3.2389814853668213\n",
354
+ "244: loss: 5.009263515472412\n",
355
+ "245: loss: 5.4173903465271\n",
356
+ "246: loss: 3.464853048324585\n",
357
+ "247: loss: 2.690930128097534\n",
358
+ "248: loss: 5.482550621032715\n",
359
+ "249: loss: 1.500435709953308\n",
360
+ "250: loss: 1.207865834236145\n",
361
+ "251: loss: 6.162202835083008\n",
362
+ "252: loss: 0.5159206986427307\n",
363
+ "253: loss: 0.352285772562027\n",
364
+ "254: loss: 0.28347644209861755\n",
365
+ "255: loss: 0.2998739182949066\n",
366
+ "256: loss: 7.412589073181152\n",
367
+ "257: loss: 1.0271281003952026\n",
368
+ "258: loss: 0.5622831583023071\n",
369
+ "259: loss: 6.975170135498047\n",
370
+ "260: loss: 0.050237879157066345\n",
371
+ "261: loss: 9.500787734985352\n",
372
+ "262: loss: 1.1100494861602783\n",
373
+ "263: loss: 10.5401029586792\n",
374
+ "264: loss: 7.637964725494385\n",
375
+ "265: loss: 1.5384433269500732\n",
376
+ "266: loss: 0.6748937368392944\n",
377
+ "267: loss: 0.38336750864982605\n",
378
+ "268: loss: 0.1832476705312729\n",
379
+ "269: loss: 7.080984115600586\n",
380
+ "270: loss: 6.806582927703857\n",
381
+ "271: loss: 6.216980457305908\n",
382
+ "272: loss: 8.122699737548828\n",
383
+ "273: loss: 2.344430685043335\n",
384
+ "274: loss: 5.185897350311279\n",
385
+ "275: loss: 5.136538982391357\n",
386
+ "276: loss: 4.847122669219971\n",
387
+ "277: loss: 3.447641372680664\n",
388
+ "278: loss: 1.9696052074432373\n",
389
+ "279: loss: 6.129249095916748\n",
390
+ "280: loss: 1.4744977951049805\n",
391
+ "281: loss: 4.836997032165527\n",
392
+ "282: loss: 4.361396789550781\n",
393
+ "283: loss: 4.975046157836914\n",
394
+ "284: loss: 5.6431074142456055\n",
395
+ "285: loss: 8.127538681030273\n",
396
+ "286: loss: 7.203218460083008\n",
397
+ "287: loss: 2.408040761947632\n",
398
+ "288: loss: 1.7607803344726562\n",
399
+ "289: loss: 1.1752283573150635\n",
400
+ "290: loss: 5.39897346496582\n",
401
+ "291: loss: 0.8753417134284973\n",
402
+ "292: loss: 6.104700088500977\n",
403
+ "293: loss: 0.8714774250984192\n",
404
+ "294: loss: 5.633414268493652\n",
405
+ "295: loss: 1.0734435319900513\n",
406
+ "296: loss: 0.5978174209594727\n",
407
+ "297: loss: 0.6240620613098145\n",
408
+ "298: loss: 0.3799970746040344\n",
409
+ "299: loss: 5.793654441833496\n",
410
+ "300: loss: 4.920631408691406\n",
411
+ "300: valid loss 0.5733768343925476\n",
412
+ "301: loss: 0.5733768343925476\n",
413
+ "302: loss: 0.35356906056404114\n",
414
+ "303: loss: 6.0288190841674805\n",
415
+ "304: loss: 0.17994554340839386\n",
416
+ "305: loss: 6.07096004486084\n",
417
+ "306: loss: 0.798763632774353\n",
418
+ "307: loss: 0.30721110105514526\n",
419
+ "308: loss: 0.35866862535476685\n",
420
+ "309: loss: 6.664376258850098\n",
421
+ "310: loss: 10.371112823486328\n",
422
+ "311: loss: 1.5442111492156982\n",
423
+ "312: loss: 0.5046924948692322\n",
424
+ "313: loss: 0.02138896845281124\n",
425
+ "314: loss: 11.088417053222656\n",
426
+ "315: loss: 0.2801823616027832\n",
427
+ "316: loss: 1.6325680017471313\n",
428
+ "317: loss: 1.042490005493164\n",
429
+ "318: loss: 0.19980621337890625\n",
430
+ "319: loss: 6.208798408508301\n",
431
+ "320: loss: 2.2923152446746826\n",
432
+ "321: loss: 1.5293265581130981\n",
433
+ "322: loss: 5.384918212890625\n",
434
+ "323: loss: 0.5806372165679932\n",
435
+ "324: loss: 0.11083264648914337\n",
436
+ "325: loss: 6.474861145019531\n",
437
+ "326: loss: 6.7361063957214355\n",
438
+ "327: loss: 6.07684850692749\n",
439
+ "328: loss: 0.1449495404958725\n",
440
+ "329: loss: 0.24492450058460236\n",
441
+ "330: loss: 0.0179277490824461\n",
442
+ "331: loss: 5.866001605987549\n",
443
+ "332: loss: 0.14012691378593445\n",
444
+ "333: loss: 0.14467062056064606\n",
445
+ "334: loss: 0.01395170483738184\n",
446
+ "335: loss: 0.04150881618261337\n",
447
+ "336: loss: 0.07648518681526184\n",
448
+ "337: loss: 9.367613792419434\n",
449
+ "338: loss: 8.372873306274414\n",
450
+ "339: loss: 0.6273093223571777\n",
451
+ "340: loss: 0.11360179632902145\n",
452
+ "341: loss: 0.02351052314043045\n",
453
+ "342: loss: 0.06904540210962296\n",
454
+ "343: loss: 0.02174321562051773\n",
455
+ "344: loss: 0.11702124029397964\n",
456
+ "345: loss: 0.061455100774765015\n",
457
+ "346: loss: 0.03193430230021477\n",
458
+ "347: loss: 0.33268794417381287\n",
459
+ "348: loss: 0.053275030106306076\n",
460
+ "349: loss: 0.009291582740843296\n",
461
+ "350: loss: 0.18401774764060974\n",
462
+ "351: loss: 0.30571281909942627\n",
463
+ "352: loss: 17.913070678710938\n",
464
+ "353: loss: 0.2126859426498413\n",
465
+ "354: loss: 0.6229326128959656\n",
466
+ "355: loss: 11.214807510375977\n",
467
+ "356: loss: 0.15888328850269318\n",
468
+ "357: loss: 0.662460446357727\n",
469
+ "358: loss: 7.345875263214111\n",
470
+ "359: loss: 7.803595066070557\n",
471
+ "360: loss: 1.2322083711624146\n",
472
+ "361: loss: 0.7014895081520081\n",
473
+ "362: loss: 0.10298460721969604\n",
474
+ "363: loss: 8.574231147766113\n",
475
+ "364: loss: 0.03108447603881359\n",
476
+ "365: loss: 0.6616091728210449\n",
477
+ "366: loss: 4.938299655914307\n",
478
+ "367: loss: 5.479018688201904\n",
479
+ "368: loss: 6.740688800811768\n",
480
+ "369: loss: 3.110865831375122\n",
481
+ "370: loss: 4.795236587524414\n",
482
+ "371: loss: 1.8502461910247803\n",
483
+ "372: loss: 3.737464427947998\n",
484
+ "373: loss: 1.9333598613739014\n",
485
+ "374: loss: 7.145735740661621\n",
486
+ "375: loss: 1.3372946977615356\n",
487
+ "376: loss: 5.683573246002197\n",
488
+ "377: loss: 1.204305648803711\n",
489
+ "378: loss: 0.9289284348487854\n",
490
+ "379: loss: 5.174688339233398\n",
491
+ "380: loss: 1.458616852760315\n",
492
+ "381: loss: 0.9457168579101562\n",
493
+ "382: loss: 0.4627819359302521\n",
494
+ "383: loss: 0.2658665180206299\n",
495
+ "384: loss: 4.429558753967285\n",
496
+ "385: loss: 1.2449607849121094\n",
497
+ "386: loss: 1.3288488388061523\n",
498
+ "387: loss: 6.628821849822998\n",
499
+ "388: loss: 0.4825551211833954\n",
500
+ "389: loss: 0.6510865688323975\n",
501
+ "390: loss: 0.36395493149757385\n",
502
+ "391: loss: 0.18036174774169922\n",
503
+ "392: loss: 0.3237663209438324\n",
504
+ "393: loss: 6.840792655944824\n",
505
+ "394: loss: 1.6587960720062256\n",
506
+ "395: loss: 7.458000659942627\n",
507
+ "396: loss: 0.8729283809661865\n",
508
+ "397: loss: 0.6731876134872437\n",
509
+ "398: loss: 0.1747300773859024\n",
510
+ "399: loss: 0.5882076621055603\n",
511
+ "400: loss: 0.6982569098472595\n",
512
+ "400: valid loss 0.4763210713863373\n",
513
+ "401: loss: 0.4763210713863373\n",
514
+ "402: loss: 0.46096739172935486\n",
515
+ "403: loss: 4.166454792022705\n",
516
+ "404: loss: 0.44991931319236755\n",
517
+ "405: loss: 4.830379009246826\n",
518
+ "406: loss: 0.5408239364624023\n",
519
+ "407: loss: 0.2607786953449249\n",
520
+ "408: loss: 0.13067474961280823\n",
521
+ "409: loss: 4.062631130218506\n",
522
+ "410: loss: 5.5028300285339355\n",
523
+ "411: loss: 1.2942296266555786\n",
524
+ "412: loss: 1.4390389919281006\n",
525
+ "413: loss: 5.374651908874512\n",
526
+ "414: loss: 1.2929461002349854\n",
527
+ "415: loss: 0.643798291683197\n",
528
+ "416: loss: 0.6353816986083984\n",
529
+ "417: loss: 5.8032636642456055\n",
530
+ "418: loss: 3.3737053871154785\n",
531
+ "419: loss: 1.8712362051010132\n",
532
+ "420: loss: 1.0622261762619019\n",
533
+ "421: loss: 0.8681365847587585\n",
534
+ "422: loss: 0.6761938333511353\n",
535
+ "423: loss: 4.074782371520996\n",
536
+ "424: loss: 0.4106965661048889\n"
537
+ ]
538
+ },
539
+ {
540
+ "ename": "KeyboardInterrupt",
541
+ "evalue": "",
542
+ "output_type": "error",
543
+ "traceback": [
544
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
545
+ "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
546
+ "Cell \u001b[1;32mIn[4], line 49\u001b[0m\n\u001b[0;32m 46\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m coarse_transformer, trainer, wav2vec, soundstream\n\u001b[0;32m 47\u001b[0m gc\u001b[38;5;241m.\u001b[39mcollect()\n\u001b[1;32m---> 49\u001b[0m \u001b[43mtrain_coarse_transformer\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
547
+ "Cell \u001b[1;32mIn[4], line 43\u001b[0m, in \u001b[0;36mtrain_coarse_transformer\u001b[1;34m()\u001b[0m\n\u001b[0;32m 23\u001b[0m coarse_transformer \u001b[38;5;241m=\u001b[39m CoarseTransformer(\n\u001b[0;32m 24\u001b[0m num_semantic_tokens\u001b[38;5;241m=\u001b[39mwav2vec\u001b[38;5;241m.\u001b[39mcodebook_size,\n\u001b[0;32m 25\u001b[0m codebook_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1024\u001b[39m,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 29\u001b[0m audio_text_condition\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m 30\u001b[0m )\n\u001b[0;32m 32\u001b[0m trainer \u001b[38;5;241m=\u001b[39m CoarseTransformerTrainer(\n\u001b[0;32m 33\u001b[0m transformer\u001b[38;5;241m=\u001b[39mcoarse_transformer,\n\u001b[0;32m 34\u001b[0m codec\u001b[38;5;241m=\u001b[39msoundstream,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 40\u001b[0m num_train_steps\u001b[38;5;241m=\u001b[39mnum_train_steps\n\u001b[0;32m 41\u001b[0m )\n\u001b[1;32m---> 43\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 44\u001b[0m torch\u001b[38;5;241m.\u001b[39msave(coarse_transformer\u001b[38;5;241m.\u001b[39mstate_dict(), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcoarse_transformer.pth\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m 45\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msave coarse_transformer.pth\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
548
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiolm_pytorch\\trainer.py:1302\u001b[0m, in \u001b[0;36mCoarseTransformerTrainer.train\u001b[1;34m(self, log_fn)\u001b[0m\n\u001b[0;32m 1299\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtrain\u001b[39m(\u001b[38;5;28mself\u001b[39m, log_fn \u001b[38;5;241m=\u001b[39m noop):\n\u001b[0;32m 1301\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msteps \u001b[38;5;241m<\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_train_steps:\n\u001b[1;32m-> 1302\u001b[0m logs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1303\u001b[0m log_fn(logs)\n\u001b[0;32m 1305\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprint(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtraining complete\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
549
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiolm_pytorch\\trainer.py:1244\u001b[0m, in \u001b[0;36mCoarseTransformerTrainer.train_step\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 1238\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mautocast(), context():\n\u001b[0;32m 1239\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_wrapper(\n\u001b[0;32m 1240\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mdata_kwargs,\n\u001b[0;32m 1241\u001b[0m return_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m 1242\u001b[0m )\n\u001b[1;32m-> 1244\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maccelerator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m/\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgrad_accum_every\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1246\u001b[0m accum_log(logs, {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mloss\u001b[39m\u001b[38;5;124m'\u001b[39m: loss\u001b[38;5;241m.\u001b[39mitem() \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgrad_accum_every})\n\u001b[0;32m 1248\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m exists(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmax_grad_norm):\n",
550
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\accelerate\\accelerator.py:2151\u001b[0m, in \u001b[0;36mAccelerator.backward\u001b[1;34m(self, loss, **kwargs)\u001b[0m\n\u001b[0;32m 2149\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlomo_backward(loss, learning_rate)\n\u001b[0;32m 2150\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 2151\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
551
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\_tensor.py:525\u001b[0m, in \u001b[0;36mTensor.backward\u001b[1;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[0;32m 515\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m 516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[0;32m 517\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[0;32m 518\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 523\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[0;32m 524\u001b[0m )\n\u001b[1;32m--> 525\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 526\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[0;32m 527\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
552
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\autograd\\__init__.py:267\u001b[0m, in \u001b[0;36mbackward\u001b[1;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[0;32m 262\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[0;32m 264\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[0;32m 265\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[0;32m 266\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[1;32m--> 267\u001b[0m \u001b[43m_engine_run_backward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 268\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 270\u001b[0m \u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 275\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
553
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\autograd\\graph.py:744\u001b[0m, in \u001b[0;36m_engine_run_backward\u001b[1;34m(t_outputs, *args, **kwargs)\u001b[0m\n\u001b[0;32m 742\u001b[0m unregister_hooks \u001b[38;5;241m=\u001b[39m _register_logging_hooks_on_whole_graph(t_outputs)\n\u001b[0;32m 743\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 744\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[0;32m 745\u001b[0m \u001b[43m \u001b[49m\u001b[43mt_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[0;32m 746\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[0;32m 747\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[0;32m 748\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attach_logging_hooks:\n",
554
+ "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
555
+ ]
556
+ }
557
+ ],
558
+ "source": [
559
+ "def train_coarse_transformer():\n",
560
+ " wav2vec = HubertWithKmeans(\n",
561
+ " checkpoint_path=checkpoint_path,\n",
562
+ " kmeans_path=kmeans_path\n",
563
+ " )\n",
564
+ " soundstream = MusicLMSoundStream(\n",
565
+ " codebook_size=1024, # Add this line to specify the codebook size\n",
566
+ " strides=(3, 4, 5, 8),\n",
567
+ " target_sample_hz=24000,\n",
568
+ " rq_num_quantizers=8\n",
569
+ " )\n",
570
+ "\n",
571
+ " if torch.cuda.is_available():\n",
572
+ " coarse_transformer = CoarseTransformer(\n",
573
+ " num_semantic_tokens=wav2vec.codebook_size,\n",
574
+ " codebook_size=1024,\n",
575
+ " num_coarse_quantizers=4,\n",
576
+ " dim=1024,\n",
577
+ " depth=6,\n",
578
+ " audio_text_condition=True\n",
579
+ " ).cuda()\n",
580
+ " else:\n",
581
+ " coarse_transformer = CoarseTransformer(\n",
582
+ " num_semantic_tokens=wav2vec.codebook_size,\n",
583
+ " codebook_size=1024,\n",
584
+ " num_coarse_quantizers=4,\n",
585
+ " dim=1024,\n",
586
+ " depth=6,\n",
587
+ " audio_text_condition=True\n",
588
+ " )\n",
589
+ "\n",
590
+ " trainer = CoarseTransformerTrainer(\n",
591
+ " transformer=coarse_transformer,\n",
592
+ " codec=soundstream,\n",
593
+ " wav2vec=wav2vec,\n",
594
+ " audio_conditioner=quantizer,\n",
595
+ " folder=audio_output_dir,\n",
596
+ " batch_size=batch_size,\n",
597
+ " data_max_length=data_max_length,\n",
598
+ " num_train_steps=num_train_steps\n",
599
+ " )\n",
600
+ "\n",
601
+ " trainer.train()\n",
602
+ " torch.save(coarse_transformer.state_dict(), 'coarse_transformer.pth')\n",
603
+ " print(\"save coarse_transformer.pth\")\n",
604
+ " del coarse_transformer, trainer, wav2vec, soundstream\n",
605
+ " gc.collect()\n",
606
+ "\n",
607
+ "train_coarse_transformer()"
608
+ ]
609
+ }
610
+ ],
611
+ "metadata": {
612
+ "kernelspec": {
613
+ "display_name": "myenv",
614
+ "language": "python",
615
+ "name": "python3"
616
+ },
617
+ "language_info": {
618
+ "codemirror_mode": {
619
+ "name": "ipython",
620
+ "version": 3
621
+ },
622
+ "file_extension": ".py",
623
+ "mimetype": "text/x-python",
624
+ "name": "python",
625
+ "nbconvert_exporter": "python",
626
+ "pygments_lexer": "ipython3",
627
+ "version": "3.11.2"
628
+ }
629
+ },
630
+ "nbformat": 4,
631
+ "nbformat_minor": 2
632
+ }
fine_transformer.ipynb ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Fine Transformer"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "### Libraries:"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 1,
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "import torch\n",
24
+ "from audiolm_pytorch import HubertWithKmeans\n",
25
+ "from audiolm_pytorch import SemanticTransformer, SemanticTransformerTrainer\n",
26
+ "from audiolm_pytorch import CoarseTransformer, CoarseTransformerTrainer\n",
27
+ "from audiolm_pytorch import SoundStream, FineTransformer, FineTransformerTrainer\n",
28
+ "from audiolm_pytorch import AudioLMSoundStream, AudioLM, MusicLMSoundStream\n",
29
+ "from musiclm_pytorch import MuLaNEmbedQuantizer\n",
30
+ "from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer\n",
31
+ "import gc\n",
32
+ "from nltk.tokenize import word_tokenize\n",
33
+ "import nltk\n",
34
+ "import librosa\n",
35
+ "import numpy as np\n",
36
+ "import pickle"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": 2,
42
+ "metadata": {},
43
+ "outputs": [
44
+ {
45
+ "name": "stderr",
46
+ "output_type": "stream",
47
+ "text": [
48
+ "[nltk_data] Downloading package punkt to\n",
49
+ "[nltk_data] C:\\Users\\hp\\AppData\\Roaming\\nltk_data...\n",
50
+ "[nltk_data] Package punkt is already up-to-date!\n"
51
+ ]
52
+ }
53
+ ],
54
+ "source": [
55
+ "nltk.download('punkt')\n",
56
+ "checkpoint_path = './models/hubert/hubert_base_ls960.pt'\n",
57
+ "kmeans_path = './models/hubert/hubert_base_ls960_L9_km500.bin'\n",
58
+ "\n",
59
+ "audio_output_dir = './audio'\n",
60
+ "batch_size = 1\n",
61
+ "data_max_length = 320 * 32\n",
62
+ "num_train_steps = 1000"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": 3,
68
+ "metadata": {},
69
+ "outputs": [
70
+ {
71
+ "name": "stdout",
72
+ "output_type": "stream",
73
+ "text": [
74
+ "training with dataset of 4806 samples and validating with randomly splitted 253 samples\n",
75
+ "spectrogram yielded shape of (65, 841), but had to be cropped to (64, 832) to be patchified for transformer\n",
76
+ "0: loss: 103.04938507080078\n",
77
+ "0: valid loss 11.681041717529297\n",
78
+ "0: saving model to results\n",
79
+ "training complete\n",
80
+ "save fine_transformer.pth\n"
81
+ ]
82
+ }
83
+ ],
84
+ "source": [
85
+ "audio_transformer = AudioSpectrogramTransformer(\n",
86
+ " dim = 512,\n",
87
+ " depth = 6,\n",
88
+ " heads = 8,\n",
89
+ " dim_head = 64,\n",
90
+ " spec_n_fft = 128,\n",
91
+ " spec_win_length = 24,\n",
92
+ " spec_aug_stretch_factor = 0.8\n",
93
+ ")\n",
94
+ "\n",
95
+ "text_transformer = TextTransformer(\n",
96
+ " dim = 512,\n",
97
+ " depth = 6,\n",
98
+ " heads = 8,\n",
99
+ " dim_head = 64\n",
100
+ ")\n",
101
+ "\n",
102
+ "mulan = MuLaN(\n",
103
+ " audio_transformer = audio_transformer,\n",
104
+ " text_transformer = text_transformer\n",
105
+ ")\n",
106
+ "\n",
107
+ "quantizer = MuLaNEmbedQuantizer(\n",
108
+ " mulan = mulan, \n",
109
+ " conditioning_dims = (1024, 1024, 1024), \n",
110
+ " namespaces = ('semantic', 'coarse', 'fine')\n",
111
+ ")\n",
112
+ "\n",
113
+ "\n",
114
+ "def train_fine_transformer():\n",
115
+ " soundstream = MusicLMSoundStream(\n",
116
+ " codebook_size=1024, \n",
117
+ " strides=(3, 4, 5, 8),\n",
118
+ " target_sample_hz=24000,\n",
119
+ " rq_num_quantizers=8\n",
120
+ " )\n",
121
+ "\n",
122
+ " if torch.cuda.is_available():\n",
123
+ " fine_transformer = FineTransformer(\n",
124
+ " num_coarse_quantizers = 4,\n",
125
+ " num_fine_quantizers = 4,\n",
126
+ " codebook_size = 1024,\n",
127
+ " dim = 1024,\n",
128
+ " depth = 6,\n",
129
+ " audio_text_condition = True\n",
130
+ " ).cuda()\n",
131
+ " else:\n",
132
+ " fine_transformer = FineTransformer(\n",
133
+ " num_coarse_quantizers = 4,\n",
134
+ " num_fine_quantizers = 4,\n",
135
+ " codebook_size = 1024,\n",
136
+ " dim = 1024,\n",
137
+ " depth = 6,\n",
138
+ " audio_text_condition = True\n",
139
+ " )\n",
140
+ "\n",
141
+ " trainer = FineTransformerTrainer(\n",
142
+ " transformer=fine_transformer,\n",
143
+ " codec=soundstream,\n",
144
+ " folder=audio_output_dir,\n",
145
+ " batch_size=batch_size,\n",
146
+ " data_max_length=data_max_length,\n",
147
+ " num_train_steps=num_train_steps,\n",
148
+ " audio_conditioner = quantizer\n",
149
+ " )\n",
150
+ "\n",
151
+ " trainer.train()\n",
152
+ " torch.save(fine_transformer.state_dict(), 'fine_transformer.pth')\n",
153
+ " print(\"save fine_transformer.pth\")\n",
154
+ " del fine_transformer, trainer, soundstream\n",
155
+ " gc.collect()\n",
156
+ "\n",
157
+ "\n",
158
+ "train_fine_transformer()"
159
+ ]
160
+ }
161
+ ],
162
+ "metadata": {
163
+ "kernelspec": {
164
+ "display_name": "myenv",
165
+ "language": "python",
166
+ "name": "python3"
167
+ },
168
+ "language_info": {
169
+ "codemirror_mode": {
170
+ "name": "ipython",
171
+ "version": 3
172
+ },
173
+ "file_extension": ".py",
174
+ "mimetype": "text/x-python",
175
+ "name": "python",
176
+ "nbconvert_exporter": "python",
177
+ "pygments_lexer": "ipython3",
178
+ "version": "3.11.2"
179
+ }
180
+ },
181
+ "nbformat": 4,
182
+ "nbformat_minor": 2
183
+ }
musiclm.ipynb ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# AudioLM"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "### Libraries:"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 2,
20
+ "metadata": {},
21
+ "outputs": [
22
+ {
23
+ "name": "stderr",
24
+ "output_type": "stream",
25
+ "text": [
26
+ "2024-07-26 16:06:09 | WARNING | xformers | WARNING[XFORMERS]: xFormers can't load C++/CUDA extensions. xFormers was built for:\n",
27
+ " PyTorch 2.1.0+cu121 with CUDA 1201 (you have 2.1.0+cpu)\n",
28
+ " Python 3.11.6 (you have 3.11.2)\n",
29
+ " Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)\n",
30
+ " Memory-efficient attention, SwiGLU, sparse and more won't be available.\n",
31
+ " Set XFORMERS_MORE_DETAILS=1 for more details\n",
32
+ "2024-07-26 16:06:09 | WARNING | xformers | Triton is not available, some optimizations will not be enabled.\n",
33
+ "This is just a warning: triton is not available\n"
34
+ ]
35
+ }
36
+ ],
37
+ "source": [
38
+ "import torch\n",
39
+ "from audiolm_pytorch import HubertWithKmeans\n",
40
+ "from audiolm_pytorch import SemanticTransformer\n",
41
+ "from audiolm_pytorch import CoarseTransformer\n",
42
+ "from audiolm_pytorch import FineTransformer\n",
43
+ "from audiolm_pytorch import AudioLMSoundStream, AudioLM\n",
44
+ "from musiclm_pytorch import MuLaNEmbedQuantizer\n",
45
+ "from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": 3,
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "checkpoint_path = './models/hubert/hubert_base_ls960.pt'\n",
55
+ "kmeans_path = './models/hubert/hubert_base_ls960_L9_km500.bin'"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": 5,
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "wav2vec = HubertWithKmeans(checkpoint_path=checkpoint_path, kmeans_path=kmeans_path)\n",
65
+ "\n",
66
+ "soundstream = AudioLMSoundStream(\n",
67
+ " codebook_size=1024, # Add this line to specify the codebook size\n",
68
+ " strides=(2, 4, 5, 8),\n",
69
+ " target_sample_hz=16000,\n",
70
+ " rq_num_quantizers=8\n",
71
+ ")\n",
72
+ "\n",
73
+ "\n",
74
+ "if torch.cuda.is_available():\n",
75
+ " semantic_transformer = SemanticTransformer(\n",
76
+ " num_semantic_tokens=wav2vec.codebook_size,\n",
77
+ " dim=1024,\n",
78
+ " depth=6,\n",
79
+ " audio_text_condition=True\n",
80
+ " ).cuda()\n",
81
+ "\n",
82
+ " coarse_transformer = CoarseTransformer(\n",
83
+ " num_semantic_tokens=wav2vec.codebook_size,\n",
84
+ " codebook_size=1024,\n",
85
+ " num_coarse_quantizers=4, # Consistent with training\n",
86
+ " dim=1024,\n",
87
+ " depth=6,\n",
88
+ " audio_text_condition=True\n",
89
+ " ).cuda()\n",
90
+ "\n",
91
+ " fine_transformer = FineTransformer(\n",
92
+ " num_coarse_quantizers=4, # Consistent with training\n",
93
+ " num_fine_quantizers=4,\n",
94
+ " codebook_size=1024,\n",
95
+ " dim=1024,\n",
96
+ " depth=6,\n",
97
+ " audio_text_condition=True\n",
98
+ " ).cuda()\n",
99
+ "else:\n",
100
+ " semantic_transformer = SemanticTransformer(\n",
101
+ " num_semantic_tokens=wav2vec.codebook_size,\n",
102
+ " dim=1024,\n",
103
+ " depth=6,\n",
104
+ " audio_text_condition=True\n",
105
+ " )\n",
106
+ "\n",
107
+ " coarse_transformer = CoarseTransformer(\n",
108
+ " num_semantic_tokens=wav2vec.codebook_size,\n",
109
+ " codebook_size=1024,\n",
110
+ " num_coarse_quantizers=4, # Consistent with training\n",
111
+ " dim=1024,\n",
112
+ " depth=6,\n",
113
+ " audio_text_condition=True\n",
114
+ " )\n",
115
+ "\n",
116
+ " fine_transformer = FineTransformer(\n",
117
+ " num_coarse_quantizers=4, # Consistent with training\n",
118
+ " num_fine_quantizers=4,\n",
119
+ " codebook_size=1024,\n",
120
+ " dim=1024,\n",
121
+ " depth=6,\n",
122
+ " audio_text_condition=True\n",
123
+ " )\n",
124
+ "\n",
125
+ "semantic_transformer.load_state_dict(torch.load('semantic_transformer.pth'))\n",
126
+ "coarse_transformer.load_state_dict(torch.load('coarse_transformer.pth'))\n",
127
+ "fine_transformer.load_state_dict(torch.load('fine_transformer.pth'))\n",
128
+ "\n",
129
+ "audiolm = AudioLM(\n",
130
+ " wav2vec=wav2vec,\n",
131
+ " codec=soundstream,\n",
132
+ " semantic_transformer=semantic_transformer,\n",
133
+ " coarse_transformer=coarse_transformer,\n",
134
+ " fine_transformer=fine_transformer\n",
135
+ ")\n"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "markdown",
140
+ "metadata": {},
141
+ "source": [
142
+ "# MuLaN"
143
+ ]
144
+ },
145
+ {
146
+ "cell_type": "code",
147
+ "execution_count": 6,
148
+ "metadata": {},
149
+ "outputs": [],
150
+ "source": [
151
+ "audio_transformer = AudioSpectrogramTransformer(\n",
152
+ " dim = 512,\n",
153
+ " depth = 6,\n",
154
+ " heads = 8,\n",
155
+ " dim_head = 64,\n",
156
+ " spec_n_fft = 128,\n",
157
+ " spec_win_length = 24,\n",
158
+ " spec_aug_stretch_factor = 0.8\n",
159
+ ")\n",
160
+ "\n",
161
+ "text_transformer = TextTransformer(\n",
162
+ " dim = 512,\n",
163
+ " depth = 6,\n",
164
+ " heads = 8,\n",
165
+ " dim_head = 64\n",
166
+ ")\n",
167
+ "\n",
168
+ "mulan = MuLaN(\n",
169
+ " audio_transformer = audio_transformer,\n",
170
+ " text_transformer = text_transformer\n",
171
+ ")\n",
172
+ "\n",
173
+ "quantizer = MuLaNEmbedQuantizer(\n",
174
+ " mulan = mulan, \n",
175
+ " conditioning_dims = (1024, 1024, 1024), \n",
176
+ " namespaces = ('semantic', 'coarse', 'fine')\n",
177
+ ")\n"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "markdown",
182
+ "metadata": {},
183
+ "source": [
184
+ "# MusicLM"
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "execution_count": 7,
190
+ "metadata": {},
191
+ "outputs": [],
192
+ "source": [
193
+ "from musiclm_pytorch import MusicLM\n",
194
+ "\n",
195
+ "if torch.cuda.is_available():\n",
196
+ " musiclm = MusicLM(\n",
197
+ " audio_lm = audiolm,\n",
198
+ " mulan_embed_quantizer = quantizer\n",
199
+ " ).cuda()\n",
200
+ "else:\n",
201
+ " musiclm = MusicLM(\n",
202
+ " audio_lm = audiolm,\n",
203
+ " mulan_embed_quantizer = quantizer\n",
204
+ " )"
205
+ ]
206
+ },
207
+ {
208
+ "cell_type": "markdown",
209
+ "metadata": {},
210
+ "source": [
211
+ "# Inference:"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": 10,
217
+ "metadata": {},
218
+ "outputs": [
219
+ {
220
+ "name": "stdout",
221
+ "output_type": "stream",
222
+ "text": [
223
+ " 31 / 403\r"
224
+ ]
225
+ },
226
+ {
227
+ "ename": "KeyboardInterrupt",
228
+ "evalue": "",
229
+ "output_type": "error",
230
+ "traceback": [
231
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
232
+ "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
233
+ "Cell \u001b[1;32mIn[10], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[43mMusiclm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\n\u001b[0;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mcrazy EDM, heavy bang\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 3\u001b[0m \u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[0;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43mprogress\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 5\u001b[0m display_audio(res, \u001b[38;5;241m32000\u001b[39m)\n",
234
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\models\\genmodel.py:161\u001b[0m, in \u001b[0;36mBaseGenModel.generate\u001b[1;34m(self, descriptions, progress, return_tokens)\u001b[0m\n\u001b[0;32m 159\u001b[0m attributes, prompt_tokens \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_prepare_tokens_and_attributes(descriptions, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m 160\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m prompt_tokens \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m--> 161\u001b[0m tokens \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_generate_tokens\u001b[49m\u001b[43m(\u001b[49m\u001b[43mattributes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprompt_tokens\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprogress\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 162\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m return_tokens:\n\u001b[0;32m 163\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgenerate_audio(tokens), tokens\n",
235
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\models\\musicgen.py:256\u001b[0m, in \u001b[0;36mMusicGen._generate_tokens\u001b[1;34m(self, attributes, prompt_tokens, progress)\u001b[0m\n\u001b[0;32m 253\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mduration \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmax_duration:\n\u001b[0;32m 254\u001b[0m \u001b[38;5;66;03m# generate by sampling from LM, simple case.\u001b[39;00m\n\u001b[0;32m 255\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mautocast:\n\u001b[1;32m--> 256\u001b[0m gen_tokens \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 257\u001b[0m \u001b[43m \u001b[49m\u001b[43mprompt_tokens\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattributes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 258\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcallback\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_gen_len\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtotal_gen_len\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgeneration_params\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 260\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 261\u001b[0m \u001b[38;5;66;03m# now this gets a bit messier, we need to handle prompts,\u001b[39;00m\n\u001b[0;32m 262\u001b[0m \u001b[38;5;66;03m# melody conditioning etc.\u001b[39;00m\n\u001b[0;32m 263\u001b[0m ref_wavs \u001b[38;5;241m=\u001b[39m [attr\u001b[38;5;241m.\u001b[39mwav[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mself_wav\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m attr \u001b[38;5;129;01min\u001b[39;00m attributes]\n",
236
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\utils\\_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[0;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[1;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
237
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\models\\lm.py:510\u001b[0m, in \u001b[0;36mLMModel.generate\u001b[1;34m(self, prompt, conditions, num_samples, max_gen_len, use_sampling, temp, top_k, top_p, cfg_coef, two_step_cfg, remove_prompts, check, callback, **kwargs)\u001b[0m\n\u001b[0;32m 508\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (curr_sequence \u001b[38;5;241m==\u001b[39m unknown_token)\u001b[38;5;241m.\u001b[39many()\n\u001b[0;32m 509\u001b[0m \u001b[38;5;66;03m# sample next token from the model, next token shape is [B, K, 1]\u001b[39;00m\n\u001b[1;32m--> 510\u001b[0m next_token \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sample_next_token\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 511\u001b[0m \u001b[43m \u001b[49m\u001b[43mcurr_sequence\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcfg_conditions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43munconditional_state\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43muse_sampling\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtemp\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtop_k\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtop_p\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 512\u001b[0m \u001b[43m \u001b[49m\u001b[43mcfg_coef\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcfg_coef\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtwo_step_cfg\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtwo_step_cfg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 513\u001b[0m \u001b[38;5;66;03m# ensure the tokens that should be masked are properly set to special_token_id\u001b[39;00m\n\u001b[0;32m 514\u001b[0m \u001b[38;5;66;03m# as the model never output special_token_id\u001b[39;00m\n\u001b[0;32m 515\u001b[0m valid_mask \u001b[38;5;241m=\u001b[39m mask[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m, offset:offset\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39mexpand(B, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n",
238
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\models\\lm.py:369\u001b[0m, in \u001b[0;36mLMModel._sample_next_token\u001b[1;34m(self, sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p, cfg_coef, two_step_cfg)\u001b[0m\n\u001b[0;32m 366\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m condition_tensors:\n\u001b[0;32m 367\u001b[0m \u001b[38;5;66;03m# Preparing for CFG, predicting both conditional and unconditional logits.\u001b[39;00m\n\u001b[0;32m 368\u001b[0m sequence \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat([sequence, sequence], dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m--> 369\u001b[0m all_logits \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 370\u001b[0m \u001b[43m \u001b[49m\u001b[43msequence\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 371\u001b[0m \u001b[43m \u001b[49m\u001b[43mconditions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcondition_tensors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcondition_tensors\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 372\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m condition_tensors:\n\u001b[0;32m 373\u001b[0m cond_logits, uncond_logits \u001b[38;5;241m=\u001b[39m all_logits\u001b[38;5;241m.\u001b[39msplit(B, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m) \u001b[38;5;66;03m# [B, K, T, card]\u001b[39;00m\n",
239
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
240
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
241
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\models\\lm.py:257\u001b[0m, in \u001b[0;36mLMModel.forward\u001b[1;34m(self, sequence, conditions, condition_tensors, stage)\u001b[0m\n\u001b[0;32m 253\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m conditions, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mShouldn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt pass both conditions and condition_tensors.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 255\u001b[0m input_, cross_attention_input \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfuser(input_, condition_tensors)\n\u001b[1;32m--> 257\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtransformer\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcross_attention_src\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcross_attention_input\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 258\u001b[0m \u001b[43m \u001b[49m\u001b[43msrc_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattn_mask_per_stage\u001b[49m\u001b[43m[\u001b[49m\u001b[43mstage\u001b[49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mstage\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m>\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 259\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mout_norm:\n\u001b[0;32m 260\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mout_norm(out)\n",
242
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
243
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
244
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\modules\\transformer.py:708\u001b[0m, in \u001b[0;36mStreamingTransformer.forward\u001b[1;34m(self, x, *args, **kwargs)\u001b[0m\n\u001b[0;32m 705\u001b[0m x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpositional_scale \u001b[38;5;241m*\u001b[39m pos_emb\n\u001b[0;32m 707\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m layer \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayers:\n\u001b[1;32m--> 708\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply_layer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlayer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 710\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_is_streaming:\n\u001b[0;32m 711\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_streaming_state[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124moffsets\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m offsets \u001b[38;5;241m+\u001b[39m T\n",
245
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\modules\\transformer.py:665\u001b[0m, in \u001b[0;36mStreamingTransformer._apply_layer\u001b[1;34m(self, layer, *args, **kwargs)\u001b[0m\n\u001b[0;32m 663\u001b[0m method \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcheckpointing\n\u001b[0;32m 664\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m method \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mnone\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m--> 665\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mlayer\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 666\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m method \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtorch\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[0;32m 667\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m torch_checkpoint(layer, \u001b[38;5;241m*\u001b[39margs, use_reentrant\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
246
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
247
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
248
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\modules\\transformer.py:563\u001b[0m, in \u001b[0;36mStreamingTransformerLayer.forward\u001b[1;34m(self, src, src_mask, src_key_padding_mask, cross_attention_src)\u001b[0m\n\u001b[0;32m 559\u001b[0m x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayer_scale_1(\n\u001b[0;32m 560\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sa_block(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm1(x), src_mask, src_key_padding_mask))\n\u001b[0;32m 561\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m cross_attention_src \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m 562\u001b[0m x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayer_scale_cross(\n\u001b[1;32m--> 563\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_cross_attention_block\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 564\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnorm_cross\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcross_attention_src\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[0;32m 565\u001b[0m x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayer_scale_2(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ff_block(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm2(x)))\n\u001b[0;32m 566\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
249
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\modules\\transformer.py:546\u001b[0m, in \u001b[0;36mStreamingTransformerLayer._cross_attention_block\u001b[1;34m(self, src, cross_attention_src)\u001b[0m\n\u001b[0;32m 544\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcross_attention \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m 545\u001b[0m \u001b[38;5;66;03m# queries are from src, keys and values from cross_attention_src.\u001b[39;00m\n\u001b[1;32m--> 546\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcross_attention\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 547\u001b[0m \u001b[43m \u001b[49m\u001b[43msrc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcross_attention_src\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcross_attention_src\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mneed_weights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m 548\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdropout_cross(x)\n",
250
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
251
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
252
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\modules\\transformer.py:356\u001b[0m, in \u001b[0;36mStreamingMultiheadAttention.forward\u001b[1;34m(self, query, key, value, key_padding_mask, need_weights, attn_mask, average_attn_weights, is_causal)\u001b[0m\n\u001b[0;32m 354\u001b[0m q \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mfunctional\u001b[38;5;241m.\u001b[39mlinear(query, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39min_proj_weight[:dim], bias_q)\n\u001b[0;32m 355\u001b[0m \u001b[38;5;66;03m# todo: when streaming, we could actually save k, v and check the shape actually match.\u001b[39;00m\n\u001b[1;32m--> 356\u001b[0m k \u001b[38;5;241m=\u001b[39m \u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfunctional\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43min_proj_weight\u001b[49m\u001b[43m[\u001b[49m\u001b[43mdim\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias_k\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 357\u001b[0m v \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mfunctional\u001b[38;5;241m.\u001b[39mlinear(value, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39min_proj_weight[\u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m dim:], bias_v)\n\u001b[0;32m 358\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mqk_layer_norm \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n",
253
+ "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
254
+ ]
255
+ }
256
+ ],
257
+ "source": [
258
+ "music = musiclm('the crystalline sounds of the piano in a ballroom', num_samples = 4)"
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "execution_count": null,
264
+ "metadata": {},
265
+ "outputs": [],
266
+ "source": [
267
+ "torch.save(music, 'generated_music.pt')"
268
+ ]
269
+ },
270
+ {
271
+ "cell_type": "code",
272
+ "execution_count": null,
273
+ "metadata": {},
274
+ "outputs": [],
275
+ "source": [
276
+ "import torchaudio\n",
277
+ "output_path = \"out.wav\"\n",
278
+ "sample_rate = 44100\n",
279
+ "torchaudio.save(output_path, music.cpu() , sample_rate)"
280
+ ]
281
+ }
282
+ ],
283
+ "metadata": {
284
+ "kernelspec": {
285
+ "display_name": "myenv",
286
+ "language": "python",
287
+ "name": "python3"
288
+ },
289
+ "language_info": {
290
+ "codemirror_mode": {
291
+ "name": "ipython",
292
+ "version": 3
293
+ },
294
+ "file_extension": ".py",
295
+ "mimetype": "text/x-python",
296
+ "name": "python",
297
+ "nbconvert_exporter": "python",
298
+ "pygments_lexer": "ipython3",
299
+ "version": "3.11.2"
300
+ }
301
+ },
302
+ "nbformat": 4,
303
+ "nbformat_minor": 2
304
+ }
semantic_transformer.ipynb ADDED
@@ -0,0 +1,851 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Semantic Transformer"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "### Libraries"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 1,
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "import torch\n",
24
+ "import multiprocessing\n",
25
+ "from audiolm_pytorch import HubertWithKmeans, MusicLMSoundStream\n",
26
+ "from audiolm_pytorch import SemanticTransformer, SemanticTransformerTrainer\n",
27
+ "from audiolm_pytorch import CoarseTransformer, CoarseTransformerTrainer\n",
28
+ "from audiolm_pytorch import FineTransformer, FineTransformerTrainer\n",
29
+ "from musiclm_pytorch import MuLaNEmbedQuantizer\n",
30
+ "from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer\n",
31
+ "import gc "
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": 2,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "checkpoint_path = './models/hubert/hubert_base_ls960.pt'\n",
41
+ "kmeans_path = './models/hubert/hubert_base_ls960_L9_km500.bin'\n",
42
+ "audio_output_dir = './audio'\n",
43
+ "batch_size = 1\n",
44
+ "data_max_length = 320 * 32\n",
45
+ "num_train_steps = 1000"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": 3,
51
+ "metadata": {},
52
+ "outputs": [
53
+ {
54
+ "name": "stdout",
55
+ "output_type": "stream",
56
+ "text": [
57
+ "spectrogram yielded shape of (65, 86), but had to be cropped to (64, 80) to be patchified for transformer\n",
58
+ "ANTLR runtime and generated code versions disagree: 4.9.3!=4.8\n",
59
+ "ANTLR runtime and generated code versions disagree: 4.9.3!=4.8\n",
60
+ "training with dataset of 4806 samples and validating with randomly splitted 253 samples\n",
61
+ "0: loss: 6.5572309494018555\n",
62
+ "0: valid loss 6.723005294799805\n",
63
+ "0: saving model to results\n",
64
+ "1: loss: 6.5375285148620605\n",
65
+ "2: loss: 5.515031337738037\n",
66
+ "3: loss: 0.6989991664886475\n",
67
+ "4: loss: 0.016623886302113533\n",
68
+ "5: loss: 6.3969268798828125\n",
69
+ "6: loss: 0.8643577098846436\n",
70
+ "7: loss: 0.008508207276463509\n",
71
+ "8: loss: 0.00020680516900029033\n",
72
+ "9: loss: 8.900370597839355\n",
73
+ "10: loss: 0.00010900969209615141\n",
74
+ "11: loss: 0.0001591881300555542\n",
75
+ "12: loss: 8.055902481079102\n",
76
+ "13: loss: 0.0009496303973719478\n",
77
+ "14: loss: 0.0027423782739788294\n",
78
+ "15: loss: 0.0009589337860234082\n",
79
+ "16: loss: 7.296541690826416\n",
80
+ "17: loss: 0.0005210856324993074\n",
81
+ "18: loss: 0.0008424322586506605\n",
82
+ "19: loss: 5.571179389953613\n",
83
+ "20: loss: 0.003094581188634038\n",
84
+ "21: loss: 0.0019461463671177626\n",
85
+ "22: loss: 5.488490104675293\n",
86
+ "23: loss: 4.800296783447266\n",
87
+ "24: loss: 4.962136268615723\n",
88
+ "25: loss: 5.943732738494873\n",
89
+ "26: loss: 0.006312617566436529\n",
90
+ "27: loss: 4.396454334259033\n",
91
+ "28: loss: 0.012498963624238968\n",
92
+ "29: loss: 0.0049488842487335205\n",
93
+ "30: loss: 0.0011625693878158927\n",
94
+ "31: loss: 3.445856809616089\n",
95
+ "32: loss: 0.000534387887455523\n",
96
+ "33: loss: 0.000711498549208045\n",
97
+ "34: loss: 0.0009514373959973454\n",
98
+ "35: loss: 0.001239188713952899\n",
99
+ "36: loss: 8.732012748718262\n",
100
+ "37: loss: 0.0009216524777002633\n",
101
+ "38: loss: 0.0006809335318394005\n",
102
+ "39: loss: 0.000797786982730031\n",
103
+ "40: loss: 4.916833400726318\n",
104
+ "41: loss: 0.0010107718408107758\n",
105
+ "42: loss: 0.0008451942121610045\n",
106
+ "43: loss: 3.160980701446533\n",
107
+ "44: loss: 0.0008387335110455751\n",
108
+ "45: loss: 0.0010360947344452143\n",
109
+ "46: loss: 0.001215349417179823\n",
110
+ "47: loss: 5.990973949432373\n",
111
+ "48: loss: 0.0017369053093716502\n",
112
+ "49: loss: 6.410669803619385\n",
113
+ "50: loss: 0.003450337564572692\n",
114
+ "51: loss: 0.003860922297462821\n",
115
+ "52: loss: 0.002359303878620267\n",
116
+ "53: loss: 0.001058467198163271\n",
117
+ "54: loss: 0.00047752217506058514\n",
118
+ "55: loss: 0.00025489379186183214\n",
119
+ "56: loss: 0.00016276698443107307\n",
120
+ "57: loss: 7.828070163726807\n",
121
+ "58: loss: 0.00011652028479147702\n",
122
+ "59: loss: 4.505963325500488\n",
123
+ "60: loss: 0.00013153781765140593\n",
124
+ "61: loss: 0.00015024915046524256\n",
125
+ "62: loss: 0.00017777853645384312\n",
126
+ "63: loss: 8.09732437133789\n",
127
+ "64: loss: 0.00041875039460137486\n",
128
+ "65: loss: 0.0009824583539739251\n",
129
+ "66: loss: 0.001990197692066431\n",
130
+ "67: loss: 5.392111778259277\n",
131
+ "68: loss: 0.0017270153621211648\n",
132
+ "69: loss: 0.0010434042196720839\n",
133
+ "70: loss: 0.0005951145431026816\n",
134
+ "71: loss: 0.00037293724017217755\n",
135
+ "72: loss: 0.00025969729176722467\n",
136
+ "73: loss: 7.013213157653809\n",
137
+ "74: loss: 3.807203531265259\n",
138
+ "75: loss: 0.00026780215557664633\n",
139
+ "76: loss: 0.00031897667213343084\n",
140
+ "77: loss: 0.0003657388442661613\n",
141
+ "78: loss: 5.076975345611572\n",
142
+ "79: loss: 0.001055362867191434\n",
143
+ "80: loss: 0.0010116726625710726\n",
144
+ "81: loss: 0.0017484871204942465\n",
145
+ "82: loss: 0.0018696936313062906\n",
146
+ "83: loss: 5.30266809463501\n",
147
+ "84: loss: 5.457505226135254\n",
148
+ "85: loss: 0.0012204349040985107\n",
149
+ "86: loss: 3.2936503887176514\n",
150
+ "87: loss: 0.0020471797324717045\n",
151
+ "88: loss: 0.0026046710554510355\n",
152
+ "89: loss: 0.0026721167378127575\n",
153
+ "90: loss: 0.0024667021352797747\n",
154
+ "91: loss: 5.0201215744018555\n",
155
+ "92: loss: 4.591504096984863\n",
156
+ "93: loss: 0.0025711969938129187\n",
157
+ "94: loss: 0.002706416416913271\n",
158
+ "95: loss: 0.0024713831953704357\n",
159
+ "96: loss: 0.002004373585805297\n",
160
+ "97: loss: 0.001489203074015677\n",
161
+ "98: loss: 0.0010426173685118556\n",
162
+ "99: loss: 8.796974182128906\n",
163
+ "100: loss: 0.0005365900578908622\n",
164
+ "100: valid loss 5.255128860473633\n",
165
+ "101: loss: 0.0004417159070726484\n",
166
+ "102: loss: 4.595282554626465\n",
167
+ "103: loss: 0.000659952696878463\n",
168
+ "104: loss: 0.0008260122267529368\n",
169
+ "105: loss: 0.0009083786280825734\n",
170
+ "106: loss: 4.042155742645264\n",
171
+ "107: loss: 4.17121696472168\n",
172
+ "108: loss: 0.0007671767962165177\n",
173
+ "109: loss: 4.022541522979736\n",
174
+ "110: loss: 3.5455234050750732\n",
175
+ "111: loss: 0.001035561435855925\n",
176
+ "112: loss: 0.0012967187212780118\n",
177
+ "113: loss: 7.237168312072754\n",
178
+ "114: loss: 3.522667407989502\n",
179
+ "115: loss: 0.004003542009741068\n",
180
+ "116: loss: 0.0040553268045187\n",
181
+ "117: loss: 0.0029700316954404116\n",
182
+ "118: loss: 0.0019125432008877397\n",
183
+ "119: loss: 3.4947195053100586\n",
184
+ "120: loss: 0.001095975050702691\n",
185
+ "121: loss: 0.0009612821158953011\n",
186
+ "122: loss: 0.000824352668132633\n",
187
+ "123: loss: 3.3077425956726074\n",
188
+ "124: loss: 0.0007418203167617321\n",
189
+ "125: loss: 0.0007488489500246942\n",
190
+ "126: loss: 0.0007235489320009947\n",
191
+ "127: loss: 3.426555633544922\n",
192
+ "128: loss: 0.0006980476318858564\n",
193
+ "129: loss: 0.0006986281368881464\n",
194
+ "130: loss: 0.0006706370622850955\n",
195
+ "131: loss: 0.0006185953388921916\n",
196
+ "132: loss: 4.421964645385742\n",
197
+ "133: loss: 0.0006264401017688215\n",
198
+ "134: loss: 0.0006876335828565061\n",
199
+ "135: loss: 0.0007215599762275815\n",
200
+ "136: loss: 0.0007203654968179762\n",
201
+ "137: loss: 0.0006922150496393442\n",
202
+ "138: loss: 0.0006356032681651413\n",
203
+ "139: loss: 3.7695367336273193\n",
204
+ "140: loss: 0.0006305422284640372\n",
205
+ "141: loss: 0.0006744156708009541\n",
206
+ "142: loss: 0.0006895355763845146\n",
207
+ "143: loss: 3.770907402038574\n",
208
+ "144: loss: 0.000908360059838742\n",
209
+ "145: loss: 0.0011299465550109744\n",
210
+ "146: loss: 0.0012696339981630445\n",
211
+ "147: loss: 0.0012722468236461282\n",
212
+ "148: loss: 3.8808021545410156\n",
213
+ "149: loss: 3.783026695251465\n",
214
+ "150: loss: 0.002035590121522546\n",
215
+ "151: loss: 0.0026034933980554342\n",
216
+ "152: loss: 0.0024936539120972157\n",
217
+ "153: loss: 0.0018582777120172977\n",
218
+ "154: loss: 2.8572535514831543\n",
219
+ "155: loss: 0.001062657218426466\n",
220
+ "156: loss: 0.0008821044466458261\n",
221
+ "157: loss: 0.0007058316841721535\n",
222
+ "158: loss: 0.0005539683043025434\n",
223
+ "159: loss: 5.476413726806641\n",
224
+ "160: loss: 0.00043070572428405285\n",
225
+ "161: loss: 0.00042034301441162825\n",
226
+ "162: loss: 0.0004015824815724045\n",
227
+ "163: loss: 0.0003759717510547489\n",
228
+ "164: loss: 0.00034577338374219835\n",
229
+ "165: loss: 3.9209775924682617\n",
230
+ "166: loss: 0.0003425567992962897\n",
231
+ "167: loss: 0.00036322552477940917\n",
232
+ "168: loss: 0.00037287475424818695\n",
233
+ "169: loss: 7.9045209884643555\n",
234
+ "170: loss: 0.0004473228473216295\n",
235
+ "171: loss: 0.0005134259699843824\n",
236
+ "172: loss: 2.9501657485961914\n",
237
+ "173: loss: 0.0008285943185910583\n",
238
+ "174: loss: 0.00113466486800462\n",
239
+ "175: loss: 0.0013167448341846466\n",
240
+ "176: loss: 0.0014080735854804516\n",
241
+ "177: loss: 4.0473408699035645\n",
242
+ "178: loss: 0.0016744763124734163\n",
243
+ "179: loss: 0.0016492144204676151\n",
244
+ "180: loss: 4.165207386016846\n",
245
+ "181: loss: 0.0017677460564300418\n",
246
+ "182: loss: 0.0018474040552973747\n",
247
+ "183: loss: 0.0017496442887932062\n",
248
+ "184: loss: 3.3882288932800293\n",
249
+ "185: loss: 0.0018872346263378859\n",
250
+ "186: loss: 3.0333187580108643\n",
251
+ "187: loss: 0.0028638774529099464\n",
252
+ "188: loss: 3.709534168243408\n",
253
+ "189: loss: 6.904417991638184\n",
254
+ "190: loss: 0.006619434338063002\n",
255
+ "191: loss: 0.00595641415566206\n",
256
+ "192: loss: 5.050801753997803\n",
257
+ "193: loss: 3.7556490898132324\n",
258
+ "194: loss: 0.002467694692313671\n",
259
+ "195: loss: 0.002025420544669032\n",
260
+ "196: loss: 0.001494809053838253\n",
261
+ "197: loss: 0.0010330628138035536\n",
262
+ "198: loss: 0.0006917425780557096\n",
263
+ "199: loss: 0.0004644835426006466\n",
264
+ "200: loss: 4.2029547691345215\n",
265
+ "200: valid loss 0.00025821171584539115\n",
266
+ "201: loss: 4.2771501541137695\n",
267
+ "202: loss: 3.7102839946746826\n",
268
+ "203: loss: 3.7408058643341064\n",
269
+ "204: loss: 0.0003981325135100633\n",
270
+ "205: loss: 0.0005507581518031657\n",
271
+ "206: loss: 4.065332889556885\n",
272
+ "207: loss: 0.0011804178357124329\n",
273
+ "208: loss: 0.0017080714460462332\n",
274
+ "209: loss: 0.0021062048617750406\n",
275
+ "210: loss: 0.0021494715474545956\n",
276
+ "211: loss: 6.465389251708984\n",
277
+ "212: loss: 0.0029505854472517967\n",
278
+ "213: loss: 3.367213010787964\n",
279
+ "214: loss: 0.27502918243408203\n",
280
+ "215: loss: 0.9933775663375854\n",
281
+ "216: loss: 0.810478925704956\n",
282
+ "217: loss: 0.4562891721725464\n",
283
+ "218: loss: 0.24387648701667786\n",
284
+ "219: loss: 0.11290910840034485\n",
285
+ "220: loss: 0.019248925149440765\n",
286
+ "221: loss: 0.0021138046868145466\n",
287
+ "222: loss: 5.169565200805664\n",
288
+ "223: loss: 0.0008601757581345737\n",
289
+ "224: loss: 3.9269232749938965\n",
290
+ "225: loss: 0.0007863161154091358\n",
291
+ "226: loss: 0.00024547570501454175\n",
292
+ "227: loss: 4.449281215667725\n",
293
+ "228: loss: 0.00019524114031810313\n",
294
+ "229: loss: 5.162830829620361\n",
295
+ "230: loss: 0.0005567128537222743\n",
296
+ "231: loss: 4.195521831512451\n",
297
+ "232: loss: 3.7389187812805176\n",
298
+ "233: loss: 5.919421672821045\n",
299
+ "234: loss: 6.7034173011779785\n",
300
+ "235: loss: 5.353506088256836\n",
301
+ "236: loss: 2.4018566608428955\n",
302
+ "237: loss: 3.7457311153411865\n",
303
+ "238: loss: 0.17652225494384766\n",
304
+ "239: loss: 4.564880847930908\n",
305
+ "240: loss: 0.027039170265197754\n",
306
+ "241: loss: 0.005270962603390217\n",
307
+ "242: loss: 0.0015485308831557631\n",
308
+ "243: loss: 0.0010360399028286338\n",
309
+ "244: loss: 0.0007773903198540211\n",
310
+ "245: loss: 6.206174850463867\n",
311
+ "246: loss: 6.409456253051758\n",
312
+ "247: loss: 0.04051050543785095\n",
313
+ "248: loss: 0.0017684113699942827\n",
314
+ "249: loss: 0.00044090740266256034\n",
315
+ "250: loss: 5.761023044586182\n",
316
+ "251: loss: 0.00016311556100845337\n",
317
+ "252: loss: 0.0001715785765554756\n",
318
+ "253: loss: 0.00019523760420270264\n",
319
+ "254: loss: 0.00023307953961193562\n",
320
+ "255: loss: 0.00028373271925374866\n",
321
+ "256: loss: 4.927147388458252\n",
322
+ "257: loss: 4.228280544281006\n",
323
+ "258: loss: 0.0011933923233300447\n",
324
+ "259: loss: 0.005215882323682308\n",
325
+ "260: loss: 0.0013388781808316708\n",
326
+ "261: loss: 4.206026554107666\n",
327
+ "262: loss: 0.0034830207005143166\n",
328
+ "263: loss: 4.173500061035156\n",
329
+ "264: loss: 0.007450783159583807\n",
330
+ "265: loss: 4.5892510414123535\n",
331
+ "266: loss: 0.006880312692373991\n",
332
+ "267: loss: 4.572935104370117\n",
333
+ "268: loss: 0.002904222346842289\n",
334
+ "269: loss: 3.2348222732543945\n",
335
+ "270: loss: 4.376621723175049\n",
336
+ "271: loss: 3.573988914489746\n",
337
+ "272: loss: 0.0010127610294148326\n",
338
+ "273: loss: 9.308874130249023\n",
339
+ "274: loss: 4.688360214233398\n",
340
+ "275: loss: 3.9581832885742188\n",
341
+ "276: loss: 0.01065391581505537\n",
342
+ "277: loss: 0.0067514535039663315\n",
343
+ "278: loss: 0.003611961379647255\n",
344
+ "279: loss: 0.001811509020626545\n",
345
+ "280: loss: 0.0009013370145112276\n",
346
+ "281: loss: 4.266546726226807\n",
347
+ "282: loss: 5.132745742797852\n",
348
+ "283: loss: 0.000957090116571635\n",
349
+ "284: loss: 0.0015025322791188955\n",
350
+ "285: loss: 6.258731842041016\n",
351
+ "286: loss: 5.029386043548584\n",
352
+ "287: loss: 0.007954631000757217\n",
353
+ "288: loss: 0.0050008054822683334\n",
354
+ "289: loss: 0.001655810745432973\n",
355
+ "290: loss: 5.501289367675781\n",
356
+ "291: loss: 4.655749797821045\n",
357
+ "292: loss: 4.383106231689453\n",
358
+ "293: loss: 0.000304496381431818\n",
359
+ "294: loss: 0.0003326725563965738\n",
360
+ "295: loss: 0.00035310350358486176\n",
361
+ "296: loss: 5.683162212371826\n",
362
+ "297: loss: 0.0004622728156391531\n",
363
+ "298: loss: 4.067113399505615\n",
364
+ "299: loss: 0.0008154112147167325\n",
365
+ "300: loss: 0.00108420941978693\n",
366
+ "300: valid loss 0.0013179074740037322\n",
367
+ "301: loss: 0.0013179074740037322\n",
368
+ "302: loss: 4.358561992645264\n",
369
+ "303: loss: 5.026749610900879\n",
370
+ "304: loss: 0.002862808993086219\n",
371
+ "305: loss: 0.003396229352802038\n",
372
+ "306: loss: 5.530904293060303\n",
373
+ "307: loss: 0.0035779180470854044\n",
374
+ "308: loss: 0.003205555956810713\n",
375
+ "309: loss: 4.112671852111816\n",
376
+ "310: loss: 3.6920313835144043\n",
377
+ "311: loss: 0.0026951604522764683\n",
378
+ "312: loss: 0.0026851999573409557\n",
379
+ "313: loss: 3.3092551231384277\n",
380
+ "314: loss: 0.0024079573340713978\n",
381
+ "315: loss: 0.0022026696242392063\n",
382
+ "316: loss: 0.0018284200923517346\n",
383
+ "317: loss: 0.0014258958399295807\n",
384
+ "318: loss: 0.0010761057492345572\n",
385
+ "319: loss: 0.0008039181702770293\n",
386
+ "320: loss: 0.0006038622814230621\n",
387
+ "321: loss: 0.00046244796249084175\n",
388
+ "322: loss: 5.89370059967041\n",
389
+ "323: loss: 0.00031747910543344915\n",
390
+ "324: loss: 0.00028221303364261985\n",
391
+ "325: loss: 0.00025451104738749564\n",
392
+ "326: loss: 0.00023175252135843039\n",
393
+ "327: loss: 0.00021364034910220653\n",
394
+ "328: loss: 3.906613826751709\n",
395
+ "329: loss: 3.844726085662842\n",
396
+ "330: loss: 0.00023705456987954676\n",
397
+ "331: loss: 0.0002663657069206238\n",
398
+ "332: loss: 0.0002947220054920763\n",
399
+ "333: loss: 6.28004264831543\n",
400
+ "334: loss: 0.0003821635036729276\n",
401
+ "335: loss: 3.633335828781128\n",
402
+ "336: loss: 0.0005681345355696976\n",
403
+ "337: loss: 6.994467735290527\n",
404
+ "338: loss: 7.915759086608887\n",
405
+ "339: loss: 0.0026061832904815674\n",
406
+ "340: loss: 0.0048998151905834675\n",
407
+ "341: loss: 0.004243680741637945\n",
408
+ "342: loss: 0.0025005636271089315\n",
409
+ "343: loss: 4.005818843841553\n",
410
+ "344: loss: 0.0011636920971795917\n",
411
+ "345: loss: 0.0009634271846152842\n",
412
+ "346: loss: 0.0008427661377936602\n",
413
+ "347: loss: 0.0007607618463225663\n",
414
+ "348: loss: 0.0006956492434255779\n",
415
+ "349: loss: 4.547393798828125\n",
416
+ "350: loss: 0.0006480301963165402\n",
417
+ "351: loss: 0.0006520788883790374\n",
418
+ "352: loss: 0.0006446384941227734\n",
419
+ "353: loss: 4.283820629119873\n",
420
+ "354: loss: 0.0007140468223951757\n",
421
+ "355: loss: 0.000788742327131331\n",
422
+ "356: loss: 0.0008332571596838534\n",
423
+ "357: loss: 0.0008390303701162338\n",
424
+ "358: loss: 0.000806896947324276\n",
425
+ "359: loss: 4.646646976470947\n",
426
+ "360: loss: 0.0021708165295422077\n",
427
+ "361: loss: 0.0009108624653890729\n",
428
+ "362: loss: 3.9582133293151855\n",
429
+ "363: loss: 3.3569955825805664\n",
430
+ "364: loss: 0.002499263733625412\n",
431
+ "365: loss: 4.646510601043701\n",
432
+ "366: loss: 0.0032457842025905848\n",
433
+ "367: loss: 0.0033331059385091066\n",
434
+ "368: loss: 0.00275675137527287\n",
435
+ "369: loss: 0.0020243506878614426\n",
436
+ "370: loss: 4.458893775939941\n",
437
+ "371: loss: 5.930361270904541\n",
438
+ "372: loss: 4.287806510925293\n",
439
+ "373: loss: 3.365216016769409\n",
440
+ "374: loss: 0.011499284766614437\n",
441
+ "375: loss: 0.0031067240051925182\n",
442
+ "376: loss: 0.003569819498807192\n",
443
+ "377: loss: 0.0032246895134449005\n",
444
+ "378: loss: 0.0023426800034940243\n",
445
+ "379: loss: 0.0016774036921560764\n",
446
+ "380: loss: 0.0010665183654055\n",
447
+ "381: loss: 0.0007539619691669941\n",
448
+ "382: loss: 3.873556137084961\n",
449
+ "383: loss: 0.08063449710607529\n",
450
+ "384: loss: 0.0005400768714025617\n",
451
+ "385: loss: 0.000518861401360482\n",
452
+ "386: loss: 0.00048329788842238486\n",
453
+ "387: loss: 4.2107648849487305\n",
454
+ "388: loss: 4.465734481811523\n",
455
+ "389: loss: 0.000529197626747191\n",
456
+ "390: loss: 3.872891664505005\n",
457
+ "391: loss: 5.214785099029541\n",
458
+ "392: loss: 4.345657825469971\n",
459
+ "393: loss: 0.0016826370265334845\n",
460
+ "394: loss: 0.0024580529425293207\n",
461
+ "395: loss: 0.002994671929627657\n",
462
+ "396: loss: 0.002981696743518114\n",
463
+ "397: loss: 0.002537172520533204\n",
464
+ "398: loss: 0.001975367311388254\n",
465
+ "399: loss: 0.0014994062948971987\n",
466
+ "400: loss: 0.0011500928085297346\n",
467
+ "400: valid loss 0.0009022268350236118\n",
468
+ "401: loss: 5.212808132171631\n",
469
+ "402: loss: 0.0008533270447514951\n",
470
+ "403: loss: 0.0008498210809193552\n",
471
+ "404: loss: 0.0008541711140424013\n",
472
+ "405: loss: 3.912627696990967\n",
473
+ "406: loss: 0.0008917151135392487\n",
474
+ "407: loss: 0.0009278871002607048\n",
475
+ "408: loss: 3.4623196125030518\n",
476
+ "409: loss: 0.0011483340058475733\n",
477
+ "410: loss: 0.0014651089441031218\n",
478
+ "411: loss: 3.501060962677002\n",
479
+ "412: loss: 4.905694484710693\n",
480
+ "413: loss: 0.0025538327172398567\n",
481
+ "414: loss: 0.0019650040194392204\n",
482
+ "415: loss: 0.001453581964597106\n",
483
+ "416: loss: 4.282127857208252\n",
484
+ "417: loss: 0.001117513864301145\n",
485
+ "418: loss: 3.2745401859283447\n",
486
+ "419: loss: 3.0665171146392822\n",
487
+ "420: loss: 0.001583368401043117\n",
488
+ "421: loss: 0.0018978181760758162\n",
489
+ "422: loss: 5.070369720458984\n",
490
+ "423: loss: 0.0025998111814260483\n",
491
+ "424: loss: 0.0028609540313482285\n",
492
+ "425: loss: 2.7316229343414307\n",
493
+ "426: loss: 0.003324385266751051\n",
494
+ "427: loss: 0.00243724649772048\n",
495
+ "428: loss: 0.0020084292627871037\n",
496
+ "429: loss: 0.001639676047489047\n",
497
+ "430: loss: 0.0012756038922816515\n",
498
+ "431: loss: 0.0010202551493421197\n",
499
+ "432: loss: 0.0008382818195968866\n",
500
+ "433: loss: 3.9101459980010986\n",
501
+ "434: loss: 3.4464950561523438\n",
502
+ "435: loss: 4.598957538604736\n",
503
+ "436: loss: 6.656869888305664\n",
504
+ "437: loss: 2.557544469833374\n",
505
+ "438: loss: 1.769715666770935\n",
506
+ "439: loss: 0.8786362409591675\n",
507
+ "440: loss: 0.09529905021190643\n",
508
+ "441: loss: 3.9526867866516113\n",
509
+ "442: loss: 3.4567954540252686\n",
510
+ "443: loss: 0.28547608852386475\n",
511
+ "444: loss: 0.1331639289855957\n",
512
+ "445: loss: 0.01748904585838318\n",
513
+ "446: loss: 3.7364015579223633\n",
514
+ "447: loss: 1.6454107761383057\n",
515
+ "448: loss: 0.007931341417133808\n",
516
+ "449: loss: 0.0017749288817867637\n",
517
+ "450: loss: 3.6518070697784424\n",
518
+ "451: loss: 3.056483507156372\n",
519
+ "452: loss: 0.0008364453678950667\n",
520
+ "453: loss: 0.0009152528364211321\n",
521
+ "454: loss: 0.0009797721868380904\n",
522
+ "455: loss: 4.194733142852783\n",
523
+ "456: loss: 0.0013897174503654242\n",
524
+ "457: loss: 0.0018761098617687821\n",
525
+ "458: loss: 0.0020015202462673187\n",
526
+ "459: loss: 9.263550758361816\n",
527
+ "460: loss: 0.0025061527267098427\n",
528
+ "461: loss: 0.003998400643467903\n",
529
+ "462: loss: 0.0031979954801499844\n",
530
+ "463: loss: 0.0009064731420949101\n",
531
+ "464: loss: 3.1668450832366943\n",
532
+ "465: loss: 6.006053924560547\n",
533
+ "466: loss: 0.0006406777538359165\n",
534
+ "467: loss: 0.0009267539135180414\n",
535
+ "468: loss: 0.0012060123262926936\n",
536
+ "469: loss: 0.0013315295800566673\n",
537
+ "470: loss: 3.5539376735687256\n",
538
+ "471: loss: 3.4590916633605957\n",
539
+ "472: loss: 0.0017678193980827928\n",
540
+ "473: loss: 0.00218581547960639\n",
541
+ "474: loss: 0.0025737383402884007\n",
542
+ "475: loss: 2.97592830657959\n",
543
+ "476: loss: 0.0032222135923802853\n",
544
+ "477: loss: 0.0020487091969698668\n",
545
+ "478: loss: 3.0420033931732178\n",
546
+ "479: loss: 0.001554043497890234\n",
547
+ "480: loss: 0.001528518507257104\n",
548
+ "481: loss: 0.001422215485945344\n",
549
+ "482: loss: 0.0012641653884202242\n",
550
+ "483: loss: 0.0010866222437471151\n",
551
+ "484: loss: 7.149199962615967\n",
552
+ "485: loss: 0.0010687584290280938\n",
553
+ "486: loss: 0.0012197017204016447\n",
554
+ "487: loss: 0.001343191834166646\n",
555
+ "488: loss: 0.0013996028574183583\n",
556
+ "489: loss: 0.001371717662550509\n",
557
+ "490: loss: 3.68569278717041\n",
558
+ "491: loss: 0.0014253916451707482\n",
559
+ "492: loss: 0.001504680491052568\n",
560
+ "493: loss: 0.0014929386088624597\n",
561
+ "494: loss: 0.0013759569264948368\n",
562
+ "495: loss: 3.385620355606079\n",
563
+ "496: loss: 0.0012212302535772324\n",
564
+ "497: loss: 0.0011952322674915195\n",
565
+ "498: loss: 3.1083197593688965\n",
566
+ "499: loss: 8.146794319152832\n",
567
+ "500: loss: 3.8151681423187256\n",
568
+ "500: valid loss 3.2241313457489014\n",
569
+ "501: loss: 0.002565972041338682\n",
570
+ "502: loss: 4.1275224685668945\n",
571
+ "503: loss: 0.004586916882544756\n",
572
+ "504: loss: 3.6200292110443115\n",
573
+ "505: loss: 0.004917770624160767\n",
574
+ "506: loss: 0.0035543786361813545\n",
575
+ "507: loss: 0.002198878675699234\n",
576
+ "508: loss: 3.9696688652038574\n",
577
+ "509: loss: 0.0012150105321779847\n",
578
+ "510: loss: 3.0237858295440674\n",
579
+ "511: loss: 0.0016711285570636392\n",
580
+ "512: loss: 0.0017911652103066444\n",
581
+ "513: loss: 0.001645330572500825\n",
582
+ "514: loss: 3.3689823150634766\n",
583
+ "515: loss: 0.0014145843451842666\n",
584
+ "516: loss: 0.0013438486494123936\n",
585
+ "517: loss: 0.0011701782932505012\n",
586
+ "518: loss: 0.0009688445716165006\n",
587
+ "519: loss: 0.0007915324531495571\n",
588
+ "520: loss: 4.113221645355225\n",
589
+ "521: loss: 0.0006360645638778806\n",
590
+ "522: loss: 0.0006149905384518206\n",
591
+ "523: loss: 8.360527038574219\n",
592
+ "524: loss: 0.0006234433385543525\n",
593
+ "525: loss: 0.0006739232921972871\n",
594
+ "526: loss: 0.0007281479192897677\n",
595
+ "527: loss: 0.000767726160120219\n",
596
+ "528: loss: 0.000772368221078068\n",
597
+ "529: loss: 0.0007228502072393894\n",
598
+ "530: loss: 0.0006368369213305414\n",
599
+ "531: loss: 3.732311725616455\n",
600
+ "532: loss: 5.932078838348389\n",
601
+ "533: loss: 3.5892159938812256\n",
602
+ "534: loss: 5.249965667724609\n",
603
+ "535: loss: 7.211183071136475\n",
604
+ "536: loss: 4.0714263916015625\n",
605
+ "537: loss: 3.1499719619750977\n",
606
+ "538: loss: 0.1844794750213623\n",
607
+ "539: loss: 3.4192230701446533\n",
608
+ "540: loss: 0.011980107054114342\n",
609
+ "541: loss: 0.010612019337713718\n",
610
+ "542: loss: 0.0045662750490009785\n",
611
+ "543: loss: 0.005457601509988308\n",
612
+ "544: loss: 0.015783555805683136\n",
613
+ "545: loss: 0.0013816619757562876\n",
614
+ "546: loss: 8.18481731414795\n",
615
+ "547: loss: 0.0006438567652367055\n",
616
+ "548: loss: 0.000572906865272671\n",
617
+ "549: loss: 10.10994815826416\n",
618
+ "550: loss: 0.003346000798046589\n",
619
+ "551: loss: 0.0006713962065987289\n",
620
+ "552: loss: 0.00026078836526721716\n",
621
+ "553: loss: 11.756505012512207\n",
622
+ "554: loss: 7.101832389831543\n",
623
+ "555: loss: 0.00021459207346197218\n",
624
+ "556: loss: 0.00025998923229053617\n",
625
+ "557: loss: 0.0003112201811745763\n",
626
+ "558: loss: 14.851192474365234\n",
627
+ "559: loss: 0.0004224810691084713\n",
628
+ "560: loss: 0.00047494613681919873\n",
629
+ "561: loss: 0.000519308028742671\n",
630
+ "562: loss: 0.0005509845213964581\n",
631
+ "563: loss: 0.0005668219528160989\n",
632
+ "564: loss: 14.569344520568848\n",
633
+ "565: loss: 6.4913740158081055\n",
634
+ "566: loss: 0.0008433411712758243\n",
635
+ "567: loss: 8.495502471923828\n",
636
+ "568: loss: 0.0019402098841965199\n",
637
+ "569: loss: 0.0035519124940037727\n",
638
+ "570: loss: 0.006841914728283882\n",
639
+ "571: loss: 4.089066982269287\n",
640
+ "572: loss: 5.491721153259277\n",
641
+ "573: loss: 3.87937331199646\n",
642
+ "574: loss: 0.03460773825645447\n",
643
+ "575: loss: 0.015647828578948975\n",
644
+ "576: loss: 0.002720448188483715\n",
645
+ "577: loss: 6.188972473144531\n",
646
+ "578: loss: 0.0008381525985896587\n",
647
+ "579: loss: 0.0008579537970945239\n",
648
+ "580: loss: 0.0008331844583153725\n",
649
+ "581: loss: 7.444668769836426\n",
650
+ "582: loss: 0.0013645365834236145\n",
651
+ "583: loss: 0.0018909723730757833\n",
652
+ "584: loss: 4.148159503936768\n",
653
+ "585: loss: 6.465692043304443\n",
654
+ "586: loss: 0.0040971520356833935\n",
655
+ "587: loss: 0.015496809035539627\n",
656
+ "588: loss: 0.0011185817420482635\n",
657
+ "589: loss: 0.00048535081441514194\n",
658
+ "590: loss: 0.0002821610542014241\n",
659
+ "591: loss: 0.00022055530280340463\n",
660
+ "592: loss: 0.0002070294285658747\n",
661
+ "593: loss: 0.00021876658138353378\n",
662
+ "594: loss: 0.00024527875939384103\n",
663
+ "595: loss: 0.00028197691426612437\n",
664
+ "596: loss: 0.00031235843198373914\n",
665
+ "597: loss: 0.00032129406463354826\n",
666
+ "598: loss: 0.000305092049529776\n",
667
+ "599: loss: 6.581624507904053\n",
668
+ "600: loss: 0.0004181505355518311\n",
669
+ "600: valid loss 0.001562803634442389\n",
670
+ "601: loss: 0.001562803634442389\n",
671
+ "602: loss: 0.0008329854463227093\n",
672
+ "603: loss: 8.43118953704834\n",
673
+ "604: loss: 0.00018880203424487263\n",
674
+ "605: loss: 6.225329399108887\n",
675
+ "606: loss: 0.0001953585451701656\n",
676
+ "607: loss: 0.00031005332130007446\n",
677
+ "608: loss: 6.243394374847412\n",
678
+ "609: loss: 0.002007008297368884\n",
679
+ "610: loss: 0.2842656672000885\n",
680
+ "611: loss: 0.002102950122207403\n",
681
+ "612: loss: 0.0013235295191407204\n",
682
+ "613: loss: 0.0012432391522452235\n",
683
+ "614: loss: 0.0011076040100306273\n",
684
+ "615: loss: 0.0009366637095808983\n",
685
+ "616: loss: 0.0007713991799391806\n",
686
+ "617: loss: 0.0006266268319450319\n",
687
+ "618: loss: 0.0005072436179034412\n",
688
+ "619: loss: 0.00041213506483472884\n",
689
+ "620: loss: 0.0003370844351593405\n",
690
+ "621: loss: 0.0002783465606626123\n",
691
+ "622: loss: 6.750359535217285\n",
692
+ "623: loss: 4.032569408416748\n",
693
+ "624: loss: 4.749107360839844\n",
694
+ "625: loss: 5.599199295043945\n",
695
+ "626: loss: 4.851316452026367\n",
696
+ "627: loss: 0.0012356003280729055\n",
697
+ "628: loss: 0.0019876735750585794\n",
698
+ "629: loss: 0.0022025934886187315\n",
699
+ "630: loss: 0.09389199316501617\n",
700
+ "631: loss: 0.0011942394776269794\n",
701
+ "632: loss: 0.0008771757711656392\n",
702
+ "633: loss: 0.000724500569049269\n",
703
+ "634: loss: 4.850365161895752\n",
704
+ "635: loss: 6.96458101272583\n",
705
+ "636: loss: 3.944305658340454\n",
706
+ "637: loss: 1.573992133140564\n",
707
+ "638: loss: 0.006376080680638552\n",
708
+ "639: loss: 0.004621799103915691\n",
709
+ "640: loss: 0.008686978369951248\n",
710
+ "641: loss: 0.002786734839901328\n",
711
+ "642: loss: 0.0012673415476456285\n",
712
+ "643: loss: 0.0008905518334358931\n"
713
+ ]
714
+ },
715
+ {
716
+ "ename": "KeyboardInterrupt",
717
+ "evalue": "",
718
+ "output_type": "error",
719
+ "traceback": [
720
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
721
+ "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
722
+ "Cell \u001b[1;32mIn[3], line 78\u001b[0m\n\u001b[0;32m 72\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m semantic_transformer, trainer, wav2vec\n\u001b[0;32m 73\u001b[0m gc\u001b[38;5;241m.\u001b[39mcollect()\n\u001b[1;32m---> 78\u001b[0m \u001b[43mtrain_semantic_transformer\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
723
+ "Cell \u001b[1;32mIn[3], line 69\u001b[0m, in \u001b[0;36mtrain_semantic_transformer\u001b[1;34m()\u001b[0m\n\u001b[0;32m 52\u001b[0m semantic_transformer \u001b[38;5;241m=\u001b[39m SemanticTransformer(\n\u001b[0;32m 53\u001b[0m num_semantic_tokens\u001b[38;5;241m=\u001b[39mwav2vec\u001b[38;5;241m.\u001b[39mcodebook_size,\n\u001b[0;32m 54\u001b[0m dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1024\u001b[39m,\n\u001b[0;32m 55\u001b[0m depth\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m6\u001b[39m,\n\u001b[0;32m 56\u001b[0m audio_text_condition\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m 57\u001b[0m )\n\u001b[0;32m 59\u001b[0m trainer \u001b[38;5;241m=\u001b[39m SemanticTransformerTrainer(\n\u001b[0;32m 60\u001b[0m transformer\u001b[38;5;241m=\u001b[39msemantic_transformer,\n\u001b[0;32m 61\u001b[0m wav2vec\u001b[38;5;241m=\u001b[39mwav2vec,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 66\u001b[0m num_train_steps\u001b[38;5;241m=\u001b[39mnum_train_steps\n\u001b[0;32m 67\u001b[0m )\n\u001b[1;32m---> 69\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 70\u001b[0m torch\u001b[38;5;241m.\u001b[39msave(semantic_transformer\u001b[38;5;241m.\u001b[39mstate_dict(), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msemantic_transformer.pth\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m 71\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msave semantic_transformer.pth\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
724
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiolm_pytorch\\trainer.py:1000\u001b[0m, in \u001b[0;36mSemanticTransformerTrainer.train\u001b[1;34m(self, log_fn)\u001b[0m\n\u001b[0;32m 997\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtrain\u001b[39m(\u001b[38;5;28mself\u001b[39m, log_fn \u001b[38;5;241m=\u001b[39m noop):\n\u001b[0;32m 999\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msteps \u001b[38;5;241m<\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_train_steps:\n\u001b[1;32m-> 1000\u001b[0m logs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1001\u001b[0m log_fn(logs)\n\u001b[0;32m 1003\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprint(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtraining complete\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
725
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiolm_pytorch\\trainer.py:944\u001b[0m, in \u001b[0;36mSemanticTransformerTrainer.train_step\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 941\u001b[0m data_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdata_tuple_to_kwargs(\u001b[38;5;28mnext\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdl_iter))\n\u001b[0;32m 943\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mautocast(), context():\n\u001b[1;32m--> 944\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_wrapper\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mdata_kwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_loss\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 946\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mbackward(loss \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgrad_accum_every)\n\u001b[0;32m 948\u001b[0m accum_log(logs, {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mloss\u001b[39m\u001b[38;5;124m'\u001b[39m: loss\u001b[38;5;241m.\u001b[39mitem() \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgrad_accum_every})\n",
726
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
727
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
728
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiolm_pytorch\\audiolm_pytorch.py:1480\u001b[0m, in \u001b[0;36mSemanticTransformerWrapper.forward\u001b[1;34m(self, semantic_token_ids, raw_wave, text, text_embeds, return_loss, **kwargs)\u001b[0m\n\u001b[0;32m 1478\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m exists(raw_wave)\n\u001b[0;32m 1479\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m exists(text) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m exists(text_embeds)\n\u001b[1;32m-> 1480\u001b[0m text_embeds \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maudio_conditioner\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwavs\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mraw_wave\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnamespace\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43msemantic\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1482\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m exists(semantic_token_ids):\n\u001b[0;32m 1483\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m exists(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwav2vec), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mVQWav2Vec must be be provided if given raw wave for training\u001b[39m\u001b[38;5;124m'\u001b[39m\n",
729
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
730
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
731
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\musiclm_pytorch\\musiclm_pytorch.py:872\u001b[0m, in \u001b[0;36mMuLaNEmbedQuantizer.forward\u001b[1;34m(self, wavs, texts, namespace)\u001b[0m\n\u001b[0;32m 869\u001b[0m \u001b[38;5;66;03m# sound and language live in joint embedding space because of contrastive learning\u001b[39;00m\n\u001b[0;32m 871\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m exists(wavs):\n\u001b[1;32m--> 872\u001b[0m latents \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmulan\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_audio_latents\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwavs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 873\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m exists(texts):\n\u001b[0;32m 874\u001b[0m latents \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmulan\u001b[38;5;241m.\u001b[39mget_text_latents(texts)\n",
732
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\musiclm_pytorch\\musiclm_pytorch.py:732\u001b[0m, in \u001b[0;36mMuLaN.get_audio_latents\u001b[1;34m(self, wavs, return_all_layers)\u001b[0m\n\u001b[0;32m 727\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_audio_latents\u001b[39m(\n\u001b[0;32m 728\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[0;32m 729\u001b[0m wavs,\n\u001b[0;32m 730\u001b[0m return_all_layers \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m 731\u001b[0m ):\n\u001b[1;32m--> 732\u001b[0m audio_embeds, audio_layers \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maudio\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwavs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_all_layers\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 733\u001b[0m audio_latents \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maudio_to_latents(audio_embeds)\n\u001b[0;32m 734\u001b[0m out \u001b[38;5;241m=\u001b[39m l2norm(audio_latents)\n",
733
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
734
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
735
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\musiclm_pytorch\\musiclm_pytorch.py:525\u001b[0m, in \u001b[0;36mAudioSpectrogramTransformer.forward\u001b[1;34m(self, x, force_no_patch_dropout, return_all_layers)\u001b[0m\n\u001b[0;32m 521\u001b[0m rel_pos_bias \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdynamic_pos_bias_mlp(rel_dist\u001b[38;5;241m.\u001b[39mfloat())\n\u001b[0;32m 523\u001b[0m \u001b[38;5;66;03m# attention, what else\u001b[39;00m\n\u001b[1;32m--> 525\u001b[0m x, all_layers \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtransformer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrel_pos_bias\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mrel_pos_bias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_all_layers\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 527\u001b[0m \u001b[38;5;66;03m# final global average and norm (most recent papers show this is superior to CLS token)\u001b[39;00m\n\u001b[0;32m 529\u001b[0m x \u001b[38;5;241m=\u001b[39m reduce(x, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mb n d -> b d\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmean\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
736
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
737
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
738
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\musiclm_pytorch\\musiclm_pytorch.py:247\u001b[0m, in \u001b[0;36mTransformer.forward\u001b[1;34m(self, x, rel_pos_bias, mask, return_all_layers)\u001b[0m\n\u001b[0;32m 245\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m attn, ff \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayers:\n\u001b[0;32m 246\u001b[0m x \u001b[38;5;241m=\u001b[39m attn(x, rel_pos_bias \u001b[38;5;241m=\u001b[39m rel_pos_bias, mask \u001b[38;5;241m=\u001b[39m mask) \u001b[38;5;241m+\u001b[39m x\n\u001b[1;32m--> 247\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[43mff\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;241m+\u001b[39m x\n\u001b[0;32m 248\u001b[0m layers\u001b[38;5;241m.\u001b[39mappend(x)\n\u001b[0;32m 250\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m return_all_layers:\n",
739
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
740
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
741
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\container.py:217\u001b[0m, in \u001b[0;36mSequential.forward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 215\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[0;32m 216\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[1;32m--> 217\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 218\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n",
742
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
743
+ "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
744
+ "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
745
+ ]
746
+ }
747
+ ],
748
+ "source": [
749
+ "audio_transformer = AudioSpectrogramTransformer(\n",
750
+ " dim = 512,\n",
751
+ " depth = 6,\n",
752
+ " heads = 8,\n",
753
+ " dim_head = 64,\n",
754
+ " spec_n_fft = 128,\n",
755
+ " spec_win_length = 24,\n",
756
+ " spec_aug_stretch_factor = 0.8\n",
757
+ ")\n",
758
+ "\n",
759
+ "text_transformer = TextTransformer(\n",
760
+ " dim = 512,\n",
761
+ " depth = 6,\n",
762
+ " heads = 8,\n",
763
+ " dim_head = 64\n",
764
+ ")\n",
765
+ "\n",
766
+ "mulan = MuLaN(\n",
767
+ " audio_transformer = audio_transformer,\n",
768
+ " text_transformer = text_transformer\n",
769
+ ")\n",
770
+ "\n",
771
+ "# setup the quantizer with the namespaced conditioning embeddings, unique per quantizer as well as namespace (per transformer)\n",
772
+ "\n",
773
+ "quantizer = MuLaNEmbedQuantizer(\n",
774
+ " mulan = mulan, # pass in trained mulan from above\n",
775
+ " conditioning_dims = (1024, 1024, 1024), # say all three transformers have model dimensions of 1024\n",
776
+ " namespaces = ('semantic', 'coarse', 'fine')\n",
777
+ ")\n",
778
+ "\n",
779
+ "# now say you want the conditioning embeddings for semantic transformer\n",
780
+ "\n",
781
+ "wavs = torch.randn(2, 1024)\n",
782
+ "conds = quantizer(wavs = wavs, namespace = 'semantic') # (2, 8, 1024) - 8 is number of quantizers\n",
783
+ "\n",
784
+ "# SemanticTransformer\n",
785
+ "def train_semantic_transformer():\n",
786
+ " wav2vec = HubertWithKmeans(\n",
787
+ " checkpoint_path=checkpoint_path,\n",
788
+ " kmeans_path=kmeans_path\n",
789
+ " )\n",
790
+ "\n",
791
+ "\n",
792
+ " if torch.cuda.is_available():\n",
793
+ " semantic_transformer = SemanticTransformer(\n",
794
+ " num_semantic_tokens=wav2vec.codebook_size,\n",
795
+ " dim=1024,\n",
796
+ " depth=6,\n",
797
+ " audio_text_condition=True\n",
798
+ " ).cuda()\n",
799
+ " else:\n",
800
+ " semantic_transformer = SemanticTransformer(\n",
801
+ " num_semantic_tokens=wav2vec.codebook_size,\n",
802
+ " dim=1024,\n",
803
+ " depth=6,\n",
804
+ " audio_text_condition=True\n",
805
+ " )\n",
806
+ "\n",
807
+ " trainer = SemanticTransformerTrainer(\n",
808
+ " transformer=semantic_transformer,\n",
809
+ " wav2vec=wav2vec,\n",
810
+ " audio_conditioner=quantizer,\n",
811
+ " folder=audio_output_dir,\n",
812
+ " batch_size=batch_size,\n",
813
+ " data_max_length=data_max_length,\n",
814
+ " num_train_steps=num_train_steps\n",
815
+ " )\n",
816
+ "\n",
817
+ " trainer.train()\n",
818
+ " torch.save(semantic_transformer.state_dict(), 'semantic_transformer.pth')\n",
819
+ " print(\"save semantic_transformer.pth\")\n",
820
+ " del semantic_transformer, trainer, wav2vec\n",
821
+ " gc.collect()\n",
822
+ "\n",
823
+ "\n",
824
+ "\n",
825
+ "\n",
826
+ "train_semantic_transformer()"
827
+ ]
828
+ }
829
+ ],
830
+ "metadata": {
831
+ "kernelspec": {
832
+ "display_name": "myenv",
833
+ "language": "python",
834
+ "name": "python3"
835
+ },
836
+ "language_info": {
837
+ "codemirror_mode": {
838
+ "name": "ipython",
839
+ "version": 3
840
+ },
841
+ "file_extension": ".py",
842
+ "mimetype": "text/x-python",
843
+ "name": "python",
844
+ "nbconvert_exporter": "python",
845
+ "pygments_lexer": "ipython3",
846
+ "version": "3.11.2"
847
+ }
848
+ },
849
+ "nbformat": 4,
850
+ "nbformat_minor": 2
851
+ }