27 Commits

Author SHA1 Message Date
ff38cefdd3 🐛 | Fix loading wrong model. 2025-06-08 18:14:31 +03:00
03fdc050cc | Made training bit faster. 2025-06-07 20:43:52 +03:00
2ded03713d | Added app.py script so the model can be used. 2025-06-06 22:10:06 +03:00
a135c765da 🐛 | Misc fixes... 2025-05-05 00:50:56 +03:00
b1e18443ba | Added support for .mp3 and .flac loading... 2025-05-04 23:56:14 +03:00
660b41aef8 :albemic: | Real-time testing... 2025-05-04 22:48:57 +03:00
d70c86c257 | Implemented MFCC and STFT. 2025-04-26 17:03:28 +03:00
c04b072de6 | Added smarter ways that would've been needed from the begining. 2025-04-16 17:08:13 +03:00
b6d16e4f11 ♻️ | Restructured procject code. 2025-04-14 17:51:34 +03:00
3936b6c160 🐛 | Fixed NVIDIA training... again. 2025-04-07 14:49:07 +03:00
fbcd5803b8 🐛 | Fixed training on CPU and NVIDIA hardware. 2025-04-07 02:14:06 +03:00
9394bc6c5a :albemic: | Fat architecture. Hopefully better results. 2025-04-06 00:05:43 +03:00
f928d8c2cf :albemic: | More tests. 2025-03-25 21:51:29 +02:00
54338e55a9 :albemic: | Tests. 2025-03-25 19:50:51 +02:00
7e1c7e935a :albemic: | Experimenting with other model layouts. 2025-03-15 18:01:19 +02:00
416500f7fc | Removed/Updated dependencies. 2025-02-26 20:15:30 +02:00
8332b0df2d | Added ability to set epoch. 2025-02-26 19:36:43 +02:00
741dcce7b4 ⚗️ | Increase discriminator size and implement mfcc_loss for generator. 2025-02-23 13:52:01 +02:00
fb7b624c87 ⚗️ | Experimenting with very small model. 2025-02-10 12:44:42 +02:00
0790a0d3da ⚗️ | Experimenting with smaller architecture. 2025-01-25 16:48:10 +02:00
f615b39ded ⚗️ | Experimenting with larger model architecture. 2025-01-08 15:33:18 +02:00
89f8c68986 ⚗️ | Experimenting, again. 2024-12-26 04:00:24 +02:00
2ff45de22d 🔥 | Removed unnecessary test file. 2024-12-25 00:10:45 +02:00
eca71ff5ea ⚗️ | Experimenting still... 2024-12-25 00:09:57 +02:00
1000692f32 ⚗️ | Experimenting with other generator architectures. 2024-12-21 23:54:11 +02:00
de72ee31ea 🔥 | Removed unnecessary models. 2024-12-21 23:28:34 +02:00
70e20f53d4 ⚗️ | Experiment with other layer layouts. 2024-12-21 23:27:38 +02:00
12 changed files with 616 additions and 166 deletions

1
.gitignore vendored
View File

@ -166,3 +166,4 @@ dataset/
old-output/ old-output/
output/ output/
*.wav *.wav
models/

71
AudioUtils.py Normal file
View File

