import argparse import torch import torchaudio import torchcodec import tqdm from accelerate import Accelerator import AudioUtils from generator import SISUGenerator # Init script argument parser parser = argparse.ArgumentParser(description="Training script") parser.add_argument("--device", type=str, default="cpu", help="Select device") parser.add_argument("--model", type=str, help="Model to use for upscaling") parser.add_argument( "--clip_length", type=int, default=8000, help="Internal clip length, leave unspecified if unsure", ) parser.add_argument( "--sample_rate", type=int, default=44100, help="Output clip sample rate" ) parser.add_argument( "--bitrate", type=int, default=192000, help="Output clip bitrate", ) parser.add_argument("-i", "--input", type=str, help="Input audio file") parser.add_argument("-o", "--output", type=str, help="Output audio file") args = parser.parse_args() if args.sample_rate < 8000: print( "Sample rate cannot be lower than 8000! (44100 is recommended for base models)" ) exit() # --------------------------- # Init accelerator # --------------------------- accelerator = Accelerator(mixed_precision="bf16") # --------------------------- # Models # --------------------------- generator = SISUGenerator() accelerator.print("🔨 | Compiling models...") generator = torch.compile(generator) accelerator.print("✅ | Compiling done!") # --------------------------- # Prepare accelerator # --------------------------- generator = accelerator.prepare(generator) # --------------------------- # Checkpoint helpers # --------------------------- models_dir = args.model clip_length = args.clip_length input_audio = args.input output_audio = args.output if models_dir: ckpt = torch.load(models_dir) accelerator.unwrap_model(generator).load_state_dict(ckpt["G"]) accelerator.print("💾 | Loaded model!") else: print( "Generator model (--model) isn't specified. Do you have the trained model? If not, you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)" ) def start(): # To Mono! decoder = torchcodec.decoders.AudioDecoder(input_audio) decoded_samples = decoder.get_all_samples() audio = decoded_samples.data original_sample_rate = decoded_samples.sample_rate # Support for multichannel audio # audio = AudioUtils.stereo_tensor_to_mono(audio) audio = AudioUtils.normalize(audio) resample_transform = torchaudio.transforms.Resample( original_sample_rate, args.sample_rate ) audio = resample_transform(audio) splitted_audio = AudioUtils.split_audio(audio, clip_length) splitted_audio_on_device = [t.view(1, t.shape[0], t.shape[-1]).to(accelerator.device) for t in splitted_audio] processed_audio = [] with torch.no_grad(): for clip in tqdm.tqdm(splitted_audio_on_device, desc="Processing..."): channels = [] for audio_channel in torch.split(clip, 1, dim=1): output_piece = generator(audio_channel) channels.append(output_piece.detach().cpu()) output_clip = torch.cat(channels, dim=1) processed_audio.append(output_clip) reconstructed_audio = AudioUtils.reconstruct_audio(processed_audio) reconstructed_audio = reconstructed_audio.squeeze(0) print(f"🔊 | Saving {output_audio}!") torchaudio.save_with_torchcodec( uri=output_audio, src=reconstructed_audio, sample_rate=args.sample_rate, channels_first=True, compression=args.bitrate, ) start()