15 Commits

8 changed files with 291 additions and 157 deletions

1
.gitignore vendored
View File

@ -166,3 +166,4 @@ dataset/
old-output/ old-output/
output/ output/
*.wav *.wav
models/

18
AudioUtils.py Normal file
View File

@ -0,0 +1,18 @@
import torch
import torch.nn.functional as F
def stereo_tensor_to_mono(waveform):
if waveform.shape[0] > 1:
# Average across channels
mono_waveform = torch.mean(waveform, dim=0, keepdim=True)
else:
# Already mono
mono_waveform = waveform
return mono_waveform
def stretch_tensor(tensor, target_length):
scale_factor = target_length / tensor.size(1)
tensor = F.interpolate(tensor, scale_factor=scale_factor, mode='linear', align_corners=False)
return tensor

63
data.py
View File

@ -1,50 +1,53 @@
from torch.utils.data import Dataset from torch.utils.data import Dataset
import torch.nn.functional as F import torch.nn.functional as F
import torch
import torchaudio import torchaudio
import os import os
import random import random
import torchaudio.transforms as T
import AudioUtils
class AudioDataset(Dataset): class AudioDataset(Dataset):
audio_sample_rates = [8000, 11025, 16000, 22050] audio_sample_rates = [11025]
MAX_LENGTH = 44100 # Define your desired maximum length here
def __init__(self, input_dir, target_duration=None, padding_mode='constant', padding_value=0.0): def __init__(self, input_dir, device):
self.input_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.wav')] self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')]
self.target_duration = target_duration # Duration in seconds or None if not set self.device = device
self.padding_mode = padding_mode
self.padding_value = padding_value
def __len__(self): def __len__(self):
return len(self.input_files) return len(self.input_files)
def __getitem__(self, idx): def __getitem__(self, idx):
high_quality_wav, sr_original = torchaudio.load(self.input_files[idx], normalize=True) # Load high-quality audio
high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True)
sample_rate = random.choice(self.audio_sample_rates) # Generate low-quality audio with random downsampling
resample_transform = torchaudio.transforms.Resample(sr_original, sample_rate) mangled_sample_rate = random.choice(self.audio_sample_rates)
low_quality_wav = resample_transform(high_quality_wav) resample_transform_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate)
low_quality_wav = low_quality_wav low_quality_audio = resample_transform_low(high_quality_audio)
# Calculate target length based on desired duration and 16000 Hz resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate)
if self.target_duration is not None: low_quality_audio = resample_transform_high(low_quality_audio)
target_length = int(self.target_duration * 44100)
else:
# Calculate duration of original high quality audio
target_length = high_quality_wav.size(1)
# Pad both to the calculated target length high_quality_audio = AudioUtils.stereo_tensor_to_mono(high_quality_audio)
high_quality_wav = self.stretch_tensor(high_quality_wav, target_length) low_quality_audio = AudioUtils.stereo_tensor_to_mono(low_quality_audio)
low_quality_wav = self.stretch_tensor(low_quality_wav, target_length)
# Pad or truncate high-quality audio
if high_quality_audio.shape[1] < self.MAX_LENGTH:
padding = self.MAX_LENGTH - high_quality_audio.shape[1]
high_quality_audio = F.pad(high_quality_audio, (0, padding))
elif high_quality_audio.shape[1] > self.MAX_LENGTH:
high_quality_audio = high_quality_audio[:, :self.MAX_LENGTH]
return low_quality_wav, high_quality_wav # Pad or truncate low-quality audio
if low_quality_audio.shape[1] < self.MAX_LENGTH:
padding = self.MAX_LENGTH - low_quality_audio.shape[1]
low_quality_audio = F.pad(low_quality_audio, (0, padding))
elif low_quality_audio.shape[1] > self.MAX_LENGTH:
low_quality_audio = low_quality_audio[:, :self.MAX_LENGTH]
def stretch_tensor(self, tensor, target_length): high_quality_audio = high_quality_audio.to(self.device)
current_length = tensor.size(1) low_quality_audio = low_quality_audio.to(self.device)
scale_factor = target_length / current_length
# Resample the tensor using linear interpolation return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate)
tensor = F.interpolate(tensor.unsqueeze(0), scale_factor=scale_factor, mode='linear', align_corners=False).squeeze(0)
return tensor

