⚡ | Made training bit faster.
This commit is contained in:
47
training.py
47
training.py
@ -43,27 +43,38 @@ print(f"Using device: {device}")
|
||||
|
||||
# Parameters
|
||||
sample_rate = 44100
|
||||
n_fft = 128
|
||||
hop_length = 128
|
||||
n_fft = 1024
|
||||
win_length = n_fft
|
||||
hop_length = n_fft // 4
|
||||
n_mels = 40
|
||||
n_mfcc = 13 # If using MFCC
|
||||
n_mfcc = 13
|
||||
|
||||
mfcc_transform = T.MFCC(
|
||||
sample_rate,
|
||||
n_mfcc,
|
||||
melkwargs = {'n_fft': n_fft, 'hop_length': hop_length}
|
||||
sample_rate=sample_rate,
|
||||
n_mfcc=n_mfcc,
|
||||
melkwargs={
|
||||
'n_fft': n_fft,
|
||||
'hop_length': hop_length,
|
||||
'win_length': win_length,
|
||||
'n_mels': n_mels,
|
||||
'power': 1.0,
|
||||
}
|
||||
).to(device)
|
||||
|
||||
mel_transform = T.MelSpectrogram(
|
||||
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
|
||||
win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel
|
||||
sample_rate=sample_rate,
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
n_mels=n_mels,
|
||||
power=1.0 # Magnitude Mel
|
||||
).to(device)
|
||||
|
||||
stft_transform = T.Spectrogram(
|
||||
n_fft=n_fft, win_length=win_length, hop_length=hop_length
|
||||
n_fft=n_fft,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length
|
||||
).to(device)
|
||||
|
||||
debug = args.debug
|
||||
|
||||
# Initialize dataset and dataloader
|
||||
@ -76,7 +87,7 @@ os.makedirs(audio_output_dir, exist_ok=True)
|
||||
|
||||
# ========= SINGLE =========
|
||||
|
||||
train_data_loader = DataLoader(dataset, batch_size=8192, shuffle=True, num_workers=24)
|
||||
train_data_loader = DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=24)
|
||||
|
||||
|
||||
# ========= MODELS =========
|
||||
@ -94,6 +105,7 @@ if args.continue_training:
|
||||
else:
|
||||
generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
|
||||
discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
|
||||
|
||||
epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json")
|
||||
epoch = epoch_from_file["epoch"] + 1
|
||||
|
||||
@ -178,19 +190,10 @@ def start_training():
|
||||
low_quality_audio = (bad_quality_data, original_sample_rate)
|
||||
ai_enhanced_audio = (generator_output, original_sample_rate)
|
||||
|
||||
new_epoch = generator_epoch+epoch
|
||||
|
||||
# if generator_epoch % 25 == 0:
|
||||
# print(f"Saved epoch {new_epoch}!")
|
||||
# torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0][-1].cpu().detach(), high_quality_audio[1][-1])
|
||||
# torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0][-1].cpu().detach(), high_quality_audio[1][-1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
|
||||
# torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0][-1].cpu().detach(), high_quality_audio[1][-1])
|
||||
|
||||
#if debug:
|
||||
# print(generator.state_dict().keys())
|
||||
# print(discriminator.state_dict().keys())
|
||||
torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt")
|
||||
torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt")
|
||||
|
||||
new_epoch = generator_epoch+epoch
|
||||
Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch})
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user