Project Beatrice
Initial commit
d95893e
raw
history blame
No virus
115 kB
# %% [markdown]
# ## Settings
# %%
import argparse
import gc
import json
import math
import os
import shutil
import warnings
from collections import defaultdict
from copy import deepcopy
from fractions import Fraction
from functools import partial
from pathlib import Path
from random import Random
from typing import BinaryIO, Literal, Optional, Union
import numpy as np
import pyworld
import torch
import torch.nn as nn
import torchaudio
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
assert "soundfile" in torchaudio.list_audio_backends()
# モジュールのバージョンではない
PARAPHERNALIA_VERSION = "2.0.0-alpha.2"
def is_notebook() -> bool:
return "get_ipython" in globals()
def repo_root() -> Path:
d = Path.cwd() / "dummy" if is_notebook() else Path(__file__)
assert d.is_absolute(), d
for d in d.parents:
if (d / ".git").is_dir():
return d
raise RuntimeError("Repository root is not found.")
# ハイパーパラメータ
# 学習データや出力ディレクトリなど、学習ごとに変わるようなものはここに含めない
dict_default_hparams = {
# train
"learning_rate": 1e-4,
"min_learning_rate": 5e-6,
"adam_betas": [0.8, 0.99],
"adam_eps": 1e-6,
"batch_size": 8,
"grad_weight_mel": 3.0, # grad_weight は比が同じなら同じ意味になるはず
"grad_weight_adv": 1.0,
"grad_weight_fm": 1.0,
"grad_balancer_ema_decay": 0.995,
"use_amp": True,
"num_workers": min(os.cpu_count(), max(16, os.cpu_count() - 1)),
"n_steps": 3000000,
"warmup_steps": 10000,
"in_sample_rate": 16000, # 変更不可
"out_sample_rate": 24000, # 変更不可
"wav_length": 4 * 24000, # 4s
"segment_length": 100, # 1s
# data
"phone_extractor_file": "notebooks/003b/checkpoint_03000000.pt", # TODO
# "phone_extractor_file": "",
"pitch_estimator_file": "notebooks/034pre/008_1_checkpoint_00300000.pt", # TODO
# "pitch_estimator_file": "",
"in_ir_wav_dir": "../cavorite-ball-hf/data/ir", # TODO
# "in_ir_wav_dir": "data/ir",
"in_noise_wav_dir": "../DNS-Challenge/datasets_fullband/noise_fullband", # TODO
# "in_noise_wav_dir": "data/noise",
"in_test_wav_dir": "../cavorite-ball-hf/data/test", # TODO
# "in_test_wav_dir": "data/test",
"pretrained_file": None,
# model
"hidden_channels": 256, # ファインチューン時変更不可、変更した場合は推論側の対応必要
"san": False, # ファインチューン時変更不可
}
if __name__ == "__main__":
# スクリプト内部のデフォルト設定と assets/default_config.json が同期されているか確認
default_config_file = repo_root() / "assets/default_config.json"
if default_config_file.is_file():
with open(default_config_file, encoding="utf-8") as f:
default_config: dict = json.load(f)
for key, value in dict_default_hparams.items():
if key not in default_config:
warnings.warn(f"{key} not found in default_config.json.")
else:
if value != default_config[key]:
warnings.warn(
f"{key} differs between default_config.json ({default_config[key]}) and internal default hparams ({value})."
)
del default_config[key]
for key in default_config:
warnings.warn(f"{key} found in default_config.json is unknown.")
else:
warnings.warn("dafualt_config.json not found.")
def prepare_training_configs_for_experiment() -> tuple[dict, Path, Path, bool]:
import ipynbname
from IPython import get_ipython
h = deepcopy(dict_default_hparams)
in_wav_dataset_dir = repo_root() / "../../data/processed/libritts_r_200"
try:
notebook_name = ipynbname.name()
except FileNotFoundError:
notebook_name = Path(get_ipython().user_ns["__vsc_ipynb_file__"]).name
out_dir = repo_root() / "notebooks" / notebook_name.split(".")[0].split("_")[0]
resume = False
return h, in_wav_dataset_dir, out_dir, resume
def prepare_training_configs() -> tuple[dict, Path, Path, bool]:
# data_dir, out_dir は config ファイルでもコマンドライン引数でも指定でき、
# コマンドライン引数が優先される。
# 各種ファイルパスを相対パスで指定した場合、config ファイルでは
# リポジトリルートからの相対パスとなるが、コマンドライン引数では
# カレントディレクトリからの相対パスとなる。
parser = argparse.ArgumentParser()
# fmt: off
parser.add_argument("-c", "--config", type=Path, help="Path to the config file.")
parser.add_argument("-d", "--data_dir", type=Path, help="Directory containing the training data.")
parser.add_argument("-o", "--out_dir", type=Path, help="Output directory.")
parser.add_argument("-r", "--resume", action="store_true", help="Resume training.")
# fmt: on
args = parser.parse_args()
# config
if args.config is None:
h = deepcopy(dict_default_hparams)
else:
with open(args.config, encoding="utf-8") as f:
h = json.load(f)
for key in dict_default_hparams.keys():
if key not in h:
h[key] = dict_default_hparams[key]
warnings.warn(
f"{key} is not specified in the config file. Using the default value."
)
# data_dir
if args.data_dir is not None:
in_wav_dataset_dir = args.data_dir
elif "data_dir" in h:
in_wav_dataset_dir = repo_root() / Path(h["data_dir"])
del h["data_dir"]
else:
raise ValueError(
"data_dir must be specified. "
"For example `python3 beatrice_trainer -d my_training_data_dir -o my_output_dir`."
)
# out_dir
if args.out_dir is not None:
out_dir = args.out_dir
elif "out_dir" in h:
out_dir = repo_root() / Path(h["out_dir"])
del h["out_dir"]
else:
raise ValueError(
"out_dir must be specified. "
"For example `python3 beatrice_trainer -d my_training_data_dir -o my_output_dir`."
)
for key in list(h.keys()):
if key not in dict_default_hparams:
warnings.warn(f"`{key}` specified in the config file will be ignored.")
del h[key]
# resume
resume = args.resume
return h, in_wav_dataset_dir, out_dir, resume
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__dict__ = self
# %% [markdown]
# ## Phone Extractor
# %%
def dump_params(params: torch.Tensor, f: BinaryIO):
if params is None:
return
if params.dtype == torch.bfloat16:
f.write(
params.detach()
.clone()
.float()
.view(torch.short)
.numpy()
.ravel()[1::2]
.tobytes()
)
else:
f.write(params.detach().numpy().ravel().tobytes())
f.flush()
def dump_layer(layer: nn.Module, f: BinaryIO):
dump = partial(dump_params, f=f)
if hasattr(layer, "dump"):
layer.dump(f)
elif isinstance(layer, (nn.Linear, nn.Conv1d, nn.LayerNorm)):
dump(layer.weight)
dump(layer.bias)
elif isinstance(layer, nn.ConvTranspose1d):
dump(layer.weight.transpose(0, 1))
dump(layer.bias)
elif isinstance(layer, nn.GRU):
dump(layer.weight_ih_l0)
dump(layer.bias_ih_l0)
dump(layer.weight_hh_l0)
dump(layer.bias_hh_l0)
for i in range(1, 99999):
if not hasattr(layer, f"weight_ih_l{i}"):
break
dump(getattr(layer, f"weight_ih_l{i}"))
dump(getattr(layer, f"bias_ih_l{i}"))
dump(getattr(layer, f"weight_hh_l{i}"))
dump(getattr(layer, f"bias_hh_l{i}"))
elif isinstance(layer, nn.Embedding):
dump(layer.weight)
elif isinstance(layer, nn.ModuleList):
for l in layer:
dump_layer(l, f)
else:
assert False, layer
class CausalConv1d(nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
delay: int = 0,
):
padding = (kernel_size - 1) * dilation - delay
self.trim = (kernel_size - 1) * dilation - 2 * delay
if self.trim < 0:
raise ValueError
super().__init__(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
result = super().forward(input)
if self.trim == 0:
return result
else:
return result[:, :, : -self.trim]
class ConvNeXtBlock(nn.Module):
def __init__(
self,
channels: int,
intermediate_channels: int,
layer_scale_init_value: float,
kernel_size: int = 7,
use_weight_norm: bool = False,
):
super().__init__()
self.use_weight_norm = use_weight_norm
self.dwconv = CausalConv1d(
channels, channels, kernel_size=kernel_size, groups=channels
)
self.norm = nn.LayerNorm(channels)
self.pwconv1 = nn.Linear(channels, intermediate_channels)
self.pwconv2 = nn.Linear(intermediate_channels, channels)
self.gamma = nn.Parameter(torch.full((channels,), layer_scale_init_value))
if use_weight_norm:
self.norm = nn.Identity()
self.dwconv = weight_norm(self.dwconv)
self.pwconv1 = weight_norm(self.pwconv1)
self.pwconv2 = weight_norm(self.pwconv2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = x
x = self.dwconv(x)
x = x.transpose(1, 2)
x = self.norm(x)
x = self.pwconv1(x)
x = F.gelu(x, approximate="tanh")
x = self.pwconv2(x)
x *= self.gamma
x = x.transpose(1, 2)
x += identity
return x
def remove_weight_norm(self):
if self.use_weight_norm:
remove_weight_norm(self.dwconv)
remove_weight_norm(self.pwconv1)
remove_weight_norm(self.pwconv2)
def merge_weights(self):
if not self.use_weight_norm:
self.pwconv1.bias.data += (
self.norm.bias.data[None, :] * self.pwconv1.weight.data
).sum(1)
self.pwconv1.weight.data *= self.norm.weight.data[None, :]
self.norm.bias.data[:] = 0.0
self.norm.weight.data[:] = 1.0
self.pwconv2.weight.data *= self.gamma.data[:, None]
self.pwconv2.bias.data *= self.gamma.data
self.gamma.data[:] = 1.0
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump(f)
return
if not hasattr(f, "write"):
raise TypeError
dump_layer(self.dwconv, f)
dump_layer(self.pwconv1, f)
dump_layer(self.pwconv2, f)
class ConvNeXtStack(nn.Module):
def __init__(
self,
in_channels: int,
channels: int,
intermediate_channels: int,
n_blocks: int,
delay: int,
embed_kernel_size: int,
kernel_size: int,
use_weight_norm: bool = False,
):
super().__init__()
assert delay * 2 + 1 <= embed_kernel_size
self.use_weight_norm = use_weight_norm
self.embed = CausalConv1d(in_channels, channels, embed_kernel_size, delay=delay)
self.norm = nn.LayerNorm(channels)
self.convnext = nn.ModuleList(
[
ConvNeXtBlock(
channels=channels,
intermediate_channels=intermediate_channels,
layer_scale_init_value=1.0 / n_blocks,
kernel_size=kernel_size,
use_weight_norm=use_weight_norm,
)
for _ in range(n_blocks)
]
)
self.final_layer_norm = nn.LayerNorm(channels)
if use_weight_norm:
self.embed = weight_norm(self.embed)
self.norm = nn.Identity()
self.final_layer_norm = nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.embed(x)
x = self.norm(x.transpose(1, 2)).transpose(1, 2)
for conv_block in self.convnext:
x = conv_block(x)
x = self.final_layer_norm(x.transpose(1, 2)).transpose(1, 2)
return x
def remove_weight_norm(self):
if self.use_weight_norm:
remove_weight_norm(self.embed)
for conv_block in self.convnext:
conv_block.remove_weight_norm()
def merge_weights(self):
for conv_block in self.convnext:
conv_block.merge_weights()
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump(f)
return
if not hasattr(f, "write"):
raise TypeError
dump_layer(self.embed, f)
if not self.use_weight_norm:
dump_layer(self.norm, f)
dump_layer(self.convnext, f)
if not self.use_weight_norm:
dump_layer(self.final_layer_norm, f)
class FeatureExtractor(nn.Module):
def __init__(self, hidden_channels: int):
super().__init__()
# fmt: off
self.conv0 = weight_norm(nn.Conv1d(1, hidden_channels // 8, 10, 5, bias=False))
self.conv1 = weight_norm(nn.Conv1d(hidden_channels // 8, hidden_channels // 4, 3, 2, bias=False))
self.conv2 = weight_norm(nn.Conv1d(hidden_channels // 4, hidden_channels // 2, 3, 2, bias=False))
self.conv3 = weight_norm(nn.Conv1d(hidden_channels // 2, hidden_channels, 3, 2, bias=False))
self.conv4 = weight_norm(nn.Conv1d(hidden_channels, hidden_channels, 3, 2, bias=False))
self.conv5 = weight_norm(nn.Conv1d(hidden_channels, hidden_channels, 2, 2, bias=False))
# fmt: on
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [batch_size, 1, wav_length]
wav_length = x.size(2)
if wav_length % 160 != 0:
warnings.warn("wav_length % 160 != 0")
x = F.pad(x, (40, 40))
x = F.gelu(self.conv0(x), approximate="tanh")
x = F.gelu(self.conv1(x), approximate="tanh")
x = F.gelu(self.conv2(x), approximate="tanh")
x = F.gelu(self.conv3(x), approximate="tanh")
x = F.gelu(self.conv4(x), approximate="tanh")
x = F.gelu(self.conv5(x), approximate="tanh")
# [batch_size, hidden_channels, wav_length / 160]
return x
def remove_weight_norm(self):
remove_weight_norm(self.conv0)
remove_weight_norm(self.conv1)
remove_weight_norm(self.conv2)
remove_weight_norm(self.conv3)
remove_weight_norm(self.conv4)
remove_weight_norm(self.conv5)
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump(f)
return
if not hasattr(f, "write"):
raise TypeError
dump_layer(self.conv0, f)
dump_layer(self.conv1, f)
dump_layer(self.conv2, f)
dump_layer(self.conv3, f)
dump_layer(self.conv4, f)
dump_layer(self.conv5, f)
class FeatureProjection(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.norm = nn.LayerNorm(in_channels)
self.projection = nn.Conv1d(in_channels, out_channels, 1)
self.dropout = nn.Dropout(0.1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# [batch_size, channels, length]
x = self.norm(x.transpose(1, 2)).transpose(1, 2)
x = self.projection(x)
x = self.dropout(x)
return x
def merge_weights(self):
self.projection.bias.data += (
(self.norm.bias.data[None, :, None] * self.projection.weight.data)
.sum(1)
.squeeze(1)
)
self.projection.weight.data *= self.norm.weight.data[None, :, None]
self.norm.bias.data[:] = 0.0
self.norm.weight.data[:] = 1.0
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump(f)
return
if not hasattr(f, "write"):
raise TypeError
dump_layer(self.projection, f)
class PhoneExtractor(nn.Module):
def __init__(
self,
phone_channels: int = 256,
hidden_channels: int = 256,
backbone_embed_kernel_size: int = 7,
kernel_size: int = 17,
n_blocks: int = 8,
cardinality: int = 256,
):
super().__init__()
self.feature_extractor = FeatureExtractor(hidden_channels)
self.feature_projection = FeatureProjection(hidden_channels, hidden_channels)
self.n_speaker_encoder_layers = 3
self.speaker_encoder = nn.GRU(
hidden_channels,
hidden_channels,
self.n_speaker_encoder_layers,
batch_first=True,
)
for i in range(self.n_speaker_encoder_layers):
for input_char in "ih":
self.speaker_encoder = weight_norm(
self.speaker_encoder, f"weight_{input_char}h_l{i}"
)
self.backbone = ConvNeXtStack(
in_channels=hidden_channels,
channels=hidden_channels,
intermediate_channels=hidden_channels * 3,
n_blocks=n_blocks,
delay=0,
embed_kernel_size=backbone_embed_kernel_size,
kernel_size=kernel_size,
)
self.head = weight_norm(nn.Conv1d(hidden_channels, phone_channels, 1))
self.label_embeddings = nn.ModuleList(
[
nn.Embedding(cardinality, phone_channels),
nn.Embedding(cardinality, phone_channels),
]
)
def forward(
self, x: torch.Tensor, return_stats: bool = True
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, float]]]:
# x: [batch_size, 1, wav_length]
stats = {}
# [batch_size, 1, wav_length] -> [batch_size, feature_extractor_hidden_channels, length]
x = self.feature_extractor(x)
if return_stats:
stats["feature_norm"] = x.detach().norm(dim=1).mean()
# [batch_size, feature_extractor_hidden_channels, length] -> [batch_size, hidden_channels, length]
x = self.feature_projection(x)
# [batch_size, hidden_channels, length] -> [batch_size, length, hidden_channels]
g, _ = self.speaker_encoder(x.transpose(1, 2))
if self.training:
batch_size, length, _ = g.size()
shuffle_sizes_for_each_data = torch.randint(
0, 50, (batch_size,), device=g.device
)
max_indices = torch.arange(length, device=g.device)[None, :, None]
min_indices = (
max_indices - shuffle_sizes_for_each_data[:, None, None]
).clamp_(min=0)
with torch.cuda.amp.autocast(False):
indices = (
torch.rand(g.size(), device=g.device)
* (max_indices - min_indices + 1)
).long() + min_indices
assert indices.min() >= 0, indices.min()
assert indices.max() < length, (indices.max(), length)
g = g.gather(1, indices)
# [batch_size, length, hidden_channels] -> [batch_size, hidden_channels, length]
g = g.transpose(1, 2).contiguous()
# [batch_size, hidden_channels, length]
x = self.backbone(x + g)
# [batch_size, hidden_channels, length] -> [batch_size, phone_channels, length]
phone = self.head(F.gelu(x, approximate="tanh"))
results = [phone]
if return_stats:
stats["code_norm"] = phone.detach().norm(dim=1).mean().item()
results.append(stats)
if len(results) == 1:
return results[0]
return tuple(results)
@torch.inference_mode()
def units(self, x: torch.Tensor) -> torch.Tensor:
# x: [batch_size, 1, wav_length]
# [batch_size, 1, wav_length] -> [batch_size, phone_channels, length]
phone = self.forward(x, return_stats=False)
# [batch_size, phone_channels, length] -> [batch_size, length, phone_channels]
phone = phone.transpose(1, 2)
# [batch_size, length, phone_channels]
return phone
def remove_weight_norm(self):
self.feature_extractor.remove_weight_norm()
for i in range(self.n_speaker_encoder_layers):
for input_char in "ih":
remove_weight_norm(self.speaker_encoder, f"weight_{input_char}h_l{i}")
remove_weight_norm(self.head)
def merge_weights(self):
self.feature_projection.merge_weights()
self.backbone.merge_weights()
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump(f)
return
if not hasattr(f, "write"):
raise TypeError
dump_layer(self.feature_extractor, f)
dump_layer(self.feature_projection, f)
dump_layer(self.speaker_encoder, f)
dump_layer(self.backbone, f)
dump_layer(self.head, f)
# %% [markdown]
# ## Pitch Estimator
# %%
def extract_pitch_features(
y: torch.Tensor, # [..., wav_length]
hop_length: int = 160, # 10ms
win_length: int = 560, # 35ms
max_corr_period: int = 256, # 16ms, 62.5Hz (16000 / 256)
corr_win_length: int = 304, # 19ms
instfreq_features_cutoff_bin: int = 64, # 1828Hz (16000 * 64 / 560)
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert max_corr_period + corr_win_length == win_length
# パディングする
padding_length = (win_length - hop_length) // 2
y = F.pad(y, (padding_length, padding_length))
# フレームにする
# [..., win_length, n_frames]
y_frames = y.unfold(-1, win_length, hop_length).transpose_(-2, -1)
# 複素スペクトログラム
# Complex[..., (win_length // 2 + 1), n_frames]
spec: torch.Tensor = torch.fft.rfft(y_frames, n=win_length, dim=-2)
# Complex[..., instfreq_features_cutoff_bin, n_frames]
spec = spec[..., :instfreq_features_cutoff_bin, :]
# 対数パワースペクトログラム
log_power_spec = spec.abs().add_(1e-5).log10_()
# 瞬時位相の時間差分
# 時刻 0 の値は 0
delta_spec = spec[..., :, 1:] * spec[..., :, :-1].conj()
delta_spec /= delta_spec.abs().add_(1e-5)
delta_spec = torch.cat(
[torch.zeros_like(delta_spec[..., :, :1]), delta_spec], dim=-1
)
# [..., instfreq_features_cutoff_bin * 3, n_frames]
instfreq_features = torch.cat(
[log_power_spec, delta_spec.real, delta_spec.imag], dim=-2
)
# 自己相関
# 余裕があったら LPC 残差にするのも試したい
# 元々これに 2.0 / corr_win_length を掛けて使おうと思っていたが、
# この値は振幅の 2 乗に比例していて、NN に入力するために良い感じに分散を
# 標準化する方法が思いつかなかったのでやめた
flipped_y_frames = y_frames.flip((-2,))
a = torch.fft.rfft(flipped_y_frames, n=win_length, dim=-2)
b = torch.fft.rfft(y_frames[..., -corr_win_length:, :], n=win_length, dim=-2)
# [..., max_corr_period, n_frames]
corr = torch.fft.irfft(a * b, n=win_length, dim=-2)[..., corr_win_length:, :]
# エネルギー項
energy = flipped_y_frames.square_().cumsum_(-2)
energy0 = energy[..., corr_win_length - 1 : corr_win_length, :]
energy = energy[..., corr_win_length:, :] - energy[..., :-corr_win_length, :]
# Difference function
corr_diff = (energy0 + energy).sub_(corr.mul_(2.0))
assert corr_diff.min() >= -1e-3, corr_diff.min()
corr_diff.clamp_(min=0.0) # 計算誤差対策
# 標準化
corr_diff *= 2.0 / corr_win_length
corr_diff.sqrt_()
# 変換モデルへの入力用のエネルギー
energy = (
y_frames.mul_(
torch.signal.windows.cosine(win_length, device=y.device)[..., None]
)
.square_()
.sum(-2, keepdim=True)
)
energy.clamp_(min=1e-3).log10_() # >= -3, 振幅 1 の正弦波なら大体 2.15
energy *= 0.5 # >= -1.5, 振幅 1 の正弦波なら大体 1.07, 1 の差は振幅で 20dB の差
return (
instfreq_features, # [..., instfreq_features_cutoff_bin * 3, n_frames]
corr_diff, # [..., max_corr_period, n_frames]
energy, # [..., 1, n_frames]
)
class PitchEstimator(nn.Module):
def __init__(
self,
input_instfreq_channels: int = 192,
input_corr_channels: int = 256,
pitch_channels: int = 384,
channels: int = 192,
intermediate_channels: int = 192 * 3,
n_blocks: int = 6,
delay: int = 1, # 10ms, 特徴抽出と合わせると 22.5ms
embed_kernel_size: int = 3,
kernel_size: int = 33,
bins_per_octave: int = 96,
):
super().__init__()
self.bins_per_octave = bins_per_octave
self.instfreq_embed_0 = nn.Conv1d(input_instfreq_channels, channels, 1)
self.instfreq_embed_1 = nn.Conv1d(channels, channels, 1)
self.corr_embed_0 = nn.Conv1d(input_corr_channels, channels, 1)
self.corr_embed_1 = nn.Conv1d(channels, channels, 1)
self.backbone = ConvNeXtStack(
channels,
channels,
intermediate_channels,
n_blocks,
delay,
embed_kernel_size,
kernel_size,
)
self.head = nn.Conv1d(channels, pitch_channels, 1)
def forward(self, wav: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# wav: [batch_size, 1, wav_length]
# [batch_size, input_instfreq_channels, length],
# [batch_size, input_corr_channels, length]
with torch.cuda.amp.autocast(False):
instfreq_features, corr_diff, energy = extract_pitch_features(
wav.squeeze(1),
hop_length=160,
win_length=560,
max_corr_period=256,
corr_win_length=304,
instfreq_features_cutoff_bin=64,
)
instfreq_features = F.gelu(
self.instfreq_embed_0(instfreq_features), approximate="tanh"
)
instfreq_features = self.instfreq_embed_1(instfreq_features)
corr_diff = F.gelu(self.corr_embed_0(corr_diff), approximate="tanh")
corr_diff = self.corr_embed_1(corr_diff)
# [batch_size, channels, length]
x = instfreq_features + corr_diff # ここ活性化関数忘れてる
x = self.backbone(x)
# [batch_size, pitch_channels, length]
x = self.head(x)
return x, energy
def sample_pitch(
self, pitch: torch.Tensor, band_width: int = 48, return_features: bool = False
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
# pitch: [batch_size, pitch_channels, length]
# 返されるピッチの値には 0 は含まれない
batch_size, pitch_channels, length = pitch.size()
pitch = pitch.softmax(1)
if return_features:
unvoiced_proba = pitch[:, :1, :].clone()
pitch[:, 0, :] = -100.0
pitch = (
pitch.transpose(1, 2)
.contiguous()
.view(batch_size * length, 1, pitch_channels)
)
band_pitch = F.conv1d(
pitch,
torch.ones((1, 1, 1), device=pitch.device).expand(1, 1, band_width),
)
# [batch_size * length, 1, pitch_channels - band_width + 1] -> Long[batch_size * length, 1]
quantized_band_pitch = band_pitch.argmax(2)
if return_features:
# [batch_size * length, 1]
band_proba = band_pitch.gather(2, quantized_band_pitch[:, :, None])
# [batch_size * length, 1]
half_pitch_band_proba = band_pitch.gather(
2,
(quantized_band_pitch - self.bins_per_octave).clamp_(min=1)[:, :, None],
)
half_pitch_band_proba[quantized_band_pitch <= self.bins_per_octave] = 0.0
half_pitch_proba = (half_pitch_band_proba / (band_proba + 1e-6)).view(
batch_size, 1, length
)
# [batch_size * length, 1]
double_pitch_band_proba = band_pitch.gather(
2,
(quantized_band_pitch + self.bins_per_octave).clamp_(
max=pitch_channels - band_width
)[:, :, None],
)
double_pitch_band_proba[
quantized_band_pitch
> pitch_channels - band_width - self.bins_per_octave
] = 0.0
double_pitch_proba = (double_pitch_band_proba / (band_proba + 1e-6)).view(
batch_size, 1, length
)
# Long[1, pitch_channels]
mask = torch.arange(pitch_channels, device=pitch.device)[None, :]
# bool[batch_size * length, pitch_channels]
mask = (quantized_band_pitch <= mask) & (
mask < quantized_band_pitch + band_width
)
# Long[batch_size, length]
quantized_pitch = (pitch.squeeze(1) * mask).argmax(1).view(batch_size, length)
if return_features:
features = torch.cat(
[unvoiced_proba, half_pitch_proba, double_pitch_proba], dim=1
)
# Long[batch_size, length], [batch_size, 3, length]
return quantized_pitch, features
else:
return quantized_pitch
def merge_weights(self):
self.backbone.merge_weights()
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump(f)
return
if not hasattr(f, "write"):
raise TypeError
dump_layer(self.instfreq_embed_0, f)
dump_layer(self.instfreq_embed_1, f)
dump_layer(self.corr_embed_0, f)
dump_layer(self.corr_embed_1, f)
dump_layer(self.backbone, f)
dump_layer(self.head, f)
# %% [markdown]
# ## Vocoder
# %%
def overlap_add(
ir: torch.Tensor,
pitch: torch.Tensor,
hop_length: int = 240,
delay: int = 0,
) -> torch.Tensor:
# print("ir, pitch: ", ir.dtype, pitch.dtype)
batch_size, ir_length, length = ir.size()
assert pitch.size() == (batch_size, length * hop_length)
assert 0 <= delay < ir_length, (delay, ir_length)
# 位相は [0, 1) で表す
normalized_freq = pitch / 24000.0
# 初期位相をランダムに設定
normalized_freq[:, 0] = torch.rand(batch_size, device=pitch.device)
with torch.cuda.amp.autocast(enabled=False):
phase = (normalized_freq.double().cumsum_(1) % 1.0).float()
# 重ねる箇所を求める
# [n_pitchmarks], [n_pitchmarks]
indices0, indices1 = torch.nonzero(phase[:, :-1] > phase[:, 1:], as_tuple=True)
# 重ねる箇所の小数部分 (位相の遅れ) を求める
numer = 1.0 - phase[indices0, indices1]
# [n_pitchmarks]
fractional_part = numer / (numer + phase[indices0, indices1 + 1])
# 重ねる値を求める
# [n_pitchmarks, ir_length]
values = ir[indices0, :, indices1 // hop_length]
# 位相を遅らせる
# values が時間領域と仮定
# Complex[n_pitchmarks, ir_length / 2 + 1]
values = torch.fft.rfft(values, n=ir_length, dim=1)
# 位相遅れの量
# [n_pitchmarks, ir_length / 2 + 1]
delay_phase = (
torch.arange(ir_length // 2 + 1, device=pitch.device, dtype=torch.float32)[
None, :
]
/ -ir_length
* fractional_part[:, None]
)
# Complex[n_pitchmarks, ir_length / 2 + 1]
delay_phase = torch.polar(torch.ones_like(delay_phase), delay_phase * math.tau)
# values *= delay_phase
values = values * delay_phase
# [n_pitchmarks, ir_length]
values = torch.fft.irfft(values, n=ir_length, dim=1)
# 加算する値をサンプル単位にばらす
# [n_pitchmarks * ir_length]
values = values.ravel()
# Long[n_pitchmarks * ir_length]
indices0 = indices0[:, None].expand(-1, ir_length).ravel()
# Long[n_pitchmarks * ir_length]
indices1 = (
indices1[:, None] + torch.arange(ir_length, device=pitch.device)
).ravel()
# overlap-add する
overlap_added_signal = torch.zeros(
(batch_size, length * hop_length + ir_length), device=pitch.device
)
# print("overlap_added_signal, values: ", overlap_added_signal.dtype, values.dtype)
overlap_added_signal.index_put_((indices0, indices1), values, accumulate=True)
overlap_added_signal = overlap_added_signal[:, delay : -ir_length + delay]
# sinc 重ねたものと ir を畳み込んだ方が FFT の回数減らせた気がする
return overlap_added_signal
def generate_noise(aperiodicity: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# aperiodicity: [batch_size, hop_length, length]
batch_size, hop_length, length = aperiodicity.size()
excitation = torch.rand(
batch_size, (length + 1) * hop_length, device=aperiodicity.device
)
excitation -= 0.5
# 矩形窓で分析
# Complex[batch_size, hop_length + 1, length]
noise = torch.stft(
excitation,
n_fft=2 * hop_length,
hop_length=hop_length,
center=False,
return_complex=True,
)
assert noise.size(2) == aperiodicity.size(2), (
noise.size(),
aperiodicity.size(),
)
noise[:, 0, :] = 0.0
noise[:, 1:, :] *= aperiodicity
# ハン窓で合成
# torch.istft は最適合成窓が使われるので使えないことに注意
# [batch_size, 2 * hop_length, length]
noise = torch.fft.irfft(noise, n=2 * hop_length, dim=1)
noise *= torch.hann_window(2 * hop_length, device=noise.device)[None, :, None]
# [batch_size, (length + 1) * hop_length]
noise = F.fold(
noise,
(1, (length + 1) * hop_length),
(1, 2 * hop_length),
stride=(1, hop_length),
).squeeze_((1, 2))
noise = noise[:, hop_length // 2 : -hop_length // 2]
excitation = excitation[:, hop_length // 2 : -hop_length // 2]
return noise, excitation # [batch_size, length * hop_length]
class GradientEqualizerFunction(torch.autograd.Function):
"""ノルムが小さいほど勾配が大きくなってしまうのを補正する"""
@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
# x: [batch_size, 1, length]
rms = x.square().mean(dim=2, keepdim=True).sqrt_()
ctx.save_for_backward(rms)
return x
@staticmethod
def backward(ctx, dx: torch.Tensor) -> torch.Tensor:
# dx: [batch_size, 1, length]
(rms,) = ctx.saved_tensors
dx = dx * (math.sqrt(2.0) * rms + 0.1)
return dx
class PseudoDDSPVocoder(nn.Module):
def __init__(
self,
channels: int,
hop_length: int = 240,
n_pre_blocks: int = 4,
):
super().__init__()
self.hop_length = hop_length
self.prenet = ConvNeXtStack(
in_channels=channels,
channels=channels,
intermediate_channels=channels * 3,
n_blocks=n_pre_blocks,
delay=2, # 20ms 遅延
embed_kernel_size=7,
kernel_size=33,
)
self.ir_generator = ConvNeXtStack(
in_channels=channels,
channels=channels,
intermediate_channels=channels * 3,
n_blocks=2,
delay=0,
embed_kernel_size=3,
kernel_size=33,
use_weight_norm=True,
)
self.ir_generator_post = weight_norm(nn.Conv1d(channels, 512, 1, bias=False))
self.aperiodicity_generator = ConvNeXtStack(
in_channels=channels,
channels=channels,
intermediate_channels=channels * 3,
n_blocks=2,
delay=0,
embed_kernel_size=3,
kernel_size=33,
use_weight_norm=True,
)
self.aperiodicity_generator_post = weight_norm(
nn.Conv1d(channels, hop_length, 1, bias=False)
)
def forward(
self, x: torch.Tensor, pitch: torch.Tensor
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
# x: [batch_size, channels, length]
# pitch: [batch_size, length]
x = self.prenet(x)
ir = self.ir_generator(x)
ir = F.elu(ir, inplace=True)
# [batch_size, 512, length]
ir = self.ir_generator_post(ir)
# 最近傍補間
# [batch_size, length * hop_length]
pitch = torch.repeat_interleave(pitch, self.hop_length, dim=1)
# [batch_size, length * hop_length]
periodic_signal = overlap_add(ir, pitch, self.hop_length, delay=120)
aperiodicity = self.aperiodicity_generator(x)
aperiodicity = F.elu(aperiodicity, inplace=True)
# [batch_size, hop_length, length]
aperiodicity = self.aperiodicity_generator_post(aperiodicity)
# [batch_size, length * hop_length], [batch_size, length * hop_length]
aperiodic_signal, noise_excitation = generate_noise(aperiodicity)
# [batch_size, 1, length * hop_length]
y_g_hat = (periodic_signal + aperiodic_signal)[:, None, :]
y_g_hat = GradientEqualizerFunction.apply(y_g_hat)
return y_g_hat, {
"periodic_signal": periodic_signal.detach(),
"aperiodic_signal": aperiodic_signal.detach(),
"noise_excitation": noise_excitation.detach(),
}
def remove_weight_norm(self):
self.prenet.remove_weight_norm()
self.ir_generator.remove_weight_norm()
remove_weight_norm(self.ir_generator_post)
self.aperiodicity_generator.remove_weight_norm()
remove_weight_norm(self.aperiodicity_generator_post)
def merge_weights(self):
self.prenet.merge_weights()
self.ir_generator.merge_weights()
self.aperiodicity_generator.merge_weights()
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump(f)
return
if not hasattr(f, "write"):
raise TypeError
dump_layer(self.prenet, f)
dump_layer(self.ir_generator, f)
dump_layer(self.ir_generator_post, f)
dump_layer(self.aperiodicity_generator, f)
dump_layer(self.aperiodicity_generator_post, f)
def slice_segments(
x: torch.Tensor, start_indices: torch.Tensor, segment_length: int
) -> torch.Tensor:
batch_size, channels, _ = x.size()
# [batch_size, 1, segment_size]
indices = start_indices[:, None, None] + torch.arange(
segment_length, device=start_indices.device
)
# [batch_size, channels, segment_size]
indices = indices.expand(batch_size, channels, segment_length)
return x.gather(2, indices)
class ConverterNetwork(nn.Module):
def __init__(
self,
phone_extractor: PhoneExtractor,
pitch_estimator: PitchEstimator,
n_speakers: int,
hidden_channels: int,
):
super().__init__()
self.frozen_modules = {
"phone_extractor": phone_extractor.eval().requires_grad_(False),
"pitch_estimator": pitch_estimator.eval().requires_grad_(False),
}
self.embed_phone = nn.Conv1d(256, hidden_channels, 1)
self.embed_quantized_pitch = nn.Embedding(384, hidden_channels)
phase = (
torch.arange(384, dtype=torch.float)[:, None]
* (
torch.arange(0, hidden_channels, 2, dtype=torch.float)
* (-math.log(10000.0) / hidden_channels)
).exp_()
)
self.embed_quantized_pitch.weight.data[:, 0::2] = phase.sin()
self.embed_quantized_pitch.weight.data[:, 1::2] = phase.cos_()
self.embed_quantized_pitch.weight.requires_grad_(False)
self.embed_pitch_features = nn.Conv1d(4, hidden_channels, 1)
self.embed_speaker = nn.Embedding(n_speakers, hidden_channels)
self.embed_formant_shift = nn.Embedding(9, hidden_channels)
self.vocoder = PseudoDDSPVocoder(
channels=hidden_channels,
hop_length=240,
n_pre_blocks=4,
)
self.melspectrogram = torchaudio.transforms.MelSpectrogram(
sample_rate=24000,
n_fft=1024,
win_length=720,
hop_length=128,
n_mels=80,
power=2, # 不安定さの原因になっているかも
norm="slaney",
mel_scale="slaney",
)
def _get_resampler(
self, orig_freq, new_freq, device, cache={}
) -> torchaudio.transforms.Resample:
key = orig_freq, new_freq
if key in cache:
return cache[key]
resampler = torchaudio.transforms.Resample(orig_freq, new_freq).to(device)
cache[key] = resampler
return resampler
def forward(
self,
x: torch.Tensor,
target_speaker_id: torch.Tensor,
formant_shift_semitone: torch.Tensor,
pitch_shift_semitone: Optional[torch.Tensor] = None,
slice_start_indices: Optional[torch.Tensor] = None,
slice_segment_length: Optional[int] = None,
return_stats: bool = False,
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, float]]]:
# x: [batch_size, 1, wav_length]
# target_speaker_id: Long[batch_size]
# formant_shift_semitone: [batch_size]
# pitch_shift_semitone: [batch_size]
# slice_start_indices: [batch_size]
batch_size, _, _ = x.size()
with torch.inference_mode():
phone_extractor: PhoneExtractor = self.frozen_modules["phone_extractor"]
pitch_estimator: PitchEstimator = self.frozen_modules["pitch_estimator"]
# [batch_size, 1, wav_length] -> [batch_size, phone_channels, length]
phone = phone_extractor.units(x).transpose(1, 2)
# [batch_size, 1, wav_length] -> [batch_size, pitch_channels, length], [batch_size, 1, length]
pitch, energy = pitch_estimator(x)
# augmentation
if self.training:
# [batch_size, pitch_channels - 1]
weights = pitch.softmax(1)[:, 1:, :].mean(2)
# [batch_size]
mean_pitch = (
weights * torch.arange(1, 384, device=weights.device)
).sum(1) / weights.sum(1)
mean_pitch = mean_pitch.round_().long()
target_pitch = torch.randint_like(mean_pitch, 64, 257)
shift = target_pitch - mean_pitch
shift_ratio = (
2.0 ** (shift.float() / pitch_estimator.bins_per_octave)
).tolist()
shift = []
interval_length = 100 # 1s
interval_zeros = torch.zeros(
(1, 1, interval_length * 160), device=x.device
)
concatenated_shifted_x = []
offsets = [0]
for i in range(batch_size):
shift_ratio_i = shift_ratio[i]
shift_ratio_fraction_i = Fraction.from_float(
shift_ratio_i
).limit_denominator(30)
shift_numer_i = shift_ratio_fraction_i.numerator
shift_denom_i = shift_ratio_fraction_i.denominator
shift_ratio_i = shift_numer_i / shift_denom_i
shift_i = int(
round(
math.log2(shift_ratio_i) * pitch_estimator.bins_per_octave
)
)
shift.append(shift_i)
shift_ratio[i] = shift_ratio_i
# [1, 1, wav_length / shift_ratio]
with torch.cuda.amp.autocast(False):
shifted_x_i = self._get_resampler(
shift_numer_i, shift_denom_i, x.device
)(x[i])[None]
if shifted_x_i.size(2) % 160 != 0:
shifted_x_i = F.pad(
shifted_x_i,
(0, 160 - shifted_x_i.size(2) % 160),
mode="reflect",
)
assert shifted_x_i.size(2) % 160 == 0
offsets.append(
offsets[-1] + interval_length + shifted_x_i.size(2) // 160
)
concatenated_shifted_x.extend([interval_zeros, shifted_x_i])
if offsets[-1] % 256 != 0:
# 長さが同じ方が何かのキャッシュが効いて早くなるようなので
# 適当に 256 の倍数になるようにパディングして長さのパターン数を減らす
concatenated_shifted_x.append(
torch.zeros(
(1, 1, (256 - offsets[-1] % 256) * 160), device=x.device
)
)
# [batch_size, 1, sum(wav_length) + batch_size * 16000]
concatenated_shifted_x = torch.cat(concatenated_shifted_x, dim=2)
assert concatenated_shifted_x.size(2) % (256 * 160) == 0
# [1, pitch_channels, length / shift_ratio], [1, 1, length / shift_ratio]
concatenated_pitch, concatenated_energy = pitch_estimator(
concatenated_shifted_x
)
for i in range(batch_size):
shift_i = shift[i]
shift_ratio_i = shift_ratio[i]
left = offsets[i] + interval_length
right = offsets[i + 1]
pitch_i = concatenated_pitch[:, :, left:right]
energy_i = concatenated_energy[:, :, left:right]
pitch_i = F.interpolate(
pitch_i,
scale_factor=shift_ratio_i,
mode="linear",
align_corners=False,
)
energy_i = F.interpolate(
energy_i,
scale_factor=shift_ratio_i,
mode="linear",
align_corners=False,
)
assert pitch_i.size(2) == energy_i.size(2)
assert abs(pitch_i.size(2) - pitch.size(2)) <= 10
length = min(pitch_i.size(2), pitch.size(2))
if shift_i > 0:
pitch[i : i + 1, :1, :length] = pitch_i[:, :1, :length]
pitch[i : i + 1, 1:-shift_i, :length] = pitch_i[
:, 1 + shift_i :, :length
]
pitch[i : i + 1, -shift_i:, :length] = -10.0
elif shift_i < 0:
pitch[i : i + 1, :1, :length] = pitch_i[:, :1, :length]
pitch[i : i + 1, 1 : 1 - shift_i, :length] = -10.0
pitch[i : i + 1, 1 - shift_i :, :length] = pitch_i[
:, 1:shift_i, :length
]
energy[i : i + 1, :, :length] = energy_i[:, :, :length]
# [batch_size, pitch_channels, length] -> Long[batch_size, length], [batch_size, 3, length]
quantized_pitch, pitch_features = pitch_estimator.sample_pitch(
pitch, return_features=True
)
if pitch_shift_semitone is not None:
quantized_pitch = torch.where(
quantized_pitch == 0,
quantized_pitch,
(
quantized_pitch
+ (
pitch_shift_semitone[:, None]
* (pitch_estimator.bins_per_octave / 12)
)
.round_()
.long()
).clamp_(1, 383),
)
pitch = 55.0 * 2.0 ** (
quantized_pitch.float() / pitch_estimator.bins_per_octave
)
# phone が 2.5ms 先読みしているのに対して、
# energy は 12.5ms, pitch_features は 22.5ms 先読みしているので、
# ずらして phone に合わせる
energy = F.pad(energy[:, :, :-1], (1, 0), mode="reflect")
quantized_pitch = F.pad(quantized_pitch[:, :-2], (2, 0), mode="reflect")
pitch_features = F.pad(pitch_features[:, :, :-2], (2, 0), mode="reflect")
# [batch_size, 1, length], [batch_size, 3, length] -> [batch_size, 4, length]
pitch_features = torch.cat([energy, pitch_features], dim=1)
formant_shift_indices = (
((formant_shift_semitone + 2.0) * 2.0).round_().long()
)
phone = phone.clone()
quantized_pitch = quantized_pitch.clone()
pitch_features = pitch_features.clone()
formant_shift_indices = formant_shift_indices.clone()
pitch = pitch.clone()
# [batch_sise, hidden_channels, length]
x = (
self.embed_phone(phone)
+ self.embed_quantized_pitch(quantized_pitch).transpose(1, 2)
+ self.embed_pitch_features(pitch_features)
+ (
self.embed_speaker(target_speaker_id)[:, :, None]
+ self.embed_formant_shift(formant_shift_indices)[:, :, None]
)
)
if slice_start_indices is not None:
assert slice_segment_length is not None
# [batch_size, hidden_channels, length] -> [batch_size, hidden_channels, segment_length]
x = slice_segments(x, slice_start_indices, slice_segment_length)
x = F.silu(x, inplace=True)
# [batch_size, hidden_channels, segment_length] -> [batch_size, 1, segment_length * 240]
y_g_hat, stats = self.vocoder(x, pitch)
if return_stats:
return y_g_hat, stats
else:
return y_g_hat
def _normalize_melsp(self, x):
return x.log().mul(0.5).clamp_(min=math.log(1e-5))
def forward_and_compute_loss(
self,
noisy_wavs_16k: torch.Tensor,
target_speaker_id: torch.Tensor,
formant_shift_semitone: torch.Tensor,
slice_start_indices: torch.Tensor,
slice_segment_length: int,
y_all: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# noisy_wavs_16k: [batch_size, 1, wav_length]
# target_speaker_id: Long[batch_size]
# formant_shift_semitone: [batch_size]
# slice_start_indices: [batch_size]
# slice_segment_length: int
# y_all: [batch_size, 1, wav_length]
# [batch_size, 1, wav_length] -> [batch_size, 1, wav_length * 240]
y_hat_all, stats = self(
noisy_wavs_16k,
target_speaker_id,
formant_shift_semitone,
return_stats=True,
)
with torch.cuda.amp.autocast(False):
melsp_periodic_signal = self.melspectrogram(
stats["periodic_signal"].float()
)
melsp_aperiodic_signal = self.melspectrogram(
stats["aperiodic_signal"].float()
)
melsp_noise_excitation = self.melspectrogram(
stats["noise_excitation"].float()
)
# [1, n_mels, 1]
# 1/6 ... [-0.5, 0.5] の一様乱数の平均パワー
# 3/8 ... ハン窓をかけた時のパワー減衰
# 0.5 ... 謎
reference_melsp = self.melspectrogram.mel_scale(
torch.full(
(1, self.melspectrogram.n_fft // 2 + 1, 1),
(1 / 6) * (3 / 8) * 0.5 * self.melspectrogram.win_length,
device=noisy_wavs_16k.device,
)
)
aperiodic_ratio = melsp_aperiodic_signal / (
melsp_periodic_signal + melsp_aperiodic_signal + 1e-5
)
compensation_ratio = reference_melsp / (melsp_noise_excitation + 1e-5)
melsp_y_hat = self.melspectrogram(y_hat_all.float().squeeze(1))
melsp_y_hat = melsp_y_hat * (
(1.0 - aperiodic_ratio) + aperiodic_ratio * compensation_ratio
)
y_hat_mel = self._normalize_melsp(melsp_y_hat)
# [batch_size, 1, wav_length] -> [batch_size, 1, wav_length * 240]
y_hat = slice_segments(
y_hat_all, slice_start_indices * 240, slice_segment_length * 240
)
y_mel = self._normalize_melsp(self.melspectrogram(y_all.squeeze(1)))
# [batch_size, 1, wav_length] -> [batch_size, 1, wav_length * 240]
y = slice_segments(
y_all, slice_start_indices * 240, slice_segment_length * 240
)
loss_mel = F.l1_loss(y_hat_mel, y_mel)
return y, y_hat, y_hat_all, loss_mel
def remove_weight_norm(self):
self.vocoder.remove_weight_norm()
def merge_weights(self):
self.vocoder.merge_weights()
def dump(self, f: Union[BinaryIO, str, bytes, os.PathLike]):
if isinstance(f, (str, bytes, os.PathLike)):
with open(f, "wb") as f:
self.dump(f)
return
if not hasattr(f, "write"):
raise TypeError
dump_layer(self.embed_phone, f)
dump_layer(self.embed_quantized_pitch, f)
dump_layer(self.embed_pitch_features, f)
dump_layer(self.vocoder, f)
# Discriminator
def _normalize(tensor: torch.Tensor, dim: int) -> torch.Tensor:
denom = tensor.norm(p=2.0, dim=dim, keepdim=True).clamp_min(1e-6)
return tensor / denom
class SANConv2d(nn.Conv2d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
bias: bool = True,
padding_mode="zeros",
device=None,
dtype=None,
):
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding=padding,
dilation=dilation,
groups=1,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
scale = self.weight.norm(p=2.0, dim=[1, 2, 3], keepdim=True).clamp_min(1e-6)
self.weight = nn.parameter.Parameter(self.weight / scale.expand_as(self.weight))
self.scale = nn.parameter.Parameter(scale.view(out_channels))
if bias:
self.bias = nn.parameter.Parameter(
torch.zeros(in_channels, device=device, dtype=dtype)
)
else:
self.register_parameter("bias", None)
def forward(
self, input: torch.Tensor, flg_san_train: bool = False
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if self.bias is not None:
input = input + self.bias.view(self.in_channels, 1, 1)
normalized_weight = self._get_normalized_weight()
scale = self.scale.view(self.out_channels, 1, 1)
if flg_san_train:
out_fun = F.conv2d(
input,
normalized_weight.detach(),
None,
self.stride,
self.padding,
self.dilation,
self.groups,
)
out_dir = F.conv2d(
input.detach(),
normalized_weight,
None,
self.stride,
self.padding,
self.dilation,
self.groups,
)
out = out_fun * scale, out_dir * scale.detach()
else:
out = F.conv2d(
input,
normalized_weight,
None,
self.stride,
self.padding,
self.dilation,
self.groups,
)
out = out * scale
return out
@torch.no_grad()
def normalize_weight(self):
self.weight.data = self._get_normalized_weight()
def _get_normalized_weight(self) -> torch.Tensor:
return _normalize(self.weight, dim=[1, 2, 3])
def get_padding(kernel_size: int, dilation: int = 1) -> int:
return (kernel_size * dilation - dilation) // 2
class DiscriminatorP(nn.Module):
def __init__(
self, period: int, kernel_size: int = 5, stride: int = 3, san: bool = False
):
super().__init__()
self.period = period
self.san = san
# fmt: off
self.convs = nn.ModuleList([
weight_norm(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))),
weight_norm(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))),
weight_norm(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))),
weight_norm(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), (get_padding(kernel_size, 1), 0))),
weight_norm(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, (get_padding(kernel_size, 1), 0))),
])
# fmt: on
if san:
self.conv_post = SANConv2d(1024, 1, (3, 1), 1, (1, 0))
else:
self.conv_post = weight_norm(nn.Conv2d(1024, 1, (3, 1), 1, (1, 0)))
def forward(
self, x: torch.Tensor, flg_san_train: bool = False
) -> tuple[
Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], list[torch.Tensor]
]:
fmap = []
b, c, t = x.shape
if t % self.period != 0:
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.silu(x, inplace=True)
fmap.append(x)
if self.san:
x = self.conv_post(x, flg_san_train=flg_san_train)
else:
x = self.conv_post(x)
if flg_san_train:
x_fun, x_dir = x
fmap.append(x_fun)
x_fun = torch.flatten(x_fun, 1, -1)
x_dir = torch.flatten(x_dir, 1, -1)
x = x_fun, x_dir
else:
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class DiscriminatorR(nn.Module):
def __init__(self, resolution: int, san: bool = False):
super().__init__()
self.resolution = resolution
self.san = san
assert len(self.resolution) == 3
self.convs = nn.ModuleList(
[
weight_norm(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))),
weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))),
weight_norm(nn.Conv2d(32, 32, (3, 9), (1, 2), (1, 4))),
weight_norm(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
]
)
if san:
self.conv_post = SANConv2d(32, 1, (3, 3), padding=(1, 1))
else:
self.conv_post = weight_norm(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
def forward(
self, x: torch.Tensor, flg_san_train: bool = False
) -> tuple[
Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], list[torch.Tensor]
]:
fmap = []
x = self._spectrogram(x)
x.unsqueeze_(1)
for l in self.convs:
x = l(x)
x = F.silu(x, inplace=True)
fmap.append(x)
if self.san:
x = self.conv_post(x, flg_san_train=flg_san_train)
else:
x = self.conv_post(x)
if flg_san_train:
x_fun, x_dir = x
fmap.append(x_fun)
x_fun = torch.flatten(x_fun, 1, -1)
x_dir = torch.flatten(x_dir, 1, -1)
x = x_fun, x_dir
else:
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
def _spectrogram(self, x: torch.Tensor) -> torch.Tensor:
n_fft, hop_length, win_length = self.resolution
x = F.pad(
x, ((n_fft - hop_length) // 2, (n_fft - hop_length) // 2), mode="reflect"
)
x.squeeze_(1)
with torch.cuda.amp.autocast(False):
mag = torch.stft(
x.float(),
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
center=False,
return_complex=True,
).abs()
return mag
class MultiPeriodDiscriminator(nn.Module):
def __init__(self, san: bool = False):
super().__init__()
resolutions = [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]]
periods = [2, 3, 5, 7, 11]
self.discriminators = nn.ModuleList(
[DiscriminatorR(r, san=san) for r in resolutions]
+ [DiscriminatorP(p, san=san) for p in periods]
)
self.discriminator_names = [f"R_{n}_{h}_{w}" for n, h, w in resolutions] + [
f"P_{p}" for p in periods
]
self.san = san
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor, flg_san_train: bool = False
) -> tuple[
list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]],
list[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]],
list[list[torch.Tensor]],
list[list[torch.Tensor]],
]:
batch_size = y.size(0)
concatenated_y_y_hat = torch.cat([y, y_hat])
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for d in self.discriminators:
if flg_san_train:
(y_d_fun, y_d_dir), fmap = d(
concatenated_y_y_hat, flg_san_train=flg_san_train
)
y_d_r_fun, y_d_g_fun = torch.split(y_d_fun, batch_size)
y_d_r_dir, y_d_g_dir = torch.split(y_d_dir, batch_size)
y_d_r = y_d_r_fun, y_d_r_dir
y_d_g = y_d_g_fun, y_d_g_dir
else:
y_d, fmap = d(concatenated_y_y_hat, flg_san_train=flg_san_train)
y_d_r, y_d_g = torch.split(y_d, batch_size)
fmap_r = []
fmap_g = []
for fm in fmap:
fm_r, fm_g = torch.split(fm, batch_size)
fmap_r.append(fm_r)
fmap_g.append(fm_g)
y_d_rs.append(y_d_r)
y_d_gs.append(y_d_g)
fmap_rs.append(fmap_r)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
def forward_and_compute_discriminator_loss(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> tuple[torch.Tensor, dict[str, float]]:
y_d_rs, y_d_gs, _, _ = self(y, y_hat, flg_san_train=self.san)
loss = 0.0
stats = {}
assert len(y_d_gs) == len(y_d_rs) == len(self.discriminators)
for dr, dg, name in zip(y_d_rs, y_d_gs, self.discriminator_names):
if self.san:
dr_fun, dr_dir = map(lambda x: x.float(), dr)
dg_fun, dg_dir = map(lambda x: x.float(), dg)
r_loss_fun = F.softplus(1.0 - dr_fun).square().mean()
g_loss_fun = F.softplus(dg_fun).square().mean()
r_loss_dir = F.softplus(1.0 - dr_dir).square().mean()
g_loss_dir = -F.softplus(1.0 - dg_dir).square().mean()
r_loss = r_loss_fun + r_loss_dir
g_loss = g_loss_fun + g_loss_dir
else:
dr = dr.float()
dg = dg.float()
r_loss = (1.0 - dr).square().mean()
g_loss = dg.square().mean()
stats[f"{name}_dr_loss"] = r_loss.item()
stats[f"{name}_dg_loss"] = g_loss.item()
loss += r_loss + g_loss
return loss, stats
def forward_and_compute_generator_loss(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, dict[str, float]]:
_, y_d_gs, fmap_rs, fmap_gs = self(y, y_hat, flg_san_train=False)
stats = {}
# adversarial loss
adv_loss = 0.0
for dg, name in zip(y_d_gs, self.discriminator_names):
dg = dg.float()
if self.san:
g_loss = F.softplus(1.0 - dg).square().mean()
else:
g_loss = (1.0 - dg).square().mean()
stats[f"{name}_gg_loss"] = g_loss.item()
adv_loss += g_loss
# feature mathcing loss
fm_loss = 0.0
for fr, fg in zip(fmap_rs, fmap_gs):
for r, g in zip(fr, fg):
fm_loss += (r.detach() - g).abs().mean()
return adv_loss, fm_loss, stats
# %% [markdown]
# ## Utilities
# %%
class GradBalancer:
"""Adapted from https://github.com/facebookresearch/encodec/blob/main/encodec/balancer.py"""
def __init__(
self,
weights: dict[str, float],
rescale_grads: bool = True,
total_norm: float = 1.0,
ema_decay: float = 0.999,
per_batch_item: bool = True,
):
self.weights = weights
self.per_batch_item = per_batch_item
self.total_norm = total_norm
self.ema_decay = ema_decay
self.rescale_grads = rescale_grads
self.ema_total: dict[str, float] = defaultdict(float)
self.ema_fix: dict[str, float] = defaultdict(float)
def backward(
self,
losses: dict[str, torch.Tensor],
input: torch.Tensor,
scaler: Optional[torch.cuda.amp.GradScaler] = None,
skip_update_ema: bool = False,
) -> dict[str, float]:
stats = {}
if skip_update_ema:
assert len(losses) == len(self.ema_total)
ema_norms = {k: tot / self.ema_fix[k] for k, tot in self.ema_total.items()}
else:
# 各 loss に対して d loss / d input とそのノルムを計算する
norms = {}
grads = {}
for name, loss in losses.items():
if scaler is not None:
loss = scaler.scale(loss)
(grad,) = torch.autograd.grad(loss, [input], retain_graph=True)
if not grad.isfinite().all():
input.backward(grad)
return {}
grad = grad.detach() / (1.0 if scaler is None else scaler.get_scale())
if self.per_batch_item:
dims = tuple(range(1, grad.dim()))
ema_norm = grad.norm(dim=dims).mean()
else:
ema_norm = grad.norm()
norms[name] = float(ema_norm)
grads[name] = grad
# ノルムの移動平均を計算する
for key, value in norms.items():
self.ema_total[key] = self.ema_total[key] * self.ema_decay + value
self.ema_fix[key] = self.ema_fix[key] * self.ema_decay + 1.0
ema_norms = {k: tot / self.ema_fix[k] for k, tot in self.ema_total.items()}
# ログを取る
total_ema_norm = sum(ema_norms.values())
for k, ema_norm in ema_norms.items():
stats[f"grad_norm_value_{k}"] = ema_norm
stats[f"grad_norm_ratio_{k}"] = ema_norm / (total_ema_norm + 1e-12)
# loss の係数の比率を計算する
if self.rescale_grads:
total_weights = sum([self.weights[k] for k in ema_norms])
ratios = {k: w / total_weights for k, w in self.weights.items()}
# 勾配を修正する
loss = 0.0
for name, ema_norm in ema_norms.items():
if self.rescale_grads:
scale = ratios[name] * self.total_norm / (ema_norm + 1e-12)
else:
scale = self.weights[name]
loss += (losses if skip_update_ema else grads)[name] * scale
if scaler is not None:
loss = scaler.scale(loss)
if skip_update_ema:
loss.backward()
else:
input.backward(loss)
return stats
def state_dict(self):
return {
"ema_total": self.ema_total,
"ema_fix": self.ema_fix,
}
def load_state_dict(self, state_dict):
self.ema_total = state_dict["ema_total"]
self.ema_fix = state_dict["ema_fix"]
class QualityTester(nn.Module):
def __init__(self):
super().__init__()
self.utmos = torch.hub.load(
"tarepan/SpeechMOS:v1.0.0", "utmos22_strong", trust_repo=True
).eval()
@torch.inference_mode()
def compute_mos(self, wav: torch.Tensor) -> dict[str, list[float]]:
res = {"utmos": self.utmos(wav, sr=16000).tolist()}
return res
def test(
self, converted_wav: torch.Tensor, source_wav: torch.Tensor
) -> dict[str, list[float]]:
# [batch_size, wav_length]
res = {}
res.update(self.compute_mos(converted_wav))
return res
def test_many(
self, converted_wavs: list[torch.Tensor], source_wavs: list[torch.Tensor]
) -> tuple[dict[str, float], dict[str, list[float]]]:
# list[batch_size, wav_length]
results = defaultdict(list)
assert len(converted_wavs) == len(source_wavs)
for converted_wav, source_wav in zip(converted_wavs, source_wavs):
res = self.test(converted_wav, source_wav)
for metric_name, value in res.items():
results[metric_name].extend(value)
return {
metric_name: sum(values) / len(values)
for metric_name, values in results.items()
}, results
def compute_grad_norm(
model: nn.Module, return_stats: bool = False
) -> Union[float, dict[str, float]]:
total_norm = 0.0
stats = {}
for name, p in model.named_parameters():
if p.grad is None:
continue
param_norm = p.grad.data.norm(2.0).item()
total_norm += param_norm**2
if return_stats:
stats[f"grad_norm_{name}"] = param_norm
total_norm = math.sqrt(total_norm)
if return_stats:
return total_norm, stats
else:
return total_norm
def compute_mean_f0(
files: list[Path], method: Literal["dio", "harvest"] = "dio"
) -> float:
sum_log_f0 = 0.0
n_frames = 0
for file in files:
wav, sr = torchaudio.load(file, backend="soundfile")
if method == "dio":
f0, _ = pyworld.dio(wav.ravel().numpy().astype(np.float64), sr)
elif method == "harvest":
f0, _ = pyworld.harvest(wav.ravel().numpy().astype(np.float64), sr)
else:
raise ValueError(f"Invalid method: {method}")
f0 = f0[f0 > 0]
sum_log_f0 += float(np.log(f0).sum())
n_frames += len(f0)
if n_frames == 0:
return math.nan
mean_log_f0 = sum_log_f0 / n_frames
return math.exp(mean_log_f0)
# %% [markdown]
# ## Dataset
# %%
def get_resampler(
sr_before: int, sr_after: int, device="cpu", cache={}
) -> torchaudio.transforms.Resample:
if not isinstance(device, str):
device = str(device)
if (sr_before, sr_after, device) not in cache:
cache[(sr_before, sr_after, device)] = torchaudio.transforms.Resample(
sr_before, sr_after
).to(device)
return cache[(sr_before, sr_after, device)]
def convolve(signal: torch.Tensor, ir: torch.Tensor) -> torch.Tensor:
n = 1 << (signal.size(-1) + ir.size(-1) - 2).bit_length()
res = torch.fft.irfft(torch.fft.rfft(signal, n=n) * torch.fft.rfft(ir, n=n), n=n)
return res[..., : signal.size(-1)]
def random_filter(audio: torch.Tensor) -> torch.Tensor:
assert audio.ndim == 2
ab = torch.rand(audio.size(0), 6) * 0.75 - 0.375
a, b = ab[:, :3], ab[:, 3:]
a[:, 0] = 1.0
b[:, 0] = 1.0
audio = torchaudio.functional.lfilter(audio, a, b, clamp=False)
return audio
def get_noise(
n_samples: int, sample_rate: float, files: list[Union[str, bytes, os.PathLike]]
) -> torch.Tensor:
resample_augmentation_candidates = [0.9, 0.95, 1.0, 1.05, 1.1]
wavs = []
current_length = 0
while current_length < n_samples:
idx_files = torch.randint(0, len(files), ())
file = files[idx_files]
wav, sr = torchaudio.load(file, backend="soundfile")
assert wav.size(0) == 1
augmented_sample_rate = int(
round(
sample_rate
* resample_augmentation_candidates[
torch.randint(0, len(resample_augmentation_candidates), ())
]
)
)
resampler = get_resampler(sr, augmented_sample_rate)
wav = resampler(wav)
wav = random_filter(wav)
wav *= 0.99 / (wav.abs().max() + 1e-5)
wavs.append(wav)
current_length += wav.size(1)
start = torch.randint(0, current_length - n_samples + 1, ())
wav = torch.cat(wavs, dim=1)[:, start : start + n_samples]
assert wav.size() == (1, n_samples), wav.size()
return wav
def get_butterworth_lpf(
cutoff_freq: int, sample_rate: int, cache={}
) -> tuple[torch.Tensor, torch.Tensor]:
if (cutoff_freq, sample_rate) not in cache:
q = math.sqrt(0.5)
omega = math.tau * cutoff_freq / sample_rate
cos_omega = math.cos(omega)
alpha = math.sin(omega) / (2.0 * q)
b1 = (1.0 - cos_omega) / (1.0 + alpha)
b0 = b1 * 0.5
a1 = -2.0 * cos_omega / (1.0 + alpha)
a2 = (1.0 - alpha) / (1.0 + alpha)
cache[(cutoff_freq, sample_rate)] = torch.tensor([b0, b1, b0]), torch.tensor(
[1.0, a1, a2]
)
return cache[(cutoff_freq, sample_rate)]
def augment_audio(
clean: torch.Tensor,
sample_rate: int,
noise_files: list[Union[str, bytes, os.PathLike]],
ir_files: list[Union[str, bytes, os.PathLike]],
) -> torch.Tensor:
# [1, wav_length]
assert clean.size(0) == 1
n_samples = clean.size(1)
snr_candidates = [-20, -25, -30, -35, -40, -45]
original_clean_rms = clean.square().mean().sqrt_()
# noise を取得して clean と concat する
noise = get_noise(n_samples, sample_rate, noise_files)
signals = torch.cat([clean, noise])
# clean, noise に異なるランダムフィルタをかける
signals = random_filter(signals)
# clean, noise にリバーブをかける
if torch.rand(()) < 0.5:
ir_file = ir_files[torch.randint(0, len(ir_files), ())]
ir, sr = torchaudio.load(ir_file, backend="soundfile")
assert ir.size() == (2, sr), ir.size()
assert sr == sample_rate, (sr, sample_rate)
signals = convolve(signals, ir)
# clean, noise に同じ LPF をかける
if torch.rand(()) < 0.2:
if signals.abs().max() > 0.8:
signals /= signals.abs().max() * 1.25
cutoff_freq_candidates = [2000, 3000, 4000, 6000]
cutoff_freq = cutoff_freq_candidates[
torch.randint(0, len(cutoff_freq_candidates), ())
]
b, a = get_butterworth_lpf(cutoff_freq, sample_rate)
signals = torchaudio.functional.lfilter(signals, a, b, clamp=False)
# clean の音量を合わせる
clean, noise = signals
clean_rms = clean.square().mean().sqrt_()
clean *= original_clean_rms / clean_rms
# clean, noise の音量をピークを重視して取る
clean_level = clean.square().square_().mean().sqrt_().sqrt_()
noise_level = noise.square().square_().mean().sqrt_().sqrt_()
# SNR
snr = snr_candidates[torch.randint(0, len(snr_candidates), ())]
# noisy を生成
noisy = clean + noise * (10.0 ** (snr / 20.0) * clean_level / (noise_level + 1e-5))
return noisy
class WavDataset(torch.utils.data.Dataset):
def __init__(
self,
audio_files: list[tuple[Path, int]],
in_sample_rate: int = 16000,
out_sample_rate: int = 24000,
wav_length: int = 4 * 24000, # 4s
segment_length: int = 100, # 1s
noise_files: Optional[list[Union[str, bytes, os.PathLike]]] = None,
ir_files: Optional[list[Union[str, bytes, os.PathLike]]] = None,
):
self.audio_files = audio_files
self.in_sample_rate = in_sample_rate
self.out_sample_rate = out_sample_rate
self.wav_length = wav_length
self.segment_length = segment_length
self.noise_files = noise_files
self.ir_files = ir_files
if (noise_files is None) is not (ir_files is None):
raise ValueError("noise_files and ir_files must be both None or not None")
self.in_hop_length = in_sample_rate // 100
self.out_hop_length = out_sample_rate // 100 # 10ms 刻み
def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor, int, int]:
file, speaker_id = self.audio_files[index]
clean_wav, sample_rate = torchaudio.load(file, backend="soundfile")
formant_shift_candidates = [-2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]
formant_shift = formant_shift_candidates[
torch.randint(0, len(formant_shift_candidates), ()).item()
]
resampler_fraction = Fraction(
sample_rate / self.out_sample_rate * 2.0 ** (formant_shift / 12.0)
).limit_denominator(300)
clean_wav = get_resampler(
resampler_fraction.numerator, resampler_fraction.denominator
)(clean_wav)
assert clean_wav.size(0) == 1
assert clean_wav.size(1) != 0
clean_wav = F.pad(clean_wav, (self.wav_length, self.wav_length))
if self.noise_files is None:
assert False
noisy_wav_16k = get_resampler(self.out_sample_rate, self.in_sample_rate)(
clean_wav
)
else:
clean_wav_16k = get_resampler(self.out_sample_rate, self.in_sample_rate)(
clean_wav
)
noisy_wav_16k = augment_audio(
clean_wav_16k, self.in_sample_rate, self.noise_files, self.ir_files
)
clean_wav = clean_wav.squeeze_(0)
noisy_wav_16k = noisy_wav_16k.squeeze_(0)
# 音量をランダマイズする
amplitude = torch.rand(()).item() * 0.899 + 0.1
factor = amplitude / clean_wav.abs().max()
clean_wav *= factor
noisy_wav_16k *= factor
while noisy_wav_16k.abs().max() >= 1.0:
clean_wav *= 0.5
noisy_wav_16k *= 0.5
return clean_wav, noisy_wav_16k, speaker_id, formant_shift
def __len__(self) -> int:
return len(self.audio_files)
def collate(
self, batch: list[tuple[torch.Tensor, torch.Tensor, int, int]]
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
assert self.wav_length % self.out_hop_length == 0
length = self.wav_length // self.out_hop_length
clean_wavs = []
noisy_wavs = []
slice_starts = []
speaker_ids = []
formant_shifts = []
for clean_wav, noisy_wav, speaker_id, formant_shift in batch:
# 発声部分をランダムに 1 箇所選ぶ
(voiced,) = clean_wav.nonzero(as_tuple=True)
assert voiced.numel() != 0
center = voiced[torch.randint(0, voiced.numel(), ()).item()].item()
# 発声部分が中央にくるように、スライス区間を選ぶ
slice_start = center - self.segment_length * self.out_hop_length // 2
assert slice_start >= 0
# スライス区間が含まれるように、ランダムに wav_length の長さを切り出す
r = torch.randint(0, length - self.segment_length + 1, ()).item()
offset = slice_start - r * self.out_hop_length
clean_wavs.append(clean_wav[offset : offset + self.wav_length])
offset_in_sample_rate = int(
round(offset * self.in_sample_rate / self.out_sample_rate)
)
noisy_wavs.append(
noisy_wav[
offset_in_sample_rate : offset_in_sample_rate
+ length * self.in_hop_length
]
)
slice_start = r
slice_starts.append(slice_start)
speaker_ids.append(speaker_id)
formant_shifts.append(formant_shift)
clean_wavs = torch.stack(clean_wavs)
noisy_wavs = torch.stack(noisy_wavs)
slice_starts = torch.tensor(slice_starts)
speaker_ids = torch.tensor(speaker_ids)
formant_shifts = torch.tensor(formant_shifts)
return (
clean_wavs, # [batch_size, wav_length]
noisy_wavs, # [batch_size, wav_length]
slice_starts, # Long[batch_size]
speaker_ids, # Long[batch_size]
formant_shifts, # Long[batch_size]
)
# %% [markdown]
# ## Train
# %%
AUDIO_FILE_SUFFIXES = {
".wav",
".aif",
".aiff",
".fla",
".flac",
".oga",
".ogg",
".opus",
".mp3",
}
def prepare_training():
# 各種準備をする
# 副作用として、出力ディレクトリと TensorBoard のログファイルなどが生成される
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device={device}")
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
(h, in_wav_dataset_dir, out_dir, resume) = (
prepare_training_configs_for_experiment
if is_notebook()
else prepare_training_configs
)()
print("config:")
print(h)
print()
h = AttrDict(h)
if not in_wav_dataset_dir.is_dir():
raise ValueError(f"{in_wav_dataset_dir} is not found.")
if resume:
latest_checkpoint_file = out_dir / "checkpoint_latest.pt"
if not latest_checkpoint_file.is_file():
raise ValueError(f"{latest_checkpoint_file} is not found.")
else:
if out_dir.is_dir():
if (out_dir / "checkpoint_latest.pt").is_file():
raise ValueError(
f"{out_dir / 'checkpoint_latest.pt'} already exists. "
"Please specify a different output directory, or use --resume option."
)
for file in out_dir.iterdir():
if file.suffix == ".pt":
raise ValueError(
f"{out_dir} already contains model files. "
"Please specify a different output directory."
)
else:
out_dir.mkdir(parents=True)
in_ir_wav_dir = repo_root() / h.in_ir_wav_dir
in_noise_wav_dir = repo_root() / h.in_noise_wav_dir
in_test_wav_dir = repo_root() / h.in_test_wav_dir
assert in_wav_dataset_dir.is_dir(), in_wav_dataset_dir
assert out_dir.is_dir(), out_dir
assert in_ir_wav_dir.is_dir(), in_ir_wav_dir
assert in_noise_wav_dir.is_dir(), in_noise_wav_dir
assert in_test_wav_dir.is_dir(), in_test_wav_dir
# .wav または *.flac のファイルを再帰的に取得
noise_files = sorted(
list(in_noise_wav_dir.rglob("*.wav")) + list(in_noise_wav_dir.rglob("*.flac"))
)
if len(noise_files) == 0:
raise ValueError(f"No audio data found in {in_noise_wav_dir}.")
ir_files = sorted(
list(in_ir_wav_dir.rglob("*.wav")) + list(in_ir_wav_dir.rglob("*.flac"))
)
if len(ir_files) == 0:
raise ValueError(f"No audio data found in {in_ir_wav_dir}.")
# TODO: 無音除去とか
def get_training_filelist(in_wav_dataset_dir: Path):
min_data_per_speaker = 1
speakers: list[str] = []
training_filelist: list[tuple[Path, int]] = []
speaker_audio_files: list[list[Path]] = []
for speaker_dir in sorted(in_wav_dataset_dir.iterdir()):
if not speaker_dir.is_dir():
continue
candidates = []
for wav_file in sorted(speaker_dir.rglob("*")):
if (
not wav_file.is_file()
or wav_file.suffix.lower() not in AUDIO_FILE_SUFFIXES
):
continue
candidates.append(wav_file)
if len(candidates) >= min_data_per_speaker:
speaker_id = len(speakers)
speakers.append(speaker_dir.name)
training_filelist.extend([(file, speaker_id) for file in candidates])
speaker_audio_files.append(candidates)
return speakers, training_filelist, speaker_audio_files
speakers, training_filelist, speaker_audio_files = get_training_filelist(
in_wav_dataset_dir
)
n_speakers = len(speakers)
if n_speakers == 0:
raise ValueError(f"No speaker data found in {in_wav_dataset_dir}.")
print(f"{n_speakers=}")
for i, speaker in enumerate(speakers):
print(f" {i:{len(str(n_speakers - 1))}d}: {speaker}")
print()
print(f"{len(training_filelist)=}")
def get_test_filelist(
in_test_wav_dir: Path, n_speakers: int
) -> list[tuple[Path, list[int]]]:
max_n_test_files = 1000
test_filelist = []
rng = Random(42)
def get_target_id_generator():
if n_speakers > 8:
while True:
order = list(range(n_speakers))
rng.shuffle(order)
yield from order
else:
while True:
yield from range(n_speakers)
for file in sorted(in_test_wav_dir.iterdir())[:max_n_test_files]:
if file.suffix.lower() not in AUDIO_FILE_SUFFIXES:
continue
target_ids = [
next(get_target_id_generator()) for _ in range(min(8, n_speakers))
]
test_filelist.append((file, target_ids))
return test_filelist
test_filelist = get_test_filelist(in_test_wav_dir, n_speakers)
if len(test_filelist) == 0:
warnings.warn(f"No audio data found in {test_filelist}.")
print(f"{len(test_filelist)=}")
for file, target_ids in test_filelist[:12]:
print(f" {file}, {target_ids}")
if len(test_filelist) > 12:
print(" ...")
print()
# データ
training_dataset = WavDataset(
training_filelist,
in_sample_rate=h.in_sample_rate,
out_sample_rate=h.out_sample_rate,
wav_length=h.wav_length,
segment_length=h.segment_length,
noise_files=noise_files,
ir_files=ir_files,
)
training_loader = torch.utils.data.DataLoader(
training_dataset,
num_workers=h.num_workers,
collate_fn=training_dataset.collate,
shuffle=True,
sampler=None,
batch_size=h.batch_size,
pin_memory=True,
drop_last=True,
)
print("Computing mean F0s of target speakers...")
speaker_f0s = []
for speaker, files in enumerate(speaker_audio_files):
if len(files) > 10:
files = Random(42).sample(files, 10)
f0 = compute_mean_f0(files)
speaker_f0s.append(f0)
print(f" {speaker:3d}: {f0:.1f}Hz", end=",")
if speaker % 5 == 4:
print()
print("Done.")
print("Computing pitch shifts for test files...")
test_pitch_shifts = []
# fmt: off
# TODO
source_f0s_cache = [275.9,230.0,135.8,129.6,256.4,357.4,463.7,315.3,144.8,119.5,232.4,349.4,444.0,330.7,182.5,272.4,314.9,282.6,171.5,250.0,208.7,317.9,325.3,168.0,320.0,308.3,113.3,196.8,244.7,292.9,381.3,297.6,218.5,286.4,350.8,372.1,276.7,309.9,157.3,160.0,147.3,128.3,332.4,301.4,123.5,308.9,206.5,312.5,186.5,327.7,335.8,116.1,505.2,452.8,135.4,293.3,239.9,368.0,293.0,288.4,184.1,441.3,272.3,278.9,154.2,297.1,360.3,545.7,270.3,412.0,294.5,351.6,336.6,120.1,308.4,236.4,340.5,158.2,170.8,460.8,407.4,302.8,165.9,304.8,154.1,250.9,257.1,206.0,222.6,346.8,378.7,131.8,166.6,109.0,239.6,206.5,214.1,192.5,132.9,487.1,270.4,142.6,117.2,364.6,321.8,130.3,491.6,263.8,346.6,349.0,290.5,264.4,117.1,136.8,146.6,336.0,228.0,242.0,134.3,301.5,414.9,212.4,108.6,321.7,313.4,168.6,222.4,399.4,325.3,175.5,116.8,411.9,303.3,380.6,230.7,127.4,167.6,151.5,332.7,149.3,126.0,356.2,247.9,246.2,396.3,245.9,175.4,290.5,208.1,505.1,480.1,245.1,241.8,269.6,280.3,407.6,140.6,263.5,255.8,195.6,272.2,167.8,383.1,254.7,101.8,245.7,114.1,124.2,349.1,180.3,205.4,242.6,407.6,303.3,272.0,248.3,106.5,263.8,266.3,147.1,201.6,412.0,192.4,122.0,184.8,381.5,243.3,153.9,174.1,323.9,254.5,209.3,241.1,402.6,264.8,333.6,293.1,194.1,124.7,407.6,161.7,134.2,383.3,268.7,286.6,326.4,302.9,183.4,362.3,86.7,198.0,238.9,229.0,244.4,258.4,196.5,170.3,299.5,235.1,219.4,260.6,231.9,269.9,112.2,168.4,239.7,195.7,228.4,190.7,336.6,357.6,185.3,349.7,293.1,229.2,287.0,354.0,245.4,306.5,185.8,369.7,118.6,98.5,269.4,264.4,320.1,159.5,281.3,300.2,99.2,399.2,316.0,279.7,334.7,109.4,162.1,113.0,237.9,107.1,122.3,246.8,303.3,118.2,122.8,139.2,258.1,250.0,188.7,252.8,120.2,272.4,114.2,251.9,198.7,364.1,588.4,269.5,164.6,192.5,313.7,280.3,235.2,264.5,185.2,500.4,324.9,275.0,160.7,173.3,237.3,393.4,286.0,313.2,166.1,278.7,200.4,292.8,260.6,297.0,113.1,228.5,118.4,359.7,127.6,128.3,332.2,367.7,548.5,290.3,273.7,176.5,250.0,219.7,218.9,471.2,355.9,283.9,230.9,290.5,200.0,268.4,260.4,339.9,416.2,171.2,443.1,273.0,118.4,371.9,228.1,295.3,482.2,391.1,173.0,131.1,112.6,286.9,95.0,349.7,218.4,143.8,344.0,153.4,364.6,329.7,213.3,499.0,162.6,227.5,526.2,151.8,242.6,273.6,107.1,339.8,350.8,324.8,173.9,247.5,401.9,253.9,294.9,281.4,342.1,495.2,141.4,371.3,265.3,403.5,137.3,270.8,143.2,291.4,199.9,274.0,121.3,326.5,143.8,371.3,255.7,392.6,302.8,152.1,332.2,225.1,257.5,470.4,220.3,226.2,101.2,301.4,137.4,274.5,218.1,270.4,238.5,192.6,461.5,219.3,368.8,248.7,316.6,242.5,335.8,274.6,241.5,386.5,232.6,510.6,203.1,291.0,229.6,346.0,262.9,397.0,136.8,189.6,304.7,177.7,351.1,220.2,113.2,301.1,311.6,106.2,105.7,342.1,125.0,252.6,371.2,274.1,152.2,461.0,131.1,211.6,325.3,421.7,291.0,402.4,254.7,304.3,118.5,122.1,436.8,186.6,217.5,356.3,207.1,147.6,390.9,293.7,409.4,151.5,182.5,433.9,302.2,278.0,145.0,345.4,322.4,85.9,237.0,342.1,329.2,546.4,287.2,246.3,249.5,293.2,318.1,271.7,122.7,183.6,135.4,327.0,254.4,367.8,256.8,233.4,268.9,290.4,242.8,178.2,174.8,536.3,171.9,262.6,285.9,306.2,212.1,295.0,180.3,371.8,264.8,318.6,278.4,267.3,437.1,466.0,272.6,168.4,312.9,299.5,424.7,177.9,294.1,274.7,318.9,395.0,548.1,116.7,322.8,188.5,315.3,321.1,280.3,254.2,249.6,274.5,333.2,285.6,370.4,278.7,218.2,106.1,122.3,351.7,254.7,465.2,100.6,176.0,187.6,181.3,277.3,156.0,437.3,322.7,293.5,361.4,237.4,272.0,218.1,236.3,292.1,245.8,335.5,197.7,150.7,210.1,278.5,386.0,198.7,246.8,177.0,337.1,342.3,264.6,412.4,189.0,138.1,274.8,246.6,469.5,183.1,277.8,296.0,376.6,96.3,236.8,102.9,292.1,121.9,238.2,431.0,130.3,195.3,532.3,447.0,265.7,216.3,104.3,124.7,138.6,193.9,322.3,327.1,124.2,398.0,124.7,200.5,130.0,410.0,323.2,251.0,283.8,127.7,156.6,261.3,114.3,316.5,192.8,134.6,230.1,175.8,200.4,156.7,103.4,293.8,475.1,295.1,112.7,608.3,294.5,518.7,114.5,106.1,556.2,108.6,266.8,275.5,406.8,377.2,206.3,127.9,293.7,239.3,323.0,338.4,128.5,501.1,94.3,143.2,261.0,433.6,166.8,281.7,116.2,255.7,172.8,174.7,104.4,164.4,343.8,269.2,177.2,130.2,186.2,471.5,161.1,257.6,270.7,99.4,274.2,209.8,161.3,181.8,241.0,134.6,154.0,366.2,316.4,211.1,478.1,131.8,288.8,114.4,545.1,303.8,327.2,321.7,98.2,285.0,122.9,150.0,176.6,187.2,322.2,221.6,189.8,174.7,142.8,410.0,349.0,487.4,132.2,147.8,121.3,282.5,288.1,464.8,365.6,289.4,209.9,256.8,129.7,247.0,262.8,266.1,312.2,297.9,221.5,116.0,402.5,157.1,123.8,120.6,113.2,304.2,232.9,468.6,130.3,194.5,266.7,112.6,381.8,389.5,413.6,427.5,126.5,242.2,259.3,161.0,365.0,155.9,460.0,433.9,288.2,360.5,252.0,119.6,373.3,126.1,130.5,284.1,539.1,265.9,245.9,346.3,283.3,162.2,113.4,134.3,181.5,178.7,275.4,111.8,247.4,201.1,311.3,278.1,268.9,401.1,107.0,306.8,167.8,259.2,169.2,127.6,173.3,270.2,153.0,264.3,316.8,341.2,321.4,271.1,297.1,424.9,284.9,123.7,147.8,226.0,267.2,202.1,199.4,219.5,209.4,343.8,108.8,119.4,273.5,176.0,122.4,130.4,322.2,113.4,261.5,281.3,293.4,127.3,281.4,490.3,358.4,296.0,370.4,307.7,196.5,119.9,206.8,282.4,128.8,234.5,239.5,182.8,321.4,285.7,101.9,212.0,254.4,352.1,261.6,145.4,107.6,304.7,111.1,317.6,124.2,290.9,295.9,406.4,369.2,283.2,96.9,188.3,170.5,150.3,463.3,345.0,166.2,261.6,277.9,334.8,341.2,207.2,215.4,332.2,326.7,233.1,322.8,174.2,452.0,327.5,328.6,162.1,188.4,356.0,151.6,318.5,119.5,247.7,191.9,217.4,233.0,130.8,283.1,242.9,259.4,359.0,230.6,302.9,261.1,257.8,308.1,271.7,373.6,333.4,150.8,292.0,343.1,277.5,318.0,142.1,174.5,119.3,382.6,306.9,136.5,215.3,122.4,336.4,311.7,133.5,294.1,158.3,201.7,152.8,246.6,151.3,248.0,231.1,457.5,309.6,325.9,192.9,177.8,113.9,295.8,315.6,134.6,108.5,317.3,121.7,141.0,235.3,417.0,201.9,176.0,229.5,140.8,156.2,312.2,139.3,241.5,360.5,321.1,351.8,341.1,233.3,379.7,129.5,241.7,291.5,434.2,296.1,232.0,237.4,314.7,279.4,304.6,279.9,358.8,289.5,491.9,132.1,469.1,103.9,172.9,516.0,103.8,591.9,351.3,116.9,360.0,431.4,147.4,131.5,218.1,110.6,218.6,231.5,423.5,318.5,167.7,284.0,276.8,149.3,295.1,244.8,290.6,233.8,293.5,131.5,231.5,127.4,328.2,199.7,102.5,257.5,459.1,267.0,149.8,226.9,194.7,297.2,337.5,283.9,130.7,171.9,249.4,278.9,123.6,111.6,293.0,193.0,400.1,412.4,410.8,321.3,380.7,213.4,157.9,307.0,126.0,126.9,188.3,362.5,246.5,295.3,140.2,241.1]
# fmt: on
source_f0s = []
for i, (file, target_ids) in enumerate(tqdm(test_filelist)):
# source_f0 = compute_mean_f0([file], method="harvest")
source_f0 = source_f0s_cache[i]
source_f0s.append(source_f0)
if source_f0 != source_f0:
test_pitch_shifts.append([0] * len(target_ids))
continue
pitch_shifts = []
for target_id in target_ids:
target_f0 = speaker_f0s[target_id]
if target_f0 != target_f0:
pitch_shift = 0
else:
pitch_shift = int(round(12 * math.log2(target_f0 / source_f0)))
pitch_shifts.append(pitch_shift)
test_pitch_shifts.append(pitch_shifts)
print("Done.")
# モデルと最適化
phone_extractor = PhoneExtractor().to(device).eval().requires_grad_(False)
phone_extractor_checkpoint = torch.load(
repo_root() / h.phone_extractor_file, map_location="cpu"
)
print(
phone_extractor.load_state_dict(
phone_extractor_checkpoint["phone_extractor"],
strict=False,
)
)
del phone_extractor_checkpoint
pitch_estimator = PitchEstimator().to(device).eval().requires_grad_(False)
pitch_estimator_checkpoint = torch.load(
repo_root() / h.pitch_estimator_file, map_location="cpu"
)
print(
pitch_estimator.load_state_dict(pitch_estimator_checkpoint["pitch_estimator"])
)
del pitch_estimator_checkpoint
net_g = ConverterNetwork(
phone_extractor,
pitch_estimator,
n_speakers,
h.hidden_channels,
).to(device)
net_d = MultiPeriodDiscriminator(san=h.san).to(device)
optim_g = torch.optim.AdamW(
net_g.parameters(),
h.learning_rate,
betas=h.adam_betas,
eps=h.adam_eps,
)
optim_d = torch.optim.AdamW(
net_d.parameters(),
h.learning_rate,
betas=h.adam_betas,
eps=h.adam_eps,
)
grad_scaler = torch.cuda.amp.GradScaler(enabled=h.use_amp)
grad_balancer = GradBalancer(
weights={
"loss_mel": h.grad_weight_mel,
"loss_adv": h.grad_weight_adv,
"loss_fm": h.grad_weight_fm,
},
ema_decay=h.grad_balancer_ema_decay,
)
resample_to_in_sample_rate = torchaudio.transforms.Resample(
h.out_sample_rate, h.in_sample_rate
).to(device)
# チェックポイント読み出し
initial_iteration = 0
if resume:
checkpoint_file = latest_checkpoint_file
elif h.pretrained_file is not None:
checkpoint_file = repo_root() / h.pretrained_file
else:
checkpoint_file = None
if checkpoint_file is not None:
checkpoint = torch.load(checkpoint_file, map_location="cpu")
checkpoint_n_speakers = len(checkpoint["net_g"]["embed_speaker.weight"])
if not resume: # ファインチューニング
mean_speaker_embedding = checkpoint["net_g"]["embed_speaker.weight"].mean(
0, keepdim=True
)
if True:
# 0 とかランダムとかの方が良いかもしれない
checkpoint["net_g"]["embed_speaker.weight"] = mean_speaker_embedding[
[0] * n_speakers
]
else: # 話者追加用
assert n_speakers > checkpoint_n_speakers
print(
f"embed_speaker.weight was padded: {checkpoint_n_speakers} -> {n_speakers}"
)
checkpoint["net_g"]["embed_speaker.weight"] = F.pad(
checkpoint["net_g"]["embed_speaker.weight"],
(0, 0, 0, n_speakers - checkpoint_n_speakers),
)
checkpoint["net_g"]["embed_speaker.weight"][
checkpoint_n_speakers:
] = mean_speaker_embedding
print(net_g.load_state_dict(checkpoint["net_g"], strict=False))
print(net_d.load_state_dict(checkpoint["net_d"], strict=False))
if resume:
optim_g.load_state_dict(checkpoint["optim_g"])
optim_d.load_state_dict(checkpoint["optim_d"])
grad_balancer.load_state_dict(checkpoint["gradient_balancer"])
initial_iteration = checkpoint["iteration"]
grad_scaler.load_state_dict(checkpoint["scaler"])
# スケジューラ
def get_cosine_annealing_warmup_scheduler(
optimizer: torch.optim.Optimizer,
warmup_epochs: int,
total_epochs: int,
min_learning_rate: float,
) -> torch.optim.lr_scheduler.LambdaLR:
lr_ratio = min_learning_rate / optimizer.param_groups[0]["lr"]
m = 0.5 * (1.0 - lr_ratio)
a = 0.5 * (1.0 + lr_ratio)
def lr_lambda(current_epoch: int) -> float:
if current_epoch < warmup_epochs:
return current_epoch / warmup_epochs
elif current_epoch < total_epochs:
rate = (current_epoch - warmup_epochs) / (total_epochs - warmup_epochs)
return math.cos(rate * math.pi) * m + a
else:
return min_learning_rate
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
scheduler_g = get_cosine_annealing_warmup_scheduler(
optim_g, h.warmup_steps, h.n_steps, h.min_learning_rate
)
scheduler_d = get_cosine_annealing_warmup_scheduler(
optim_d, h.warmup_steps, h.n_steps, h.min_learning_rate
)
for _ in range(initial_iteration + 1):
scheduler_g.step()
scheduler_d.step()
net_g.train()
net_d.train()
# ログとか
dict_scalars = defaultdict(list)
quality_tester = QualityTester().eval().to(device)
writer = SummaryWriter(out_dir)
writer.add_text(
"log",
f"start training w/ {torch.cuda.get_device_name(device) if torch.cuda.is_available() else 'cpu'}.",
initial_iteration,
)
if not resume:
with open(out_dir / "config.json", "w", encoding="utf-8") as f:
json.dump(dict(h), f, indent=4)
if not is_notebook():
shutil.copy(__file__, out_dir)
return (
device,
in_wav_dataset_dir,
h,
out_dir,
speakers,
test_filelist,
training_loader,
speaker_f0s,
test_pitch_shifts,
phone_extractor,
pitch_estimator,
net_g,
net_d,
optim_g,
optim_d,
grad_scaler,
grad_balancer,
resample_to_in_sample_rate,
initial_iteration,
scheduler_g,
scheduler_d,
dict_scalars,
quality_tester,
writer,
)
if __name__ == "__main__":
(
device,
in_wav_dataset_dir,
h,
out_dir,
speakers,
test_filelist,
training_loader,
speaker_f0s,
test_pitch_shifts,
phone_extractor,
pitch_estimator,
net_g,
net_d,
optim_g,
optim_d,
grad_scaler,
grad_balancer,
resample_to_in_sample_rate,
initial_iteration,
scheduler_g,
scheduler_d,
dict_scalars,
quality_tester,
writer,
) = prepare_training()
# 学習
for iteration in tqdm(range(initial_iteration, h.n_steps)):
# === 1. データ前処理 ===
try:
batch = next(data_iter)
except:
data_iter = iter(training_loader)
batch = next(data_iter)
(
clean_wavs,
noisy_wavs_16k,
slice_starts,
speaker_ids,
formant_shift_semitone,
) = map(lambda x: x.to(device, non_blocking=True), batch)
# === 2.1 Discriminator の学習 ===
with torch.cuda.amp.autocast(h.use_amp):
# Generator
y, y_hat, y_hat_for_backward, loss_mel = net_g.forward_and_compute_loss(
noisy_wavs_16k[:, None, :],
speaker_ids,
formant_shift_semitone,
slice_start_indices=slice_starts,
slice_segment_length=h.segment_length,
y_all=clean_wavs[:, None, :],
)
assert y_hat.isfinite().all()
assert loss_mel.isfinite().all()
# Discriminator
loss_discriminator, discriminator_d_stats = (
net_d.forward_and_compute_discriminator_loss(y, y_hat.detach())
)
optim_d.zero_grad()
grad_scaler.scale(loss_discriminator).backward()
grad_scaler.unscale_(optim_d)
grad_norm_d, d_grad_norm_stats = compute_grad_norm(net_d, True)
grad_scaler.step(optim_d)
# === 2.2 Generator の学習 ===
with torch.cuda.amp.autocast(h.use_amp):
# Discriminator
loss_adv, loss_fm, discriminator_g_stats = (
net_d.forward_and_compute_generator_loss(y, y_hat)
)
optim_g.zero_grad()
gradient_balancer_stats = grad_balancer.backward(
{
"loss_mel": loss_mel,
"loss_adv": loss_adv,
"loss_fm": loss_fm,
},
y_hat_for_backward,
grad_scaler,
skip_update_ema=iteration > 10 and iteration % 5 != 0,
)
grad_scaler.unscale_(optim_g)
grad_norm_g, g_grad_norm_stats = compute_grad_norm(net_g, True)
grad_scaler.step(optim_g)
grad_scaler.update()
# === 3. ログ ===
dict_scalars["loss_g/loss_mel"].append(loss_mel.item())
dict_scalars["loss_g/loss_fm"].append(loss_fm.item())
dict_scalars["loss_g/loss_adv"].append(loss_adv.item())
dict_scalars["other/grad_scale"].append(grad_scaler.get_scale())
dict_scalars["loss_d/loss_discriminator"].append(loss_discriminator.item())
if grad_norm_d == grad_norm_d:
dict_scalars["other/gradient_norm_d"].append(grad_norm_d)
for name, value in d_grad_norm_stats.items():
dict_scalars[f"~gradient_norm_d/{name}"].append(value)
if grad_norm_g == grad_norm_g:
dict_scalars["other/gradient_norm_g"].append(grad_norm_g)
for name, value in g_grad_norm_stats.items():
dict_scalars[f"~gradient_norm_g/{name}"].append(value)
dict_scalars["other/lr_g"].append(scheduler_g.get_last_lr()[0])
dict_scalars["other/lr_d"].append(scheduler_d.get_last_lr()[0])
for k, v in discriminator_d_stats.items():
dict_scalars[f"~loss_discriminator/{k}"].append(v)
for k, v in discriminator_g_stats.items():
dict_scalars[f"~loss_discriminator/{k}"].append(v)
for k, v in gradient_balancer_stats.items():
dict_scalars[f"~gradient_balancer/{k}"].append(v)
if (iteration + 1) % 1000 == 0 or iteration == 0:
for name, scalars in dict_scalars.items():
if scalars:
writer.add_scalar(name, sum(scalars) / len(scalars), iteration + 1)
scalars.clear()
# === 4. 検証 ===
if (iteration + 1) % 50000 == 0 or iteration + 1 in {
1,
10000,
30000,
h.n_steps,
}:
net_g.eval()
torch.cuda.empty_cache()
dict_qualities_all = defaultdict(list)
n_added_wavs = 0
with torch.inference_mode():
for i, ((file, target_ids), pitch_shift_semitones) in enumerate(
zip(test_filelist, test_pitch_shifts)
):
source_wav, sr = torchaudio.load(file, backend="soundfile")
source_wav = source_wav.to(device)
if sr != h.in_sample_rate:
source_wav = get_resampler(sr, h.in_sample_rate, device)(
source_wav
)
source_wav = source_wav.to(device)
original_source_wav_length = source_wav.size(1)
# 長さのパターンを減らしてキャッシュを効かせる
if source_wav.size(1) % h.in_sample_rate != 0:
source_wav = F.pad(
source_wav,
(
0,
h.in_sample_rate
- source_wav.size(1) % h.in_sample_rate,
),
)
converted = net_g(
source_wav[[0] * len(target_ids), None],
torch.tensor(target_ids, device=device),
torch.tensor(
[0.0] * len(target_ids), device=device
), # フォルマントシフト
torch.tensor(
[float(p) for p in pitch_shift_semitones], device=device
),
).squeeze_(1)[:, : original_source_wav_length // 160 * 240]
if i < 12:
if iteration == 0:
writer.add_audio(
f"source/y_{i:02d}",
source_wav,
iteration + 1,
h.in_sample_rate,
)
for d in range(
min(len(target_ids), 1 + (12 - i - 1) // len(test_filelist))
):
idx_in_barch = n_added_wavs % len(target_ids)
writer.add_audio(
f"converted/y_hat_{i:02d}_{target_ids[0]:03d}_{pitch_shift_semitones[0]:+02d}",
converted[0],
iteration + 1,
h.out_sample_rate,
)
n_added_wavs += 1
converted = resample_to_in_sample_rate(converted)
quality = quality_tester.test(converted, source_wav)
for metric_name, values in quality.items():
dict_qualities_all[metric_name].extend(values)
assert n_added_wavs == min(
12, len(test_filelist) * len(test_filelist[0][1])
), (
n_added_wavs,
len(test_filelist),
len(speakers),
len(test_filelist[0][1]),
)
dict_qualities = {
metric_name: sum(values) / len(values)
for metric_name, values in dict_qualities_all.items()
if len(values)
}
for metric_name, value in dict_qualities.items():
writer.add_scalar(f"validation/{metric_name}", value, iteration + 1)
for metric_name, values in dict_qualities_all.items():
for i, value in enumerate(values):
writer.add_scalar(
f"~validation_{metric_name}/{i:03d}", value, iteration + 1
)
del dict_qualities, dict_qualities_all
gc.collect()
net_g.train()
torch.cuda.empty_cache()
# === 5. 保存 ===
if (iteration + 1) % 50000 == 0 or iteration + 1 in {
1,
10000,
30000,
h.n_steps,
}:
# チェックポイント
name = f"{in_wav_dataset_dir.name}_{iteration + 1:08d}"
checkpoint_file_save = out_dir / f"checkpoint_{name}.pt"
if checkpoint_file_save.exists():
checkpoint_file_save = checkpoint_file_save.with_name(
f"{checkpoint_file_save.name}_{hash(None):x}"
)
torch.save(
{
"iteration": iteration + 1,
"net_g": net_g.state_dict(),
"phone_extractor": phone_extractor.state_dict(),
"pitch_estimator": pitch_estimator.state_dict(),
"net_d": net_d.state_dict(),
"optim_g": optim_g.state_dict(),
"optim_d": optim_d.state_dict(),
"gradient_balancer": grad_balancer.state_dict(),
"scaler": grad_scaler.state_dict(),
"h": dict(h),
},
checkpoint_file_save,
)
shutil.copy(checkpoint_file_save, out_dir / "checkpoint_latest.pt")
# 推論用
paraphernalia_dir = out_dir / f"paraphernalia_{name}"
if paraphernalia_dir.exists():
paraphernalia_dir = paraphernalia_dir.with_name(
f"{paraphernalia_dir.name}_{hash(None):x}"
)
paraphernalia_dir.mkdir()
phone_extractor_fp16 = PhoneExtractor()
phone_extractor_fp16.load_state_dict(phone_extractor.state_dict())
phone_extractor_fp16.remove_weight_norm()
phone_extractor_fp16.merge_weights()
phone_extractor_fp16.half()
phone_extractor_fp16.dump(paraphernalia_dir / f"phone_extractor.bin")
del phone_extractor_fp16
pitch_estimator_fp16 = PitchEstimator()
pitch_estimator_fp16.load_state_dict(pitch_estimator.state_dict())
pitch_estimator_fp16.merge_weights()
pitch_estimator_fp16.half()
pitch_estimator_fp16.dump(paraphernalia_dir / f"pitch_estimator.bin")
del pitch_estimator_fp16
net_g_fp16 = ConverterNetwork(
nn.Module(), nn.Module(), len(speakers), h.hidden_channels
)
net_g_fp16.load_state_dict(net_g.state_dict())
net_g_fp16.remove_weight_norm()
net_g_fp16.merge_weights()
net_g_fp16.half()
net_g_fp16.dump(paraphernalia_dir / f"waveform_generator.bin")
with open(paraphernalia_dir / f"speaker_embeddings.bin", "wb") as f:
dump_layer(net_g_fp16.embed_speaker, f)
with open(paraphernalia_dir / f"formant_shift_embeddings.bin", "wb") as f:
dump_layer(net_g_fp16.embed_formant_shift, f)
del net_g_fp16
with open(
paraphernalia_dir / f"beatrice_paraphernalia_{name}.toml",
"w",
encoding="utf-8",
) as f:
f.write(
f'''[model]
version = "{PARAPHERNALIA_VERSION}"
name = "{name}"
description = """
No description for this model.
このモデルの説明はありません。
"""
'''
)
for speaker_id, (speaker, speaker_f0) in enumerate(
zip(speakers, speaker_f0s)
):
average_pitch = 69.0 + 12.0 * math.log2(speaker_f0 / 440.0)
average_pitch = round(average_pitch * 8.0) / 8.0
f.write(
f'''
[voice.{speaker_id}]
name = "{speaker}"
description = """
No description for this voice.
この声の説明はありません。
"""
average_pitch = {average_pitch}
'''
)
del paraphernalia_dir
# TODO: phone_extractor, pitch_estimator が既知のモデルであれば dump を省略
# === 6. スケジューラ更新 ===
scheduler_g.step()
scheduler_d.step()
print("Training finished.")