98 lines
3.0 KiB
Python
98 lines
3.0 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def stereo_tensor_to_mono(waveform: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Convert stereo (C, N) to mono (1, N). Ensures a channel dimension.
|
|
"""
|
|
if waveform.dim() == 1:
|
|
waveform = waveform.unsqueeze(0) # (N,) -> (1, N)
|
|
|
|
if waveform.shape[0] > 1:
|
|
mono_waveform = torch.mean(waveform, dim=0, keepdim=True) # (1, N)
|
|
else:
|
|
mono_waveform = waveform
|
|
return mono_waveform
|
|
|
|
|
|
def stretch_tensor(tensor: torch.Tensor, target_length: int) -> torch.Tensor:
|
|
"""
|
|
Stretch audio along time dimension to target_length.
|
|
Input assumed (1, N). Returns (1, target_length).
|
|
"""
|
|
if tensor.dim() == 1:
|
|
tensor = tensor.unsqueeze(0) # ensure (1, N)
|
|
|
|
tensor = tensor.unsqueeze(0) # (1, 1, N) for interpolate
|
|
stretched = F.interpolate(
|
|
tensor, size=target_length, mode="linear", align_corners=False
|
|
)
|
|
return stretched.squeeze(0) # back to (1, target_length)
|
|
|
|
|
|
def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 128) -> torch.Tensor:
|
|
"""
|
|
Pad to fixed length. Input assumed (1, N). Returns (1, target_length).
|
|
"""
|
|
if audio_tensor.dim() == 1:
|
|
audio_tensor = audio_tensor.unsqueeze(0)
|
|
|
|
current_length = audio_tensor.shape[-1]
|
|
if current_length < target_length:
|
|
padding_needed = target_length - current_length
|
|
padding_tuple = (0, padding_needed)
|
|
padded_audio_tensor = F.pad(
|
|
audio_tensor, padding_tuple, mode="constant", value=0
|
|
)
|
|
else:
|
|
padded_audio_tensor = audio_tensor[..., :target_length] # crop if too long
|
|
|
|
return padded_audio_tensor
|
|
|
|
|
|
def split_audio(
|
|
audio_tensor: torch.Tensor, chunk_size: int = 128
|
|
) -> list[torch.Tensor]:
|
|
"""
|
|
Split into chunks of (1, chunk_size).
|
|
"""
|
|
if not isinstance(chunk_size, int) or chunk_size <= 0:
|
|
raise ValueError("chunk_size must be a positive integer.")
|
|
|
|
if audio_tensor.dim() == 1:
|
|
audio_tensor = audio_tensor.unsqueeze(0)
|
|
|
|
num_samples = audio_tensor.shape[-1]
|
|
if num_samples == 0:
|
|
return []
|
|
|
|
chunks = list(torch.split(audio_tensor, chunk_size, dim=-1))
|
|
return chunks
|
|
|
|
|
|
def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor:
|
|
"""
|
|
Reconstruct audio from chunks. Returns (1, N).
|
|
"""
|
|
if not chunks:
|
|
return torch.empty(1, 0)
|
|
|
|
chunks = [c if c.dim() == 2 else c.unsqueeze(0) for c in chunks]
|
|
try:
|
|
reconstructed_tensor = torch.cat(chunks, dim=-1)
|
|
except RuntimeError as e:
|
|
raise RuntimeError(
|
|
f"Failed to concatenate audio chunks. Ensure chunks have compatible shapes "
|
|
f"for concatenation along dim -1. Original error: {e}"
|
|
)
|
|
|
|
return reconstructed_tensor
|
|
|
|
|
|
def normalize(audio_tensor: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
|
max_val = torch.max(torch.abs(audio_tensor))
|
|
if max_val < eps:
|
|
return audio_tensor # silence, skip normalization
|
|
return audio_tensor / max_val
|