Merge new-arch, because it has proven to give the best results #1

Merged
NikkeDoy merged 14 commits from new-arch into main 2025-04-30 23:47:41 +03:00
3 changed files with 44 additions and 69 deletions
Showing only changes of commit 416500f7fc - Show all commits

View File

@ -12,8 +12,9 @@ class AudioDataset(Dataset):
#audio_sample_rates = [8000, 11025, 16000, 22050] #audio_sample_rates = [8000, 11025, 16000, 22050]
audio_sample_rates = [11025] audio_sample_rates = [11025]
def __init__(self, input_dir): def __init__(self, input_dir, device):
self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')] self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')]
self.device = device
def __len__(self): def __len__(self):
@ -32,4 +33,7 @@ class AudioDataset(Dataset):
resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate) resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate)
low_quality_audio = resample_transform_high(low_quality_audio) low_quality_audio = resample_transform_high(low_quality_audio)
return (AudioUtils.stereo_tensor_to_mono(high_quality_audio), original_sample_rate), (AudioUtils.stereo_tensor_to_mono(low_quality_audio), mangled_sample_rate) high_quality_audio = AudioUtils.stereo_tensor_to_mono(high_quality_audio).to(self.device)
low_quality_audio = AudioUtils.stereo_tensor_to_mono(low_quality_audio).to(self.device)
return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate)

View File

@ -4,11 +4,11 @@ Jinja2==3.1.4
MarkupSafe==2.1.5 MarkupSafe==2.1.5
mpmath==1.3.0 mpmath==1.3.0
networkx==3.4.2 networkx==3.4.2
numpy==2.2.1 numpy==2.2.3
pytorch-triton-rocm==3.2.0+git0d4682f0 pytorch-triton-rocm==3.2.0+git4b3bb1f8
setuptools==70.2.0 setuptools==70.2.0
sympy==1.13.1 sympy==1.13.3
torch==2.6.0.dev20241222+rocm6.2.4 torch==2.7.0.dev20250226+rocm6.3
torchaudio==2.6.0.dev20241222+rocm6.2.4 torchaudio==2.6.0.dev20250226+rocm6.3
tqdm==4.67.1 tqdm==4.67.1
typing_extensions==4.12.2 typing_extensions==4.12.2

View File

