10 Commits

Author SHA1 Message Date
1717e7a008 ⚗️ | Experimenting... 2025-02-10 19:35:50 +02:00
fb7b624c87 ⚗️ | Experimenting with very small model. 2025-02-10 12:44:42 +02:00
0790a0d3da ⚗️ | Experimenting with smaller architecture. 2025-01-25 16:48:10 +02:00
f615b39ded ⚗️ | Experimenting with larger model architecture. 2025-01-08 15:33:18 +02:00
89f8c68986 ⚗️ | Experimenting, again. 2024-12-26 04:00:24 +02:00
2ff45de22d 🔥 | Removed unnecessary test file. 2024-12-25 00:10:45 +02:00
eca71ff5ea ⚗️ | Experimenting still... 2024-12-25 00:09:57 +02:00
1000692f32 ⚗️ | Experimenting with other generator architectures. 2024-12-21 23:54:11 +02:00
de72ee31ea 🔥 | Removed unnecessary models. 2024-12-21 23:28:34 +02:00
70e20f53d4 ⚗️ | Experiment with other layer layouts. 2024-12-21 23:27:38 +02:00
8 changed files with 226 additions and 152 deletions

1
.gitignore vendored
View File

@ -166,3 +166,4 @@ dataset/
old-output/ old-output/
output/ output/
*.wav *.wav
models/

18
AudioUtils.py Normal file
View File

@ -0,0 +1,18 @@
import torch
import torch.nn.functional as F
def stereo_tensor_to_mono(waveform):
if waveform.shape[0] > 1:
# Average across channels
mono_waveform = torch.mean(waveform, dim=0, keepdim=True)
else:
# Already mono
mono_waveform = waveform
return mono_waveform
def stretch_tensor(tensor, target_length):
scale_factor = target_length / tensor.size(1)
tensor = F.interpolate(tensor, scale_factor=scale_factor, mode='linear', align_corners=False)
return tensor

62
data.py
View File

@ -1,50 +1,52 @@
from torch.utils.data import Dataset from torch.utils.data import Dataset
import torch.nn.functional as F import torch.nn.functional as F
import torch
import torchaudio import torchaudio
import os import os
import random import random
from AudioUtils import stereo_tensor_to_mono, stretch_tensor
class AudioDataset(Dataset): class AudioDataset(Dataset):
audio_sample_rates = [8000, 11025, 16000, 22050] audio_sample_rates = [11025]
def __init__(self, input_dir, target_duration=None, padding_mode='constant', padding_value=0.0): def __init__(self, input_dir):
self.input_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.wav')] self.input_files = [
self.target_duration = target_duration # Duration in seconds or None if not set os.path.join(root, f)
self.padding_mode = padding_mode for root, _, files in os.walk(input_dir)
self.padding_value = padding_value for f in files if f.endswith('.wav')
]
def __len__(self): def __len__(self):
return len(self.input_files) return len(self.input_files)
def __getitem__(self, idx): def __getitem__(self, idx):
high_quality_wav, sr_original = torchaudio.load(self.input_files[idx], normalize=True) # Load high-quality audio
high_quality_path = self.input_files[idx]
high_quality_audio, original_sample_rate = torchaudio.load(high_quality_path)
high_quality_audio = stereo_tensor_to_mono(high_quality_audio)
sample_rate = random.choice(self.audio_sample_rates) # Generate low-quality audio with random downsampling
resample_transform = torchaudio.transforms.Resample(sr_original, sample_rate) mangled_sample_rate = random.choice(self.audio_sample_rates)
low_quality_wav = resample_transform(high_quality_wav) resample_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate)
low_quality_wav = low_quality_wav low_quality_audio = resample_low(high_quality_audio)
# Calculate target length based on desired duration and 16000 Hz resample_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate)
if self.target_duration is not None: low_quality_audio = resample_high(low_quality_audio)
target_length = int(self.target_duration * 44100)
else:
# Calculate duration of original high quality audio
target_length = high_quality_wav.size(1)
# Pad both to the calculated target length # Pad or truncate to match a fixed length
high_quality_wav = self.stretch_tensor(high_quality_wav, target_length) target_length = 44100 # Adjust this based on your data
low_quality_wav = self.stretch_tensor(low_quality_wav, target_length) high_quality_audio = self.pad_or_truncate(high_quality_audio, target_length)
low_quality_audio = self.pad_or_truncate(low_quality_audio, target_length)
return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate)
return low_quality_wav, high_quality_wav def pad_or_truncate(self, tensor, target_length):
def stretch_tensor(self, tensor, target_length):
current_length = tensor.size(1) current_length = tensor.size(1)
scale_factor = target_length / current_length if current_length < target_length:
# Pad with zeros
# Resample the tensor using linear interpolation padding = target_length - current_length
tensor = F.interpolate(tensor.unsqueeze(0), scale_factor=scale_factor, mode='linear', align_corners=False).squeeze(0) tensor = F.pad(tensor, (0, padding))
else:
# Truncate to target length
tensor = tensor[:, :target_length]
return tensor return tensor

