diff --git a/data.py b/data.py index 214b6e5..67abb91 100644 --- a/data.py +++ b/data.py @@ -13,37 +13,33 @@ class AudioDataset(Dataset): def __len__(self): return len(self.input_files) + def __getitem__(self, idx): - # Load audio samples using torchaudio high_quality_wav, sr_original = torchaudio.load(self.input_files[idx], normalize=True) - # Resample to 16000 Hz if necessary resample_transform = torchaudio.transforms.Resample(sr_original, 16000) low_quality_wav = resample_transform(high_quality_wav) - # Calculate target length in samples if target_duration is specified + # Calculate target length based on desired duration and 16000 Hz if self.target_duration is not None: - target_length = int(self.target_duration * 16000) # Assuming 16000 Hz as target sample rate + target_length = int(self.target_duration * 44100) else: - target_length = high_quality_wav.size(1) + # Calculate duration of original high quality audio + duration_original = high_quality_wav.shape[1] / sr_original + target_length = int(duration_original * 16000) - # Pad high_quality_wav and low_quality_wav to target_length + # Pad both to the calculated target length high_quality_wav = self.pad_tensor(high_quality_wav, target_length) low_quality_wav = self.pad_tensor(low_quality_wav, target_length) - return high_quality_wav, low_quality_wav + return low_quality_wav, high_quality_wav def pad_tensor(self, tensor, target_length): - """Pad tensor to target length along the time dimension (dim=1).""" current_length = tensor.size(1) - if current_length < target_length: - # Calculate padding amount for each side padding_amount = target_length - current_length - padding = (0, padding_amount) # (left_pad, right_pad) for 1D padding + padding = (0, padding_amount) tensor = torch.nn.functional.pad(tensor, padding, mode=self.padding_mode, value=self.padding_value) - else: - # If tensor is longer than target, truncate it + elif current_length > target_length: tensor = tensor[:, :target_length] - return tensor