import os import random import torchaudio import torchcodec.decoders as decoders import tqdm from torch.utils.data import Dataset import AudioUtils class AudioDataset(Dataset): audio_sample_rates = [11025] def __init__(self, input_dir, clip_length: int = 8000, normalize: bool = True): self.clip_length = clip_length self.normalize = normalize input_files = [ os.path.join(input_dir, f) for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.lower().endswith((".wav", ".mp3", ".flac")) ] data = [] for audio_clip in tqdm.tqdm( input_files, desc=f"Processing {len(input_files)} audio file(s)" ): decoder = decoders.AudioDecoder(audio_clip) decoded_samples = decoder.get_all_samples() audio = decoded_samples.data.float() # ensure float32 original_sample_rate = decoded_samples.sample_rate audio = AudioUtils.stereo_tensor_to_mono(audio) if normalize: audio = AudioUtils.normalize(audio) mangled_sample_rate = random.choice(self.audio_sample_rates) resample_transform_low = torchaudio.transforms.Resample( original_sample_rate, mangled_sample_rate ) resample_transform_high = torchaudio.transforms.Resample( mangled_sample_rate, original_sample_rate ) low_audio = resample_transform_high(resample_transform_low(audio)) splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length) splitted_low_quality_audio = AudioUtils.split_audio(low_audio, clip_length) if not splitted_high_quality_audio or not splitted_low_quality_audio: continue # skip empty or invalid clips splitted_high_quality_audio[-1] = AudioUtils.pad_tensor( splitted_high_quality_audio[-1], clip_length ) splitted_low_quality_audio[-1] = AudioUtils.pad_tensor( splitted_low_quality_audio[-1], clip_length ) for high_quality_data, low_quality_data in zip( splitted_high_quality_audio, splitted_low_quality_audio ): data.append( ( (high_quality_data, low_quality_data), (original_sample_rate, mangled_sample_rate), ) ) self.audio_data = data def __len__(self): return len(self.audio_data) def __getitem__(self, idx): return self.audio_data[idx]