import torch from torch.utils.data import Dataset import torchaudio import os class AudioDataset(Dataset): def __init__(self, input_dir, target_duration=None, padding_mode='constant', padding_value=0.0): self.input_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.wav')] self.target_duration = target_duration # Duration in seconds or None if not set self.padding_mode = padding_mode self.padding_value = padding_value def __len__(self): return len(self.input_files) def __getitem__(self, idx): # Load audio samples using torchaudio high_quality_wav, sr_original = torchaudio.load(self.input_files[idx], normalize=True) # Resample to 16000 Hz if necessary resample_transform = torchaudio.transforms.Resample(sr_original, 16000) low_quality_wav = resample_transform(high_quality_wav) # Calculate target length in samples if target_duration is specified if self.target_duration is not None: target_length = int(self.target_duration * 16000) # Assuming 16000 Hz as target sample rate else: target_length = high_quality_wav.size(1) # Pad high_quality_wav and low_quality_wav to target_length high_quality_wav = self.pad_tensor(high_quality_wav, target_length) low_quality_wav = self.pad_tensor(low_quality_wav, target_length) return high_quality_wav, low_quality_wav def pad_tensor(self, tensor, target_length): """Pad tensor to target length along the time dimension (dim=1).""" current_length = tensor.size(1) if current_length < target_length: # Calculate padding amount for each side padding_amount = target_length - current_length padding = (0, padding_amount) # (left_pad, right_pad) for 1D padding tensor = torch.nn.functional.pad(tensor, padding, mode=self.padding_mode, value=self.padding_value) else: # If tensor is longer than target, truncate it tensor = tensor[:, :target_length] return tensor