⚗️ | Experimenting, again.

This commit is contained in:
NikkeDoy 2024-12-26 04:00:24 +02:00
parent 2ff45de22d
commit 89f8c68986
4 changed files with 49 additions and 55 deletions

10
data.py
View File

@ -13,7 +13,8 @@ class AudioDataset(Dataset):
audio_sample_rates = [11025]
def __init__(self, input_dir):
self.input_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.wav')]
self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')]
def __len__(self):
return len(self.input_files)
@ -25,7 +26,10 @@ class AudioDataset(Dataset):
# Generate low-quality audio with random downsampling
mangled_sample_rate = random.choice(self.audio_sample_rates)
resample_transform = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate)
low_quality_audio = resample_transform(high_quality_audio)
resample_transform_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate)
low_quality_audio = resample_transform_low(high_quality_audio)
resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate)
low_quality_audio = resample_transform_high(low_quality_audio)
return (AudioUtils.stereo_tensor_to_mono(high_quality_audio), original_sample_rate), (AudioUtils.stereo_tensor_to_mono(low_quality_audio), mangled_sample_rate)

View File

@ -1,30 +1,31 @@
import torch
import torch.nn as nn
import torch.nn.utils as utils
class SISUDiscriminator(nn.Module):
def __init__(self):
super(SISUDiscriminator, self).__init__()
layers = 32
layers = 8
self.model = nn.Sequential(
nn.Conv1d(1, layers, kernel_size=5, stride=2, padding=2),
utils.spectral_norm(nn.Conv1d(1, layers, kernel_size=7, stride=2, padding=3)),
nn.BatchNorm1d(layers),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(layers, layers * 2, kernel_size=5, stride=2, padding=2),
nn.PReLU(),
nn.Conv1d(layers, layers * 2, kernel_size=7, padding=3),
nn.BatchNorm1d(layers * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(layers * 2, layers * 4, kernel_size=5, stride=2, padding=2),
nn.PReLU(),
nn.Conv1d(layers * 2, layers * 4, kernel_size=5, padding=2),
nn.BatchNorm1d(layers * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(layers * 4, layers * 8, kernel_size=5, stride=2, padding=2),
nn.PReLU(),
nn.Conv1d(layers * 4, layers * 8, kernel_size=3, padding=1),
nn.BatchNorm1d(layers * 8),
nn.LeakyReLU(0.2, inplace=True),
nn.PReLU(),
nn.Conv1d(layers * 8, 1, kernel_size=3, padding=1),
)
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = x + 0.01 * torch.randn_like(x)
x = self.model(x)
x = self.global_avg_pool(x)
x = x.view(-1, 1)
x = self.sigmoid(x)
return x

View File

@ -1,32 +1,31 @@
import torch.nn as nn
class SISUGenerator(nn.Module):
def __init__(self, upscale_scale=4): # No noise_dim parameter
def __init__(self):
super(SISUGenerator, self).__init__()
layer = 32
# Convolution layers
layer = 16
# Convolution layers with BatchNorm and Residuals
self.conv1 = nn.Sequential(
nn.Conv1d(1, layer * 2, kernel_size=7, padding=1),
nn.Conv1d(1, layer * 2, kernel_size=7, padding=3),
nn.BatchNorm1d(layer * 2),
nn.PReLU(),
nn.Conv1d(layer * 2, layer * 5, kernel_size=5, padding=1),
nn.Conv1d(layer * 2, layer * 5, kernel_size=7, padding=3),
nn.BatchNorm1d(layer * 5),
nn.PReLU(),
nn.Conv1d(layer * 5, layer * 5, kernel_size=3, padding=1),
nn.PReLU()
nn.Conv1d(layer * 5, layer * 5, kernel_size=7, padding=3),
nn.BatchNorm1d(layer * 5),
nn.PReLU(),
)
self.final_layer = nn.Sequential(
nn.Conv1d(layer * 5, layer * 2, kernel_size=5, padding=2),
nn.BatchNorm1d(layer * 2),
nn.PReLU(),
nn.Conv1d(layer * 2, 1, kernel_size=3, padding=1),
# nn.Tanh() # Normalize audio... if needed...
)
# Transposed convolution for upsampling
self.upsample = nn.ConvTranspose1d(layer * 5, layer * 5, kernel_size=upscale_scale, stride=upscale_scale)
self.conv2 = nn.Sequential(
nn.Conv1d(layer * 5, layer * 5, kernel_size=3, padding=1),
nn.PReLU(),
nn.Conv1d(layer * 5, layer * 2, kernel_size=5, padding=1),
nn.PReLU(),
nn.Conv1d(layer * 2, 1, kernel_size=7, padding=1)
)
def forward(self, x, upscale_scale=4):
def forward(self, x):
residual = x
x = self.conv1(x)
x = self.upsample(x)
x = self.conv2(x)
return x
x = self.final_layer(x)
return x + residual

View File

@ -28,14 +28,9 @@ def discriminator_train(high_quality, low_quality, real_labels, fake_labels):
discriminator_decision_from_real = discriminator(high_quality[0])
d_loss_real = criterion_d(discriminator_decision_from_real, real_labels)
integer_scale = math.ceil(high_quality[1]/low_quality[1])
# Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0], integer_scale)
resample_transform = torchaudio.transforms.Resample(low_quality[1] * integer_scale, high_quality[1]).to(device)
resampled = resample_transform(generator_output.detach())
discriminator_decision_from_fake = discriminator(resampled)
generator_output = generator(low_quality[0])
discriminator_decision_from_fake = discriminator(generator_output.detach())
d_loss_fake = criterion_d(discriminator_decision_from_fake, fake_labels)
# Combine real and fake losses
@ -48,22 +43,17 @@ def discriminator_train(high_quality, low_quality, real_labels, fake_labels):
return d_loss
def generator_train(low_quality, real_labels, target_sample_rate=44100):
def generator_train(low_quality, real_labels):
optimizer_g.zero_grad()
scale = math.ceil(target_sample_rate/low_quality[1])
# Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0], scale)
resample_transform = torchaudio.transforms.Resample(low_quality[1] * scale, target_sample_rate).to(device)
resampled = resample_transform(generator_output)
discriminator_decision = discriminator(resampled)
generator_output = generator(low_quality[0])
discriminator_decision = discriminator(generator_output)
g_loss = criterion_g(discriminator_decision, real_labels)
g_loss.backward()
optimizer_g.step()
return resampled
return generator_output
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
@ -110,7 +100,7 @@ generator = generator.to(device)
discriminator = discriminator.to(device)
# Loss
criterion_g = nn.L1Loss()
criterion_g = nn.MSELoss()
criterion_d = nn.BCELoss()
# Optimizers
@ -172,7 +162,7 @@ def start_training():
# ========= GENERATOR =========
generator.train()
generator_output = generator_train(low_quality_sample, real_labels, high_quality_sample[1])
generator_output = generator_train(low_quality_sample, real_labels)
# ========= SAVE LATEST AUDIO =========
high_quality_audio = high_quality_clip
@ -185,7 +175,7 @@ def start_training():
if generator_epoch % 10 == 0:
print(f"Saved epoch {generator_epoch}!")
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), low_quality_audio[1])
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from what ever that low_quality had to high_quality
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1])
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1])