diff --git a/.gitignore b/.gitignore index 5e49b42..493e1db 100644 --- a/.gitignore +++ b/.gitignore @@ -158,4 +158,9 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ \ No newline at end of file +#.idea/ + +# Project based files +backup/ +dataset/ +old-output/ diff --git a/data.py b/data.py new file mode 100644 index 0000000..214b6e5 --- /dev/null +++ b/data.py @@ -0,0 +1,49 @@ +import torch +from torch.utils.data import Dataset +import torchaudio +import os + +class AudioDataset(Dataset): + def __init__(self, input_dir, target_duration=None, padding_mode='constant', padding_value=0.0): + self.input_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.wav')] + self.target_duration = target_duration # Duration in seconds or None if not set + self.padding_mode = padding_mode + self.padding_value = padding_value + + def __len__(self): + return len(self.input_files) + + def __getitem__(self, idx): + # Load audio samples using torchaudio + high_quality_wav, sr_original = torchaudio.load(self.input_files[idx], normalize=True) + + # Resample to 16000 Hz if necessary + resample_transform = torchaudio.transforms.Resample(sr_original, 16000) + low_quality_wav = resample_transform(high_quality_wav) + + # Calculate target length in samples if target_duration is specified + if self.target_duration is not None: + target_length = int(self.target_duration * 16000) # Assuming 16000 Hz as target sample rate + else: + target_length = high_quality_wav.size(1) + + # Pad high_quality_wav and low_quality_wav to target_length + high_quality_wav = self.pad_tensor(high_quality_wav, target_length) + low_quality_wav = self.pad_tensor(low_quality_wav, target_length) + + return high_quality_wav, low_quality_wav + + def pad_tensor(self, tensor, target_length): + """Pad tensor to target length along the time dimension (dim=1).""" + current_length = tensor.size(1) + + if current_length < target_length: + # Calculate padding amount for each side + padding_amount = target_length - current_length + padding = (0, padding_amount) # (left_pad, right_pad) for 1D padding + tensor = torch.nn.functional.pad(tensor, padding, mode=self.padding_mode, value=self.padding_value) + else: + # If tensor is longer than target, truncate it + tensor = tensor[:, :target_length] + + return tensor diff --git a/discriminator.py b/discriminator.py new file mode 100644 index 0000000..e0083d4 --- /dev/null +++ b/discriminator.py @@ -0,0 +1,23 @@ +import torch.nn as nn + +class SISUDiscriminator(nn.Module): + def __init__(self): + super(SISUDiscriminator, self).__init__() + self.model = nn.Sequential( + nn.Conv1d(2, 64, kernel_size=4, stride=2, padding=1), # Now accepts 2 input channels + nn.LeakyReLU(0.2), + nn.Conv1d(64, 128, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(128), + nn.LeakyReLU(0.2), + nn.Conv1d(128, 256, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(256), + nn.LeakyReLU(0.2), + nn.Conv1d(256, 512, kernel_size=4, stride=2, padding=1), + nn.BatchNorm1d(512), + nn.LeakyReLU(0.2), + nn.Conv1d(512, 1, kernel_size=4, stride=1, padding=0), + nn.Sigmoid() + ) + + def forward(self, x): + return self.model(x) diff --git a/generator.py b/generator.py new file mode 100644 index 0000000..b9401d7 --- /dev/null +++ b/generator.py @@ -0,0 +1,24 @@ +import torch.nn as nn + +class SISUGenerator(nn.Module): + def __init__(self): # No noise_dim parameter + super(SISUGenerator, self).__init__() + self.model = nn.Sequential( + nn.Conv1d(2, 64, kernel_size=7, stride=1, padding=3), # Input 2 channels (low-quality audio) + nn.LeakyReLU(0.2), + nn.Conv1d(64, 64, kernel_size=7, stride=1, padding=3), + nn.LeakyReLU(0.2), + nn.Conv1d(64, 128, kernel_size=5, stride=2, padding=2), + nn.LeakyReLU(0.2), + nn.Conv1d(128, 128, kernel_size=5, stride=1, padding=2), + nn.LeakyReLU(0.2), + nn.ConvTranspose1d(128, 64, kernel_size=4, stride=2, padding=1), + nn.LeakyReLU(0.2), + nn.Conv1d(64, 64, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2), + nn.Conv1d(64, 2, kernel_size=3, stride=1, padding=1), # Output 2 channels (high-quality audio) + nn.Tanh() + ) + + def forward(self, x): + return self.model(x) diff --git a/output.wav b/output.wav new file mode 100644 index 0000000..90b8499 Binary files /dev/null and b/output.wav differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..dc0c01f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +filelock>=3.16.1 +fsspec>=2024.10.0 +Jinja2>=3.1.4 +MarkupSafe>=2.1.5 +mpmath>=1.3.0 +networkx>=3.4.2 +numpy>=2.1.2 +pillow>=11.0.0 +setuptools>=70.2.0 +sympy>=1.13.1 +tqdm>=4.67.1 +typing_extensions>=4.12.2 diff --git a/training.py b/training.py new file mode 100644 index 0000000..23a918a --- /dev/null +++ b/training.py @@ -0,0 +1,112 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import torchaudio +import tqdm + +from torch.utils.data import random_split +from torch.utils.data import DataLoader + +from data import AudioDataset +from generator import SISUGenerator +from discriminator import SISUDiscriminator + +# Check for CUDA availability +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f"Using device: {device}") + +# Initialize dataset and dataloader +dataset_dir = './dataset/good' +dataset = AudioDataset(dataset_dir, target_duration=2.0) # 5 seconds target duration + +dataset_size = len(dataset) +train_size = int(dataset_size * .9) +val_size = int(dataset_size-train_size) + +train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) + +train_data_loader = DataLoader(train_dataset, batch_size=4, shuffle=True) +val_data_loader = DataLoader(val_dataset, batch_size=4, shuffle=True) + +# Initialize models and move them to device +generator = SISUGenerator() +discriminator = SISUDiscriminator() + +generator = generator.to(device) +discriminator = discriminator.to(device) + +# Loss and optimizers +criterion = nn.MSELoss() # Use Mean Squared Error loss +optimizer_g = optim.Adam(generator.parameters(), lr=0.0005, betas=(0.5, 0.999)) +optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999)) + +# Learning rate scheduler +scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.1, patience=5) + +# Training loop +num_epochs = 500 +for epoch in range(num_epochs): + latest_crap_audio = torch.empty((2,3), dtype=torch.int64) + for high_quality, low_quality in tqdm.tqdm(train_data_loader): + # Check for NaN values in input tensors + if torch.isnan(low_quality).any() or torch.isnan(high_quality).any(): + continue + + high_quality = high_quality.to(device) + low_quality = low_quality.to(device) + + batch_size = low_quality.size(0) + + # Labels + real_labels = torch.ones(batch_size, 1).to(device) + fake_labels = torch.zeros(batch_size, 1).to(device) + + # Train Discriminator + optimizer_d.zero_grad() + outputs = discriminator(high_quality) + d_loss_real = criterion(outputs, real_labels) + d_loss_real.backward() + + resampled_audio = generator(low_quality) + + outputs = discriminator(resampled_audio.detach()) + d_loss_fake = criterion(outputs, fake_labels) + d_loss_fake.backward() + + + # Gradient clipping for discriminator + clip_value = 2.0 + for param in discriminator.parameters(): + if param.grad is not None: + param.grad.clamp_(-clip_value, clip_value) + + optimizer_d.step() + + d_loss = d_loss_real + d_loss_fake + + # Train Generator + optimizer_g.zero_grad() + outputs = discriminator(resampled_audio) + g_loss = criterion(outputs, real_labels) + g_loss.backward() + + # Gradient clipping for generator + clip_value = 1.0 + for param in generator.parameters(): + if param.grad is not None: + param.grad.clamp_(-clip_value, clip_value) + + optimizer_g.step() + + scheduler.step(d_loss + g_loss) + latest_crap_audio = resampled_audio + + if epoch % 10 == 0: + print(latest_crap_audio.size()) + torchaudio.save(f"./epoch-{epoch}-audio.wav", latest_crap_audio[0].cpu(), 44100) + print(f'Epoch [{epoch+1}/{num_epochs}]') + +torch.save(generator.state_dict(), "generator.pt") +torch.save(discriminator.state_dict(), "discriminator.pt") + +print("Training complete!") diff --git a/use.py b/use.py new file mode 100644 index 0000000..59996f8 --- /dev/null +++ b/use.py @@ -0,0 +1,37 @@ +import torch +import torchaudio +from generator import SISUGenerator + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Initialize models and move them to device +generator = SISUGenerator() +generator.load_state_dict(torch.load("generator.pt", weights_only=True)) +generator.to(device) +generator.eval() + +def generate_audio(input_audio_path, output_audio_path): + # Load and preprocess input audio + low_quality_wav, sr_b = torchaudio.load(input_audio_path) + low_quality_wav = low_quality_wav.to(device) + + # Normalize audio + low_quality_wav = normalize(low_quality_wav) + + # Flatten the input if necessary + low_quality_wav = low_quality_wav.view(low_quality_wav.size(0), -1) + + fake_audio = generator(low_quality_wav) + + print(fake_audio) + + print(f"Generated audio saved to {output_audio_path}") + return low_quality_wav + +def normalize(wav): + return wav / torch.max(torch.abs(wav)) + +# Example usage +input_audio_path = "/mnt/games/Home/Downloads/SISU/sample_3_023.wav" +output_audio_path = "/mnt/games/Home/Downloads/SISU/godtier_audio.wav" +generate_audio(input_audio_path, output_audio_path)