View File

@ -1,24 +1,58 @@
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.utils as utils
def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, spectral_norm=True):
padding = (kernel_size // 2) * dilation
conv_layer = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding)
if spectral_norm:
conv_layer = utils.spectral_norm(conv_layer)
return nn.Sequential(
conv_layer,
nn.LeakyReLU(0.2, inplace=True),
nn.BatchNorm1d(out_channels)
)
class AttentionBlock(nn.Module):
def __init__(self, channels):
super(AttentionBlock, self).__init__()
self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.ReLU(),
nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x):
attention_weights = self.attention(x)
return x * attention_weights
class SISUDiscriminator(nn.Module): class SISUDiscriminator(nn.Module):
def __init__(self): def __init__(self, layers=4): #Increased base layer count
super(SISUDiscriminator, self).__init__() super(SISUDiscriminator, self).__init__()
self.model = nn.Sequential( self.model = nn.Sequential(
nn.Conv1d(2, 128, kernel_size=3, padding=1), discriminator_block(1, layers, kernel_size=3, stride=1), #Aggressive downsampling
nn.LeakyReLU(0.2, inplace=True), discriminator_block(layers, layers * 2, kernel_size=5, stride=2),
nn.Conv1d(128, 256, kernel_size=3, padding=1), discriminator_block(layers * 2, layers * 4, kernel_size=5, dilation=4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(256, 128, kernel_size=3, padding=1), #AttentionBlock(layers * 4), #Added attention
nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(128, 64, kernel_size=3, padding=1), #discriminator_block(layers * 4, layers * 8, kernel_size=5, dilation=4),
nn.LeakyReLU(0.2, inplace=True), #AttentionBlock(layers * 8), #Added attention
nn.Conv1d(64, 1, kernel_size=3, padding=1), #discriminator_block(layers * 8, layers * 16, kernel_size=5, dilation=8),
nn.LeakyReLU(0.2, inplace=True), #discriminator_block(layers * 16, layers * 16, kernel_size=3, dilation=1),
#discriminator_block(layers * 16, layers * 8, kernel_size=3, dilation=2),
#discriminator_block(layers * 8, layers * 4, kernel_size=3, dilation=1),
discriminator_block(layers * 4, layers * 2, kernel_size=5, stride=2),
discriminator_block(layers * 2, layers, kernel_size=3, stride=1),
discriminator_block(layers, 1, kernel_size=3, stride=1, spectral_norm=False) #last layer no spectral norm.
) )
self.global_avg_pool = nn.AdaptiveAvgPool1d(1) # Output size (1,) self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
self.sigmoid = nn.Sigmoid()
def forward(self, x): def forward(self, x):
x = self.model(x) x = self.model(x)
x = self.global_avg_pool(x) x = self.global_avg_pool(x)
x = x.view(-1, 1) # Flatten to (batch_size, 1) x = x.view(-1, 1)
x = self.sigmoid(x)
return x return x

View File

@ -1,23 +1,52 @@
import torch.nn as nn import torch.nn as nn
class SISUGenerator(nn.Module): def conv_block(in_channels, out_channels, kernel_size=3, dilation=1):
def __init__(self, upscale_scale=1): # No noise_dim parameter return nn.Sequential(
super(SISUGenerator, self).__init__() nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, dilation=dilation, padding=(kernel_size // 2) * dilation),
self.model = nn.Sequential( nn.BatchNorm1d(out_channels),
nn.Conv1d(2, 128, kernel_size=3, padding=1), nn.PReLU()
nn.LeakyReLU(0.2, inplace=True), )
nn.Conv1d(128, 256, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=upscale_scale, mode='nearest'), class AttentionBlock(nn.Module):
def __init__(self, channels):
nn.Conv1d(256, 128, kernel_size=3, padding=1), super(AttentionBlock, self).__init__()
nn.LeakyReLU(0.2, inplace=True), self.attention = nn.Sequential(
nn.Conv1d(128, 64, kernel_size=3, padding=1), nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.LeakyReLU(0.2, inplace=True), nn.ReLU(),
nn.Conv1d(64, 2, kernel_size=3, padding=1), nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Tanh() nn.Sigmoid()
) )
def forward(self, x): def forward(self, x):
return self.model(x) attention_weights = self.attention(x)
return x * attention_weights
class ResidualInResidualBlock(nn.Module):
def __init__(self, channels, num_convs=3):
super(ResidualInResidualBlock, self).__init__()
self.conv_layers = nn.Sequential(*[conv_block(channels, channels) for _ in range(num_convs)])
self.attention = AttentionBlock(channels)
def forward(self, x):
residual = x
x = self.conv_layers(x)
x = self.attention(x)
return x + residual
class SISUGenerator(nn.Module):
def __init__(self, layer=4, num_rirb=4): #increased base layer and rirb amounts
super(SISUGenerator, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv1d(1, layer, kernel_size=7, padding=3),
nn.BatchNorm1d(layer),
nn.PReLU(),
)
self.rir_blocks = nn.Sequential(*[ResidualInResidualBlock(layer) for _ in range(num_rirb)])
self.final_layer = nn.Conv1d(layer, 1, kernel_size=3, padding=1)
def forward(self, x):
residual = x
x = self.conv1(x)
x = self.rir_blocks(x)
x = self.final_layer(x)
return x + residual

View File

@ -1,12 +1,14 @@
filelock>=3.16.1 filelock==3.16.1
fsspec>=2024.10.0 fsspec==2024.10.0
Jinja2>=3.1.4 Jinja2==3.1.4
MarkupSafe>=2.1.5 MarkupSafe==2.1.5
mpmath>=1.3.0 mpmath==1.3.0
networkx>=3.4.2 networkx==3.4.2
numpy>=2.1.2 numpy==2.2.3
pillow>=11.0.0 pytorch-triton-rocm==3.2.0+git4b3bb1f8
setuptools>=70.2.0 setuptools==70.2.0
sympy>=1.13.1 sympy==1.13.3
tqdm>=4.67.1 torch==2.7.0.dev20250226+rocm6.3
typing_extensions>=4.12.2 torchaudio==2.6.0.dev20250226+rocm6.3
tqdm==4.67.1
typing_extensions==4.12.2

10
test.py
View File

@ -1,10 +0,0 @@
import torch.nn as nn
import torch
from discriminator import SISUDiscriminator
discriminator = SISUDiscriminator()
test_input = torch.randn(1, 2, 1000) # Example input (batch_size, channels, frames)
output = discriminator(test_input)
print(output)
print("Output shape:", output.shape)

View File

@ -6,40 +6,120 @@ import torch.nn.functional as F
import torchaudio import torchaudio
import tqdm import tqdm
import argparse
import math
import os
from torch.utils.data import random_split from torch.utils.data import random_split
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import AudioUtils
from data import AudioDataset from data import AudioDataset
from generator import SISUGenerator from generator import SISUGenerator
from discriminator import SISUDiscriminator from discriminator import SISUDiscriminator
# Check for CUDA availability import torchaudio.transforms as T
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--generator", type=str, default=None,
help="Path to the generator model file")
parser.add_argument("--discriminator", type=str, default=None,
help="Path to the discriminator model file")
parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
parser.add_argument("--debug", action="store_true", help="Print debug logs")
args = parser.parse_args()
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}") print(f"Using device: {device}")
mfcc_transform = T.MFCC(
sample_rate=44100,
n_mfcc=20,
melkwargs={'n_fft': 2048, 'hop_length': 256}
).to(device)
def gpu_mfcc_loss(y_true, y_pred):
mfccs_true = mfcc_transform(y_true)
mfccs_pred = mfcc_transform(y_pred)
min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2])
mfccs_true = mfccs_true[:, :, :min_len]
mfccs_pred = mfccs_pred[:, :, :min_len]
loss = torch.mean((mfccs_true - mfccs_pred)**2)
return loss
def discriminator_train(high_quality, low_quality, real_labels, fake_labels):
optimizer_d.zero_grad()
# Forward pass for real samples
discriminator_decision_from_real = discriminator(high_quality[0])
d_loss_real = criterion_d(discriminator_decision_from_real, real_labels)
# Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0])
discriminator_decision_from_fake = discriminator(generator_output.detach())
d_loss_fake = criterion_d(discriminator_decision_from_fake, fake_labels)
# Combine real and fake losses
d_loss = (d_loss_real + d_loss_fake) / 2.0
# Backward pass and optimization
d_loss.backward()
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping
optimizer_d.step()
return d_loss
def generator_train(low_quality, high_quality, real_labels):
optimizer_g.zero_grad()
# Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0])
#mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output)
discriminator_decision = discriminator(generator_output)
adversarial_loss = criterion_g(discriminator_decision, real_labels)
#combined_loss = adversarial_loss + 0.5 * mfcc_l
adversarial_loss.backward()
optimizer_g.step()
#return (generator_output, combined_loss, adversarial_loss, mfcc_l)
return (generator_output, adversarial_loss)
debug = args.debug
# Initialize dataset and dataloader # Initialize dataset and dataloader
dataset_dir = './dataset/good' dataset_dir = './dataset/good'
dataset = AudioDataset(dataset_dir, target_duration=2.0) dataset = AudioDataset(dataset_dir, device)
dataset_size = len(dataset) # ========= SINGLE =========
train_size = int(dataset_size * .9)
val_size = int(dataset_size-train_size)
train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) train_data_loader = DataLoader(dataset, batch_size=256, shuffle=True)
train_data_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_data_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)
# Initialize models and move them to device # Initialize models and move them to device
generator = SISUGenerator() generator = SISUGenerator()
discriminator = SISUDiscriminator() discriminator = SISUDiscriminator()
epoch: int = args.epoch
generator = generator.to(device) generator = generator.to(device)
discriminator = discriminator.to(device) discriminator = discriminator.to(device)
if args.generator is not None:
generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True))
if args.discriminator is not None:
discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True))
# Loss # Loss
criterion_g = nn.L1Loss() criterion_g = nn.MSELoss()
criterion_d = nn.BCEWithLogitsLoss() criterion_d = nn.BCELoss()
# Optimizers # Optimizers
optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999)) optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
@ -49,87 +129,64 @@ optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.99
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5) scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5)
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5) scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5)
def snr(y_true, y_pred): models_dir = "models"
noise = y_true - y_pred os.makedirs(models_dir, exist_ok=True)
signal_power = torch.mean(y_true ** 2)
noise_power = torch.mean(noise ** 2)
snr_db = 10 * torch.log10(signal_power / noise_power)
return snr_db
def discriminator_train(discriminator, optimizer, criterion, generator, real_labels, fake_labels, high_quality, low_quality):
optimizer.zero_grad()
discriminator_decision_from_real = discriminator(high_quality)
d_loss_real = criterion(discriminator_decision_from_real, real_labels)
generator_output = generator(low_quality)
discriminator_decision_from_fake = discriminator(generator_output.detach())
d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels)
d_loss = (d_loss_real + d_loss_fake) / 2.0
d_loss.backward()
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) #Gradient Clipping
optimizer.step()
# print(f"Discriminator Loss: {d_loss.item():.4f}, Mean Real Logit: {discriminator_decision_from_real.mean().item():.2f}, Mean Fake Logit: {discriminator_decision_from_fake.mean().item():.2f}")
def start_training(): def start_training():
generator_epochs = 5000
# Training loop
# discriminator_epochs = 1000
generator_epochs = 500
for generator_epoch in range(generator_epochs): for generator_epoch in range(generator_epochs):
low_quality_audio = torch.empty((1)) low_quality_audio = (torch.empty((1)), 1)
high_quality_audio = torch.empty((1)) high_quality_audio = (torch.empty((1)), 1)
ai_enhanced_audio = torch.empty((1)) ai_enhanced_audio = (torch.empty((1)), 1)
# Training times_correct = 0
for low_quality, high_quality in tqdm.tqdm(train_data_loader, desc=f"Epoch {generator_epoch+1}/{generator_epochs}"):
high_quality = high_quality.to(device)
low_quality = low_quality.to(device)
batch_size = high_quality.size(0) # ========= TRAINING =========
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"):
# for high_quality_clip, low_quality_clip in train_data_loader:
high_quality_sample = (high_quality_clip[0], high_quality_clip[1])
low_quality_sample = (low_quality_clip[0], low_quality_clip[1])
# ========= LABELS =========
batch_size = high_quality_clip[0].size(0)
real_labels = torch.ones(batch_size, 1).to(device) real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device) fake_labels = torch.zeros(batch_size, 1).to(device)
# Train Discriminator # ========= DISCRIMINATOR =========
discriminator.train() discriminator.train()
d_loss = discriminator_train(high_quality_sample, low_quality_sample, real_labels, fake_labels)
for _ in range(3): # ========= GENERATOR =========
discriminator_train(discriminator, optimizer_d, criterion_d, generator, real_labels, fake_labels, high_quality, low_quality)
# Train Generator
generator.train() generator.train()
optimizer_g.zero_grad() #generator_output, combined_loss, adversarial_loss, mfcc_l = generator_train(low_quality_sample, high_quality_sample, real_labels)
generator_output, adversarial_loss = generator_train(low_quality_sample, high_quality_sample, real_labels)
# Generator loss: how well fake data fools the discriminator if debug:
generator_output = generator(low_quality) print(d_loss, adversarial_loss)
discriminator_decision = discriminator(generator_output) # No detach here scheduler_d.step(d_loss)
g_loss = criterion_g(discriminator_decision, real_labels) # Train generator to produce real-like outputs scheduler_g.step(adversarial_loss)
g_loss.backward() # ========= SAVE LATEST AUDIO =========
optimizer_g.step() high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0])
low_quality_audio = (low_quality_clip[0][0], low_quality_clip[1][0])
ai_enhanced_audio = (generator_output[0], high_quality_clip[1][0])
low_quality_audio = low_quality new_epoch = generator_epoch+epoch
high_quality_audio = high_quality
ai_enhanced_audio = generator_output
metric = snr(high_quality_audio, ai_enhanced_audio)
print(f"Generator metric {metric}!")
scheduler_g.step(metric)
if generator_epoch % 10 == 0: if generator_epoch % 10 == 0:
print(f"Saved epoch {generator_epoch}!") print(f"Saved epoch {new_epoch}!")
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0].cpu(), 44100) torchaudio.save(f"./output/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu(), 44100) torchaudio.save(f"./output/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu(), ai_enhanced_audio[1])
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0].cpu(), 44100) torchaudio.save(f"./output/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu(), high_quality_audio[1])
if generator_epoch % 50 == 0: if debug:
torch.save(discriminator.state_dict(), "discriminator.pt") print(generator.state_dict().keys())
torch.save(generator.state_dict(), "generator.pt") print(discriminator.state_dict().keys())
torch.save(discriminator.state_dict(), f"{models_dir}/discriminator_epoch_{new_epoch}.pt")
torch.save(generator.state_dict(), f"{models_dir}/generator_epoch_{new_epoch}.pt")
torch.save(discriminator.state_dict(), "discriminator.pt") torch.save(discriminator, "models/epoch-5000-discriminator.pt")
torch.save(generator.state_dict(), "generator.pt") torch.save(generator, "models/epoch-5000-generator.pt")
print("Training complete!") print("Training complete!")
start_training() start_training()