🐛 | Fixed output not being same length.
This commit is contained in:
24
data.py
24
data.py
@ -13,37 +13,33 @@ class AudioDataset(Dataset):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.input_files)
|
return len(self.input_files)
|
||||||
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
# Load audio samples using torchaudio
|
|
||||||
high_quality_wav, sr_original = torchaudio.load(self.input_files[idx], normalize=True)
|
high_quality_wav, sr_original = torchaudio.load(self.input_files[idx], normalize=True)
|
||||||
|
|
||||||
# Resample to 16000 Hz if necessary
|
|
||||||
resample_transform = torchaudio.transforms.Resample(sr_original, 16000)
|
resample_transform = torchaudio.transforms.Resample(sr_original, 16000)
|
||||||
low_quality_wav = resample_transform(high_quality_wav)
|
low_quality_wav = resample_transform(high_quality_wav)
|
||||||
|
|
||||||
# Calculate target length in samples if target_duration is specified
|
# Calculate target length based on desired duration and 16000 Hz
|
||||||
if self.target_duration is not None:
|
if self.target_duration is not None:
|
||||||
target_length = int(self.target_duration * 16000) # Assuming 16000 Hz as target sample rate
|
target_length = int(self.target_duration * 44100)
|
||||||
else:
|
else:
|
||||||
target_length = high_quality_wav.size(1)
|
# Calculate duration of original high quality audio
|
||||||
|
duration_original = high_quality_wav.shape[1] / sr_original
|
||||||
|
target_length = int(duration_original * 16000)
|
||||||
|
|
||||||
# Pad high_quality_wav and low_quality_wav to target_length
|
# Pad both to the calculated target length
|
||||||
high_quality_wav = self.pad_tensor(high_quality_wav, target_length)
|
high_quality_wav = self.pad_tensor(high_quality_wav, target_length)
|
||||||
low_quality_wav = self.pad_tensor(low_quality_wav, target_length)
|
low_quality_wav = self.pad_tensor(low_quality_wav, target_length)
|
||||||
|
|
||||||
return high_quality_wav, low_quality_wav
|
return low_quality_wav, high_quality_wav
|
||||||
|
|
||||||
def pad_tensor(self, tensor, target_length):
|
def pad_tensor(self, tensor, target_length):
|
||||||
"""Pad tensor to target length along the time dimension (dim=1)."""
|
|
||||||
current_length = tensor.size(1)
|
current_length = tensor.size(1)
|
||||||
|
|
||||||
if current_length < target_length:
|
if current_length < target_length:
|
||||||
# Calculate padding amount for each side
|
|
||||||
padding_amount = target_length - current_length
|
padding_amount = target_length - current_length
|
||||||
padding = (0, padding_amount) # (left_pad, right_pad) for 1D padding
|
padding = (0, padding_amount)
|
||||||
tensor = torch.nn.functional.pad(tensor, padding, mode=self.padding_mode, value=self.padding_value)
|
tensor = torch.nn.functional.pad(tensor, padding, mode=self.padding_mode, value=self.padding_value)
|
||||||
else:
|
elif current_length > target_length:
|
||||||
# If tensor is longer than target, truncate it
|
|
||||||
tensor = tensor[:, :target_length]
|
tensor = tensor[:, :target_length]
|
||||||
|
|
||||||
return tensor
|
return tensor
|
||||||
|
Reference in New Issue
Block a user