@ -0,0 +1,71 @@
import torch
import torch.nn.functional as F
def stereo_tensor_to_mono(waveform):
if waveform.shape[0] > 1:
# Average across channels
mono_waveform = torch.mean(waveform, dim=0, keepdim=True)
else:
# Already mono
mono_waveform = waveform
return mono_waveform
def stretch_tensor(tensor, target_length):
scale_factor = target_length / tensor.size(1)
tensor = F.interpolate(tensor, scale_factor=scale_factor, mode='linear', align_corners=False)
return tensor
def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 128):
current_length = audio_tensor.shape[-1]
if current_length < target_length:
padding_needed = target_length - current_length
padding_tuple = (0, padding_needed)
padded_audio_tensor = F.pad(audio_tensor, padding_tuple, mode='constant', value=0)
else:
padded_audio_tensor = audio_tensor
return padded_audio_tensor
def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 128) -> list[torch.Tensor]:
if not isinstance(chunk_size, int) or chunk_size <= 0:
raise ValueError("chunk_size must be a positive integer.")
# Handle scalar tensor edge case if necessary
if audio_tensor.dim() == 0:
return [audio_tensor] if audio_tensor.numel() > 0 else []
# Identify the dimension to split (usually the last one, representing time/samples)
split_dim = -1
num_samples = audio_tensor.shape[split_dim]
if num_samples == 0:
return [] # Return empty list if the dimension to split is empty
# Use torch.split to divide the tensor into chunks
# It handles the last chunk being potentially smaller automatically.
chunks = list(torch.split(audio_tensor, chunk_size, dim=split_dim))
return chunks
def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor:
if not chunks:
return torch.empty(0)
if len(chunks) == 1 and chunks[0].dim() == 0:
return chunks[0]
concat_dim = -1
try:
reconstructed_tensor = torch.cat(chunks, dim=concat_dim)
except RuntimeError as e:
raise RuntimeError(
f"Failed to concatenate audio chunks. Ensure chunks have compatible shapes "
f"for concatenation along dimension {concat_dim}. Original error: {e}"
)
return reconstructed_tensor

View File

@ -18,6 +18,7 @@ SISU (Super Ingenious Sound Upscaler) is a project that uses GANs (Generative Ad
1. **Set Up**: 1. **Set Up**:
- Make sure you have Python installed (version 3.8 or higher). - Make sure you have Python installed (version 3.8 or higher).
- Install needed packages: `pip install -r requirements.txt` - Install needed packages: `pip install -r requirements.txt`
- Install current version of PyTorch (CUDA/ROCm/What ever your device supports)
2. **Prepare Audio Data**: 2. **Prepare Audio Data**:
- Put your audio files in the `dataset/good` folder. - Put your audio files in the `dataset/good` folder.

60
app.py Normal file
View File

@ -0,0 +1,60 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchaudio
import tqdm
import argparse
import math
import os
import AudioUtils
from generator import SISUGenerator
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--model", type=str, help="Model to use for upscaling")
parser.add_argument("--clip_length", type=int, default=1024, help="Internal clip length, leave unspecified if unsure")
parser.add_argument("-i", "--input", type=str, help="Input audio file")
parser.add_argument("-o", "--output", type=str, help="Output audio file")
args = parser.parse_args()
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
generator = SISUGenerator()
models_dir = args.model
clip_length = args.clip_length
input_audio = args.input
output_audio = args.output
if models_dir:
generator.load_state_dict(torch.load(f"{models_dir}", map_location=device, weights_only=True))
else:
print(f"Generator model (--model) isn't specified. Do you have the trained model? If not you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)")
generator = generator.to(device)
def start():
# To Mono!
audio, original_sample_rate = torchaudio.load(input_audio, normalize=True)
audio = AudioUtils.stereo_tensor_to_mono(audio)
splitted_audio = AudioUtils.split_audio(audio, clip_length)
splitted_audio_on_device = [t.to(device) for t in splitted_audio]
processed_audio = []
for clip in tqdm.tqdm(splitted_audio_on_device, desc="Processing..."):
processed_audio.append(generator(clip))
reconstructed_audio = AudioUtils.reconstruct_audio(processed_audio)
print(f"Saving {output_audio}!")
torchaudio.save(output_audio, reconstructed_audio.cpu().detach(), original_sample_rate)
start()

72
data.py
View File

@ -1,50 +1,46 @@
from torch.utils.data import Dataset from torch.utils.data import Dataset
import torch.nn.functional as F import torch.nn.functional as F
import torch
import torchaudio import torchaudio
import os import os
import random import random
import torchaudio.transforms as T
import tqdm
import AudioUtils
class AudioDataset(Dataset): class AudioDataset(Dataset):
audio_sample_rates = [8000, 11025, 16000, 22050] audio_sample_rates = [11025]
def __init__(self, input_dir, target_duration=None, padding_mode='constant', padding_value=0.0): def __init__(self, input_dir, device, clip_length = 1024):
self.input_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.wav')] self.device = device
self.target_duration = target_duration # Duration in seconds or None if not set input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav') or f.endswith('.mp3') or f.endswith('.flac')]
self.padding_mode = padding_mode
self.padding_value = padding_value data = []
for audio_clip in tqdm.tqdm(input_files, desc=f"Processing {len(input_files)} audio file(s)"):
audio, original_sample_rate = torchaudio.load(audio_clip, normalize=True)
audio = AudioUtils.stereo_tensor_to_mono(audio)
# Generate low-quality audio with random downsampling
mangled_sample_rate = random.choice(self.audio_sample_rates)
resample_transform_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate)
resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate)
low_audio = resample_transform_low(audio)
low_audio = resample_transform_high(low_audio)
splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length)
splitted_high_quality_audio[-1] = AudioUtils.pad_tensor(splitted_high_quality_audio[-1], clip_length)
splitted_low_quality_audio = AudioUtils.split_audio(low_audio, clip_length)
splitted_low_quality_audio[-1] = AudioUtils.pad_tensor(splitted_low_quality_audio[-1], clip_length)
for high_quality_sample, low_quality_sample in zip(splitted_high_quality_audio, splitted_low_quality_audio):
data.append(((high_quality_sample, low_quality_sample), (original_sample_rate, mangled_sample_rate)))
self.audio_data = data
def __len__(self): def __len__(self):
return len(self.input_files) return len(self.audio_data)
def __getitem__(self, idx): def __getitem__(self, idx):
high_quality_wav, sr_original = torchaudio.load(self.input_files[idx], normalize=True) return self.audio_data[idx]
sample_rate = random.choice(self.audio_sample_rates)
resample_transform = torchaudio.transforms.Resample(sr_original, sample_rate)
low_quality_wav = resample_transform(high_quality_wav)
low_quality_wav = low_quality_wav
# Calculate target length based on desired duration and 16000 Hz
if self.target_duration is not None:
target_length = int(self.target_duration * 44100)
else:
# Calculate duration of original high quality audio
target_length = high_quality_wav.size(1)
# Pad both to the calculated target length
high_quality_wav = self.stretch_tensor(high_quality_wav, target_length)
low_quality_wav = self.stretch_tensor(low_quality_wav, target_length)
return low_quality_wav, high_quality_wav
def stretch_tensor(self, tensor, target_length):
current_length = tensor.size(1)
scale_factor = target_length / current_length
# Resample the tensor using linear interpolation
tensor = F.interpolate(tensor.unsqueeze(0), scale_factor=scale_factor, mode='linear', align_corners=False).squeeze(0)
return tensor

