⚗️ | Experimenting still...

This commit is contained in:
NikkeDoy 2024-12-25 00:09:57 +02:00
parent 1000692f32
commit eca71ff5ea
6 changed files with 167 additions and 149 deletions

18
AudioUtils.py Normal file
View File

@ -0,0 +1,18 @@
import torch
import torch.nn.functional as F
def stereo_tensor_to_mono(waveform):
if waveform.shape[0] > 1:
# Average across channels
mono_waveform = torch.mean(waveform, dim=0, keepdim=True)
else:
# Already mono
mono_waveform = waveform
return mono_waveform
def stretch_tensor(tensor, target_length):
scale_factor = target_length / tensor.size(1)
tensor = F.interpolate(tensor, scale_factor=scale_factor, mode='linear', align_corners=False)
return tensor

36
data.py
View File

@ -1,49 +1,31 @@
from torch.utils.data import Dataset from torch.utils.data import Dataset
import torch.nn.functional as F import torch.nn.functional as F
import torch
import torchaudio import torchaudio
import os import os
import random import random
import torchaudio.transforms as T
import AudioUtils
class AudioDataset(Dataset): class AudioDataset(Dataset):
audio_sample_rates = [8000, 11025, 16000, 22050] #audio_sample_rates = [8000, 11025, 16000, 22050]
audio_sample_rates = [11025]
def __init__(self, input_dir, target_duration=None, padding_mode='constant', padding_value=0.0): def __init__(self, input_dir):
self.input_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.wav')] self.input_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.wav')]
self.target_duration = target_duration # Duration in seconds or None if not set
self.padding_mode = padding_mode
self.padding_value = padding_value
def __len__(self): def __len__(self):
return len(self.input_files) return len(self.input_files)
def __getitem__(self, idx): def __getitem__(self, idx):
# Load high-quality audio
high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True) high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True)
# Generate low-quality audio with random downsampling
mangled_sample_rate = random.choice(self.audio_sample_rates) mangled_sample_rate = random.choice(self.audio_sample_rates)
resample_transform = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate) resample_transform = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate)
low_quality_audio = resample_transform(high_quality_audio) low_quality_audio = resample_transform(high_quality_audio)
# Calculate target length based on desired duration and 16000 Hz return (AudioUtils.stereo_tensor_to_mono(high_quality_audio), original_sample_rate), (AudioUtils.stereo_tensor_to_mono(low_quality_audio), mangled_sample_rate)
# if self.target_duration is not None:
# target_length = int(self.target_duration * 44100)
# else:
# # Calculate duration of original high quality audio
# target_length = high_quality_wav.size(1)
# Pad both to the calculated target length
# high_quality_wav = self.stretch_tensor(high_quality_wav, target_length)
# low_quality_wav = self.stretch_tensor(low_quality_wav, target_length)
return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate)
def stretch_tensor(self, tensor, target_length):
current_length = tensor.size(1)
scale_factor = target_length / current_length
# Resample the tensor using linear interpolation
tensor = F.interpolate(tensor.unsqueeze(0), scale_factor=scale_factor, mode='linear', align_corners=False).squeeze(0)
return tensor

View File

