105 lines
4.7 KiB
Python
105 lines
4.7 KiB
Python
# Keep necessary PyTorch imports for torchaudio and Dataset structure
|
|
from torch.utils.data import Dataset
|
|
import torch
|
|
import torchaudio
|
|
import torchaudio.transforms as T # Keep using torchaudio transforms
|
|
|
|
# Import NumPy
|
|
import numpy as np
|
|
|
|
import os
|
|
import random
|
|
# Assume AudioUtils is available and works on PyTorch Tensors as before
|
|
import AudioUtils
|
|
|
|
class AudioDatasetNumPy(Dataset): # Renamed slightly for clarity
|
|
audio_sample_rates = [11025]
|
|
MAX_LENGTH = 44100 # Define your desired maximum length here
|
|
|
|
def __init__(self, input_dir):
|
|
"""
|
|
Initializes the dataset. Device argument is removed.
|
|
"""
|
|
self.input_files = [
|
|
os.path.join(root, f)
|
|
for root, _, files in os.walk(input_dir)
|
|
for f in files if f.endswith('.wav')
|
|
]
|
|
if not self.input_files:
|
|
print(f"Warning: No .wav files found in {input_dir}")
|
|
|
|
def __len__(self):
|
|
return len(self.input_files)
|
|
|
|
def __getitem__(self, idx):
|
|
"""
|
|
Loads audio, processes it, and returns NumPy arrays.
|
|
"""
|
|
# --- Load and Resample using torchaudio (produces PyTorch tensors) ---
|
|
try:
|
|
high_quality_audio_pt, original_sample_rate = torchaudio.load(
|
|
self.input_files[idx], normalize=True
|
|
)
|
|
except Exception as e:
|
|
print(f"Error loading file {self.input_files[idx]}: {e}")
|
|
# Return None or raise error, or return dummy data if preferred
|
|
# Returning dummy data might hide issues
|
|
return None # Or handle appropriately
|
|
|
|
# Generate low-quality audio with random downsampling
|
|
mangled_sample_rate = random.choice(self.audio_sample_rates)
|
|
|
|
# Ensure sample rates are different before resampling
|
|
if original_sample_rate != mangled_sample_rate:
|
|
resample_transform_low = T.Resample(original_sample_rate, mangled_sample_rate)
|
|
low_quality_audio_pt = resample_transform_low(high_quality_audio_pt)
|
|
|
|
resample_transform_high = T.Resample(mangled_sample_rate, original_sample_rate)
|
|
low_quality_audio_pt = resample_transform_high(low_quality_audio_pt)
|
|
else:
|
|
# If rates match, just copy the tensor
|
|
low_quality_audio_pt = high_quality_audio_pt.clone()
|
|
|
|
|
|
# --- Process Stereo to Mono (still using PyTorch tensors) ---
|
|
# Assuming AudioUtils.stereo_tensor_to_mono expects PyTorch Tensor (C, L)
|
|
# and returns PyTorch Tensor (1, L)
|
|
try:
|
|
high_quality_audio_pt_mono = AudioUtils.stereo_tensor_to_mono(high_quality_audio_pt)
|
|
low_quality_audio_pt_mono = AudioUtils.stereo_tensor_to_mono(low_quality_audio_pt)
|
|
except Exception as e:
|
|
# Handle cases where mono conversion might fail (e.g., already mono)
|
|
# This depends on how AudioUtils is implemented. Let's assume it handles it.
|
|
print(f"Warning: Mono conversion issue with {self.input_files[idx]}: {e}. Using original.")
|
|
high_quality_audio_pt_mono = high_quality_audio_pt if high_quality_audio_pt.shape[0] == 1 else torch.mean(high_quality_audio_pt, dim=0, keepdim=True)
|
|
low_quality_audio_pt_mono = low_quality_audio_pt if low_quality_audio_pt.shape[0] == 1 else torch.mean(low_quality_audio_pt, dim=0, keepdim=True)
|
|
|
|
|
|
# --- Convert to NumPy Arrays ---
|
|
high_quality_audio_np = high_quality_audio_pt_mono.numpy() # Shape (1, L)
|
|
low_quality_audio_np = low_quality_audio_pt_mono.numpy() # Shape (1, L)
|
|
|
|
|
|
# --- Pad or Truncate using NumPy ---
|
|
def process_numpy_audio(audio_np, max_len):
|
|
current_len = audio_np.shape[1] # Length is axis 1 for shape (1, L)
|
|
if current_len < max_len:
|
|
padding_needed = max_len - current_len
|
|
# np.pad format: ((pad_before_ax0, pad_after_ax0), (pad_before_ax1, pad_after_ax1), ...)
|
|
# We only pad axis 1 (length) at the end
|
|
audio_np = np.pad(audio_np, ((0, 0), (0, padding_needed)), mode='constant', constant_values=0)
|
|
elif current_len > max_len:
|
|
# Truncate axis 1 (length)
|
|
audio_np = audio_np[:, :max_len]
|
|
return audio_np
|
|
|
|
high_quality_audio_np = process_numpy_audio(high_quality_audio_np, self.MAX_LENGTH)
|
|
low_quality_audio_np = process_numpy_audio(low_quality_audio_np, self.MAX_LENGTH)
|
|
|
|
# --- Remove Device Handling ---
|
|
# .to(self.device) is removed.
|
|
|
|
# --- Return NumPy arrays and metadata ---
|
|
# Note: Arrays likely have shape (1, MAX_LENGTH) here
|
|
return (high_quality_audio_np, original_sample_rate), (low_quality_audio_np, mangled_sample_rate)
|