View File

@ -1,24 +1,63 @@
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.utils as utils
def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, spectral_norm=True, use_instance_norm=True):
padding = (kernel_size // 2) * dilation
conv_layer = nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding
)
if spectral_norm:
conv_layer = utils.spectral_norm(conv_layer)
layers = [conv_layer]
layers.append(nn.LeakyReLU(0.2, inplace=True))
if use_instance_norm:
layers.append(nn.InstanceNorm1d(out_channels))
return nn.Sequential(*layers)
class AttentionBlock(nn.Module):
def __init__(self, channels):
super(AttentionBlock, self).__init__()
self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x):
attention_weights = self.attention(x)
return x * attention_weights
class SISUDiscriminator(nn.Module): class SISUDiscriminator(nn.Module):
def __init__(self): def __init__(self, base_channels=16):
super(SISUDiscriminator, self).__init__() super(SISUDiscriminator, self).__init__()
layers = base_channels
self.model = nn.Sequential( self.model = nn.Sequential(
nn.Conv1d(2, 128, kernel_size=3, padding=1), discriminator_block(1, layers, kernel_size=7, stride=1, spectral_norm=True, use_instance_norm=False),
nn.LeakyReLU(0.2, inplace=True), discriminator_block(layers, layers * 2, kernel_size=5, stride=2, spectral_norm=True, use_instance_norm=True),
nn.Conv1d(128, 256, kernel_size=3, padding=1), discriminator_block(layers * 2, layers * 4, kernel_size=5, stride=1, dilation=2, spectral_norm=True, use_instance_norm=True),
nn.LeakyReLU(0.2, inplace=True), AttentionBlock(layers * 4),
nn.Conv1d(256, 128, kernel_size=3, padding=1), discriminator_block(layers * 4, layers * 8, kernel_size=5, stride=1, dilation=4, spectral_norm=True, use_instance_norm=True),
nn.LeakyReLU(0.2, inplace=True), discriminator_block(layers * 8, layers * 4, kernel_size=5, stride=2, spectral_norm=True, use_instance_norm=True),
nn.Conv1d(128, 64, kernel_size=3, padding=1), discriminator_block(layers * 4, layers * 2, kernel_size=3, stride=1, spectral_norm=True, use_instance_norm=True),
nn.LeakyReLU(0.2, inplace=True), discriminator_block(layers * 2, layers, kernel_size=3, stride=1, spectral_norm=True, use_instance_norm=True),
nn.Conv1d(64, 1, kernel_size=3, padding=1), discriminator_block(layers, 1, kernel_size=3, stride=1, spectral_norm=False, use_instance_norm=False)
nn.LeakyReLU(0.2, inplace=True),
) )
self.global_avg_pool = nn.AdaptiveAvgPool1d(1) # Output size (1,)
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
def forward(self, x): def forward(self, x):
x = self.model(x) x = self.model(x)
x = self.global_avg_pool(x) x = self.global_avg_pool(x)
x = x.view(-1, 1) # Flatten to (batch_size, 1) x = x.view(x.size(0), -1)
return x return x

