From a135c765da2b99685a6cf42e64cdf34d3121ce17 Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Mon, 5 May 2025 00:50:56 +0300 Subject: [PATCH] :bug: | Misc fixes... --- data.py | 10 +++++----- training.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/data.py b/data.py index 6d64af5..59986f1 100644 --- a/data.py +++ b/data.py @@ -11,7 +11,7 @@ import AudioUtils class AudioDataset(Dataset): audio_sample_rates = [11025] - def __init__(self, input_dir, device): + def __init__(self, input_dir, device, clip_length = 256): self.device = device input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav') or f.endswith('.mp3') or f.endswith('.flac')] @@ -28,11 +28,11 @@ class AudioDataset(Dataset): low_audio = resample_transform_low(audio) low_audio = resample_transform_high(low_audio) - splitted_high_quality_audio = AudioUtils.split_audio(audio, 128) - splitted_high_quality_audio[-1] = AudioUtils.pad_tensor(splitted_high_quality_audio[-1], 128) + splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length) + splitted_high_quality_audio[-1] = AudioUtils.pad_tensor(splitted_high_quality_audio[-1], clip_length) - splitted_low_quality_audio = AudioUtils.split_audio(low_audio, 128) - splitted_low_quality_audio[-1] = AudioUtils.pad_tensor(splitted_low_quality_audio[-1], 128) + splitted_low_quality_audio = AudioUtils.split_audio(low_audio, clip_length) + splitted_low_quality_audio[-1] = AudioUtils.pad_tensor(splitted_low_quality_audio[-1], clip_length) for high_quality_sample, low_quality_sample in zip(splitted_high_quality_audio, splitted_low_quality_audio): data.append(((high_quality_sample, low_quality_sample), (original_sample_rate, mangled_sample_rate))) diff --git a/training.py b/training.py index ab9b35b..5ccabc7 100644 --- a/training.py +++ b/training.py @@ -76,7 +76,7 @@ os.makedirs(audio_output_dir, exist_ok=True) # ========= SINGLE ========= -train_data_loader = DataLoader(dataset, batch_size=256, shuffle=True) +train_data_loader = DataLoader(dataset, batch_size=1024, shuffle=True) # ========= MODELS =========