36 lines
1.3 KiB
Python
36 lines
1.3 KiB
Python
from torch.utils.data import Dataset
|
|
import torch.nn.functional as F
|
|
import torch
|
|
import torchaudio
|
|
import os
|
|
import random
|
|
|
|
import torchaudio.transforms as T
|
|
import AudioUtils
|
|
|
|
class AudioDataset(Dataset):
|
|
#audio_sample_rates = [8000, 11025, 16000, 22050]
|
|
audio_sample_rates = [11025]
|
|
|
|
def __init__(self, input_dir):
|
|
self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')]
|
|
|
|
|
|
def __len__(self):
|
|
return len(self.input_files)
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
# Load high-quality audio
|
|
high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True)
|
|
|
|
# Generate low-quality audio with random downsampling
|
|
mangled_sample_rate = random.choice(self.audio_sample_rates)
|
|
resample_transform_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate)
|
|
low_quality_audio = resample_transform_low(high_quality_audio)
|
|
|
|
resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate)
|
|
low_quality_audio = resample_transform_high(low_quality_audio)
|
|
|
|
return (AudioUtils.stereo_tensor_to_mono(high_quality_audio), original_sample_rate), (AudioUtils.stereo_tensor_to_mono(low_quality_audio), mangled_sample_rate)
|