30
file_utils.py Normal file
View File

@ -0,0 +1,30 @@
import json
filepath = "my_data.json"
def write_data(filepath, data, debug=False):
try:
with open(filepath, 'w') as f:
json.dump(data, f, indent=4) # Use indent for pretty formatting
if debug:
print(f"Data written to '{filepath}'")
except Exception as e:
print(f"Error writing to file: {e}")
def read_data(filepath, debug=False):
try:
with open(filepath, 'r') as f:
data = json.load(f)
if debug:
print(f"Data read from '{filepath}'")
return data
except FileNotFoundError:
print(f"File not found: {filepath}")
return None
except json.JSONDecodeError:
print(f"Error decoding JSON from file: {filepath}")
return None
except Exception as e:
print(f"Error reading from file: {e}")
return None

View File

@ -1,23 +1,74 @@
import torch
import torch.nn as nn import torch.nn as nn
class SISUGenerator(nn.Module): def conv_block(in_channels, out_channels, kernel_size=3, dilation=1):
def __init__(self, upscale_scale=1): # No noise_dim parameter return nn.Sequential(
super(SISUGenerator, self).__init__() nn.Conv1d(
self.model = nn.Sequential( in_channels,
nn.Conv1d(2, 128, kernel_size=3, padding=1), out_channels,
nn.LeakyReLU(0.2, inplace=True), kernel_size=kernel_size,
nn.Conv1d(128, 256, kernel_size=3, padding=1), dilation=dilation,
nn.LeakyReLU(0.2, inplace=True), padding=(kernel_size // 2) * dilation
),
nn.InstanceNorm1d(out_channels),
nn.PReLU()
)
nn.Upsample(scale_factor=upscale_scale, mode='nearest'), class AttentionBlock(nn.Module):
"""
nn.Conv1d(256, 128, kernel_size=3, padding=1), Simple Channel Attention Block. Learns to weight channels based on their importance.
nn.LeakyReLU(0.2, inplace=True), """
nn.Conv1d(128, 64, kernel_size=3, padding=1), def __init__(self, channels):
nn.LeakyReLU(0.2, inplace=True), super(AttentionBlock, self).__init__()
nn.Conv1d(64, 2, kernel_size=3, padding=1), self.attention = nn.Sequential(
nn.Tanh() nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Sigmoid()
) )
def forward(self, x): def forward(self, x):
return self.model(x) attention_weights = self.attention(x)
return x * attention_weights
class ResidualInResidualBlock(nn.Module):
def __init__(self, channels, num_convs=3):
super(ResidualInResidualBlock, self).__init__()
self.conv_layers = nn.Sequential(
*[conv_block(channels, channels) for _ in range(num_convs)]
)
self.attention = AttentionBlock(channels)
def forward(self, x):
residual = x
x = self.conv_layers(x)
x = self.attention(x)
return x + residual
class SISUGenerator(nn.Module):
def __init__(self, channels=16, num_rirb=4, alpha=1.0):
super(SISUGenerator, self).__init__()
self.alpha = alpha
self.conv1 = nn.Sequential(
nn.Conv1d(1, channels, kernel_size=7, padding=3),
nn.InstanceNorm1d(channels),
nn.PReLU(),
)
self.rir_blocks = nn.Sequential(
*[ResidualInResidualBlock(channels) for _ in range(num_rirb)]
)
self.final_layer = nn.Conv1d(channels, 1, kernel_size=3, padding=1)
def forward(self, x):
residual_input = x
x = self.conv1(x)
x_rirb_out = self.rir_blocks(x)
learned_residual = self.final_layer(x_rirb_out)
output = residual_input + self.alpha * learned_residual
return output

