SISU/data.py

105 lines
4.7 KiB
Python

# Keep necessary PyTorch imports for torchaudio and Dataset structure
from torch.utils.data import Dataset
import torch
import torchaudio
import torchaudio.transforms as T # Keep using torchaudio transforms
# Import NumPy
import numpy as np
import os
import random
# Assume AudioUtils is available and works on PyTorch Tensors as before
import AudioUtils
class AudioDatasetNumPy(Dataset): # Renamed slightly for clarity
audio_sample_rates = [11025]
MAX_LENGTH = 44100 # Define your desired maximum length here
def __init__(self, input_dir):
"""
Initializes the dataset. Device argument is removed.
"""
self.input_files = [
os.path.join(root, f)
for root, _, files in os.walk(input_dir)
for f in files if f.endswith('.wav')
]
if not self.input_files:
print(f"Warning: No .wav files found in {input_dir}")
def __len__(self):
return len(self.input_files)
def __getitem__(self, idx):
"""
Loads audio, processes it, and returns NumPy arrays.
"""
# --- Load and Resample using torchaudio (produces PyTorch tensors) ---
try:
high_quality_audio_pt, original_sample_rate = torchaudio.load(
self.input_files[idx], normalize=True
)
except Exception as e:
print(f"Error loading file {self.input_files[idx]}: {e}")
# Return None or raise error, or return dummy data if preferred
# Returning dummy data might hide issues
return None # Or handle appropriately
# Generate low-quality audio with random downsampling
mangled_sample_rate = random.choice(self.audio_sample_rates)
# Ensure sample rates are different before resampling
if original_sample_rate != mangled_sample_rate:
resample_transform_low = T.Resample(original_sample_rate, mangled_sample_rate)
low_quality_audio_pt = resample_transform_low(high_quality_audio_pt)
resample_transform_high = T.Resample(mangled_sample_rate, original_sample_rate)
low_quality_audio_pt = resample_transform_high(low_quality_audio_pt)
else:
# If rates match, just copy the tensor
low_quality_audio_pt = high_quality_audio_pt.clone()
# --- Process Stereo to Mono (still using PyTorch tensors) ---
# Assuming AudioUtils.stereo_tensor_to_mono expects PyTorch Tensor (C, L)
# and returns PyTorch Tensor (1, L)
try:
high_quality_audio_pt_mono = AudioUtils.stereo_tensor_to_mono(high_quality_audio_pt)
low_quality_audio_pt_mono = AudioUtils.stereo_tensor_to_mono(low_quality_audio_pt)
except Exception as e:
# Handle cases where mono conversion might fail (e.g., already mono)
# This depends on how AudioUtils is implemented. Let's assume it handles it.
print(f"Warning: Mono conversion issue with {self.input_files[idx]}: {e}. Using original.")
high_quality_audio_pt_mono = high_quality_audio_pt if high_quality_audio_pt.shape[0] == 1 else torch.mean(high_quality_audio_pt, dim=0, keepdim=True)
low_quality_audio_pt_mono = low_quality_audio_pt if low_quality_audio_pt.shape[0] == 1 else torch.mean(low_quality_audio_pt, dim=0, keepdim=True)
# --- Convert to NumPy Arrays ---
high_quality_audio_np = high_quality_audio_pt_mono.numpy() # Shape (1, L)
low_quality_audio_np = low_quality_audio_pt_mono.numpy() # Shape (1, L)
# --- Pad or Truncate using NumPy ---
def process_numpy_audio(audio_np, max_len):
current_len = audio_np.shape[1] # Length is axis 1 for shape (1, L)
if current_len < max_len:
padding_needed = max_len - current_len
# np.pad format: ((pad_before_ax0, pad_after_ax0), (pad_before_ax1, pad_after_ax1), ...)
# We only pad axis 1 (length) at the end
audio_np = np.pad(audio_np, ((0, 0), (0, padding_needed)), mode='constant', constant_values=0)
elif current_len > max_len:
# Truncate axis 1 (length)
audio_np = audio_np[:, :max_len]
return audio_np
high_quality_audio_np = process_numpy_audio(high_quality_audio_np, self.MAX_LENGTH)
low_quality_audio_np = process_numpy_audio(low_quality_audio_np, self.MAX_LENGTH)
# --- Remove Device Handling ---
# .to(self.device) is removed.
# --- Return NumPy arrays and metadata ---
# Note: Arrays likely have shape (1, MAX_LENGTH) here
return (high_quality_audio_np, original_sample_rate), (low_quality_audio_np, mangled_sample_rate)