diff --git a/app.py b/app.py index a24486b..ed51803 100644 --- a/app.py +++ b/app.py @@ -18,7 +18,7 @@ from generator import SISUGenerator parser = argparse.ArgumentParser(description="Training script") parser.add_argument("--device", type=str, default="cpu", help="Select device") parser.add_argument("--model", type=str, help="Model to use for upscaling") -parser.add_argument("--clip_length", type=int, default=256, help="Internal clip length, leave unspecified if unsure") +parser.add_argument("--clip_length", type=int, default=1024, help="Internal clip length, leave unspecified if unsure") parser.add_argument("-i", "--input", type=str, help="Input audio file") parser.add_argument("-o", "--output", type=str, help="Output audio file") diff --git a/data.py b/data.py index 59986f1..c3e1047 100644 --- a/data.py +++ b/data.py @@ -11,7 +11,7 @@ import AudioUtils class AudioDataset(Dataset): audio_sample_rates = [11025] - def __init__(self, input_dir, device, clip_length = 256): + def __init__(self, input_dir, device, clip_length = 1024): self.device = device input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav') or f.endswith('.mp3') or f.endswith('.flac')] diff --git a/training.py b/training.py index 52275d5..1be713c 100644 --- a/training.py +++ b/training.py @@ -43,27 +43,38 @@ print(f"Using device: {device}") # Parameters sample_rate = 44100 -n_fft = 128 -hop_length = 128 +n_fft = 1024 win_length = n_fft +hop_length = n_fft // 4 n_mels = 40 -n_mfcc = 13 # If using MFCC +n_mfcc = 13 mfcc_transform = T.MFCC( - sample_rate, - n_mfcc, - melkwargs = {'n_fft': n_fft, 'hop_length': hop_length} + sample_rate=sample_rate, + n_mfcc=n_mfcc, + melkwargs={ + 'n_fft': n_fft, + 'hop_length': hop_length, + 'win_length': win_length, + 'n_mels': n_mels, + 'power': 1.0, + } ).to(device) mel_transform = T.MelSpectrogram( - sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, - win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel + sample_rate=sample_rate, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + n_mels=n_mels, + power=1.0 # Magnitude Mel ).to(device) stft_transform = T.Spectrogram( - n_fft=n_fft, win_length=win_length, hop_length=hop_length + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length ).to(device) - debug = args.debug # Initialize dataset and dataloader @@ -76,7 +87,7 @@ os.makedirs(audio_output_dir, exist_ok=True) # ========= SINGLE ========= -train_data_loader = DataLoader(dataset, batch_size=8192, shuffle=True, num_workers=24) +train_data_loader = DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=24) # ========= MODELS ========= @@ -94,6 +105,7 @@ if args.continue_training: else: generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True)) + epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json") epoch = epoch_from_file["epoch"] + 1 @@ -178,19 +190,10 @@ def start_training(): low_quality_audio = (bad_quality_data, original_sample_rate) ai_enhanced_audio = (generator_output, original_sample_rate) - new_epoch = generator_epoch+epoch - - # if generator_epoch % 25 == 0: - # print(f"Saved epoch {new_epoch}!") - # torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0][-1].cpu().detach(), high_quality_audio[1][-1]) - # torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0][-1].cpu().detach(), high_quality_audio[1][-1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again. - # torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0][-1].cpu().detach(), high_quality_audio[1][-1]) - - #if debug: - # print(generator.state_dict().keys()) - # print(discriminator.state_dict().keys()) torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt") torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt") + + new_epoch = generator_epoch+epoch Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch})