diff --git a/data.py b/data.py index 67abb91..fe41126 100644 --- a/data.py +++ b/data.py @@ -1,9 +1,13 @@ -import torch from torch.utils.data import Dataset +import torch.nn.functional as F import torchaudio import os +import random + class AudioDataset(Dataset): + audio_sample_rates = [8000, 11025, 16000, 22050] + def __init__(self, input_dir, target_duration=None, padding_mode='constant', padding_value=0.0): self.input_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.wav')] self.target_duration = target_duration # Duration in seconds or None if not set @@ -17,29 +21,30 @@ class AudioDataset(Dataset): def __getitem__(self, idx): high_quality_wav, sr_original = torchaudio.load(self.input_files[idx], normalize=True) - resample_transform = torchaudio.transforms.Resample(sr_original, 16000) + sample_rate = random.choice(self.audio_sample_rates) + resample_transform = torchaudio.transforms.Resample(sr_original, sample_rate) low_quality_wav = resample_transform(high_quality_wav) + low_quality_wav = -low_quality_wav # Calculate target length based on desired duration and 16000 Hz if self.target_duration is not None: target_length = int(self.target_duration * 44100) else: # Calculate duration of original high quality audio - duration_original = high_quality_wav.shape[1] / sr_original - target_length = int(duration_original * 16000) + target_length = high_quality_wav.size(1) # Pad both to the calculated target length - high_quality_wav = self.pad_tensor(high_quality_wav, target_length) - low_quality_wav = self.pad_tensor(low_quality_wav, target_length) + high_quality_wav = self.stretch_tensor(high_quality_wav, target_length) + low_quality_wav = self.stretch_tensor(low_quality_wav, target_length) + return low_quality_wav, high_quality_wav - def pad_tensor(self, tensor, target_length): + def stretch_tensor(self, tensor, target_length): current_length = tensor.size(1) - if current_length < target_length: - padding_amount = target_length - current_length - padding = (0, padding_amount) - tensor = torch.nn.functional.pad(tensor, padding, mode=self.padding_mode, value=self.padding_value) - elif current_length > target_length: - tensor = tensor[:, :target_length] + scale_factor = target_length / current_length + + # Resample the tensor using linear interpolation + tensor = F.interpolate(tensor.unsqueeze(0), scale_factor=scale_factor, mode='linear', align_corners=False).squeeze(0) + return tensor diff --git a/generator.py b/generator.py index b9401d7..8757aaf 100644 --- a/generator.py +++ b/generator.py @@ -1,23 +1,17 @@ import torch.nn as nn class SISUGenerator(nn.Module): - def __init__(self): # No noise_dim parameter + def __init__(self, upscale_scale=1): # No noise_dim parameter super(SISUGenerator, self).__init__() self.model = nn.Sequential( - nn.Conv1d(2, 64, kernel_size=7, stride=1, padding=3), # Input 2 channels (low-quality audio) - nn.LeakyReLU(0.2), - nn.Conv1d(64, 64, kernel_size=7, stride=1, padding=3), - nn.LeakyReLU(0.2), - nn.Conv1d(64, 128, kernel_size=5, stride=2, padding=2), - nn.LeakyReLU(0.2), - nn.Conv1d(128, 128, kernel_size=5, stride=1, padding=2), - nn.LeakyReLU(0.2), - nn.ConvTranspose1d(128, 64, kernel_size=4, stride=2, padding=1), - nn.LeakyReLU(0.2), - nn.Conv1d(64, 64, kernel_size=3, stride=1, padding=1), - nn.LeakyReLU(0.2), - nn.Conv1d(64, 2, kernel_size=3, stride=1, padding=1), # Output 2 channels (high-quality audio) - nn.Tanh() + nn.Conv1d(2, 128, kernel_size=3, padding=1), + nn.Conv1d(128, 256, kernel_size=3, padding=1), + + nn.Upsample(scale_factor=upscale_scale, mode='nearest'), + + nn.Conv1d(256, 128, kernel_size=3, padding=1), + nn.Conv1d(128, 64, kernel_size=3, padding=1), + nn.Conv1d(64, 2, kernel_size=3, padding=1) ) def forward(self, x): diff --git a/training.py b/training.py index 7a30eb8..828b21b 100644 --- a/training.py +++ b/training.py @@ -17,7 +17,7 @@ print(f"Using device: {device}") # Initialize dataset and dataloader dataset_dir = './dataset/good' -dataset = AudioDataset(dataset_dir, target_duration=2.0) # 5 seconds target duration +dataset = AudioDataset(dataset_dir, target_duration=2.0) dataset_size = len(dataset) train_size = int(dataset_size * .9) @@ -35,9 +35,12 @@ discriminator = SISUDiscriminator() generator = generator.to(device) discriminator = discriminator.to(device) -# Loss and optimizers -criterion = nn.MSELoss() # Use Mean Squared Error loss -optimizer_g = optim.Adam(generator.parameters(), lr=0.0005, betas=(0.5, 0.999)) +# Loss +criterion_g = nn.L1Loss() # Perceptual Loss (L1 instead of MSE) +criterion_d = nn.MSELoss() # Can keep MSE for discriminator (optional) + +# Optimizers +optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999)) # Reduced learning rate optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999)) # Learning rate scheduler @@ -45,11 +48,16 @@ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', # Training loop num_epochs = 500 -lambda_gp = 10 + for epoch in range(num_epochs): - original = torch.empty((2)) - crap_audio = torch.empty((2)) - for low_quality, high_quality in tqdm.tqdm(train_data_loader): + low_quality_audio = torch.empty((1)) + high_quality_audio = torch.empty((1)) + ai_enhanced_audio = torch.empty((1)) + total_d_loss = 0 + total_g_loss = 0 + + # Training + for low_quality, high_quality in tqdm.tqdm(train_data_loader, desc=f"Epoch {epoch+1}/{num_epochs}"): high_quality = high_quality.to(device) low_quality = low_quality.to(device) @@ -57,32 +65,44 @@ for epoch in range(num_epochs): real_labels = torch.ones(batch_size, 1).to(device) fake_labels = torch.zeros(batch_size, 1).to(device) - real_outputs = discriminator(high_quality) - fake_outputs = discriminator(generator(low_quality)) + ###### Train Discriminator ###### + discriminator.train() + optimizer_d.zero_grad() - d_loss_real = criterion(real_outputs, real_labels) - d_loss_fake = criterion(fake_outputs, fake_labels) - d_loss = (d_loss_real + d_loss_fake) * 0.5 + # 1. Real data + real_outputs = discriminator(high_quality) + d_loss_real = criterion_d(real_outputs, real_labels) + + # 2. Fake data + fake_audio = generator(low_quality) + fake_outputs = discriminator(fake_audio.detach()) # Detach to stop gradient flow to the generator + d_loss_fake = criterion_d(fake_outputs, fake_labels) + + d_loss = (d_loss_real + d_loss_fake) / 2.0 # Without gradient penalty d_loss.backward() optimizer_d.step() + total_d_loss += d_loss.item() - # Train Generator + generator.train() optimizer_g.zero_grad() - fake_audio = generator(low_quality) - fake_outputs = discriminator(fake_audio) - g_loss = criterion(fake_outputs, real_labels) + + # Generator loss: how well fake data fools the discriminator + fake_outputs = discriminator(fake_audio) # No detach here + g_loss = criterion_g(fake_outputs, real_labels) # Train generator to produce real-like outputs + g_loss.backward() optimizer_g.step() + total_g_loss += g_loss.item() - original = high_quality - crap_audio = fake_audio + low_quality_audio = low_quality + high_quality_audio = high_quality + ai_enhanced_audio = fake_audio if epoch % 10 == 0: - print(crap_audio.size()) - torchaudio.save(f"./epoch-{epoch}-audio.wav", crap_audio[0].cpu(), 44100) - torchaudio.save(f"./epoch-{epoch}-audio-orig.wav", original[0].cpu(), 44100) - - print(f'Epoch [{epoch+1}/{num_epochs}]') + print(f"Saved epoch {epoch}!") + torchaudio.save(f"./output/epoch-{epoch}-audio-crap.wav", low_quality_audio[0].cpu(), 44100) + torchaudio.save(f"./output/epoch-{epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu(), 44100) + torchaudio.save(f"./output/epoch-{epoch}-audio-orig.wav", high_quality_audio[0].cpu(), 44100) torch.save(generator.state_dict(), "generator.pt") torch.save(discriminator.state_dict(), "discriminator.pt")