import torch import torch.nn.functional as F def stereo_tensor_to_mono(waveform): if waveform.shape[0] > 1: # Average across channels mono_waveform = torch.mean(waveform, dim=0, keepdim=True) else: # Already mono mono_waveform = waveform return mono_waveform def stretch_tensor(tensor, target_length): scale_factor = target_length / tensor.size(1) tensor = F.interpolate(tensor, scale_factor=scale_factor, mode='linear', align_corners=False) return tensor def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 128): current_length = audio_tensor.shape[-1] if current_length < target_length: padding_needed = target_length - current_length padding_tuple = (0, padding_needed) padded_audio_tensor = F.pad(audio_tensor, padding_tuple, mode='constant', value=0) else: padded_audio_tensor = audio_tensor return padded_audio_tensor def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 128) -> list[torch.Tensor]: if not isinstance(chunk_size, int) or chunk_size <= 0: raise ValueError("chunk_size must be a positive integer.") # Handle scalar tensor edge case if necessary if audio_tensor.dim() == 0: return [audio_tensor] if audio_tensor.numel() > 0 else [] # Identify the dimension to split (usually the last one, representing time/samples) split_dim = -1 num_samples = audio_tensor.shape[split_dim] if num_samples == 0: return [] # Return empty list if the dimension to split is empty # Use torch.split to divide the tensor into chunks # It handles the last chunk being potentially smaller automatically. chunks = list(torch.split(audio_tensor, chunk_size, dim=split_dim)) return chunks