:albemic: | Experimenting with other model layouts.
This commit is contained in:
27
training.py
27
training.py
@ -38,7 +38,7 @@ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
mfcc_transform = T.MFCC(
|
||||
sample_rate=16000, # Adjust to your sample rate
|
||||
sample_rate=44100, # Adjust to your sample rate
|
||||
n_mfcc=20,
|
||||
melkwargs={'n_fft': 2048, 'hop_length': 512} # adjust n_fft and hop_length to your needs.
|
||||
).to(device)
|
||||
@ -97,20 +97,9 @@ debug = args.verbose
|
||||
dataset_dir = './dataset/good'
|
||||
dataset = AudioDataset(dataset_dir, device)
|
||||
|
||||
# ========= MULTIPLE =========
|
||||
|
||||
# dataset_size = len(dataset)
|
||||
# train_size = int(dataset_size * .9)
|
||||
# val_size = int(dataset_size-train_size)
|
||||
|
||||
#train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
||||
|
||||
# train_data_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
|
||||
# val_data_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)
|
||||
|
||||
# ========= SINGLE =========
|
||||
|
||||
train_data_loader = DataLoader(dataset, batch_size=1, shuffle=True)
|
||||
train_data_loader = DataLoader(dataset, batch_size=128, shuffle=True)
|
||||
|
||||
# Initialize models and move them to device
|
||||
generator = SISUGenerator()
|
||||
@ -175,17 +164,17 @@ def start_training():
|
||||
scheduler_g.step(combined_loss)
|
||||
|
||||
# ========= SAVE LATEST AUDIO =========
|
||||
high_quality_audio = high_quality_clip
|
||||
low_quality_audio = low_quality_clip
|
||||
ai_enhanced_audio = (generator_output, high_quality_clip[1])
|
||||
high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0])
|
||||
low_quality_audio = (low_quality_clip[0][0], low_quality_clip[1][0])
|
||||
ai_enhanced_audio = (generator_output[0], high_quality_clip[1][0])
|
||||
|
||||
new_epoch = generator_epoch+epoch
|
||||
|
||||
if generator_epoch % 10 == 0:
|
||||
print(f"Saved epoch {new_epoch}!")
|
||||
torchaudio.save(f"./output/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
|
||||
torchaudio.save(f"./output/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1])
|
||||
torchaudio.save(f"./output/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1])
|
||||
torchaudio.save(f"./output/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
|
||||
torchaudio.save(f"./output/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu(), ai_enhanced_audio[1])
|
||||
torchaudio.save(f"./output/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu(), high_quality_audio[1])
|
||||
|
||||
if debug:
|
||||
print(generator.state_dict().keys())
|
||||
|
Reference in New Issue
Block a user