View File

@ -1,12 +1,12 @@
filelock>=3.16.1 filelock==3.16.1
fsspec>=2024.10.0 fsspec==2024.10.0
Jinja2>=3.1.4 Jinja2==3.1.4
MarkupSafe>=2.1.5 MarkupSafe==2.1.5
mpmath>=1.3.0 mpmath==1.3.0
networkx>=3.4.2 networkx==3.4.2
numpy>=2.1.2 numpy==2.2.3
pillow>=11.0.0 pillow==11.0.0
setuptools>=70.2.0 setuptools==70.2.0
sympy>=1.13.1 sympy==1.13.3
tqdm>=4.67.1 tqdm==4.67.1
typing_extensions>=4.12.2 typing_extensions==4.12.2

10
test.py
View File

@ -1,10 +0,0 @@
import torch.nn as nn
import torch
from discriminator import SISUDiscriminator
discriminator = SISUDiscriminator()
test_input = torch.randn(1, 2, 1000) # Example input (batch_size, channels, frames)
output = discriminator(test_input)
print(output)
print("Output shape:", output.shape)

View File

@ -6,39 +6,114 @@ import torch.nn.functional as F
import torchaudio import torchaudio
import tqdm import tqdm
import argparse
import math
import os
from torch.utils.data import random_split from torch.utils.data import random_split
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import AudioUtils
from data import AudioDataset from data import AudioDataset
from generator import SISUGenerator from generator import SISUGenerator
from discriminator import SISUDiscriminator from discriminator import SISUDiscriminator
# Check for CUDA availability from training_utils import discriminator_train, generator_train
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") import file_utils as Data
import torchaudio.transforms as T
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--generator", type=str, default=None,
help="Path to the generator model file")
parser.add_argument("--discriminator", type=str, default=None,
help="Path to the discriminator model file")
parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
parser.add_argument("--debug", action="store_true", help="Print debug logs")
parser.add_argument("--continue_training", action="store_true", help="Continue training using temp_generator and temp_discriminator models")
args = parser.parse_args()
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}") print(f"Using device: {device}")
# Parameters
sample_rate = 44100
n_fft = 1024
win_length = n_fft
hop_length = n_fft // 4
n_mels = 40
n_mfcc = 13
mfcc_transform = T.MFCC(
sample_rate=sample_rate,
n_mfcc=n_mfcc,
melkwargs={
'n_fft': n_fft,
'hop_length': hop_length,
'win_length': win_length,
'n_mels': n_mels,
'power': 1.0,
}
).to(device)
mel_transform = T.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mels=n_mels,
power=1.0 # Magnitude Mel
).to(device)
stft_transform = T.Spectrogram(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length
).to(device)
debug = args.debug
# Initialize dataset and dataloader # Initialize dataset and dataloader
dataset_dir = './dataset/good' dataset_dir = './dataset/good'
dataset = AudioDataset(dataset_dir, target_duration=2.0) dataset = AudioDataset(dataset_dir, device)
models_dir = "./models"
os.makedirs(models_dir, exist_ok=True)
audio_output_dir = "./output"
os.makedirs(audio_output_dir, exist_ok=True)
dataset_size = len(dataset) # ========= SINGLE =========
train_size = int(dataset_size * .9)
val_size = int(dataset_size-train_size)
train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) train_data_loader = DataLoader(dataset, batch_size=2048, shuffle=True, num_workers=24)
train_data_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_data_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)
# Initialize models and move them to device # ========= MODELS =========
generator = SISUGenerator() generator = SISUGenerator()
discriminator = SISUDiscriminator() discriminator = SISUDiscriminator()
epoch: int = args.epoch
if args.continue_training:
if args.generator is not None:
generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True))
elif args.discriminator is not None:
discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True))
else:
generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
discriminator.load_state_dict(torch.load(f"{models_dir}/temp_discriminator.pt", map_location=device, weights_only=True))
epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json")
epoch = epoch_from_file["epoch"] + 1
generator = generator.to(device) generator = generator.to(device)
discriminator = discriminator.to(device) discriminator = discriminator.to(device)
# Loss # Loss
criterion_g = nn.L1Loss() criterion_g = nn.BCEWithLogitsLoss()
criterion_d = nn.BCEWithLogitsLoss() criterion_d = nn.BCEWithLogitsLoss()
# Optimizers # Optimizers
@ -49,87 +124,81 @@ optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.99
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5) scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5)
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5) scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5)
def snr(y_true, y_pred):
noise = y_true - y_pred
signal_power = torch.mean(y_true ** 2)
noise_power = torch.mean(noise ** 2)
snr_db = 10 * torch.log10(signal_power / noise_power)
return snr_db
def discriminator_train(discriminator, optimizer, criterion, generator, real_labels, fake_labels, high_quality, low_quality):
optimizer.zero_grad()
discriminator_decision_from_real = discriminator(high_quality)
d_loss_real = criterion(discriminator_decision_from_real, real_labels)
generator_output = generator(low_quality)
discriminator_decision_from_fake = discriminator(generator_output.detach())
d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels)
d_loss = (d_loss_real + d_loss_fake) / 2.0
d_loss.backward()
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) #Gradient Clipping
optimizer.step()
# print(f"Discriminator Loss: {d_loss.item():.4f}, Mean Real Logit: {discriminator_decision_from_real.mean().item():.2f}, Mean Fake Logit: {discriminator_decision_from_fake.mean().item():.2f}")
def start_training(): def start_training():
generator_epochs = 5000
# Training loop
# discriminator_epochs = 1000
generator_epochs = 500
for generator_epoch in range(generator_epochs): for generator_epoch in range(generator_epochs):
low_quality_audio = torch.empty((1)) high_quality_audio = ([torch.empty((1))], 1)
high_quality_audio = torch.empty((1)) low_quality_audio = ([torch.empty((1))], 1)
ai_enhanced_audio = torch.empty((1)) ai_enhanced_audio = ([torch.empty((1))], 1)
# Training times_correct = 0
for low_quality, high_quality in tqdm.tqdm(train_data_loader, desc=f"Epoch {generator_epoch+1}/{generator_epochs}"):
high_quality = high_quality.to(device)
low_quality = low_quality.to(device)
batch_size = high_quality.size(0) # ========= TRAINING =========
for training_data in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"):
## Data structure:
# [[[float..., float..., float...], [float..., float..., float...]], [original_sample_rate, mangled_sample_rate]]
# ========= LABELS =========
good_quality_data = training_data[0][0].to(device)
bad_quality_data = training_data[0][1].to(device)
original_sample_rate = training_data[1][0]
mangled_sample_rate = training_data[1][1]
batch_size = good_quality_data.size(0)
real_labels = torch.ones(batch_size, 1).to(device) real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device) fake_labels = torch.zeros(batch_size, 1).to(device)
# Train Discriminator high_quality_audio = (good_quality_data, original_sample_rate)
low_quality_audio = (bad_quality_data, mangled_sample_rate)
# ========= DISCRIMINATOR =========
discriminator.train() discriminator.train()
d_loss = discriminator_train(
good_quality_data,
bad_quality_data,
real_labels,
fake_labels,
discriminator,
generator,
criterion_d,
optimizer_d
)
for _ in range(3): # ========= GENERATOR =========
discriminator_train(discriminator, optimizer_d, criterion_d, generator, real_labels, fake_labels, high_quality, low_quality)
# Train Generator
generator.train() generator.train()
optimizer_g.zero_grad() generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor = generator_train(
bad_quality_data,
good_quality_data,
real_labels,
generator,
discriminator,
criterion_d,
optimizer_g,
device,
mel_transform,
stft_transform,
mfcc_transform
)
# Generator loss: how well fake data fools the discriminator if debug:
generator_output = generator(low_quality) print(f"D_LOSS: {d_loss.item():.4f}, COMBINED_LOSS: {combined_loss.item():.4f}, ADVERSARIAL_LOSS: {adversarial_loss.item():.4f}, MEL_L1_LOSS: {mel_l1_tensor.item():.4f}, LOG_STFT_L1_LOSS: {log_stft_l1_tensor.item():.4f}, MFCC_LOSS: {mfcc_l_tensor.item():.4f}")
discriminator_decision = discriminator(generator_output) # No detach here scheduler_d.step(d_loss.detach())
g_loss = criterion_g(discriminator_decision, real_labels) # Train generator to produce real-like outputs scheduler_g.step(adversarial_loss.detach())
g_loss.backward() # ========= SAVE LATEST AUDIO =========
optimizer_g.step() high_quality_audio = (good_quality_data, original_sample_rate)
low_quality_audio = (bad_quality_data, original_sample_rate)
ai_enhanced_audio = (generator_output, original_sample_rate)
low_quality_audio = low_quality torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt")
high_quality_audio = high_quality torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt")
ai_enhanced_audio = generator_output
metric = snr(high_quality_audio, ai_enhanced_audio) new_epoch = generator_epoch+epoch
print(f"Generator metric {metric}!") Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch})
scheduler_g.step(metric)
if generator_epoch % 10 == 0:
print(f"Saved epoch {generator_epoch}!")
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0].cpu(), 44100)
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu(), 44100)
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0].cpu(), 44100)
if generator_epoch % 50 == 0: torch.save(discriminator, "models/epoch-5000-discriminator.pt")
torch.save(discriminator.state_dict(), "discriminator.pt") torch.save(generator, "models/epoch-5000-generator.pt")
torch.save(generator.state_dict(), "generator.pt")
torch.save(discriminator.state_dict(), "discriminator.pt")
torch.save(generator.state_dict(), "generator.pt")
print("Training complete!") print("Training complete!")
start_training() start_training()

