➖ | Removed/Updated dependencies.
This commit is contained in:
8
data.py
8
data.py
@ -12,8 +12,9 @@ class AudioDataset(Dataset):
|
||||
#audio_sample_rates = [8000, 11025, 16000, 22050]
|
||||
audio_sample_rates = [11025]
|
||||
|
||||
def __init__(self, input_dir):
|
||||
def __init__(self, input_dir, device):
|
||||
self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')]
|
||||
self.device = device
|
||||
|
||||
|
||||
def __len__(self):
|
||||
@ -32,4 +33,7 @@ class AudioDataset(Dataset):
|
||||
resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate)
|
||||
low_quality_audio = resample_transform_high(low_quality_audio)
|
||||
|
||||
return (AudioUtils.stereo_tensor_to_mono(high_quality_audio), original_sample_rate), (AudioUtils.stereo_tensor_to_mono(low_quality_audio), mangled_sample_rate)
|
||||
high_quality_audio = AudioUtils.stereo_tensor_to_mono(high_quality_audio).to(self.device)
|
||||
low_quality_audio = AudioUtils.stereo_tensor_to_mono(low_quality_audio).to(self.device)
|
||||
|
||||
return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate)
|
||||
|
Reference in New Issue
Block a user