@ -20,49 +20,35 @@ from data import AudioDataset
from generator import SISUGenerator from generator import SISUGenerator
from discriminator import SISUDiscriminator from discriminator import SISUDiscriminator
import librosa import torchaudio.transforms as T
import numpy as np
def mfcc_loss(y_true, y_pred, sr): # Init script argument parser
"""Calculates MFCC loss between two audio signals. parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--generator", type=str, default=None,
help="Path to the generator model file")
parser.add_argument("--discriminator", type=str, default=None,
help="Path to the discriminator model file")
parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
parser.add_argument("--verbose", action="store_true", help="Increase output verbosity")
Args: args = parser.parse_args()
y_true: Target audio signal (PyTorch tensor).
y_pred: Predicted audio signal (PyTorch tensor).
sr: Sample rate (NumPy scalar).
Returns: device = torch.device(args.device if torch.cuda.is_available() else "cpu")
MFCC loss (PyTorch tensor). print(f"Using device: {device}")
"""
# 1. Ensure sr is a NumPy scalar (not a Tensor) mfcc_transform = T.MFCC(
if isinstance(sr, torch.Tensor): sample_rate=16000, # Adjust to your sample rate
sr = sr.item() n_mfcc=20,
melkwargs={'n_fft': 2048, 'hop_length': 512} # adjust n_fft and hop_length to your needs.
# 2. Convert y_true and y_pred to NumPy arrays (and detach from graph) ).to(device)
y_true_np = y_true.cpu().detach().numpy()[0] # .cpu() and .detach() are crucial!
y_pred_np = y_pred.cpu().detach().numpy()[0]
# 3. Dynamically calculate n_fft based on signal length
signal_length = min(y_true_np.shape[0], y_pred_np.shape[0]) # Use shortest signal length
n_fft = min(2048, 2**int(np.log2(signal_length))) # Power of 2, up to 2048
# 4. Calculate MFCCs using adjusted n_fft
mfccs_true = librosa.feature.mfcc(y=y_true_np, sr=sr, n_fft=n_fft, n_mfcc=20)
mfccs_pred = librosa.feature.mfcc(y=y_pred_np, sr=sr, n_fft=n_fft, n_mfcc=20)
# 5. Truncate MFCCs to the same length (important!)
len_true = mfccs_true.shape[1]
len_pred = mfccs_pred.shape[1]
min_len = min(len_true, len_pred)
mfccs_true = mfccs_true[:, :min_len]
mfccs_pred = mfccs_pred[:, :min_len]
# 6. Convert MFCCs back to PyTorch tensors and ensure correct device
mfccs_true = torch.tensor(mfccs_true, device=y_true.device, dtype=torch.float32)
mfccs_pred = torch.tensor(mfccs_pred, device=y_pred.device, dtype=torch.float32)
def gpu_mfcc_loss(y_true, y_pred):
mfccs_true = mfcc_transform(y_true)
mfccs_pred = mfcc_transform(y_pred)
min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2])
mfccs_true = mfccs_true[:, :, :min_len]
mfccs_pred = mfccs_pred[:, :, :min_len]
return torch.mean((mfccs_true - mfccs_pred)**2) return torch.mean((mfccs_true - mfccs_pred)**2)
def discriminator_train(high_quality, low_quality, real_labels, fake_labels): def discriminator_train(high_quality, low_quality, real_labels, fake_labels):
@ -93,7 +79,7 @@ def generator_train(low_quality, high_quality, real_labels):
# Forward pass for fake samples (from generator output) # Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0]) generator_output = generator(low_quality[0])
mfcc_l = mfcc_loss(high_quality[0], generator_output, high_quality[1]) mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output)
discriminator_decision = discriminator(generator_output) discriminator_decision = discriminator(generator_output)
adversarial_loss = criterion_g(discriminator_decision, real_labels) adversarial_loss = criterion_g(discriminator_decision, real_labels)
@ -105,26 +91,11 @@ def generator_train(low_quality, high_quality, real_labels):
return (generator_output, combined_loss, adversarial_loss, mfcc_l) return (generator_output, combined_loss, adversarial_loss, mfcc_l)
# Init script argument parser
parser = argparse.ArgumentParser(description="Training script")
parser.add_argument("--generator", type=str, default=None,
help="Path to the generator model file")
parser.add_argument("--discriminator", type=str, default=None,
help="Path to the discriminator model file")
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
parser.add_argument("--verbose", action="store_true", help="Increase output verbosity")
args = parser.parse_args()
# Check for CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
debug = args.verbose debug = args.verbose
# Initialize dataset and dataloader # Initialize dataset and dataloader
dataset_dir = './dataset/good' dataset_dir = './dataset/good'
dataset = AudioDataset(dataset_dir) dataset = AudioDataset(dataset_dir, device)
# ========= MULTIPLE ========= # ========= MULTIPLE =========
@ -147,14 +118,14 @@ discriminator = SISUDiscriminator()
epoch: int = args.epoch epoch: int = args.epoch
if args.generator is not None:
generator.load_state_dict(torch.load(args.generator, weights_only=True))
if args.discriminator is not None:
discriminator.load_state_dict(torch.load(args.discriminator, weights_only=True))
generator = generator.to(device) generator = generator.to(device)
discriminator = discriminator.to(device) discriminator = discriminator.to(device)
if args.generator is not None:
generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True))
if args.discriminator is not None:
discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True))
# Loss # Loss
criterion_g = nn.MSELoss() criterion_g = nn.MSELoss()
criterion_d = nn.BCELoss() criterion_d = nn.BCELoss()
@ -182,8 +153,8 @@ def start_training():
# ========= TRAINING ========= # ========= TRAINING =========
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"): for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"):
# for high_quality_clip, low_quality_clip in train_data_loader: # for high_quality_clip, low_quality_clip in train_data_loader:
high_quality_sample = (high_quality_clip[0].to(device), high_quality_clip[1]) high_quality_sample = (high_quality_clip[0], high_quality_clip[1])
low_quality_sample = (low_quality_clip[0].to(device), low_quality_clip[1]) low_quality_sample = (low_quality_clip[0], low_quality_clip[1])
# ========= LABELS ========= # ========= LABELS =========
batch_size = high_quality_clip[0].size(0) batch_size = high_quality_clip[0].size(0)