View File

@ -1,24 +1,38 @@
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.utils as utils
def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1):
padding = (kernel_size // 2) * dilation
return nn.Sequential(
utils.spectral_norm(
nn.Conv1d(in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding
)
),
nn.BatchNorm1d(out_channels),
nn.LeakyReLU(0.2, inplace=True)
)
class SISUDiscriminator(nn.Module): class SISUDiscriminator(nn.Module):
def __init__(self): def __init__(self):
super(SISUDiscriminator, self).__init__() super(SISUDiscriminator, self).__init__()
layers = 4
self.model = nn.Sequential( self.model = nn.Sequential(
nn.Conv1d(2, 128, kernel_size=3, padding=1), discriminator_block(1, layers, kernel_size=7, stride=2, dilation=1),
nn.LeakyReLU(0.2, inplace=True), discriminator_block(layers, layers * 2, kernel_size=5, stride=2, dilation=1),
nn.Conv1d(128, 256, kernel_size=3, padding=1), discriminator_block(layers * 2, layers * 4, kernel_size=3, dilation=4),
nn.LeakyReLU(0.2, inplace=True), discriminator_block(layers * 4, layers * 4, kernel_size=5, dilation=8),
nn.Conv1d(256, 128, kernel_size=3, padding=1), discriminator_block(layers * 4, layers * 2, kernel_size=3, dilation=16),
nn.LeakyReLU(0.2, inplace=True), discriminator_block(layers * 2, layers, kernel_size=5, dilation=2),
nn.Conv1d(128, 64, kernel_size=3, padding=1), discriminator_block(layers, 1, kernel_size=3, stride=1)
nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(64, 1, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
) )
self.global_avg_pool = nn.AdaptiveAvgPool1d(1) # Output size (1,) self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
def forward(self, x): def forward(self, x):
x = self.model(x) x = self.model(x)
x = self.global_avg_pool(x) x = self.global_avg_pool(x)
x = x.view(-1, 1) # Flatten to (batch_size, 1) return x.view(-1, 1)
return x

View File

@ -1,23 +1,41 @@
import torch.nn as nn import torch.nn as nn
def conv_residual_block(in_channels, out_channels, kernel_size=3, dilation=1):
padding = (kernel_size // 2) * dilation
return nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=padding),
nn.BatchNorm1d(out_channels),
nn.PReLU(),
nn.Conv1d(out_channels, out_channels, kernel_size, dilation=dilation, padding=padding),
nn.BatchNorm1d(out_channels)
)
class SISUGenerator(nn.Module): class SISUGenerator(nn.Module):
def __init__(self, upscale_scale=1): # No noise_dim parameter def __init__(self):
super(SISUGenerator, self).__init__() super(SISUGenerator, self).__init__()
self.model = nn.Sequential( layers = 4
nn.Conv1d(2, 128, kernel_size=3, padding=1), self.conv1 = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True), nn.Conv1d(1, layers, kernel_size=7, padding=3),
nn.Conv1d(128, 256, kernel_size=3, padding=1), nn.BatchNorm1d(layers),
nn.LeakyReLU(0.2, inplace=True), nn.PReLU()
)
nn.Upsample(scale_factor=upscale_scale, mode='nearest'), self.conv_blocks = nn.Sequential(
conv_residual_block(layers, layers, kernel_size=3, dilation=1),
conv_residual_block(layers, layers * 2, kernel_size=5, dilation=2),
conv_residual_block(layers * 2, layers * 4, kernel_size=3, dilation=16),
conv_residual_block(layers * 4, layers * 2, kernel_size=5, dilation=8),
conv_residual_block(layers * 2, layers, kernel_size=5, dilation=2),
conv_residual_block(layers, layers, kernel_size=3, dilation=1)
)
nn.Conv1d(256, 128, kernel_size=3, padding=1), self.final_layer = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True), nn.Conv1d(layers, 1, kernel_size=3, padding=1)
nn.Conv1d(128, 64, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(64, 2, kernel_size=3, padding=1),
nn.Tanh()
) )
def forward(self, x): def forward(self, x):
return self.model(x) residual = x
x = self.conv1(x)
x = self.conv_blocks(x) + x # Adding residual connection after blocks
x = self.final_layer(x)
return x + residual

View File

