import torch import torch.nn.functional as F def stereo_tensor_to_mono(waveform: torch.Tensor) -> torch.Tensor: mono_tensor = torch.mean(waveform, dim=0, keepdim=True) return mono_tensor def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 512) -> torch.Tensor: padding_amount = target_length - audio_tensor.size(-1) if padding_amount <= 0: return audio_tensor padded_audio_tensor = F.pad(audio_tensor, (0, padding_amount)) return padded_audio_tensor def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 512, pad_last_tensor: bool = False) -> list[torch.Tensor]: chunks = list(torch.split(audio_tensor, chunk_size, dim=1)) if pad_last_tensor: last_chunk = chunks[-1] if last_chunk.size(-1) < chunk_size: chunks[-1] = pad_tensor(last_chunk, chunk_size) return chunks def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor: reconstructed_tensor = torch.cat(chunks, dim=-1) return reconstructed_tensor def normalize(audio_tensor: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: max_val = torch.max(torch.abs(audio_tensor)) if max_val < eps: return audio_tensor return audio_tensor / max_val