Compare commits

...

31 Commits

Author SHA1 Message Date
571c403b93 ⚗️ | Small fixes here and there 2025-12-11 23:06:38 +02:00
e3e555794e ⚗️ | Added MultiPeriodDiscriminator implementation from HiFi-GAN 2025-12-06 18:04:18 +02:00
bf0a6e58e9 ⚗️ | Added MultiPeriodDiscriminator implementation from HiFi-GAN 2025-12-04 14:22:48 +02:00
782a3bab28 ⚗️ | More architectural changes 2025-11-18 21:34:59 +02:00
3f23242d6f ⚗️ | Added some stupid ways for training + some makeup 2025-10-04 22:38:11 +03:00
0bc8fc2792 | Made training bit... spicier. 2025-09-10 19:52:53 +03:00
ff38cefdd3 🐛 | Fix loading wrong model. 2025-06-08 18:14:31 +03:00
03fdc050cc | Made training bit faster. 2025-06-07 20:43:52 +03:00
2ded03713d | Added app.py script so the model can be used. 2025-06-06 22:10:06 +03:00
a135c765da 🐛 | Misc fixes... 2025-05-05 00:50:56 +03:00
b1e18443ba | Added support for .mp3 and .flac loading... 2025-05-04 23:56:14 +03:00
660b41aef8 :albemic: | Real-time testing... 2025-05-04 22:48:57 +03:00
d70c86c257 | Implemented MFCC and STFT. 2025-04-26 17:03:28 +03:00
c04b072de6 | Added smarter ways that would've been needed from the begining. 2025-04-16 17:08:13 +03:00
b6d16e4f11 ♻️ | Restructured procject code. 2025-04-14 17:51:34 +03:00
nsiltala
3936b6c160 🐛 | Fixed NVIDIA training... again. 2025-04-07 14:49:07 +03:00
fbcd5803b8 🐛 | Fixed training on CPU and NVIDIA hardware. 2025-04-07 02:14:06 +03:00
9394bc6c5a :albemic: | Fat architecture. Hopefully better results. 2025-04-06 00:05:43 +03:00
f928d8c2cf :albemic: | More tests. 2025-03-25 21:51:29 +02:00
54338e55a9 :albemic: | Tests. 2025-03-25 19:50:51 +02:00
7e1c7e935a :albemic: | Experimenting with other model layouts. 2025-03-15 18:01:19 +02:00
416500f7fc | Removed/Updated dependencies. 2025-02-26 20:15:30 +02:00
8332b0df2d | Added ability to set epoch. 2025-02-26 19:36:43 +02:00
741dcce7b4 ⚗️ | Increase discriminator size and implement mfcc_loss for generator. 2025-02-23 13:52:01 +02:00
fb7b624c87 ⚗️ | Experimenting with very small model. 2025-02-10 12:44:42 +02:00
0790a0d3da ⚗️ | Experimenting with smaller architecture. 2025-01-25 16:48:10 +02:00
f615b39ded ⚗️ | Experimenting with larger model architecture. 2025-01-08 15:33:18 +02:00
89f8c68986 ⚗️ | Experimenting, again. 2024-12-26 04:00:24 +02:00
2ff45de22d 🔥 | Removed unnecessary test file. 2024-12-25 00:10:45 +02:00
eca71ff5ea ⚗️ | Experimenting still... 2024-12-25 00:09:57 +02:00
1000692f32 ⚗️ | Experimenting with other generator architectures. 2024-12-21 23:54:11 +02:00
13 changed files with 888 additions and 240 deletions

40
AudioUtils.py Normal file
View File

