8 Commits

Author SHA1 Message Date
1717e7a008 ⚗️ | Experimenting... 2025-02-10 19:35:50 +02:00
fb7b624c87 ⚗️ | Experimenting with very small model. 2025-02-10 12:44:42 +02:00
0790a0d3da ⚗️ | Experimenting with smaller architecture. 2025-01-25 16:48:10 +02:00
f615b39ded ⚗️ | Experimenting with larger model architecture. 2025-01-08 15:33:18 +02:00
89f8c68986 ⚗️ | Experimenting, again. 2024-12-26 04:00:24 +02:00
2ff45de22d 🔥 | Removed unnecessary test file. 2024-12-25 00:10:45 +02:00
eca71ff5ea ⚗️ | Experimenting still... 2024-12-25 00:09:57 +02:00
1000692f32 ⚗️ | Experimenting with other generator architectures. 2024-12-21 23:54:11 +02:00
7 changed files with 192 additions and 170 deletions

18
AudioUtils.py Normal file
View File

@ -0,0 +1,18 @@
import torch
import torch.nn.functional as F
def stereo_tensor_to_mono(waveform):
if waveform.shape[0] > 1:
# Average across channels
mono_waveform = torch.mean(waveform, dim=0, keepdim=True)
else:
# Already mono
mono_waveform = waveform
return mono_waveform
def stretch_tensor(tensor, target_length):
scale_factor = target_length / tensor.size(1)
tensor = F.interpolate(tensor, scale_factor=scale_factor, mode='linear', align_corners=False)
return tensor

57
data.py
View File

@ -1,49 +1,52 @@
from torch.utils.data import Dataset from torch.utils.data import Dataset
import torch.nn.functional as F import torch.nn.functional as F
import torch
import torchaudio import torchaudio
import os import os
import random import random
from AudioUtils import stereo_tensor_to_mono, stretch_tensor
class AudioDataset(Dataset): class AudioDataset(Dataset):
audio_sample_rates = [8000, 11025, 16000, 22050] audio_sample_rates = [11025]
def __init__(self, input_dir, target_duration=None, padding_mode='constant', padding_value=0.0): def __init__(self, input_dir):
self.input_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.wav')] self.input_files = [
self.target_duration = target_duration # Duration in seconds or None if not set os.path.join(root, f)
self.padding_mode = padding_mode for root, _, files in os.walk(input_dir)
self.padding_value = padding_value for f in files if f.endswith('.wav')
]
def __len__(self): def __len__(self):
return len(self.input_files) return len(self.input_files)
def __getitem__(self, idx): def __getitem__(self, idx):
high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True) # Load high-quality audio
high_quality_path = self.input_files[idx]
high_quality_audio, original_sample_rate = torchaudio.load(high_quality_path)
high_quality_audio = stereo_tensor_to_mono(high_quality_audio)
# Generate low-quality audio with random downsampling
mangled_sample_rate = random.choice(self.audio_sample_rates) mangled_sample_rate = random.choice(self.audio_sample_rates)
resample_transform = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate) resample_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate)
low_quality_audio = resample_transform(high_quality_audio) low_quality_audio = resample_low(high_quality_audio)
# Calculate target length based on desired duration and 16000 Hz resample_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate)
# if self.target_duration is not None: low_quality_audio = resample_high(low_quality_audio)
# target_length = int(self.target_duration * 44100)
# else:
# # Calculate duration of original high quality audio
# target_length = high_quality_wav.size(1)
# Pad both to the calculated target length
# high_quality_wav = self.stretch_tensor(high_quality_wav, target_length)
# low_quality_wav = self.stretch_tensor(low_quality_wav, target_length)
# Pad or truncate to match a fixed length
target_length = 44100 # Adjust this based on your data
high_quality_audio = self.pad_or_truncate(high_quality_audio, target_length)
low_quality_audio = self.pad_or_truncate(low_quality_audio, target_length)
return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate) return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate)
def stretch_tensor(self, tensor, target_length): def pad_or_truncate(self, tensor, target_length):
current_length = tensor.size(1) current_length = tensor.size(1)
scale_factor = target_length / current_length if current_length < target_length:
# Pad with zeros
# Resample the tensor using linear interpolation padding = target_length - current_length
tensor = F.interpolate(tensor.unsqueeze(0), scale_factor=scale_factor, mode='linear', align_corners=False).squeeze(0) tensor = F.pad(tensor, (0, padding))
else:
# Truncate to target length
tensor = tensor[:, :target_length]
return tensor return tensor

