import os import random import torch import torchaudio import torchcodec.decoders as decoders import tqdm from torch.utils.data import Dataset import AudioUtils class AudioDataset(Dataset): audio_sample_rates = [8000, 11025, 12000, 16000, 22050, 24000, 32000, 44100] def __init__(self, input_dir, clip_length: int = 512, normalize: bool = True): self.clip_length = clip_length self.normalize = normalize input_files = [ os.path.join(input_dir, f) for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.lower().endswith((".wav", ".mp3", ".flac")) ] data = [] for audio_clip in tqdm.tqdm( input_files, desc=f"Processing {len(input_files)} audio file(s)" ): decoder = decoders.AudioDecoder(audio_clip) decoded_samples = decoder.get_all_samples() audio = decoded_samples.data.float() original_sample_rate = decoded_samples.sample_rate if normalize: audio = AudioUtils.normalize(audio) splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length, True) if not splitted_high_quality_audio: continue for splitted_audio_clip in splitted_high_quality_audio: for audio_clip in torch.split(splitted_audio_clip, 1): data.append((audio_clip, original_sample_rate)) self.audio_data = data def __len__(self): return len(self.audio_data) def __getitem__(self, idx): audio_clip = self.audio_data[idx] mangled_sample_rate = random.choice(self.audio_sample_rates) resample_transform_low = torchaudio.transforms.Resample( audio_clip[1], mangled_sample_rate ) resample_transform_high = torchaudio.transforms.Resample( mangled_sample_rate, audio_clip[1] ) low_audio_clip = resample_transform_high(resample_transform_low(audio_clip[0])) if audio_clip[0].shape[1] < low_audio_clip.shape[1]: low_audio_clip = low_audio_clip[:, :audio_clip[0].shape[1]] elif audio_clip[0].shape[1] > low_audio_clip.shape[1]: low_audio_clip = AudioUtils.pad_tensor(low_audio_clip, self.clip_length) return ((audio_clip[0], low_audio_clip), (audio_clip[1], mangled_sample_rate))