diff --git a/discriminator.py b/discriminator.py index 777abf2..dfd0126 100644 --- a/discriminator.py +++ b/discriminator.py @@ -39,7 +39,7 @@ class AttentionBlock(nn.Module): return x * attention_weights class SISUDiscriminator(nn.Module): - def __init__(self, base_channels=64): + def __init__(self, base_channels=16): super(SISUDiscriminator, self).__init__() layers = base_channels self.model = nn.Sequential( diff --git a/generator.py b/generator.py index cd4d48c..a53feb7 100644 --- a/generator.py +++ b/generator.py @@ -48,7 +48,7 @@ class ResidualInResidualBlock(nn.Module): return x + residual class SISUGenerator(nn.Module): - def __init__(self, channels=64, num_rirb=8, alpha=1.0): + def __init__(self, channels=16, num_rirb=4, alpha=1.0): super(SISUGenerator, self).__init__() self.alpha = alpha diff --git a/training.py b/training.py index db7cb86..01ea749 100644 --- a/training.py +++ b/training.py @@ -34,7 +34,7 @@ parser.add_argument("--discriminator", type=str, default=None, parser.add_argument("--device", type=str, default="cpu", help="Select device") parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning") parser.add_argument("--debug", action="store_true", help="Print debug logs") -parser.add_argument("--continue_training", type=bool, default=False, help="Continue training using temp_generator and temp_discriminator models") +parser.add_argument("--continue_training", action="store_true", help="Continue training using temp_generator and temp_discriminator models") args = parser.parse_args() @@ -60,6 +60,10 @@ mel_transform = T.MelSpectrogram( win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel ).to(device) +stft_transform = T.Spectrogram( + n_fft=n_fft, win_length=win_length, hop_length=hop_length +).to(device) + debug = args.debug # Initialize dataset and dataloader @@ -72,7 +76,7 @@ os.makedirs(audio_output_dir, exist_ok=True) # ========= SINGLE ========= -train_data_loader = DataLoader(dataset, batch_size=12, shuffle=True) +train_data_loader = DataLoader(dataset, batch_size=64, shuffle=True) # ========= MODELS ========= @@ -143,7 +147,7 @@ def start_training(): # ========= GENERATOR ========= generator.train() - generator_output, combined_loss, adversarial_loss, mel_l1_tensor = generator_train( + generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train( low_quality_sample, high_quality_sample, real_labels, @@ -152,11 +156,13 @@ def start_training(): criterion_d, optimizer_g, device, - mel_transform + mel_transform, + stft_transform, + mfcc_transform ) if debug: - print(combined_loss, adversarial_loss, mel_l1_tensor) + print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}") scheduler_d.step(d_loss.detach()) scheduler_g.step(adversarial_loss.detach()) @@ -173,9 +179,9 @@ def start_training(): torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu().detach(), ai_enhanced_audio[1]) torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu().detach(), high_quality_audio[1]) - if debug: - print(generator.state_dict().keys()) - print(discriminator.state_dict().keys()) + #if debug: + # print(generator.state_dict().keys()) + # print(discriminator.state_dict().keys()) torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt") torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt") Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch}) diff --git a/training_utils.py b/training_utils.py index be402d9..6f26f58 100644 --- a/training_utils.py +++ b/training_utils.py @@ -37,7 +37,6 @@ def mel_spectrogram_l2_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tenso mel_spec_true = mel_spec_true[..., :min_len] mel_spec_pred = mel_spec_pred[..., :min_len] - # L2 Loss (Mean Squared Error) loss = torch.mean((mel_spec_true - mel_spec_pred)**2) return loss @@ -49,7 +48,6 @@ def log_stft_magnitude_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, stft_mag_true = stft_mag_true[..., :min_len] stft_mag_pred = stft_mag_pred[..., :min_len] - # Log Magnitude L1 Loss loss = torch.mean(torch.abs(torch.log(stft_mag_true + eps) - torch.log(stft_mag_pred + eps))) return loss @@ -61,12 +59,9 @@ def spectral_convergence_loss(stft_transform: T.Spectrogram, y_true: torch.Tenso stft_mag_true = stft_mag_true[..., :min_len] stft_mag_pred = stft_mag_pred[..., :min_len] - # Calculate Frobenius norms and the loss - # Ensure norms are calculated over frequency and time dims ([..., freq, time]) norm_true = torch.linalg.norm(stft_mag_true, ord='fro', dim=(-2, -1)) norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, ord='fro', dim=(-2, -1)) - # Average loss over the batch loss = torch.mean(norm_diff / (norm_true + eps)) return loss @@ -77,16 +72,13 @@ def discriminator_train(high_quality, low_quality, real_labels, fake_labels, dis discriminator_decision_from_real = discriminator(high_quality[0]) d_loss_real = criterion(discriminator_decision_from_real, real_labels) - # Forward pass for fake samples (from generator output) - with torch.no_grad(): # Detach generator output within no_grad context + with torch.no_grad(): generator_output = generator(low_quality[0]) - discriminator_decision_from_fake = discriminator(generator_output) # No need to detach again if inside no_grad + discriminator_decision_from_fake = discriminator(generator_output) d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels.expand_as(discriminator_decision_from_fake)) - # Combine real and fake losses d_loss = (d_loss_real + d_loss_fake) / 2.0 - # Backward pass and optimization d_loss.backward() # Optional: Gradient Clipping (can be helpful) # nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping @@ -100,65 +92,53 @@ def generator_train( real_labels, generator, discriminator, - adv_criterion, # Criterion for adversarial loss (e.g., BCEWithLogitsLoss) + adv_criterion, g_optimizer, device, - # --- Pass necessary transforms and loss weights --- - mel_transform: T.MelSpectrogram, # Example: Pass Mel transform - # stft_transform: T.Spectrogram, # Pass STFT transform if using STFT losses - # mfcc_transform: T.MFCC, # Pass MFCC transform if using MFCC loss - lambda_adv: float = 1.0, # Weight for adversarial loss - lambda_mel_l1: float = 10.0, # Example: Weight for Mel L1 loss - # lambda_log_stft: float = 0.0, # Set weights > 0 for losses you want to use - # lambda_mfcc: float = 0.0 + mel_transform: T.MelSpectrogram, + stft_transform: T.Spectrogram, + mfcc_transform: T.MFCC, + lambda_adv: float = 1.0, + lambda_mel_l1: float = 10.0, + lambda_log_stft: float = 1.0, + lambda_mfcc: float = 1.0 ): g_optimizer.zero_grad() - # 1. Generate high-quality audio from low-quality input generator_output = generator(low_quality[0]) - # 2. Calculate Adversarial Loss (Generator tries to fool discriminator) discriminator_decision = discriminator(generator_output) - # Generator wants discriminator to output "real" labels for its fakes adversarial_loss = adv_criterion(discriminator_decision, real_labels.expand_as(discriminator_decision)) - # 3. Calculate Reconstruction/Spectrogram Loss(es) - # --- Choose and calculate the losses you want to include --- mel_l1 = 0.0 - # log_stft_l1 = 0.0 - # mfcc_l = 0.0 + log_stft_l1 = 0.0 + mfcc_l = 0.0 # Calculate Mel L1 Loss if weight is positive if lambda_mel_l1 > 0: mel_l1 = mel_spectrogram_l1_loss(mel_transform, high_quality[0], generator_output) - # # Calculate Log STFT L1 Loss if weight is positive - # if lambda_log_stft > 0: - # log_stft_l1 = log_stft_magnitude_loss(stft_transform, hq_audio, generator_output) + # Calculate Log STFT L1 Loss if weight is positive + if lambda_log_stft > 0: + log_stft_l1 = log_stft_magnitude_loss(stft_transform, high_quality[0], generator_output) - # # Calculate MFCC Loss if weight is positive - # if lambda_mfcc > 0: - # mfcc_l = gpu_mfcc_loss(mfcc_transform, hq_audio, generator_output) - # --- End of Loss Calculation Choices --- + # Calculate MFCC Loss if weight is positive + if lambda_mfcc > 0: + mfcc_l = gpu_mfcc_loss(mfcc_transform, high_quality[0], generator_output) - - # 4. Combine Losses - # Make sure calculated losses are tensors even if weights are 0 initially - # (or handle appropriately in the sum) mel_l1_tensor = torch.tensor(mel_l1, device=device) if isinstance(mel_l1, float) else mel_l1 - # log_stft_l1_tensor = torch.tensor(log_stft_l1, device=device) if isinstance(log_stft_l1, float) else log_stft_l1 - # mfcc_l_tensor = torch.tensor(mfcc_l, device=device) if isinstance(mfcc_l, float) else mfcc_l + log_stft_l1_tensor = torch.tensor(log_stft_l1, device=device) if isinstance(log_stft_l1, float) else log_stft_l1 + mfcc_l_tensor = torch.tensor(mfcc_l, device=device) if isinstance(mfcc_l, float) else mfcc_l combined_loss = (lambda_adv * adversarial_loss) + \ - (lambda_mel_l1 * mel_l1_tensor) - # + (lambda_log_stft * log_stft_l1_tensor) \ - # + (lambda_mfcc * mfcc_l_tensor) + (lambda_mel_l1 * mel_l1_tensor) + \ + (lambda_log_stft * log_stft_l1_tensor) + \ + (lambda_mfcc * mfcc_l_tensor) - # 5. Backward Pass and Optimization combined_loss.backward() # Optional: Gradient Clipping # nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0) g_optimizer.step() # 6. Return values for logging - return generator_output, combined_loss, adversarial_loss, mel_l1_tensor + return generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor