50 lines
2.0 KiB
Python
50 lines
2.0 KiB
Python
from torch.utils.data import Dataset
|
|
import torch.nn.functional as F
|
|
import torchaudio
|
|
import os
|
|
import random
|
|
|
|
|
|
class AudioDataset(Dataset):
|
|
audio_sample_rates = [8000, 11025, 16000, 22050]
|
|
|
|
def __init__(self, input_dir, target_duration=None, padding_mode='constant', padding_value=0.0):
|
|
self.input_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.wav')]
|
|
self.target_duration = target_duration # Duration in seconds or None if not set
|
|
self.padding_mode = padding_mode
|
|
self.padding_value = padding_value
|
|
|
|
def __len__(self):
|
|
return len(self.input_files)
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True)
|
|
|
|
mangled_sample_rate = random.choice(self.audio_sample_rates)
|
|
resample_transform = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate)
|
|
low_quality_audio = resample_transform(high_quality_audio)
|
|
|
|
# Calculate target length based on desired duration and 16000 Hz
|
|
# if self.target_duration is not None:
|
|
# target_length = int(self.target_duration * 44100)
|
|
# else:
|
|
# # Calculate duration of original high quality audio
|
|
# target_length = high_quality_wav.size(1)
|
|
|
|
# Pad both to the calculated target length
|
|
# high_quality_wav = self.stretch_tensor(high_quality_wav, target_length)
|
|
# low_quality_wav = self.stretch_tensor(low_quality_wav, target_length)
|
|
|
|
|
|
return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate)
|
|
|
|
def stretch_tensor(self, tensor, target_length):
|
|
current_length = tensor.size(1)
|
|
scale_factor = target_length / current_length
|
|
|
|
# Resample the tensor using linear interpolation
|
|
tensor = F.interpolate(tensor.unsqueeze(0), scale_factor=scale_factor, mode='linear', align_corners=False).squeeze(0)
|
|
|
|
return tensor
|