| Made training bit... spicier.

This commit is contained in:
2025-09-10 19:52:53 +03:00
parent ff38cefdd3
commit 0bc8fc2792
8 changed files with 581 additions and 303 deletions

74
app.py
View File

@@ -1,33 +1,49 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchaudio
import tqdm
import argparse
import math
import os
import torch
import torchaudio
import torchcodec
import tqdm
import AudioUtils
from generator import SISUGenerator
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--model", type=str, help="Model to use for upscaling")
parser.add_argument("--clip_length", type=int, default=1024, help="Internal clip length, leave unspecified if unsure")
parser.add_argument(
"--clip_length",
type=int,
default=16384,
help="Internal clip length, leave unspecified if unsure",
)
parser.add_argument(
"--sample_rate", type=int, default=44100, help="Output clip sample rate"
)
parser.add_argument(
"--bitrate",
type=int,
default=192000,
help="Output clip bitrate",
)
parser.add_argument("-i", "--input", type=str, help="Input audio file")
parser.add_argument("-o", "--output", type=str, help="Output audio file")
args = parser.parse_args()
if args.sample_rate < 8000:
print(
"Sample rate cannot be lower than 8000! (44100 is recommended for base models)"
)
exit()
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
generator = SISUGenerator()
generator = SISUGenerator().to(device)
generator = torch.compile(generator)
models_dir = args.model
clip_length = args.clip_length
@@ -35,17 +51,30 @@ input_audio = args.input
output_audio = args.output
if models_dir:
generator.load_state_dict(torch.load(f"{models_dir}", map_location=device, weights_only=True))
ckpt = torch.load(models_dir, map_location=device)
generator.load_state_dict(ckpt["G"])
else:
print(f"Generator model (--model) isn't specified. Do you have the trained model? If not you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)")
print(
"Generator model (--model) isn't specified. Do you have the trained model? If not, you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)"
)
generator = generator.to(device)
def start():
# To Mono!
audio, original_sample_rate = torchaudio.load(input_audio, normalize=True)
decoder = torchcodec.decoders.AudioDecoder(input_audio)
decoded_samples = decoder.get_all_samples()
audio = decoded_samples.data
original_sample_rate = decoded_samples.sample_rate
audio = AudioUtils.stereo_tensor_to_mono(audio)
resample_transform = torchaudio.transforms.Resample(
original_sample_rate, args.sample_rate
)
audio = resample_transform(audio)
splitted_audio = AudioUtils.split_audio(audio, clip_length)
splitted_audio_on_device = [t.to(device) for t in splitted_audio]
processed_audio = []
@@ -55,6 +84,13 @@ def start():
reconstructed_audio = AudioUtils.reconstruct_audio(processed_audio)
print(f"Saving {output_audio}!")
torchaudio.save(output_audio, reconstructed_audio.cpu().detach(), original_sample_rate)
torchaudio.save_with_torchcodec(
uri=output_audio,
src=reconstructed_audio,
sample_rate=args.sample_rate,
channels_first=True,
compression=args.bitrate,
)
start()