⚗️ | Experimenting, again.

This commit is contained in:
2024-12-26 04:00:24 +02:00
parent 2ff45de22d
commit 89f8c68986
4 changed files with 49 additions and 55 deletions

10
data.py
View File

@ -13,7 +13,8 @@ class AudioDataset(Dataset):
audio_sample_rates = [11025] audio_sample_rates = [11025]
def __init__(self, input_dir): def __init__(self, input_dir):
self.input_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.wav')] self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')]
def __len__(self): def __len__(self):
return len(self.input_files) return len(self.input_files)
@ -25,7 +26,10 @@ class AudioDataset(Dataset):
# Generate low-quality audio with random downsampling # Generate low-quality audio with random downsampling
mangled_sample_rate = random.choice(self.audio_sample_rates) mangled_sample_rate = random.choice(self.audio_sample_rates)
resample_transform = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate) resample_transform_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate)
low_quality_audio = resample_transform(high_quality_audio) low_quality_audio = resample_transform_low(high_quality_audio)
resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate)
low_quality_audio = resample_transform_high(low_quality_audio)
return (AudioUtils.stereo_tensor_to_mono(high_quality_audio), original_sample_rate), (AudioUtils.stereo_tensor_to_mono(low_quality_audio), mangled_sample_rate) return (AudioUtils.stereo_tensor_to_mono(high_quality_audio), original_sample_rate), (AudioUtils.stereo_tensor_to_mono(low_quality_audio), mangled_sample_rate)

View File

@ -1,30 +1,31 @@
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.utils as utils
class SISUDiscriminator(nn.Module): class SISUDiscriminator(nn.Module):
def __init__(self): def __init__(self):
super(SISUDiscriminator, self).__init__() super(SISUDiscriminator, self).__init__()
layers = 32 layers = 8
self.model = nn.Sequential( self.model = nn.Sequential(
nn.Conv1d(1, layers, kernel_size=5, stride=2, padding=2), utils.spectral_norm(nn.Conv1d(1, layers, kernel_size=7, stride=2, padding=3)),
nn.BatchNorm1d(layers), nn.BatchNorm1d(layers),
nn.LeakyReLU(0.2, inplace=True), nn.PReLU(),
nn.Conv1d(layers, layers * 2, kernel_size=5, stride=2, padding=2), nn.Conv1d(layers, layers * 2, kernel_size=7, padding=3),
nn.BatchNorm1d(layers * 2), nn.BatchNorm1d(layers * 2),
nn.LeakyReLU(0.2, inplace=True), nn.PReLU(),
nn.Conv1d(layers * 2, layers * 4, kernel_size=5, stride=2, padding=2), nn.Conv1d(layers * 2, layers * 4, kernel_size=5, padding=2),
nn.BatchNorm1d(layers * 4), nn.BatchNorm1d(layers * 4),
nn.LeakyReLU(0.2, inplace=True), nn.PReLU(),
nn.Conv1d(layers * 4, layers * 8, kernel_size=5, stride=2, padding=2), nn.Conv1d(layers * 4, layers * 8, kernel_size=3, padding=1),
nn.BatchNorm1d(layers * 8), nn.BatchNorm1d(layers * 8),
nn.LeakyReLU(0.2, inplace=True), nn.PReLU(),
nn.Conv1d(layers * 8, 1, kernel_size=3, padding=1), nn.Conv1d(layers * 8, 1, kernel_size=3, padding=1),
) )
self.global_avg_pool = nn.AdaptiveAvgPool1d(1) self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
self.sigmoid = nn.Sigmoid()
def forward(self, x): def forward(self, x):
x = x + 0.01 * torch.randn_like(x)
x = self.model(x) x = self.model(x)
x = self.global_avg_pool(x) x = self.global_avg_pool(x)
x = x.view(-1, 1) x = x.view(-1, 1)
x = self.sigmoid(x)
return x return x

View File