@ -3,22 +3,28 @@ import torch.nn as nn
class SISUDiscriminator(nn.Module): class SISUDiscriminator(nn.Module):
def __init__(self): def __init__(self):
super(SISUDiscriminator, self).__init__() super(SISUDiscriminator, self).__init__()
layers = 32
self.model = nn.Sequential( self.model = nn.Sequential(
nn.Conv1d(2, 128, kernel_size=3, padding=1), nn.Conv1d(1, layers, kernel_size=5, stride=2, padding=2),
#nn.LeakyReLU(0.2, inplace=True), nn.BatchNorm1d(layers),
nn.Conv1d(128, 256, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True), nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(256, 128, kernel_size=3, padding=1), nn.Conv1d(layers, layers * 2, kernel_size=5, stride=2, padding=2),
#nn.LeakyReLU(0.2, inplace=True), nn.BatchNorm1d(layers * 2),
nn.Conv1d(128, 64, kernel_size=3, padding=1), nn.LeakyReLU(0.2, inplace=True),
#nn.LeakyReLU(0.2, inplace=True), nn.Conv1d(layers * 2, layers * 4, kernel_size=5, stride=2, padding=2),
nn.Conv1d(64, 1, kernel_size=3, padding=1), nn.BatchNorm1d(layers * 4),
#nn.LeakyReLU(0.2, inplace=True), nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(layers * 4, layers * 8, kernel_size=5, stride=2, padding=2),
nn.BatchNorm1d(layers * 8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(layers * 8, 1, kernel_size=3, padding=1),
) )
self.global_avg_pool = nn.AdaptiveAvgPool1d(1) # Output size (1,) self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
self.sigmoid = nn.Sigmoid()
def forward(self, x): def forward(self, x):
x = self.model(x) x = self.model(x)
x = self.global_avg_pool(x) x = self.global_avg_pool(x)
x = x.view(-1, 1) # Flatten to (batch_size, 1) x = x.view(-1, 1)
x = self.sigmoid(x)
return x return x

View File

@ -1,39 +1,32 @@
import torch.nn as nn import torch.nn as nn
class SISUGenerator(nn.Module): class SISUGenerator(nn.Module):
def __init__(self, upscale_scale=1): def __init__(self, upscale_scale=4): # No noise_dim parameter
super(SISUGenerator, self).__init__() super(SISUGenerator, self).__init__()
self.layers1 = nn.Sequential( layer = 32
nn.Conv1d(2, 128, kernel_size=3, padding=1), # Convolution layers
nn.LeakyReLU(0.2, inplace=True), # Activation self.conv1 = nn.Sequential(
nn.BatchNorm1d(128), # Batch Norm nn.Conv1d(1, layer * 2, kernel_size=7, padding=1),
nn.Conv1d(128, 256, kernel_size=3, padding=1), nn.PReLU(),
nn.LeakyReLU(0.2, inplace=True), # Activation nn.Conv1d(layer * 2, layer * 5, kernel_size=5, padding=1),
nn.BatchNorm1d(256), # Batch Norm nn.PReLU(),
nn.Conv1d(layer * 5, layer * 5, kernel_size=3, padding=1),
nn.PReLU()
) )
self.layers2 = nn.Sequential( # Transposed convolution for upsampling
nn.Conv1d(256, 128, kernel_size=3, padding=1), self.upsample = nn.ConvTranspose1d(layer * 5, layer * 5, kernel_size=upscale_scale, stride=upscale_scale)
nn.LeakyReLU(0.2, inplace=True), # Activation
nn.BatchNorm1d(128), # Batch Norm self.conv2 = nn.Sequential(
nn.Conv1d(128, 64, kernel_size=3, padding=1), nn.Conv1d(layer * 5, layer * 5, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True), # Activation nn.PReLU(),
nn.BatchNorm1d(64), # Batch Norm nn.Conv1d(layer * 5, layer * 2, kernel_size=5, padding=1),
nn.Conv1d(64, upscale_scale * 2, kernel_size=3, padding=1), # Output channels scaled nn.PReLU(),
nn.Conv1d(layer * 2, 1, kernel_size=7, padding=1)
) )
self.upscale_factor = upscale_scale
def pixel_shuffle_1d(self, input, upscale_factor): def forward(self, x, upscale_scale=4):
batch_size, channels, in_width = input.size() x = self.conv1(x)
out_width = in_width * upscale_factor x = self.upsample(x)
input_view = input.contiguous().view(batch_size, channels // upscale_factor, upscale_factor, in_width) x = self.conv2(x)
shuffle_out = input_view.permute(0, 1, 3, 2).contiguous()
return shuffle_out.view(batch_size, channels // upscale_factor, out_width)
def forward(self, x, scale):
x = self.layers1(x)
upsample = nn.Upsample(scale_factor=scale, mode='nearest')
x = upsample(x)
x = self.layers2(x)
x = self.pixel_shuffle_1d(x, self.upscale_factor)
return x return x

View File

@ -1,12 +1,14 @@
filelock>=3.16.1 filelock==3.16.1
fsspec>=2024.10.0 fsspec==2024.10.0
Jinja2>=3.1.4 Jinja2==3.1.4
MarkupSafe>=2.1.5 MarkupSafe==2.1.5
mpmath>=1.3.0 mpmath==1.3.0
networkx>=3.4.2 networkx==3.4.2
numpy>=2.1.2 numpy==2.2.1
pillow>=11.0.0 pytorch-triton-rocm==3.2.0+git0d4682f0
setuptools>=70.2.0 setuptools==70.2.0
sympy>=1.13.1 sympy==1.13.1
tqdm>=4.67.1 torch==2.6.0.dev20241222+rocm6.2.4
typing_extensions>=4.12.2 torchaudio==2.6.0.dev20241222+rocm6.2.4
tqdm==4.67.1
typing_extensions==4.12.2

View File

@ -6,66 +6,73 @@ import torch.nn.functional as F
import torchaudio import torchaudio
import tqdm import tqdm
import argparse
import math
from torch.utils.data import random_split from torch.utils.data import random_split
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import AudioUtils
from data import AudioDataset from data import AudioDataset
from generator import SISUGenerator from generator import SISUGenerator
from discriminator import SISUDiscriminator from discriminator import SISUDiscriminator
# Mel Spectrogram Loss def perceptual_loss(y_true, y_pred):
class MelSpectrogramLoss(nn.Module): return torch.mean((y_true - y_pred) ** 2)
def __init__(self, sample_rate=44100, n_fft=2048, hop_length=512, n_mels=128):
super(MelSpectrogramLoss, self).__init__()
self.mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
n_mels=n_mels
).to(device) # Move to device
def forward(self, y_pred, y_true): def discriminator_train(high_quality, low_quality, real_labels, fake_labels):
mel_pred = self.mel_transform(y_pred)
mel_true = self.mel_transform(y_true)
return F.l1_loss(mel_pred, mel_true)
def snr(y_true, y_pred):
noise = y_true - y_pred
signal_power = torch.mean(y_true ** 2)
noise_power = torch.mean(noise ** 2)
snr_db = 10 * torch.log10(signal_power / noise_power)
return snr_db
def discriminator_train(high_quality, low_quality, scale, real_labels, fake_labels):
optimizer_d.zero_grad() optimizer_d.zero_grad()
discriminator_decision_from_real = discriminator(high_quality) # Forward pass for real samples
# TODO: Experiment with criterions HERE! discriminator_decision_from_real = discriminator(high_quality[0])
d_loss_real = criterion_d(discriminator_decision_from_real, real_labels) d_loss_real = criterion_d(discriminator_decision_from_real, real_labels)
generator_output = generator(low_quality, scale) integer_scale = math.ceil(high_quality[1]/low_quality[1])
discriminator_decision_from_fake = discriminator(generator_output.detach())
# TODO: Experiment with criterions HERE! # Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0], integer_scale)
resample_transform = torchaudio.transforms.Resample(low_quality[1] * integer_scale, high_quality[1]).to(device)
resampled = resample_transform(generator_output.detach())
discriminator_decision_from_fake = discriminator(resampled)
d_loss_fake = criterion_d(discriminator_decision_from_fake, fake_labels) d_loss_fake = criterion_d(discriminator_decision_from_fake, fake_labels)
# Combine real and fake losses
d_loss = (d_loss_real + d_loss_fake) / 2.0 d_loss = (d_loss_real + d_loss_fake) / 2.0
# Backward pass and optimization
d_loss.backward() d_loss.backward()
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) #Gradient Clipping nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping
optimizer_d.step() optimizer_d.step()
return d_loss return d_loss
def generator_train(low_quality, scale, real_labels): def generator_train(low_quality, real_labels, target_sample_rate=44100):
optimizer_g.zero_grad() optimizer_g.zero_grad()
generator_output = generator(low_quality, scale) scale = math.ceil(target_sample_rate/low_quality[1])
discriminator_decision = discriminator(generator_output)
# TODO: Fix this shit # Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0], scale)
resample_transform = torchaudio.transforms.Resample(low_quality[1] * scale, target_sample_rate).to(device)
resampled = resample_transform(generator_output)
discriminator_decision = discriminator(resampled)
g_loss = criterion_g(discriminator_decision, real_labels) g_loss = criterion_g(discriminator_decision, real_labels)
g_loss.backward() g_loss.backward()
optimizer_g.step() optimizer_g.step()
return generator_output return resampled
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--generator", type=str, default=None,
help="Path to the generator model file")
parser.add_argument("--discriminator", type=str, default=None,
help="Path to the discriminator model file")
args = parser.parse_args()
# Check for CUDA availability # Check for CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -73,28 +80,38 @@ print(f"Using device: {device}")
# Initialize dataset and dataloader # Initialize dataset and dataloader
dataset_dir = './dataset/good' dataset_dir = './dataset/good'
dataset = AudioDataset(dataset_dir, target_duration=2.0) dataset = AudioDataset(dataset_dir)
dataset_size = len(dataset) # ========= MULTIPLE =========
train_size = int(dataset_size * .9)
val_size = int(dataset_size-train_size)
train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) # dataset_size = len(dataset)
# train_size = int(dataset_size * .9)
# val_size = int(dataset_size-train_size)
train_data_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) #train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
val_data_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)
# train_data_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
# val_data_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)
# ========= SINGLE =========
train_data_loader = DataLoader(dataset, batch_size=1, shuffle=True)
# Initialize models and move them to device # Initialize models and move them to device
generator = SISUGenerator() generator = SISUGenerator()
discriminator = SISUDiscriminator() discriminator = SISUDiscriminator()
if args.generator is not None:
generator.load_state_dict(torch.load(args.generator, weights_only=True))
if args.discriminator is not None:
discriminator.load_state_dict(torch.load(args.discriminator, weights_only=True))
generator = generator.to(device) generator = generator.to(device)
discriminator = discriminator.to(device) discriminator = discriminator.to(device)
# Loss # Loss
criterion_g = nn.L1Loss() criterion_g = nn.L1Loss()
criterion_g_mel = MelSpectrogramLoss().to(device) criterion_d = nn.BCELoss()
criterion_d = nn.BCEWithLogitsLoss()
# Optimizers # Optimizers
optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999)) optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
@ -109,39 +126,40 @@ def start_training():
# Training loop # Training loop
# ========= DISCRIMINATOR PRE-TRAINING ========= # ========= DISCRIMINATOR PRE-TRAINING =========
discriminator_epochs = 1 # discriminator_epochs = 1
for discriminator_epoch in range(discriminator_epochs): # for discriminator_epoch in range(discriminator_epochs):
# ========= TRAINING ========= # # ========= TRAINING =========
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Epoch {discriminator_epoch+1}/{discriminator_epochs}"): # for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Epoch {discriminator_epoch+1}/{discriminator_epochs}"):
high_quality_sample = high_quality_clip[0].to(device) # high_quality_sample = high_quality_clip[0].to(device)
low_quality_sample = low_quality_clip[0].to(device) # low_quality_sample = low_quality_clip[0].to(device)
scale = high_quality_clip[0].shape[2]/low_quality_clip[0].shape[2] # scale = high_quality_clip[0].shape[2]/low_quality_clip[0].shape[2]
# ========= LABELS ========= # # ========= LABELS =========
batch_size = high_quality_sample.size(0) # batch_size = high_quality_sample.size(0)
real_labels = torch.ones(batch_size, 1).to(device) # real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device) # fake_labels = torch.zeros(batch_size, 1).to(device)
# ========= DISCRIMINATOR ========= # # ========= DISCRIMINATOR =========
discriminator.train() # discriminator.train()
discriminator_train(high_quality_sample, low_quality_sample, scale, real_labels, fake_labels) # discriminator_train(high_quality_sample, low_quality_sample, scale, real_labels, fake_labels)
torch.save(discriminator.state_dict(), "models/discriminator-single-shot-pre-train.pt") # torch.save(discriminator.state_dict(), "models/discriminator-single-shot-pre-train.pt")
generator_epochs = 500 generator_epochs = 5000
for generator_epoch in range(generator_epochs): for generator_epoch in range(generator_epochs):
low_quality_audio = (torch.empty((1)), 1) low_quality_audio = (torch.empty((1)), 1)
high_quality_audio = (torch.empty((1)), 1) high_quality_audio = (torch.empty((1)), 1)
ai_enhanced_audio = (torch.empty((1)), 1) ai_enhanced_audio = (torch.empty((1)), 1)
times_correct = 0
# ========= TRAINING ========= # ========= TRAINING =========
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Epoch {generator_epoch+1}/{generator_epochs}"): for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Epoch {generator_epoch+1}/{generator_epochs}"):
high_quality_sample = high_quality_clip[0].to(device) # for high_quality_clip, low_quality_clip in train_data_loader:
low_quality_sample = low_quality_clip[0].to(device) high_quality_sample = (high_quality_clip[0].to(device), high_quality_clip[1])
low_quality_sample = (low_quality_clip[0].to(device), low_quality_clip[1])
scale = high_quality_clip[0].shape[2]/low_quality_clip[0].shape[2]
# ========= LABELS ========= # ========= LABELS =========
batch_size = high_quality_clip[0].size(0) batch_size = high_quality_clip[0].size(0)
@ -150,21 +168,20 @@ def start_training():
# ========= DISCRIMINATOR ========= # ========= DISCRIMINATOR =========
discriminator.train() discriminator.train()
for _ in range(3): discriminator_train(high_quality_sample, low_quality_sample, real_labels, fake_labels)
discriminator_train(high_quality_sample, low_quality_sample, scale, real_labels, fake_labels)
# ========= GENERATOR ========= # ========= GENERATOR =========
generator.train() generator.train()
generator_output = generator_train(low_quality_sample, scale, real_labels) generator_output = generator_train(low_quality_sample, real_labels, high_quality_sample[1])
# ========= SAVE LATEST AUDIO ========= # ========= SAVE LATEST AUDIO =========
high_quality_audio = high_quality_clip high_quality_audio = high_quality_clip
low_quality_audio = low_quality_clip low_quality_audio = low_quality_clip
ai_enhanced_audio = (generator_output, high_quality_clip[1]) ai_enhanced_audio = (generator_output, high_quality_clip[1])
metric = snr(high_quality_audio[0].to(device), ai_enhanced_audio[0]) #metric = snr(high_quality_audio[0].to(device), ai_enhanced_audio[0])
print(f"Generator metric {metric}!") #print(f"Generator metric {metric}!")
scheduler_g.step(metric) #scheduler_g.step(metric)
if generator_epoch % 10 == 0: if generator_epoch % 10 == 0:
print(f"Saved epoch {generator_epoch}!") print(f"Saved epoch {generator_epoch}!")