waidhoferj commited on
Commit
42c4703
1 Parent(s): 248f682

updated models

Browse files
app.py CHANGED
@@ -1,23 +1,113 @@
1
  from pathlib import Path
2
  import gradio as gr
3
  import numpy as np
4
- from models.residual import DancePredictor
5
  import os
6
  from functools import cache
7
  from pathlib import Path
8
- CONFIG_FILE = Path("models/config/dance-predictor.yaml")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  @cache
12
- def get_model(config_path:str) -> DancePredictor:
13
  model = DancePredictor.from_config(config_path)
14
  return model
15
 
 
16
  def predict(audio: tuple[int, np.ndarray]) -> list[str]:
17
  sample_rate, waveform = audio
18
-
19
  model = get_model(CONFIG_FILE)
20
- results = model(waveform,sample_rate)
21
  return results if len(results) else "Dance Not Found"
22
 
23
 
@@ -25,34 +115,36 @@ def demo():
25
  title = "Dance Classifier"
26
  description = "What should I dance to this song? Pass some audio to the Dance Classifier find out!"
27
  song_samples = Path(os.path.dirname(__file__), "assets", "song-samples")
28
- example_audio = [str(song) for song in song_samples.iterdir() if song.name[0] != '.']
 
 
29
  all_dances = get_model(CONFIG_FILE).labels
30
-
31
  recording_interface = gr.Interface(
32
  fn=predict,
33
  description="Record at least **6 seconds** of the song.",
34
  inputs=gr.Audio(source="microphone", label="Song Recording"),
35
  outputs=gr.Label(label="Dances"),
36
- examples=example_audio
37
  )
38
  uploading_interface = gr.Interface(
39
  fn=predict,
40
  inputs=gr.Audio(label="Song Audio File"),
41
  outputs=gr.Label(label="Dances"),
42
- examples=example_audio
43
  )
44
-
45
  with gr.Blocks() as app:
46
  gr.Markdown(f"# {title}")
47
  gr.Markdown(description)
48
- gr.TabbedInterface([uploading_interface, recording_interface], ["Upload Song", "Record Song"])
 
 
49
  with gr.Accordion("See all dances", open=False):
50
  gr.Markdown("\n".join(f"- {dance}" for dance in all_dances))
51
 
52
-
53
-
54
  return app
55
 
56
 
57
  if __name__ == "__main__":
58
- demo().launch()
 
1
  from pathlib import Path
2
  import gradio as gr
3
  import numpy as np
 
4
  import os
5
  from functools import cache
6
  from pathlib import Path
7
+ from models.audio_spectrogram_transformer import AST, ASTExtractorWrapper
8
+ from models.training_environment import TrainingEnvironment
9
+ import torch
10
+ from torch import nn
11
+ import yaml
12
+ import torchaudio
13
+
14
+ CONFIG_FILE = Path("models/config/train_local.yaml")
15
+ MODEL_CLS = AST
16
+ EXTRACTOR = ASTExtractorWrapper
17
+
18
+
19
+ class DancePredictor:
20
+ def __init__(
21
+ self,
22
+ weight_path: str,
23
+ labels: list[str],
24
+ expected_duration=6,
25
+ threshold=0.5,
26
+ resample_frequency=16000,
27
+ device="cpu",
28
+ ):
29
+ super().__init__()
30
+
31
+ self.expected_duration = expected_duration
32
+ self.threshold = threshold
33
+ self.resample_frequency = resample_frequency
34
+
35
+ self.labels = np.array(labels)
36
+ self.device = device
37
+ self.model = self.get_model(weight_path)
38
+ self.extractor = ASTExtractorWrapper()
39
+
40
+ def get_model(self, weight_path: str) -> nn.Module:
41
+ weights = torch.load(weight_path, map_location=self.device)["state_dict"]
42
+ model = AST(self.labels).to(self.device)
43
+ for key in list(weights):
44
+ weights[
45
+ key.replace(
46
+ "model.",
47
+ "",
48
+ )
49
+ ] = weights.pop(key)
50
+ model.load_state_dict(weights, strict=False)
51
+ return model.to(self.device).eval()
52
+
53
+ @classmethod
54
+ def from_config(cls, config_path: str) -> "DancePredictor":
55
+ with open(config_path, "r") as f:
56
+ config = yaml.safe_load(f)
57
+ weight_path = config["checkpoint"]
58
+ labels = sorted(config["dance_ids"])
59
+ expected_duration = 6
60
+ threshold = 0.5
61
+ resample_frequency = 16000
62
+ device = "mps"
63
+ return DancePredictor(
64
+ weight_path,
65
+ labels,
66
+ expected_duration,
67
+ threshold,
68
+ resample_frequency,
69
+ device,
70
+ )
71
+
72
+ @torch.no_grad()
73
+ def __call__(self, waveform: np.ndarray, sample_rate: int) -> dict[str, float]:
74
+ if waveform.ndim == 1:
75
+ waveform = np.stack([waveform, waveform]).T
76
+ waveform = torch.from_numpy(waveform.T)
77
+ waveform = torchaudio.functional.apply_codec(
78
+ waveform, sample_rate, "wav", channels_first=True
79
+ )
80
+
81
+ waveform = torchaudio.functional.resample(
82
+ waveform, sample_rate, self.resample_frequency
83
+ )
84
+ waveform = waveform[
85
+ :, : self.resample_frequency * self.expected_duration
86
+ ] # TODO PAD
87
+ features = self.extractor(waveform)
88
+ features = features.unsqueeze(0).to(self.device)
89
+ results = self.model(features)
90
+ results = nn.functional.softmax(results.squeeze(0), dim=0)
91
+ results = results.detach().cpu().numpy()
92
+
93
+ result_mask = results > self.threshold
94
+ probs = results[result_mask]
95
+ dances = self.labels[result_mask]
96
+
97
+ return {dance: float(prob) for dance, prob in zip(dances, probs)}
98
 
