⚗️ | More architectural changes
This commit is contained in:
55
app.py
55
app.py
@@ -4,6 +4,7 @@ import torch
|
||||
import torchaudio
|
||||
import torchcodec
|
||||
import tqdm
|
||||
from accelerate import Accelerator
|
||||
|
||||
import AudioUtils
|
||||
from generator import SISUGenerator
|
||||
@@ -15,7 +16,7 @@ parser.add_argument("--model", type=str, help="Model to use for upscaling")
|
||||
parser.add_argument(
|
||||
"--clip_length",
|
||||
type=int,
|
||||
default=16384,
|
||||
default=8000,
|
||||
help="Internal clip length, leave unspecified if unsure",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -38,21 +39,44 @@ if args.sample_rate < 8000:
|
||||
)
|
||||
exit()
|
||||
|
||||
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
# ---------------------------
|
||||
# Init accelerator
|
||||
# ---------------------------
|
||||
|
||||
generator = SISUGenerator().to(device)
|
||||
accelerator = Accelerator(mixed_precision="bf16")
|
||||
|
||||
# ---------------------------
|
||||
# Models
|
||||
# ---------------------------
|
||||
generator = SISUGenerator()
|
||||
|
||||
accelerator.print("🔨 | Compiling models...")
|
||||
|
||||
generator = torch.compile(generator)
|
||||
|
||||
accelerator.print("✅ | Compiling done!")
|
||||
|
||||
# ---------------------------
|
||||
# Prepare accelerator
|
||||
# ---------------------------
|
||||
|
||||
generator = accelerator.prepare(generator)
|
||||
|
||||
# ---------------------------
|
||||
# Checkpoint helpers
|
||||
# ---------------------------
|
||||
|
||||
models_dir = args.model
|
||||
clip_length = args.clip_length
|
||||
input_audio = args.input
|
||||
output_audio = args.output
|
||||
|
||||
if models_dir:
|
||||
ckpt = torch.load(models_dir, map_location=device)
|
||||
generator.load_state_dict(ckpt["G"])
|
||||
ckpt = torch.load(models_dir)
|
||||
|
||||
accelerator.unwrap_model(generator).load_state_dict(ckpt["G"])
|
||||
|
||||
accelerator.print("💾 | Loaded model!")
|
||||
else:
|
||||
print(
|
||||
"Generator model (--model) isn't specified. Do you have the trained model? If not, you need to train it OR acquire it from somewhere (DON'T ASK ME, YET!)"
|
||||
@@ -67,7 +91,8 @@ def start():
|
||||
audio = decoded_samples.data
|
||||
original_sample_rate = decoded_samples.sample_rate
|
||||
|
||||
audio = AudioUtils.stereo_tensor_to_mono(audio)
|
||||
# Support for multichannel audio
|
||||
# audio = AudioUtils.stereo_tensor_to_mono(audio)
|
||||
audio = AudioUtils.normalize(audio)
|
||||
|
||||
resample_transform = torchaudio.transforms.Resample(
|
||||
@@ -77,14 +102,20 @@ def start():
|
||||
audio = resample_transform(audio)
|
||||
|
||||
splitted_audio = AudioUtils.split_audio(audio, clip_length)
|
||||
splitted_audio_on_device = [t.to(device) for t in splitted_audio]
|
||||
splitted_audio_on_device = [t.view(1, t.shape[0], t.shape[-1]).to(accelerator.device) for t in splitted_audio]
|
||||
processed_audio = []
|
||||
|
||||
for clip in tqdm.tqdm(splitted_audio_on_device, desc="Processing..."):
|
||||
processed_audio.append(generator(clip))
|
||||
with torch.no_grad():
|
||||
for clip in tqdm.tqdm(splitted_audio_on_device, desc="Processing..."):
|
||||
channels = []
|
||||
for audio_channel in torch.split(clip, 1, dim=1):
|
||||
output_piece = generator(audio_channel)
|
||||
channels.append(output_piece.detach().cpu())
|
||||
output_clip = torch.cat(channels, dim=1)
|
||||
processed_audio.append(output_clip)
|
||||
|
||||
reconstructed_audio = AudioUtils.reconstruct_audio(processed_audio)
|
||||
print(f"Saving {output_audio}!")
|
||||
reconstructed_audio = reconstructed_audio.squeeze(0)
|
||||
print(f"🔊 | Saving {output_audio}!")
|
||||
torchaudio.save_with_torchcodec(
|
||||
uri=output_audio,
|
||||
src=reconstructed_audio,
|
||||
|
||||
Reference in New Issue
Block a user