import torch import torch.nn.functional as F def stereo_tensor_to_mono(waveform: torch.Tensor) -> torch.Tensor: """ Convert stereo (C, N) to mono (1, N). Ensures a channel dimension. """ if waveform.dim() == 1: waveform = waveform.unsqueeze(0) # (N,) -> (1, N) if waveform.shape[0] > 1: mono_waveform = torch.mean(waveform, dim=0, keepdim=True) # (1, N) else: mono_waveform = waveform return mono_waveform def stretch_tensor(tensor: torch.Tensor, target_length: int) -> torch.Tensor: """ Stretch audio along time dimension to target_length. Input assumed (1, N). Returns (1, target_length). """ if tensor.dim() == 1: tensor = tensor.unsqueeze(0) # ensure (1, N) tensor = tensor.unsqueeze(0) # (1, 1, N) for interpolate stretched = F.interpolate( tensor, size=target_length, mode="linear", align_corners=False ) return stretched.squeeze(0) # back to (1, target_length) def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 128) -> torch.Tensor: """ Pad to fixed length. Input assumed (1, N). Returns (1, target_length). """ if audio_tensor.dim() == 1: audio_tensor = audio_tensor.unsqueeze(0) current_length = audio_tensor.shape[-1] if current_length < target_length: padding_needed = target_length - current_length padding_tuple = (0, padding_needed) padded_audio_tensor = F.pad( audio_tensor, padding_tuple, mode="constant", value=0 ) else: padded_audio_tensor = audio_tensor[..., :target_length] # crop if too long return padded_audio_tensor def split_audio( audio_tensor: torch.Tensor, chunk_size: int = 128 ) -> list[torch.Tensor]: """ Split into chunks of (1, chunk_size). """ if not isinstance(chunk_size, int) or chunk_size <= 0: raise ValueError("chunk_size must be a positive integer.") if audio_tensor.dim() == 1: audio_tensor = audio_tensor.unsqueeze(0) num_samples = audio_tensor.shape[-1] if num_samples == 0: return [] chunks = list(torch.split(audio_tensor, chunk_size, dim=-1)) return chunks def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor: """ Reconstruct audio from chunks. Returns (1, N). """ if not chunks: return torch.empty(1, 0) chunks = [c if c.dim() == 2 else c.unsqueeze(0) for c in chunks] try: reconstructed_tensor = torch.cat(chunks, dim=-1) except RuntimeError as e: raise RuntimeError( f"Failed to concatenate audio chunks. Ensure chunks have compatible shapes " f"for concatenation along dim -1. Original error: {e}" ) return reconstructed_tensor def normalize(audio_tensor: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: max_val = torch.max(torch.abs(audio_tensor)) if max_val < eps: return audio_tensor # silence, skip normalization return audio_tensor / max_val