# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import logging import typing as tp from functools import partial import os from pathlib import Path import flashy from omegaconf import DictConfig import multiprocessing import numpy as np import torch import torch.nn as nn from . import base, builders from ..models.builders import get_watermark_model from ..modules.watermark import pad, mix from ..metrics.miou import calculate_miou from ..metrics.pesq import PesqMetric from ..utils import checkpoint from ..utils.audio_effects import ( compress_with_encodec, get_audio_effects, select_audio_effects, ) from ..utils.samples.manager import SampleManager from ..data.audio import save_spectrograms from ..utils.utils import get_pool_executor from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility if tp.TYPE_CHECKING: from ..models.watermark import WMModel def get_encodec_audio_effect(encodec_cfg: DictConfig, sr: int) -> tp.Dict: """ Construct encodec-based compression data agumentation. This method is is put here instead of in `audiocraft.utils.audio_effects` because it depends on the package `audiocraft.solvers`, which is one layer higher than `audiocraft.utils`, so we avoid the circle dependency from any solvers using `audiocraft.utils.audio_effects` to do the augmentation """ from ..solvers.compression import CompressionSolver codec_model = CompressionSolver.model_from_checkpoint(encodec_cfg.ckpt) codec_model.train() return { f"encodec_nq={n_q}": partial( compress_with_encodec, model=codec_model, n_q=n_q, sample_rate=sr, ) for n_q in encodec_cfg.n_qs } def random_message(nbits: int, batch_size: int) -> torch.Tensor: """Return random message as 0/1 tensor.""" if nbits == 0: return torch.tensor([]) return torch.randint(0, 2, (batch_size, nbits)) class WatermarkSolver(base.StandardSolver): """Solver for different watermarking models""" def __init__(self, cfg: DictConfig): super().__init__(cfg) self.rng: torch.Generator # set at each epoch self.model: WMModel if hasattr(cfg, "fsdp"): assert not getattr( cfg.fsdp, "use", False ), "FSDP not supported by WatermarkSolver." self._init_losses() self._init_augmentations() self.balancer = builders.get_balancer(self.loss_weights, self.cfg.balancer) self.path_specs = os.path.join(self.folder, "spectrograms") os.makedirs(self.path_specs, exist_ok=True) def _init_losses(self): assert hasattr(self.cfg, "losses") and isinstance( self.cfg.losses, (DictConfig, tp.Mapping) ), "WatermarkSolver must declare training losses in the config" self.adv_losses = builders.get_adversarial_losses(self.cfg) # noqa self.register_stateful("adv_losses") self.aux_losses = nn.ModuleDict() # noqa self.info_losses = nn.ModuleDict() # noqa self.wm_losses = nn.ModuleDict() # noqa loss_weights = {} for loss_name, weight in self.cfg.losses.items(): # explicitly skip this loss calculation by setting a -1 as weight # if weight == 0 it will be calculated but kept as info if weight == -1: continue if loss_name in ["adv", "feat"]: for adv_name, _ in self.adv_losses.items(): loss_weights[f"{loss_name}_{adv_name}"] = weight elif weight > 0: if loss_name[:3] == "wm_": self.wm_losses[loss_name] = builders.get_loss( loss_name, self.cfg ).to(self.device) loss_weights[loss_name] = weight else: self.aux_losses[loss_name] = builders.get_loss( loss_name, self.cfg ).to(self.device) loss_weights[loss_name] = weight else: self.info_losses[loss_name] = builders.get_loss(loss_name, self.cfg).to( self.device ) self.loss_weights = loss_weights # noqa def _init_augmentations(self): if not hasattr(self.cfg, "aug_weights") or not hasattr( self.cfg, "audio_effects" ): return aug_weights = {} cfg_audio_effects = dict(self.cfg.audio_effects) # Handle `encodec` augmentation separately as this requires loading a # CompressionSolver checkpoint encodec_cfg = cfg_audio_effects.pop("encodec", None) if encodec_cfg: encodec_effects = get_encodec_audio_effect( encodec_cfg, self.cfg.sample_rate ) for aug_name in encodec_effects.keys(): aug_weights[aug_name] = getattr(self.cfg.aug_weights, "encodec", -1) else: encodec_effects = {} other_effects = get_audio_effects(self.cfg) # noqa for name in other_effects.keys(): aug_weights[name] = self.cfg.aug_weights.get(name, -1) self.aug_weights = aug_weights # noqa self.augmentations = {**encodec_effects, **other_effects} # noqa @property def best_metric_name(self) -> tp.Optional[str]: # best model is the last for the watermark model for now return None def build_model(self): """Instantiate model and optimizer.""" # Model and optimizer self.model = get_watermark_model(self.cfg) # Need two optimizers ? self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim) self.register_stateful("model", "optimizer") self.register_best_state("model") self.register_ema("model") def build_dataloaders(self): """Instantiate audio dataloaders for each stage.""" self.dataloaders = builders.get_audio_datasets(self.cfg) def show(self): """Show the Watermark model and employed adversarial loss.""" self.log_model_summary(self.model) self.logger.info("Sould print losses here:") def crop( self, signal: torch.Tensor, watermark: torch.Tensor ) -> tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Applies a transformation to modify the watermarked signal to train localization. It can be one of the following: - zero padding: add zeros at the begining and the end of the signal - crop: crop the watermark apply a watermark only on some parts of the signal - shuffle: replace some part of the audio with other non watermarked parts from the batch In every cases the function returns a mask that contains indicates the parts that are or not watermarked Args: watermark (torch.Tensor): The watermark to apply on the signal. signal (torch.Tensor): clean signal Returns: watermark (torch.Tensor): modified watermark signal (torch.Tensor): modified signal mask (torch.Tensor): mask indicating which portion is still watermarked """ assert ( self.cfg.crop.prob + self.cfg.crop.shuffle_prob + self.cfg.crop.pad_prob <= 1 ), f"The sum of the probabilities {self.cfg.crop.prob=} {self.cfg.crop.shuffle_prob=} \ {self.cfg.crop.pad_prob=} should be less than 1" mask = torch.ones_like(watermark) p = torch.rand(1) if p < self.cfg.crop.pad_prob: # Pad with some probability start = int(torch.rand(1) * 0.33 * watermark.size(-1)) finish = int((0.66 + torch.rand(1) * 0.33) * watermark.size(-1)) mask[:, :, :start] = 0 mask[:, :, finish:] = 0 if torch.rand(1) > 0.5: mask = 1 - mask signal *= mask # pad signal elif ( p < self.cfg.crop.prob + self.cfg.crop.pad_prob + self.cfg.crop.shuffle_prob ): # Define a mask, then crop or shuffle mask_size = round(watermark.shape[-1] * self.cfg.crop.size) n_windows = int( torch.randint(1, self.cfg.crop.max_n_windows + 1, (1,)).item() ) window_size = int(mask_size / n_windows) for _ in range(n_windows): # Create multiple windows in the mask mask_start = torch.randint(0, watermark.shape[-1] - window_size, (1,)) mask[:, :, mask_start: mask_start + window_size] = ( 0 # Apply window to mask ) # inverse the mask half the time if torch.rand(1) > 0.5: mask = 1 - mask if p < self.cfg.crop.pad_prob + self.cfg.crop.shuffle_prob: # shuffle # shuffle signal_cloned = signal.clone().detach() # detach to be sure shuffle_idx = torch.randint(0, signal.size(0), (signal.size(0),)) signal = signal * mask + signal_cloned[shuffle_idx] * ( 1 - mask ) # shuffle signal where not wm watermark *= mask # Apply mask to the watermark return signal, watermark, mask def run_step(self, idx: int, batch: torch.Tensor, metrics: dict): """Perform one training or valid step on a given batch.""" x = batch.to(self.device) y = x.clone() nbits = getattr(self.model, "nbits") message = random_message(nbits, y.shape[0]).to(self.device) watermark = self.model.get_watermark(x, message=message) y, watermark, mask = self.crop(y, watermark) y_wm = y + watermark if ( self.cfg.losses.adv != 0 or self.cfg.losses.feat != 0 ) and self.is_training: # train quality adv d_losses: dict = {} if ( len(self.adv_losses) > 0 and torch.rand(1, generator=self.rng).item() <= 1 / self.cfg.adversarial.every ): for adv_name, adversary in self.adv_losses.items(): disc_loss = adversary.train_adv(y_wm, y) d_losses[f"d_{adv_name}"] = disc_loss metrics["d_loss"] = torch.sum(torch.stack(list(d_losses.values()))) metrics.update(d_losses) balanced_losses: dict = {} other_losses: dict = {} # adversarial losses if self.cfg.losses.adv != 0 or self.cfg.losses.feat != 0: for adv_name, adversary in self.adv_losses.items(): adv_loss, feat_loss = adversary(y_wm, y) balanced_losses[f"adv_{adv_name}"] = adv_loss balanced_losses[f"feat_{adv_name}"] = feat_loss # auxiliary losses on quality/similarity for loss_name, criterion in self.aux_losses.items(): loss = criterion(y_wm, y) balanced_losses[loss_name] = loss # apply augmentations mode = "all" if self.cfg.select_aug_mode == "all" else "weighted" selected_augs = select_audio_effects( self.augmentations, self.aug_weights, mode=mode, max_length=self.cfg.n_max_aug, ) N_augs = len(selected_augs) for ( augmentation_name, augmentation_method, ) in selected_augs.items(): # concatenate to use the augmentation function only once y_y_wm = torch.cat([y, y_wm], dim=0) aug_cat, mask_aug = augmentation_method(y_y_wm, mask=mask) aug_y = aug_cat[: y.size(0)] aug_y_wm = aug_cat[y.size(0):] positive = self.model.detect_watermark(aug_y_wm) negative = self.model.detect_watermark(aug_y) for loss_name, criterion in self.wm_losses.items(): loss = criterion(positive, negative, mask_aug, message) other_losses[f"{loss_name}_{augmentation_name}"] = loss # weighted losses metrics.update(balanced_losses) metrics.update(other_losses) if self.is_training: # something is weird about the loss balancer not other_loss = torch.tensor(0.0, device=self.device) for name, o_loss in other_losses.items(): if "wm_detection" in name: # here we include the detection losses for augmentation other_loss += (self.loss_weights["wm_detection"] / N_augs) * o_loss elif "wm_mb" in name: other_loss += (self.loss_weights["wm_mb"] / N_augs) * o_loss else: other_loss += self.loss_weights[name] * o_loss if other_loss.requires_grad: other_loss.backward(retain_graph=True) ratio1 = sum( p.grad.data.norm(p=2).pow(2) for p in self.model.parameters() if p.grad is not None ) assert isinstance(ratio1, torch.Tensor) metrics["ratio1"] = ratio1.sqrt() # balancer losses backward, returns effective training loss # with effective weights at the current batch. metrics["g_loss"] = self.balancer.backward(balanced_losses, y_wm) # add metrics corresponding to weight ratios metrics.update(self.balancer.metrics) ratio2 = sum( p.grad.data.norm(p=2).pow(2) for p in self.model.parameters() if p.grad is not None ) assert isinstance(ratio2, torch.Tensor) metrics["ratio2"] = ratio2.sqrt() # optim flashy.distrib.sync_model(self.model) if self.cfg.optim.max_norm: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.cfg.optim.max_norm ) self.optimizer.step() self.optimizer.zero_grad() # informative losses only info_losses: dict = {} with torch.no_grad(): for loss_name, criterion in self.info_losses.items(): loss = criterion(y_wm, y) info_losses[loss_name] = loss # pesq metrics["pesq"] = tensor_pesq(y_wm, y, sr=self.cfg.sample_rate) # max allocated memory metrics["max_mem"] = torch.cuda.max_memory_allocated() / 1e9 metrics.update(info_losses) if self.cfg.losses.adv != 0 or self.cfg.losses.feat != 0: # aggregated GAN losses: this is useful to report adv and feat across different adversarial loss setups adv_losses = [ loss for loss_name, loss in metrics.items() if loss_name.startswith("adv") ] if len(adv_losses) > 0: metrics["adv"] = torch.sum(torch.stack(adv_losses)) feat_losses = [ loss for loss_name, loss in metrics.items() if loss_name.startswith("feat") ] if len(feat_losses) > 0: metrics["feat"] = torch.sum(torch.stack(feat_losses)) return metrics def run_epoch(self): # reset random seed at the beginning of the epoch self.rng = torch.Generator() self.rng.manual_seed(1234 + self.epoch) # run epoch super().run_epoch() def evaluate(self) -> dict: """Evaluate stage. Runs audio reconstruction evaluation.""" self.model.eval() evaluate_stage_name = str(self.current_stage) loader = self.dataloaders["evaluate"] updates = len(loader) lp = self.log_progress( f"{evaluate_stage_name} inference", loader, total=updates, updates=self.log_updates, ) average = flashy.averager() pendings = [] ctx = multiprocessing.get_context("spawn") with get_pool_executor(self.cfg.evaluate.num_workers, mp_context=ctx) as pool: for batch in lp: x = batch.to(self.device) with torch.no_grad(): message = random_message(self.model.nbits, x.shape[0]) watermark = self.model.get_watermark(x, message) x_wm = x + watermark y_pred = x_wm.cpu() y = batch.cpu() # should already be on CPU but just in case pendings.append( pool.submit( evaluate_audio_watermark, y_pred, y, self.cfg, ) ) # evaluate augmentations # evaluation is run on all the augmentations for ( augmentation_name, augmentation_method, ) in self.augmentations.items(): # if ( # "mp3" in augmentation_name # and idx >= 8 # and self.cfg.evaluate.every <= 2 # ): # # When evaluating often do not compute mp3 on the full eval dset to make things faster # continue with torch.no_grad(): aug_positive = self.model.detect_watermark( augmentation_method(x_wm) ) aug_negative = self.model.detect_watermark( augmentation_method(x) ) pendings.append( pool.submit( evaluate_augmentations, aug_positive.cpu(), aug_negative.cpu(), augmentation_name, message.cpu(), ) ) # end eval of augmentations # evaluate localization cropping for window_size in np.linspace(0.1, 0.9, 9): mixed, true_predictions = mix(x, x_wm, window_size=window_size) model_predictions = self.model.detect_watermark(mixed) pendings.append( pool.submit( evaluate_localizations, model_predictions.cpu(), true_predictions.cpu(), f"crop_{window_size:0.1f}", ) ) mixed, true_predictions = mix( x, x_wm, window_size=window_size, shuffle=True ) model_predictions = self.model.detect_watermark(mixed) pendings.append( pool.submit( evaluate_localizations, model_predictions.cpu(), true_predictions.cpu(), f"shuffle_{window_size:0.1f}", ) ) # evaluate localization padding mixed, true_predictions = pad(x_wm) model_predictions = self.model.detect_watermark(mixed) pendings.append( pool.submit( evaluate_localizations, model_predictions.cpu(), true_predictions.cpu(), "padding", ) ) mixed, true_predictions = pad(x_wm, central=True) model_predictions = self.model.detect_watermark(mixed) pendings.append( pool.submit( evaluate_localizations, model_predictions.cpu(), true_predictions.cpu(), "central_padding", ) ) # end of evaluate localization metrics_lp = self.log_progress( f"{evaluate_stage_name} metrics", pendings, updates=self.log_updates ) for pending in metrics_lp: metrics = pending.result() metrics = average(metrics) metrics = flashy.distrib.average_metrics(metrics, len(loader)) if self.cfg.select_aug_mode == "use_eval_acc": # Adjust augmentation weights based on evaluation loss. # Higher accuracy results in lower probability of selecting this augmentation. for name in self.augmentations.keys(): if ( self.aug_weights[name] != -1 ): # keep weight to -1 for unwanted augmentations # set to 0.05 to ensure that an augmentation is never completely removed during a full epoch. self.aug_weights[name] = max(1 - metrics[f"aug_{name}_acc"], 0.05) return metrics def generate(self): """Generate stage.""" self.model.eval() sample_manager = SampleManager(self.xp, map_reference_to_sample_id=True) generate_stage_name = str(self.current_stage) loader = self.dataloaders["generate"] updates = len(loader) lp = self.log_progress( generate_stage_name, loader, total=updates, updates=self.log_updates ) path_dir = os.path.join(self.path_specs, f"epoch={self.epoch}") os.makedirs(path_dir, exist_ok=True) first_batch = True for batch in lp: reference, _ = batch reference = reference.to(self.device) with torch.no_grad(): message = random_message(self.model.nbits, reference.shape[0]) watermark = self.model.get_watermark(reference, message) x_wm = reference + watermark reference = reference.cpu() sample_manager.add_samples( x_wm.cpu(), self.epoch, ground_truth_wavs=reference ) if first_batch and flashy.distrib.is_rank_zero(): for i in range(reference.size(0)): ys = [ reference.cpu()[i].squeeze(0).numpy(), x_wm.cpu()[i].squeeze(0).numpy(), watermark.cpu()[i].squeeze(0).numpy(), ] path = os.path.join(path_dir, f"spec_{i}.pdf") save_spectrograms( ys, names=["Ground Truth", "Audio Watermarked", "Watermark"], sr=self.cfg.sample_rate, path=path, ) first_batch = False flashy.distrib.barrier() def load_from_pretrained(self, name: str) -> dict: raise ValueError("No pretrained model") @staticmethod def model_from_checkpoint( checkpoint_path: tp.Union[Path, str], device: tp.Union[torch.device, str] = "cpu", ) -> "WMModel": """Instantiate a WatermarkModel from a given checkpoint path or dora sig. Args: checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved. device (torch.device or str): Device on which the model is loaded. """ checkpoint_path = str(checkpoint_path) logger = logging.getLogger(__name__) logger.info(f"Loading WatermarkModel from checkpoint: {checkpoint_path}") _checkpoint_path = checkpoint.resolve_checkpoint_path( checkpoint_path, use_fsdp=False ) assert ( _checkpoint_path is not None ), f"Could not resolve WatermarkModel checkpoint path: {checkpoint_path}" state = checkpoint.load_checkpoint(_checkpoint_path) assert ( state is not None and "xp.cfg" in state ), f"Could not load WatermarkModel from ckpt: {checkpoint_path}" cfg = state["xp.cfg"] cfg.device = device watermarking_model = get_watermark_model(cfg).to(device) assert "best_state" in state and state["best_state"] != {} assert ( "exported" not in state ), "When loading an exported checkpoint, use the //pretrained/ prefix." watermarking_model.load_state_dict(state["best_state"]["model"]) watermarking_model.eval() logger.info("Watermarking model loaded!") return watermarking_model def evaluate_localizations(predictions, true_predictions, name): metrics = {} # predictions are output of the detector shape [bsz, 2, frames] # true_predictions is output of the mix method shape [bsz, 2, frames] metrics[f"localization_acc_{name}"] = ( ((predictions[:, 1, :] > 0.5) == true_predictions[:, 1, :]) .float() .mean() .item() ) metrics[f"localization_miou_{name}"] = calculate_miou( predictions[:, 1, :], true_predictions[:, 1, :] ) return metrics def evaluate_augmentations( positive: torch.Tensor, negative: torch.Tensor, augmentation_name: str, message: torch.Tensor, ) -> dict: """calculating evaluation metrics but take name of the augmentation method that has been done before getting positive and negative results""" metrics = {} metrics[f"aug_{augmentation_name}_acc"] = compute_accuracy(positive, negative) metrics[f"aug_{augmentation_name}_fpr"] = compute_FPR(negative) metrics[f"aug_{augmentation_name}_fnr"] = compute_FNR(positive) if message.shape[0] != 0: metrics[f"aug_{augmentation_name}_bit_acc"] = compute_bit_acc(positive, message) # add one metric which is average overall score of all augmentations metrics["all_aug_acc"] = compute_accuracy(positive, negative) return metrics def evaluate_audio_watermark( y_pred: torch.Tensor, y: torch.Tensor, cfg: DictConfig, ) -> dict: """Audio reconstruction evaluation method that can be conveniently pickled.""" metrics = {} if cfg.evaluate.metrics.visqol: visqol = builders.get_visqol(cfg.metrics.visqol) metrics["visqol"] = visqol(y_pred, y, cfg.sample_rate) sisnr = ScaleInvariantSignalNoiseRatio().to(y.device) stoi = ShortTimeObjectiveIntelligibility(fs=cfg.sample_rate) metrics["sisnr"] = sisnr(y_pred, y) metrics["stoi"] = stoi(y_pred, y) metrics["pesq"] = tensor_pesq(y_pred, y, sr=cfg.sample_rate) return metrics def tensor_pesq(y_pred: torch.Tensor, y: torch.Tensor, sr: int): # pesq returns error if no speech is detected, so we catch it return PesqMetric(sr)(y_pred, y).item() def compute_accuracy(positive, negative): N = (positive[:, 1, :].mean(dim=1) > 0.5).sum() + ( negative[:, 0, :].mean(dim=1) > 0.5 ).sum() acc = N / (2 * positive.size(0)) return acc def compute_FPR(negative): N = (negative[:, 1, :].mean(dim=1) > 0.5).sum() fpr = N / (negative.size(0)) return fpr def compute_FNR(positive): N = (positive[:, 0, :].mean(dim=1) > 0.5).sum() fpr = N / (positive.size(0)) return fpr def _bit_acc(decoded, original): bit_acc = (decoded == original).float().mean() return bit_acc def compute_bit_acc(positive, original, mask=None): """Compute bit accuracy. Args: positive: detector outputs [bsz, 2+nbits, time_steps] original: original message (0 or 1) [bsz, nbits] mask: mask of the watermark [bsz, 1, time_steps] """ decoded = positive[:, 2:, :] # b 2+nbits t -> b nbits t if mask is not None: # cut last dim of positive to keep only where mask is 1 new_shape = [*decoded.shape[:-1], -1] # b nbits t -> b nbits -1 decoded = torch.masked_select(decoded, mask == 1).reshape(new_shape) # average decision over time, then threshold decoded = decoded.mean(dim=-1) > 0 # b nbits return _bit_acc(decoded, original)