99
 
100
  @cache
101
+ def get_model(config_path: str) -> DancePredictor:
102
  model = DancePredictor.from_config(config_path)
103
  return model
104
 
105
+
106
  def predict(audio: tuple[int, np.ndarray]) -> list[str]:
107
  sample_rate, waveform = audio
108
+
109
  model = get_model(CONFIG_FILE)
110
+ results = model(waveform, sample_rate)
111
  return results if len(results) else "Dance Not Found"
112
 
113
 
 
115
  title = "Dance Classifier"
116
  description = "What should I dance to this song? Pass some audio to the Dance Classifier find out!"
117
  song_samples = Path(os.path.dirname(__file__), "assets", "song-samples")
118
+ example_audio = [
119
+ str(song) for song in song_samples.iterdir() if song.name[0] != "."
120
+ ]
121
  all_dances = get_model(CONFIG_FILE).labels
122
+
123
  recording_interface = gr.Interface(
124
  fn=predict,
125
  description="Record at least **6 seconds** of the song.",
126
  inputs=gr.Audio(source="microphone", label="Song Recording"),
127
  outputs=gr.Label(label="Dances"),
128
+ examples=example_audio,
129
  )
130
  uploading_interface = gr.Interface(
131
  fn=predict,
132
  inputs=gr.Audio(label="Song Audio File"),
133
  outputs=gr.Label(label="Dances"),
134
+ examples=example_audio,
135
  )
136
+
137
  with gr.Blocks() as app:
138
  gr.Markdown(f"# {title}")
139
  gr.Markdown(description)
140
+ gr.TabbedInterface(
141
+ [uploading_interface, recording_interface], ["Upload Song", "Record Song"]
142
+ )
143
  with gr.Accordion("See all dances", open=False):
144
  gr.Markdown("\n".join(f"- {dance}" for dance in all_dances))
145
 
 
 
146
  return app
147
 
148
 
149
  if __name__ == "__main__":
150
+ demo().launch()
models/audio_spectrogram_transformer.py CHANGED
@@ -88,13 +88,17 @@ def train_lightning_ast(config: dict):
88
  target_classes=TARGET_CLASSES,
89
  **config["data_module"],
90
  )
91
-
92
  model = AST(TARGET_CLASSES).to(DEVICE)
93
  label_weights = data.get_label_weights().to(DEVICE)
94
  criterion = nn.CrossEntropyLoss(
95
  label_weights
96
  ) # LabelWeightedBCELoss(label_weights)
97
- train_env = TrainingEnvironment(model, criterion, config)
 
 
 
 
 
