✨ | Added app.py script so the model can be used.
This commit is contained in:
@ -50,3 +50,22 @@ def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 128) -> list[torch
|
|||||||
chunks = list(torch.split(audio_tensor, chunk_size, dim=split_dim))
|
chunks = list(torch.split(audio_tensor, chunk_size, dim=split_dim))
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor:
|
||||||
|
if not chunks:
|
||||||
|
return torch.empty(0)
|
||||||
|
|
||||||
|
if len(chunks) == 1 and chunks[0].dim() == 0:
|
||||||
|
return chunks[0]
|
||||||
|
|
||||||
|
concat_dim = -1
|
||||||
|
|
||||||
|
try:
|
||||||
|
reconstructed_tensor = torch.cat(chunks, dim=concat_dim)
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to concatenate audio chunks. Ensure chunks have compatible shapes "
|
||||||
|
f"for concatenation along dimension {concat_dim}. Original error: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return reconstructed_tensor
|
||||||
|
60
app.py
Normal file
60
app.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchaudio
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
|
||||||
|
import AudioUtils
|
||||||
|
from generator import SISUGenerator
|
||||||
|
|
||||||
|
|
||||||
|
# Init script argument parser
|
||||||
|
parser = argparse.ArgumentParser(description="Training script")
|
||||||
|
parser.add_argument("--device", type=str, default="cpu", help="Select device")
|
||||||
|
parser.add_argument("--model", type=str, help="Model to use for upscaling")
|
||||||
|
parser.add_argument("--clip_length", type=int, default=256, help="Internal clip length, leave unspecified if unsure")
|
||||||
|
parser.add_argument("-i", "--input", type=str, help="Input audio file")
|
||||||
|
parser.add_argument("-o", "--output", type=str, help="Output audio file")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
generator = SISUGenerator()
|
||||||
|
|
||||||
|
models_dir = args.model
|
||||||
|
clip_length = args.clip_length
|
||||||
|
input_audio = args.input
|
||||||
|
output_audio = args.output
|
||||||
|
|
||||||
|
if models_dir:
|
||||||
|
generator.load_state_dict(torch.load(f"{models_dir}", map_location=device, weights_only=True))
|
||||||
|
else:
|
||||||
|
print(f"Generator model (--model) isn't specified. Do you have the trained model? If not you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)")
|
||||||
|
|
||||||
|
generator = generator.to(device)
|
||||||
|
|
||||||
|
def start():
|
||||||
|
# To Mono!
|
||||||
|
audio, original_sample_rate = torchaudio.load(input_audio, normalize=True)
|
||||||
|
audio = AudioUtils.stereo_tensor_to_mono(audio)
|
||||||
|
|
||||||
|
splitted_audio = AudioUtils.split_audio(audio, clip_length)
|
||||||
|
splitted_audio_on_device = [t.to(device) for t in splitted_audio]
|
||||||
|
processed_audio = []
|
||||||
|
|
||||||
|
for clip in tqdm.tqdm(splitted_audio_on_device, desc="Processing..."):
|
||||||
|
processed_audio.append(generator(clip))
|
||||||
|
|
||||||
|
reconstructed_audio = AudioUtils.reconstruct_audio(processed_audio)
|
||||||
|
print(f"Saving {output_audio}!")
|
||||||
|
torchaudio.save(output_audio, reconstructed_audio.cpu().detach(), original_sample_rate)
|
||||||
|
|
||||||
|
start()
|
18
training.py
18
training.py
@ -69,14 +69,14 @@ debug = args.debug
|
|||||||
# Initialize dataset and dataloader
|
# Initialize dataset and dataloader
|
||||||
dataset_dir = './dataset/good'
|
dataset_dir = './dataset/good'
|
||||||
dataset = AudioDataset(dataset_dir, device)
|
dataset = AudioDataset(dataset_dir, device)
|
||||||
models_dir = "models"
|
models_dir = "./models"
|
||||||
os.makedirs(models_dir, exist_ok=True)
|
os.makedirs(models_dir, exist_ok=True)
|
||||||
audio_output_dir = "output"
|
audio_output_dir = "./output"
|
||||||
os.makedirs(audio_output_dir, exist_ok=True)
|
os.makedirs(audio_output_dir, exist_ok=True)
|
||||||
|
|
||||||
# ========= SINGLE =========
|
# ========= SINGLE =========
|
||||||
|
|
||||||
train_data_loader = DataLoader(dataset, batch_size=1024, shuffle=True)
|
train_data_loader = DataLoader(dataset, batch_size=8192, shuffle=True, num_workers=24)
|
||||||
|
|
||||||
|
|
||||||
# ========= MODELS =========
|
# ========= MODELS =========
|
||||||
@ -85,17 +85,17 @@ generator = SISUGenerator()
|
|||||||
discriminator = SISUDiscriminator()
|
discriminator = SISUDiscriminator()
|
||||||
|
|
||||||
epoch: int = args.epoch
|
epoch: int = args.epoch
|
||||||
epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json")
|
|
||||||
|
|
||||||
if args.continue_training:
|
if args.continue_training:
|
||||||
generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
|
|
||||||
discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
|
|
||||||
epoch = epoch_from_file["epoch"] + 1
|
|
||||||
else:
|
|
||||||
if args.generator is not None:
|
if args.generator is not None:
|
||||||
generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True))
|
generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True))
|
||||||
if args.discriminator is not None:
|
elif args.discriminator is not None:
|
||||||
discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True))
|
discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True))
|
||||||
|
else:
|
||||||
|
generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
|
||||||
|
discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
|
||||||
|
epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json")
|
||||||
|
epoch = epoch_from_file["epoch"] + 1
|
||||||
|
|
||||||
generator = generator.to(device)
|
generator = generator.to(device)
|
||||||
discriminator = discriminator.to(device)
|
discriminator = discriminator.to(device)
|
||||||
|
Reference in New Issue
Block a user