142
training_utils.py Normal file
View File

@ -0,0 +1,142 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import torchaudio.transforms as T
def gpu_mfcc_loss(mfcc_transform, y_true, y_pred):
mfccs_true = mfcc_transform(y_true)
mfccs_pred = mfcc_transform(y_pred)
min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2])
mfccs_true = mfccs_true[:, :, :min_len]
mfccs_pred = mfccs_pred[:, :, :min_len]
loss = torch.mean((mfccs_true - mfccs_pred)**2)
return loss
def mel_spectrogram_l1_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
mel_spec_true = mel_transform(y_true)
mel_spec_pred = mel_transform(y_pred)
min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1])
mel_spec_true = mel_spec_true[..., :min_len]
mel_spec_pred = mel_spec_pred[..., :min_len]
loss = torch.mean(torch.abs(mel_spec_true - mel_spec_pred))
return loss
def mel_spectrogram_l2_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
mel_spec_true = mel_transform(y_true)
mel_spec_pred = mel_transform(y_pred)
min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1])
mel_spec_true = mel_spec_true[..., :min_len]
mel_spec_pred = mel_spec_pred[..., :min_len]
loss = torch.mean((mel_spec_true - mel_spec_pred)**2)
return loss
def log_stft_magnitude_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
stft_mag_true = stft_transform(y_true)
stft_mag_pred = stft_transform(y_pred)
min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1])
stft_mag_true = stft_mag_true[..., :min_len]
stft_mag_pred = stft_mag_pred[..., :min_len]
loss = torch.mean(torch.abs(torch.log(stft_mag_true + eps) - torch.log(stft_mag_pred + eps)))
return loss
def spectral_convergence_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
stft_mag_true = stft_transform(y_true)
stft_mag_pred = stft_transform(y_pred)
min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1])
stft_mag_true = stft_mag_true[..., :min_len]
stft_mag_pred = stft_mag_pred[..., :min_len]
norm_true = torch.linalg.norm(stft_mag_true, ord='fro', dim=(-2, -1))
norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, ord='fro', dim=(-2, -1))
loss = torch.mean(norm_diff / (norm_true + eps))
return loss
def discriminator_train(high_quality, low_quality, real_labels, fake_labels, discriminator, generator, criterion, optimizer):
optimizer.zero_grad()
# Forward pass for real samples
discriminator_decision_from_real = discriminator(high_quality)
d_loss_real = criterion(discriminator_decision_from_real, real_labels)
with torch.no_grad():
generator_output = generator(low_quality)
discriminator_decision_from_fake = discriminator(generator_output)
d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels.expand_as(discriminator_decision_from_fake))
d_loss = (d_loss_real + d_loss_fake) / 2.0
d_loss.backward()
# Optional: Gradient Clipping (can be helpful)
# nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping
optimizer.step()
return d_loss
def generator_train(
low_quality,
high_quality,
real_labels,
generator,
discriminator,
adv_criterion,
g_optimizer,
device,
mel_transform: T.MelSpectrogram,
stft_transform: T.Spectrogram,
mfcc_transform: T.MFCC,
lambda_adv: float = 1.0,
lambda_mel_l1: float = 10.0,
lambda_log_stft: float = 1.0,
lambda_mfcc: float = 1.0
):
g_optimizer.zero_grad()
generator_output = generator(low_quality)
discriminator_decision = discriminator(generator_output)
adversarial_loss = adv_criterion(discriminator_decision, real_labels.expand_as(discriminator_decision))
mel_l1 = 0.0
log_stft_l1 = 0.0
mfcc_l = 0.0
# Calculate Mel L1 Loss if weight is positive
if lambda_mel_l1 > 0:
mel_l1 = mel_spectrogram_l1_loss(mel_transform, high_quality, generator_output)
# Calculate Log STFT L1 Loss if weight is positive
if lambda_log_stft > 0:
log_stft_l1 = log_stft_magnitude_loss(stft_transform, high_quality, generator_output)
# Calculate MFCC Loss if weight is positive
if lambda_mfcc > 0:
mfcc_l = gpu_mfcc_loss(mfcc_transform, high_quality, generator_output)
mel_l1_tensor = torch.tensor(mel_l1, device=device) if isinstance(mel_l1, float) else mel_l1
log_stft_l1_tensor = torch.tensor(log_stft_l1, device=device) if isinstance(log_stft_l1, float) else log_stft_l1
mfcc_l_tensor = torch.tensor(mfcc_l, device=device) if isinstance(mfcc_l, float) else mfcc_l
combined_loss = (lambda_adv * adversarial_loss) + \
(lambda_mel_l1 * mel_l1_tensor) + \
(lambda_log_stft * log_stft_l1_tensor) + \
(lambda_mfcc * mfcc_l_tensor)
combined_loss.backward()
# Optional: Gradient Clipping
# nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
g_optimizer.step()
# 6. Return values for logging
return generator_output, combined_loss, adversarial_loss, mel_l1_tensor, log_stft_l1_tensor, mfcc_l_tensor