@@ -0,0 +1,40 @@
import torch
import torch.nn.functional as F
def stereo_tensor_to_mono(waveform: torch.Tensor) -> torch.Tensor:
mono_tensor = torch.mean(waveform, dim=0, keepdim=True)
return mono_tensor
def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 512) -> torch.Tensor:
current = audio_tensor.size(-1)
padding_amount = target_length - current
if padding_amount <= 0:
return audio_tensor
return F.pad(audio_tensor, (0, padding_amount))
def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 512, pad_last_tensor: bool = False) -> list[torch.Tensor]:
chunks = list(torch.split(audio_tensor, chunk_size, dim=1))
if pad_last_tensor:
last_chunk = chunks[-1]
if last_chunk.size(-1) < chunk_size:
chunks[-1] = pad_tensor(last_chunk, chunk_size)
return chunks
def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor:
reconstructed_tensor = torch.cat(chunks, dim=-1)
return reconstructed_tensor
def normalize(audio_tensor: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
max_val = torch.max(torch.abs(audio_tensor))
if max_val < eps:
return audio_tensor
return audio_tensor / max_val

View File

@@ -18,6 +18,7 @@ SISU (Super Ingenious Sound Upscaler) is a project that uses GANs (Generative Ad
1. **Set Up**: 1. **Set Up**:
- Make sure you have Python installed (version 3.8 or higher). - Make sure you have Python installed (version 3.8 or higher).
- Install needed packages: `pip install -r requirements.txt` - Install needed packages: `pip install -r requirements.txt`
- Install current version of PyTorch (CUDA/ROCm/What ever your device supports)
2. **Prepare Audio Data**: 2. **Prepare Audio Data**:
- Put your audio files in the `dataset/good` folder. - Put your audio files in the `dataset/good` folder.

0
__init__.py Normal file
View File

128
app.py Normal file
View File

@@ -0,0 +1,128 @@
import argparse
import torch
import torchaudio
import torchcodec
import tqdm
from accelerate import Accelerator
import AudioUtils
from generator import SISUGenerator
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--model", type=str, help="Model to use for upscaling")
parser.add_argument(
"--clip_length",
type=int,
default=8000,
help="Internal clip length, leave unspecified if unsure",
)
parser.add_argument(
"--sample_rate", type=int, default=44100, help="Output clip sample rate"
)
parser.add_argument(
"--bitrate",
type=int,
default=192000,
help="Output clip bitrate",
)
parser.add_argument("-i", "--input", type=str, help="Input audio file")
parser.add_argument("-o", "--output", type=str, help="Output audio file")
args = parser.parse_args()
if args.sample_rate < 8000:
print(
"Sample rate cannot be lower than 8000! (44100 is recommended for base models)"
)
exit()
# ---------------------------
# Init accelerator
# ---------------------------
accelerator = Accelerator(mixed_precision="bf16")
# ---------------------------
# Models
# ---------------------------
generator = SISUGenerator()
accelerator.print("🔨 | Compiling models...")
generator = torch.compile(generator)
accelerator.print("✅ | Compiling done!")
# ---------------------------
# Prepare accelerator
# ---------------------------
generator = accelerator.prepare(generator)
# ---------------------------
# Checkpoint helpers
# ---------------------------
models_dir = args.model
clip_length = args.clip_length
input_audio = args.input
output_audio = args.output
if models_dir:
ckpt = torch.load(models_dir)
accelerator.unwrap_model(generator).load_state_dict(ckpt["G"])
accelerator.print("💾 | Loaded model!")
else:
print(
"Generator model (--model) isn't specified. Do you have the trained model? If not, you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)"
)
def start():
# To Mono!
decoder = torchcodec.decoders.AudioDecoder(input_audio)
decoded_samples = decoder.get_all_samples()
audio = decoded_samples.data
original_sample_rate = decoded_samples.sample_rate
# Support for multichannel audio
# audio = AudioUtils.stereo_tensor_to_mono(audio)
audio = AudioUtils.normalize(audio)
resample_transform = torchaudio.transforms.Resample(
original_sample_rate, args.sample_rate
)
audio = resample_transform(audio)
splitted_audio = AudioUtils.split_audio(audio, clip_length)
splitted_audio_on_device = [t.view(1, t.shape[0], t.shape[-1]).to(accelerator.device) for t in splitted_audio]
processed_audio = []
with torch.no_grad():
for clip in tqdm.tqdm(splitted_audio_on_device, desc="Processing..."):
channels = []
for audio_channel in torch.split(clip, 1, dim=1):
output_piece = generator(audio_channel)
channels.append(output_piece.detach().cpu())
output_clip = torch.cat(channels, dim=1)
processed_audio.append(output_clip)
reconstructed_audio = AudioUtils.reconstruct_audio(processed_audio)
reconstructed_audio = reconstructed_audio.squeeze(0)
print(f"🔊 | Saving {output_audio}!")
torchaudio.save_with_torchcodec(
uri=output_audio,
src=reconstructed_audio,
sample_rate=args.sample_rate,
channels_first=True,
compression=args.bitrate,
)
start()

93
data.py
View File

@@ -1,49 +1,72 @@
from torch.utils.data import Dataset
import torch.nn.functional as F
import torchaudio
import os import os
import random import random
import torch
import torchaudio
import torchcodec.decoders as decoders
import tqdm
from torch.utils.data import Dataset
import AudioUtils
class AudioDataset(Dataset): class AudioDataset(Dataset):
audio_sample_rates = [8000, 11025, 16000, 22050] audio_sample_rates = [8000, 11025, 12000, 16000, 22050, 24000, 32000, 44100]
def __init__(self, input_dir, target_duration=None, padding_mode='constant', padding_value=0.0): def __init__(self, input_dir, clip_length: int = 512, normalize: bool = True):
self.input_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.wav')] self.clip_length = clip_length
self.target_duration = target_duration # Duration in seconds or None if not set self.normalize = normalize
self.padding_mode = padding_mode
self.padding_value = padding_value input_files = [
os.path.join(input_dir, f)
for f in os.listdir(input_dir)
if os.path.isfile(os.path.join(input_dir, f))
and f.lower().endswith((".wav", ".mp3", ".flac"))
]
data = []
for audio_clip in tqdm.tqdm(
input_files, desc=f"Processing {len(input_files)} audio file(s)"
):
decoder = decoders.AudioDecoder(audio_clip)
decoded_samples = decoder.get_all_samples()
audio = decoded_samples.data.float()
original_sample_rate = decoded_samples.sample_rate
if normalize:
audio = AudioUtils.normalize(audio)
splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length, True)
if not splitted_high_quality_audio:
continue
for splitted_audio_clip in splitted_high_quality_audio:
for audio_clip in torch.split(splitted_audio_clip, 1):
data.append((audio_clip, original_sample_rate))
self.audio_data = data
def __len__(self): def __len__(self):
return len(self.input_files) return len(self.audio_data)
def __getitem__(self, idx): def __getitem__(self, idx):
high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True) audio_clip = self.audio_data[idx]
mangled_sample_rate = random.choice(self.audio_sample_rates) mangled_sample_rate = random.choice(self.audio_sample_rates)
resample_transform = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate)
low_quality_audio = resample_transform(high_quality_audio)
# Calculate target length based on desired duration and 16000 Hz resample_transform_low = torchaudio.transforms.Resample(
# if self.target_duration is not None: audio_clip[1], mangled_sample_rate
# target_length = int(self.target_duration * 44100) )
# else:
# # Calculate duration of original high quality audio
# target_length = high_quality_wav.size(1)
# Pad both to the calculated target length resample_transform_high = torchaudio.transforms.Resample(
# high_quality_wav = self.stretch_tensor(high_quality_wav, target_length) mangled_sample_rate, audio_clip[1]
# low_quality_wav = self.stretch_tensor(low_quality_wav, target_length) )
low_audio_clip = resample_transform_high(resample_transform_low(audio_clip[0]))
return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate) if audio_clip[0].shape[1] < low_audio_clip.shape[1]:
low_audio_clip = low_audio_clip[:, :audio_clip[0].shape[1]]
def stretch_tensor(self, tensor, target_length): elif audio_clip[0].shape[1] > low_audio_clip.shape[1]:
current_length = tensor.size(1) target_len = audio_clip[0].shape[1]
scale_factor = target_length / current_length low_audio_clip = AudioUtils.pad_tensor(low_audio_clip, target_len)
return ((audio_clip[0], low_audio_clip), (audio_clip[1], mangled_sample_rate))
# Resample the tensor using linear interpolation
tensor = F.interpolate(tensor.unsqueeze(0), scale_factor=scale_factor, mode='linear', align_corners=False).squeeze(0)
return tensor

