import torch
import torch.nn.functional as F

def stereo_tensor_to_mono(waveform):
    if waveform.shape[0] > 1:
        # Average across channels
        mono_waveform = torch.mean(waveform, dim=0, keepdim=True)
    else:
        # Already mono
        mono_waveform = waveform
    return mono_waveform

def stretch_tensor(tensor, target_length):
    scale_factor = target_length / tensor.size(1)

    tensor = F.interpolate(tensor, scale_factor=scale_factor, mode='linear', align_corners=False)

    return tensor