| Made training bit faster.

This commit is contained in:
2025-06-07 20:43:52 +03:00
parent 2ded03713d
commit 03fdc050cc
3 changed files with 27 additions and 24 deletions

2
app.py
View File

@ -18,7 +18,7 @@ from generator import SISUGenerator
parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--model", type=str, help="Model to use for upscaling")
parser.add_argument("--clip_length", type=int, default=256, help="Internal clip length, leave unspecified if unsure")
parser.add_argument("--clip_length", type=int, default=1024, help="Internal clip length, leave unspecified if unsure")
parser.add_argument("-i", "--input", type=str, help="Input audio file")
parser.add_argument("-o", "--output", type=str, help="Output audio file")

View File

@ -11,7 +11,7 @@ import AudioUtils
class AudioDataset(Dataset):
audio_sample_rates = [11025]
def __init__(self, input_dir, device, clip_length = 256):
def __init__(self, input_dir, device, clip_length = 1024):
self.device = device
input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav') or f.endswith('.mp3') or f.endswith('.flac')]

View File

@ -43,27 +43,38 @@ print(f"Using device: {device}")
# Parameters
sample_rate = 44100
n_fft = 128
hop_length = 128
n_fft = 1024
win_length = n_fft
hop_length = n_fft // 4
n_mels = 40
n_mfcc = 13 # If using MFCC
n_mfcc = 13
mfcc_transform = T.MFCC(
sample_rate,
n_mfcc,
melkwargs = {'n_fft': n_fft, 'hop_length': hop_length}
sample_rate=sample_rate,
n_mfcc=n_mfcc,
melkwargs={
'n_fft': n_fft,
'hop_length': hop_length,
'win_length': win_length,
'n_mels': n_mels,
'power': 1.0,
}
).to(device)
mel_transform = T.MelSpectrogram(
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel
sample_rate=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mels=n_mels,
power=1.0 # Magnitude Mel
).to(device)
stft_transform = T.Spectrogram(
n_fft=n_fft, win_length=win_length, hop_length=hop_length
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length
).to(device)
debug = args.debug
# Initialize dataset and dataloader
@ -76,7 +87,7 @@ os.makedirs(audio_output_dir, exist_ok=True)
# ========= SINGLE =========
train_data_loader = DataLoader(dataset, batch_size=8192, shuffle=True, num_workers=24)
train_data_loader = DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=24)
# ========= MODELS =========
@ -94,6 +105,7 @@ if args.continue_training:
else:
generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json")
epoch = epoch_from_file["epoch"] + 1
@ -178,19 +190,10 @@ def start_training():
low_quality_audio = (bad_quality_data, original_sample_rate)
ai_enhanced_audio = (generator_output, original_sample_rate)
new_epoch = generator_epoch+epoch
# if generator_epoch % 25 == 0:
# print(f"Saved epoch {new_epoch}!")
# torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0][-1].cpu().detach(), high_quality_audio[1][-1])
# torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0][-1].cpu().detach(), high_quality_audio[1][-1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
# torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0][-1].cpu().detach(), high_quality_audio[1][-1])
#if debug:
# print(generator.state_dict().keys())
# print(discriminator.state_dict().keys())
torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt")
torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt")
new_epoch = generator_epoch+epoch
Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch})