# Keep necessary PyTorch imports for torchaudio and Dataset structure from torch.utils.data import Dataset import torch import torchaudio import torchaudio.transforms as T # Keep using torchaudio transforms # Import NumPy import numpy as np import os import random # Assume AudioUtils is available and works on PyTorch Tensors as before import AudioUtils class AudioDatasetNumPy(Dataset): # Renamed slightly for clarity audio_sample_rates = [11025] MAX_LENGTH = 44100 # Define your desired maximum length here def __init__(self, input_dir): """ Initializes the dataset. Device argument is removed. """ self.input_files = [ os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav') ] if not self.input_files: print(f"Warning: No .wav files found in {input_dir}") def __len__(self): return len(self.input_files) def __getitem__(self, idx): """ Loads audio, processes it, and returns NumPy arrays. """ # --- Load and Resample using torchaudio (produces PyTorch tensors) --- try: high_quality_audio_pt, original_sample_rate = torchaudio.load( self.input_files[idx], normalize=True ) except Exception as e: print(f"Error loading file {self.input_files[idx]}: {e}") # Return None or raise error, or return dummy data if preferred # Returning dummy data might hide issues return None # Or handle appropriately # Generate low-quality audio with random downsampling mangled_sample_rate = random.choice(self.audio_sample_rates) # Ensure sample rates are different before resampling if original_sample_rate != mangled_sample_rate: resample_transform_low = T.Resample(original_sample_rate, mangled_sample_rate) low_quality_audio_pt = resample_transform_low(high_quality_audio_pt) resample_transform_high = T.Resample(mangled_sample_rate, original_sample_rate) low_quality_audio_pt = resample_transform_high(low_quality_audio_pt) else: # If rates match, just copy the tensor low_quality_audio_pt = high_quality_audio_pt.clone() # --- Process Stereo to Mono (still using PyTorch tensors) --- # Assuming AudioUtils.stereo_tensor_to_mono expects PyTorch Tensor (C, L) # and returns PyTorch Tensor (1, L) try: high_quality_audio_pt_mono = AudioUtils.stereo_tensor_to_mono(high_quality_audio_pt) low_quality_audio_pt_mono = AudioUtils.stereo_tensor_to_mono(low_quality_audio_pt) except Exception as e: # Handle cases where mono conversion might fail (e.g., already mono) # This depends on how AudioUtils is implemented. Let's assume it handles it. print(f"Warning: Mono conversion issue with {self.input_files[idx]}: {e}. Using original.") high_quality_audio_pt_mono = high_quality_audio_pt if high_quality_audio_pt.shape[0] == 1 else torch.mean(high_quality_audio_pt, dim=0, keepdim=True) low_quality_audio_pt_mono = low_quality_audio_pt if low_quality_audio_pt.shape[0] == 1 else torch.mean(low_quality_audio_pt, dim=0, keepdim=True) # --- Convert to NumPy Arrays --- high_quality_audio_np = high_quality_audio_pt_mono.numpy() # Shape (1, L) low_quality_audio_np = low_quality_audio_pt_mono.numpy() # Shape (1, L) # --- Pad or Truncate using NumPy --- def process_numpy_audio(audio_np, max_len): current_len = audio_np.shape[1] # Length is axis 1 for shape (1, L) if current_len < max_len: padding_needed = max_len - current_len # np.pad format: ((pad_before_ax0, pad_after_ax0), (pad_before_ax1, pad_after_ax1), ...) # We only pad axis 1 (length) at the end audio_np = np.pad(audio_np, ((0, 0), (0, padding_needed)), mode='constant', constant_values=0) elif current_len > max_len: # Truncate axis 1 (length) audio_np = audio_np[:, :max_len] return audio_np high_quality_audio_np = process_numpy_audio(high_quality_audio_np, self.MAX_LENGTH) low_quality_audio_np = process_numpy_audio(low_quality_audio_np, self.MAX_LENGTH) # --- Remove Device Handling --- # .to(self.device) is removed. # --- Return NumPy arrays and metadata --- # Note: Arrays likely have shape (1, MAX_LENGTH) here return (high_quality_audio_np, original_sample_rate), (low_quality_audio_np, mangled_sample_rate)