View File

@@ -1,24 +1,179 @@
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.parametrizations import weight_norm, spectral_norm
# -------------------------------------------------------------------
# 1. Multi-Period Discriminator (MPD)
# Captures periodic structures (pitch/timbre) by folding audio.
# -------------------------------------------------------------------
class DiscriminatorP(nn.Module):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
self.use_spectral_norm = use_spectral_norm
# Use spectral_norm for stability, or weight_norm for performance
norm_f = spectral_norm if use_spectral_norm else weight_norm
# We use 2D convs because we "fold" the 1D audio into 2D (Period x Time)
self.convs = nn.ModuleList([
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(2, 0))),
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(2, 0))),
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(2, 0))),
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(2, 0))),
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
])
self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
# 1d to 2d conversion: [B, C, T] -> [B, C, T/P, P]
b, c, t = x.shape
if t % self.period != 0: # Pad if not divisible by period
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, 0.1)
fmap.append(x) # Store feature map for Feature Matching Loss
x = self.conv_post(x)
fmap.append(x)
# Flatten back to 1D for score
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(nn.Module):
def __init__(self, periods=[2, 3, 5, 7, 11]):
super(MultiPeriodDiscriminator, self).__init__()
self.discriminators = nn.ModuleList([
DiscriminatorP(p) for p in periods
])
def forward(self, y, y_hat):
y_d_rs = [] # Real scores
y_d_gs = [] # Generated (Fake) scores
fmap_rs = [] # Real feature maps
fmap_gs = [] # Generated (Fake) feature maps
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
# -------------------------------------------------------------------
# 2. Multi-Scale Discriminator (MSD)
# Captures structure at different audio resolutions (raw, x0.5, x0.25).
# -------------------------------------------------------------------
class DiscriminatorS(nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
norm_f = spectral_norm if use_spectral_norm else weight_norm
# Standard 1D Convolutions with large receptive field
self.convs = nn.ModuleList([
norm_f(nn.Conv1d(1, 16, 15, 1, padding=7)),
norm_f(nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
])
self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, 0.1)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiScaleDiscriminator(nn.Module):
def __init__(self):
super(MultiScaleDiscriminator, self).__init__()
# 3 Scales: Original, Downsampled x2, Downsampled x4
self.discriminators = nn.ModuleList([
DiscriminatorS(use_spectral_norm=True),
DiscriminatorS(),
DiscriminatorS(),
])
self.meanpools = nn.ModuleList([
nn.AvgPool1d(4, 2, padding=2),
nn.AvgPool1d(4, 2, padding=2)
])
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
if i != 0:
# Downsample input for subsequent discriminators
y = self.meanpools[i-1](y)
y_hat = self.meanpools[i-1](y_hat)
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
# -------------------------------------------------------------------
# 3. Master Wrapper
# Combines MPD and MSD into one class to fit your training script.
# -------------------------------------------------------------------
class SISUDiscriminator(nn.Module): class SISUDiscriminator(nn.Module):
def __init__(self): def __init__(self):
super(SISUDiscriminator, self).__init__() super(SISUDiscriminator, self).__init__()
self.model = nn.Sequential( self.mpd = MultiPeriodDiscriminator()
nn.Conv1d(2, 128, kernel_size=3, padding=1), self.msd = MultiScaleDiscriminator()
#nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(128, 256, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(256, 128, kernel_size=3, padding=1),
#nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(128, 64, kernel_size=3, padding=1),
#nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(64, 1, kernel_size=3, padding=1),
#nn.LeakyReLU(0.2, inplace=True),
)
self.global_avg_pool = nn.AdaptiveAvgPool1d(1) # Output size (1,)
def forward(self, x): def forward(self, y, y_hat):
x = self.model(x) # Return format:
x = self.global_avg_pool(x) # scores_real, scores_fake, features_real, features_fake
x = x.view(-1, 1) # Flatten to (batch_size, 1)
return x # Run Multi-Period
mpd_y_d_rs, mpd_y_d_gs, mpd_fmap_rs, mpd_fmap_gs = self.mpd(y, y_hat)
# Run Multi-Scale
msd_y_d_rs, msd_y_d_gs, msd_fmap_rs, msd_fmap_gs = self.msd(y, y_hat)
# Combine all results
return (
mpd_y_d_rs + msd_y_d_rs, # All real scores
mpd_y_d_gs + msd_y_d_gs, # All fake scores
mpd_fmap_rs + msd_fmap_rs, # All real feature maps
mpd_fmap_gs + msd_fmap_gs # All fake feature maps
)

View File

@@ -1,27 +1,126 @@
import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.utils.parametrizations import weight_norm
def GeneratorBlock(in_channels, out_channels, kernel_size=3, stride=1, dilation=1):
padding = (kernel_size - 1) // 2 * dilation
return nn.Sequential(
weight_norm(nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding
)),
nn.PReLU(num_parameters=1, init=0.1),
)
class AttentionBlock(nn.Module):
def __init__(self, channels):
super(AttentionBlock, self).__init__()
self.attention = nn.Sequential(
weight_norm(nn.Conv1d(channels, channels // 4, kernel_size=1)),
nn.ReLU(inplace=True),
weight_norm(nn.Conv1d(channels // 4, channels, kernel_size=1)),
nn.Sigmoid(),
)
def forward(self, x):
attention_weights = self.attention(x)
return x + (x * attention_weights)
class ResidualInResidualBlock(nn.Module):
def __init__(self, channels, num_convs=3):
super(ResidualInResidualBlock, self).__init__()
self.conv_layers = nn.Sequential(
*[GeneratorBlock(channels, channels) for _ in range(num_convs)]
)
self.attention = AttentionBlock(channels)
def forward(self, x):
residual = x
x = self.conv_layers(x)
x = self.attention(x)
return x + residual
def UpsampleBlock(in_channels, out_channels, scale_factor=2):
return nn.Sequential(
nn.Upsample(scale_factor=scale_factor, mode='nearest'),
weight_norm(nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1
)),
nn.PReLU(num_parameters=1, init=0.1)
)
class SISUGenerator(nn.Module): class SISUGenerator(nn.Module):
def __init__(self, upscale_scale=1): # No noise_dim parameter def __init__(self, channels=32, num_rirb=4):
super(SISUGenerator, self).__init__() super(SISUGenerator, self).__init__()
self.layers1 = nn.Sequential(
nn.Conv1d(2, 128, kernel_size=3, padding=1), self.first_conv = GeneratorBlock(1, channels)
# nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(128, 256, kernel_size=3, padding=1), self.downsample = GeneratorBlock(channels, channels * 2, stride=2)
# nn.LeakyReLU(0.2, inplace=True), self.downsample_attn = AttentionBlock(channels * 2)
self.downsample_2 = GeneratorBlock(channels * 2, channels * 4, stride=2)
self.downsample_2_attn = AttentionBlock(channels * 4)
self.rirb = nn.Sequential(
*[ResidualInResidualBlock(channels * 4) for _ in range(num_rirb)]
) )
self.layers2 = nn.Sequential( self.upsample = UpsampleBlock(channels * 4, channels * 2)
nn.Conv1d(256, 128, kernel_size=3, padding=1), self.upsample_attn = AttentionBlock(channels * 2)
# nn.LeakyReLU(0.2, inplace=True), self.compress_1 = GeneratorBlock(channels * 4, channels * 2)
nn.Conv1d(128, 64, kernel_size=3, padding=1),
# nn.LeakyReLU(0.2, inplace=True), self.upsample_2 = UpsampleBlock(channels * 2, channels)
nn.Conv1d(64, 2, kernel_size=3, padding=1), self.upsample_2_attn = AttentionBlock(channels)
# nn.Tanh() self.compress_2 = GeneratorBlock(channels * 2, channels)
self.final_conv = nn.Sequential(
weight_norm(nn.Conv1d(channels, 1, kernel_size=7, padding=3)),
nn.Tanh()
) )
def forward(self, x, scale):
x = self.layers1(x) def forward(self, x):
upsample = nn.Upsample(scale_factor=scale, mode='nearest') residual_input = x
x = upsample(x)
x = self.layers2(x) # Encoding
return x x1 = self.first_conv(x)
x2 = self.downsample(x1)
x2 = self.downsample_attn(x2)
x3 = self.downsample_2(x2)
x3 = self.downsample_2_attn(x3)
# Bottleneck (Deep Residual processing)
x_rirb = self.rirb(x3)
# Decoding with Skip Connections
up1 = self.upsample(x_rirb)
up1 = self.upsample_attn(up1)
cat1 = torch.cat((up1, x2), dim=1)
comp1 = self.compress_1(cat1)
up2 = self.upsample_2(comp1)
up2 = self.upsample_2_attn(up2)
cat2 = torch.cat((up2, x1), dim=1)
comp2 = self.compress_2(cat2)
learned_residual = self.final_conv(comp2)
output = residual_input + learned_residual
return output

View File

@@ -1,12 +0,0 @@
filelock>=3.16.1
fsspec>=2024.10.0
Jinja2>=3.1.4
MarkupSafe>=2.1.5
mpmath>=1.3.0
networkx>=3.4.2
numpy>=2.1.2
pillow>=11.0.0
setuptools>=70.2.0
sympy>=1.13.1
tqdm>=4.67.1
typing_extensions>=4.12.2

10
test.py
View File

@@ -1,10 +0,0 @@
import torch.nn as nn
import torch
from discriminator import SISUDiscriminator
discriminator = SISUDiscriminator()
test_input = torch.randn(1, 2, 1000) # Example input (batch_size, channels, frames)
output = discriminator(test_input)
print(output)
print("Output shape:", output.shape)

View File

@@ -1,183 +1,251 @@
import argparse
import datetime
import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import torch.nn.functional as F
import torchaudio
import tqdm import tqdm
from accelerate import Accelerator
from torch.utils.data import random_split from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data import DataLoader
from data import AudioDataset from data import AudioDataset
from generator import SISUGenerator
from discriminator import SISUDiscriminator from discriminator import SISUDiscriminator
from generator import SISUGenerator
from utils.TrainingTools import discriminator_train, generator_train
# Mel Spectrogram Loss # ---------------------------
class MelSpectrogramLoss(nn.Module): # Argument parsing
def __init__(self, sample_rate=44100, n_fft=2048, hop_length=512, n_mels=128): # ---------------------------
super(MelSpectrogramLoss, self).__init__() parser = argparse.ArgumentParser(description="Training script (safer defaults)")
self.mel_transform = torchaudio.transforms.MelSpectrogram( parser.add_argument("--resume", action="store_true", help="Resume training")
sample_rate=sample_rate, parser.add_argument(
n_fft=n_fft, "--epochs", type=int, default=5000, help="Number of training epochs"
hop_length=hop_length, )
n_mels=n_mels parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
).to(device) # Move to device parser.add_argument("--num_workers", type=int, default=4, help="DataLoader num_workers") # Increased workers slightly
parser.add_argument("--debug", action="store_true", help="Print debug logs")
parser.add_argument(
"--no_pin_memory", action="store_true", help="Disable pin_memory even on CUDA"
)
args = parser.parse_args()
def forward(self, y_pred, y_true): # ---------------------------
mel_pred = self.mel_transform(y_pred) # Init accelerator
mel_true = self.mel_transform(y_true) # ---------------------------
return F.l1_loss(mel_pred, mel_true)
def snr(y_true, y_pred): try:
noise = y_true - y_pred accelerator = Accelerator(mixed_precision="bf16")
signal_power = torch.mean(y_true ** 2) except Exception:
noise_power = torch.mean(noise ** 2) accelerator = Accelerator(mixed_precision="fp16")
snr_db = 10 * torch.log10(signal_power / noise_power) accelerator.print("⚠️ | bf16 unavailable — falling back to fp16")
return snr_db
def discriminator_train(high_quality, low_quality, scale, real_labels, fake_labels): # ---------------------------
optimizer_d.zero_grad() # Models
# ---------------------------
discriminator_decision_from_real = discriminator(high_quality)
# TODO: Experiment with criterions HERE!
d_loss_real = criterion_d(discriminator_decision_from_real, real_labels)
generator_output = generator(low_quality, scale)
discriminator_decision_from_fake = discriminator(generator_output.detach())
# TODO: Experiment with criterions HERE!
d_loss_fake = criterion_d(discriminator_decision_from_fake, fake_labels)
d_loss = (d_loss_real + d_loss_fake) / 2.0
d_loss.backward()
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) #Gradient Clipping
optimizer_d.step()
return d_loss
def generator_train(low_quality, scale, real_labels):
optimizer_g.zero_grad()
generator_output = generator(low_quality, scale)
discriminator_decision = discriminator(generator_output)
# TODO: Fix this shit
g_loss = criterion_g(discriminator_decision, real_labels)
g_loss.backward()
optimizer_g.step()
return generator_output
# Check for CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Initialize dataset and dataloader
dataset_dir = './dataset/good'
dataset = AudioDataset(dataset_dir, target_duration=2.0)
dataset_size = len(dataset)
train_size = int(dataset_size * .9)
val_size = int(dataset_size-train_size)
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_data_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_data_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)
# Initialize models and move them to device
generator = SISUGenerator() generator = SISUGenerator()
# Note: SISUDiscriminator is now an Ensemble (MPD + MSD)
discriminator = SISUDiscriminator() discriminator = SISUDiscriminator()
generator = generator.to(device) accelerator.print("🔨 | Compiling models...")
discriminator = discriminator.to(device)
# Loss # Torch compile is great, but if you hit errors with the new List/Tuple outputs
criterion_g = nn.L1Loss() # of the discriminator, you might need to disable it for D.
criterion_g_mel = MelSpectrogramLoss().to(device) generator = torch.compile(generator)
criterion_d = nn.BCEWithLogitsLoss() discriminator = torch.compile(discriminator)
# Optimizers accelerator.print("✅ | Compiling done!")
optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
# Scheduler # ---------------------------
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5) # Dataset / DataLoader
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5) # ---------------------------
accelerator.print("📊 | Fetching dataset...")
dataset = AudioDataset("./dataset", 8192)
def start_training(): sampler = DistributedSampler(dataset) if accelerator.num_processes > 1 else None
pin_memory = torch.cuda.is_available() and not args.no_pin_memory
# Training loop train_loader = DataLoader(
dataset,
sampler=sampler,
batch_size=args.batch_size,
shuffle=(sampler is None),
num_workers=args.num_workers,
pin_memory=pin_memory,
persistent_workers=pin_memory,
)
# ========= DISCRIMINATOR PRE-TRAINING ========= if not train_loader or not train_loader.batch_size or train_loader.batch_size == 0:
discriminator_epochs = 1 accelerator.print("🪹 | There is no data to train with! Exiting...")
for discriminator_epoch in range(discriminator_epochs): exit()
# ========= TRAINING ========= loader_batch_size = train_loader.batch_size
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Epoch {discriminator_epoch+1}/{discriminator_epochs}"):
high_quality_sample = high_quality_clip[0].to(device)
low_quality_sample = low_quality_clip[0].to(device)
scale = high_quality_clip[0].shape[2]/low_quality_clip[0].shape[2] accelerator.print("✅ | Dataset fetched!")
# ========= LABELS ========= # ---------------------------
batch_size = high_quality_sample.size(0) # Losses / Optimizers / Scalers
real_labels = torch.ones(batch_size, 1).to(device) # ---------------------------
fake_labels = torch.zeros(batch_size, 1).to(device)
# ========= DISCRIMINATOR ========= optimizer_g = optim.AdamW(
discriminator.train() generator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001
discriminator_train(high_quality_sample, low_quality_sample, scale, real_labels, fake_labels) )
optimizer_d = optim.AdamW(
discriminator.parameters(), lr=0.0003, betas=(0.5, 0.999), weight_decay=0.0001
)
torch.save(discriminator.state_dict(), "models/discriminator-single-shot-pre-train.pt") scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_g, mode="min", factor=0.5, patience=5
)
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_d, mode="min", factor=0.5, patience=5
)
generator_epochs = 500 # ---------------------------
for generator_epoch in range(generator_epochs): # Prepare accelerator
low_quality_audio = (torch.empty((1)), 1) # ---------------------------
high_quality_audio = (torch.empty((1)), 1)
ai_enhanced_audio = (torch.empty((1)), 1)
# ========= TRAINING ========= generator, discriminator, optimizer_g, optimizer_d, train_loader = accelerator.prepare(
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Epoch {generator_epoch+1}/{generator_epochs}"): generator, discriminator, optimizer_g, optimizer_d, train_loader
high_quality_sample = high_quality_clip[0].to(device) )
low_quality_sample = low_quality_clip[0].to(device)
scale = high_quality_clip[0].shape[2]/low_quality_clip[0].shape[2] # ---------------------------
# Checkpoint helpers
# ---------------------------
models_dir = "./models"
os.makedirs(models_dir, exist_ok=True)
# ========= LABELS =========
batch_size = high_quality_clip[0].size(0)
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# ========= DISCRIMINATOR ========= def save_ckpt(path, epoch, loss=None, is_best=False):
discriminator.train() accelerator.wait_for_everyone()
for _ in range(3): if accelerator.is_main_process:
discriminator_train(high_quality_sample, low_quality_sample, scale, real_labels, fake_labels) state = {
"epoch": epoch,
"G": accelerator.unwrap_model(generator).state_dict(),
"D": accelerator.unwrap_model(discriminator).state_dict(),
"optG": optimizer_g.state_dict(),
"optD": optimizer_d.state_dict(),
"schedG": scheduler_g.state_dict(),
"schedD": scheduler_d.state_dict()
}
# ========= GENERATOR ========= accelerator.save(state, os.path.join(models_dir, "last.pt"))
if is_best:
accelerator.save(state, os.path.join(models_dir, "best.pt"))
accelerator.print(f"🌟 | New best model saved with G Loss: {loss:.4f}")
start_epoch = 0
if args.resume:
ckpt_path = os.path.join(models_dir, "last.pt")
if os.path.exists(ckpt_path):
ckpt = torch.load(ckpt_path)
accelerator.unwrap_model(generator).load_state_dict(ckpt["G"])
accelerator.unwrap_model(discriminator).load_state_dict(ckpt["D"])
optimizer_g.load_state_dict(ckpt["optG"])
optimizer_d.load_state_dict(ckpt["optD"])
scheduler_g.load_state_dict(ckpt["schedG"])
scheduler_d.load_state_dict(ckpt["schedD"])
start_epoch = ckpt.get("epoch", 1)
accelerator.print(f"🔁 | Resumed from epoch {start_epoch}!")
else:
accelerator.print("⚠️ | Resume requested but no checkpoint found. Starting fresh.")
accelerator.print("🏋️ | Started training...")
try:
for epoch in range(start_epoch, args.epochs):
generator.train() generator.train()
generator_output = generator_train(low_quality_sample, scale, real_labels) discriminator.train()
# ========= SAVE LATEST AUDIO ========= discriminator_time = 0
high_quality_audio = high_quality_clip generator_time = 0
low_quality_audio = low_quality_clip
ai_enhanced_audio = (generator_output, high_quality_clip[1])
metric = snr(high_quality_audio[0].to(device), ai_enhanced_audio[0]) running_d, running_g, steps = 0.0, 0.0, 0
print(f"Generator metric {metric}!")
scheduler_g.step(metric)
if generator_epoch % 10 == 0: progress_bar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch} | D {discriminator_time}μs | G {generator_time}μs")
print(f"Saved epoch {generator_epoch}!")
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), low_quality_audio[1])
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1])
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1])
if generator_epoch % 50 == 0: for i, (
torch.save(discriminator.state_dict(), f"models/epoch-{generator_epoch}-discriminator.pt") (high_quality, low_quality),
torch.save(generator.state_dict(), f"models/epoch-{generator_epoch}-generator.pt") (high_sample_rate, low_sample_rate),
) in enumerate(progress_bar):
with accelerator.autocast():
generator_output = generator(low_quality)
torch.save(discriminator.state_dict(), "models/epoch-500-discriminator.pt") # --- Discriminator ---
torch.save(generator.state_dict(), "models/epoch-500-generator.pt") d_time = datetime.datetime.now()
print("Training complete!") optimizer_d.zero_grad(set_to_none=True)
with accelerator.autocast():
d_loss = discriminator_train(
high_quality,
discriminator,
generator_output.detach()
)
start_training() accelerator.backward(d_loss)
optimizer_d.step()
discriminator_time = (datetime.datetime.now() - d_time).microseconds
# --- Generator ---
g_time = datetime.datetime.now()
optimizer_g.zero_grad(set_to_none=True)
with accelerator.autocast():
g_total, g_adv = generator_train(
low_quality,
high_quality,
generator,
discriminator,
generator_output
)
accelerator.backward(g_total)
torch.nn.utils.clip_grad_norm_(generator.parameters(), 1)
optimizer_g.step()
generator_time = (datetime.datetime.now() - g_time).microseconds
d_val = accelerator.gather(d_loss.detach()).mean()
g_val = accelerator.gather(g_total.detach()).mean()
if torch.isfinite(d_val):
running_d += d_val.item()
else:
accelerator.print(
f"🫥 | NaN in discriminator loss at step {i}, skipping update."
)
if torch.isfinite(g_val):
running_g += g_val.item()
else:
accelerator.print(
f"🫥 | NaN in generator loss at step {i}, skipping update."
)
steps += 1
progress_bar.set_description(f"Epoch {epoch} | D {discriminator_time}μs | G {generator_time}μs")
if steps == 0:
accelerator.print("🪹 | No steps in epoch (empty dataloader?). Exiting.")
break
mean_d = running_d / steps
mean_g = running_g / steps
scheduler_d.step(mean_d)
scheduler_g.step(mean_g)
save_ckpt(os.path.join(models_dir, "last.pt"), epoch)
accelerator.print(f"🤝 | Epoch {epoch} done | D {mean_d:.4f} | G {mean_g:.4f}")
except Exception:
try:
save_ckpt(os.path.join(models_dir, "crash_last.pt"), epoch)
accelerator.print(f"💾 | Saved crash checkpoint for epoch {epoch}")
except Exception as e:
accelerator.print("😬 | Failed saving crash checkpoint:", e)
raise
accelerator.print("🏁 | Training finished.")