98
  callbacks = [
99
  # cb.LearningRateFinder(update_attr=True),
100
  cb.EarlyStopping("val/loss", patience=5),
 
88
  target_classes=TARGET_CLASSES,
89
  **config["data_module"],
90
  )
 
91
  model = AST(TARGET_CLASSES).to(DEVICE)
92
  label_weights = data.get_label_weights().to(DEVICE)
93
  criterion = nn.CrossEntropyLoss(
94
  label_weights
95
  ) # LabelWeightedBCELoss(label_weights)
96
+ if "checkpoint" in config:
97
+ train_env = TrainingEnvironment.load_from_checkpoint(
98
+ config["checkpoint"], criterion=criterion, model=model, config=config
99
+ )
100
+ else:
101
+ train_env = TrainingEnvironment(model, criterion, config)
102
  callbacks = [
103
  # cb.LearningRateFinder(update_attr=True),
104
  cb.EarlyStopping("val/loss", patience=5),
models/config/train_local.yaml CHANGED
@@ -1,4 +1,5 @@
1
- training_fn: audio_spectrogram_transformer.train_lightning_ast
 
2
  device: mps
3
  seed: 42
4
  dance_ids: &dance_ids
@@ -23,10 +24,10 @@ data_module:
23
  test_proportion: 0.2
24
 
25
  datasets:
26
- preprocessing.dataset.BestBallroomDataset:
27
- audio_dir: data/ballroom-songs
28
- class_list: *dance_ids
29
- audio_window_jitter: 0.7
30
 
31
  preprocessing.dataset.Music4DanceDataset:
32
  song_data_path: data/songs_cleaned.csv
@@ -49,7 +50,7 @@ trainer:
49
  log_every_n_steps: 15
50
  accelerator: gpu
51
  max_epochs: 50
52
- min_epochs: 7
53
  fast_dev_run: False
54
  # gradient_clip_val: 0.5
55
  # overfit_batches: 1
 
1
+ training_fn: wav2vec2.train_huggingface
2
+ checkpoint: lightning_logs/version_172/checkpoints/epoch=3-step=4572.ckpt
3
  device: mps
4
  seed: 42
5
  dance_ids: &dance_ids
 
24
  test_proportion: 0.2
25
 
26
  datasets:
27
+ # preprocessing.dataset.BestBallroomDataset:
28
+ # audio_dir: data/ballroom-songs
29
+ # class_list: *dance_ids
30
+ # audio_window_jitter: 0.7
31
 
32
  preprocessing.dataset.Music4DanceDataset:
33
  song_data_path: data/songs_cleaned.csv
 
50
  log_every_n_steps: 15
51
  accelerator: gpu
52
  max_epochs: 50
53
+ min_epochs: 2
54
  fast_dev_run: False
55
  # gradient_clip_val: 0.5
56
  # overfit_batches: 1
models/residual.py CHANGED
@@ -107,70 +107,6 @@ class ResBlock(nn.Module):
107
  return out
108
 
109
 
110
- class DancePredictor:
111
- def __init__(
112
- self,
113
- weight_path: str,
114
- labels: list[str],
115
- expected_duration=6,
116
- threshold=0.5,
117
- resample_frequency=16000,
118
- device="cpu",
119
- ):
120
- super().__init__()
121
-
122
- self.expected_duration = expected_duration
123
- self.threshold = threshold
124
- self.resample_frequency = resample_frequency
125
- self.preprocess_waveform = WaveformPreprocessing(
126
- resample_frequency * expected_duration
127
- )
128
- self.audio_to_spectrogram = lambda x: x # TODO: Fix
129
- self.labels = np.array(labels)
130
- self.device = device
131
- self.model = self.get_model(weight_path)
132
-
133
- def get_model(self, weight_path: str) -> nn.Module:
134
- weights = torch.load(weight_path, map_location=self.device)["state_dict"]
135
- model = ResidualDancer(n_classes=len(self.labels))
136
- for key in list(weights):
137
- weights[key.replace("model.", "")] = weights.pop(key)
138
- model.load_state_dict(weights)
139
- return model.to(self.device).eval()
140
-
141
- @classmethod
142
- def from_config(cls, config_path: str) -> "DancePredictor":
143
- with open(config_path, "r") as f:
144
- config = yaml.safe_load(f)
145
- return DancePredictor(**config)
146
-
147
- @torch.no_grad()
148
- def __call__(self, waveform: np.ndarray, sample_rate: int) -> dict[str, float]:
149
- if len(waveform.shape) > 1 and waveform.shape[1] < waveform.shape[0]:
150
- waveform = waveform.transpose(1, 0)
151
- elif len(waveform.shape) == 1:
152
- waveform = np.expand_dims(waveform, 0)
153
- waveform = torch.from_numpy(waveform.astype("int16"))
154
- waveform = torchaudio.functional.apply_codec(
155
- waveform, sample_rate, "wav", channels_first=True
156
- )
157
-
158
- waveform = torchaudio.functional.resample(
159
- waveform, sample_rate, self.resample_frequency
160
- )
161
- waveform = self.preprocess_waveform(waveform)
162
- spectrogram = self.audio_to_spectrogram(waveform)
163
- spectrogram = spectrogram.unsqueeze(0).to(self.device)
164
-
165
- results = self.model(spectrogram)
166
- results = results.squeeze(0).detach().cpu().numpy()
167
- result_mask = results > self.threshold
168
- probs = results[result_mask]
169
- dances = self.labels[result_mask]
170
-
171
- return {dance: float(prob) for dance, prob in zip(dances, probs)}
172
-
173
-
174
  def train_residual_dancer(config: dict):
175
  TARGET_CLASSES = config["dance_ids"]
176
  DEVICE = config["device"]
 
107
  return out
108
 
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  def train_residual_dancer(config: dict):
111
  TARGET_CLASSES = config["dance_ids"]
112
  DEVICE = config["device"]
models/training_environment.py CHANGED
@@ -17,10 +17,12 @@ class TrainingEnvironment(pl.LightningModule):
17
  *args,
18
  **kwargs,
19
  ):
