38 lines
1.1 KiB
Python
38 lines
1.1 KiB
Python
import torch
|
|
import torchaudio
|
|
from generator import SISUGenerator
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
# Initialize models and move them to device
|
|
generator = SISUGenerator()
|
|
generator.load_state_dict(torch.load("generator.pt", weights_only=True))
|
|
generator.to(device)
|
|
generator.eval()
|
|
|
|
def generate_audio(input_audio_path, output_audio_path):
|
|
# Load and preprocess input audio
|
|
low_quality_wav, sr_b = torchaudio.load(input_audio_path)
|
|
low_quality_wav = low_quality_wav.to(device)
|
|
|
|
# Normalize audio
|
|
low_quality_wav = normalize(low_quality_wav)
|
|
|
|
# Flatten the input if necessary
|
|
low_quality_wav = low_quality_wav.view(low_quality_wav.size(0), -1)
|
|
|
|
fake_audio = generator(low_quality_wav)
|
|
|
|
print(fake_audio)
|
|
|
|
print(f"Generated audio saved to {output_audio_path}")
|
|
return low_quality_wav
|
|
|
|
def normalize(wav):
|
|
return wav / torch.max(torch.abs(wav))
|
|
|
|
# Example usage
|
|
input_audio_path = "/mnt/games/Home/Downloads/SISU/sample_3_023.wav"
|
|
output_audio_path = "/mnt/games/Home/Downloads/SISU/godtier_audio.wav"
|
|
generate_audio(input_audio_path, output_audio_path)
|