{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Semantic Transformer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Libraries" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import multiprocessing\n", "from audiolm_pytorch import HubertWithKmeans, MusicLMSoundStream\n", "from audiolm_pytorch import SemanticTransformer, SemanticTransformerTrainer\n", "from audiolm_pytorch import CoarseTransformer, CoarseTransformerTrainer\n", "from audiolm_pytorch import FineTransformer, FineTransformerTrainer\n", "from musiclm_pytorch import MuLaNEmbedQuantizer\n", "from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer\n", "import gc " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "checkpoint_path = './models/hubert/hubert_base_ls960.pt'\n", "kmeans_path = './models/hubert/hubert_base_ls960_L9_km500.bin'\n", "audio_output_dir = './audio'\n", "batch_size = 1\n", "data_max_length = 320 * 32\n", "num_train_steps = 1000" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "spectrogram yielded shape of (65, 86), but had to be cropped to (64, 80) to be patchified for transformer\n", "ANTLR runtime and generated code versions disagree: 4.9.3!=4.8\n", "ANTLR runtime and generated code versions disagree: 4.9.3!=4.8\n", "training with dataset of 4806 samples and validating with randomly splitted 253 samples\n", "0: loss: 6.5572309494018555\n", "0: valid loss 6.723005294799805\n", "0: saving model to results\n", "1: loss: 6.5375285148620605\n", "2: loss: 5.515031337738037\n", "3: loss: 0.6989991664886475\n", "4: loss: 0.016623886302113533\n", "5: loss: 6.3969268798828125\n", "6: loss: 0.8643577098846436\n", "7: loss: 0.008508207276463509\n", "8: loss: 0.00020680516900029033\n", "9: loss: 8.900370597839355\n", "10: loss: 0.00010900969209615141\n", "11: loss: 0.0001591881300555542\n", "12: loss: 8.055902481079102\n", "13: loss: 0.0009496303973719478\n", "14: loss: 0.0027423782739788294\n", "15: loss: 0.0009589337860234082\n", "16: loss: 7.296541690826416\n", "17: loss: 0.0005210856324993074\n", "18: loss: 0.0008424322586506605\n", "19: loss: 5.571179389953613\n", "20: loss: 0.003094581188634038\n", "21: loss: 0.0019461463671177626\n", "22: loss: 5.488490104675293\n", "23: loss: 4.800296783447266\n", "24: loss: 4.962136268615723\n", "25: loss: 5.943732738494873\n", "26: loss: 0.006312617566436529\n", "27: loss: 4.396454334259033\n", "28: loss: 0.012498963624238968\n", "29: loss: 0.0049488842487335205\n", "30: loss: 0.0011625693878158927\n", "31: loss: 3.445856809616089\n", "32: loss: 0.000534387887455523\n", "33: loss: 0.000711498549208045\n", "34: loss: 0.0009514373959973454\n", "35: loss: 0.001239188713952899\n", "36: loss: 8.732012748718262\n", "37: loss: 0.0009216524777002633\n", "38: loss: 0.0006809335318394005\n", "39: loss: 0.000797786982730031\n", "40: loss: 4.916833400726318\n", "41: loss: 0.0010107718408107758\n", "42: loss: 0.0008451942121610045\n", "43: loss: 3.160980701446533\n", "44: loss: 0.0008387335110455751\n", "45: loss: 0.0010360947344452143\n", "46: loss: 0.001215349417179823\n", "47: loss: 5.990973949432373\n", "48: loss: 0.0017369053093716502\n", "49: loss: 6.410669803619385\n", "50: loss: 0.003450337564572692\n", "51: loss: 0.003860922297462821\n", "52: loss: 0.002359303878620267\n", "53: loss: 0.001058467198163271\n", "54: loss: 0.00047752217506058514\n", "55: loss: 0.00025489379186183214\n", "56: loss: 0.00016276698443107307\n", "57: loss: 7.828070163726807\n", "58: loss: 0.00011652028479147702\n", "59: loss: 4.505963325500488\n", "60: loss: 0.00013153781765140593\n", "61: loss: 0.00015024915046524256\n", "62: loss: 0.00017777853645384312\n", "63: loss: 8.09732437133789\n", "64: loss: 0.00041875039460137486\n", "65: loss: 0.0009824583539739251\n", "66: loss: 0.001990197692066431\n", "67: loss: 5.392111778259277\n", "68: loss: 0.0017270153621211648\n", "69: loss: 0.0010434042196720839\n", "70: loss: 0.0005951145431026816\n", "71: loss: 0.00037293724017217755\n", "72: loss: 0.00025969729176722467\n", "73: loss: 7.013213157653809\n", "74: loss: 3.807203531265259\n", "75: loss: 0.00026780215557664633\n", "76: loss: 0.00031897667213343084\n", "77: loss: 0.0003657388442661613\n", "78: loss: 5.076975345611572\n", "79: loss: 0.001055362867191434\n", "80: loss: 0.0010116726625710726\n", "81: loss: 0.0017484871204942465\n", "82: loss: 0.0018696936313062906\n", "83: loss: 5.30266809463501\n", "84: loss: 5.457505226135254\n", "85: loss: 0.0012204349040985107\n", "86: loss: 3.2936503887176514\n", "87: loss: 0.0020471797324717045\n", "88: loss: 0.0026046710554510355\n", "89: loss: 0.0026721167378127575\n", "90: loss: 0.0024667021352797747\n", "91: loss: 5.0201215744018555\n", "92: loss: 4.591504096984863\n", "93: loss: 0.0025711969938129187\n", "94: loss: 0.002706416416913271\n", "95: loss: 0.0024713831953704357\n", "96: loss: 0.002004373585805297\n", "97: loss: 0.001489203074015677\n", "98: loss: 0.0010426173685118556\n", "99: loss: 8.796974182128906\n", "100: loss: 0.0005365900578908622\n", "100: valid loss 5.255128860473633\n", "101: loss: 0.0004417159070726484\n", "102: loss: 4.595282554626465\n", "103: loss: 0.000659952696878463\n", "104: loss: 0.0008260122267529368\n", "105: loss: 0.0009083786280825734\n", "106: loss: 4.042155742645264\n", "107: loss: 4.17121696472168\n", "108: loss: 0.0007671767962165177\n", "109: loss: 4.022541522979736\n", "110: loss: 3.5455234050750732\n", "111: loss: 0.001035561435855925\n", "112: loss: 0.0012967187212780118\n", "113: loss: 7.237168312072754\n", "114: loss: 3.522667407989502\n", "115: loss: 0.004003542009741068\n", "116: loss: 0.0040553268045187\n", "117: loss: 0.0029700316954404116\n", "118: loss: 0.0019125432008877397\n", "119: loss: 3.4947195053100586\n", "120: loss: 0.001095975050702691\n", "121: loss: 0.0009612821158953011\n", "122: loss: 0.000824352668132633\n", "123: loss: 3.3077425956726074\n", "124: loss: 0.0007418203167617321\n", "125: loss: 0.0007488489500246942\n", "126: loss: 0.0007235489320009947\n", "127: loss: 3.426555633544922\n", "128: loss: 0.0006980476318858564\n", "129: loss: 0.0006986281368881464\n", "130: loss: 0.0006706370622850955\n", "131: loss: 0.0006185953388921916\n", "132: loss: 4.421964645385742\n", "133: loss: 0.0006264401017688215\n", "134: loss: 0.0006876335828565061\n", "135: loss: 0.0007215599762275815\n", "136: loss: 0.0007203654968179762\n", "137: loss: 0.0006922150496393442\n", "138: loss: 0.0006356032681651413\n", "139: loss: 3.7695367336273193\n", "140: loss: 0.0006305422284640372\n", "141: loss: 0.0006744156708009541\n", "142: loss: 0.0006895355763845146\n", "143: loss: 3.770907402038574\n", "144: loss: 0.000908360059838742\n", "145: loss: 0.0011299465550109744\n", "146: loss: 0.0012696339981630445\n", "147: loss: 0.0012722468236461282\n", "148: loss: 3.8808021545410156\n", "149: loss: 3.783026695251465\n", "150: loss: 0.002035590121522546\n", "151: loss: 0.0026034933980554342\n", "152: loss: 0.0024936539120972157\n", "153: loss: 0.0018582777120172977\n", "154: loss: 2.8572535514831543\n", "155: loss: 0.001062657218426466\n", "156: loss: 0.0008821044466458261\n", "157: loss: 0.0007058316841721535\n", "158: loss: 0.0005539683043025434\n", "159: loss: 5.476413726806641\n", "160: loss: 0.00043070572428405285\n", "161: loss: 0.00042034301441162825\n", "162: loss: 0.0004015824815724045\n", "163: loss: 0.0003759717510547489\n", "164: loss: 0.00034577338374219835\n", "165: loss: 3.9209775924682617\n", "166: loss: 0.0003425567992962897\n", "167: loss: 0.00036322552477940917\n", "168: loss: 0.00037287475424818695\n", "169: loss: 7.9045209884643555\n", "170: loss: 0.0004473228473216295\n", "171: loss: 0.0005134259699843824\n", "172: loss: 2.9501657485961914\n", "173: loss: 0.0008285943185910583\n", "174: loss: 0.00113466486800462\n", "175: loss: 0.0013167448341846466\n", "176: loss: 0.0014080735854804516\n", "177: loss: 4.0473408699035645\n", "178: loss: 0.0016744763124734163\n", "179: loss: 0.0016492144204676151\n", "180: loss: 4.165207386016846\n", "181: loss: 0.0017677460564300418\n", "182: loss: 0.0018474040552973747\n", "183: loss: 0.0017496442887932062\n", "184: loss: 3.3882288932800293\n", "185: loss: 0.0018872346263378859\n", "186: loss: 3.0333187580108643\n", "187: loss: 0.0028638774529099464\n", "188: loss: 3.709534168243408\n", "189: loss: 6.904417991638184\n", "190: loss: 0.006619434338063002\n", "191: loss: 0.00595641415566206\n", "192: loss: 5.050801753997803\n", "193: loss: 3.7556490898132324\n", "194: loss: 0.002467694692313671\n", "195: loss: 0.002025420544669032\n", "196: loss: 0.001494809053838253\n", "197: loss: 0.0010330628138035536\n", "198: loss: 0.0006917425780557096\n", "199: loss: 0.0004644835426006466\n", "200: loss: 4.2029547691345215\n", "200: valid loss 0.00025821171584539115\n", "201: loss: 4.2771501541137695\n", "202: loss: 3.7102839946746826\n", "203: loss: 3.7408058643341064\n", "204: loss: 0.0003981325135100633\n", "205: loss: 0.0005507581518031657\n", "206: loss: 4.065332889556885\n", "207: loss: 0.0011804178357124329\n", "208: loss: 0.0017080714460462332\n", "209: loss: 0.0021062048617750406\n", "210: loss: 0.0021494715474545956\n", "211: loss: 6.465389251708984\n", "212: loss: 0.0029505854472517967\n", "213: loss: 3.367213010787964\n", "214: loss: 0.27502918243408203\n", "215: loss: 0.9933775663375854\n", "216: loss: 0.810478925704956\n", "217: loss: 0.4562891721725464\n", "218: loss: 0.24387648701667786\n", "219: loss: 0.11290910840034485\n", "220: loss: 0.019248925149440765\n", "221: loss: 0.0021138046868145466\n", "222: loss: 5.169565200805664\n", "223: loss: 0.0008601757581345737\n", "224: loss: 3.9269232749938965\n", "225: loss: 0.0007863161154091358\n", "226: loss: 0.00024547570501454175\n", "227: loss: 4.449281215667725\n", "228: loss: 0.00019524114031810313\n", "229: loss: 5.162830829620361\n", "230: loss: 0.0005567128537222743\n", "231: loss: 4.195521831512451\n", "232: loss: 3.7389187812805176\n", "233: loss: 5.919421672821045\n", "234: loss: 6.7034173011779785\n", "235: loss: 5.353506088256836\n", "236: loss: 2.4018566608428955\n", "237: loss: 3.7457311153411865\n", "238: loss: 0.17652225494384766\n", "239: loss: 4.564880847930908\n", "240: loss: 0.027039170265197754\n", "241: loss: 0.005270962603390217\n", "242: loss: 0.0015485308831557631\n", "243: loss: 0.0010360399028286338\n", "244: loss: 0.0007773903198540211\n", "245: loss: 6.206174850463867\n", "246: loss: 6.409456253051758\n", "247: loss: 0.04051050543785095\n", "248: loss: 0.0017684113699942827\n", "249: loss: 0.00044090740266256034\n", "250: loss: 5.761023044586182\n", "251: loss: 0.00016311556100845337\n", "252: loss: 0.0001715785765554756\n", "253: loss: 0.00019523760420270264\n", "254: loss: 0.00023307953961193562\n", "255: loss: 0.00028373271925374866\n", "256: loss: 4.927147388458252\n", "257: loss: 4.228280544281006\n", "258: loss: 0.0011933923233300447\n", "259: loss: 0.005215882323682308\n", "260: loss: 0.0013388781808316708\n", "261: loss: 4.206026554107666\n", "262: loss: 0.0034830207005143166\n", "263: loss: 4.173500061035156\n", "264: loss: 0.007450783159583807\n", "265: loss: 4.5892510414123535\n", "266: loss: 0.006880312692373991\n", "267: loss: 4.572935104370117\n", "268: loss: 0.002904222346842289\n", "269: loss: 3.2348222732543945\n", "270: loss: 4.376621723175049\n", "271: loss: 3.573988914489746\n", "272: loss: 0.0010127610294148326\n", "273: loss: 9.308874130249023\n", "274: loss: 4.688360214233398\n", "275: loss: 3.9581832885742188\n", "276: loss: 0.01065391581505537\n", "277: loss: 0.0067514535039663315\n", "278: loss: 0.003611961379647255\n", "279: loss: 0.001811509020626545\n", "280: loss: 0.0009013370145112276\n", "281: loss: 4.266546726226807\n", "282: loss: 5.132745742797852\n", "283: loss: 0.000957090116571635\n", "284: loss: 0.0015025322791188955\n", "285: loss: 6.258731842041016\n", "286: loss: 5.029386043548584\n", "287: loss: 0.007954631000757217\n", "288: loss: 0.0050008054822683334\n", "289: loss: 0.001655810745432973\n", "290: loss: 5.501289367675781\n", "291: loss: 4.655749797821045\n", "292: loss: 4.383106231689453\n", "293: loss: 0.000304496381431818\n", "294: loss: 0.0003326725563965738\n", "295: loss: 0.00035310350358486176\n", "296: loss: 5.683162212371826\n", "297: loss: 0.0004622728156391531\n", "298: loss: 4.067113399505615\n", "299: loss: 0.0008154112147167325\n", "300: loss: 0.00108420941978693\n", "300: valid loss 0.0013179074740037322\n", "301: loss: 0.0013179074740037322\n", "302: loss: 4.358561992645264\n", "303: loss: 5.026749610900879\n", "304: loss: 0.002862808993086219\n", "305: loss: 0.003396229352802038\n", "306: loss: 5.530904293060303\n", "307: loss: 0.0035779180470854044\n", "308: loss: 0.003205555956810713\n", "309: loss: 4.112671852111816\n", "310: loss: 3.6920313835144043\n", "311: loss: 0.0026951604522764683\n", "312: loss: 0.0026851999573409557\n", "313: loss: 3.3092551231384277\n", "314: loss: 0.0024079573340713978\n", "315: loss: 0.0022026696242392063\n", "316: loss: 0.0018284200923517346\n", "317: loss: 0.0014258958399295807\n", "318: loss: 0.0010761057492345572\n", "319: loss: 0.0008039181702770293\n", "320: loss: 0.0006038622814230621\n", "321: loss: 0.00046244796249084175\n", "322: loss: 5.89370059967041\n", "323: loss: 0.00031747910543344915\n", "324: loss: 0.00028221303364261985\n", "325: loss: 0.00025451104738749564\n", "326: loss: 0.00023175252135843039\n", "327: loss: 0.00021364034910220653\n", "328: loss: 3.906613826751709\n", "329: loss: 3.844726085662842\n", "330: loss: 0.00023705456987954676\n", "331: loss: 0.0002663657069206238\n", "332: loss: 0.0002947220054920763\n", "333: loss: 6.28004264831543\n", "334: loss: 0.0003821635036729276\n", "335: loss: 3.633335828781128\n", "336: loss: 0.0005681345355696976\n", "337: loss: 6.994467735290527\n", "338: loss: 7.915759086608887\n", "339: loss: 0.0026061832904815674\n", "340: loss: 0.0048998151905834675\n", "341: loss: 0.004243680741637945\n", "342: loss: 0.0025005636271089315\n", "343: loss: 4.005818843841553\n", "344: loss: 0.0011636920971795917\n", "345: loss: 0.0009634271846152842\n", "346: loss: 0.0008427661377936602\n", "347: loss: 0.0007607618463225663\n", "348: loss: 0.0006956492434255779\n", "349: loss: 4.547393798828125\n", "350: loss: 0.0006480301963165402\n", "351: loss: 0.0006520788883790374\n", "352: loss: 0.0006446384941227734\n", "353: loss: 4.283820629119873\n", "354: loss: 0.0007140468223951757\n", "355: loss: 0.000788742327131331\n", "356: loss: 0.0008332571596838534\n", "357: loss: 0.0008390303701162338\n", "358: loss: 0.000806896947324276\n", "359: loss: 4.646646976470947\n", "360: loss: 0.0021708165295422077\n", "361: loss: 0.0009108624653890729\n", "362: loss: 3.9582133293151855\n", "363: loss: 3.3569955825805664\n", "364: loss: 0.002499263733625412\n", "365: loss: 4.646510601043701\n", "366: loss: 0.0032457842025905848\n", "367: loss: 0.0033331059385091066\n", "368: loss: 0.00275675137527287\n", "369: loss: 0.0020243506878614426\n", "370: loss: 4.458893775939941\n", "371: loss: 5.930361270904541\n", "372: loss: 4.287806510925293\n", "373: loss: 3.365216016769409\n", "374: loss: 0.011499284766614437\n", "375: loss: 0.0031067240051925182\n", "376: loss: 0.003569819498807192\n", "377: loss: 0.0032246895134449005\n", "378: loss: 0.0023426800034940243\n", "379: loss: 0.0016774036921560764\n", "380: loss: 0.0010665183654055\n", "381: loss: 0.0007539619691669941\n", "382: loss: 3.873556137084961\n", "383: loss: 0.08063449710607529\n", "384: loss: 0.0005400768714025617\n", "385: loss: 0.000518861401360482\n", "386: loss: 0.00048329788842238486\n", "387: loss: 4.2107648849487305\n", "388: loss: 4.465734481811523\n", "389: loss: 0.000529197626747191\n", "390: loss: 3.872891664505005\n", "391: loss: 5.214785099029541\n", "392: loss: 4.345657825469971\n", "393: loss: 0.0016826370265334845\n", "394: loss: 0.0024580529425293207\n", "395: loss: 0.002994671929627657\n", "396: loss: 0.002981696743518114\n", "397: loss: 0.002537172520533204\n", "398: loss: 0.001975367311388254\n", "399: loss: 0.0014994062948971987\n", "400: loss: 0.0011500928085297346\n", "400: valid loss 0.0009022268350236118\n", "401: loss: 5.212808132171631\n", "402: loss: 0.0008533270447514951\n", "403: loss: 0.0008498210809193552\n", "404: loss: 0.0008541711140424013\n", "405: loss: 3.912627696990967\n", "406: loss: 0.0008917151135392487\n", "407: loss: 0.0009278871002607048\n", "408: loss: 3.4623196125030518\n", "409: loss: 0.0011483340058475733\n", "410: loss: 0.0014651089441031218\n", "411: loss: 3.501060962677002\n", "412: loss: 4.905694484710693\n", "413: loss: 0.0025538327172398567\n", "414: loss: 0.0019650040194392204\n", "415: loss: 0.001453581964597106\n", "416: loss: 4.282127857208252\n", "417: loss: 0.001117513864301145\n", "418: loss: 3.2745401859283447\n", "419: loss: 3.0665171146392822\n", "420: loss: 0.001583368401043117\n", "421: loss: 0.0018978181760758162\n", "422: loss: 5.070369720458984\n", "423: loss: 0.0025998111814260483\n", "424: loss: 0.0028609540313482285\n", "425: loss: 2.7316229343414307\n", "426: loss: 0.003324385266751051\n", "427: loss: 0.00243724649772048\n", "428: loss: 0.0020084292627871037\n", "429: loss: 0.001639676047489047\n", "430: loss: 0.0012756038922816515\n", "431: loss: 0.0010202551493421197\n", "432: loss: 0.0008382818195968866\n", "433: loss: 3.9101459980010986\n", "434: loss: 3.4464950561523438\n", "435: loss: 4.598957538604736\n", "436: loss: 6.656869888305664\n", "437: loss: 2.557544469833374\n", "438: loss: 1.769715666770935\n", "439: loss: 0.8786362409591675\n", "440: loss: 0.09529905021190643\n", "441: loss: 3.9526867866516113\n", "442: loss: 3.4567954540252686\n", "443: loss: 0.28547608852386475\n", "444: loss: 0.1331639289855957\n", "445: loss: 0.01748904585838318\n", "446: loss: 3.7364015579223633\n", "447: loss: 1.6454107761383057\n", "448: loss: 0.007931341417133808\n", "449: loss: 0.0017749288817867637\n", "450: loss: 3.6518070697784424\n", "451: loss: 3.056483507156372\n", "452: loss: 0.0008364453678950667\n", "453: loss: 0.0009152528364211321\n", "454: loss: 0.0009797721868380904\n", "455: loss: 4.194733142852783\n", "456: loss: 0.0013897174503654242\n", "457: loss: 0.0018761098617687821\n", "458: loss: 0.0020015202462673187\n", "459: loss: 9.263550758361816\n", "460: loss: 0.0025061527267098427\n", "461: loss: 0.003998400643467903\n", "462: loss: 0.0031979954801499844\n", "463: loss: 0.0009064731420949101\n", "464: loss: 3.1668450832366943\n", "465: loss: 6.006053924560547\n", "466: loss: 0.0006406777538359165\n", "467: loss: 0.0009267539135180414\n", "468: loss: 0.0012060123262926936\n", "469: loss: 0.0013315295800566673\n", "470: loss: 3.5539376735687256\n", "471: loss: 3.4590916633605957\n", "472: loss: 0.0017678193980827928\n", "473: loss: 0.00218581547960639\n", "474: loss: 0.0025737383402884007\n", "475: loss: 2.97592830657959\n", "476: loss: 0.0032222135923802853\n", "477: loss: 0.0020487091969698668\n", "478: loss: 3.0420033931732178\n", "479: loss: 0.001554043497890234\n", "480: loss: 0.001528518507257104\n", "481: loss: 0.001422215485945344\n", "482: loss: 0.0012641653884202242\n", "483: loss: 0.0010866222437471151\n", "484: loss: 7.149199962615967\n", "485: loss: 0.0010687584290280938\n", "486: loss: 0.0012197017204016447\n", "487: loss: 0.001343191834166646\n", "488: loss: 0.0013996028574183583\n", "489: loss: 0.001371717662550509\n", "490: loss: 3.68569278717041\n", "491: loss: 0.0014253916451707482\n", "492: loss: 0.001504680491052568\n", "493: loss: 0.0014929386088624597\n", "494: loss: 0.0013759569264948368\n", "495: loss: 3.385620355606079\n", "496: loss: 0.0012212302535772324\n", "497: loss: 0.0011952322674915195\n", "498: loss: 3.1083197593688965\n", "499: loss: 8.146794319152832\n", "500: loss: 3.8151681423187256\n", "500: valid loss 3.2241313457489014\n", "501: loss: 0.002565972041338682\n", "502: loss: 4.1275224685668945\n", "503: loss: 0.004586916882544756\n", "504: loss: 3.6200292110443115\n", "505: loss: 0.004917770624160767\n", "506: loss: 0.0035543786361813545\n", "507: loss: 0.002198878675699234\n", "508: loss: 3.9696688652038574\n", "509: loss: 0.0012150105321779847\n", "510: loss: 3.0237858295440674\n", "511: loss: 0.0016711285570636392\n", "512: loss: 0.0017911652103066444\n", "513: loss: 0.001645330572500825\n", "514: loss: 3.3689823150634766\n", "515: loss: 0.0014145843451842666\n", "516: loss: 0.0013438486494123936\n", "517: loss: 0.0011701782932505012\n", "518: loss: 0.0009688445716165006\n", "519: loss: 0.0007915324531495571\n", "520: loss: 4.113221645355225\n", "521: loss: 0.0006360645638778806\n", "522: loss: 0.0006149905384518206\n", "523: loss: 8.360527038574219\n", "524: loss: 0.0006234433385543525\n", "525: loss: 0.0006739232921972871\n", "526: loss: 0.0007281479192897677\n", "527: loss: 0.000767726160120219\n", "528: loss: 0.000772368221078068\n", "529: loss: 0.0007228502072393894\n", "530: loss: 0.0006368369213305414\n", "531: loss: 3.732311725616455\n", "532: loss: 5.932078838348389\n", "533: loss: 3.5892159938812256\n", "534: loss: 5.249965667724609\n", "535: loss: 7.211183071136475\n", "536: loss: 4.0714263916015625\n", "537: loss: 3.1499719619750977\n", "538: loss: 0.1844794750213623\n", "539: loss: 3.4192230701446533\n", "540: loss: 0.011980107054114342\n", "541: loss: 0.010612019337713718\n", "542: loss: 0.0045662750490009785\n", "543: loss: 0.005457601509988308\n", "544: loss: 0.015783555805683136\n", "545: loss: 0.0013816619757562876\n", "546: loss: 8.18481731414795\n", "547: loss: 0.0006438567652367055\n", "548: loss: 0.000572906865272671\n", "549: loss: 10.10994815826416\n", "550: loss: 0.003346000798046589\n", "551: loss: 0.0006713962065987289\n", "552: loss: 0.00026078836526721716\n", "553: loss: 11.756505012512207\n", "554: loss: 7.101832389831543\n", "555: loss: 0.00021459207346197218\n", "556: loss: 0.00025998923229053617\n", "557: loss: 0.0003112201811745763\n", "558: loss: 14.851192474365234\n", "559: loss: 0.0004224810691084713\n", "560: loss: 0.00047494613681919873\n", "561: loss: 0.000519308028742671\n", "562: loss: 0.0005509845213964581\n", "563: loss: 0.0005668219528160989\n", "564: loss: 14.569344520568848\n", "565: loss: 6.4913740158081055\n", "566: loss: 0.0008433411712758243\n", "567: loss: 8.495502471923828\n", "568: loss: 0.0019402098841965199\n", "569: loss: 0.0035519124940037727\n", "570: loss: 0.006841914728283882\n", "571: loss: 4.089066982269287\n", "572: loss: 5.491721153259277\n", "573: loss: 3.87937331199646\n", "574: loss: 0.03460773825645447\n", "575: loss: 0.015647828578948975\n", "576: loss: 0.002720448188483715\n", "577: loss: 6.188972473144531\n", "578: loss: 0.0008381525985896587\n", "579: loss: 0.0008579537970945239\n", "580: loss: 0.0008331844583153725\n", "581: loss: 7.444668769836426\n", "582: loss: 0.0013645365834236145\n", "583: loss: 0.0018909723730757833\n", "584: loss: 4.148159503936768\n", "585: loss: 6.465692043304443\n", "586: loss: 0.0040971520356833935\n", "587: loss: 0.015496809035539627\n", "588: loss: 0.0011185817420482635\n", "589: loss: 0.00048535081441514194\n", "590: loss: 0.0002821610542014241\n", "591: loss: 0.00022055530280340463\n", "592: loss: 0.0002070294285658747\n", "593: loss: 0.00021876658138353378\n", "594: loss: 0.00024527875939384103\n", "595: loss: 0.00028197691426612437\n", "596: loss: 0.00031235843198373914\n", "597: loss: 0.00032129406463354826\n", "598: loss: 0.000305092049529776\n", "599: loss: 6.581624507904053\n", "600: loss: 0.0004181505355518311\n", "600: valid loss 0.001562803634442389\n", "601: loss: 0.001562803634442389\n", "602: loss: 0.0008329854463227093\n", "603: loss: 8.43118953704834\n", "604: loss: 0.00018880203424487263\n", "605: loss: 6.225329399108887\n", "606: loss: 0.0001953585451701656\n", "607: loss: 0.00031005332130007446\n", "608: loss: 6.243394374847412\n", "609: loss: 0.002007008297368884\n", "610: loss: 0.2842656672000885\n", "611: loss: 0.002102950122207403\n", "612: loss: 0.0013235295191407204\n", "613: loss: 0.0012432391522452235\n", "614: loss: 0.0011076040100306273\n", "615: loss: 0.0009366637095808983\n", "616: loss: 0.0007713991799391806\n", "617: loss: 0.0006266268319450319\n", "618: loss: 0.0005072436179034412\n", "619: loss: 0.00041213506483472884\n", "620: loss: 0.0003370844351593405\n", "621: loss: 0.0002783465606626123\n", "622: loss: 6.750359535217285\n", "623: loss: 4.032569408416748\n", "624: loss: 4.749107360839844\n", "625: loss: 5.599199295043945\n", "626: loss: 4.851316452026367\n", "627: loss: 0.0012356003280729055\n", "628: loss: 0.0019876735750585794\n", "629: loss: 0.0022025934886187315\n", "630: loss: 0.09389199316501617\n", "631: loss: 0.0011942394776269794\n", "632: loss: 0.0008771757711656392\n", "633: loss: 0.000724500569049269\n", "634: loss: 4.850365161895752\n", "635: loss: 6.96458101272583\n", "636: loss: 3.944305658340454\n", "637: loss: 1.573992133140564\n", "638: loss: 0.006376080680638552\n", "639: loss: 0.004621799103915691\n", "640: loss: 0.008686978369951248\n", "641: loss: 0.002786734839901328\n", "642: loss: 0.0012673415476456285\n", "643: loss: 0.0008905518334358931\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "Cell \u001b[1;32mIn[3], line 78\u001b[0m\n\u001b[0;32m 72\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m semantic_transformer, trainer, wav2vec\n\u001b[0;32m 73\u001b[0m gc\u001b[38;5;241m.\u001b[39mcollect()\n\u001b[1;32m---> 78\u001b[0m \u001b[43mtrain_semantic_transformer\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", "Cell \u001b[1;32mIn[3], line 69\u001b[0m, in \u001b[0;36mtrain_semantic_transformer\u001b[1;34m()\u001b[0m\n\u001b[0;32m 52\u001b[0m semantic_transformer \u001b[38;5;241m=\u001b[39m SemanticTransformer(\n\u001b[0;32m 53\u001b[0m num_semantic_tokens\u001b[38;5;241m=\u001b[39mwav2vec\u001b[38;5;241m.\u001b[39mcodebook_size,\n\u001b[0;32m 54\u001b[0m dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1024\u001b[39m,\n\u001b[0;32m 55\u001b[0m depth\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m6\u001b[39m,\n\u001b[0;32m 56\u001b[0m audio_text_condition\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m 57\u001b[0m )\n\u001b[0;32m 59\u001b[0m trainer \u001b[38;5;241m=\u001b[39m SemanticTransformerTrainer(\n\u001b[0;32m 60\u001b[0m transformer\u001b[38;5;241m=\u001b[39msemantic_transformer,\n\u001b[0;32m 61\u001b[0m wav2vec\u001b[38;5;241m=\u001b[39mwav2vec,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 66\u001b[0m num_train_steps\u001b[38;5;241m=\u001b[39mnum_train_steps\n\u001b[0;32m 67\u001b[0m )\n\u001b[1;32m---> 69\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 70\u001b[0m torch\u001b[38;5;241m.\u001b[39msave(semantic_transformer\u001b[38;5;241m.\u001b[39mstate_dict(), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msemantic_transformer.pth\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m 71\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msave semantic_transformer.pth\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiolm_pytorch\\trainer.py:1000\u001b[0m, in \u001b[0;36mSemanticTransformerTrainer.train\u001b[1;34m(self, log_fn)\u001b[0m\n\u001b[0;32m 997\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtrain\u001b[39m(\u001b[38;5;28mself\u001b[39m, log_fn \u001b[38;5;241m=\u001b[39m noop):\n\u001b[0;32m 999\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msteps \u001b[38;5;241m<\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_train_steps:\n\u001b[1;32m-> 1000\u001b[0m logs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1001\u001b[0m log_fn(logs)\n\u001b[0;32m 1003\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprint(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtraining complete\u001b[39m\u001b[38;5;124m'\u001b[39m)\n", "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiolm_pytorch\\trainer.py:944\u001b[0m, in \u001b[0;36mSemanticTransformerTrainer.train_step\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 941\u001b[0m data_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdata_tuple_to_kwargs(\u001b[38;5;28mnext\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdl_iter))\n\u001b[0;32m 943\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mautocast(), context():\n\u001b[1;32m--> 944\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_wrapper\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mdata_kwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_loss\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 946\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mbackward(loss \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgrad_accum_every)\n\u001b[0;32m 948\u001b[0m accum_log(logs, {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mloss\u001b[39m\u001b[38;5;124m'\u001b[39m: loss\u001b[38;5;241m.\u001b[39mitem() \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgrad_accum_every})\n", "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiolm_pytorch\\audiolm_pytorch.py:1480\u001b[0m, in \u001b[0;36mSemanticTransformerWrapper.forward\u001b[1;34m(self, semantic_token_ids, raw_wave, text, text_embeds, return_loss, **kwargs)\u001b[0m\n\u001b[0;32m 1478\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m exists(raw_wave)\n\u001b[0;32m 1479\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m exists(text) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m exists(text_embeds)\n\u001b[1;32m-> 1480\u001b[0m text_embeds \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maudio_conditioner\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwavs\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mraw_wave\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnamespace\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43msemantic\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1482\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m exists(semantic_token_ids):\n\u001b[0;32m 1483\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m exists(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwav2vec), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mVQWav2Vec must be be provided if given raw wave for training\u001b[39m\u001b[38;5;124m'\u001b[39m\n", "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\musiclm_pytorch\\musiclm_pytorch.py:872\u001b[0m, in \u001b[0;36mMuLaNEmbedQuantizer.forward\u001b[1;34m(self, wavs, texts, namespace)\u001b[0m\n\u001b[0;32m 869\u001b[0m \u001b[38;5;66;03m# sound and language live in joint embedding space because of contrastive learning\u001b[39;00m\n\u001b[0;32m 871\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m exists(wavs):\n\u001b[1;32m--> 872\u001b[0m latents \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmulan\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_audio_latents\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwavs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 873\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m exists(texts):\n\u001b[0;32m 874\u001b[0m latents \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmulan\u001b[38;5;241m.\u001b[39mget_text_latents(texts)\n", "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\musiclm_pytorch\\musiclm_pytorch.py:732\u001b[0m, in \u001b[0;36mMuLaN.get_audio_latents\u001b[1;34m(self, wavs, return_all_layers)\u001b[0m\n\u001b[0;32m 727\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_audio_latents\u001b[39m(\n\u001b[0;32m 728\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[0;32m 729\u001b[0m wavs,\n\u001b[0;32m 730\u001b[0m return_all_layers \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m 731\u001b[0m ):\n\u001b[1;32m--> 732\u001b[0m audio_embeds, audio_layers \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maudio\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwavs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_all_layers\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 733\u001b[0m audio_latents \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maudio_to_latents(audio_embeds)\n\u001b[0;32m 734\u001b[0m out \u001b[38;5;241m=\u001b[39m l2norm(audio_latents)\n", "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\musiclm_pytorch\\musiclm_pytorch.py:525\u001b[0m, in \u001b[0;36mAudioSpectrogramTransformer.forward\u001b[1;34m(self, x, force_no_patch_dropout, return_all_layers)\u001b[0m\n\u001b[0;32m 521\u001b[0m rel_pos_bias \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdynamic_pos_bias_mlp(rel_dist\u001b[38;5;241m.\u001b[39mfloat())\n\u001b[0;32m 523\u001b[0m \u001b[38;5;66;03m# attention, what else\u001b[39;00m\n\u001b[1;32m--> 525\u001b[0m x, all_layers \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtransformer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrel_pos_bias\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mrel_pos_bias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_all_layers\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 527\u001b[0m \u001b[38;5;66;03m# final global average and norm (most recent papers show this is superior to CLS token)\u001b[39;00m\n\u001b[0;32m 529\u001b[0m x \u001b[38;5;241m=\u001b[39m reduce(x, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mb n d -> b d\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmean\u001b[39m\u001b[38;5;124m'\u001b[39m)\n", "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\musiclm_pytorch\\musiclm_pytorch.py:247\u001b[0m, in \u001b[0;36mTransformer.forward\u001b[1;34m(self, x, rel_pos_bias, mask, return_all_layers)\u001b[0m\n\u001b[0;32m 245\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m attn, ff \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayers:\n\u001b[0;32m 246\u001b[0m x \u001b[38;5;241m=\u001b[39m attn(x, rel_pos_bias \u001b[38;5;241m=\u001b[39m rel_pos_bias, mask \u001b[38;5;241m=\u001b[39m mask) \u001b[38;5;241m+\u001b[39m x\n\u001b[1;32m--> 247\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[43mff\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;241m+\u001b[39m x\n\u001b[0;32m 248\u001b[0m layers\u001b[38;5;241m.\u001b[39mappend(x)\n\u001b[0;32m 250\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m return_all_layers:\n", "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\container.py:217\u001b[0m, in \u001b[0;36mSequential.forward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 215\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[0;32m 216\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[1;32m--> 217\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 218\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n", "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", "\u001b[1;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "audio_transformer = AudioSpectrogramTransformer(\n", " dim = 512,\n", " depth = 6,\n", " heads = 8,\n", " dim_head = 64,\n", " spec_n_fft = 128,\n", " spec_win_length = 24,\n", " spec_aug_stretch_factor = 0.8\n", ")\n", "\n", "text_transformer = TextTransformer(\n", " dim = 512,\n", " depth = 6,\n", " heads = 8,\n", " dim_head = 64\n", ")\n", "\n", "mulan = MuLaN(\n", " audio_transformer = audio_transformer,\n", " text_transformer = text_transformer\n", ")\n", "\n", "# setup the quantizer with the namespaced conditioning embeddings, unique per quantizer as well as namespace (per transformer)\n", "\n", "quantizer = MuLaNEmbedQuantizer(\n", " mulan = mulan, # pass in trained mulan from above\n", " conditioning_dims = (1024, 1024, 1024), # say all three transformers have model dimensions of 1024\n", " namespaces = ('semantic', 'coarse', 'fine')\n", ")\n", "\n", "# now say you want the conditioning embeddings for semantic transformer\n", "\n", "wavs = torch.randn(2, 1024)\n", "conds = quantizer(wavs = wavs, namespace = 'semantic') # (2, 8, 1024) - 8 is number of quantizers\n", "\n", "# SemanticTransformer\n", "def train_semantic_transformer():\n", " wav2vec = HubertWithKmeans(\n", " checkpoint_path=checkpoint_path,\n", " kmeans_path=kmeans_path\n", " )\n", "\n", "\n", " if torch.cuda.is_available():\n", " semantic_transformer = SemanticTransformer(\n", " num_semantic_tokens=wav2vec.codebook_size,\n", " dim=1024,\n", " depth=6,\n", " audio_text_condition=True\n", " ).cuda()\n", " else:\n", " semantic_transformer = SemanticTransformer(\n", " num_semantic_tokens=wav2vec.codebook_size,\n", " dim=1024,\n", " depth=6,\n", " audio_text_condition=True\n", " )\n", "\n", " trainer = SemanticTransformerTrainer(\n", " transformer=semantic_transformer,\n", " wav2vec=wav2vec,\n", " audio_conditioner=quantizer,\n", " folder=audio_output_dir,\n", " batch_size=batch_size,\n", " data_max_length=data_max_length,\n", " num_train_steps=num_train_steps\n", " )\n", "\n", " trainer.train()\n", " torch.save(semantic_transformer.state_dict(), 'semantic_transformer.pth')\n", " print(\"save semantic_transformer.pth\")\n", " del semantic_transformer, trainer, wav2vec\n", " gc.collect()\n", "\n", "\n", "\n", "\n", "train_semantic_transformer()" ] } ], "metadata": { "kernelspec": { "display_name": "myenv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.2" } }, "nbformat": 4, "nbformat_minor": 2 }