20
- super().__init__(*args, **kwargs)
21
  self.model = model
22
  self.criterion = criterion
23
- self.learning_rate = learning_rate
 
 
24
  self.experiment_loggers = load_loggers(
25
  config["training_environment"].get("loggers", {})
26
  )
@@ -64,7 +66,7 @@ class TrainingEnvironment(pl.LightningModule):
64
  preds, y, prefix="val/", multi_label=self.has_multi_label_predictions
65
  )
66
  metrics["val/loss"] = self.criterion(preds, y)
67
- self.log_dict(metrics, prog_bar=True)
68
 
69
  def test_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int):
70
  x, y = batch
 
17
  *args,
18
  **kwargs,
19
  ):
20
+ super().__init__()
21
  self.model = model
22
  self.criterion = criterion
23
+ self.learning_rate = config["training_environment"].get(
24
+ "learning_rate", learning_rate
25
+ )
26
  self.experiment_loggers = load_loggers(
27
  config["training_environment"].get("loggers", {})
28
  )
 
66
  preds, y, prefix="val/", multi_label=self.has_multi_label_predictions
67
  )
68
  metrics["val/loss"] = self.criterion(preds, y)
69
+ self.log_dict(metrics, prog_bar=True, sync_dist=True)
70
 
71
  def test_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int):
72
  x, y = batch
models/wav2vec2.py CHANGED
@@ -7,14 +7,13 @@ from transformers import AutoModelForAudioClassification, TrainingArguments, Tra
7
 
8
  from preprocessing.dataset import (
9
  HuggingFaceDatasetWrapper,
10
- BestBallroomDataset,
11
  get_datasets,
12
  )
13
  from preprocessing.pipelines import WaveformTrainingPipeline
14
 
15
  from .utils import get_id_label_mapping, compute_hf_metrics
16
 
17
- MODEL_CHECKPOINT = "facebook/wav2vec2-base"
18
 
19
 
20
  class Wav2VecFeatureExtractor:
 
7
 
8
  from preprocessing.dataset import (
9
  HuggingFaceDatasetWrapper,
 
10
  get_datasets,
11
  )
12
  from preprocessing.pipelines import WaveformTrainingPipeline
13
 
14
  from .utils import get_id_label_mapping, compute_hf_metrics
15
 
16
+ MODEL_CHECKPOINT = "m3hrdadfi/wav2vec2-base-100k-voxpopuli-gtzan-music"
17
 
18
 
19
  class Wav2VecFeatureExtractor: