42 lines
1.2 KiB
Python
42 lines
1.2 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def stereo_tensor_to_mono(waveform: torch.Tensor) -> torch.Tensor:
|
|
mono_tensor = torch.mean(waveform, dim=0, keepdim=True)
|
|
return mono_tensor
|
|
|
|
|
|
def pad_tensor(audio_tensor: torch.Tensor, target_length: int = 512) -> torch.Tensor:
|
|
padding_amount = target_length - audio_tensor.size(-1)
|
|
if padding_amount <= 0:
|
|
return audio_tensor
|
|
|
|
padded_audio_tensor = F.pad(audio_tensor, (0, padding_amount))
|
|
|
|
return padded_audio_tensor
|
|
|
|
|
|
def split_audio(audio_tensor: torch.Tensor, chunk_size: int = 512, pad_last_tensor: bool = False) -> list[torch.Tensor]:
|
|
chunks = list(torch.split(audio_tensor, chunk_size, dim=1))
|
|
|
|
if pad_last_tensor:
|
|
last_chunk = chunks[-1]
|
|
|
|
if last_chunk.size(-1) < chunk_size:
|
|
chunks[-1] = pad_tensor(last_chunk, chunk_size)
|
|
|
|
return chunks
|
|
|
|
|
|
def reconstruct_audio(chunks: list[torch.Tensor]) -> torch.Tensor:
|
|
reconstructed_tensor = torch.cat(chunks, dim=-1)
|
|
return reconstructed_tensor
|
|
|
|
|
|
def normalize(audio_tensor: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
|
max_val = torch.max(torch.abs(audio_tensor))
|
|
if max_val < eps:
|
|
return audio_tensor
|
|
return audio_tensor / max_val
|