⚗️ | More architectural changes
This commit is contained in:
62
data.py
62
data.py
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchcodec.decoders as decoders
|
||||
import tqdm
|
||||
@@ -10,9 +11,9 @@ import AudioUtils
|
||||
|
||||
|
||||
class AudioDataset(Dataset):
|
||||
audio_sample_rates = [11025]
|
||||
audio_sample_rates = [8000, 11025, 12000, 16000, 22050, 24000, 32000, 44100]
|
||||
|
||||
def __init__(self, input_dir, clip_length: int = 8000, normalize: bool = True):
|
||||
def __init__(self, input_dir, clip_length: int = 512, normalize: bool = True):
|
||||
self.clip_length = clip_length
|
||||
self.normalize = normalize
|
||||
|
||||
@@ -30,45 +31,20 @@ class AudioDataset(Dataset):
|
||||
decoder = decoders.AudioDecoder(audio_clip)
|
||||
decoded_samples = decoder.get_all_samples()
|
||||
|
||||
audio = decoded_samples.data.float() # ensure float32
|
||||
audio = decoded_samples.data.float()
|
||||
original_sample_rate = decoded_samples.sample_rate
|
||||
|
||||
audio = AudioUtils.stereo_tensor_to_mono(audio)
|
||||
if normalize:
|
||||
audio = AudioUtils.normalize(audio)
|
||||
|
||||
mangled_sample_rate = random.choice(self.audio_sample_rates)
|
||||
resample_transform_low = torchaudio.transforms.Resample(
|
||||
original_sample_rate, mangled_sample_rate
|
||||
)
|
||||
resample_transform_high = torchaudio.transforms.Resample(
|
||||
mangled_sample_rate, original_sample_rate
|
||||
)
|
||||
splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length, True)
|
||||
|
||||
low_audio = resample_transform_high(resample_transform_low(audio))
|
||||
if not splitted_high_quality_audio:
|
||||
continue
|
||||
|
||||
splitted_high_quality_audio = AudioUtils.split_audio(audio, clip_length)
|
||||
splitted_low_quality_audio = AudioUtils.split_audio(low_audio, clip_length)
|
||||
|
||||
if not splitted_high_quality_audio or not splitted_low_quality_audio:
|
||||
continue # skip empty or invalid clips
|
||||
|
||||
splitted_high_quality_audio[-1] = AudioUtils.pad_tensor(
|
||||
splitted_high_quality_audio[-1], clip_length
|
||||
)
|
||||
splitted_low_quality_audio[-1] = AudioUtils.pad_tensor(
|
||||
splitted_low_quality_audio[-1], clip_length
|
||||
)
|
||||
|
||||
for high_quality_data, low_quality_data in zip(
|
||||
splitted_high_quality_audio, splitted_low_quality_audio
|
||||
):
|
||||
data.append(
|
||||
(
|
||||
(high_quality_data, low_quality_data),
|
||||
(original_sample_rate, mangled_sample_rate),
|
||||
)
|
||||
)
|
||||
for splitted_audio_clip in splitted_high_quality_audio:
|
||||
for audio_clip in torch.split(splitted_audio_clip, 1):
|
||||
data.append((audio_clip, original_sample_rate))
|
||||
|
||||
self.audio_data = data
|
||||
|
||||
@@ -76,4 +52,20 @@ class AudioDataset(Dataset):
|
||||
return len(self.audio_data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.audio_data[idx]
|
||||
audio_clip = self.audio_data[idx]
|
||||
mangled_sample_rate = random.choice(self.audio_sample_rates)
|
||||
|
||||
resample_transform_low = torchaudio.transforms.Resample(
|
||||
audio_clip[1], mangled_sample_rate
|
||||
)
|
||||
|
||||
resample_transform_high = torchaudio.transforms.Resample(
|
||||
mangled_sample_rate, audio_clip[1]
|
||||
)
|
||||
|
||||
low_audio_clip = resample_transform_high(resample_transform_low(audio_clip[0]))
|
||||
if audio_clip[0].shape[1] < low_audio_clip.shape[1]:
|
||||
low_audio_clip = low_audio_clip[:, :audio_clip[0].shape[1]]
|
||||
elif audio_clip[0].shape[1] > low_audio_clip.shape[1]:
|
||||
low_audio_clip = AudioUtils.pad_tensor(low_audio_clip, self.clip_length)
|
||||
return ((audio_clip[0], low_audio_clip), (audio_clip[1], mangled_sample_rate))
|
||||
|
||||
Reference in New Issue
Block a user