waidhoferj's picture
first commit
c914273
raw
history blame
3.53 kB
from pathlib import Path
import gradio as gr
import numpy as np
import torch
from preprocessing.preprocess import AudioPipeline
from preprocessing.preprocess import AudioPipeline
from dancer_net.dancer_net import ShortChunkCNN
import os
import json
from functools import cache
import pandas as pd
@cache
def get_model(device) -> tuple[ShortChunkCNN, np.ndarray]:
model_path = "logs/20221226-230930"
weights = os.path.join(model_path, "dancer_net.pt")
config_path = os.path.join(model_path, "config.json")
with open(config_path) as f:
config = json.load(f)
labels = np.array(sorted(config["classes"]))
model = ShortChunkCNN(n_class=len(labels))
model.load_state_dict(torch.load(weights))
model = model.to(device).eval()
return model, labels
@cache
def get_pipeline(sample_rate:int) -> AudioPipeline:
return AudioPipeline(input_freq=sample_rate)
@cache
def get_dance_map() -> dict:
df = pd.read_csv("data/dance_mapping.csv")
return df.set_index("id").to_dict()["name"]
def predict(audio: tuple[int, np.ndarray]) -> list[str]:
sample_rate, waveform = audio
expected_duration = 6
threshold = 0.5
sample_len = sample_rate * expected_duration
device = "mps"
audio_pipeline = get_pipeline(sample_rate)
model, labels = get_model(device)
if sample_len > len(waveform):
raise gr.Error("You must record for at least 6 seconds")
if len(waveform.shape) > 1 and waveform.shape[1] > 1:
waveform = waveform.transpose(1,0)
waveform = waveform.mean(axis=0, keepdims=True)
else:
waveform = np.expand_dims(waveform, 0)
waveform = waveform[: ,:sample_len]
waveform = (waveform - waveform.min()) / (waveform.max() - waveform.min()) * 2 - 1
waveform = waveform.astype("float32")
waveform = torch.from_numpy(waveform)
spectrogram = audio_pipeline(waveform)
spectrogram = spectrogram.unsqueeze(0).to(device)
with torch.no_grad():
results = model(spectrogram)
dance_mapping = get_dance_map()
results = results.squeeze(0).detach().cpu().numpy()
result_mask = results > threshold
probs = results[result_mask]
dances = labels[result_mask]
return {dance_mapping[dance_id]:float(prob) for dance_id, prob in zip(dances, probs)} if len(dances) else "Couldn't find a dance."
def demo():
title = "Dance Classifier"
description = "Record 6 seconds of a song and find out what dance fits the music."
with gr.Blocks() as app:
gr.Markdown(f"# {title}")
gr.Markdown(description)
with gr.Tab("Record Song"):
mic_audio = gr.Audio(source="microphone", label="Song Recording")
mic_submit = gr.Button("Predict")
with gr.Tab("Upload Song") as t:
audio_file = gr.Audio(label="Song Audio File")
audio_file_submit = gr.Button("Predict")
song_samples = Path(os.path.dirname(__file__), "assets", "song-samples")
example_audio = [str(song) for song in song_samples.iterdir() if song.name[0] != '.']
labels = gr.Label(label="Dances")
gr.Markdown("## Examples")
gr.Examples(
examples=example_audio,
inputs=audio_file,
outputs=labels,
fn=predict,
)
audio_file_submit.click(fn=predict, inputs=audio_file, outputs=labels)
mic_submit.click(fn=predict, inputs=mic_audio, outputs=labels)
return app
if __name__ == "__main__":
demo().launch()