from torch.utils.data import Dataset import torch.nn.functional as F import torchaudio import os import random class AudioDataset(Dataset): audio_sample_rates = [8000, 11025, 16000, 22050] def __init__(self, input_dir, target_duration=None, padding_mode='constant', padding_value=0.0): self.input_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.wav')] self.target_duration = target_duration # Duration in seconds or None if not set self.padding_mode = padding_mode self.padding_value = padding_value def __len__(self): return len(self.input_files) def __getitem__(self, idx): high_quality_wav, sr_original = torchaudio.load(self.input_files[idx], normalize=True) sample_rate = random.choice(self.audio_sample_rates) resample_transform = torchaudio.transforms.Resample(sr_original, sample_rate) low_quality_wav = resample_transform(high_quality_wav) low_quality_wav = low_quality_wav # Calculate target length based on desired duration and 16000 Hz if self.target_duration is not None: target_length = int(self.target_duration * 44100) else: # Calculate duration of original high quality audio target_length = high_quality_wav.size(1) # Pad both to the calculated target length high_quality_wav = self.stretch_tensor(high_quality_wav, target_length) low_quality_wav = self.stretch_tensor(low_quality_wav, target_length) return low_quality_wav, high_quality_wav def stretch_tensor(self, tensor, target_length): current_length = tensor.size(1) scale_factor = target_length / current_length # Resample the tensor using linear interpolation tensor = F.interpolate(tensor.unsqueeze(0), scale_factor=scale_factor, mode='linear', align_corners=False).squeeze(0) return tensor