From 2ded03713d38e4f0e45e11cb945c5ffda7e90a64 Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Fri, 6 Jun 2025 22:10:06 +0300 Subject: [PATCH] :sparkles: | Added app.py script so the model can be used. --- AudioUtils.py | 19 ++++++++++++++++ app.py | 60 +++++++++++++++++++++++++++++++++++++++++++++++++++ training.py | 18 ++++++++-------- 3 files changed, 88 insertions(+), 9 deletions(-) create mode 100644 app.py diff --git a/AudioUtils.py b/AudioUtils.py index f4866dd..f45efb5 100644 --- a/AudioUtils.py +++ b/AudioUtils.py @@ -50,3 +50,22 @@ def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 128) -> list[torch chunks = list(torch.split(audio_tensor, chunk_size, dim=split_dim)) return chunks + +def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor: + if not chunks: + return torch.empty(0) + + if len(chunks) == 1 and chunks[0].dim() == 0: + return chunks[0] + + concat_dim = -1 + + try: + reconstructed_tensor = torch.cat(chunks, dim=concat_dim) + except RuntimeError as e: + raise RuntimeError( + f"Failed to concatenate audio chunks. Ensure chunks have compatible shapes " + f"for concatenation along dimension {concat_dim}. Original error: {e}" + ) + + return reconstructed_tensor diff --git a/app.py b/app.py new file mode 100644 index 0000000..a24486b --- /dev/null +++ b/app.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn +import torch.optim as optim + +import torch.nn.functional as F +import torchaudio +import tqdm + +import argparse +import math +import os + +import AudioUtils +from generator import SISUGenerator + + +# Init script argument parser +parser = argparse.ArgumentParser(description="Training script") +parser.add_argument("--device", type=str, default="cpu", help="Select device") +parser.add_argument("--model", type=str, help="Model to use for upscaling") +parser.add_argument("--clip_length", type=int, default=256, help="Internal clip length, leave unspecified if unsure") +parser.add_argument("-i", "--input", type=str, help="Input audio file") +parser.add_argument("-o", "--output", type=str, help="Output audio file") + +args = parser.parse_args() + +device = torch.device(args.device if torch.cuda.is_available() else "cpu") +print(f"Using device: {device}") + +generator = SISUGenerator() + +models_dir = args.model +clip_length = args.clip_length +input_audio = args.input +output_audio = args.output + +if models_dir: + generator.load_state_dict(torch.load(f"{models_dir}", map_location=device, weights_only=True)) +else: + print(f"Generator model (--model) isn't specified. Do you have the trained model? If not you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)") + +generator = generator.to(device) + +def start(): + # To Mono! + audio, original_sample_rate = torchaudio.load(input_audio, normalize=True) + audio = AudioUtils.stereo_tensor_to_mono(audio) + + splitted_audio = AudioUtils.split_audio(audio, clip_length) + splitted_audio_on_device = [t.to(device) for t in splitted_audio] + processed_audio = [] + + for clip in tqdm.tqdm(splitted_audio_on_device, desc="Processing..."): + processed_audio.append(generator(clip)) + + reconstructed_audio = AudioUtils.reconstruct_audio(processed_audio) + print(f"Saving {output_audio}!") + torchaudio.save(output_audio, reconstructed_audio.cpu().detach(), original_sample_rate) + +start() diff --git a/training.py b/training.py index 5ccabc7..52275d5 100644 --- a/training.py +++ b/training.py @@ -69,14 +69,14 @@ debug = args.debug # Initialize dataset and dataloader dataset_dir = './dataset/good' dataset = AudioDataset(dataset_dir, device) -models_dir = "models" +models_dir = "./models" os.makedirs(models_dir, exist_ok=True) -audio_output_dir = "output" +audio_output_dir = "./output" os.makedirs(audio_output_dir, exist_ok=True) # ========= SINGLE ========= -train_data_loader = DataLoader(dataset, batch_size=1024, shuffle=True) +train_data_loader = DataLoader(dataset, batch_size=8192, shuffle=True, num_workers=24) # ========= MODELS ========= @@ -85,17 +85,17 @@ generator = SISUGenerator() discriminator = SISUDiscriminator() epoch: int = args.epoch -epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json") if args.continue_training: - generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) - discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) - epoch = epoch_from_file["epoch"] + 1 -else: if args.generator is not None: generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True)) - if args.discriminator is not None: + elif args.discriminator is not None: discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True)) + else: + generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) + discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) + epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json") + epoch = epoch_from_file["epoch"] + 1 generator = generator.to(device) discriminator = discriminator.to(device)