View File

@ -1,24 +1,38 @@
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.utils as utils
def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1):
padding = (kernel_size // 2) * dilation
return nn.Sequential(
utils.spectral_norm(
nn.Conv1d(in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding
)
),
nn.BatchNorm1d(out_channels),
nn.LeakyReLU(0.2, inplace=True)
)
class SISUDiscriminator(nn.Module): class SISUDiscriminator(nn.Module):
def __init__(self): def __init__(self):
super(SISUDiscriminator, self).__init__() super(SISUDiscriminator, self).__init__()
layers = 4
self.model = nn.Sequential( self.model = nn.Sequential(
nn.Conv1d(2, 128, kernel_size=3, padding=1), discriminator_block(1, layers, kernel_size=7, stride=2, dilation=1),
#nn.LeakyReLU(0.2, inplace=True), discriminator_block(layers, layers * 2, kernel_size=5, stride=2, dilation=1),
nn.Conv1d(128, 256, kernel_size=3, padding=1), discriminator_block(layers * 2, layers * 4, kernel_size=3, dilation=4),
nn.LeakyReLU(0.2, inplace=True), discriminator_block(layers * 4, layers * 4, kernel_size=5, dilation=8),
nn.Conv1d(256, 128, kernel_size=3, padding=1), discriminator_block(layers * 4, layers * 2, kernel_size=3, dilation=16),
#nn.LeakyReLU(0.2, inplace=True), discriminator_block(layers * 2, layers, kernel_size=5, dilation=2),
nn.Conv1d(128, 64, kernel_size=3, padding=1), discriminator_block(layers, 1, kernel_size=3, stride=1)
#nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(64, 1, kernel_size=3, padding=1),
#nn.LeakyReLU(0.2, inplace=True),
) )
self.global_avg_pool = nn.AdaptiveAvgPool1d(1) # Output size (1,) self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
def forward(self, x): def forward(self, x):
x = self.model(x) x = self.model(x)
x = self.global_avg_pool(x) x = self.global_avg_pool(x)
x = x.view(-1, 1) # Flatten to (batch_size, 1) return x.view(-1, 1)
return x

View File

@ -1,27 +1,41 @@
import torch.nn as nn import torch.nn as nn
def conv_residual_block(in_channels, out_channels, kernel_size=3, dilation=1):
padding = (kernel_size // 2) * dilation
return nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=padding),
nn.BatchNorm1d(out_channels),
nn.PReLU(),
nn.Conv1d(out_channels, out_channels, kernel_size, dilation=dilation, padding=padding),
nn.BatchNorm1d(out_channels)
)
class SISUGenerator(nn.Module): class SISUGenerator(nn.Module):
def __init__(self, upscale_scale=1): # No noise_dim parameter def __init__(self):
super(SISUGenerator, self).__init__() super(SISUGenerator, self).__init__()
self.layers1 = nn.Sequential( layers = 4
nn.Conv1d(2, 128, kernel_size=3, padding=1), self.conv1 = nn.Sequential(
# nn.LeakyReLU(0.2, inplace=True), nn.Conv1d(1, layers, kernel_size=7, padding=3),
nn.Conv1d(128, 256, kernel_size=3, padding=1), nn.BatchNorm1d(layers),
# nn.LeakyReLU(0.2, inplace=True), nn.PReLU()
) )
self.layers2 = nn.Sequential( self.conv_blocks = nn.Sequential(
nn.Conv1d(256, 128, kernel_size=3, padding=1), conv_residual_block(layers, layers, kernel_size=3, dilation=1),
# nn.LeakyReLU(0.2, inplace=True), conv_residual_block(layers, layers * 2, kernel_size=5, dilation=2),
nn.Conv1d(128, 64, kernel_size=3, padding=1), conv_residual_block(layers * 2, layers * 4, kernel_size=3, dilation=16),
# nn.LeakyReLU(0.2, inplace=True), conv_residual_block(layers * 4, layers * 2, kernel_size=5, dilation=8),
nn.Conv1d(64, 2, kernel_size=3, padding=1), conv_residual_block(layers * 2, layers, kernel_size=5, dilation=2),
# nn.Tanh() conv_residual_block(layers, layers, kernel_size=3, dilation=1)
) )
def forward(self, x, scale): self.final_layer = nn.Sequential(
x = self.layers1(x) nn.Conv1d(layers, 1, kernel_size=3, padding=1)
upsample = nn.Upsample(scale_factor=scale, mode='nearest') )
x = upsample(x)
x = self.layers2(x) def forward(self, x):
return x residual = x
x = self.conv1(x)
x = self.conv_blocks(x) + x # Adding residual connection after blocks
x = self.final_layer(x)
return x + residual