@ -1,12 +1,14 @@
filelock>=3.16.1 filelock==3.16.1
fsspec>=2024.10.0 fsspec==2024.10.0
Jinja2>=3.1.4 Jinja2==3.1.4
MarkupSafe>=2.1.5 MarkupSafe==2.1.5
mpmath>=1.3.0 mpmath==1.3.0
networkx>=3.4.2 networkx==3.4.2
numpy>=2.1.2 numpy==2.2.1
pillow>=11.0.0 pytorch-triton-rocm==3.2.0+git0d4682f0
setuptools>=70.2.0 setuptools==70.2.0
sympy>=1.13.1 sympy==1.13.1
tqdm>=4.67.1 torch==2.6.0.dev20241222+rocm6.2.4
typing_extensions>=4.12.2 torchaudio==2.6.0.dev20241222+rocm6.2.4
tqdm==4.67.1
typing_extensions==4.12.2

10
test.py
View File

@ -1,10 +0,0 @@
import torch.nn as nn
import torch
from discriminator import SISUDiscriminator
discriminator = SISUDiscriminator()
test_input = torch.randn(1, 2, 1000) # Example input (batch_size, channels, frames)
output = discriminator(test_input)
print(output)
print("Output shape:", output.shape)

View File

@ -6,40 +6,96 @@ import torch.nn.functional as F
import torchaudio import torchaudio
import tqdm import tqdm
import argparse
import math
from torch.utils.data import random_split from torch.utils.data import random_split
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import AudioUtils
from data import AudioDataset from data import AudioDataset
from generator import SISUGenerator from generator import SISUGenerator
from discriminator import SISUDiscriminator from discriminator import SISUDiscriminator
def perceptual_loss(y_true, y_pred):
return torch.mean((y_true - y_pred) ** 2)
def discriminator_train(high_quality, low_quality, real_labels, fake_labels):
optimizer_d.zero_grad()
# Forward pass for real samples
discriminator_decision_from_real = discriminator(high_quality[0])
d_loss_real = criterion_d(discriminator_decision_from_real, real_labels)
# Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0])
discriminator_decision_from_fake = discriminator(generator_output.detach())
d_loss_fake = criterion_d(discriminator_decision_from_fake, fake_labels)
# Combine real and fake losses
d_loss = (d_loss_real + d_loss_fake) / 2.0
# Backward pass and optimization
d_loss.backward()
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping
optimizer_d.step()
return d_loss
def generator_train(low_quality, real_labels):
optimizer_g.zero_grad()
# Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0])
discriminator_decision = discriminator(generator_output)
g_loss = criterion_g(discriminator_decision, real_labels)
g_loss.backward()
optimizer_g.step()
return generator_output
def first(objects):
if len(objects) >= 1:
return objects[0]
return objects
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--generator", type=str, default=None,
help="Path to the generator model file")
parser.add_argument("--discriminator", type=str, default=None,
help="Path to the discriminator model file")
args = parser.parse_args()
# Check for CUDA availability # Check for CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}") print(f"Using device: {device}")
# Initialize dataset and dataloader # Initialize dataset and dataloader
dataset_dir = './dataset/good' dataset_dir = './dataset/good'
dataset = AudioDataset(dataset_dir, target_duration=2.0) dataset = AudioDataset(dataset_dir)
dataset_size = len(dataset) # ========= SINGLE =========
train_size = int(dataset_size * .9)
val_size = int(dataset_size-train_size)
train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) train_data_loader = DataLoader(dataset, batch_size=16, shuffle=True)
train_data_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_data_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)
# Initialize models and move them to device # Initialize models and move them to device
generator = SISUGenerator() generator = SISUGenerator()
discriminator = SISUDiscriminator() discriminator = SISUDiscriminator()
if args.generator is not None:
generator.load_state_dict(torch.load(args.generator, weights_only=True))
if args.discriminator is not None:
discriminator.load_state_dict(torch.load(args.discriminator, weights_only=True))
generator = generator.to(device) generator = generator.to(device)
discriminator = discriminator.to(device) discriminator = discriminator.to(device)
# Loss # Loss
criterion_g = nn.L1Loss() criterion_g = nn.MSELoss()
criterion_d = nn.BCEWithLogitsLoss() criterion_d = nn.BCELoss()
# Optimizers # Optimizers
optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999)) optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
@ -49,87 +105,60 @@ optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.99
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5) scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5)
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5) scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5)
def snr(y_true, y_pred):
noise = y_true - y_pred
signal_power = torch.mean(y_true ** 2)
noise_power = torch.mean(noise ** 2)
snr_db = 10 * torch.log10(signal_power / noise_power)
return snr_db
def discriminator_train(discriminator, optimizer, criterion, generator, real_labels, fake_labels, high_quality, low_quality):
optimizer.zero_grad()
discriminator_decision_from_real = discriminator(high_quality)
d_loss_real = criterion(discriminator_decision_from_real, real_labels)
generator_output = generator(low_quality)
discriminator_decision_from_fake = discriminator(generator_output.detach())
d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels)
d_loss = (d_loss_real + d_loss_fake) / 2.0
d_loss.backward()
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) #Gradient Clipping
optimizer.step()
# print(f"Discriminator Loss: {d_loss.item():.4f}, Mean Real Logit: {discriminator_decision_from_real.mean().item():.2f}, Mean Fake Logit: {discriminator_decision_from_fake.mean().item():.2f}")
def start_training(): def start_training():
generator_epochs = 5000
# Training loop
# discriminator_epochs = 1000
generator_epochs = 500
for generator_epoch in range(generator_epochs): for generator_epoch in range(generator_epochs):
low_quality_audio = torch.empty((1)) low_quality_audio = (torch.empty((1)), 1)
high_quality_audio = torch.empty((1)) high_quality_audio = (torch.empty((1)), 1)
ai_enhanced_audio = torch.empty((1)) ai_enhanced_audio = (torch.empty((1)), 1)
# Training times_correct = 0
for low_quality, high_quality in tqdm.tqdm(train_data_loader, desc=f"Epoch {generator_epoch+1}/{generator_epochs}"):
high_quality = high_quality.to(device)
low_quality = low_quality.to(device)
batch_size = high_quality.size(0) # ========= TRAINING =========
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Epoch {generator_epoch+1}/{generator_epochs}"):
# for high_quality_clip, low_quality_clip in train_data_loader:
high_quality_sample = (high_quality_clip[0].to(device), high_quality_clip[1])
low_quality_sample = (low_quality_clip[0].to(device), low_quality_clip[1])
# ========= LABELS =========
batch_size = high_quality_clip[0].size(0)
real_labels = torch.ones(batch_size, 1).to(device) real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device) fake_labels = torch.zeros(batch_size, 1).to(device)
# Train Discriminator # ========= DISCRIMINATOR =========
discriminator.train() discriminator.train()
discriminator_train(high_quality_sample, low_quality_sample, real_labels, fake_labels)
for _ in range(3): # ========= GENERATOR =========
discriminator_train(discriminator, optimizer_d, criterion_d, generator, real_labels, fake_labels, high_quality, low_quality)
# Train Generator
generator.train() generator.train()
optimizer_g.zero_grad() generator_output = generator_train(low_quality_sample, real_labels)
# Generator loss: how well fake data fools the discriminator # ========= SAVE LATEST AUDIO =========
generator_output = generator(low_quality) high_quality_audio = (first(high_quality_clip[0]), high_quality_clip[1][0])
discriminator_decision = discriminator(generator_output) # No detach here low_quality_audio = (first(low_quality_clip[0]), low_quality_clip[1][0])
g_loss = criterion_g(discriminator_decision, real_labels) # Train generator to produce real-like outputs ai_enhanced_audio = (first(generator_output[0]), high_quality_clip[1][0])
print(high_quality_audio)
g_loss.backward() print(f"Saved epoch {generator_epoch}!")
optimizer_g.step() torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1])
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1])
low_quality_audio = low_quality #metric = snr(high_quality_audio[0].to(device), ai_enhanced_audio[0])
high_quality_audio = high_quality #print(f"Generator metric {metric}!")
ai_enhanced_audio = generator_output #scheduler_g.step(metric)
metric = snr(high_quality_audio, ai_enhanced_audio)
print(f"Generator metric {metric}!")
scheduler_g.step(metric)
if generator_epoch % 10 == 0: if generator_epoch % 10 == 0:
print(f"Saved epoch {generator_epoch}!") print(f"Saved epoch {generator_epoch}!")
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0].cpu(), 44100) torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu(), 44100) torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1])
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0].cpu(), 44100) torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1])
if generator_epoch % 50 == 0: torch.save(discriminator.state_dict(), f"models/current-epoch-discriminator.pt")
torch.save(discriminator.state_dict(), "discriminator.pt") torch.save(generator.state_dict(), f"models/current-epoch-generator.pt")
torch.save(generator.state_dict(), "generator.pt")
torch.save(discriminator.state_dict(), "discriminator.pt") torch.save(discriminator.state_dict(), "models/epoch-5000-discriminator.pt")
torch.save(generator.state_dict(), "generator.pt") torch.save(generator.state_dict(), "models/epoch-5000-generator.pt")
print("Training complete!") print("Training complete!")
start_training() start_training()