import torch
import torch.nn as nn
import torch.optim as optim

import torchaudio
import torchaudio.transforms as T

def gpu_mfcc_loss(mfcc_transform, y_true, y_pred):
    mfccs_true = mfcc_transform(y_true)
    mfccs_pred = mfcc_transform(y_pred)

    min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2])
    mfccs_true = mfccs_true[:, :, :min_len]
    mfccs_pred = mfccs_pred[:, :, :min_len]

    loss = torch.mean((mfccs_true - mfccs_pred)**2)
    return loss

def mel_spectrogram_l1_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
    mel_spec_true = mel_transform(y_true)
    mel_spec_pred = mel_transform(y_pred)

    # Ensure same time dimension length (due to potential framing differences)
    min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1])
    mel_spec_true = mel_spec_true[..., :min_len]
    mel_spec_pred = mel_spec_pred[..., :min_len]

    # L1 Loss (Mean Absolute Error)
    loss = torch.mean(torch.abs(mel_spec_true - mel_spec_pred))
    return loss

def mel_spectrogram_l2_loss(mel_transform: T.MelSpectrogram, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
    mel_spec_true = mel_transform(y_true)
    mel_spec_pred = mel_transform(y_pred)

    min_len = min(mel_spec_true.shape[-1], mel_spec_pred.shape[-1])
    mel_spec_true = mel_spec_true[..., :min_len]
    mel_spec_pred = mel_spec_pred[..., :min_len]

    # L2 Loss (Mean Squared Error)
    loss = torch.mean((mel_spec_true - mel_spec_pred)**2)
    return loss

def log_stft_magnitude_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
    stft_mag_true = stft_transform(y_true)
    stft_mag_pred = stft_transform(y_pred)

    min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1])
    stft_mag_true = stft_mag_true[..., :min_len]
    stft_mag_pred = stft_mag_pred[..., :min_len]

    # Log Magnitude L1 Loss
    loss = torch.mean(torch.abs(torch.log(stft_mag_true + eps) - torch.log(stft_mag_pred + eps)))
    return loss

def spectral_convergence_loss(stft_transform: T.Spectrogram, y_true: torch.Tensor, y_pred: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
    stft_mag_true = stft_transform(y_true)
    stft_mag_pred = stft_transform(y_pred)

    min_len = min(stft_mag_true.shape[-1], stft_mag_pred.shape[-1])
    stft_mag_true = stft_mag_true[..., :min_len]
    stft_mag_pred = stft_mag_pred[..., :min_len]

    # Calculate Frobenius norms and the loss
    # Ensure norms are calculated over frequency and time dims ([..., freq, time])
    norm_true = torch.linalg.norm(stft_mag_true, ord='fro', dim=(-2, -1))
    norm_diff = torch.linalg.norm(stft_mag_true - stft_mag_pred, ord='fro', dim=(-2, -1))

    # Average loss over the batch
    loss = torch.mean(norm_diff / (norm_true + eps))
    return loss

def discriminator_train(high_quality, low_quality, real_labels, fake_labels, discriminator, generator, criterion, optimizer):
    optimizer.zero_grad()

    # Forward pass for real samples
    discriminator_decision_from_real = discriminator(high_quality[0])
    d_loss_real = criterion(discriminator_decision_from_real, real_labels)

    # Forward pass for fake samples (from generator output)
    with torch.no_grad(): # Detach generator output within no_grad context
        generator_output = generator(low_quality[0])
    discriminator_decision_from_fake = discriminator(generator_output) # No need to detach again if inside no_grad
    d_loss_fake = criterion(discriminator_decision_from_fake, fake_labels.expand_as(discriminator_decision_from_fake))

    # Combine real and fake losses
    d_loss = (d_loss_real + d_loss_fake) / 2.0

    # Backward pass and optimization
    d_loss.backward()
    # Optional: Gradient Clipping (can be helpful)
    # nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)  # Gradient Clipping
    optimizer.step()

    return d_loss

def generator_train(
    low_quality,
    high_quality,
    real_labels,
    generator,
    discriminator,
    adv_criterion, # Criterion for adversarial loss (e.g., BCEWithLogitsLoss)
    g_optimizer,
    device,
    # --- Pass necessary transforms and loss weights ---
    mel_transform: T.MelSpectrogram, # Example: Pass Mel transform
    # stft_transform: T.Spectrogram, # Pass STFT transform if using STFT losses
    # mfcc_transform: T.MFCC,      # Pass MFCC transform if using MFCC loss
    lambda_adv: float = 1.0,       # Weight for adversarial loss
    lambda_mel_l1: float = 10.0,   # Example: Weight for Mel L1 loss
    # lambda_log_stft: float = 0.0, # Set weights > 0 for losses you want to use
    # lambda_mfcc: float = 0.0
):
    g_optimizer.zero_grad()

    # 1. Generate high-quality audio from low-quality input
    generator_output = generator(low_quality[0])

    # 2. Calculate Adversarial Loss (Generator tries to fool discriminator)
    discriminator_decision = discriminator(generator_output)
    # Generator wants discriminator to output "real" labels for its fakes
    adversarial_loss = adv_criterion(discriminator_decision, real_labels.expand_as(discriminator_decision))

    # 3. Calculate Reconstruction/Spectrogram Loss(es)
    # --- Choose and calculate the losses you want to include ---
    mel_l1 = 0.0
    # log_stft_l1 = 0.0
    # mfcc_l = 0.0

    # Calculate Mel L1 Loss if weight is positive
    if lambda_mel_l1 > 0:
        mel_l1 = mel_spectrogram_l1_loss(mel_transform, high_quality[0], generator_output)

    # # Calculate Log STFT L1 Loss if weight is positive
    # if lambda_log_stft > 0:
    #     log_stft_l1 = log_stft_magnitude_loss(stft_transform, hq_audio, generator_output)

    # # Calculate MFCC Loss if weight is positive
    # if lambda_mfcc > 0:
    #     mfcc_l = gpu_mfcc_loss(mfcc_transform, hq_audio, generator_output)
    # --- End of Loss Calculation Choices ---


    # 4. Combine Losses
    # Make sure calculated losses are tensors even if weights are 0 initially
    # (or handle appropriately in the sum)
    mel_l1_tensor = torch.tensor(mel_l1, device=device) if isinstance(mel_l1, float) else mel_l1
    # log_stft_l1_tensor = torch.tensor(log_stft_l1, device=device) if isinstance(log_stft_l1, float) else log_stft_l1
    # mfcc_l_tensor = torch.tensor(mfcc_l, device=device) if isinstance(mfcc_l, float) else mfcc_l

    combined_loss = (lambda_adv * adversarial_loss) + \
                    (lambda_mel_l1 * mel_l1_tensor)
                    # + (lambda_log_stft * log_stft_l1_tensor) \
                    # + (lambda_mfcc * mfcc_l_tensor)

    # 5. Backward Pass and Optimization
    combined_loss.backward()
    # Optional: Gradient Clipping
    # nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
    g_optimizer.step()

    # 6. Return values for logging
    return generator_output, combined_loss, adversarial_loss, mel_l1_tensor