View File

@@ -0,0 +1,68 @@
from typing import Dict, List
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.transforms as T
class MultiResolutionSTFTLoss(nn.Module):
def __init__(
self,
fft_sizes: List[int] = [512, 1024, 2048, 4096, 8192],
hop_sizes: List[int] = [64, 128, 256, 512, 1024],
win_lengths: List[int] = [256, 512, 1024, 2048, 4096],
eps: float = 1e-7,
center: bool = True
):
super().__init__()
self.eps = eps
self.n_resolutions = len(fft_sizes)
self.stft_transforms = nn.ModuleList()
for i, (n_fft, hop_len, win_len) in enumerate(zip(fft_sizes, hop_sizes, win_lengths)):
stft = T.Spectrogram(
n_fft=n_fft,
hop_length=hop_len,
win_length=win_len,
window_fn=torch.hann_window,
power=None,
center=center,
pad_mode="reflect",
normalized=False,
)
self.stft_transforms.append(stft)
def forward(
self, y_true: torch.Tensor, y_pred: torch.Tensor
) -> Dict[str, torch.Tensor]:
if y_true.dim() == 3 and y_true.size(1) == 1:
y_true = y_true.squeeze(1)
if y_pred.dim() == 3 and y_pred.size(1) == 1:
y_pred = y_pred.squeeze(1)
sc_loss = 0.0
mag_loss = 0.0
for stft in self.stft_transforms:
stft.window = stft.window.to(y_true.device)
stft_true = stft(y_true)
stft_pred = stft(y_pred)
stft_mag_true = torch.abs(stft_true)
stft_mag_pred = torch.abs(stft_pred)
norm_true = torch.linalg.norm(stft_mag_true, dim=(-2, -1))
norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, dim=(-2, -1))
sc_loss += torch.mean(norm_diff / (norm_true + self.eps))
log_mag_pred = torch.log(stft_mag_pred + self.eps)
log_mag_true = torch.log(stft_mag_true + self.eps)
mag_loss += F.l1_loss(log_mag_pred, log_mag_true)
sc_loss /= self.n_resolutions
mag_loss /= self.n_resolutions
total_loss = sc_loss + mag_loss
return {"total": total_loss, "sc": sc_loss, "mag": mag_loss}

