from torch.utils.data import Dataset import torch.nn.functional as F import torch import torchaudio import os import random import torchaudio.transforms as T import AudioUtils class AudioDataset(Dataset): audio_sample_rates = [11025] MAX_LENGTH = 44100 # Define your desired maximum length here def __init__(self, input_dir, device): self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')] self.device = device def __len__(self): return len(self.input_files) def __getitem__(self, idx): # Load high-quality audio high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True) # Generate low-quality audio with random downsampling mangled_sample_rate = random.choice(self.audio_sample_rates) resample_transform_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate) low_quality_audio = resample_transform_low(high_quality_audio) resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate) low_quality_audio = resample_transform_high(low_quality_audio) high_quality_audio = AudioUtils.stereo_tensor_to_mono(high_quality_audio) low_quality_audio = AudioUtils.stereo_tensor_to_mono(low_quality_audio) # Pad or truncate high-quality audio if high_quality_audio.shape[1] < self.MAX_LENGTH: padding = self.MAX_LENGTH - high_quality_audio.shape[1] high_quality_audio = F.pad(high_quality_audio, (0, padding)) elif high_quality_audio.shape[1] > self.MAX_LENGTH: high_quality_audio = high_quality_audio[:, :self.MAX_LENGTH] # Pad or truncate low-quality audio if low_quality_audio.shape[1] < self.MAX_LENGTH: padding = self.MAX_LENGTH - low_quality_audio.shape[1] low_quality_audio = F.pad(low_quality_audio, (0, padding)) elif low_quality_audio.shape[1] > self.MAX_LENGTH: low_quality_audio = low_quality_audio[:, :self.MAX_LENGTH] high_quality_audio = high_quality_audio.to(self.device) low_quality_audio = low_quality_audio.to(self.device) return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate)