from torch.utils.data import Dataset import torch.nn.functional as F import torch import torchaudio import os import random import torchaudio.transforms as T import AudioUtils class AudioDataset(Dataset): #audio_sample_rates = [8000, 11025, 16000, 22050] audio_sample_rates = [11025] def __init__(self, input_dir): self.input_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.wav')] def __len__(self): return len(self.input_files) def __getitem__(self, idx): # Load high-quality audio high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True) # Generate low-quality audio with random downsampling mangled_sample_rate = random.choice(self.audio_sample_rates) resample_transform = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate) low_quality_audio = resample_transform(high_quality_audio) return (AudioUtils.stereo_tensor_to_mono(high_quality_audio), original_sample_rate), (AudioUtils.stereo_tensor_to_mono(low_quality_audio), mangled_sample_rate)