from torch.utils.data import Dataset import torch.nn.functional as F import torch import torchaudio import os import random import torchaudio.transforms as T import tqdm import AudioUtils class AudioDataset(Dataset): audio_sample_rates = [11025] def __init__(self, input_dir, device): self.device = device input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav') or f.endswith('.mp3') or f.endswith('.flac')] data = [] for audio_clip in tqdm.tqdm(input_files, desc=f"Processing {len(input_files)} audio file(s)"): audio, original_sample_rate = torchaudio.load(audio_clip, normalize=True) audio = AudioUtils.stereo_tensor_to_mono(audio) # Generate low-quality audio with random downsampling mangled_sample_rate = random.choice(self.audio_sample_rates) resample_transform_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate) resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate) low_audio = resample_transform_low(audio) low_audio = resample_transform_high(low_audio) splitted_high_quality_audio = AudioUtils.split_audio(audio, 128) splitted_high_quality_audio[-1] = AudioUtils.pad_tensor(splitted_high_quality_audio[-1], 128) splitted_low_quality_audio = AudioUtils.split_audio(low_audio, 128) splitted_low_quality_audio[-1] = AudioUtils.pad_tensor(splitted_low_quality_audio[-1], 128) for high_quality_sample, low_quality_sample in zip(splitted_high_quality_audio, splitted_low_quality_audio): data.append(((high_quality_sample, low_quality_sample), (original_sample_rate, mangled_sample_rate))) self.audio_data = data def __len__(self): return len(self.audio_data) def __getitem__(self, idx): return self.audio_data[idx]