✨ | Made training bit... spicier.
This commit is contained in:
74
app.py
74
app.py
@@ -1,33 +1,49 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
import tqdm
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchcodec
|
||||
import tqdm
|
||||
|
||||
import AudioUtils
|
||||
from generator import SISUGenerator
|
||||
|
||||
|
||||
# Init script argument parser
|
||||
parser = argparse.ArgumentParser(description="Training script")
|
||||
parser.add_argument("--device", type=str, default="cpu", help="Select device")
|
||||
parser.add_argument("--device", type=str, default="cpu", help="Select device")
|
||||
parser.add_argument("--model", type=str, help="Model to use for upscaling")
|
||||
parser.add_argument("--clip_length", type=int, default=1024, help="Internal clip length, leave unspecified if unsure")
|
||||
parser.add_argument(
|
||||
"--clip_length",
|
||||
type=int,
|
||||
default=16384,
|
||||
help="Internal clip length, leave unspecified if unsure",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample_rate", type=int, default=44100, help="Output clip sample rate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bitrate",
|
||||
type=int,
|
||||
default=192000,
|
||||
help="Output clip bitrate",
|
||||
)
|
||||
parser.add_argument("-i", "--input", type=str, help="Input audio file")
|
||||
parser.add_argument("-o", "--output", type=str, help="Output audio file")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.sample_rate < 8000:
|
||||
print(
|
||||
"Sample rate cannot be lower than 8000! (44100 is recommended for base models)"
|
||||
)
|
||||
exit()
|
||||
|
||||
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
generator = SISUGenerator()
|
||||
generator = SISUGenerator().to(device)
|
||||
|
||||
generator = torch.compile(generator)
|
||||
|
||||
models_dir = args.model
|
||||
clip_length = args.clip_length
|
||||
@@ -35,17 +51,30 @@ input_audio = args.input
|
||||
output_audio = args.output
|
||||
|
||||
if models_dir:
|
||||
generator.load_state_dict(torch.load(f"{models_dir}", map_location=device, weights_only=True))
|
||||
ckpt = torch.load(models_dir, map_location=device)
|
||||
generator.load_state_dict(ckpt["G"])
|
||||
else:
|
||||
print(f"Generator model (--model) isn't specified. Do you have the trained model? If not you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)")
|
||||
print(
|
||||
"Generator model (--model) isn't specified. Do you have the trained model? If not, you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)"
|
||||
)
|
||||
|
||||
generator = generator.to(device)
|
||||
|
||||
def start():
|
||||
# To Mono!
|
||||
audio, original_sample_rate = torchaudio.load(input_audio, normalize=True)
|
||||
decoder = torchcodec.decoders.AudioDecoder(input_audio)
|
||||
|
||||
decoded_samples = decoder.get_all_samples()
|
||||
audio = decoded_samples.data
|
||||
original_sample_rate = decoded_samples.sample_rate
|
||||
|
||||
audio = AudioUtils.stereo_tensor_to_mono(audio)
|
||||
|
||||
resample_transform = torchaudio.transforms.Resample(
|
||||
original_sample_rate, args.sample_rate
|
||||
)
|
||||
|
||||
audio = resample_transform(audio)
|
||||
|
||||
splitted_audio = AudioUtils.split_audio(audio, clip_length)
|
||||
splitted_audio_on_device = [t.to(device) for t in splitted_audio]
|
||||
processed_audio = []
|
||||
@@ -55,6 +84,13 @@ def start():
|
||||
|
||||
reconstructed_audio = AudioUtils.reconstruct_audio(processed_audio)
|
||||
print(f"Saving {output_audio}!")
|
||||
torchaudio.save(output_audio, reconstructed_audio.cpu().detach(), original_sample_rate)
|
||||
torchaudio.save_with_torchcodec(
|
||||
uri=output_audio,
|
||||
src=reconstructed_audio,
|
||||
sample_rate=args.sample_rate,
|
||||
channels_first=True,
|
||||
compression=args.bitrate,
|
||||
)
|
||||
|
||||
|
||||
start()
|
||||
|
Reference in New Issue
Block a user