View File

@ -1,12 +1,14 @@
filelock>=3.16.1 filelock==3.16.1
fsspec>=2024.10.0 fsspec==2024.10.0
Jinja2>=3.1.4 Jinja2==3.1.4
MarkupSafe>=2.1.5 MarkupSafe==2.1.5
mpmath>=1.3.0 mpmath==1.3.0
networkx>=3.4.2 networkx==3.4.2
numpy>=2.1.2 numpy==2.2.1
pillow>=11.0.0 pytorch-triton-rocm==3.2.0+git0d4682f0
setuptools>=70.2.0 setuptools==70.2.0
sympy>=1.13.1 sympy==1.13.1
tqdm>=4.67.1 torch==2.6.0.dev20241222+rocm6.2.4
typing_extensions>=4.12.2 torchaudio==2.6.0.dev20241222+rocm6.2.4
tqdm==4.67.1
typing_extensions==4.12.2

10
test.py
View File

@ -1,10 +0,0 @@
import torch.nn as nn
import torch
from discriminator import SISUDiscriminator
discriminator = SISUDiscriminator()
test_input = torch.randn(1, 2, 1000) # Example input (batch_size, channels, frames)
output = discriminator(test_input)
print(output)
print("Output shape:", output.shape)

View File

@ -6,95 +6,96 @@ import torch.nn.functional as F
import torchaudio import torchaudio
import tqdm import tqdm
import argparse
import math
from torch.utils.data import random_split from torch.utils.data import random_split
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import AudioUtils
from data import AudioDataset from data import AudioDataset
from generator import SISUGenerator from generator import SISUGenerator
from discriminator import SISUDiscriminator from discriminator import SISUDiscriminator
# Mel Spectrogram Loss def perceptual_loss(y_true, y_pred):
class MelSpectrogramLoss(nn.Module): return torch.mean((y_true - y_pred) ** 2)
def __init__(self, sample_rate=44100, n_fft=2048, hop_length=512, n_mels=128):
super(MelSpectrogramLoss, self).__init__()
self.mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
n_mels=n_mels
).to(device) # Move to device
def forward(self, y_pred, y_true): def discriminator_train(high_quality, low_quality, real_labels, fake_labels):
mel_pred = self.mel_transform(y_pred)
mel_true = self.mel_transform(y_true)
return F.l1_loss(mel_pred, mel_true)
def snr(y_true, y_pred):
noise = y_true - y_pred
signal_power = torch.mean(y_true ** 2)
noise_power = torch.mean(noise ** 2)
snr_db = 10 * torch.log10(signal_power / noise_power)
return snr_db
def discriminator_train(high_quality, low_quality, scale, real_labels, fake_labels):
optimizer_d.zero_grad() optimizer_d.zero_grad()
discriminator_decision_from_real = discriminator(high_quality) # Forward pass for real samples
# TODO: Experiment with criterions HERE! discriminator_decision_from_real = discriminator(high_quality[0])
d_loss_real = criterion_d(discriminator_decision_from_real, real_labels) d_loss_real = criterion_d(discriminator_decision_from_real, real_labels)
generator_output = generator(low_quality, scale) # Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0])
discriminator_decision_from_fake = discriminator(generator_output.detach()) discriminator_decision_from_fake = discriminator(generator_output.detach())
# TODO: Experiment with criterions HERE!
d_loss_fake = criterion_d(discriminator_decision_from_fake, fake_labels) d_loss_fake = criterion_d(discriminator_decision_from_fake, fake_labels)
# Combine real and fake losses
d_loss = (d_loss_real + d_loss_fake) / 2.0 d_loss = (d_loss_real + d_loss_fake) / 2.0
# Backward pass and optimization
d_loss.backward() d_loss.backward()
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping
optimizer_d.step() optimizer_d.step()
return d_loss return d_loss
def generator_train(low_quality, scale, real_labels): def generator_train(low_quality, real_labels):
optimizer_g.zero_grad() optimizer_g.zero_grad()
generator_output = generator(low_quality, scale) # Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0])
discriminator_decision = discriminator(generator_output) discriminator_decision = discriminator(generator_output)
# TODO: Fix this shit
g_loss = criterion_g(discriminator_decision, real_labels) g_loss = criterion_g(discriminator_decision, real_labels)
g_loss.backward() g_loss.backward()
optimizer_g.step() optimizer_g.step()
return generator_output return generator_output
def first(objects):
if len(objects) >= 1:
return objects[0]
return objects
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--generator", type=str, default=None,
help="Path to the generator model file")
parser.add_argument("--discriminator", type=str, default=None,
help="Path to the discriminator model file")
args = parser.parse_args()
# Check for CUDA availability # Check for CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}") print(f"Using device: {device}")
# Initialize dataset and dataloader # Initialize dataset and dataloader
dataset_dir = './dataset/good' dataset_dir = './dataset/good'
dataset = AudioDataset(dataset_dir, target_duration=2.0) dataset = AudioDataset(dataset_dir)
dataset_size = len(dataset) # ========= SINGLE =========
train_size = int(dataset_size * .9)
val_size = int(dataset_size-train_size)
train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) train_data_loader = DataLoader(dataset, batch_size=16, shuffle=True)
train_data_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_data_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)
# Initialize models and move them to device # Initialize models and move them to device
generator = SISUGenerator() generator = SISUGenerator()
discriminator = SISUDiscriminator() discriminator = SISUDiscriminator()
if args.generator is not None:
generator.load_state_dict(torch.load(args.generator, weights_only=True))
if args.discriminator is not None:
discriminator.load_state_dict(torch.load(args.discriminator, weights_only=True))
generator = generator.to(device) generator = generator.to(device)
discriminator = discriminator.to(device) discriminator = discriminator.to(device)
# Loss # Loss
criterion_g = nn.L1Loss() criterion_g = nn.MSELoss()
criterion_g_mel = MelSpectrogramLoss().to(device) criterion_d = nn.BCELoss()
criterion_d = nn.BCEWithLogitsLoss()
# Optimizers # Optimizers
optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999)) optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
@ -105,43 +106,19 @@ scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min'
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5) scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5)
def start_training(): def start_training():
generator_epochs = 5000
# Training loop
# ========= DISCRIMINATOR PRE-TRAINING =========
discriminator_epochs = 1
for discriminator_epoch in range(discriminator_epochs):
# ========= TRAINING =========
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Epoch {discriminator_epoch+1}/{discriminator_epochs}"):
high_quality_sample = high_quality_clip[0].to(device)
low_quality_sample = low_quality_clip[0].to(device)
scale = high_quality_clip[0].shape[2]/low_quality_clip[0].shape[2]
# ========= LABELS =========
batch_size = high_quality_sample.size(0)
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# ========= DISCRIMINATOR =========
discriminator.train()
discriminator_train(high_quality_sample, low_quality_sample, scale, real_labels, fake_labels)
torch.save(discriminator.state_dict(), "models/discriminator-single-shot-pre-train.pt")
generator_epochs = 500
for generator_epoch in range(generator_epochs): for generator_epoch in range(generator_epochs):
low_quality_audio = (torch.empty((1)), 1) low_quality_audio = (torch.empty((1)), 1)
high_quality_audio = (torch.empty((1)), 1) high_quality_audio = (torch.empty((1)), 1)
ai_enhanced_audio = (torch.empty((1)), 1) ai_enhanced_audio = (torch.empty((1)), 1)
times_correct = 0
# ========= TRAINING ========= # ========= TRAINING =========
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Epoch {generator_epoch+1}/{generator_epochs}"): for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Epoch {generator_epoch+1}/{generator_epochs}"):
high_quality_sample = high_quality_clip[0].to(device) # for high_quality_clip, low_quality_clip in train_data_loader:
low_quality_sample = low_quality_clip[0].to(device) high_quality_sample = (high_quality_clip[0].to(device), high_quality_clip[1])
low_quality_sample = (low_quality_clip[0].to(device), low_quality_clip[1])
scale = high_quality_clip[0].shape[2]/low_quality_clip[0].shape[2]
# ========= LABELS ========= # ========= LABELS =========
batch_size = high_quality_clip[0].size(0) batch_size = high_quality_clip[0].size(0)
@ -150,34 +127,38 @@ def start_training():
# ========= DISCRIMINATOR ========= # ========= DISCRIMINATOR =========
discriminator.train() discriminator.train()
for _ in range(3): discriminator_train(high_quality_sample, low_quality_sample, real_labels, fake_labels)
discriminator_train(high_quality_sample, low_quality_sample, scale, real_labels, fake_labels)
# ========= GENERATOR ========= # ========= GENERATOR =========
generator.train() generator.train()
generator_output = generator_train(low_quality_sample, scale, real_labels) generator_output = generator_train(low_quality_sample, real_labels)
# ========= SAVE LATEST AUDIO ========= # ========= SAVE LATEST AUDIO =========
high_quality_audio = high_quality_clip high_quality_audio = (first(high_quality_clip[0]), high_quality_clip[1][0])
low_quality_audio = low_quality_clip low_quality_audio = (first(low_quality_clip[0]), low_quality_clip[1][0])
ai_enhanced_audio = (generator_output, high_quality_clip[1]) ai_enhanced_audio = (first(generator_output[0]), high_quality_clip[1][0])
print(high_quality_audio)
metric = snr(high_quality_audio[0].to(device), ai_enhanced_audio[0])
print(f"Generator metric {metric}!")
scheduler_g.step(metric)
if generator_epoch % 10 == 0:
print(f"Saved epoch {generator_epoch}!") print(f"Saved epoch {generator_epoch}!")
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), low_quality_audio[1]) torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1]) torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1])
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1]) torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1])
if generator_epoch % 50 == 0: #metric = snr(high_quality_audio[0].to(device), ai_enhanced_audio[0])
torch.save(discriminator.state_dict(), f"models/epoch-{generator_epoch}-discriminator.pt") #print(f"Generator metric {metric}!")
torch.save(generator.state_dict(), f"models/epoch-{generator_epoch}-generator.pt") #scheduler_g.step(metric)
torch.save(discriminator.state_dict(), "models/epoch-500-discriminator.pt") if generator_epoch % 10 == 0:
torch.save(generator.state_dict(), "models/epoch-500-generator.pt") print(f"Saved epoch {generator_epoch}!")
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1])
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1])
torch.save(discriminator.state_dict(), f"models/current-epoch-discriminator.pt")
torch.save(generator.state_dict(), f"models/current-epoch-generator.pt")
torch.save(discriminator.state_dict(), "models/epoch-5000-discriminator.pt")
torch.save(generator.state_dict(), "models/epoch-5000-generator.pt")
print("Training complete!") print("Training complete!")
start_training() start_training()