61 lines
1.9 KiB
Python
61 lines
1.9 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
|
|
import torch.nn.functional as F
|
|
import torchaudio
|
|
import tqdm
|
|
|
|
import argparse
|
|
import math
|
|
import os
|
|
|
|
import AudioUtils
|
|
from generator import SISUGenerator
|
|
|
|
|
|
# Init script argument parser
|
|
parser = argparse.ArgumentParser(description="Training script")
|
|
parser.add_argument("--device", type=str, default="cpu", help="Select device")
|
|
parser.add_argument("--model", type=str, help="Model to use for upscaling")
|
|
parser.add_argument("--clip_length", type=int, default=1024, help="Internal clip length, leave unspecified if unsure")
|
|
parser.add_argument("-i", "--input", type=str, help="Input audio file")
|
|
parser.add_argument("-o", "--output", type=str, help="Output audio file")
|
|
|
|
args = parser.parse_args()
|
|
|
|
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
|
print(f"Using device: {device}")
|
|
|
|
generator = SISUGenerator()
|
|
|
|
models_dir = args.model
|
|
clip_length = args.clip_length
|
|
input_audio = args.input
|
|
output_audio = args.output
|
|
|
|
if models_dir:
|
|
generator.load_state_dict(torch.load(f"{models_dir}", map_location=device, weights_only=True))
|
|
else:
|
|
print(f"Generator model (--model) isn't specified. Do you have the trained model? If not you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)")
|
|
|
|
generator = generator.to(device)
|
|
|
|
def start():
|
|
# To Mono!
|
|
audio, original_sample_rate = torchaudio.load(input_audio, normalize=True)
|
|
audio = AudioUtils.stereo_tensor_to_mono(audio)
|
|
|
|
splitted_audio = AudioUtils.split_audio(audio, clip_length)
|
|
splitted_audio_on_device = [t.to(device) for t in splitted_audio]
|
|
processed_audio = []
|
|
|
|
for clip in tqdm.tqdm(splitted_audio_on_device, desc="Processing..."):
|
|
processed_audio.append(generator(clip))
|
|
|
|
reconstructed_audio = AudioUtils.reconstruct_audio(processed_audio)
|
|
print(f"Saving {output_audio}!")
|
|
torchaudio.save(output_audio, reconstructed_audio.cpu().detach(), original_sample_rate)
|
|
|
|
start()
|