⚡ | Made training bit faster.
This commit is contained in:
2
app.py
2
app.py
@ -18,7 +18,7 @@ from generator import SISUGenerator
|
||||
parser = argparse.ArgumentParser(description="Training script")
|
||||
parser.add_argument("--device", type=str, default="cpu", help="Select device")
|
||||
parser.add_argument("--model", type=str, help="Model to use for upscaling")
|
||||
parser.add_argument("--clip_length", type=int, default=256, help="Internal clip length, leave unspecified if unsure")
|
||||
parser.add_argument("--clip_length", type=int, default=1024, help="Internal clip length, leave unspecified if unsure")
|
||||
parser.add_argument("-i", "--input", type=str, help="Input audio file")
|
||||
parser.add_argument("-o", "--output", type=str, help="Output audio file")
|
||||
|
||||
|
2
data.py
2
data.py
@ -11,7 +11,7 @@ import AudioUtils
|
||||
class AudioDataset(Dataset):
|
||||
audio_sample_rates = [11025]
|
||||
|
||||
def __init__(self, input_dir, device, clip_length = 256):
|
||||
def __init__(self, input_dir, device, clip_length = 1024):
|
||||
self.device = device
|
||||
input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav') or f.endswith('.mp3') or f.endswith('.flac')]
|
||||
|
||||
|
47
training.py
47
training.py
@ -43,27 +43,38 @@ print(f"Using device: {device}")
|
||||
|
||||
# Parameters
|
||||
sample_rate = 44100
|
||||
n_fft = 128
|
||||
hop_length = 128
|
||||
n_fft = 1024
|
||||
win_length = n_fft
|
||||
hop_length = n_fft // 4
|
||||
n_mels = 40
|
||||
n_mfcc = 13 # If using MFCC
|
||||
n_mfcc = 13
|
||||
|
||||
mfcc_transform = T.MFCC(
|
||||
sample_rate,
|
||||
n_mfcc,
|
||||
melkwargs = {'n_fft': n_fft, 'hop_length': hop_length}
|
||||
sample_rate=sample_rate,
|
||||
n_mfcc=n_mfcc,
|
||||
melkwargs={
|
||||
'n_fft': n_fft,
|
||||
'hop_length': hop_length,
|
||||
'win_length': win_length,
|
||||
'n_mels': n_mels,
|
||||
'power': 1.0,
|
||||
}
|
||||
).to(device)
|
||||
|
||||
mel_transform = T.MelSpectrogram(
|
||||
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
|
||||
win_length=win_length, n_mels=n_mels, power=1.0 # Magnitude Mel
|
||||
sample_rate=sample_rate,
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
n_mels=n_mels,
|
||||
power=1.0 # Magnitude Mel
|
||||
).to(device)
|
||||
|
||||
stft_transform = T.Spectrogram(
|
||||
n_fft=n_fft, win_length=win_length, hop_length=hop_length
|
||||
n_fft=n_fft,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length
|
||||
).to(device)
|
||||
|
||||
debug = args.debug
|
||||
|
||||
# Initialize dataset and dataloader
|
||||
@ -76,7 +87,7 @@ os.makedirs(audio_output_dir, exist_ok=True)
|
||||
|
||||
# ========= SINGLE =========
|
||||
|
||||
train_data_loader = DataLoader(dataset, batch_size=8192, shuffle=True, num_workers=24)
|
||||
train_data_loader = DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=24)
|
||||
|
||||
|
||||
# ========= MODELS =========
|
||||
@ -94,6 +105,7 @@ if args.continue_training:
|
||||
else:
|
||||
generator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
|
||||
discriminator.load_state_dict(torch.load(f"{models_dir}/temp_generator.pt", map_location=device, weights_only=True))
|
||||
|
||||
epoch_from_file = Data.read_data(f"{models_dir}/epoch_data.json")
|
||||
epoch = epoch_from_file["epoch"] + 1
|
||||
|
||||
@ -178,19 +190,10 @@ def start_training():
|
||||
low_quality_audio = (bad_quality_data, original_sample_rate)
|
||||
ai_enhanced_audio = (generator_output, original_sample_rate)
|
||||
|
||||
new_epoch = generator_epoch+epoch
|
||||
|
||||
# if generator_epoch % 25 == 0:
|
||||
# print(f"Saved epoch {new_epoch}!")
|
||||
# torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0][-1].cpu().detach(), high_quality_audio[1][-1])
|
||||
# torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0][-1].cpu().detach(), high_quality_audio[1][-1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
|
||||
# torchaudio.save(f"{audio_output_dir}/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0][-1].cpu().detach(), high_quality_audio[1][-1])
|
||||
|
||||
#if debug:
|
||||
# print(generator.state_dict().keys())
|
||||
# print(discriminator.state_dict().keys())
|
||||
torch.save(discriminator.state_dict(), f"{models_dir}/temp_discriminator.pt")
|
||||
torch.save(generator.state_dict(), f"{models_dir}/temp_generator.pt")
|
||||
|
||||
new_epoch = generator_epoch+epoch
|
||||
Data.write_data(f"{models_dir}/epoch_data.json", {"epoch": new_epoch})
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user