diff --git a/data.py b/data.py index 71f25dd..dfd5c2b 100644 --- a/data.py +++ b/data.py @@ -19,26 +19,25 @@ class AudioDataset(Dataset): def __getitem__(self, idx): - high_quality_wav, sr_original = torchaudio.load(self.input_files[idx], normalize=True) + high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True) - sample_rate = random.choice(self.audio_sample_rates) - resample_transform = torchaudio.transforms.Resample(sr_original, sample_rate) - low_quality_wav = resample_transform(high_quality_wav) - low_quality_wav = low_quality_wav + mangled_sample_rate = random.choice(self.audio_sample_rates) + resample_transform = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate) + low_quality_audio = resample_transform(high_quality_audio) # Calculate target length based on desired duration and 16000 Hz - if self.target_duration is not None: - target_length = int(self.target_duration * 44100) - else: - # Calculate duration of original high quality audio - target_length = high_quality_wav.size(1) + # if self.target_duration is not None: + # target_length = int(self.target_duration * 44100) + # else: + # # Calculate duration of original high quality audio + # target_length = high_quality_wav.size(1) # Pad both to the calculated target length - high_quality_wav = self.stretch_tensor(high_quality_wav, target_length) - low_quality_wav = self.stretch_tensor(low_quality_wav, target_length) + # high_quality_wav = self.stretch_tensor(high_quality_wav, target_length) + # low_quality_wav = self.stretch_tensor(low_quality_wav, target_length) - return low_quality_wav, high_quality_wav + return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate) def stretch_tensor(self, tensor, target_length): current_length = tensor.size(1) diff --git a/discriminator.pt b/discriminator.pt new file mode 100644 index 0000000..b0962e9 Binary files /dev/null and b/discriminator.pt differ diff --git a/discriminator.py b/discriminator.py index 86a35d6..a63d1e9 100644 --- a/discriminator.py +++ b/discriminator.py @@ -5,15 +5,15 @@ class SISUDiscriminator(nn.Module): super(SISUDiscriminator, self).__init__() self.model = nn.Sequential( nn.Conv1d(2, 128, kernel_size=3, padding=1), - nn.LeakyReLU(0.2, inplace=True), + #nn.LeakyReLU(0.2, inplace=True), nn.Conv1d(128, 256, kernel_size=3, padding=1), nn.LeakyReLU(0.2, inplace=True), nn.Conv1d(256, 128, kernel_size=3, padding=1), - nn.LeakyReLU(0.2, inplace=True), + #nn.LeakyReLU(0.2, inplace=True), nn.Conv1d(128, 64, kernel_size=3, padding=1), - nn.LeakyReLU(0.2, inplace=True), + #nn.LeakyReLU(0.2, inplace=True), nn.Conv1d(64, 1, kernel_size=3, padding=1), - nn.LeakyReLU(0.2, inplace=True), + #nn.LeakyReLU(0.2, inplace=True), ) self.global_avg_pool = nn.AdaptiveAvgPool1d(1) # Output size (1,) diff --git a/generator.pt b/generator.pt new file mode 100644 index 0000000..044e716 Binary files /dev/null and b/generator.pt differ diff --git a/generator.py b/generator.py index 978f02a..08fe584 100644 --- a/generator.py +++ b/generator.py @@ -3,21 +3,25 @@ import torch.nn as nn class SISUGenerator(nn.Module): def __init__(self, upscale_scale=1): # No noise_dim parameter super(SISUGenerator, self).__init__() - self.model = nn.Sequential( + self.layers1 = nn.Sequential( nn.Conv1d(2, 128, kernel_size=3, padding=1), - nn.LeakyReLU(0.2, inplace=True), + # nn.LeakyReLU(0.2, inplace=True), nn.Conv1d(128, 256, kernel_size=3, padding=1), - nn.LeakyReLU(0.2, inplace=True), - - nn.Upsample(scale_factor=upscale_scale, mode='nearest'), - - nn.Conv1d(256, 128, kernel_size=3, padding=1), - nn.LeakyReLU(0.2, inplace=True), - nn.Conv1d(128, 64, kernel_size=3, padding=1), - nn.LeakyReLU(0.2, inplace=True), - nn.Conv1d(64, 2, kernel_size=3, padding=1), - nn.Tanh() + # nn.LeakyReLU(0.2, inplace=True), ) - def forward(self, x): - return self.model(x) + self.layers2 = nn.Sequential( + nn.Conv1d(256, 128, kernel_size=3, padding=1), + # nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(128, 64, kernel_size=3, padding=1), + # nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(64, 2, kernel_size=3, padding=1), + # nn.Tanh() + ) + + def forward(self, x, scale): + x = self.layers1(x) + upsample = nn.Upsample(scale_factor=scale, mode='nearest') + x = upsample(x) + x = self.layers2(x) + return x diff --git a/training.py b/training.py index 73751f0..9d16004 100644 --- a/training.py +++ b/training.py @@ -13,6 +13,60 @@ from data import AudioDataset from generator import SISUGenerator from discriminator import SISUDiscriminator +# Mel Spectrogram Loss +class MelSpectrogramLoss(nn.Module): + def __init__(self, sample_rate=44100, n_fft=2048, hop_length=512, n_mels=128): + super(MelSpectrogramLoss, self).__init__() + self.mel_transform = torchaudio.transforms.MelSpectrogram( + sample_rate=sample_rate, + n_fft=n_fft, + hop_length=hop_length, + n_mels=n_mels + ).to(device) # Move to device + + def forward(self, y_pred, y_true): + mel_pred = self.mel_transform(y_pred) + mel_true = self.mel_transform(y_true) + return F.l1_loss(mel_pred, mel_true) + +def snr(y_true, y_pred): + noise = y_true - y_pred + signal_power = torch.mean(y_true ** 2) + noise_power = torch.mean(noise ** 2) + snr_db = 10 * torch.log10(signal_power / noise_power) + return snr_db + +def discriminator_train(high_quality, low_quality, scale, real_labels, fake_labels): + optimizer_d.zero_grad() + + discriminator_decision_from_real = discriminator(high_quality) + # TODO: Experiment with criterions HERE! + d_loss_real = criterion_d(discriminator_decision_from_real, real_labels) + + generator_output = generator(low_quality, scale) + discriminator_decision_from_fake = discriminator(generator_output.detach()) + # TODO: Experiment with criterions HERE! + d_loss_fake = criterion_d(discriminator_decision_from_fake, fake_labels) + + d_loss = (d_loss_real + d_loss_fake) / 2.0 + + d_loss.backward() + nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) #Gradient Clipping + optimizer_d.step() + return d_loss + +def generator_train(low_quality, scale, real_labels): + optimizer_g.zero_grad() + + generator_output = generator(low_quality, scale) + discriminator_decision = discriminator(generator_output) + # TODO: Fix this shit + g_loss = criterion_g(discriminator_decision, real_labels) + + g_loss.backward() + optimizer_g.step() + return generator_output + # Check for CUDA availability device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") @@ -27,8 +81,8 @@ val_size = int(dataset_size-train_size) train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) -train_data_loader = DataLoader(train_dataset, batch_size=8, shuffle=True) -val_data_loader = DataLoader(val_dataset, batch_size=8, shuffle=True) +train_data_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) +val_data_loader = DataLoader(val_dataset, batch_size=1, shuffle=True) # Initialize models and move them to device generator = SISUGenerator() @@ -39,6 +93,7 @@ discriminator = discriminator.to(device) # Loss criterion_g = nn.L1Loss() +criterion_g_mel = MelSpectrogramLoss().to(device) criterion_d = nn.BCEWithLogitsLoss() # Optimizers @@ -49,87 +104,80 @@ optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.99 scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5) scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5) -def snr(y_true, y_pred): - noise = y_true - y_pred - signal_power = torch.mean(y_true ** 2) - noise_power = torch.mean(noise ** 2) - snr_db = 10 * torch.log10(signal_power / noise_power) - return snr_db - -def discriminator_train(discriminator, optimizer, criterion, generator, real_labels, fake_labels, high_quality, low_quality): - optimizer.zero_grad() - - discriminator_decision_from_real = discriminator(high_quality) - d_loss_real = criterion(discriminator_decision_from_real, real_labels) - - generator_output = generator(low_quality) - discriminator_decision_from_fake = discriminator(generator_output.detach()) - d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels) - - d_loss = (d_loss_real + d_loss_fake) / 2.0 - - d_loss.backward() - nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) #Gradient Clipping - optimizer.step() - # print(f"Discriminator Loss: {d_loss.item():.4f}, Mean Real Logit: {discriminator_decision_from_real.mean().item():.2f}, Mean Fake Logit: {discriminator_decision_from_fake.mean().item():.2f}") - def start_training(): # Training loop - # discriminator_epochs = 1000 - generator_epochs = 500 - for generator_epoch in range(generator_epochs): - low_quality_audio = torch.empty((1)) - high_quality_audio = torch.empty((1)) - ai_enhanced_audio = torch.empty((1)) - # Training - for low_quality, high_quality in tqdm.tqdm(train_data_loader, desc=f"Epoch {generator_epoch+1}/{generator_epochs}"): - high_quality = high_quality.to(device) - low_quality = low_quality.to(device) + # ========= DISCRIMINATOR PRE-TRAINING ========= + discriminator_epochs = 1 + for discriminator_epoch in range(discriminator_epochs): - batch_size = high_quality.size(0) + # ========= TRAINING ========= + for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Epoch {discriminator_epoch+1}/{discriminator_epochs}"): + high_quality_sample = high_quality_clip[0].to(device) + low_quality_sample = low_quality_clip[0].to(device) + + scale = high_quality_clip[0].shape[2]/low_quality_clip[0].shape[2] + + # ========= LABELS ========= + batch_size = high_quality_sample.size(0) real_labels = torch.ones(batch_size, 1).to(device) fake_labels = torch.zeros(batch_size, 1).to(device) - # Train Discriminator + # ========= DISCRIMINATOR ========= discriminator.train() + discriminator_train(high_quality_sample, low_quality_sample, scale, real_labels, fake_labels) + torch.save(discriminator.state_dict(), "models/discriminator-single-shot-pre-train.pt") + + generator_epochs = 500 + for generator_epoch in range(generator_epochs): + low_quality_audio = (torch.empty((1)), 1) + high_quality_audio = (torch.empty((1)), 1) + ai_enhanced_audio = (torch.empty((1)), 1) + + # ========= TRAINING ========= + for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Epoch {generator_epoch+1}/{generator_epochs}"): + high_quality_sample = high_quality_clip[0].to(device) + low_quality_sample = low_quality_clip[0].to(device) + + scale = high_quality_clip[0].shape[2]/low_quality_clip[0].shape[2] + + # ========= LABELS ========= + batch_size = high_quality_clip[0].size(0) + real_labels = torch.ones(batch_size, 1).to(device) + fake_labels = torch.zeros(batch_size, 1).to(device) + + # ========= DISCRIMINATOR ========= + discriminator.train() for _ in range(3): - discriminator_train(discriminator, optimizer_d, criterion_d, generator, real_labels, fake_labels, high_quality, low_quality) + discriminator_train(high_quality_sample, low_quality_sample, scale, real_labels, fake_labels) - # Train Generator + # ========= GENERATOR ========= generator.train() - optimizer_g.zero_grad() + generator_output = generator_train(low_quality_sample, scale, real_labels) - # Generator loss: how well fake data fools the discriminator - generator_output = generator(low_quality) - discriminator_decision = discriminator(generator_output) # No detach here - g_loss = criterion_g(discriminator_decision, real_labels) # Train generator to produce real-like outputs + # ========= SAVE LATEST AUDIO ========= + high_quality_audio = high_quality_clip + low_quality_audio = low_quality_clip + ai_enhanced_audio = (generator_output, high_quality_clip[1]) - g_loss.backward() - optimizer_g.step() - - low_quality_audio = low_quality - high_quality_audio = high_quality - ai_enhanced_audio = generator_output - - metric = snr(high_quality_audio, ai_enhanced_audio) + metric = snr(high_quality_audio[0].to(device), ai_enhanced_audio[0]) print(f"Generator metric {metric}!") scheduler_g.step(metric) if generator_epoch % 10 == 0: print(f"Saved epoch {generator_epoch}!") - torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0].cpu(), 44100) - torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu(), 44100) - torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0].cpu(), 44100) + torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), low_quality_audio[1]) + torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1]) + torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1]) if generator_epoch % 50 == 0: - torch.save(discriminator.state_dict(), "discriminator.pt") - torch.save(generator.state_dict(), "generator.pt") + torch.save(discriminator.state_dict(), f"models/epoch-{generator_epoch}-discriminator.pt") + torch.save(generator.state_dict(), f"models/epoch-{generator_epoch}-generator.pt") - torch.save(discriminator.state_dict(), "discriminator.pt") - torch.save(generator.state_dict(), "generator.pt") + torch.save(discriminator.state_dict(), "models/epoch-500-discriminator.pt") + torch.save(generator.state_dict(), "models/epoch-500-generator.pt") print("Training complete!") start_training()