From 660b41aef8b3be235284082451c165f7f66405e0 Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Sun, 4 May 2025 22:48:57 +0300 Subject: [PATCH 01/11] :albemic: | Real-time testing... --- AudioUtils.py | 34 ++++++++++++++++ data.py | 32 ++++++--------- training.py | 102 +++++++++++++++++++++++++--------------------- training_utils.py | 14 +++---- 4 files changed, 107 insertions(+), 75 deletions(-) diff --git a/AudioUtils.py b/AudioUtils.py index 04f75db..f4866dd 100644 --- a/AudioUtils.py +++ b/AudioUtils.py @@ -16,3 +16,37 @@ def stretch_tensor(tensor, target_length): tensor = F.interpolate(tensor, scale_factor=scale_factor, mode='linear', align_corners=False) return tensor + +def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 128): + current_length = audio_tensor.shape[-1] + + if current_length < target_length: + padding_needed = target_length - current_length + + padding_tuple = (0, padding_needed) + padded_audio_tensor = F.pad(audio_tensor, padding_tuple, mode='constant', value=0) + else: + padded_audio_tensor = audio_tensor + + return padded_audio_tensor + +def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 128) -> list[torch.Tensor]: + if not isinstance(chunk_size, int) or chunk_size <= 0: + raise ValueError("chunk_size must be a positive integer.") + + # Handle scalar tensor edge case if necessary + if audio_tensor.dim() == 0: + return [audio_tensor] if audio_tensor.numel() > 0 else [] + + # Identify the dimension to split (usually the last one, representing time/samples) + split_dim = -1 + num_samples = audio_tensor.shape[split_dim] + + if num_samples == 0: + return [] # Return empty list if the dimension to split is empty + + # Use torch.split to divide the tensor into chunks + # It handles the last chunk being potentially smaller automatically. + chunks = list(torch.split(audio_tensor, chunk_size, dim=split_dim)) + + return chunks diff --git a/data.py b/data.py index bc7574f..88364b6 100644 --- a/data.py +++ b/data.py @@ -21,33 +21,25 @@ class AudioDataset(Dataset): def __getitem__(self, idx): # Load high-quality audio high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True) + # Change to mono + high_quality_audio = AudioUtils.stereo_tensor_to_mono(high_quality_audio) # Generate low-quality audio with random downsampling mangled_sample_rate = random.choice(self.audio_sample_rates) - resample_transform_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate) - low_quality_audio = resample_transform_low(high_quality_audio) + resample_transform_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate) resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate) + + low_quality_audio = resample_transform_low(high_quality_audio) low_quality_audio = resample_transform_high(low_quality_audio) - high_quality_audio = AudioUtils.stereo_tensor_to_mono(high_quality_audio) - low_quality_audio = AudioUtils.stereo_tensor_to_mono(low_quality_audio) - # Pad or truncate high-quality audio - if high_quality_audio.shape[1] < self.MAX_LENGTH: - padding = self.MAX_LENGTH - high_quality_audio.shape[1] - high_quality_audio = F.pad(high_quality_audio, (0, padding)) - elif high_quality_audio.shape[1] > self.MAX_LENGTH: - high_quality_audio = high_quality_audio[:, :self.MAX_LENGTH] + splitted_high_quality_audio = AudioUtils.split_audio(high_quality_audio, 128) + splitted_high_quality_audio[-1] = AudioUtils.pad_tensor(splitted_high_quality_audio[-1], 128) + splitted_high_quality_audio = [tensor.to(self.device) for tensor in splitted_high_quality_audio] - # Pad or truncate low-quality audio - if low_quality_audio.shape[1] < self.MAX_LENGTH: - padding = self.MAX_LENGTH - low_quality_audio.shape[1] - low_quality_audio = F.pad(low_quality_audio, (0, padding)) - elif low_quality_audio.shape[1] > self.MAX_LENGTH: - low_quality_audio = low_quality_audio[:, :self.MAX_LENGTH] + splitted_low_quality_audio = AudioUtils.split_audio(low_quality_audio, 128) + splitted_low_quality_audio[-1] = AudioUtils.pad_tensor(splitted_low_quality_audio[-1], 128) + splitted_low_quality_audio = [tensor.to(self.device) for tensor in splitted_low_quality_audio] - high_quality_audio = high_quality_audio.to(self.device) - low_quality_audio = low_quality_audio.to(self.device) - - return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate) + return (splitted_high_quality_audio, original_sample_rate), (splitted_low_quality_audio, mangled_sample_rate) diff --git a/training.py b/training.py index 01ea749..f6ab2f4 100644 --- a/training.py +++ b/training.py @@ -43,11 +43,11 @@ print(f"Using device: {device}") # Parameters sample_rate = 44100 -n_fft = 2048 -hop_length = 256 +n_fft = 128 +hop_length = 128 win_length = n_fft -n_mels = 128 -n_mfcc = 20 # If using MFCC +n_mels = 40 +n_mfcc = 13 # If using MFCC mfcc_transform = T.MFCC( sample_rate, @@ -76,7 +76,7 @@ os.makedirs(audio_output_dir, exist_ok=True) # ========= SINGLE ========= -train_data_loader = DataLoader(dataset, batch_size=64, shuffle=True) +train_data_loader = DataLoader(dataset, batch_size=1, shuffle=True) # ========= MODELS ========= @@ -115,61 +115,69 @@ scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min' def start_training(): generator_epochs = 5000 for generator_epoch in range(generator_epochs): - low_quality_audio = (torch.empty((1)), 1) - high_quality_audio = (torch.empty((1)), 1) - ai_enhanced_audio = (torch.empty((1)), 1) + high_quality_audio = ([torch.empty((1))], 1) + low_quality_audio = ([torch.empty((1))], 1) + ai_enhanced_audio = ([torch.empty((1))], 1) times_correct = 0 # ========= TRAINING ========= - for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"): - # for high_quality_clip, low_quality_clip in train_data_loader: - high_quality_sample = (high_quality_clip[0], high_quality_clip[1]) - low_quality_sample = (low_quality_clip[0], low_quality_clip[1]) + for high_quality_data, low_quality_data in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"): + ## Data structure: + # [[float..., float..., float...], sample_rate] # ========= LABELS ========= - batch_size = high_quality_clip[0].size(0) + + batch_size = high_quality_data[0][0].size(0) real_labels = torch.ones(batch_size, 1).to(device) fake_labels = torch.zeros(batch_size, 1).to(device) - # ========= DISCRIMINATOR ========= - discriminator.train() - d_loss = discriminator_train( - high_quality_sample, - low_quality_sample, - real_labels, - fake_labels, - discriminator, - generator, - criterion_d, - optimizer_d - ) + high_quality_audio = high_quality_data + low_quality_audio = low_quality_data - # ========= GENERATOR ========= - generator.train() - generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train( - low_quality_sample, - high_quality_sample, - real_labels, - generator, - discriminator, - criterion_d, - optimizer_g, - device, - mel_transform, - stft_transform, - mfcc_transform - ) + ai_enhanced_outputs = [] - if debug: - print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}") - scheduler_d.step(d_loss.detach()) - scheduler_g.step(adversarial_loss.detach()) + for high_quality_sample, low_quality_sample in tqdm.tqdm(zip(high_quality_data[0], low_quality_data[0]), desc=f"Processing audio clip.. Length: {len(high_quality_data[0])}"): + # ========= DISCRIMINATOR ========= + discriminator.train() + d_loss = discriminator_train( + high_quality_sample, + low_quality_sample, + real_labels, + fake_labels, + discriminator, + generator, + criterion_d, + optimizer_d + ) + + # ========= GENERATOR ========= + generator.train() + generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train( + low_quality_sample, + high_quality_sample, + real_labels, + generator, + discriminator, + criterion_d, + optimizer_g, + device, + mel_transform, + stft_transform, + mfcc_transform + ) + + ai_enhanced_outputs.append(generator_output) + + if debug: + print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}") + scheduler_d.step(d_loss.detach()) + scheduler_g.step(adversarial_loss.detach()) # ========= SAVE LATEST AUDIO ========= - high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0]) - low_quality_audio = (low_quality_clip[0][0], low_quality_clip[1][0]) - ai_enhanced_audio = (generator_output[0], high_quality_clip[1][0]) + high_quality_audio = (torch.cat(high_quality_data[0]), high_quality_data[1]) + low_quality_audio = (torch.cat(low_quality_data[0]), low_quality_data[1]) + ai_enhanced_audio = (torch.cat(ai_enhanced_outputs), high_quality_data[1]) new_epoch = generator_epoch+epoch diff --git a/training_utils.py b/training_utils.py index 6f26f58..c7d43e5 100644 --- a/training_utils.py +++ b/training_utils.py @@ -20,12 +20,10 @@ def mel_spectrogram_l1_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tenso mel_spec_true = mel_transform(y_true) mel_spec_pred = mel_transform(y_pred) - # Ensure same time dimension length (due to potential framing differences) min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1]) mel_spec_true = mel_spec_true[..., :min_len] mel_spec_pred = mel_spec_pred[..., :min_len] - # L1 Loss (Mean Absolute Error) loss = torch.mean(torch.abs(mel_spec_true - mel_spec_pred)) return loss @@ -69,11 +67,11 @@ def discriminator_train(high_quality, low_quality, real_labels, fake_labels, dis optimizer.zero_grad() # Forward pass for real samples - discriminator_decision_from_real = discriminator(high_quality[0]) + discriminator_decision_from_real = discriminator(high_quality) d_loss_real = criterion(discriminator_decision_from_real, real_labels) with torch.no_grad(): - generator_output = generator(low_quality[0]) + generator_output = generator(low_quality) discriminator_decision_from_fake = discriminator(generator_output) d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels.expand_as(discriminator_decision_from_fake)) @@ -105,7 +103,7 @@ def generator_train( ): g_optimizer.zero_grad() - generator_output = generator(low_quality[0]) + generator_output = generator(low_quality) discriminator_decision = discriminator(generator_output) adversarial_loss = adv_criterion(discriminator_decision, real_labels.expand_as(discriminator_decision)) @@ -116,15 +114,15 @@ def generator_train( # Calculate Mel L1 Loss if weight is positive if lambda_mel_l1 > 0: - mel_l1 = mel_spectrogram_l1_loss(mel_transform, high_quality[0], generator_output) + mel_l1 = mel_spectrogram_l1_loss(mel_transform, high_quality, generator_output) # Calculate Log STFT L1 Loss if weight is positive if lambda_log_stft > 0: - log_stft_l1 = log_stft_magnitude_loss(stft_transform, high_quality[0], generator_output) + log_stft_l1 = log_stft_magnitude_loss(stft_transform, high_quality, generator_output) # Calculate MFCC Loss if weight is positive if lambda_mfcc > 0: - mfcc_l = gpu_mfcc_loss(mfcc_transform, high_quality[0], generator_output) + mfcc_l = gpu_mfcc_loss(mfcc_transform, high_quality, generator_output) mel_l1_tensor = torch.tensor(mel_l1, device=device) if isinstance(mel_l1, float) else mel_l1 log_stft_l1_tensor = torch.tensor(log_stft_l1, device=device) if isinstance(log_stft_l1, float) else log_stft_l1 From b1e18443ba1c090b1bbd517d6fb183f5d9872bc6 Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Sun, 4 May 2025 23:56:14 +0300 Subject: [PATCH 02/11] :sparkles: | Added support for .mp3 and .flac loading... --- data.py | 55 ++++++++++++++-------------- file_utils.py | 10 +++--- training.py | 99 +++++++++++++++++++++++++-------------------------- 3 files changed, 83 insertions(+), 81 deletions(-) diff --git a/data.py b/data.py index 88364b6..6d64af5 100644 --- a/data.py +++ b/data.py @@ -5,41 +5,42 @@ import torchaudio import os import random import torchaudio.transforms as T +import tqdm import AudioUtils class AudioDataset(Dataset): audio_sample_rates = [11025] - MAX_LENGTH = 44100 # Define your desired maximum length here def __init__(self, input_dir, device): - self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')] self.device = device + input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav') or f.endswith('.mp3') or f.endswith('.flac')] + + data = [] + for audio_clip in tqdm.tqdm(input_files, desc=f"Processing {len(input_files)} audio file(s)"): + audio, original_sample_rate = torchaudio.load(audio_clip, normalize=True) + audio = AudioUtils.stereo_tensor_to_mono(audio) + + # Generate low-quality audio with random downsampling + mangled_sample_rate = random.choice(self.audio_sample_rates) + resample_transform_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate) + resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate) + + low_audio = resample_transform_low(audio) + low_audio = resample_transform_high(low_audio) + + splitted_high_quality_audio = AudioUtils.split_audio(audio, 128) + splitted_high_quality_audio[-1] = AudioUtils.pad_tensor(splitted_high_quality_audio[-1], 128) + + splitted_low_quality_audio = AudioUtils.split_audio(low_audio, 128) + splitted_low_quality_audio[-1] = AudioUtils.pad_tensor(splitted_low_quality_audio[-1], 128) + + for high_quality_sample, low_quality_sample in zip(splitted_high_quality_audio, splitted_low_quality_audio): + data.append(((high_quality_sample, low_quality_sample), (original_sample_rate, mangled_sample_rate))) + + self.audio_data = data def __len__(self): - return len(self.input_files) + return len(self.audio_data) def __getitem__(self, idx): - # Load high-quality audio - high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True) - # Change to mono - high_quality_audio = AudioUtils.stereo_tensor_to_mono(high_quality_audio) - - # Generate low-quality audio with random downsampling - mangled_sample_rate = random.choice(self.audio_sample_rates) - - resample_transform_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate) - resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate) - - low_quality_audio = resample_transform_low(high_quality_audio) - low_quality_audio = resample_transform_high(low_quality_audio) - - - splitted_high_quality_audio = AudioUtils.split_audio(high_quality_audio, 128) - splitted_high_quality_audio[-1] = AudioUtils.pad_tensor(splitted_high_quality_audio[-1], 128) - splitted_high_quality_audio = [tensor.to(self.device) for tensor in splitted_high_quality_audio] - - splitted_low_quality_audio = AudioUtils.split_audio(low_quality_audio, 128) - splitted_low_quality_audio[-1] = AudioUtils.pad_tensor(splitted_low_quality_audio[-1], 128) - splitted_low_quality_audio = [tensor.to(self.device) for tensor in splitted_low_quality_audio] - - return (splitted_high_quality_audio, original_sample_rate), (splitted_low_quality_audio, mangled_sample_rate) + return self.audio_data[idx] diff --git a/file_utils.py b/file_utils.py index a723688..98f70bc 100644 --- a/file_utils.py +++ b/file_utils.py @@ -2,20 +2,22 @@ import json filepath = "my_data.json" -def write_data(filepath, data): +def write_data(filepath, data, debug=False): try: with open(filepath, 'w') as f: json.dump(data, f, indent=4) # Use indent for pretty formatting - print(f"Data written to '{filepath}'") + if debug: + print(f"Data written to '{filepath}'") except Exception as e: print(f"Error writing to file: {e}") -def read_data(filepath): +def read_data(filepath, debug=False): try: with open(filepath, 'r') as f: data = json.load(f) - print(f"Data read from '{filepath}'") + if debug: + print(f"Data read from '{filepath}'") return data except FileNotFoundError: print(f"File not found: {filepath}") diff --git a/training.py b/training.py index f6ab2f4..ab9b35b 100644 --- a/training.py +++ b/training.py @@ -76,7 +76,7 @@ os.makedirs(audio_output_dir, exist_ok=True) # ========= SINGLE ========= -train_data_loader = DataLoader(dataset, batch_size=1, shuffle=True) +train_data_loader = DataLoader(dataset, batch_size=256, shuffle=True) # ========= MODELS ========= @@ -122,70 +122,69 @@ def start_training(): times_correct = 0 # ========= TRAINING ========= - for high_quality_data, low_quality_data in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"): + for training_data in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"): ## Data structure: - # [[float..., float..., float...], sample_rate] + # [[[float..., float..., float...], [float..., float..., float...]], [original_sample_rate, mangled_sample_rate]] # ========= LABELS ========= + good_quality_data = training_data[0][0].to(device) + bad_quality_data = training_data[0][1].to(device) + original_sample_rate = training_data[1][0] + mangled_sample_rate = training_data[1][1] - batch_size = high_quality_data[0][0].size(0) + batch_size = good_quality_data.size(0) real_labels = torch.ones(batch_size, 1).to(device) fake_labels = torch.zeros(batch_size, 1).to(device) - high_quality_audio = high_quality_data - low_quality_audio = low_quality_data + high_quality_audio = (good_quality_data, original_sample_rate) + low_quality_audio = (bad_quality_data, mangled_sample_rate) - ai_enhanced_outputs = [] + # ========= DISCRIMINATOR ========= + discriminator.train() + d_loss = discriminator_train( + good_quality_data, + bad_quality_data, + real_labels, + fake_labels, + discriminator, + generator, + criterion_d, + optimizer_d + ) - for high_quality_sample, low_quality_sample in tqdm.tqdm(zip(high_quality_data[0], low_quality_data[0]), desc=f"Processing audio clip.. Length: {len(high_quality_data[0])}"): - # ========= DISCRIMINATOR ========= - discriminator.train() - d_loss = discriminator_train( - high_quality_sample, - low_quality_sample, - real_labels, - fake_labels, - discriminator, - generator, - criterion_d, - optimizer_d - ) + # ========= GENERATOR ========= + generator.train() + generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train( + bad_quality_data, + good_quality_data, + real_labels, + generator, + discriminator, + criterion_d, + optimizer_g, + device, + mel_transform, + stft_transform, + mfcc_transform + ) - # ========= GENERATOR ========= - generator.train() - generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train( - low_quality_sample, - high_quality_sample, - real_labels, - generator, - discriminator, - criterion_d, - optimizer_g, - device, - mel_transform, - stft_transform, - mfcc_transform - ) - - ai_enhanced_outputs.append(generator_output) - - if debug: - print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}") - scheduler_d.step(d_loss.detach()) - scheduler_g.step(adversarial_loss.detach()) + if debug: + print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}") + scheduler_d.step(d_loss.detach()) + scheduler_g.step(adversarial_loss.detach()) # ========= SAVE LATEST AUDIO ========= - high_quality_audio = (torch.cat(high_quality_data[0]), high_quality_data[1]) - low_quality_audio = (torch.cat(low_quality_data[0]), low_quality_data[1]) - ai_enhanced_audio = (torch.cat(ai_enhanced_outputs), high_quality_data[1]) + high_quality_audio = (good_quality_data, original_sample_rate) + low_quality_audio = (bad_quality_data, original_sample_rate) + ai_enhanced_audio = (generator_output, original_sample_rate) new_epoch = generator_epoch+epoch - if generator_epoch % 25 == 0: - print(f"Saved epoch {new_epoch}!") - torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu().detach(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again. - torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu().detach(), ai_enhanced_audio[1]) - torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu().detach(), high_quality_audio[1]) + # if generator_epoch % 25 == 0: + # print(f"Saved epoch {new_epoch}!") + # torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0][-1].cpu().detach(), high_quality_audio[1][-1]) + # torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0][-1].cpu().detach(), high_quality_audio[1][-1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again. + # torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0][-1].cpu().detach(), high_quality_audio[1][-1]) #if debug: # print(generator.state_dict().keys()) From a135c765da2b99685a6cf42e64cdf34d3121ce17 Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Mon, 5 May 2025 00:50:56 +0300 Subject: [PATCH 03/11] :bug: | Misc fixes... --- data.py | 10 +++++----- training.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/data.py b/data.py index 6d64af5..59986f1 100644 --- a/data.py +++ b/data.py @@ -11,7 +11,7 @@ import AudioUtils class AudioDataset(Dataset): audio_sample_rates = [11025] - def __init__(self, input_dir, device): + def __init__(self, input_dir, device, clip_length = 256): self.device = device input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav') or f.endswith('.mp3') or f.endswith('.flac')] @@ -28,11 +28,11 @@ class AudioDataset(Dataset): low_audio = resample_transform_low(audio) low_audio = resample_transform_high(low_audio) - splitted_high_quality_audio = AudioUtils.split_audio(audio, 128) - splitted_high_quality_audio[-1] = AudioUtils.pad_tensor(splitted_high_quality_audio[-1], 128) + splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length) + splitted_high_quality_audio[-1] = AudioUtils.pad_tensor(splitted_high_quality_audio[-1], clip_length) - splitted_low_quality_audio = AudioUtils.split_audio(low_audio, 128) - splitted_low_quality_audio[-1] = AudioUtils.pad_tensor(splitted_low_quality_audio[-1], 128) + splitted_low_quality_audio = AudioUtils.split_audio(low_audio, clip_length) + splitted_low_quality_audio[-1] = AudioUtils.pad_tensor(splitted_low_quality_audio[-1], clip_length) for high_quality_sample, low_quality_sample in zip(splitted_high_quality_audio, splitted_low_quality_audio): data.append(((high_quality_sample, low_quality_sample), (original_sample_rate, mangled_sample_rate))) diff --git a/training.py b/training.py index ab9b35b..5ccabc7 100644 --- a/training.py +++ b/training.py @@ -76,7 +76,7 @@ os.makedirs(audio_output_dir, exist_ok=True) # ========= SINGLE ========= -train_data_loader = DataLoader(dataset, batch_size=256, shuffle=True) +train_data_loader = DataLoader(dataset, batch_size=1024, shuffle=True) # ========= MODELS ========= From 2ded03713d38e4f0e45e11cb945c5ffda7e90a64 Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Fri, 6 Jun 2025 22:10:06 +0300 Subject: [PATCH 04/11] :sparkles: | Added app.py script so the model can be used. --- AudioUtils.py | 19 ++++++++++++++++ app.py | 60 +++++++++++++++++++++++++++++++++++++++++++++++++++ training.py | 18 ++++++++-------- 3 files changed, 88 insertions(+), 9 deletions(-) create mode 100644 app.py diff --git a/AudioUtils.py b/AudioUtils.py index f4866dd..f45efb5 100644 --- a/AudioUtils.py +++ b/AudioUtils.py @@ -50,3 +50,22 @@ def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 128) -> list[torch chunks = list(torch.split(audio_tensor, chunk_size, dim=split_dim)) return chunks + +def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor: + if not chunks: + return torch.empty(0) + + if len(chunks) == 1 and chunks[0].dim() == 0: + return chunks[0] + + concat_dim = -1 + + try: + reconstructed_tensor = torch.cat(chunks, dim=concat_dim) + except RuntimeError as e: + raise RuntimeError( + f"Failed to concatenate audio chunks. Ensure chunks have compatible shapes " + f"for concatenation along dimension {concat_dim}. Original error: {e}" + ) + + return reconstructed_tensor diff --git a/app.py b/app.py new file mode 100644 index 0000000..a24486b --- /dev/null +++ b/app.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn +import torch.optim as optim + +import torch.nn.functional as F +import torchaudio +import tqdm + +import argparse +import math +import os + +import AudioUtils +from generator import SISUGenerator + + +# Init script argument parser +parser = argparse.ArgumentParser(description="Training script") +parser.add_argument("--device", type=str, default="cpu", help="Select device") +parser.add_argument("--model", type=str, help="Model to use for upscaling") +parser.add_argument("--clip_length", type=int, default=256, help="Internal clip length, leave unspecified if unsure") +parser.add_argument("-i", "--input", type=str, help="Input audio file") +parser.add_argument("-o", "--output", type=str, help="Output audio file") + +args = parser.parse_args() + +device = torch.device(args.device if torch.cuda.is_available() else "cpu") +print(f"Using device: {device}") + +generator = SISUGenerator() + +models_dir = args.model +clip_length = args.clip_length +input_audio = args.input +output_audio = args.output + +if models_dir: + generator.load_state_dict(torch.load(f"{models_dir}", map_location=device, weights_only=True)) +else: + print(f"Generator model (--model) isn't specified. Do you have the trained model? If not you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)") + +generator = generator.to(device) + +def start(): + # To Mono! + audio, original_sample_rate = torchaudio.load(input_audio, normalize=True) + audio = AudioUtils.stereo_tensor_to_mono(audio) + + splitted_audio = AudioUtils.split_audio(audio, clip_length) + splitted_audio_on_device = [t.to(device) for t in splitted_audio] + processed_audio = [] + + for clip in tqdm.tqdm(splitted_audio_on_device, desc="Processing..."): + processed_audio.append(generator(clip)) + + reconstructed_audio = AudioUtils.reconstruct_audio(processed_audio) + print(f"Saving {output_audio}!") + torchaudio.save(output_audio, reconstructed_audio.cpu().detach(), original_sample_rate) + +start() diff --git a/training.py b/training.py index 5ccabc7..52275d5 100644 --- a/training.py +++ b/training.py @@ -69,14 +69,14 @@ debug = args.debug # Initialize dataset and dataloader dataset_dir = './dataset/good' dataset = AudioDataset(dataset_dir, device) -models_dir = "models" +models_dir = "./models" os.makedirs(models_dir, exist_ok=True) -audio_output_dir = "output" +audio_output_dir = "./output" os.makedirs(audio_output_dir, exist_ok=True) # ========= SINGLE ========= -train_data_loader = DataLoader(dataset, batch_size=1024, shuffle=True) +train_data_loader = DataLoader(dataset, batch_size=8192, shuffle=True, num_workers=24) # ========= MODELS ========= @@ -85,17 +85,17 @@ generator = SISUGenerator() discriminator = SISUDiscriminator() epoch: int = args.epoch -epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json") if args.continue_training: - generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) - discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) - epoch = epoch_from_file["epoch"] + 1 -else: if args.generator is not None: generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True)) - if args.discriminator is not None: + elif args.discriminator is not None: discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True)) + else: + generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) + discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) + epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json") + epoch = epoch_from_file["epoch"] + 1 generator = generator.to(device) discriminator = discriminator.to(device) From 03fdc050cc4c13d2dcec8d7cfea3ca27ae7ae557 Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Sat, 7 Jun 2025 20:43:52 +0300 Subject: [PATCH 05/11] :zap: | Made training bit faster. --- app.py | 2 +- data.py | 2 +- training.py | 47 +++++++++++++++++++++++++---------------------- 3 files changed, 27 insertions(+), 24 deletions(-) diff --git a/app.py b/app.py index a24486b..ed51803 100644 --- a/app.py +++ b/app.py @@ -18,7 +18,7 @@ from generator import SISUGenerator parser = argparse.ArgumentParser(description="Training script") parser.add_argument("--device", type=str, default="cpu", help="Select device") parser.add_argument("--model", type=str, help="Model to use for upscaling") -parser.add_argument("--clip_length", type=int, default=256, help="Internal clip length, leave unspecified if unsure") +parser.add_argument("--clip_length", type=int, default=1024, help="Internal clip length, leave unspecified if unsure") parser.add_argument("-i", "--input", type=str, help="Input audio file") parser.add_argument("-o", "--output", type=str, help="Output audio file") diff --git a/data.py b/data.py index 59986f1..c3e1047 100644 --- a/data.py +++ b/data.py @@ -11,7 +11,7 @@ import AudioUtils class AudioDataset(Dataset): audio_sample_rates = [11025] - def __init__(self, input_dir, device, clip_length = 256): + def __init__(self, input_dir, device, clip_length = 1024): self.device = device input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav') or f.endswith('.mp3') or f.endswith('.flac')] diff --git a/training.py b/training.py index 52275d5..1be713c 100644 --- a/training.py +++ b/training.py @@ -43,27 +43,38 @@ print(f"Using device: {device}") # Parameters sample_rate = 44100 -n_fft = 128 -hop_length = 128 +n_fft = 1024 win_length = n_fft +hop_length = n_fft // 4 n_mels = 40 -n_mfcc = 13 # If using MFCC +n_mfcc = 13 mfcc_transform = T.MFCC( - sample_rate, - n_mfcc, - melkwargs = {'n_fft': n_fft, 'hop_length': hop_length} + sample_rate=sample_rate, + n_mfcc=n_mfcc, + melkwargs={ + 'n_fft': n_fft, + 'hop_length': hop_length, + 'win_length': win_length, + 'n_mels': n_mels, + 'power': 1.0, + } ).to(device) mel_transform = T.MelSpectrogram( - sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, - win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel + sample_rate=sample_rate, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + n_mels=n_mels, + power=1.0 # Magnitude Mel ).to(device) stft_transform = T.Spectrogram( - n_fft=n_fft, win_length=win_length, hop_length=hop_length + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length ).to(device) - debug = args.debug # Initialize dataset and dataloader @@ -76,7 +87,7 @@ os.makedirs(audio_output_dir, exist_ok=True) # ========= SINGLE ========= -train_data_loader = DataLoader(dataset, batch_size=8192, shuffle=True, num_workers=24) +train_data_loader = DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=24) # ========= MODELS ========= @@ -94,6 +105,7 @@ if args.continue_training: else: generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) + epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json") epoch = epoch_from_file["epoch"] + 1 @@ -178,19 +190,10 @@ def start_training(): low_quality_audio = (bad_quality_data, original_sample_rate) ai_enhanced_audio = (generator_output, original_sample_rate) - new_epoch = generator_epoch+epoch - - # if generator_epoch % 25 == 0: - # print(f"Saved epoch {new_epoch}!") - # torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0][-1].cpu().detach(), high_quality_audio[1][-1]) - # torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0][-1].cpu().detach(), high_quality_audio[1][-1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again. - # torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0][-1].cpu().detach(), high_quality_audio[1][-1]) - - #if debug: - # print(generator.state_dict().keys()) - # print(discriminator.state_dict().keys()) torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt") torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt") + + new_epoch = generator_epoch+epoch Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch}) From ff38cefdd372e53eb68d14b36bca59534c94c4d2 Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Sun, 8 Jun 2025 18:14:31 +0300 Subject: [PATCH 06/11] :bug: | Fix loading wrong model. --- training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/training.py b/training.py index 1be713c..ab2b1e5 100644 --- a/training.py +++ b/training.py @@ -87,7 +87,7 @@ os.makedirs(audio_output_dir, exist_ok=True) # ========= SINGLE ========= -train_data_loader = DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=24) +train_data_loader = DataLoader(dataset, batch_size=2048, shuffle=True, num_workers=24) # ========= MODELS ========= @@ -104,7 +104,7 @@ if args.continue_training: discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True)) else: generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) - discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) + discriminator.load_state_dict(torch.load(f"{models_dir}/temp_discriminator.pt", map_location=device, weights_only=True)) epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json") epoch = epoch_from_file["epoch"] + 1 From 0bc8fc279210c6791735a58d0e445fcf1bb54bbf Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Wed, 10 Sep 2025 19:52:53 +0300 Subject: [PATCH 07/11] :sparkles: | Made training bit... spicier. --- app.py | 74 ++++-- data.py | 59 +++-- discriminator.py | 92 +++++++- generator.py | 12 +- training.py | 387 ++++++++++++++++++------------- training_utils.py | 198 ++++++++-------- utils/MultiResolutionSTFTLoss.py | 62 +++++ utils/__init__.py | 0 8 files changed, 581 insertions(+), 303 deletions(-) create mode 100644 utils/MultiResolutionSTFTLoss.py create mode 100644 utils/__init__.py diff --git a/app.py b/app.py index ed51803..8c669a9 100644 --- a/app.py +++ b/app.py @@ -1,33 +1,49 @@ -import torch -import torch.nn as nn -import torch.optim as optim - -import torch.nn.functional as F -import torchaudio -import tqdm - import argparse -import math -import os + +import torch +import torchaudio +import torchcodec +import tqdm import AudioUtils from generator import SISUGenerator - # Init script argument parser parser = argparse.ArgumentParser(description="Training script") -parser.add_argument("--device", type=str, default="cpu", help="Select device") +parser.add_argument("--device", type=str, default="cpu", help="Select device") parser.add_argument("--model", type=str, help="Model to use for upscaling") -parser.add_argument("--clip_length", type=int, default=1024, help="Internal clip length, leave unspecified if unsure") +parser.add_argument( + "--clip_length", + type=int, + default=16384, + help="Internal clip length, leave unspecified if unsure", +) +parser.add_argument( + "--sample_rate", type=int, default=44100, help="Output clip sample rate" +) +parser.add_argument( + "--bitrate", + type=int, + default=192000, + help="Output clip bitrate", +) parser.add_argument("-i", "--input", type=str, help="Input audio file") parser.add_argument("-o", "--output", type=str, help="Output audio file") args = parser.parse_args() +if args.sample_rate < 8000: + print( + "Sample rate cannot be lower than 8000! (44100 is recommended for base models)" + ) + exit() + device = torch.device(args.device if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") -generator = SISUGenerator() +generator = SISUGenerator().to(device) + +generator = torch.compile(generator) models_dir = args.model clip_length = args.clip_length @@ -35,17 +51,30 @@ input_audio = args.input output_audio = args.output if models_dir: - generator.load_state_dict(torch.load(f"{models_dir}", map_location=device, weights_only=True)) + ckpt = torch.load(models_dir, map_location=device) + generator.load_state_dict(ckpt["G"]) else: - print(f"Generator model (--model) isn't specified. Do you have the trained model? If not you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)") + print( + "Generator model (--model) isn't specified. Do you have the trained model? If not, you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)" + ) -generator = generator.to(device) def start(): # To Mono! - audio, original_sample_rate = torchaudio.load(input_audio, normalize=True) + decoder = torchcodec.decoders.AudioDecoder(input_audio) + + decoded_samples = decoder.get_all_samples() + audio = decoded_samples.data + original_sample_rate = decoded_samples.sample_rate + audio = AudioUtils.stereo_tensor_to_mono(audio) + resample_transform = torchaudio.transforms.Resample( + original_sample_rate, args.sample_rate + ) + + audio = resample_transform(audio) + splitted_audio = AudioUtils.split_audio(audio, clip_length) splitted_audio_on_device = [t.to(device) for t in splitted_audio] processed_audio = [] @@ -55,6 +84,13 @@ def start(): reconstructed_audio = AudioUtils.reconstruct_audio(processed_audio) print(f"Saving {output_audio}!") - torchaudio.save(output_audio, reconstructed_audio.cpu().detach(), original_sample_rate) + torchaudio.save_with_torchcodec( + uri=output_audio, + src=reconstructed_audio, + sample_rate=args.sample_rate, + channels_first=True, + compression=args.bitrate, + ) + start() diff --git a/data.py b/data.py index c3e1047..c9f4e76 100644 --- a/data.py +++ b/data.py @@ -1,41 +1,68 @@ -from torch.utils.data import Dataset -import torch.nn.functional as F -import torch -import torchaudio import os import random -import torchaudio.transforms as T + +import torchaudio +import torchcodec.decoders as decoders import tqdm +from torch.utils.data import Dataset + import AudioUtils + class AudioDataset(Dataset): audio_sample_rates = [11025] - def __init__(self, input_dir, device, clip_length = 1024): - self.device = device - input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav') or f.endswith('.mp3') or f.endswith('.flac')] + def __init__(self, input_dir, clip_length=16384): + input_files = [ + os.path.join(root, f) + for root, _, files in os.walk(input_dir) + for f in files + if f.endswith(".wav") or f.endswith(".mp3") or f.endswith(".flac") + ] data = [] - for audio_clip in tqdm.tqdm(input_files, desc=f"Processing {len(input_files)} audio file(s)"): - audio, original_sample_rate = torchaudio.load(audio_clip, normalize=True) + for audio_clip in tqdm.tqdm( + input_files, desc=f"Processing {len(input_files)} audio file(s)" + ): + decoder = decoders.AudioDecoder(audio_clip) + + decoded_samples = decoder.get_all_samples() + audio = decoded_samples.data + original_sample_rate = decoded_samples.sample_rate + audio = AudioUtils.stereo_tensor_to_mono(audio) # Generate low-quality audio with random downsampling mangled_sample_rate = random.choice(self.audio_sample_rates) - resample_transform_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate) - resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate) + resample_transform_low = torchaudio.transforms.Resample( + original_sample_rate, mangled_sample_rate + ) + resample_transform_high = torchaudio.transforms.Resample( + mangled_sample_rate, original_sample_rate + ) low_audio = resample_transform_low(audio) low_audio = resample_transform_high(low_audio) splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length) - splitted_high_quality_audio[-1] = AudioUtils.pad_tensor(splitted_high_quality_audio[-1], clip_length) + splitted_high_quality_audio[-1] = AudioUtils.pad_tensor( + splitted_high_quality_audio[-1], clip_length + ) splitted_low_quality_audio = AudioUtils.split_audio(low_audio, clip_length) - splitted_low_quality_audio[-1] = AudioUtils.pad_tensor(splitted_low_quality_audio[-1], clip_length) + splitted_low_quality_audio[-1] = AudioUtils.pad_tensor( + splitted_low_quality_audio[-1], clip_length + ) - for high_quality_sample, low_quality_sample in zip(splitted_high_quality_audio, splitted_low_quality_audio): - data.append(((high_quality_sample, low_quality_sample), (original_sample_rate, mangled_sample_rate))) + for high_quality_sample, low_quality_sample in zip( + splitted_high_quality_audio, splitted_low_quality_audio + ): + data.append( + ( + (high_quality_sample, low_quality_sample), + (original_sample_rate, mangled_sample_rate), + ) + ) self.audio_data = data diff --git a/discriminator.py b/discriminator.py index dfd0126..ce2e84c 100644 --- a/discriminator.py +++ b/discriminator.py @@ -1,8 +1,16 @@ -import torch import torch.nn as nn import torch.nn.utils as utils -def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, spectral_norm=True, use_instance_norm=True): + +def discriminator_block( + in_channels, + out_channels, + kernel_size=3, + stride=1, + dilation=1, + spectral_norm=True, + use_instance_norm=True, +): padding = (kernel_size // 2) * dilation conv_layer = nn.Conv1d( in_channels, @@ -10,7 +18,7 @@ def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dila kernel_size=kernel_size, stride=stride, dilation=dilation, - padding=padding + padding=padding, ) if spectral_norm: @@ -24,6 +32,7 @@ def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dila return nn.Sequential(*layers) + class AttentionBlock(nn.Module): def __init__(self, channels): super(AttentionBlock, self).__init__() @@ -31,27 +40,86 @@ class AttentionBlock(nn.Module): nn.Conv1d(channels, channels // 4, kernel_size=1), nn.ReLU(inplace=True), nn.Conv1d(channels // 4, channels, kernel_size=1), - nn.Sigmoid() + nn.Sigmoid(), ) def forward(self, x): attention_weights = self.attention(x) return x * attention_weights + class SISUDiscriminator(nn.Module): def __init__(self, base_channels=16): super(SISUDiscriminator, self).__init__() layers = base_channels self.model = nn.Sequential( - discriminator_block(1, layers, kernel_size=7, stride=1, spectral_norm=True, use_instance_norm=False), - discriminator_block(layers, layers * 2, kernel_size=5, stride=2, spectral_norm=True, use_instance_norm=True), - discriminator_block(layers * 2, layers * 4, kernel_size=5, stride=1, dilation=2, spectral_norm=True, use_instance_norm=True), + discriminator_block( + 1, + layers, + kernel_size=7, + stride=1, + spectral_norm=True, + use_instance_norm=False, + ), + discriminator_block( + layers, + layers * 2, + kernel_size=5, + stride=2, + spectral_norm=True, + use_instance_norm=True, + ), + discriminator_block( + layers * 2, + layers * 4, + kernel_size=5, + stride=1, + dilation=2, + spectral_norm=True, + use_instance_norm=True, + ), AttentionBlock(layers * 4), - discriminator_block(layers * 4, layers * 8, kernel_size=5, stride=1, dilation=4, spectral_norm=True, use_instance_norm=True), - discriminator_block(layers * 8, layers * 4, kernel_size=5, stride=2, spectral_norm=True, use_instance_norm=True), - discriminator_block(layers * 4, layers * 2, kernel_size=3, stride=1, spectral_norm=True, use_instance_norm=True), - discriminator_block(layers * 2, layers, kernel_size=3, stride=1, spectral_norm=True, use_instance_norm=True), - discriminator_block(layers, 1, kernel_size=3, stride=1, spectral_norm=False, use_instance_norm=False) + discriminator_block( + layers * 4, + layers * 8, + kernel_size=5, + stride=1, + dilation=4, + spectral_norm=True, + use_instance_norm=True, + ), + discriminator_block( + layers * 8, + layers * 4, + kernel_size=5, + stride=2, + spectral_norm=True, + use_instance_norm=True, + ), + discriminator_block( + layers * 4, + layers * 2, + kernel_size=3, + stride=1, + spectral_norm=True, + use_instance_norm=True, + ), + discriminator_block( + layers * 2, + layers, + kernel_size=3, + stride=1, + spectral_norm=True, + use_instance_norm=True, + ), + discriminator_block( + layers, + 1, + kernel_size=3, + stride=1, + spectral_norm=False, + use_instance_norm=False, + ), ) self.global_avg_pool = nn.AdaptiveAvgPool1d(1) diff --git a/generator.py b/generator.py index a53feb7..0240860 100644 --- a/generator.py +++ b/generator.py @@ -1,6 +1,6 @@ -import torch import torch.nn as nn + def conv_block(in_channels, out_channels, kernel_size=3, dilation=1): return nn.Sequential( nn.Conv1d( @@ -8,29 +8,32 @@ def conv_block(in_channels, out_channels, kernel_size=3, dilation=1): out_channels, kernel_size=kernel_size, dilation=dilation, - padding=(kernel_size // 2) * dilation + padding=(kernel_size // 2) * dilation, ), nn.InstanceNorm1d(out_channels), - nn.PReLU() + nn.PReLU(), ) + class AttentionBlock(nn.Module): """ Simple Channel Attention Block. Learns to weight channels based on their importance. """ + def __init__(self, channels): super(AttentionBlock, self).__init__() self.attention = nn.Sequential( nn.Conv1d(channels, channels // 4, kernel_size=1), nn.ReLU(inplace=True), nn.Conv1d(channels // 4, channels, kernel_size=1), - nn.Sigmoid() + nn.Sigmoid(), ) def forward(self, x): attention_weights = self.attention(x) return x * attention_weights + class ResidualInResidualBlock(nn.Module): def __init__(self, channels, num_convs=3): super(ResidualInResidualBlock, self).__init__() @@ -47,6 +50,7 @@ class ResidualInResidualBlock(nn.Module): x = self.attention(x) return x + residual + class SISUGenerator(nn.Module): def __init__(self, channels=16, num_rirb=4, alpha=1.0): super(SISUGenerator, self).__init__() diff --git a/training.py b/training.py index ab2b1e5..8876740 100644 --- a/training.py +++ b/training.py @@ -1,65 +1,74 @@ +import argparse +import os + import torch import torch.nn as nn import torch.optim as optim - -import torch.nn.functional as F -import torchaudio +import torchaudio.transforms as T import tqdm - -import argparse - -import math - -import os - -from torch.utils.data import random_split +from torch.amp import GradScaler, autocast from torch.utils.data import DataLoader -import AudioUtils +import training_utils from data import AudioDataset -from generator import SISUGenerator from discriminator import SISUDiscriminator - +from generator import SISUGenerator from training_utils import discriminator_train, generator_train -import file_utils as Data - -import torchaudio.transforms as T - -# Init script argument parser -parser = argparse.ArgumentParser(description="Training script") -parser.add_argument("--generator", type=str, default=None, - help="Path to the generator model file") -parser.add_argument("--discriminator", type=str, default=None, - help="Path to the discriminator model file") -parser.add_argument("--device", type=str, default="cpu", help="Select device") -parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning") -parser.add_argument("--debug", action="store_true", help="Print debug logs") -parser.add_argument("--continue_training", action="store_true", help="Continue training using temp_generator and temp_discriminator models") +# --------------------------- +# Argument parsing +# --------------------------- +parser = argparse.ArgumentParser(description="Training script (safer defaults)") +parser.add_argument("--resume", action="store_true", help="Resume training") +parser.add_argument( + "--device", type=str, default="cuda", help="Device (cuda, cpu, mps)" +) +parser.add_argument( + "--epochs", type=int, default=5000, help="Number of training epochs" +) +parser.add_argument("--batch_size", type=int, default=8, help="Batch size") +parser.add_argument("--num_workers", type=int, default=2, help="DataLoader num_workers") +parser.add_argument("--debug", action="store_true", help="Print debug logs") +parser.add_argument( + "--no_pin_memory", action="store_true", help="Disable pin_memory even on CUDA" +) args = parser.parse_args() -device = torch.device(args.device if torch.cuda.is_available() else "cpu") +# --------------------------- +# Device setup +# --------------------------- +# Use requested device only if available +device = torch.device( + args.device if (args.device != "cuda" or torch.cuda.is_available()) else "cpu" +) print(f"Using device: {device}") +# sensible performance flags +if device.type == "cuda": + torch.backends.cudnn.benchmark = True + # optional: torch.set_float32_matmul_precision("high") +debug = args.debug -# Parameters +# --------------------------- +# Audio transforms +# --------------------------- sample_rate = 44100 n_fft = 1024 win_length = n_fft hop_length = n_fft // 4 -n_mels = 40 -n_mfcc = 13 +n_mels = 96 +# n_mfcc = 13 -mfcc_transform = T.MFCC( - sample_rate=sample_rate, - n_mfcc=n_mfcc, - melkwargs={ - 'n_fft': n_fft, - 'hop_length': hop_length, - 'win_length': win_length, - 'n_mels': n_mels, - 'power': 1.0, - } -).to(device) +# mfcc_transform = T.MFCC( +# sample_rate=sample_rate, +# n_mfcc=n_mfcc, +# melkwargs=dict( +# n_fft=n_fft, +# hop_length=hop_length, +# win_length=win_length, +# n_mels=n_mels, +# power=1.0, +# ), +# ).to(device) mel_transform = T.MelSpectrogram( sample_rate=sample_rate, @@ -67,138 +76,198 @@ mel_transform = T.MelSpectrogram( hop_length=hop_length, win_length=win_length, n_mels=n_mels, - power=1.0 # Magnitude Mel + power=1.0, ).to(device) stft_transform = T.Spectrogram( - n_fft=n_fft, - win_length=win_length, - hop_length=hop_length + n_fft=n_fft, win_length=win_length, hop_length=hop_length ).to(device) -debug = args.debug -# Initialize dataset and dataloader -dataset_dir = './dataset/good' -dataset = AudioDataset(dataset_dir, device) -models_dir = "./models" -os.makedirs(models_dir, exist_ok=True) -audio_output_dir = "./output" -os.makedirs(audio_output_dir, exist_ok=True) +# training_utils.init(mel_transform, stft_transform, mfcc_transform) +training_utils.init(mel_transform, stft_transform) -# ========= SINGLE ========= +# --------------------------- +# Dataset / DataLoader +# --------------------------- +dataset_dir = "./dataset/good" +dataset = AudioDataset(dataset_dir) -train_data_loader = DataLoader(dataset, batch_size=2048, shuffle=True, num_workers=24) +train_loader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.num_workers, + pin_memory=True, + persistent_workers=True, +) +# --------------------------- +# Models +# --------------------------- +generator = SISUGenerator().to(device) +discriminator = SISUDiscriminator().to(device) -# ========= MODELS ========= +generator = torch.compile(generator) +discriminator = torch.compile(discriminator) -generator = SISUGenerator() -discriminator = SISUDiscriminator() - -epoch: int = args.epoch - -if args.continue_training: - if args.generator is not None: - generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True)) - elif args.discriminator is not None: - discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True)) - else: - generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) - discriminator.load_state_dict(torch.load(f"{models_dir}/temp_discriminator.pt", map_location=device, weights_only=True)) - - epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json") - epoch = epoch_from_file["epoch"] + 1 - -generator = generator.to(device) -discriminator = discriminator.to(device) - -# Loss +# --------------------------- +# Losses / Optimizers / Scalers +# --------------------------- criterion_g = nn.BCEWithLogitsLoss() criterion_d = nn.BCEWithLogitsLoss() -# Optimizers -optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999)) -optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999)) +optimizer_g = optim.AdamW( + generator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001 +) +optimizer_d = optim.AdamW( + discriminator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001 +) -# Scheduler -scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5) -scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5) +# Use modern GradScaler signature; choose device_type based on runtime device. +scaler = GradScaler(device=device) -def start_training(): - generator_epochs = 5000 - for generator_epoch in range(generator_epochs): - high_quality_audio = ([torch.empty((1))], 1) - low_quality_audio = ([torch.empty((1))], 1) - ai_enhanced_audio = ([torch.empty((1))], 1) +scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer_g, mode="min", factor=0.5, patience=5 +) +scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer_d, mode="min", factor=0.5, patience=5 +) - times_correct = 0 - - # ========= TRAINING ========= - for training_data in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"): - ## Data structure: - # [[[float..., float..., float...], [float..., float..., float...]], [original_sample_rate, mangled_sample_rate]] - - # ========= LABELS ========= - good_quality_data = training_data[0][0].to(device) - bad_quality_data = training_data[0][1].to(device) - original_sample_rate = training_data[1][0] - mangled_sample_rate = training_data[1][1] - - batch_size = good_quality_data.size(0) - real_labels = torch.ones(batch_size, 1).to(device) - fake_labels = torch.zeros(batch_size, 1).to(device) - - high_quality_audio = (good_quality_data, original_sample_rate) - low_quality_audio = (bad_quality_data, mangled_sample_rate) - - # ========= DISCRIMINATOR ========= - discriminator.train() - d_loss = discriminator_train( - good_quality_data, - bad_quality_data, - real_labels, - fake_labels, - discriminator, - generator, - criterion_d, - optimizer_d - ) - - # ========= GENERATOR ========= - generator.train() - generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train( - bad_quality_data, - good_quality_data, - real_labels, - generator, - discriminator, - criterion_d, - optimizer_g, - device, - mel_transform, - stft_transform, - mfcc_transform - ) - - if debug: - print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}") - scheduler_d.step(d_loss.detach()) - scheduler_g.step(adversarial_loss.detach()) - - # ========= SAVE LATEST AUDIO ========= - high_quality_audio = (good_quality_data, original_sample_rate) - low_quality_audio = (bad_quality_data, original_sample_rate) - ai_enhanced_audio = (generator_output, original_sample_rate) - - torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt") - torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt") - - new_epoch = generator_epoch+epoch - Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch}) +# --------------------------- +# Checkpoint helpers +# --------------------------- +models_dir = "./models" +os.makedirs(models_dir, exist_ok=True) - torch.save(discriminator, "models/epoch-5000-discriminator.pt") - torch.save(generator, "models/epoch-5000-generator.pt") - print("Training complete!") +def save_ckpt(path, epoch): + torch.save( + { + "epoch": epoch, + "G": generator.state_dict(), + "D": discriminator.state_dict(), + "optG": optimizer_g.state_dict(), + "optD": optimizer_d.state_dict(), + "scaler": scaler.state_dict(), + "schedG": scheduler_g.state_dict(), + "schedD": scheduler_d.state_dict(), + }, + path, + ) -start_training() + +start_epoch = 0 +if args.resume: + ckpt = torch.load(os.path.join(models_dir, "last.pt"), map_location=device) + generator.load_state_dict(ckpt["G"]) + discriminator.load_state_dict(ckpt["D"]) + optimizer_g.load_state_dict(ckpt["optG"]) + optimizer_d.load_state_dict(ckpt["optD"]) + scaler.load_state_dict(ckpt["scaler"]) + scheduler_g.load_state_dict(ckpt["schedG"]) + scheduler_d.load_state_dict(ckpt["schedD"]) + start_epoch = ckpt.get("epoch", 1) + +# --------------------------- +# Training loop (safer) +# --------------------------- + +if not train_loader or not train_loader.batch_size: + print("There is no data to train with! Exiting...") + exit() + +max_batch = max(1, train_loader.batch_size) +real_buf = torch.full((max_batch, 1), 0.9, device=device) # label smoothing +fake_buf = torch.zeros(max_batch, 1, device=device) + +try: + for epoch in range(start_epoch, args.epochs): + generator.train() + discriminator.train() + + running_d, running_g, steps = 0.0, 0.0, 0 + + for i, ( + (high_quality, low_quality), + (high_sample_rate, low_sample_rate), + ) in enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")): + batch_size = high_quality.size(0) + + high_quality = high_quality.to(device, non_blocking=True) + low_quality = low_quality.to(device, non_blocking=True) + + real_labels = real_buf[:batch_size] + fake_labels = fake_buf[:batch_size] + + # --- Discriminator --- + optimizer_d.zero_grad(set_to_none=True) + with autocast(device_type=device.type): + d_loss = discriminator_train( + high_quality, + low_quality, + real_labels, + fake_labels, + discriminator, + generator, + criterion_d, + ) + + scaler.scale(d_loss).backward() + scaler.unscale_(optimizer_d) + torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0) + scaler.step(optimizer_d) + + # --- Generator --- + optimizer_g.zero_grad(set_to_none=True) + with autocast(device_type=device.type): + g_out, g_total, g_adv = generator_train( + low_quality, + high_quality, + real_labels, + generator, + discriminator, + criterion_d, + ) + + scaler.scale(g_total).backward() + scaler.unscale_(optimizer_g) + torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0) + scaler.step(optimizer_g) + + scaler.update() + + running_d += float(d_loss.detach().cpu().item()) + running_g += float(g_total.detach().cpu().item()) + steps += 1 + + # epoch averages & schedulers + if steps == 0: + print("No steps in epoch (empty dataloader?). Exiting.") + break + + mean_d = running_d / steps + mean_g = running_g / steps + + scheduler_d.step(mean_d) + scheduler_g.step(mean_g) + + save_ckpt(os.path.join(models_dir, "last.pt"), epoch) + print(f"Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}") + +except Exception: + try: + save_ckpt(os.path.join(models_dir, "crash_last.pt"), epoch) + print(f"Saved crash checkpoint for epoch {epoch}") + except Exception as e: + print("Failed saving crash checkpoint:", e) + raise + +try: + torch.save(generator.state_dict(), os.path.join(models_dir, "final_generator.pt")) + torch.save( + discriminator.state_dict(), os.path.join(models_dir, "final_discriminator.pt") + ) +except Exception as e: + print("Failed to save final states:", e) + +print("Training finished.") diff --git a/training_utils.py b/training_utils.py index c7d43e5..02403af 100644 --- a/training_utils.py +++ b/training_utils.py @@ -1,89 +1,88 @@ import torch -import torch.nn as nn -import torch.optim as optim - -import torchaudio import torchaudio.transforms as T -def gpu_mfcc_loss(mfcc_transform, y_true, y_pred): - mfccs_true = mfcc_transform(y_true) - mfccs_pred = mfcc_transform(y_pred) +from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss - min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2]) - mfccs_true = mfccs_true[:, :, :min_len] - mfccs_pred = mfccs_pred[:, :, :min_len] +mel_transform: T.MelSpectrogram +stft_transform: T.Spectrogram +# mfcc_transform: T.MFCC - loss = torch.mean((mfccs_true - mfccs_pred)**2) - return loss -def mel_spectrogram_l1_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: - mel_spec_true = mel_transform(y_true) - mel_spec_pred = mel_transform(y_pred) +# def init(mel_trans: T.MelSpectrogram, stft_trans: T.Spectrogram, mfcc_trans: T.MFCC): +# """Initializes the global transform variables for the module.""" +# global mel_transform, stft_transform, mfcc_transform +# mel_transform = mel_trans +# stft_transform = stft_trans +# mfcc_transform = mfcc_trans - min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1]) - mel_spec_true = mel_spec_true[..., :min_len] - mel_spec_pred = mel_spec_pred[..., :min_len] - loss = torch.mean(torch.abs(mel_spec_true - mel_spec_pred)) - return loss +def init(mel_trans: T.MelSpectrogram, stft_trans: T.Spectrogram): + """Initializes the global transform variables for the module.""" + global mel_transform, stft_transform + mel_transform = mel_trans + stft_transform = stft_trans -def mel_spectrogram_l2_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: - mel_spec_true = mel_transform(y_true) - mel_spec_pred = mel_transform(y_pred) - min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1]) - mel_spec_true = mel_spec_true[..., :min_len] - mel_spec_pred = mel_spec_pred[..., :min_len] +# def mfcc_loss(y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: +# """Computes the Mean Squared Error (MSE) loss on MFCCs.""" +# mfccs_true = mfcc_transform(y_true) +# mfccs_pred = mfcc_transform(y_pred) +# return F.mse_loss(mfccs_pred, mfccs_true) - loss = torch.mean((mel_spec_true - mel_spec_pred)**2) - return loss -def log_stft_magnitude_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: - stft_mag_true = stft_transform(y_true) - stft_mag_pred = stft_transform(y_pred) +# def mel_spectrogram_loss( +# y_true: torch.Tensor, y_pred: torch.Tensor, loss_type: str = "l1" +# ) -> torch.Tensor: +# """Calculates L1 or L2 loss on the Mel Spectrogram.""" +# mel_spec_true = mel_transform(y_true) +# mel_spec_pred = mel_transform(y_pred) +# if loss_type == "l1": +# return F.l1_loss(mel_spec_pred, mel_spec_true) +# elif loss_type == "l2": +# return F.mse_loss(mel_spec_pred, mel_spec_true) +# else: +# raise ValueError("loss_type must be 'l1' or 'l2'") - min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1]) - stft_mag_true = stft_mag_true[..., :min_len] - stft_mag_pred = stft_mag_pred[..., :min_len] - loss = torch.mean(torch.abs(torch.log(stft_mag_true + eps) - torch.log(stft_mag_pred + eps))) - return loss +# def log_stft_magnitude_loss( +# y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7 +# ) -> torch.Tensor: +# """Calculates L1 loss on the log STFT magnitude.""" +# stft_mag_true = stft_transform(y_true) +# stft_mag_pred = stft_transform(y_pred) +# return F.l1_loss(torch.log(stft_mag_pred + eps), torch.log(stft_mag_true + eps)) -def spectral_convergence_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: - stft_mag_true = stft_transform(y_true) - stft_mag_pred = stft_transform(y_pred) - min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1]) - stft_mag_true = stft_mag_true[..., :min_len] - stft_mag_pred = stft_mag_pred[..., :min_len] +stft_loss_fn = MultiResolutionSTFTLoss( + fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240] +) - norm_true = torch.linalg.norm(stft_mag_true, ord='fro', dim=(-2, -1)) - norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, ord='fro', dim=(-2, -1)) - loss = torch.mean(norm_diff / (norm_true + eps)) - return loss - -def discriminator_train(high_quality, low_quality, real_labels, fake_labels, discriminator, generator, criterion, optimizer): - optimizer.zero_grad() - - # Forward pass for real samples +def discriminator_train( + high_quality, + low_quality, + real_labels, + fake_labels, + discriminator, + generator, + criterion, +): discriminator_decision_from_real = discriminator(high_quality) d_loss_real = criterion(discriminator_decision_from_real, real_labels) with torch.no_grad(): generator_output = generator(low_quality) discriminator_decision_from_fake = discriminator(generator_output) - d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels.expand_as(discriminator_decision_from_fake)) + d_loss_fake = criterion( + discriminator_decision_from_fake, + fake_labels.expand_as(discriminator_decision_from_fake), + ) d_loss = (d_loss_real + d_loss_fake) / 2.0 - d_loss.backward() - # Optional: Gradient Clipping (can be helpful) - # nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping - optimizer.step() - return d_loss + def generator_train( low_quality, high_quality, @@ -91,52 +90,65 @@ def generator_train( generator, discriminator, adv_criterion, - g_optimizer, - device, - mel_transform: T.MelSpectrogram, - stft_transform: T.Spectrogram, - mfcc_transform: T.MFCC, lambda_adv: float = 1.0, - lambda_mel_l1: float = 10.0, - lambda_log_stft: float = 1.0, - lambda_mfcc: float = 1.0 + lambda_feat: float = 10.0, + lambda_stft: float = 2.5, ): - g_optimizer.zero_grad() - generator_output = generator(low_quality) discriminator_decision = discriminator(generator_output) - adversarial_loss = adv_criterion(discriminator_decision, real_labels.expand_as(discriminator_decision)) + # adversarial_loss = adv_criterion( + # discriminator_decision, real_labels.expand_as(discriminator_decision) + # ) + adversarial_loss = adv_criterion(discriminator_decision, real_labels) - mel_l1 = 0.0 - log_stft_l1 = 0.0 - mfcc_l = 0.0 + combined_loss = lambda_adv * adversarial_loss - # Calculate Mel L1 Loss if weight is positive - if lambda_mel_l1 > 0: - mel_l1 = mel_spectrogram_l1_loss(mel_transform, high_quality, generator_output) + stft_losses = stft_loss_fn(high_quality, generator_output) + stft_loss = stft_losses["total"] - # Calculate Log STFT L1 Loss if weight is positive - if lambda_log_stft > 0: - log_stft_l1 = log_stft_magnitude_loss(stft_transform, high_quality, generator_output) + combined_loss = (lambda_adv * adversarial_loss) + (lambda_stft * stft_loss) - # Calculate MFCC Loss if weight is positive - if lambda_mfcc > 0: - mfcc_l = gpu_mfcc_loss(mfcc_transform, high_quality, generator_output) + return generator_output, combined_loss, adversarial_loss - mel_l1_tensor = torch.tensor(mel_l1, device=device) if isinstance(mel_l1, float) else mel_l1 - log_stft_l1_tensor = torch.tensor(log_stft_l1, device=device) if isinstance(log_stft_l1, float) else log_stft_l1 - mfcc_l_tensor = torch.tensor(mfcc_l, device=device) if isinstance(mfcc_l, float) else mfcc_l - combined_loss = (lambda_adv * adversarial_loss) + \ - (lambda_mel_l1 * mel_l1_tensor) + \ - (lambda_log_stft * log_stft_l1_tensor) + \ - (lambda_mfcc * mfcc_l_tensor) +# def generator_train( +# low_quality, +# high_quality, +# real_labels, +# generator, +# discriminator, +# adv_criterion, +# lambda_adv: float = 1.0, +# lambda_mel_l1: float = 10.0, +# lambda_log_stft: float = 1.0, - combined_loss.backward() - # Optional: Gradient Clipping - # nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0) - g_optimizer.step() +# ): +# generator_output = generator(low_quality) - # 6. Return values for logging - return generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor +# discriminator_decision = discriminator(generator_output) +# adversarial_loss = adv_criterion( +# discriminator_decision, real_labels.expand_as(discriminator_decision) +# ) + +# combined_loss = lambda_adv * adversarial_loss + +# if lambda_mel_l1 > 0: +# mel_l1_loss = mel_spectrogram_loss(high_quality, generator_output, "l1") +# combined_loss += lambda_mel_l1 * mel_l1_loss +# else: +# mel_l1_loss = torch.tensor(0.0, device=low_quality.device) # For logging + +# if lambda_log_stft > 0: +# log_stft_loss = log_stft_magnitude_loss(high_quality, generator_output) +# combined_loss += lambda_log_stft * log_stft_loss +# else: +# log_stft_loss = torch.tensor(0.0, device=low_quality.device) + +# if lambda_mfcc > 0: +# mfcc_loss_val = mfcc_loss(high_quality, generator_output) +# combined_loss += lambda_mfcc * mfcc_loss_val +# else: +# mfcc_loss_val = torch.tensor(0.0, device=low_quality.device) + +# return generator_output, combined_loss, adversarial_loss diff --git a/utils/MultiResolutionSTFTLoss.py b/utils/MultiResolutionSTFTLoss.py new file mode 100644 index 0000000..5712fc3 --- /dev/null +++ b/utils/MultiResolutionSTFTLoss.py @@ -0,0 +1,62 @@ +from typing import Dict, List + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio.transforms as T + + +class MultiResolutionSTFTLoss(nn.Module): + """ + Computes a loss based on multiple STFT resolutions, including both + spectral convergence and log STFT magnitude components. + """ + + def __init__( + self, + fft_sizes: List[int] = [1024, 2048, 512], + hop_sizes: List[int] = [120, 240, 50], + win_lengths: List[int] = [600, 1200, 240], + eps: float = 1e-7, + ): + super().__init__() + self.stft_transforms = nn.ModuleList( + [ + T.Spectrogram( + n_fft=n_fft, win_length=win_len, hop_length=hop_len, power=None + ) + for n_fft, hop_len, win_len in zip(fft_sizes, hop_sizes, win_lengths) + ] + ) + self.eps = eps + + def forward( + self, y_true: torch.Tensor, y_pred: torch.Tensor + ) -> Dict[str, torch.Tensor]: + sc_loss = 0.0 # Spectral Convergence Loss + mag_loss = 0.0 # Log STFT Magnitude Loss + + for stft in self.stft_transforms: + stft.to(y_pred.device) # Ensure transform is on the correct device + + # Get complex STFTs + stft_true = stft(y_true) + stft_pred = stft(y_pred) + + # Get magnitudes + stft_mag_true = torch.abs(stft_true) + stft_mag_pred = torch.abs(stft_pred) + + # --- Spectral Convergence Loss --- + # || |S_true| - |S_pred| ||_F / || |S_true| ||_F + norm_true = torch.linalg.norm(stft_mag_true, dim=(-2, -1)) + norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, dim=(-2, -1)) + sc_loss += torch.mean(norm_diff / (norm_true + self.eps)) + + # --- Log STFT Magnitude Loss --- + mag_loss += F.l1_loss( + torch.log(stft_mag_pred + self.eps), torch.log(stft_mag_true + self.eps) + ) + + total_loss = sc_loss + mag_loss + return {"total": total_loss, "sc": sc_loss, "mag": mag_loss} diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 From 3f23242d6fbedfd8052a132193ec7c99eb6b8fbb Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Sat, 4 Oct 2025 22:38:11 +0300 Subject: [PATCH 08/11] :alembic: | Added some stupid ways for training + some makeup --- AudioUtils.py | 94 +++++++----- __init__.py | 0 app.py | 1 + data.py | 34 +++-- discriminator.py | 68 +-------- file_utils.py | 30 ---- generator.py | 9 +- requirements.txt | 12 -- training.py | 246 ++++++++++++++----------------- training_utils.py | 154 ------------------- utils/MultiResolutionSTFTLoss.py | 59 +++++--- utils/TrainingTools.py | 60 ++++++++ 12 files changed, 304 insertions(+), 463 deletions(-) create mode 100644 __init__.py delete mode 100644 file_utils.py delete mode 100644 requirements.txt delete mode 100644 training_utils.py create mode 100644 utils/TrainingTools.py diff --git a/AudioUtils.py b/AudioUtils.py index f45efb5..183dc36 100644 --- a/AudioUtils.py +++ b/AudioUtils.py @@ -1,71 +1,97 @@ import torch import torch.nn.functional as F -def stereo_tensor_to_mono(waveform): + +def stereo_tensor_to_mono(waveform: torch.Tensor) -> torch.Tensor: + """ + Convert stereo (C, N) to mono (1, N). Ensures a channel dimension. + """ + if waveform.dim() == 1: + waveform = waveform.unsqueeze(0) # (N,) -> (1, N) + if waveform.shape[0] > 1: - # Average across channels - mono_waveform = torch.mean(waveform, dim=0, keepdim=True) + mono_waveform = torch.mean(waveform, dim=0, keepdim=True) # (1, N) else: - # Already mono mono_waveform = waveform return mono_waveform -def stretch_tensor(tensor, target_length): - scale_factor = target_length / tensor.size(1) - tensor = F.interpolate(tensor, scale_factor=scale_factor, mode='linear', align_corners=False) +def stretch_tensor(tensor: torch.Tensor, target_length: int) -> torch.Tensor: + """ + Stretch audio along time dimension to target_length. + Input assumed (1, N). Returns (1, target_length). + """ + if tensor.dim() == 1: + tensor = tensor.unsqueeze(0) # ensure (1, N) - return tensor + tensor = tensor.unsqueeze(0) # (1, 1, N) for interpolate + stretched = F.interpolate( + tensor, size=target_length, mode="linear", align_corners=False + ) + return stretched.squeeze(0) # back to (1, target_length) + + +def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 128) -> torch.Tensor: + """ + Pad to fixed length. Input assumed (1, N). Returns (1, target_length). + """ + if audio_tensor.dim() == 1: + audio_tensor = audio_tensor.unsqueeze(0) -def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 128): current_length = audio_tensor.shape[-1] - if current_length < target_length: padding_needed = target_length - current_length - padding_tuple = (0, padding_needed) - padded_audio_tensor = F.pad(audio_tensor, padding_tuple, mode='constant', value=0) + padded_audio_tensor = F.pad( + audio_tensor, padding_tuple, mode="constant", value=0 + ) else: - padded_audio_tensor = audio_tensor + padded_audio_tensor = audio_tensor[..., :target_length] # crop if too long return padded_audio_tensor -def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 128) -> list[torch.Tensor]: + +def split_audio( + audio_tensor: torch.Tensor, chunk_size: int = 128 +) -> list[torch.Tensor]: + """ + Split into chunks of (1, chunk_size). + """ if not isinstance(chunk_size, int) or chunk_size <= 0: raise ValueError("chunk_size must be a positive integer.") - # Handle scalar tensor edge case if necessary - if audio_tensor.dim() == 0: - return [audio_tensor] if audio_tensor.numel() > 0 else [] - - # Identify the dimension to split (usually the last one, representing time/samples) - split_dim = -1 - num_samples = audio_tensor.shape[split_dim] + if audio_tensor.dim() == 1: + audio_tensor = audio_tensor.unsqueeze(0) + num_samples = audio_tensor.shape[-1] if num_samples == 0: - return [] # Return empty list if the dimension to split is empty - - # Use torch.split to divide the tensor into chunks - # It handles the last chunk being potentially smaller automatically. - chunks = list(torch.split(audio_tensor, chunk_size, dim=split_dim)) + return [] + chunks = list(torch.split(audio_tensor, chunk_size, dim=-1)) return chunks + def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor: + """ + Reconstruct audio from chunks. Returns (1, N). + """ if not chunks: - return torch.empty(0) - - if len(chunks) == 1 and chunks[0].dim() == 0: - return chunks[0] - - concat_dim = -1 + return torch.empty(1, 0) + chunks = [c if c.dim() == 2 else c.unsqueeze(0) for c in chunks] try: - reconstructed_tensor = torch.cat(chunks, dim=concat_dim) + reconstructed_tensor = torch.cat(chunks, dim=-1) except RuntimeError as e: raise RuntimeError( f"Failed to concatenate audio chunks. Ensure chunks have compatible shapes " - f"for concatenation along dimension {concat_dim}. Original error: {e}" + f"for concatenation along dim -1. Original error: {e}" ) return reconstructed_tensor + + +def normalize(audio_tensor: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: + max_val = torch.max(torch.abs(audio_tensor)) + if max_val < eps: + return audio_tensor # silence, skip normalization + return audio_tensor / max_val diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app.py b/app.py index 8c669a9..006cba5 100644 --- a/app.py +++ b/app.py @@ -68,6 +68,7 @@ def start(): original_sample_rate = decoded_samples.sample_rate audio = AudioUtils.stereo_tensor_to_mono(audio) + audio = AudioUtils.normalize(audio) resample_transform = torchaudio.transforms.Resample( original_sample_rate, args.sample_rate diff --git a/data.py b/data.py index c9f4e76..a2ddb71 100644 --- a/data.py +++ b/data.py @@ -12,12 +12,15 @@ import AudioUtils class AudioDataset(Dataset): audio_sample_rates = [11025] - def __init__(self, input_dir, clip_length=16384): + def __init__(self, input_dir, clip_length: int = 8000, normalize: bool = True): + self.clip_length = clip_length + self.normalize = normalize + input_files = [ - os.path.join(root, f) - for root, _, files in os.walk(input_dir) - for f in files - if f.endswith(".wav") or f.endswith(".mp3") or f.endswith(".flac") + os.path.join(input_dir, f) + for f in os.listdir(input_dir) + if os.path.isfile(os.path.join(input_dir, f)) + and f.lower().endswith((".wav", ".mp3", ".flac")) ] data = [] @@ -25,14 +28,15 @@ class AudioDataset(Dataset): input_files, desc=f"Processing {len(input_files)} audio file(s)" ): decoder = decoders.AudioDecoder(audio_clip) - decoded_samples = decoder.get_all_samples() - audio = decoded_samples.data + + audio = decoded_samples.data.float() # ensure float32 original_sample_rate = decoded_samples.sample_rate audio = AudioUtils.stereo_tensor_to_mono(audio) + if normalize: + audio = AudioUtils.normalize(audio) - # Generate low-quality audio with random downsampling mangled_sample_rate = random.choice(self.audio_sample_rates) resample_transform_low = torchaudio.transforms.Resample( original_sample_rate, mangled_sample_rate @@ -41,25 +45,27 @@ class AudioDataset(Dataset): mangled_sample_rate, original_sample_rate ) - low_audio = resample_transform_low(audio) - low_audio = resample_transform_high(low_audio) + low_audio = resample_transform_high(resample_transform_low(audio)) splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length) + splitted_low_quality_audio = AudioUtils.split_audio(low_audio, clip_length) + + if not splitted_high_quality_audio or not splitted_low_quality_audio: + continue # skip empty or invalid clips + splitted_high_quality_audio[-1] = AudioUtils.pad_tensor( splitted_high_quality_audio[-1], clip_length ) - - splitted_low_quality_audio = AudioUtils.split_audio(low_audio, clip_length) splitted_low_quality_audio[-1] = AudioUtils.pad_tensor( splitted_low_quality_audio[-1], clip_length ) - for high_quality_sample, low_quality_sample in zip( + for high_quality_data, low_quality_data in zip( splitted_high_quality_audio, splitted_low_quality_audio ): data.append( ( - (high_quality_sample, low_quality_sample), + (high_quality_data, low_quality_data), (original_sample_rate, mangled_sample_rate), ) ) diff --git a/discriminator.py b/discriminator.py index ce2e84c..5e8442b 100644 --- a/discriminator.py +++ b/discriminator.py @@ -49,74 +49,18 @@ class AttentionBlock(nn.Module): class SISUDiscriminator(nn.Module): - def __init__(self, base_channels=16): + def __init__(self, layers=32): super(SISUDiscriminator, self).__init__() - layers = base_channels self.model = nn.Sequential( - discriminator_block( - 1, - layers, - kernel_size=7, - stride=1, - spectral_norm=True, - use_instance_norm=False, - ), - discriminator_block( - layers, - layers * 2, - kernel_size=5, - stride=2, - spectral_norm=True, - use_instance_norm=True, - ), - discriminator_block( - layers * 2, - layers * 4, - kernel_size=5, - stride=1, - dilation=2, - spectral_norm=True, - use_instance_norm=True, - ), + discriminator_block(1, layers, kernel_size=7, stride=1), + discriminator_block(layers, layers * 2, kernel_size=5, stride=2), + discriminator_block(layers * 2, layers * 4, kernel_size=5, dilation=2), AttentionBlock(layers * 4), - discriminator_block( - layers * 4, - layers * 8, - kernel_size=5, - stride=1, - dilation=4, - spectral_norm=True, - use_instance_norm=True, - ), - discriminator_block( - layers * 8, - layers * 4, - kernel_size=5, - stride=2, - spectral_norm=True, - use_instance_norm=True, - ), - discriminator_block( - layers * 4, - layers * 2, - kernel_size=3, - stride=1, - spectral_norm=True, - use_instance_norm=True, - ), + discriminator_block(layers * 4, layers * 8, kernel_size=5, dilation=4), + discriminator_block(layers * 8, layers * 2, kernel_size=5, stride=2), discriminator_block( layers * 2, - layers, - kernel_size=3, - stride=1, - spectral_norm=True, - use_instance_norm=True, - ), - discriminator_block( - layers, 1, - kernel_size=3, - stride=1, spectral_norm=False, use_instance_norm=False, ), diff --git a/file_utils.py b/file_utils.py deleted file mode 100644 index 98f70bc..0000000 --- a/file_utils.py +++ /dev/null @@ -1,30 +0,0 @@ -import json - -filepath = "my_data.json" - -def write_data(filepath, data, debug=False): - try: - with open(filepath, 'w') as f: - json.dump(data, f, indent=4) # Use indent for pretty formatting - if debug: - print(f"Data written to '{filepath}'") - except Exception as e: - print(f"Error writing to file: {e}") - - -def read_data(filepath, debug=False): - try: - with open(filepath, 'r') as f: - data = json.load(f) - if debug: - print(f"Data read from '{filepath}'") - return data - except FileNotFoundError: - print(f"File not found: {filepath}") - return None - except json.JSONDecodeError: - print(f"Error decoding JSON from file: {filepath}") - return None - except Exception as e: - print(f"Error reading from file: {e}") - return None diff --git a/generator.py b/generator.py index 0240860..b6d2204 100644 --- a/generator.py +++ b/generator.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn @@ -52,7 +53,7 @@ class ResidualInResidualBlock(nn.Module): class SISUGenerator(nn.Module): - def __init__(self, channels=16, num_rirb=4, alpha=1.0): + def __init__(self, channels=16, num_rirb=4, alpha=1): super(SISUGenerator, self).__init__() self.alpha = alpha @@ -66,7 +67,9 @@ class SISUGenerator(nn.Module): *[ResidualInResidualBlock(channels) for _ in range(num_rirb)] ) - self.final_layer = nn.Conv1d(channels, 1, kernel_size=3, padding=1) + self.final_layer = nn.Sequential( + nn.Conv1d(channels, 1, kernel_size=3, padding=1), nn.Tanh() + ) def forward(self, x): residual_input = x @@ -75,4 +78,4 @@ class SISUGenerator(nn.Module): learned_residual = self.final_layer(x_rirb_out) output = residual_input + self.alpha * learned_residual - return output + return torch.tanh(output) diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 21f6bef..0000000 --- a/requirements.txt +++ /dev/null @@ -1,12 +0,0 @@ -filelock==3.16.1 -fsspec==2024.10.0 -Jinja2==3.1.4 -MarkupSafe==2.1.5 -mpmath==1.3.0 -networkx==3.4.2 -numpy==2.2.3 -pillow==11.0.0 -setuptools==70.2.0 -sympy==1.13.3 -tqdm==4.67.1 -typing_extensions==4.12.2 diff --git a/training.py b/training.py index 8876740..0e0caf8 100644 --- a/training.py +++ b/training.py @@ -4,25 +4,20 @@ import os import torch import torch.nn as nn import torch.optim as optim -import torchaudio.transforms as T import tqdm -from torch.amp import GradScaler, autocast -from torch.utils.data import DataLoader +from accelerate import Accelerator +from torch.utils.data import DataLoader, DistributedSampler -import training_utils from data import AudioDataset from discriminator import SISUDiscriminator from generator import SISUGenerator -from training_utils import discriminator_train, generator_train +from utils.TrainingTools import discriminator_train, generator_train # --------------------------- # Argument parsing # --------------------------- parser = argparse.ArgumentParser(description="Training script (safer defaults)") parser.add_argument("--resume", action="store_true", help="Resume training") -parser.add_argument( - "--device", type=str, default="cuda", help="Device (cuda, cpu, mps)" -) parser.add_argument( "--epochs", type=int, default=5000, help="Number of training epochs" ) @@ -35,86 +30,54 @@ parser.add_argument( args = parser.parse_args() # --------------------------- -# Device setup +# Init accelerator # --------------------------- -# Use requested device only if available -device = torch.device( - args.device if (args.device != "cuda" or torch.cuda.is_available()) else "cpu" -) -print(f"Using device: {device}") -# sensible performance flags -if device.type == "cuda": - torch.backends.cudnn.benchmark = True - # optional: torch.set_float32_matmul_precision("high") -debug = args.debug -# --------------------------- -# Audio transforms -# --------------------------- -sample_rate = 44100 -n_fft = 1024 -win_length = n_fft -hop_length = n_fft // 4 -n_mels = 96 -# n_mfcc = 13 - -# mfcc_transform = T.MFCC( -# sample_rate=sample_rate, -# n_mfcc=n_mfcc, -# melkwargs=dict( -# n_fft=n_fft, -# hop_length=hop_length, -# win_length=win_length, -# n_mels=n_mels, -# power=1.0, -# ), -# ).to(device) - -mel_transform = T.MelSpectrogram( - sample_rate=sample_rate, - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - n_mels=n_mels, - power=1.0, -).to(device) - -stft_transform = T.Spectrogram( - n_fft=n_fft, win_length=win_length, hop_length=hop_length -).to(device) - -# training_utils.init(mel_transform, stft_transform, mfcc_transform) -training_utils.init(mel_transform, stft_transform) - -# --------------------------- -# Dataset / DataLoader -# --------------------------- -dataset_dir = "./dataset/good" -dataset = AudioDataset(dataset_dir) - -train_loader = DataLoader( - dataset, - batch_size=args.batch_size, - shuffle=True, - num_workers=args.num_workers, - pin_memory=True, - persistent_workers=True, -) +accelerator = Accelerator(mixed_precision="bf16") # --------------------------- # Models # --------------------------- -generator = SISUGenerator().to(device) -discriminator = SISUDiscriminator().to(device) +generator = SISUGenerator() +discriminator = SISUDiscriminator() + +accelerator.print("๐Ÿ”จ | Compiling models...") generator = torch.compile(generator) discriminator = torch.compile(discriminator) +accelerator.print("โœ… | Compiling done!") + +# --------------------------- +# Dataset / DataLoader +# --------------------------- +accelerator.print("๐Ÿ“Š | Fetching dataset...") +dataset = AudioDataset("./dataset") + +sampler = DistributedSampler(dataset) if accelerator.num_processes > 1 else None +pin_memory = torch.cuda.is_available() and not args.no_pin_memory + +train_loader = DataLoader( + dataset, + sampler=sampler, + batch_size=args.batch_size, + shuffle=(sampler is None), + num_workers=args.num_workers, + pin_memory=pin_memory, + persistent_workers=pin_memory, +) + +if not train_loader or not train_loader.batch_size or train_loader.batch_size == 0: + accelerator.print("๐Ÿชน | There is no data to train with! Exiting...") + exit() + +loader_batch_size = train_loader.batch_size + +accelerator.print("โœ… | Dataset fetched!") + # --------------------------- # Losses / Optimizers / Scalers # --------------------------- -criterion_g = nn.BCEWithLogitsLoss() -criterion_d = nn.BCEWithLogitsLoss() optimizer_g = optim.AdamW( generator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001 @@ -123,9 +86,6 @@ optimizer_d = optim.AdamW( discriminator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001 ) -# Use modern GradScaler signature; choose device_type based on runtime device. -scaler = GradScaler(device=device) - scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer_g, mode="min", factor=0.5, patience=5 ) @@ -133,6 +93,17 @@ scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer_d, mode="min", factor=0.5, patience=5 ) +criterion_g = nn.BCEWithLogitsLoss() +criterion_d = nn.MSELoss() + +# --------------------------- +# Prepare accelerator +# --------------------------- + +generator, discriminator, optimizer_g, optimizer_d, train_loader = accelerator.prepare( + generator, discriminator, optimizer_g, optimizer_d, train_loader +) + # --------------------------- # Checkpoint helpers # --------------------------- @@ -141,44 +112,45 @@ os.makedirs(models_dir, exist_ok=True) def save_ckpt(path, epoch): - torch.save( - { - "epoch": epoch, - "G": generator.state_dict(), - "D": discriminator.state_dict(), - "optG": optimizer_g.state_dict(), - "optD": optimizer_d.state_dict(), - "scaler": scaler.state_dict(), - "schedG": scheduler_g.state_dict(), - "schedD": scheduler_d.state_dict(), - }, - path, - ) + accelerator.wait_for_everyone() + if accelerator.is_main_process: + accelerator.save( + { + "epoch": epoch, + "G": accelerator.unwrap_model(generator).state_dict(), + "D": accelerator.unwrap_model(discriminator).state_dict(), + "optG": optimizer_g.state_dict(), + "optD": optimizer_d.state_dict(), + "schedG": scheduler_g.state_dict(), + "schedD": scheduler_d.state_dict(), + }, + path, + ) start_epoch = 0 if args.resume: - ckpt = torch.load(os.path.join(models_dir, "last.pt"), map_location=device) - generator.load_state_dict(ckpt["G"]) - discriminator.load_state_dict(ckpt["D"]) + ckpt_path = os.path.join(models_dir, "last.pt") + ckpt = torch.load(ckpt_path) + + accelerator.unwrap_model(generator).load_state_dict(ckpt["G"]) + accelerator.unwrap_model(discriminator).load_state_dict(ckpt["D"]) optimizer_g.load_state_dict(ckpt["optG"]) optimizer_d.load_state_dict(ckpt["optD"]) - scaler.load_state_dict(ckpt["scaler"]) scheduler_g.load_state_dict(ckpt["schedG"]) scheduler_d.load_state_dict(ckpt["schedD"]) + start_epoch = ckpt.get("epoch", 1) + accelerator.print(f"๐Ÿ” | Resumed from epoch {start_epoch}!") -# --------------------------- -# Training loop (safer) -# --------------------------- +real_buf = torch.full( + (loader_batch_size, 1), 1, device=accelerator.device, dtype=torch.float32 +) +fake_buf = torch.zeros( + (loader_batch_size, 1), device=accelerator.device, dtype=torch.float32 +) -if not train_loader or not train_loader.batch_size: - print("There is no data to train with! Exiting...") - exit() - -max_batch = max(1, train_loader.batch_size) -real_buf = torch.full((max_batch, 1), 0.9, device=device) # label smoothing -fake_buf = torch.zeros(max_batch, 1, device=device) +accelerator.print("๐Ÿ‹๏ธ | Started training...") try: for epoch in range(start_epoch, args.epochs): @@ -193,15 +165,12 @@ try: ) in enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")): batch_size = high_quality.size(0) - high_quality = high_quality.to(device, non_blocking=True) - low_quality = low_quality.to(device, non_blocking=True) - - real_labels = real_buf[:batch_size] - fake_labels = fake_buf[:batch_size] + real_labels = real_buf[:batch_size].to(accelerator.device) + fake_labels = fake_buf[:batch_size].to(accelerator.device) # --- Discriminator --- optimizer_d.zero_grad(set_to_none=True) - with autocast(device_type=device.type): + with accelerator.autocast(): d_loss = discriminator_train( high_quality, low_quality, @@ -212,15 +181,14 @@ try: criterion_d, ) - scaler.scale(d_loss).backward() - scaler.unscale_(optimizer_d) - torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0) - scaler.step(optimizer_d) + accelerator.backward(d_loss) + torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1) + optimizer_d.step() # --- Generator --- optimizer_g.zero_grad(set_to_none=True) - with autocast(device_type=device.type): - g_out, g_total, g_adv = generator_train( + with accelerator.autocast(): + g_total, g_adv = generator_train( low_quality, high_quality, real_labels, @@ -229,20 +197,32 @@ try: criterion_d, ) - scaler.scale(g_total).backward() - scaler.unscale_(optimizer_g) - torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0) - scaler.step(optimizer_g) + accelerator.backward(g_total) + torch.nn.utils.clip_grad_norm_(generator.parameters(), 1) + optimizer_g.step() - scaler.update() + d_val = accelerator.gather(d_loss.detach()).mean() + g_val = accelerator.gather(g_total.detach()).mean() + + if torch.isfinite(d_val): + running_d += d_val.item() + else: + accelerator.print( + f"๐Ÿซฅ | NaN in discriminator loss at step {i}, skipping update." + ) + + if torch.isfinite(g_val): + running_g += g_val.item() + else: + accelerator.print( + f"๐Ÿซฅ | NaN in generator loss at step {i}, skipping update." + ) - running_d += float(d_loss.detach().cpu().item()) - running_g += float(g_total.detach().cpu().item()) steps += 1 # epoch averages & schedulers if steps == 0: - print("No steps in epoch (empty dataloader?). Exiting.") + accelerator.print("๐Ÿชน | No steps in epoch (empty dataloader?). Exiting.") break mean_d = running_d / steps @@ -252,22 +232,14 @@ try: scheduler_g.step(mean_g) save_ckpt(os.path.join(models_dir, "last.pt"), epoch) - print(f"Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}") + accelerator.print(f"๐Ÿค | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}") except Exception: try: save_ckpt(os.path.join(models_dir, "crash_last.pt"), epoch) - print(f"Saved crash checkpoint for epoch {epoch}") + accelerator.print(f"๐Ÿ’พ | Saved crash checkpoint for epoch {epoch}") except Exception as e: - print("Failed saving crash checkpoint:", e) + accelerator.print("๐Ÿ˜ฌ | Failed saving crash checkpoint:", e) raise -try: - torch.save(generator.state_dict(), os.path.join(models_dir, "final_generator.pt")) - torch.save( - discriminator.state_dict(), os.path.join(models_dir, "final_discriminator.pt") - ) -except Exception as e: - print("Failed to save final states:", e) - -print("Training finished.") +accelerator.print("๐Ÿ | Training finished.") diff --git a/training_utils.py b/training_utils.py deleted file mode 100644 index 02403af..0000000 --- a/training_utils.py +++ /dev/null @@ -1,154 +0,0 @@ -import torch -import torchaudio.transforms as T - -from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss - -mel_transform: T.MelSpectrogram -stft_transform: T.Spectrogram -# mfcc_transform: T.MFCC - - -# def init(mel_trans: T.MelSpectrogram, stft_trans: T.Spectrogram, mfcc_trans: T.MFCC): -# """Initializes the global transform variables for the module.""" -# global mel_transform, stft_transform, mfcc_transform -# mel_transform = mel_trans -# stft_transform = stft_trans -# mfcc_transform = mfcc_trans - - -def init(mel_trans: T.MelSpectrogram, stft_trans: T.Spectrogram): - """Initializes the global transform variables for the module.""" - global mel_transform, stft_transform - mel_transform = mel_trans - stft_transform = stft_trans - - -# def mfcc_loss(y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: -# """Computes the Mean Squared Error (MSE) loss on MFCCs.""" -# mfccs_true = mfcc_transform(y_true) -# mfccs_pred = mfcc_transform(y_pred) -# return F.mse_loss(mfccs_pred, mfccs_true) - - -# def mel_spectrogram_loss( -# y_true: torch.Tensor, y_pred: torch.Tensor, loss_type: str = "l1" -# ) -> torch.Tensor: -# """Calculates L1 or L2 loss on the Mel Spectrogram.""" -# mel_spec_true = mel_transform(y_true) -# mel_spec_pred = mel_transform(y_pred) -# if loss_type == "l1": -# return F.l1_loss(mel_spec_pred, mel_spec_true) -# elif loss_type == "l2": -# return F.mse_loss(mel_spec_pred, mel_spec_true) -# else: -# raise ValueError("loss_type must be 'l1' or 'l2'") - - -# def log_stft_magnitude_loss( -# y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7 -# ) -> torch.Tensor: -# """Calculates L1 loss on the log STFT magnitude.""" -# stft_mag_true = stft_transform(y_true) -# stft_mag_pred = stft_transform(y_pred) -# return F.l1_loss(torch.log(stft_mag_pred + eps), torch.log(stft_mag_true + eps)) - - -stft_loss_fn = MultiResolutionSTFTLoss( - fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240] -) - - -def discriminator_train( - high_quality, - low_quality, - real_labels, - fake_labels, - discriminator, - generator, - criterion, -): - discriminator_decision_from_real = discriminator(high_quality) - d_loss_real = criterion(discriminator_decision_from_real, real_labels) - - with torch.no_grad(): - generator_output = generator(low_quality) - discriminator_decision_from_fake = discriminator(generator_output) - d_loss_fake = criterion( - discriminator_decision_from_fake, - fake_labels.expand_as(discriminator_decision_from_fake), - ) - - d_loss = (d_loss_real + d_loss_fake) / 2.0 - - return d_loss - - -def generator_train( - low_quality, - high_quality, - real_labels, - generator, - discriminator, - adv_criterion, - lambda_adv: float = 1.0, - lambda_feat: float = 10.0, - lambda_stft: float = 2.5, -): - generator_output = generator(low_quality) - - discriminator_decision = discriminator(generator_output) - # adversarial_loss = adv_criterion( - # discriminator_decision, real_labels.expand_as(discriminator_decision) - # ) - adversarial_loss = adv_criterion(discriminator_decision, real_labels) - - combined_loss = lambda_adv * adversarial_loss - - stft_losses = stft_loss_fn(high_quality, generator_output) - stft_loss = stft_losses["total"] - - combined_loss = (lambda_adv * adversarial_loss) + (lambda_stft * stft_loss) - - return generator_output, combined_loss, adversarial_loss - - -# def generator_train( -# low_quality, -# high_quality, -# real_labels, -# generator, -# discriminator, -# adv_criterion, -# lambda_adv: float = 1.0, -# lambda_mel_l1: float = 10.0, -# lambda_log_stft: float = 1.0, - -# ): -# generator_output = generator(low_quality) - -# discriminator_decision = discriminator(generator_output) -# adversarial_loss = adv_criterion( -# discriminator_decision, real_labels.expand_as(discriminator_decision) -# ) - -# combined_loss = lambda_adv * adversarial_loss - -# if lambda_mel_l1 > 0: -# mel_l1_loss = mel_spectrogram_loss(high_quality, generator_output, "l1") -# combined_loss += lambda_mel_l1 * mel_l1_loss -# else: -# mel_l1_loss = torch.tensor(0.0, device=low_quality.device) # For logging - -# if lambda_log_stft > 0: -# log_stft_loss = log_stft_magnitude_loss(high_quality, generator_output) -# combined_loss += lambda_log_stft * log_stft_loss -# else: -# log_stft_loss = torch.tensor(0.0, device=low_quality.device) - -# if lambda_mfcc > 0: -# mfcc_loss_val = mfcc_loss(high_quality, generator_output) -# combined_loss += lambda_mfcc * mfcc_loss_val -# else: -# mfcc_loss_val = torch.tensor(0.0, device=low_quality.device) - -# return generator_output, combined_loss, adversarial_loss diff --git a/utils/MultiResolutionSTFTLoss.py b/utils/MultiResolutionSTFTLoss.py index 5712fc3..eab6355 100644 --- a/utils/MultiResolutionSTFTLoss.py +++ b/utils/MultiResolutionSTFTLoss.py @@ -8,8 +8,9 @@ import torchaudio.transforms as T class MultiResolutionSTFTLoss(nn.Module): """ - Computes a loss based on multiple STFT resolutions, including both - spectral convergence and log STFT magnitude components. + Multi-resolution STFT loss. + Combines spectral convergence loss and log-magnitude loss + across multiple STFT resolutions. """ def __init__( @@ -20,43 +21,67 @@ class MultiResolutionSTFTLoss(nn.Module): eps: float = 1e-7, ): super().__init__() - self.stft_transforms = nn.ModuleList( - [ - T.Spectrogram( - n_fft=n_fft, win_length=win_len, hop_length=hop_len, power=None - ) - for n_fft, hop_len, win_len in zip(fft_sizes, hop_sizes, win_lengths) - ] - ) + self.eps = eps + self.n_resolutions = len(fft_sizes) + + self.stft_transforms = nn.ModuleList() + for n_fft, hop_len, win_len in zip(fft_sizes, hop_sizes, win_lengths): + window = torch.hann_window(win_len) + stft = T.Spectrogram( + n_fft=n_fft, + hop_length=hop_len, + win_length=win_len, + window_fn=lambda _: window, + power=None, # Keep complex output + center=True, + pad_mode="reflect", + normalized=False, + ) + self.stft_transforms.append(stft) def forward( self, y_true: torch.Tensor, y_pred: torch.Tensor ) -> Dict[str, torch.Tensor]: - sc_loss = 0.0 # Spectral Convergence Loss - mag_loss = 0.0 # Log STFT Magnitude Loss + """ + Args: + y_true: (B, T) or (B, 1, T) waveform + y_pred: (B, T) or (B, 1, T) waveform + """ + # Ensure correct shape (B, T) + if y_true.dim() == 3 and y_true.size(1) == 1: + y_true = y_true.squeeze(1) + if y_pred.dim() == 3 and y_pred.size(1) == 1: + y_pred = y_pred.squeeze(1) + + sc_loss = 0.0 + mag_loss = 0.0 for stft in self.stft_transforms: - stft.to(y_pred.device) # Ensure transform is on the correct device + stft = stft.to(y_pred.device) - # Get complex STFTs + # Complex STFTs: (B, F, T, 2) stft_true = stft(y_true) stft_pred = stft(y_pred) - # Get magnitudes + # Magnitudes stft_mag_true = torch.abs(stft_true) stft_mag_pred = torch.abs(stft_pred) # --- Spectral Convergence Loss --- - # || |S_true| - |S_pred| ||_F / || |S_true| ||_F norm_true = torch.linalg.norm(stft_mag_true, dim=(-2, -1)) norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, dim=(-2, -1)) sc_loss += torch.mean(norm_diff / (norm_true + self.eps)) # --- Log STFT Magnitude Loss --- mag_loss += F.l1_loss( - torch.log(stft_mag_pred + self.eps), torch.log(stft_mag_true + self.eps) + torch.log(stft_mag_pred + self.eps), + torch.log(stft_mag_true + self.eps), ) + # Average across resolutions + sc_loss /= self.n_resolutions + mag_loss /= self.n_resolutions total_loss = sc_loss + mag_loss + return {"total": total_loss, "sc": sc_loss, "mag": mag_loss} diff --git a/utils/TrainingTools.py b/utils/TrainingTools.py new file mode 100644 index 0000000..d1959a0 --- /dev/null +++ b/utils/TrainingTools.py @@ -0,0 +1,60 @@ +import torch + +# In case if needed again... +# from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss +# +# stft_loss_fn = MultiResolutionSTFTLoss( +# fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240] +# ) + + +def signal_mae(input_one: torch.Tensor, input_two: torch.Tensor) -> torch.Tensor: + absolute_difference = torch.abs(input_one - input_two) + return torch.mean(absolute_difference) + + +def discriminator_train( + high_quality, + low_quality, + high_labels, + low_labels, + discriminator, + generator, + criterion, +): + decision_high = discriminator(high_quality) + d_loss_high = criterion(decision_high, high_labels) + # print(f"Is this real?: {discriminator_decision_from_real} | {d_loss_real}") + + decision_low = discriminator(low_quality) + d_loss_low = criterion(decision_low, low_labels) + # print(f"Is this real?: {discriminator_decision_from_fake} | {d_loss_fake}") + + with torch.no_grad(): + generator_quality = generator(low_quality) + decision_gen = discriminator(generator_quality) + d_loss_gen = criterion(decision_gen, low_labels) + + noise = torch.rand_like(high_quality) * 0.08 + decision_noise = discriminator(high_quality + noise) + d_loss_noise = criterion(decision_noise, low_labels) + + d_loss = (d_loss_high + d_loss_low + d_loss_gen + d_loss_noise) / 4.0 + + return d_loss + + +def generator_train( + low_quality, high_quality, real_labels, generator, discriminator, adv_criterion +): + generator_output = generator(low_quality) + + discriminator_decision = discriminator(generator_output) + adversarial_loss = adv_criterion(discriminator_decision, real_labels) + + # Signal similarity + similarity_loss = signal_mae(generator_output, high_quality) + + combined_loss = adversarial_loss + (similarity_loss * 100) + + return combined_loss, adversarial_loss From 782a3bab28e2484893eee1cb6b5dae54b979030a Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Tue, 18 Nov 2025 21:34:59 +0200 Subject: [PATCH 09/11] :alembic: | More architectural changes --- AudioUtils.py | 86 +++++------------------------ app.py | 55 ++++++++++++++---- data.py | 62 +++++++++------------ discriminator.py | 65 ++++++++++------------ generator.py | 95 +++++++++++++++++++++++--------- training.py | 33 +++++++---- utils/MultiResolutionSTFTLoss.py | 43 ++++----------- utils/TrainingTools.py | 60 ++++++++++---------- 8 files changed, 245 insertions(+), 254 deletions(-) diff --git a/AudioUtils.py b/AudioUtils.py index 183dc36..ff6a24f 100644 --- a/AudioUtils.py +++ b/AudioUtils.py @@ -3,95 +3,39 @@ import torch.nn.functional as F def stereo_tensor_to_mono(waveform: torch.Tensor) -> torch.Tensor: - """ - Convert stereo (C, N) to mono (1, N). Ensures a channel dimension. - """ - if waveform.dim() == 1: - waveform = waveform.unsqueeze(0) # (N,) -> (1, N) - - if waveform.shape[0] > 1: - mono_waveform = torch.mean(waveform, dim=0, keepdim=True) # (1, N) - else: - mono_waveform = waveform - return mono_waveform + mono_tensor = torch.mean(waveform, dim=0, keepdim=True) + return mono_tensor -def stretch_tensor(tensor: torch.Tensor, target_length: int) -> torch.Tensor: - """ - Stretch audio along time dimension to target_length. - Input assumed (1, N). Returns (1, target_length). - """ - if tensor.dim() == 1: - tensor = tensor.unsqueeze(0) # ensure (1, N) +def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 512) -> torch.Tensor: + padding_amount = target_length - audio_tensor.size(-1) + if padding_amount <= 0: + return audio_tensor - tensor = tensor.unsqueeze(0) # (1, 1, N) for interpolate - stretched = F.interpolate( - tensor, size=target_length, mode="linear", align_corners=False - ) - return stretched.squeeze(0) # back to (1, target_length) - - -def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 128) -> torch.Tensor: - """ - Pad to fixed length. Input assumed (1, N). Returns (1, target_length). - """ - if audio_tensor.dim() == 1: - audio_tensor = audio_tensor.unsqueeze(0) - - current_length = audio_tensor.shape[-1] - if current_length < target_length: - padding_needed = target_length - current_length - padding_tuple = (0, padding_needed) - padded_audio_tensor = F.pad( - audio_tensor, padding_tuple, mode="constant", value=0 - ) - else: - padded_audio_tensor = audio_tensor[..., :target_length] # crop if too long + padded_audio_tensor = F.pad(audio_tensor, (0, padding_amount)) return padded_audio_tensor -def split_audio( - audio_tensor: torch.Tensor, chunk_size: int = 128 -) -> list[torch.Tensor]: - """ - Split into chunks of (1, chunk_size). - """ - if not isinstance(chunk_size, int) or chunk_size <= 0: - raise ValueError("chunk_size must be a positive integer.") +def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 512, pad_last_tensor: bool = False) -> list[torch.Tensor]: + chunks = list(torch.split(audio_tensor, chunk_size, dim=1)) - if audio_tensor.dim() == 1: - audio_tensor = audio_tensor.unsqueeze(0) + if pad_last_tensor: + last_chunk = chunks[-1] - num_samples = audio_tensor.shape[-1] - if num_samples == 0: - return [] + if last_chunk.size(-1) < chunk_size: + chunks[-1] = pad_tensor(last_chunk, chunk_size) - chunks = list(torch.split(audio_tensor, chunk_size, dim=-1)) return chunks def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor: - """ - Reconstruct audio from chunks. Returns (1, N). - """ - if not chunks: - return torch.empty(1, 0) - - chunks = [c if c.dim() == 2 else c.unsqueeze(0) for c in chunks] - try: - reconstructed_tensor = torch.cat(chunks, dim=-1) - except RuntimeError as e: - raise RuntimeError( - f"Failed to concatenate audio chunks. Ensure chunks have compatible shapes " - f"for concatenation along dim -1. Original error: {e}" - ) - + reconstructed_tensor = torch.cat(chunks, dim=-1) return reconstructed_tensor def normalize(audio_tensor: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: max_val = torch.max(torch.abs(audio_tensor)) if max_val < eps: - return audio_tensor # silence, skip normalization + return audio_tensor return audio_tensor / max_val diff --git a/app.py b/app.py index 006cba5..0bacee2 100644 --- a/app.py +++ b/app.py @@ -4,6 +4,7 @@ import torch import torchaudio import torchcodec import tqdm +from accelerate import Accelerator import AudioUtils from generator import SISUGenerator @@ -15,7 +16,7 @@ parser.add_argument("--model", type=str, help="Model to use for upscaling") parser.add_argument( "--clip_length", type=int, - default=16384, + default=8000, help="Internal clip length, leave unspecified if unsure", ) parser.add_argument( @@ -38,21 +39,44 @@ if args.sample_rate < 8000: ) exit() -device = torch.device(args.device if torch.cuda.is_available() else "cpu") -print(f"Using device: {device}") +# --------------------------- +# Init accelerator +# --------------------------- -generator = SISUGenerator().to(device) +accelerator = Accelerator(mixed_precision="bf16") + +# --------------------------- +# Models +# --------------------------- +generator = SISUGenerator() + +accelerator.print("๐Ÿ”จ | Compiling models...") generator = torch.compile(generator) +accelerator.print("โœ… | Compiling done!") + +# --------------------------- +# Prepare accelerator +# --------------------------- + +generator = accelerator.prepare(generator) + +# --------------------------- +# Checkpoint helpers +# --------------------------- + models_dir = args.model clip_length = args.clip_length input_audio = args.input output_audio = args.output if models_dir: - ckpt = torch.load(models_dir, map_location=device) - generator.load_state_dict(ckpt["G"]) + ckpt = torch.load(models_dir) + + accelerator.unwrap_model(generator).load_state_dict(ckpt["G"]) + + accelerator.print("๐Ÿ’พ | Loaded model!") else: print( "Generator model (--model) isn't specified. Do you have the trained model? If not, you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)" @@ -67,7 +91,8 @@ def start(): audio = decoded_samples.data original_sample_rate = decoded_samples.sample_rate - audio = AudioUtils.stereo_tensor_to_mono(audio) + # Support for multichannel audio + # audio = AudioUtils.stereo_tensor_to_mono(audio) audio = AudioUtils.normalize(audio) resample_transform = torchaudio.transforms.Resample( @@ -77,14 +102,20 @@ def start(): audio = resample_transform(audio) splitted_audio = AudioUtils.split_audio(audio, clip_length) - splitted_audio_on_device = [t.to(device) for t in splitted_audio] + splitted_audio_on_device = [t.view(1, t.shape[0], t.shape[-1]).to(accelerator.device) for t in splitted_audio] processed_audio = [] - - for clip in tqdm.tqdm(splitted_audio_on_device, desc="Processing..."): - processed_audio.append(generator(clip)) + with torch.no_grad(): + for clip in tqdm.tqdm(splitted_audio_on_device, desc="Processing..."): + channels = [] + for audio_channel in torch.split(clip, 1, dim=1): + output_piece = generator(audio_channel) + channels.append(output_piece.detach().cpu()) + output_clip = torch.cat(channels, dim=1) + processed_audio.append(output_clip) reconstructed_audio = AudioUtils.reconstruct_audio(processed_audio) - print(f"Saving {output_audio}!") + reconstructed_audio = reconstructed_audio.squeeze(0) + print(f"๐Ÿ”Š | Saving {output_audio}!") torchaudio.save_with_torchcodec( uri=output_audio, src=reconstructed_audio, diff --git a/data.py b/data.py index a2ddb71..13b0ca8 100644 --- a/data.py +++ b/data.py @@ -1,6 +1,7 @@ import os import random +import torch import torchaudio import torchcodec.decoders as decoders import tqdm @@ -10,9 +11,9 @@ import AudioUtils class AudioDataset(Dataset): - audio_sample_rates = [11025] + audio_sample_rates = [8000, 11025, 12000, 16000, 22050, 24000, 32000, 44100] - def __init__(self, input_dir, clip_length: int = 8000, normalize: bool = True): + def __init__(self, input_dir, clip_length: int = 512, normalize: bool = True): self.clip_length = clip_length self.normalize = normalize @@ -30,45 +31,20 @@ class AudioDataset(Dataset): decoder = decoders.AudioDecoder(audio_clip) decoded_samples = decoder.get_all_samples() - audio = decoded_samples.data.float() # ensure float32 + audio = decoded_samples.data.float() original_sample_rate = decoded_samples.sample_rate - audio = AudioUtils.stereo_tensor_to_mono(audio) if normalize: audio = AudioUtils.normalize(audio) - mangled_sample_rate = random.choice(self.audio_sample_rates) - resample_transform_low = torchaudio.transforms.Resample( - original_sample_rate, mangled_sample_rate - ) - resample_transform_high = torchaudio.transforms.Resample( - mangled_sample_rate, original_sample_rate - ) + splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length, True) - low_audio = resample_transform_high(resample_transform_low(audio)) + if not splitted_high_quality_audio: + continue - splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length) - splitted_low_quality_audio = AudioUtils.split_audio(low_audio, clip_length) - - if not splitted_high_quality_audio or not splitted_low_quality_audio: - continue # skip empty or invalid clips - - splitted_high_quality_audio[-1] = AudioUtils.pad_tensor( - splitted_high_quality_audio[-1], clip_length - ) - splitted_low_quality_audio[-1] = AudioUtils.pad_tensor( - splitted_low_quality_audio[-1], clip_length - ) - - for high_quality_data, low_quality_data in zip( - splitted_high_quality_audio, splitted_low_quality_audio - ): - data.append( - ( - (high_quality_data, low_quality_data), - (original_sample_rate, mangled_sample_rate), - ) - ) + for splitted_audio_clip in splitted_high_quality_audio: + for audio_clip in torch.split(splitted_audio_clip, 1): + data.append((audio_clip, original_sample_rate)) self.audio_data = data @@ -76,4 +52,20 @@ class AudioDataset(Dataset): return len(self.audio_data) def __getitem__(self, idx): - return self.audio_data[idx] + audio_clip = self.audio_data[idx] + mangled_sample_rate = random.choice(self.audio_sample_rates) + + resample_transform_low = torchaudio.transforms.Resample( + audio_clip[1], mangled_sample_rate + ) + + resample_transform_high = torchaudio.transforms.Resample( + mangled_sample_rate, audio_clip[1] + ) + + low_audio_clip = resample_transform_high(resample_transform_low(audio_clip[0])) + if audio_clip[0].shape[1] < low_audio_clip.shape[1]: + low_audio_clip = low_audio_clip[:, :audio_clip[0].shape[1]] + elif audio_clip[0].shape[1] > low_audio_clip.shape[1]: + low_audio_clip = AudioUtils.pad_tensor(low_audio_clip, self.clip_length) + return ((audio_clip[0], low_audio_clip), (audio_clip[1], mangled_sample_rate)) diff --git a/discriminator.py b/discriminator.py index 5e8442b..69de5ce 100644 --- a/discriminator.py +++ b/discriminator.py @@ -5,32 +5,25 @@ import torch.nn.utils as utils def discriminator_block( in_channels, out_channels, - kernel_size=3, + kernel_size=15, stride=1, - dilation=1, - spectral_norm=True, - use_instance_norm=True, + dilation=1 ): - padding = (kernel_size // 2) * dilation + padding = dilation * (kernel_size - 1) // 2 + conv_layer = nn.Conv1d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, - padding=padding, + padding=padding ) - if spectral_norm: - conv_layer = utils.spectral_norm(conv_layer) + conv_layer = utils.spectral_norm(conv_layer) + leaky_relu = nn.LeakyReLU(0.2) - layers = [conv_layer] - layers.append(nn.LeakyReLU(0.2, inplace=True)) - - if use_instance_norm: - layers.append(nn.InstanceNorm1d(out_channels)) - - return nn.Sequential(*layers) + return nn.Sequential(conv_layer, leaky_relu) class AttentionBlock(nn.Module): @@ -38,38 +31,40 @@ class AttentionBlock(nn.Module): super(AttentionBlock, self).__init__() self.attention = nn.Sequential( nn.Conv1d(channels, channels // 4, kernel_size=1), - nn.ReLU(inplace=True), + nn.ReLU(), nn.Conv1d(channels // 4, channels, kernel_size=1), nn.Sigmoid(), ) def forward(self, x): attention_weights = self.attention(x) - return x * attention_weights + return x + (x * attention_weights) class SISUDiscriminator(nn.Module): - def __init__(self, layers=32): + def __init__(self, layers=8): super(SISUDiscriminator, self).__init__() - self.model = nn.Sequential( - discriminator_block(1, layers, kernel_size=7, stride=1), - discriminator_block(layers, layers * 2, kernel_size=5, stride=2), - discriminator_block(layers * 2, layers * 4, kernel_size=5, dilation=2), + self.discriminator_blocks = nn.Sequential( + # 1 -> 32 + discriminator_block(2, layers), + AttentionBlock(layers), + # 32 -> 64 + discriminator_block(layers, layers * 2, dilation=2), + # 64 -> 128 + discriminator_block(layers * 2, layers * 4, dilation=4), AttentionBlock(layers * 4), - discriminator_block(layers * 4, layers * 8, kernel_size=5, dilation=4), - discriminator_block(layers * 8, layers * 2, kernel_size=5, stride=2), - discriminator_block( - layers * 2, - 1, - spectral_norm=False, - use_instance_norm=False, - ), + # 128 -> 256 + discriminator_block(layers * 4, layers * 8, stride=4), + # 256 -> 512 + # discriminator_block(layers * 8, layers * 16, stride=4) ) - self.global_avg_pool = nn.AdaptiveAvgPool1d(1) + self.final_conv = nn.Conv1d(layers * 8, 1, kernel_size=3, padding=1) + + self.avg_pool = nn.AdaptiveAvgPool1d(1) def forward(self, x): - x = self.model(x) - x = self.global_avg_pool(x) - x = x.view(x.size(0), -1) - return x + x = self.discriminator_blocks(x) + x = self.final_conv(x) + x = self.avg_pool(x) + return x.squeeze(2) diff --git a/generator.py b/generator.py index b6d2204..bc994ac 100644 --- a/generator.py +++ b/generator.py @@ -2,25 +2,23 @@ import torch import torch.nn as nn -def conv_block(in_channels, out_channels, kernel_size=3, dilation=1): +def GeneratorBlock(in_channels, out_channels, kernel_size=3, stride=1, dilation=1): + padding = (kernel_size - 1) // 2 * dilation return nn.Sequential( nn.Conv1d( in_channels, out_channels, kernel_size=kernel_size, + stride=stride, dilation=dilation, - padding=(kernel_size // 2) * dilation, + padding=padding ), nn.InstanceNorm1d(out_channels), - nn.PReLU(), + nn.PReLU(num_parameters=1, init=0.1), ) class AttentionBlock(nn.Module): - """ - Simple Channel Attention Block. Learns to weight channels based on their importance. - """ - def __init__(self, channels): super(AttentionBlock, self).__init__() self.attention = nn.Sequential( @@ -32,7 +30,7 @@ class AttentionBlock(nn.Module): def forward(self, x): attention_weights = self.attention(x) - return x * attention_weights + return x + (x * attention_weights) class ResidualInResidualBlock(nn.Module): @@ -40,7 +38,7 @@ class ResidualInResidualBlock(nn.Module): super(ResidualInResidualBlock, self).__init__() self.conv_layers = nn.Sequential( - *[conv_block(channels, channels) for _ in range(num_convs)] + *[GeneratorBlock(channels, channels) for _ in range(num_convs)] ) self.attention = AttentionBlock(channels) @@ -51,31 +49,74 @@ class ResidualInResidualBlock(nn.Module): x = self.attention(x) return x + residual +def UpsampleBlock(in_channels, out_channels): + return nn.Sequential( + nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=4, + stride=2, + padding=1 + ), + nn.InstanceNorm1d(out_channels), + nn.PReLU(num_parameters=1, init=0.1) + ) class SISUGenerator(nn.Module): - def __init__(self, channels=16, num_rirb=4, alpha=1): + def __init__(self, channels=32, num_rirb=1): super(SISUGenerator, self).__init__() - self.alpha = alpha - self.conv1 = nn.Sequential( - nn.Conv1d(1, channels, kernel_size=7, padding=3), - nn.InstanceNorm1d(channels), - nn.PReLU(), + self.first_conv = GeneratorBlock(1, channels) + + self.downsample = GeneratorBlock(channels, channels * 2, stride=2) + self.downsample_attn = AttentionBlock(channels * 2) + self.downsample_2 = GeneratorBlock(channels * 2, channels * 4, stride=2) + self.downsample_2_attn = AttentionBlock(channels * 4) + + self.rirb = ResidualInResidualBlock(channels * 4) + # self.rirb = nn.Sequential( + # *[ResidualInResidualBlock(channels * 4) for _ in range(num_rirb)] + # ) + + self.upsample = UpsampleBlock(channels * 4, channels * 2) + self.upsample_attn = AttentionBlock(channels * 2) + self.compress_1 = GeneratorBlock(channels * 4, channels * 2) + + self.upsample_2 = UpsampleBlock(channels * 2, channels) + self.upsample_2_attn = AttentionBlock(channels) + self.compress_2 = GeneratorBlock(channels * 2, channels) + + self.final_conv = nn.Sequential( + nn.Conv1d(channels, 1, kernel_size=7, padding=3), + nn.Tanh() ) - self.rir_blocks = nn.Sequential( - *[ResidualInResidualBlock(channels) for _ in range(num_rirb)] - ) - - self.final_layer = nn.Sequential( - nn.Conv1d(channels, 1, kernel_size=3, padding=1), nn.Tanh() - ) def forward(self, x): residual_input = x - x = self.conv1(x) - x_rirb_out = self.rir_blocks(x) - learned_residual = self.final_layer(x_rirb_out) - output = residual_input + self.alpha * learned_residual + x1 = self.first_conv(x) - return torch.tanh(output) + x2 = self.downsample(x1) + x2 = self.downsample_attn(x2) + + x3 = self.downsample_2(x2) + x3 = self.downsample_2_attn(x3) + + x_rirb = self.rirb(x3) + + up1 = self.upsample(x_rirb) + up1 = self.upsample_attn(up1) + + cat1 = torch.cat((up1, x2), dim=1) + comp1 = self.compress_1(cat1) + + up2 = self.upsample_2(comp1) + up2 = self.upsample_2_attn(up2) + + cat2 = torch.cat((up2, x1), dim=1) + comp2 = self.compress_2(cat2) + + learned_residual = self.final_conv(comp2) + + output = residual_input + learned_residual + return output diff --git a/training.py b/training.py index 0e0caf8..6f962f3 100644 --- a/training.py +++ b/training.py @@ -1,4 +1,5 @@ import argparse +import datetime import os import torch @@ -52,7 +53,7 @@ accelerator.print("โœ… | Compiling done!") # Dataset / DataLoader # --------------------------- accelerator.print("๐Ÿ“Š | Fetching dataset...") -dataset = AudioDataset("./dataset") +dataset = AudioDataset("./dataset", 8192) sampler = DistributedSampler(dataset) if accelerator.num_processes > 1 else None pin_memory = torch.cuda.is_available() and not args.no_pin_memory @@ -93,7 +94,6 @@ scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer_d, mode="min", factor=0.5, patience=5 ) -criterion_g = nn.BCEWithLogitsLoss() criterion_d = nn.MSELoss() # --------------------------- @@ -143,12 +143,8 @@ if args.resume: start_epoch = ckpt.get("epoch", 1) accelerator.print(f"๐Ÿ” | Resumed from epoch {start_epoch}!") -real_buf = torch.full( - (loader_batch_size, 1), 1, device=accelerator.device, dtype=torch.float32 -) -fake_buf = torch.zeros( - (loader_batch_size, 1), device=accelerator.device, dtype=torch.float32 -) +real_buf = torch.full((loader_batch_size, 1), 1, device=accelerator.device, dtype=torch.float32) +fake_buf = torch.zeros((loader_batch_size, 1), device=accelerator.device, dtype=torch.float32) accelerator.print("๐Ÿ‹๏ธ | Started training...") @@ -157,35 +153,45 @@ try: generator.train() discriminator.train() + discriminator_time = 0 + generator_time = 0 + running_d, running_g, steps = 0.0, 0.0, 0 + progress_bar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch} | D {discriminator_time}ฮผs | G {generator_time}ฮผs") + for i, ( (high_quality, low_quality), (high_sample_rate, low_sample_rate), - ) in enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")): + ) in enumerate(progress_bar): batch_size = high_quality.size(0) real_labels = real_buf[:batch_size].to(accelerator.device) fake_labels = fake_buf[:batch_size].to(accelerator.device) + with accelerator.autocast(): + generator_output = generator(low_quality) + # --- Discriminator --- + d_time = datetime.datetime.now() optimizer_d.zero_grad(set_to_none=True) with accelerator.autocast(): d_loss = discriminator_train( high_quality, - low_quality, + low_quality.detach(), real_labels, fake_labels, discriminator, - generator, criterion_d, + generator_output.detach() ) accelerator.backward(d_loss) - torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1) optimizer_d.step() + discriminator_time = (datetime.datetime.now() - d_time).microseconds # --- Generator --- + g_time = datetime.datetime.now() optimizer_g.zero_grad(set_to_none=True) with accelerator.autocast(): g_total, g_adv = generator_train( @@ -195,11 +201,13 @@ try: generator, discriminator, criterion_d, + generator_output ) accelerator.backward(g_total) torch.nn.utils.clip_grad_norm_(generator.parameters(), 1) optimizer_g.step() + generator_time = (datetime.datetime.now() - g_time).microseconds d_val = accelerator.gather(d_loss.detach()).mean() g_val = accelerator.gather(g_total.detach()).mean() @@ -219,6 +227,7 @@ try: ) steps += 1 + progress_bar.set_description(f"Epoch {epoch} | D {discriminator_time}ฮผs | G {generator_time}ฮผs") # epoch averages & schedulers if steps == 0: diff --git a/utils/MultiResolutionSTFTLoss.py b/utils/MultiResolutionSTFTLoss.py index eab6355..560191a 100644 --- a/utils/MultiResolutionSTFTLoss.py +++ b/utils/MultiResolutionSTFTLoss.py @@ -7,18 +7,13 @@ import torchaudio.transforms as T class MultiResolutionSTFTLoss(nn.Module): - """ - Multi-resolution STFT loss. - Combines spectral convergence loss and log-magnitude loss - across multiple STFT resolutions. - """ - def __init__( self, - fft_sizes: List[int] = [1024, 2048, 512], - hop_sizes: List[int] = [120, 240, 50], - win_lengths: List[int] = [600, 1200, 240], + fft_sizes: List[int] = [512, 1024, 2048, 4096, 8192], + hop_sizes: List[int] = [64, 128, 256, 512, 1024], + win_lengths: List[int] = [256, 512, 1024, 2048, 4096], eps: float = 1e-7, + center: bool = True ): super().__init__() @@ -26,15 +21,14 @@ class MultiResolutionSTFTLoss(nn.Module): self.n_resolutions = len(fft_sizes) self.stft_transforms = nn.ModuleList() - for n_fft, hop_len, win_len in zip(fft_sizes, hop_sizes, win_lengths): - window = torch.hann_window(win_len) + for i, (n_fft, hop_len, win_len) in enumerate(zip(fft_sizes, hop_sizes, win_lengths)): stft = T.Spectrogram( n_fft=n_fft, hop_length=hop_len, win_length=win_len, - window_fn=lambda _: window, - power=None, # Keep complex output - center=True, + window_fn=torch.hann_window, + power=None, + center=center, pad_mode="reflect", normalized=False, ) @@ -43,12 +37,6 @@ class MultiResolutionSTFTLoss(nn.Module): def forward( self, y_true: torch.Tensor, y_pred: torch.Tensor ) -> Dict[str, torch.Tensor]: - """ - Args: - y_true: (B, T) or (B, 1, T) waveform - y_pred: (B, T) or (B, 1, T) waveform - """ - # Ensure correct shape (B, T) if y_true.dim() == 3 and y_true.size(1) == 1: y_true = y_true.squeeze(1) if y_pred.dim() == 3 and y_pred.size(1) == 1: @@ -58,28 +46,21 @@ class MultiResolutionSTFTLoss(nn.Module): mag_loss = 0.0 for stft in self.stft_transforms: - stft = stft.to(y_pred.device) - - # Complex STFTs: (B, F, T, 2) + stft.window = stft.window.to(y_true.device) stft_true = stft(y_true) stft_pred = stft(y_pred) - # Magnitudes stft_mag_true = torch.abs(stft_true) stft_mag_pred = torch.abs(stft_pred) - # --- Spectral Convergence Loss --- norm_true = torch.linalg.norm(stft_mag_true, dim=(-2, -1)) norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, dim=(-2, -1)) sc_loss += torch.mean(norm_diff / (norm_true + self.eps)) - # --- Log STFT Magnitude Loss --- - mag_loss += F.l1_loss( - torch.log(stft_mag_pred + self.eps), - torch.log(stft_mag_true + self.eps), - ) + log_mag_pred = torch.log(stft_mag_pred + self.eps) + log_mag_true = torch.log(stft_mag_true + self.eps) + mag_loss += F.l1_loss(log_mag_pred, log_mag_true) - # Average across resolutions sc_loss /= self.n_resolutions mag_loss /= self.n_resolutions total_loss = sc_loss + mag_loss diff --git a/utils/TrainingTools.py b/utils/TrainingTools.py index d1959a0..581a890 100644 --- a/utils/TrainingTools.py +++ b/utils/TrainingTools.py @@ -1,12 +1,17 @@ import torch -# In case if needed again... -# from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss -# -# stft_loss_fn = MultiResolutionSTFTLoss( -# fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240] -# ) +from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss +# stft_loss_fn = MultiResolutionSTFTLoss( +# fft_sizes=[512, 1024, 2048, 4096], +# hop_sizes=[128, 256, 512, 1024], +# win_lengths=[512, 1024, 2048, 4096] +# ) +stft_loss_fn = MultiResolutionSTFTLoss( + fft_sizes=[512, 1024, 2048], + hop_sizes=[64, 128, 256], + win_lengths=[256, 512, 1024] +) def signal_mae(input_one: torch.Tensor, input_two: torch.Tensor) -> torch.Tensor: absolute_difference = torch.abs(input_one - input_two) @@ -19,42 +24,35 @@ def discriminator_train( high_labels, low_labels, discriminator, - generator, criterion, + generator_output ): - decision_high = discriminator(high_quality) - d_loss_high = criterion(decision_high, high_labels) - # print(f"Is this real?: {discriminator_decision_from_real} | {d_loss_real}") - decision_low = discriminator(low_quality) - d_loss_low = criterion(decision_low, low_labels) - # print(f"Is this real?: {discriminator_decision_from_fake} | {d_loss_fake}") + real_pair = torch.cat((low_quality, high_quality), dim=1) + decision_real = discriminator(real_pair) + d_loss_real = criterion(decision_real, high_labels) - with torch.no_grad(): - generator_quality = generator(low_quality) - decision_gen = discriminator(generator_quality) - d_loss_gen = criterion(decision_gen, low_labels) - - noise = torch.rand_like(high_quality) * 0.08 - decision_noise = discriminator(high_quality + noise) - d_loss_noise = criterion(decision_noise, low_labels) - - d_loss = (d_loss_high + d_loss_low + d_loss_gen + d_loss_noise) / 4.0 + fake_pair = torch.cat((low_quality, generator_output), dim=1) + decision_fake = discriminator(fake_pair) + d_loss_fake = criterion(decision_fake, low_labels) + d_loss = (d_loss_real + d_loss_fake) / 2.0 return d_loss def generator_train( - low_quality, high_quality, real_labels, generator, discriminator, adv_criterion -): - generator_output = generator(low_quality) + low_quality, high_quality, real_labels, generator, discriminator, adv_criterion, generator_output): - discriminator_decision = discriminator(generator_output) + fake_pair = torch.cat((low_quality, generator_output), dim=1) + + discriminator_decision = discriminator(fake_pair) adversarial_loss = adv_criterion(discriminator_decision, real_labels) - # Signal similarity - similarity_loss = signal_mae(generator_output, high_quality) - - combined_loss = adversarial_loss + (similarity_loss * 100) + mae_loss = signal_mae(generator_output, high_quality) + stft_loss = stft_loss_fn(high_quality, generator_output)["total"] + lambda_mae = 10.0 + lambda_stft = 2.5 + lambda_adv = 2.5 + combined_loss = (lambda_mae * mae_loss) + (lambda_stft * stft_loss) + (lambda_adv * adversarial_loss) return combined_loss, adversarial_loss From bf0a6e58e96616e69003e3b223e56bd95a4bc370 Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Thu, 4 Dec 2025 14:22:48 +0200 Subject: [PATCH 10/11] :alembic: | Added MultiPeriodDiscriminator implementation from HiFi-GAN --- discriminator.py | 142 ++++++++++++++++++++++++----------------- generator.py | 40 ++++++------ training.py | 44 ++++++------- utils/TrainingTools.py | 115 ++++++++++++++++++++++++--------- 4 files changed, 210 insertions(+), 131 deletions(-) diff --git a/discriminator.py b/discriminator.py index 69de5ce..8e9e2ec 100644 --- a/discriminator.py +++ b/discriminator.py @@ -1,70 +1,98 @@ +import torch import torch.nn as nn import torch.nn.utils as utils +import numpy as np +class PatchEmbedding(nn.Module): + """ + Converts raw audio into a sequence of embeddings (tokens). + Small patch_size = Higher Precision (more tokens, finer detail). + Large patch_size = Lower Precision (fewer tokens, more global). + """ + def __init__(self, in_channels, embed_dim, patch_size, spectral_norm=True): + super().__init__() + # We use a Conv1d with stride=patch_size to create non-overlapping patches + self.proj = nn.Conv1d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) -def discriminator_block( - in_channels, - out_channels, - kernel_size=15, - stride=1, - dilation=1 -): - padding = dilation * (kernel_size - 1) // 2 - - conv_layer = nn.Conv1d( - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - padding=padding - ) - - conv_layer = utils.spectral_norm(conv_layer) - leaky_relu = nn.LeakyReLU(0.2) - - return nn.Sequential(conv_layer, leaky_relu) - - -class AttentionBlock(nn.Module): - def __init__(self, channels): - super(AttentionBlock, self).__init__() - self.attention = nn.Sequential( - nn.Conv1d(channels, channels // 4, kernel_size=1), - nn.ReLU(), - nn.Conv1d(channels // 4, channels, kernel_size=1), - nn.Sigmoid(), - ) + if spectral_norm: + self.proj = utils.spectral_norm(self.proj) def forward(self, x): - attention_weights = self.attention(x) - return x + (x * attention_weights) + # x shape: (batch, 1, 8000) + x = self.proj(x) # shape: (batch, embed_dim, num_patches) + x = x.transpose(1, 2) # shape: (batch, num_patches, embed_dim) + return x +class TransformerDiscriminator(nn.Module): + def __init__( + self, + audio_length=8000, + patch_size=16, # Lower this for higher precision (e.g., 8 or 16) + embed_dim=128, # Dimension of the transformer tokens + depth=4, # Number of Transformer blocks + heads=4, # Number of attention heads + mlp_dim=256, # Hidden dimension of the feed-forward layer + spectral_norm=True + ): + super().__init__() -class SISUDiscriminator(nn.Module): - def __init__(self, layers=8): - super(SISUDiscriminator, self).__init__() - self.discriminator_blocks = nn.Sequential( - # 1 -> 32 - discriminator_block(2, layers), - AttentionBlock(layers), - # 32 -> 64 - discriminator_block(layers, layers * 2, dilation=2), - # 64 -> 128 - discriminator_block(layers * 2, layers * 4, dilation=4), - AttentionBlock(layers * 4), - # 128 -> 256 - discriminator_block(layers * 4, layers * 8, stride=4), - # 256 -> 512 - # discriminator_block(layers * 8, layers * 16, stride=4) + # 1. Calculate sequence length + self.num_patches = audio_length // patch_size + + # 2. Patch Embedding (Tokenizer) + self.patch_embed = PatchEmbedding(1, embed_dim, patch_size, spectral_norm) + + # 3. Class Token (like in BERT/ViT) to aggregate global info + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + + # 4. Positional Embedding (Learnable) + self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim)) + + # 5. Transformer Encoder + encoder_layer = nn.TransformerEncoderLayer( + d_model=embed_dim, + nhead=heads, + dim_feedforward=mlp_dim, + dropout=0.1, + activation='gelu', + batch_first=True ) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth) - self.final_conv = nn.Conv1d(layers * 8, 1, kernel_size=3, padding=1) + # 6. Final Classification Head + self.norm = nn.LayerNorm(embed_dim) + self.head = nn.Linear(embed_dim, 1) - self.avg_pool = nn.AdaptiveAvgPool1d(1) + if spectral_norm: + self.head = utils.spectral_norm(self.head) + + # Initialize weights + self._init_weights() + + def _init_weights(self): + nn.init.normal_(self.cls_token, std=0.02) + nn.init.normal_(self.pos_embed, std=0.02) def forward(self, x): - x = self.discriminator_blocks(x) - x = self.final_conv(x) - x = self.avg_pool(x) - return x.squeeze(2) + b, c, t = x.shape + + # --- 1. Tokenize Audio --- + x = self.patch_embed(x) # (Batch, Num_Patches, Embed_Dim) + + # --- 2. Add CLS Token --- + cls_tokens = self.cls_token.expand(b, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) # (Batch, Num_Patches + 1, Embed_Dim) + + # --- 3. Add Positional Embeddings --- + x = x + self.pos_embed + + # --- 4. Transformer Layers --- + x = self.transformer(x) + + # --- 5. Classification (Use only CLS token) --- + cls_output = x[:, 0] # Take the first token + cls_output = self.norm(cls_output) + + score = self.head(cls_output) # (Batch, 1) + + return score diff --git a/generator.py b/generator.py index bc994ac..15279b1 100644 --- a/generator.py +++ b/generator.py @@ -1,19 +1,20 @@ import torch import torch.nn as nn - +from torch.nn.utils.parametrizations import weight_norm def GeneratorBlock(in_channels, out_channels, kernel_size=3, stride=1, dilation=1): padding = (kernel_size - 1) // 2 * dilation + return nn.Sequential( - nn.Conv1d( + + weight_norm(nn.Conv1d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding - ), - nn.InstanceNorm1d(out_channels), + )), nn.PReLU(num_parameters=1, init=0.1), ) @@ -22,9 +23,9 @@ class AttentionBlock(nn.Module): def __init__(self, channels): super(AttentionBlock, self).__init__() self.attention = nn.Sequential( - nn.Conv1d(channels, channels // 4, kernel_size=1), + weight_norm(nn.Conv1d(channels, channels // 4, kernel_size=1)), nn.ReLU(inplace=True), - nn.Conv1d(channels // 4, channels, kernel_size=1), + weight_norm(nn.Conv1d(channels // 4, channels, kernel_size=1)), nn.Sigmoid(), ) @@ -49,21 +50,21 @@ class ResidualInResidualBlock(nn.Module): x = self.attention(x) return x + residual -def UpsampleBlock(in_channels, out_channels): +def UpsampleBlock(in_channels, out_channels, scale_factor=2): return nn.Sequential( - nn.ConvTranspose1d( + nn.Upsample(scale_factor=scale_factor, mode='nearest'), + weight_norm(nn.Conv1d( in_channels=in_channels, out_channels=out_channels, - kernel_size=4, - stride=2, + kernel_size=3, + stride=1, padding=1 - ), - nn.InstanceNorm1d(out_channels), + )), nn.PReLU(num_parameters=1, init=0.1) ) class SISUGenerator(nn.Module): - def __init__(self, channels=32, num_rirb=1): + def __init__(self, channels=32, num_rirb=4): super(SISUGenerator, self).__init__() self.first_conv = GeneratorBlock(1, channels) @@ -73,10 +74,9 @@ class SISUGenerator(nn.Module): self.downsample_2 = GeneratorBlock(channels * 2, channels * 4, stride=2) self.downsample_2_attn = AttentionBlock(channels * 4) - self.rirb = ResidualInResidualBlock(channels * 4) - # self.rirb = nn.Sequential( - # *[ResidualInResidualBlock(channels * 4) for _ in range(num_rirb)] - # ) + self.rirb = nn.Sequential( + *[ResidualInResidualBlock(channels * 4) for _ in range(num_rirb)] + ) self.upsample = UpsampleBlock(channels * 4, channels * 2) self.upsample_attn = AttentionBlock(channels * 2) @@ -87,13 +87,15 @@ class SISUGenerator(nn.Module): self.compress_2 = GeneratorBlock(channels * 2, channels) self.final_conv = nn.Sequential( - nn.Conv1d(channels, 1, kernel_size=7, padding=3), + weight_norm(nn.Conv1d(channels, 1, kernel_size=7, padding=3)), nn.Tanh() ) def forward(self, x): residual_input = x + + # Encoding x1 = self.first_conv(x) x2 = self.downsample(x1) @@ -102,8 +104,10 @@ class SISUGenerator(nn.Module): x3 = self.downsample_2(x2) x3 = self.downsample_2_attn(x3) + # Bottleneck (Deep Residual processing) x_rirb = self.rirb(x3) + # Decoding with Skip Connections up1 = self.upsample(x_rirb) up1 = self.upsample_attn(up1) diff --git a/training.py b/training.py index 6f962f3..35733be 100644 --- a/training.py +++ b/training.py @@ -3,7 +3,6 @@ import datetime import os import torch -import torch.nn as nn import torch.optim as optim import tqdm from accelerate import Accelerator @@ -23,7 +22,7 @@ parser.add_argument( "--epochs", type=int, default=5000, help="Number of training epochs" ) parser.add_argument("--batch_size", type=int, default=8, help="Batch size") -parser.add_argument("--num_workers", type=int, default=2, help="DataLoader num_workers") +parser.add_argument("--num_workers", type=int, default=4, help="DataLoader num_workers") # Increased workers slightly parser.add_argument("--debug", action="store_true", help="Print debug logs") parser.add_argument( "--no_pin_memory", action="store_true", help="Disable pin_memory even on CUDA" @@ -94,8 +93,6 @@ scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer_d, mode="min", factor=0.5, patience=5 ) -criterion_d = nn.MSELoss() - # --------------------------- # Prepare accelerator # --------------------------- @@ -131,23 +128,25 @@ def save_ckpt(path, epoch): start_epoch = 0 if args.resume: ckpt_path = os.path.join(models_dir, "last.pt") - ckpt = torch.load(ckpt_path) + if os.path.exists(ckpt_path): + ckpt = torch.load(ckpt_path) - accelerator.unwrap_model(generator).load_state_dict(ckpt["G"]) - accelerator.unwrap_model(discriminator).load_state_dict(ckpt["D"]) - optimizer_g.load_state_dict(ckpt["optG"]) - optimizer_d.load_state_dict(ckpt["optD"]) - scheduler_g.load_state_dict(ckpt["schedG"]) - scheduler_d.load_state_dict(ckpt["schedD"]) + accelerator.unwrap_model(generator).load_state_dict(ckpt["G"]) + accelerator.unwrap_model(discriminator).load_state_dict(ckpt["D"]) + optimizer_g.load_state_dict(ckpt["optG"]) + optimizer_d.load_state_dict(ckpt["optD"]) + scheduler_g.load_state_dict(ckpt["schedG"]) + scheduler_d.load_state_dict(ckpt["schedD"]) - start_epoch = ckpt.get("epoch", 1) - accelerator.print(f"๐Ÿ” | Resumed from epoch {start_epoch}!") - -real_buf = torch.full((loader_batch_size, 1), 1, device=accelerator.device, dtype=torch.float32) -fake_buf = torch.zeros((loader_batch_size, 1), device=accelerator.device, dtype=torch.float32) + start_epoch = ckpt.get("epoch", 1) + accelerator.print(f"๐Ÿ” | Resumed from epoch {start_epoch}!") + else: + accelerator.print("โš ๏ธ | Resume requested but no checkpoint found. Starting fresh.") accelerator.print("๐Ÿ‹๏ธ | Started training...") +smallest_loss = float('inf') + try: for epoch in range(start_epoch, args.epochs): generator.train() @@ -164,11 +163,6 @@ try: (high_quality, low_quality), (high_sample_rate, low_sample_rate), ) in enumerate(progress_bar): - batch_size = high_quality.size(0) - - real_labels = real_buf[:batch_size].to(accelerator.device) - fake_labels = fake_buf[:batch_size].to(accelerator.device) - with accelerator.autocast(): generator_output = generator(low_quality) @@ -179,10 +173,7 @@ try: d_loss = discriminator_train( high_quality, low_quality.detach(), - real_labels, - fake_labels, discriminator, - criterion_d, generator_output.detach() ) @@ -197,10 +188,8 @@ try: g_total, g_adv = generator_train( low_quality, high_quality, - real_labels, generator, discriminator, - criterion_d, generator_output ) @@ -241,6 +230,9 @@ try: scheduler_g.step(mean_g) save_ckpt(os.path.join(models_dir, "last.pt"), epoch) + if smallest_loss > mean_g: + smallest_loss = mean_g + save_ckpt(os.path.join(models_dir, "latest-smallest_loss.pt"), epoch) accelerator.print(f"๐Ÿค | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}") except Exception: diff --git a/utils/TrainingTools.py b/utils/TrainingTools.py index 581a890..7ed2d0f 100644 --- a/utils/TrainingTools.py +++ b/utils/TrainingTools.py @@ -1,58 +1,113 @@ import torch - +import torch.nn.functional as F from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss -# stft_loss_fn = MultiResolutionSTFTLoss( -# fft_sizes=[512, 1024, 2048, 4096], -# hop_sizes=[128, 256, 512, 1024], -# win_lengths=[512, 1024, 2048, 4096] -# ) + stft_loss_fn = MultiResolutionSTFTLoss( fft_sizes=[512, 1024, 2048], hop_sizes=[64, 128, 256], win_lengths=[256, 512, 1024] ) -def signal_mae(input_one: torch.Tensor, input_two: torch.Tensor) -> torch.Tensor: - absolute_difference = torch.abs(input_one - input_two) - return torch.mean(absolute_difference) +def feature_matching_loss(fmap_r, fmap_g): + """ + Computes L1 distance between real and fake feature maps. + """ + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + # Stop gradient on real features to save memory/computation + rl = rl.detach() + loss += torch.mean(torch.abs(rl - gl)) + + # Scale by number of feature maps to keep loss magnitude reasonable + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + """ + Least Squares GAN Loss (LSGAN) for the Discriminator. + Objective: Real -> 1, Fake -> 0 + """ + loss = 0 + r_losses = [] + g_losses = [] + + # Iterate over both MPD and MSD outputs + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + # Real should be 1.0 + r_loss = torch.mean((dr - 1) ** 2) + # Fake should be 0.0 + g_loss = torch.mean(dg ** 2) + + loss += (r_loss + g_loss) + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_adv_loss(disc_generated_outputs): + """ + Least Squares GAN Loss for the Generator. + Objective: Fake -> 1 (Fool the discriminator) + """ + loss = 0 + for dg in zip(disc_generated_outputs): + dg = dg[0] # Unpack tuple + loss += torch.mean((dg - 1) ** 2) + return loss def discriminator_train( high_quality, low_quality, - high_labels, - low_labels, discriminator, - criterion, generator_output ): + # 1. Forward pass through the Ensemble Discriminator + # Note: We pass inputs separately now: (Real_Target, Fake_Candidate) + # We detach generator_output because we are only optimizing D here + y_d_rs, y_d_gs, _, _ = discriminator(high_quality, generator_output.detach()) - real_pair = torch.cat((low_quality, high_quality), dim=1) - decision_real = discriminator(real_pair) - d_loss_real = criterion(decision_real, high_labels) + # 2. Calculate Loss (LSGAN) + d_loss, _, _ = discriminator_loss(y_d_rs, y_d_gs) - fake_pair = torch.cat((low_quality, generator_output), dim=1) - decision_fake = discriminator(fake_pair) - d_loss_fake = criterion(decision_fake, low_labels) - - d_loss = (d_loss_real + d_loss_fake) / 2.0 return d_loss def generator_train( - low_quality, high_quality, real_labels, generator, discriminator, adv_criterion, generator_output): + low_quality, + high_quality, + generator, + discriminator, + generator_output +): + # 1. Forward pass through Discriminator + # We do NOT detach generator_output here, we need gradients for G + y_d_rs, y_d_gs, fmap_rs, fmap_gs = discriminator(high_quality, generator_output) - fake_pair = torch.cat((low_quality, generator_output), dim=1) + # 2. Adversarial Loss (Try to fool D into thinking G is Real) + loss_gen_adv = generator_adv_loss(y_d_gs) - discriminator_decision = discriminator(fake_pair) - adversarial_loss = adv_criterion(discriminator_decision, real_labels) + # 3. Feature Matching Loss (Force G to match internal features of D) + loss_fm = feature_matching_loss(fmap_rs, fmap_gs) - mae_loss = signal_mae(generator_output, high_quality) + # 4. Mel-Spectrogram / STFT Loss (Audio Quality) stft_loss = stft_loss_fn(high_quality, generator_output)["total"] - lambda_mae = 10.0 - lambda_stft = 2.5 - lambda_adv = 2.5 - combined_loss = (lambda_mae * mae_loss) + (lambda_stft * stft_loss) + (lambda_adv * adversarial_loss) - return combined_loss, adversarial_loss + # ----------------------------------------- + # 5. Combine Losses + # ----------------------------------------- + # Recommended weights for HiFi-GAN/EnCodec style architectures: + # STFT is dominant (45), FM provides stability (2), Adv provides texture (1) + + lambda_stft = 45.0 + lambda_fm = 2.0 + lambda_adv = 1.0 + + combined_loss = (lambda_stft * stft_loss) + \ + (lambda_fm * loss_fm) + \ + (lambda_adv * loss_gen_adv) + + return combined_loss, loss_gen_adv From e3e555794e2906d474e48c2ada5bde6d3cdf0bfa Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Sat, 6 Dec 2025 18:04:18 +0200 Subject: [PATCH 11/11] :alembic: | Added MultiPeriodDiscriminator implementation from HiFi-GAN --- discriminator.py | 247 +++++++++++++++++++++++++++-------------- training.py | 41 +++---- utils/TrainingTools.py | 24 +--- 3 files changed, 187 insertions(+), 125 deletions(-) diff --git a/discriminator.py b/discriminator.py index 8e9e2ec..0572d17 100644 --- a/discriminator.py +++ b/discriminator.py @@ -1,98 +1,179 @@ import torch import torch.nn as nn -import torch.nn.utils as utils -import numpy as np +import torch.nn.functional as F +from torch.nn.utils.parametrizations import weight_norm, spectral_norm -class PatchEmbedding(nn.Module): - """ - Converts raw audio into a sequence of embeddings (tokens). - Small patch_size = Higher Precision (more tokens, finer detail). - Large patch_size = Lower Precision (fewer tokens, more global). - """ - def __init__(self, in_channels, embed_dim, patch_size, spectral_norm=True): - super().__init__() - # We use a Conv1d with stride=patch_size to create non-overlapping patches - self.proj = nn.Conv1d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) +# ------------------------------------------------------------------- +# 1. Multi-Period Discriminator (MPD) +# Captures periodic structures (pitch/timbre) by folding audio. +# ------------------------------------------------------------------- - if spectral_norm: - self.proj = utils.spectral_norm(self.proj) +class DiscriminatorP(nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + + # Use spectral_norm for stability, or weight_norm for performance + norm_f = spectral_norm if use_spectral_norm else weight_norm + + # We use 2D convs because we "fold" the 1D audio into 2D (Period x Time) + self.convs = nn.ModuleList([ + norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(2, 0))), + norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(2, 0))), + norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(2, 0))), + norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(2, 0))), + norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ]) + + self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) def forward(self, x): - # x shape: (batch, 1, 8000) - x = self.proj(x) # shape: (batch, embed_dim, num_patches) - x = x.transpose(1, 2) # shape: (batch, num_patches, embed_dim) - return x + fmap = [] -class TransformerDiscriminator(nn.Module): - def __init__( - self, - audio_length=8000, - patch_size=16, # Lower this for higher precision (e.g., 8 or 16) - embed_dim=128, # Dimension of the transformer tokens - depth=4, # Number of Transformer blocks - heads=4, # Number of attention heads - mlp_dim=256, # Hidden dimension of the feed-forward layer - spectral_norm=True - ): - super().__init__() - - # 1. Calculate sequence length - self.num_patches = audio_length // patch_size - - # 2. Patch Embedding (Tokenizer) - self.patch_embed = PatchEmbedding(1, embed_dim, patch_size, spectral_norm) - - # 3. Class Token (like in BERT/ViT) to aggregate global info - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - - # 4. Positional Embedding (Learnable) - self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim)) - - # 5. Transformer Encoder - encoder_layer = nn.TransformerEncoderLayer( - d_model=embed_dim, - nhead=heads, - dim_feedforward=mlp_dim, - dropout=0.1, - activation='gelu', - batch_first=True - ) - self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth) - - # 6. Final Classification Head - self.norm = nn.LayerNorm(embed_dim) - self.head = nn.Linear(embed_dim, 1) - - if spectral_norm: - self.head = utils.spectral_norm(self.head) - - # Initialize weights - self._init_weights() - - def _init_weights(self): - nn.init.normal_(self.cls_token, std=0.02) - nn.init.normal_(self.pos_embed, std=0.02) - - def forward(self, x): + # 1d to 2d conversion: [B, C, T] -> [B, C, T/P, P] b, c, t = x.shape + if t % self.period != 0: # Pad if not divisible by period + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad - # --- 1. Tokenize Audio --- - x = self.patch_embed(x) # (Batch, Num_Patches, Embed_Dim) + x = x.view(b, c, t // self.period, self.period) - # --- 2. Add CLS Token --- - cls_tokens = self.cls_token.expand(b, -1, -1) - x = torch.cat((cls_tokens, x), dim=1) # (Batch, Num_Patches + 1, Embed_Dim) + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, 0.1) + fmap.append(x) # Store feature map for Feature Matching Loss - # --- 3. Add Positional Embeddings --- - x = x + self.pos_embed + x = self.conv_post(x) + fmap.append(x) - # --- 4. Transformer Layers --- - x = self.transformer(x) + # Flatten back to 1D for score + x = torch.flatten(x, 1, -1) - # --- 5. Classification (Use only CLS token) --- - cls_output = x[:, 0] # Take the first token - cls_output = self.norm(cls_output) + return x, fmap - score = self.head(cls_output) # (Batch, 1) - return score +class MultiPeriodDiscriminator(nn.Module): + def __init__(self, periods=[2, 3, 5, 7, 11]): + super(MultiPeriodDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorP(p) for p in periods + ]) + + def forward(self, y, y_hat): + y_d_rs = [] # Real scores + y_d_gs = [] # Generated (Fake) scores + fmap_rs = [] # Real feature maps + fmap_gs = [] # Generated (Fake) feature maps + + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +# ------------------------------------------------------------------- +# 2. Multi-Scale Discriminator (MSD) +# Captures structure at different audio resolutions (raw, x0.5, x0.25). +# ------------------------------------------------------------------- + +class DiscriminatorS(nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = spectral_norm if use_spectral_norm else weight_norm + + # Standard 1D Convolutions with large receptive field + self.convs = nn.ModuleList([ + norm_f(nn.Conv1d(1, 16, 15, 1, padding=7)), + norm_f(nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, 0.1) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + return x, fmap + + +class MultiScaleDiscriminator(nn.Module): + def __init__(self): + super(MultiScaleDiscriminator, self).__init__() + # 3 Scales: Original, Downsampled x2, Downsampled x4 + self.discriminators = nn.ModuleList([ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ]) + self.meanpools = nn.ModuleList([ + nn.AvgPool1d(4, 2, padding=2), + nn.AvgPool1d(4, 2, padding=2) + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for i, d in enumerate(self.discriminators): + if i != 0: + # Downsample input for subsequent discriminators + y = self.meanpools[i-1](y) + y_hat = self.meanpools[i-1](y_hat) + + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +# ------------------------------------------------------------------- +# 3. Master Wrapper +# Combines MPD and MSD into one class to fit your training script. +# ------------------------------------------------------------------- + +class SISUDiscriminator(nn.Module): + def __init__(self): + super(SISUDiscriminator, self).__init__() + self.mpd = MultiPeriodDiscriminator() + self.msd = MultiScaleDiscriminator() + + def forward(self, y, y_hat): + # Return format: + # scores_real, scores_fake, features_real, features_fake + + # Run Multi-Period + mpd_y_d_rs, mpd_y_d_gs, mpd_fmap_rs, mpd_fmap_gs = self.mpd(y, y_hat) + + # Run Multi-Scale + msd_y_d_rs, msd_y_d_gs, msd_fmap_rs, msd_fmap_gs = self.msd(y, y_hat) + + # Combine all results + return ( + mpd_y_d_rs + msd_y_d_rs, # All real scores + mpd_y_d_gs + msd_y_d_gs, # All fake scores + mpd_fmap_rs + msd_fmap_rs, # All real feature maps + mpd_fmap_gs + msd_fmap_gs # All fake feature maps + ) diff --git a/training.py b/training.py index 35733be..8107b03 100644 --- a/training.py +++ b/training.py @@ -3,6 +3,7 @@ import datetime import os import torch +import torch.nn as nn import torch.optim as optim import tqdm from accelerate import Accelerator @@ -39,10 +40,13 @@ accelerator = Accelerator(mixed_precision="bf16") # Models # --------------------------- generator = SISUGenerator() +# Note: SISUDiscriminator is now an Ensemble (MPD + MSD) discriminator = SISUDiscriminator() accelerator.print("๐Ÿ”จ | Compiling models...") +# Torch compile is great, but if you hit errors with the new List/Tuple outputs +# of the discriminator, you might need to disable it for D. generator = torch.compile(generator) discriminator = torch.compile(discriminator) @@ -108,21 +112,24 @@ models_dir = "./models" os.makedirs(models_dir, exist_ok=True) -def save_ckpt(path, epoch): +def save_ckpt(path, epoch, loss=None, is_best=False): accelerator.wait_for_everyone() if accelerator.is_main_process: - accelerator.save( - { - "epoch": epoch, - "G": accelerator.unwrap_model(generator).state_dict(), - "D": accelerator.unwrap_model(discriminator).state_dict(), - "optG": optimizer_g.state_dict(), - "optD": optimizer_d.state_dict(), - "schedG": scheduler_g.state_dict(), - "schedD": scheduler_d.state_dict(), - }, - path, - ) + state = { + "epoch": epoch, + "G": accelerator.unwrap_model(generator).state_dict(), + "D": accelerator.unwrap_model(discriminator).state_dict(), + "optG": optimizer_g.state_dict(), + "optD": optimizer_d.state_dict(), + "schedG": scheduler_g.state_dict(), + "schedD": scheduler_d.state_dict() + } + + accelerator.save(state, os.path.join(models_dir, "last.pt")) + + if is_best: + accelerator.save(state, os.path.join(models_dir, "best.pt")) + accelerator.print(f"๐ŸŒŸ | New best model saved with G Loss: {loss:.4f}") start_epoch = 0 @@ -143,9 +150,8 @@ if args.resume: else: accelerator.print("โš ๏ธ | Resume requested but no checkpoint found. Starting fresh.") -accelerator.print("๐Ÿ‹๏ธ | Started training...") -smallest_loss = float('inf') +accelerator.print("๐Ÿ‹๏ธ | Started training...") try: for epoch in range(start_epoch, args.epochs): @@ -172,7 +178,6 @@ try: with accelerator.autocast(): d_loss = discriminator_train( high_quality, - low_quality.detach(), discriminator, generator_output.detach() ) @@ -218,7 +223,6 @@ try: steps += 1 progress_bar.set_description(f"Epoch {epoch} | D {discriminator_time}ฮผs | G {generator_time}ฮผs") - # epoch averages & schedulers if steps == 0: accelerator.print("๐Ÿชน | No steps in epoch (empty dataloader?). Exiting.") break @@ -230,9 +234,6 @@ try: scheduler_g.step(mean_g) save_ckpt(os.path.join(models_dir, "last.pt"), epoch) - if smallest_loss > mean_g: - smallest_loss = mean_g - save_ckpt(os.path.join(models_dir, "latest-smallest_loss.pt"), epoch) accelerator.print(f"๐Ÿค | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}") except Exception: diff --git a/utils/TrainingTools.py b/utils/TrainingTools.py index 7ed2d0f..cd0350a 100644 --- a/utils/TrainingTools.py +++ b/utils/TrainingTools.py @@ -2,13 +2,14 @@ import torch import torch.nn.functional as F from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss - +# Keep STFT settings as is stft_loss_fn = MultiResolutionSTFTLoss( fft_sizes=[512, 1024, 2048], hop_sizes=[64, 128, 256], win_lengths=[256, 512, 1024] ) + def feature_matching_loss(fmap_r, fmap_g): """ Computes L1 distance between real and fake feature maps. @@ -16,11 +17,9 @@ def feature_matching_loss(fmap_r, fmap_g): loss = 0 for dr, dg in zip(fmap_r, fmap_g): for rl, gl in zip(dr, dg): - # Stop gradient on real features to save memory/computation rl = rl.detach() loss += torch.mean(torch.abs(rl - gl)) - # Scale by number of feature maps to keep loss magnitude reasonable return loss * 2 @@ -33,11 +32,8 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs): r_losses = [] g_losses = [] - # Iterate over both MPD and MSD outputs for dr, dg in zip(disc_real_outputs, disc_generated_outputs): - # Real should be 1.0 r_loss = torch.mean((dr - 1) ** 2) - # Fake should be 0.0 g_loss = torch.mean(dg ** 2) loss += (r_loss + g_loss) @@ -61,16 +57,11 @@ def generator_adv_loss(disc_generated_outputs): def discriminator_train( high_quality, - low_quality, discriminator, generator_output ): - # 1. Forward pass through the Ensemble Discriminator - # Note: We pass inputs separately now: (Real_Target, Fake_Candidate) - # We detach generator_output because we are only optimizing D here y_d_rs, y_d_gs, _, _ = discriminator(high_quality, generator_output.detach()) - # 2. Calculate Loss (LSGAN) d_loss, _, _ = discriminator_loss(y_d_rs, y_d_gs) return d_loss @@ -83,25 +74,14 @@ def generator_train( discriminator, generator_output ): - # 1. Forward pass through Discriminator - # We do NOT detach generator_output here, we need gradients for G y_d_rs, y_d_gs, fmap_rs, fmap_gs = discriminator(high_quality, generator_output) - # 2. Adversarial Loss (Try to fool D into thinking G is Real) loss_gen_adv = generator_adv_loss(y_d_gs) - # 3. Feature Matching Loss (Force G to match internal features of D) loss_fm = feature_matching_loss(fmap_rs, fmap_gs) - # 4. Mel-Spectrogram / STFT Loss (Audio Quality) stft_loss = stft_loss_fn(high_quality, generator_output)["total"] - # ----------------------------------------- - # 5. Combine Losses - # ----------------------------------------- - # Recommended weights for HiFi-GAN/EnCodec style architectures: - # STFT is dominant (45), FM provides stability (2), Adv provides texture (1) - lambda_stft = 45.0 lambda_fm = 2.0 lambda_adv = 1.0