88
utils/TrainingTools.py Normal file
View File

@@ -0,0 +1,88 @@
import torch
import torch.nn.functional as F
from utils.MultiResolutionSTFTLoss import MultiResolutionSTFTLoss
# Keep STFT settings as is
stft_loss_fn = MultiResolutionSTFTLoss(
fft_sizes=[512, 1024, 2048],
hop_sizes=[64, 128, 256],
win_lengths=[256, 512, 1024]
)
def feature_matching_loss(fmap_r, fmap_g):
"""
Computes L1 distance between real and fake feature maps.
"""
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
rl = rl.detach()
loss += torch.mean(torch.abs(rl - gl))
return loss * 2
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
"""
Least Squares GAN Loss (LSGAN) for the Discriminator.
Objective: Real -> 1, Fake -> 0
"""
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean((dr - 1) ** 2)
g_loss = torch.mean(dg ** 2)
loss += (r_loss + g_loss)
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
def generator_adv_loss(disc_generated_outputs):
loss = 0.0
for dg in disc_generated_outputs:
loss += torch.mean((dg - 1) ** 2)
return loss
def discriminator_train(
high_quality,
discriminator,
generator_output
):
y_d_rs, y_d_gs, _, _ = discriminator(high_quality, generator_output.detach())
d_loss, _, _ = discriminator_loss(y_d_rs, y_d_gs)
return d_loss
def generator_train(
low_quality,
high_quality,
generator,
discriminator,
generator_output
):
y_d_rs, y_d_gs, fmap_rs, fmap_gs = discriminator(high_quality, generator_output)
loss_gen_adv = generator_adv_loss(y_d_gs)
loss_fm = feature_matching_loss(fmap_rs, fmap_gs)
stft_loss = stft_loss_fn(high_quality, generator_output)["total"]
lambda_stft = 45.0
lambda_fm = 2.0
lambda_adv = 1.0
combined_loss = (lambda_stft * stft_loss) + \
(lambda_fm * loss_fm) + \
(lambda_adv * loss_gen_adv)
return combined_loss, loss_gen_adv

0
utils/__init__.py Normal file
View File