✨ | Added smarter ways that would've been needed from the begining.
This commit is contained in:
33
training.py
33
training.py
@ -41,11 +41,24 @@ args = parser.parse_args()
|
||||
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# mfcc_transform = T.MFCC(
|
||||
# sample_rate=44100,
|
||||
# n_mfcc=20,
|
||||
# melkwargs={'n_fft': 2048, 'hop_length': 256}
|
||||
# ).to(device)
|
||||
# Parameters
|
||||
sample_rate = 44100
|
||||
n_fft = 2048
|
||||
hop_length = 256
|
||||
win_length = n_fft
|
||||
n_mels = 128
|
||||
n_mfcc = 20 # If using MFCC
|
||||
|
||||
mfcc_transform = T.MFCC(
|
||||
sample_rate,
|
||||
n_mfcc,
|
||||
melkwargs = {'n_fft': n_fft, 'hop_length': hop_length}
|
||||
).to(device)
|
||||
|
||||
mel_transform = T.MelSpectrogram(
|
||||
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
|
||||
win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel
|
||||
).to(device)
|
||||
|
||||
debug = args.debug
|
||||
|
||||
@ -130,18 +143,20 @@ def start_training():
|
||||
|
||||
# ========= GENERATOR =========
|
||||
generator.train()
|
||||
generator_output, adversarial_loss = generator_train(
|
||||
generator_output, combined_loss, adversarial_loss, mel_l1_tensor = generator_train(
|
||||
low_quality_sample,
|
||||
high_quality_sample,
|
||||
real_labels,
|
||||
generator,
|
||||
discriminator,
|
||||
criterion_g,
|
||||
optimizer_g
|
||||
criterion_d,
|
||||
optimizer_g,
|
||||
device,
|
||||
mel_transform
|
||||
)
|
||||
|
||||
if debug:
|
||||
print(d_loss, adversarial_loss)
|
||||
print(combined_loss, adversarial_loss, mel_l1_tensor)
|
||||
scheduler_d.step(d_loss.detach())
|
||||
scheduler_g.step(adversarial_loss.detach())
|
||||
|
||||
|
Reference in New Issue
Block a user