⚗️ | Experimenting, again.
This commit is contained in:
parent
2ff45de22d
commit
89f8c68986
10
data.py
10
data.py
@ -13,7 +13,8 @@ class AudioDataset(Dataset):
|
||||
audio_sample_rates = [11025]
|
||||
|
||||
def __init__(self, input_dir):
|
||||
self.input_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.wav')]
|
||||
self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')]
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_files)
|
||||
@ -25,7 +26,10 @@ class AudioDataset(Dataset):
|
||||
|
||||
# Generate low-quality audio with random downsampling
|
||||
mangled_sample_rate = random.choice(self.audio_sample_rates)
|
||||
resample_transform = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate)
|
||||
low_quality_audio = resample_transform(high_quality_audio)
|
||||
resample_transform_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate)
|
||||
low_quality_audio = resample_transform_low(high_quality_audio)
|
||||
|
||||
resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate)
|
||||
low_quality_audio = resample_transform_high(low_quality_audio)
|
||||
|
||||
return (AudioUtils.stereo_tensor_to_mono(high_quality_audio), original_sample_rate), (AudioUtils.stereo_tensor_to_mono(low_quality_audio), mangled_sample_rate)
|
||||
|
@ -1,30 +1,31 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.utils as utils
|
||||
|
||||
class SISUDiscriminator(nn.Module):
|
||||
def __init__(self):
|
||||
super(SISUDiscriminator, self).__init__()
|
||||
layers = 32
|
||||
layers = 8
|
||||
self.model = nn.Sequential(
|
||||
nn.Conv1d(1, layers, kernel_size=5, stride=2, padding=2),
|
||||
utils.spectral_norm(nn.Conv1d(1, layers, kernel_size=7, stride=2, padding=3)),
|
||||
nn.BatchNorm1d(layers),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv1d(layers, layers * 2, kernel_size=5, stride=2, padding=2),
|
||||
nn.PReLU(),
|
||||
nn.Conv1d(layers, layers * 2, kernel_size=7, padding=3),
|
||||
nn.BatchNorm1d(layers * 2),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv1d(layers * 2, layers * 4, kernel_size=5, stride=2, padding=2),
|
||||
nn.PReLU(),
|
||||
nn.Conv1d(layers * 2, layers * 4, kernel_size=5, padding=2),
|
||||
nn.BatchNorm1d(layers * 4),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv1d(layers * 4, layers * 8, kernel_size=5, stride=2, padding=2),
|
||||
nn.PReLU(),
|
||||
nn.Conv1d(layers * 4, layers * 8, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(layers * 8),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.PReLU(),
|
||||
nn.Conv1d(layers * 8, 1, kernel_size=3, padding=1),
|
||||
)
|
||||
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
x = x + 0.01 * torch.randn_like(x)
|
||||
x = self.model(x)
|
||||
x = self.global_avg_pool(x)
|
||||
x = x.view(-1, 1)
|
||||
x = self.sigmoid(x)
|
||||
return x
|
||||
|
43
generator.py
43
generator.py
@ -1,32 +1,31 @@
|
||||
import torch.nn as nn
|
||||
|
||||
class SISUGenerator(nn.Module):
|
||||
def __init__(self, upscale_scale=4): # No noise_dim parameter
|
||||
def __init__(self):
|
||||
super(SISUGenerator, self).__init__()
|
||||
layer = 32
|
||||
# Convolution layers
|
||||
layer = 16
|
||||
# Convolution layers with BatchNorm and Residuals
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv1d(1, layer * 2, kernel_size=7, padding=1),
|
||||
nn.Conv1d(1, layer * 2, kernel_size=7, padding=3),
|
||||
nn.BatchNorm1d(layer * 2),
|
||||
nn.PReLU(),
|
||||
nn.Conv1d(layer * 2, layer * 5, kernel_size=5, padding=1),
|
||||
nn.Conv1d(layer * 2, layer * 5, kernel_size=7, padding=3),
|
||||
nn.BatchNorm1d(layer * 5),
|
||||
nn.PReLU(),
|
||||
nn.Conv1d(layer * 5, layer * 5, kernel_size=3, padding=1),
|
||||
nn.PReLU()
|
||||
nn.Conv1d(layer * 5, layer * 5, kernel_size=7, padding=3),
|
||||
nn.BatchNorm1d(layer * 5),
|
||||
nn.PReLU(),
|
||||
)
|
||||
self.final_layer = nn.Sequential(
|
||||
nn.Conv1d(layer * 5, layer * 2, kernel_size=5, padding=2),
|
||||
nn.BatchNorm1d(layer * 2),
|
||||
nn.PReLU(),
|
||||
nn.Conv1d(layer * 2, 1, kernel_size=3, padding=1),
|
||||
# nn.Tanh() # Normalize audio... if needed...
|
||||
)
|
||||
|
||||
# Transposed convolution for upsampling
|
||||
self.upsample = nn.ConvTranspose1d(layer * 5, layer * 5, kernel_size=upscale_scale, stride=upscale_scale)
|
||||
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv1d(layer * 5, layer * 5, kernel_size=3, padding=1),
|
||||
nn.PReLU(),
|
||||
nn.Conv1d(layer * 5, layer * 2, kernel_size=5, padding=1),
|
||||
nn.PReLU(),
|
||||
nn.Conv1d(layer * 2, 1, kernel_size=7, padding=1)
|
||||
)
|
||||
|
||||
def forward(self, x, upscale_scale=4):
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = self.conv1(x)
|
||||
x = self.upsample(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
x = self.final_layer(x)
|
||||
return x + residual
|
||||
|
28
training.py
28
training.py
@ -28,14 +28,9 @@ def discriminator_train(high_quality, low_quality, real_labels, fake_labels):
|
||||
discriminator_decision_from_real = discriminator(high_quality[0])
|
||||
d_loss_real = criterion_d(discriminator_decision_from_real, real_labels)
|
||||
|
||||
integer_scale = math.ceil(high_quality[1]/low_quality[1])
|
||||
|
||||
# Forward pass for fake samples (from generator output)
|
||||
generator_output = generator(low_quality[0], integer_scale)
|
||||
resample_transform = torchaudio.transforms.Resample(low_quality[1] * integer_scale, high_quality[1]).to(device)
|
||||
resampled = resample_transform(generator_output.detach())
|
||||
|
||||
discriminator_decision_from_fake = discriminator(resampled)
|
||||
generator_output = generator(low_quality[0])
|
||||
discriminator_decision_from_fake = discriminator(generator_output.detach())
|
||||
d_loss_fake = criterion_d(discriminator_decision_from_fake, fake_labels)
|
||||
|
||||
# Combine real and fake losses
|
||||
@ -48,22 +43,17 @@ def discriminator_train(high_quality, low_quality, real_labels, fake_labels):
|
||||
|
||||
return d_loss
|
||||
|
||||
def generator_train(low_quality, real_labels, target_sample_rate=44100):
|
||||
def generator_train(low_quality, real_labels):
|
||||
optimizer_g.zero_grad()
|
||||
|
||||
scale = math.ceil(target_sample_rate/low_quality[1])
|
||||
|
||||
# Forward pass for fake samples (from generator output)
|
||||
generator_output = generator(low_quality[0], scale)
|
||||
resample_transform = torchaudio.transforms.Resample(low_quality[1] * scale, target_sample_rate).to(device)
|
||||
resampled = resample_transform(generator_output)
|
||||
|
||||
discriminator_decision = discriminator(resampled)
|
||||
generator_output = generator(low_quality[0])
|
||||
discriminator_decision = discriminator(generator_output)
|
||||
g_loss = criterion_g(discriminator_decision, real_labels)
|
||||
|
||||
g_loss.backward()
|
||||
optimizer_g.step()
|
||||
return resampled
|
||||
return generator_output
|
||||
|
||||
# Init script argument parser
|
||||
parser = argparse.ArgumentParser(description="Training script")
|
||||
@ -110,7 +100,7 @@ generator = generator.to(device)
|
||||
discriminator = discriminator.to(device)
|
||||
|
||||
# Loss
|
||||
criterion_g = nn.L1Loss()
|
||||
criterion_g = nn.MSELoss()
|
||||
criterion_d = nn.BCELoss()
|
||||
|
||||
# Optimizers
|
||||
@ -172,7 +162,7 @@ def start_training():
|
||||
|
||||
# ========= GENERATOR =========
|
||||
generator.train()
|
||||
generator_output = generator_train(low_quality_sample, real_labels, high_quality_sample[1])
|
||||
generator_output = generator_train(low_quality_sample, real_labels)
|
||||
|
||||
# ========= SAVE LATEST AUDIO =========
|
||||
high_quality_audio = high_quality_clip
|
||||
@ -185,7 +175,7 @@ def start_training():
|
||||
|
||||
if generator_epoch % 10 == 0:
|
||||
print(f"Saved epoch {generator_epoch}!")
|
||||
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), low_quality_audio[1])
|
||||
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from what ever that low_quality had to high_quality
|
||||
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1])
|
||||
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1])
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user