# Semantic Transformer

### Libraries

In [1]:
import torch
import multiprocessing
from audiolm_pytorch import HubertWithKmeans, MusicLMSoundStream
from audiolm_pytorch import SemanticTransformer, SemanticTransformerTrainer
from audiolm_pytorch import CoarseTransformer, CoarseTransformerTrainer
from audiolm_pytorch import FineTransformer, FineTransformerTrainer
from musiclm_pytorch import MuLaNEmbedQuantizer
from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer
import gc 

In [2]:
checkpoint_path = './models/hubert/hubert_base_ls960.pt'
kmeans_path = './models/hubert/hubert_base_ls960_L9_km500.bin'
audio_output_dir = './audio'
batch_size = 1
data_max_length = 320 * 32
num_train_steps = 1000

In [3]:
audio_transformer = AudioSpectrogramTransformer(
 dim = 512,
 depth = 6,
 heads = 8,
 dim_head = 64,
 spec_n_fft = 128,
 spec_win_length = 24,
 spec_aug_stretch_factor = 0.8
)

text_transformer = TextTransformer(
 dim = 512,
 depth = 6,
 heads = 8,
 dim_head = 64
)

mulan = MuLaN(
 audio_transformer = audio_transformer,
 text_transformer = text_transformer
)

# setup the quantizer with the namespaced conditioning embeddings, unique per quantizer as well as namespace (per transformer)

quantizer = MuLaNEmbedQuantizer(
 mulan = mulan, # pass in trained mulan from above
 conditioning_dims = (1024, 1024, 1024), # say all three transformers have model dimensions of 1024
 namespaces = ('semantic', 'coarse', 'fine')
)

# now say you want the conditioning embeddings for semantic transformer

wavs = torch.randn(2, 1024)
conds = quantizer(wavs = wavs, namespace = 'semantic') # (2, 8, 1024) - 8 is number of quantizers

# SemanticTransformer
def train_semantic_transformer():
 wav2vec = HubertWithKmeans(
 checkpoint_path=checkpoint_path,
 kmeans_path=kmeans_path
 )


 if torch.cuda.is_available():
 semantic_transformer = SemanticTransformer(
 num_semantic_tokens=wav2vec.codebook_size,
 dim=1024,
 depth=6,
 audio_text_condition=True
 ).cuda()
 else:
 semantic_transformer = SemanticTransformer(
 num_semantic_tokens=wav2vec.codebook_size,
 dim=1024,
 depth=6,
 audio_text_condition=True
 )

 trainer = SemanticTransformerTrainer(
 transformer=semantic_transformer,
 wav2vec=wav2vec,
 audio_conditioner=quantizer,
 folder=audio_output_dir,
 batch_size=batch_size,
 data_max_length=data_max_length,
 num_train_steps=num_train_steps
 )

 trainer.train()
 torch.save(semantic_transformer.state_dict(), 'semantic_transformer.pth')
 print("save semantic_transformer.pth")
 del semantic_transformer, trainer, wav2vec
 gc.collect()




train_semantic_transformer()

spectrogram yielded shape of (65, 86), but had to be cropped to (64, 80) to be patchified for transformer
ANTLR runtime and generated code versions disagree: 4.9.3!=4.8
ANTLR runtime and generated code versions disagree: 4.9.3!=4.8
training with dataset of 4806 samples and validating with randomly splitted 253 samples
0: loss: 6.5572309494018555
0: valid loss 6.723005294799805
0: saving model to results
1: loss: 6.5375285148620605
2: loss: 5.515031337738037
3: loss: 0.6989991664886475
4: loss: 0.016623886302113533
5: loss: 6.3969268798828125
6: loss: 0.8643577098846436
7: loss: 0.008508207276463509
8: loss: 0.00020680516900029033
9: loss: 8.900370597839355
10: loss: 0.00010900969209615141
11: loss: 0.0001591881300555542
12: loss: 8.055902481079102
13: loss: 0.0009496303973719478
14: loss: 0.0027423782739788294
15: loss: 0.0009589337860234082
16: loss: 7.296541690826416
17: loss: 0.0005210856324993074
18: loss: 0.0008424322586506605
19: loss: 5.571179389953613
20: loss: 0.00309458118863

KeyboardInterrupt: 