@ -1,32 +1,31 @@
import torch.nn as nn import torch.nn as nn
class SISUGenerator(nn.Module): class SISUGenerator(nn.Module):
def __init__(self, upscale_scale=4): # No noise_dim parameter def __init__(self):
super(SISUGenerator, self).__init__() super(SISUGenerator, self).__init__()
layer = 32 layer = 16
# Convolution layers # Convolution layers with BatchNorm and Residuals
self.conv1 = nn.Sequential( self.conv1 = nn.Sequential(
nn.Conv1d(1, layer * 2, kernel_size=7, padding=1), nn.Conv1d(1, layer * 2, kernel_size=7, padding=3),
nn.BatchNorm1d(layer * 2),
nn.PReLU(), nn.PReLU(),
nn.Conv1d(layer * 2, layer * 5, kernel_size=5, padding=1), nn.Conv1d(layer * 2, layer * 5, kernel_size=7, padding=3),
nn.BatchNorm1d(layer * 5),
nn.PReLU(), nn.PReLU(),
nn.Conv1d(layer * 5, layer * 5, kernel_size=3, padding=1), nn.Conv1d(layer * 5, layer * 5, kernel_size=7, padding=3),
nn.PReLU() nn.BatchNorm1d(layer * 5),
nn.PReLU(),
)
self.final_layer = nn.Sequential(
nn.Conv1d(layer * 5, layer * 2, kernel_size=5, padding=2),
nn.BatchNorm1d(layer * 2),
nn.PReLU(),
nn.Conv1d(layer * 2, 1, kernel_size=3, padding=1),
# nn.Tanh() # Normalize audio... if needed...
) )
# Transposed convolution for upsampling def forward(self, x):
self.upsample = nn.ConvTranspose1d(layer * 5, layer * 5, kernel_size=upscale_scale, stride=upscale_scale) residual = x
self.conv2 = nn.Sequential(
nn.Conv1d(layer * 5, layer * 5, kernel_size=3, padding=1),
nn.PReLU(),
nn.Conv1d(layer * 5, layer * 2, kernel_size=5, padding=1),
nn.PReLU(),
nn.Conv1d(layer * 2, 1, kernel_size=7, padding=1)
)
def forward(self, x, upscale_scale=4):
x = self.conv1(x) x = self.conv1(x)
x = self.upsample(x) x = self.final_layer(x)
x = self.conv2(x) return x + residual
return x

View File

@ -28,14 +28,9 @@ def discriminator_train(high_quality, low_quality, real_labels, fake_labels):
discriminator_decision_from_real = discriminator(high_quality[0]) discriminator_decision_from_real = discriminator(high_quality[0])
d_loss_real = criterion_d(discriminator_decision_from_real, real_labels) d_loss_real = criterion_d(discriminator_decision_from_real, real_labels)
integer_scale = math.ceil(high_quality[1]/low_quality[1])
# Forward pass for fake samples (from generator output) # Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0], integer_scale) generator_output = generator(low_quality[0])
resample_transform = torchaudio.transforms.Resample(low_quality[1] * integer_scale, high_quality[1]).to(device) discriminator_decision_from_fake = discriminator(generator_output.detach())
resampled = resample_transform(generator_output.detach())
discriminator_decision_from_fake = discriminator(resampled)
d_loss_fake = criterion_d(discriminator_decision_from_fake, fake_labels) d_loss_fake = criterion_d(discriminator_decision_from_fake, fake_labels)
# Combine real and fake losses # Combine real and fake losses
@ -48,22 +43,17 @@ def discriminator_train(high_quality, low_quality, real_labels, fake_labels):
return d_loss return d_loss
def generator_train(low_quality, real_labels, target_sample_rate=44100): def generator_train(low_quality, real_labels):
optimizer_g.zero_grad() optimizer_g.zero_grad()
scale = math.ceil(target_sample_rate/low_quality[1])
# Forward pass for fake samples (from generator output) # Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0], scale) generator_output = generator(low_quality[0])
resample_transform = torchaudio.transforms.Resample(low_quality[1] * scale, target_sample_rate).to(device) discriminator_decision = discriminator(generator_output)
resampled = resample_transform(generator_output)
discriminator_decision = discriminator(resampled)
g_loss = criterion_g(discriminator_decision, real_labels) g_loss = criterion_g(discriminator_decision, real_labels)
g_loss.backward() g_loss.backward()
optimizer_g.step() optimizer_g.step()
return resampled return generator_output
# Init script argument parser # Init script argument parser
parser = argparse.ArgumentParser(description="Training script") parser = argparse.ArgumentParser(description="Training script")
@ -110,7 +100,7 @@ generator = generator.to(device)
discriminator = discriminator.to(device) discriminator = discriminator.to(device)
# Loss # Loss
criterion_g = nn.L1Loss() criterion_g = nn.MSELoss()
criterion_d = nn.BCELoss() criterion_d = nn.BCELoss()
# Optimizers # Optimizers
@ -172,7 +162,7 @@ def start_training():
# ========= GENERATOR ========= # ========= GENERATOR =========
generator.train() generator.train()
generator_output = generator_train(low_quality_sample, real_labels, high_quality_sample[1]) generator_output = generator_train(low_quality_sample, real_labels)
# ========= SAVE LATEST AUDIO ========= # ========= SAVE LATEST AUDIO =========
high_quality_audio = high_quality_clip high_quality_audio = high_quality_clip
@ -185,7 +175,7 @@ def start_training():
if generator_epoch % 10 == 0: if generator_epoch % 10 == 0:
print(f"Saved epoch {generator_epoch}!") print(f"Saved epoch {generator_epoch}!")
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), low_quality_audio[1]) torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from what ever that low_quality had to high_quality
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1]) torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1])
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1]) torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1])