13 Commits

7 changed files with 261 additions and 179 deletions

18
AudioUtils.py Normal file
View File

@ -0,0 +1,18 @@
import torch
import torch.nn.functional as F
def stereo_tensor_to_mono(waveform):
if waveform.shape[0] > 1:
# Average across channels
mono_waveform = torch.mean(waveform, dim=0, keepdim=True)
else:
# Already mono
mono_waveform = waveform
return mono_waveform
def stretch_tensor(tensor, target_length):
scale_factor = target_length / tensor.size(1)
tensor = F.interpolate(tensor, scale_factor=scale_factor, mode='linear', align_corners=False)
return tensor

60
data.py
View File

@ -1,49 +1,53 @@
from torch.utils.data import Dataset from torch.utils.data import Dataset
import torch.nn.functional as F import torch.nn.functional as F
import torch
import torchaudio import torchaudio
import os import os
import random import random
import torchaudio.transforms as T
import AudioUtils
class AudioDataset(Dataset): class AudioDataset(Dataset):
audio_sample_rates = [8000, 11025, 16000, 22050] audio_sample_rates = [11025]
MAX_LENGTH = 44100 # Define your desired maximum length here
def __init__(self, input_dir, target_duration=None, padding_mode='constant', padding_value=0.0): def __init__(self, input_dir, device):
self.input_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.wav')] self.input_files = [os.path.join(root, f) for root, _, files in os.walk(input_dir) for f in files if f.endswith('.wav')]
self.target_duration = target_duration # Duration in seconds or None if not set self.device = device
self.padding_mode = padding_mode
self.padding_value = padding_value
def __len__(self): def __len__(self):
return len(self.input_files) return len(self.input_files)
def __getitem__(self, idx): def __getitem__(self, idx):
# Load high-quality audio
high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True) high_quality_audio, original_sample_rate = torchaudio.load(self.input_files[idx], normalize=True)
# Generate low-quality audio with random downsampling
mangled_sample_rate = random.choice(self.audio_sample_rates) mangled_sample_rate = random.choice(self.audio_sample_rates)
resample_transform = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate) resample_transform_low = torchaudio.transforms.Resample(original_sample_rate, mangled_sample_rate)
low_quality_audio = resample_transform(high_quality_audio) low_quality_audio = resample_transform_low(high_quality_audio)
# Calculate target length based on desired duration and 16000 Hz resample_transform_high = torchaudio.transforms.Resample(mangled_sample_rate, original_sample_rate)
# if self.target_duration is not None: low_quality_audio = resample_transform_high(low_quality_audio)
# target_length = int(self.target_duration * 44100)
# else:
# # Calculate duration of original high quality audio
# target_length = high_quality_wav.size(1)
# Pad both to the calculated target length high_quality_audio = AudioUtils.stereo_tensor_to_mono(high_quality_audio)
# high_quality_wav = self.stretch_tensor(high_quality_wav, target_length) low_quality_audio = AudioUtils.stereo_tensor_to_mono(low_quality_audio)
# low_quality_wav = self.stretch_tensor(low_quality_wav, target_length)
# Pad or truncate high-quality audio
if high_quality_audio.shape[1] < self.MAX_LENGTH:
padding = self.MAX_LENGTH - high_quality_audio.shape[1]
high_quality_audio = F.pad(high_quality_audio, (0, padding))
elif high_quality_audio.shape[1] > self.MAX_LENGTH:
high_quality_audio = high_quality_audio[:, :self.MAX_LENGTH]
# Pad or truncate low-quality audio
if low_quality_audio.shape[1] < self.MAX_LENGTH:
padding = self.MAX_LENGTH - low_quality_audio.shape[1]
low_quality_audio = F.pad(low_quality_audio, (0, padding))
elif low_quality_audio.shape[1] > self.MAX_LENGTH:
low_quality_audio = low_quality_audio[:, :self.MAX_LENGTH]
high_quality_audio = high_quality_audio.to(self.device)
low_quality_audio = low_quality_audio.to(self.device)
return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate) return (high_quality_audio, original_sample_rate), (low_quality_audio, mangled_sample_rate)
def stretch_tensor(self, tensor, target_length):
current_length = tensor.size(1)
scale_factor = target_length / current_length
# Resample the tensor using linear interpolation
tensor = F.interpolate(tensor.unsqueeze(0), scale_factor=scale_factor, mode='linear', align_corners=False).squeeze(0)
return tensor

View File

@ -1,24 +1,58 @@
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.utils as utils
def discriminator_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, spectral_norm=True):
padding = (kernel_size // 2) * dilation
conv_layer = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding)
if spectral_norm:
conv_layer = utils.spectral_norm(conv_layer)
return nn.Sequential(
conv_layer,
nn.LeakyReLU(0.2, inplace=True),
nn.BatchNorm1d(out_channels)
)
class AttentionBlock(nn.Module):
def __init__(self, channels):
super(AttentionBlock, self).__init__()
self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.ReLU(),
nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x):
attention_weights = self.attention(x)
return x * attention_weights
class SISUDiscriminator(nn.Module): class SISUDiscriminator(nn.Module):
def __init__(self): def __init__(self, layers=4): #Increased base layer count
super(SISUDiscriminator, self).__init__() super(SISUDiscriminator, self).__init__()
self.model = nn.Sequential( self.model = nn.Sequential(
nn.Conv1d(2, 128, kernel_size=3, padding=1), discriminator_block(1, layers, kernel_size=3, stride=1), #Aggressive downsampling
#nn.LeakyReLU(0.2, inplace=True), discriminator_block(layers, layers * 2, kernel_size=5, stride=2),
nn.Conv1d(128, 256, kernel_size=3, padding=1), discriminator_block(layers * 2, layers * 4, kernel_size=5, dilation=4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(256, 128, kernel_size=3, padding=1), #AttentionBlock(layers * 4), #Added attention
#nn.LeakyReLU(0.2, inplace=True),
nn.Conv1d(128, 64, kernel_size=3, padding=1), #discriminator_block(layers * 4, layers * 8, kernel_size=5, dilation=4),
#nn.LeakyReLU(0.2, inplace=True), #AttentionBlock(layers * 8), #Added attention
nn.Conv1d(64, 1, kernel_size=3, padding=1), #discriminator_block(layers * 8, layers * 16, kernel_size=5, dilation=8),
#nn.LeakyReLU(0.2, inplace=True), #discriminator_block(layers * 16, layers * 16, kernel_size=3, dilation=1),
#discriminator_block(layers * 16, layers * 8, kernel_size=3, dilation=2),
#discriminator_block(layers * 8, layers * 4, kernel_size=3, dilation=1),
discriminator_block(layers * 4, layers * 2, kernel_size=5, stride=2),
discriminator_block(layers * 2, layers, kernel_size=3, stride=1),
discriminator_block(layers, 1, kernel_size=3, stride=1, spectral_norm=False) #last layer no spectral norm.
) )
self.global_avg_pool = nn.AdaptiveAvgPool1d(1) # Output size (1,) self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
self.sigmoid = nn.Sigmoid()
def forward(self, x): def forward(self, x):
x = self.model(x) x = self.model(x)
x = self.global_avg_pool(x) x = self.global_avg_pool(x)
x = x.view(-1, 1) # Flatten to (batch_size, 1) x = x.view(-1, 1)
x = self.sigmoid(x)
return x return x

View File

@ -1,27 +1,52 @@
import torch.nn as nn import torch.nn as nn
def conv_block(in_channels, out_channels, kernel_size=3, dilation=1):
return nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, dilation=dilation, padding=(kernel_size // 2) * dilation),
nn.BatchNorm1d(out_channels),
nn.PReLU()
)
class AttentionBlock(nn.Module):
def __init__(self, channels):
super(AttentionBlock, self).__init__()
self.attention = nn.Sequential(
nn.Conv1d(channels, channels // 4, kernel_size=1),
nn.ReLU(),
nn.Conv1d(channels // 4, channels, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x):
attention_weights = self.attention(x)
return x * attention_weights
class ResidualInResidualBlock(nn.Module):
def __init__(self, channels, num_convs=3):
super(ResidualInResidualBlock, self).__init__()
self.conv_layers = nn.Sequential(*[conv_block(channels, channels) for _ in range(num_convs)])
self.attention = AttentionBlock(channels)
def forward(self, x):
residual = x
x = self.conv_layers(x)
x = self.attention(x)
return x + residual
class SISUGenerator(nn.Module): class SISUGenerator(nn.Module):
def __init__(self, upscale_scale=1): # No noise_dim parameter def __init__(self, layer=4, num_rirb=4): #increased base layer and rirb amounts
super(SISUGenerator, self).__init__() super(SISUGenerator, self).__init__()
self.layers1 = nn.Sequential( self.conv1 = nn.Sequential(
nn.Conv1d(2, 128, kernel_size=3, padding=1), nn.Conv1d(1, layer, kernel_size=7, padding=3),
# nn.LeakyReLU(0.2, inplace=True), nn.BatchNorm1d(layer),
nn.Conv1d(128, 256, kernel_size=3, padding=1), nn.PReLU(),
# nn.LeakyReLU(0.2, inplace=True),
) )
self.rir_blocks = nn.Sequential(*[ResidualInResidualBlock(layer) for _ in range(num_rirb)])
self.final_layer = nn.Conv1d(layer, 1, kernel_size=3, padding=1)
self.layers2 = nn.Sequential( def forward(self, x):
nn.Conv1d(256, 128, kernel_size=3, padding=1), residual = x
# nn.LeakyReLU(0.2, inplace=True), x = self.conv1(x)
nn.Conv1d(128, 64, kernel_size=3, padding=1), x = self.rir_blocks(x)
# nn.LeakyReLU(0.2, inplace=True), x = self.final_layer(x)
nn.Conv1d(64, 2, kernel_size=3, padding=1), return x + residual
# nn.Tanh()
)
def forward(self, x, scale):
x = self.layers1(x)
upsample = nn.Upsample(scale_factor=scale, mode='nearest')
x = upsample(x)
x = self.layers2(x)
return x

View File

@ -1,12 +1,14 @@
filelock>=3.16.1 filelock==3.16.1
fsspec>=2024.10.0 fsspec==2024.10.0
Jinja2>=3.1.4 Jinja2==3.1.4
MarkupSafe>=2.1.5 MarkupSafe==2.1.5
mpmath>=1.3.0 mpmath==1.3.0
networkx>=3.4.2 networkx==3.4.2
numpy>=2.1.2 numpy==2.2.3
pillow>=11.0.0 pytorch-triton-rocm==3.2.0+git4b3bb1f8
setuptools>=70.2.0 setuptools==70.2.0
sympy>=1.13.1 sympy==1.13.3
tqdm>=4.67.1 torch==2.7.0.dev20250226+rocm6.3
typing_extensions>=4.12.2 torchaudio==2.6.0.dev20250226+rocm6.3
tqdm==4.67.1
typing_extensions==4.12.2

10
test.py
View File

@ -1,10 +0,0 @@
import torch.nn as nn
import torch
from discriminator import SISUDiscriminator
discriminator = SISUDiscriminator()
test_input = torch.randn(1, 2, 1000) # Example input (batch_size, channels, frames)
output = discriminator(test_input)
print(output)
print("Output shape:", output.shape)

View File

@ -6,95 +6,120 @@ import torch.nn.functional as F
import torchaudio import torchaudio
import tqdm import tqdm
import argparse
import math
import os
from torch.utils.data import random_split from torch.utils.data import random_split
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import AudioUtils
from data import AudioDataset from data import AudioDataset
from generator import SISUGenerator from generator import SISUGenerator
from discriminator import SISUDiscriminator from discriminator import SISUDiscriminator
# Mel Spectrogram Loss import torchaudio.transforms as T
class MelSpectrogramLoss(nn.Module):
def __init__(self, sample_rate=44100, n_fft=2048, hop_length=512, n_mels=128):
super(MelSpectrogramLoss, self).__init__()
self.mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
n_mels=n_mels
).to(device) # Move to device
def forward(self, y_pred, y_true): # Init script argument parser
mel_pred = self.mel_transform(y_pred) parser = argparse.ArgumentParser(description="Training script")
mel_true = self.mel_transform(y_true) parser.add_argument("--generator", type=str, default=None,
return F.l1_loss(mel_pred, mel_true) help="Path to the generator model file")
parser.add_argument("--discriminator", type=str, default=None,
help="Path to the discriminator model file")
parser.add_argument("--device", type=str, default="cpu", help="Select device")
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
parser.add_argument("--debug", action="store_true", help="Print debug logs")
def snr(y_true, y_pred): args = parser.parse_args()
noise = y_true - y_pred
signal_power = torch.mean(y_true ** 2)
noise_power = torch.mean(noise ** 2)
snr_db = 10 * torch.log10(signal_power / noise_power)
return snr_db
def discriminator_train(high_quality, low_quality, scale, real_labels, fake_labels): device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
mfcc_transform = T.MFCC(
sample_rate=44100,
n_mfcc=20,
melkwargs={'n_fft': 2048, 'hop_length': 256}
).to(device)
def gpu_mfcc_loss(y_true, y_pred):
mfccs_true = mfcc_transform(y_true)
mfccs_pred = mfcc_transform(y_pred)
min_len = min(mfccs_true.shape[2], mfccs_pred.shape[2])
mfccs_true = mfccs_true[:, :, :min_len]
mfccs_pred = mfccs_pred[:, :, :min_len]
loss = torch.mean((mfccs_true - mfccs_pred)**2)
return loss
def discriminator_train(high_quality, low_quality, real_labels, fake_labels):
optimizer_d.zero_grad() optimizer_d.zero_grad()
discriminator_decision_from_real = discriminator(high_quality) # Forward pass for real samples
# TODO: Experiment with criterions HERE! discriminator_decision_from_real = discriminator(high_quality[0])
d_loss_real = criterion_d(discriminator_decision_from_real, real_labels) d_loss_real = criterion_d(discriminator_decision_from_real, real_labels)
generator_output = generator(low_quality, scale) # Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0])
discriminator_decision_from_fake = discriminator(generator_output.detach()) discriminator_decision_from_fake = discriminator(generator_output.detach())
# TODO: Experiment with criterions HERE!
d_loss_fake = criterion_d(discriminator_decision_from_fake, fake_labels) d_loss_fake = criterion_d(discriminator_decision_from_fake, fake_labels)
# Combine real and fake losses
d_loss = (d_loss_real + d_loss_fake) / 2.0 d_loss = (d_loss_real + d_loss_fake) / 2.0
# Backward pass and optimization
d_loss.backward() d_loss.backward()
nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) #Gradient Clipping nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) # Gradient Clipping
optimizer_d.step() optimizer_d.step()
return d_loss return d_loss
def generator_train(low_quality, scale, real_labels): def generator_train(low_quality, high_quality, real_labels):
optimizer_g.zero_grad() optimizer_g.zero_grad()
generator_output = generator(low_quality, scale) # Forward pass for fake samples (from generator output)
generator_output = generator(low_quality[0])
#mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output)
discriminator_decision = discriminator(generator_output) discriminator_decision = discriminator(generator_output)
# TODO: Fix this shit adversarial_loss = criterion_g(discriminator_decision, real_labels)
g_loss = criterion_g(discriminator_decision, real_labels)
g_loss.backward() #combined_loss = adversarial_loss + 0.5 * mfcc_l
adversarial_loss.backward()
optimizer_g.step() optimizer_g.step()
return generator_output
# Check for CUDA availability #return (generator_output, combined_loss, adversarial_loss, mfcc_l)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") return (generator_output, adversarial_loss)
print(f"Using device: {device}")
debug = args.debug
# Initialize dataset and dataloader # Initialize dataset and dataloader
dataset_dir = './dataset/good' dataset_dir = './dataset/good'
dataset = AudioDataset(dataset_dir, target_duration=2.0) dataset = AudioDataset(dataset_dir, device)
dataset_size = len(dataset) # ========= SINGLE =========
train_size = int(dataset_size * .9)
val_size = int(dataset_size-train_size)
train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) train_data_loader = DataLoader(dataset, batch_size=256, shuffle=True)
train_data_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_data_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)
# Initialize models and move them to device # Initialize models and move them to device
generator = SISUGenerator() generator = SISUGenerator()
discriminator = SISUDiscriminator() discriminator = SISUDiscriminator()
epoch: int = args.epoch
generator = generator.to(device) generator = generator.to(device)
discriminator = discriminator.to(device) discriminator = discriminator.to(device)
if args.generator is not None:
generator.load_state_dict(torch.load(args.generator, map_location=device, weights_only=True))
if args.discriminator is not None:
discriminator.load_state_dict(torch.load(args.discriminator, map_location=device, weights_only=True))
# Loss # Loss
criterion_g = nn.L1Loss() criterion_g = nn.MSELoss()
criterion_g_mel = MelSpectrogramLoss().to(device) criterion_d = nn.BCELoss()
criterion_d = nn.BCEWithLogitsLoss()
# Optimizers # Optimizers
optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999)) optimizer_g = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
@ -104,44 +129,23 @@ optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.99
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5) scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=5)
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5) scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=5)
models_dir = "models"
os.makedirs(models_dir, exist_ok=True)
def start_training(): def start_training():
generator_epochs = 5000
# Training loop
# ========= DISCRIMINATOR PRE-TRAINING =========
discriminator_epochs = 1
for discriminator_epoch in range(discriminator_epochs):
# ========= TRAINING =========
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Epoch {discriminator_epoch+1}/{discriminator_epochs}"):
high_quality_sample = high_quality_clip[0].to(device)
low_quality_sample = low_quality_clip[0].to(device)
scale = high_quality_clip[0].shape[2]/low_quality_clip[0].shape[2]
# ========= LABELS =========
batch_size = high_quality_sample.size(0)
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# ========= DISCRIMINATOR =========
discriminator.train()
discriminator_train(high_quality_sample, low_quality_sample, scale, real_labels, fake_labels)
torch.save(discriminator.state_dict(), "models/discriminator-single-shot-pre-train.pt")
generator_epochs = 500
for generator_epoch in range(generator_epochs): for generator_epoch in range(generator_epochs):
low_quality_audio = (torch.empty((1)), 1) low_quality_audio = (torch.empty((1)), 1)
high_quality_audio = (torch.empty((1)), 1) high_quality_audio = (torch.empty((1)), 1)
ai_enhanced_audio = (torch.empty((1)), 1) ai_enhanced_audio = (torch.empty((1)), 1)
# ========= TRAINING ========= times_correct = 0
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Epoch {generator_epoch+1}/{generator_epochs}"):
high_quality_sample = high_quality_clip[0].to(device)
low_quality_sample = low_quality_clip[0].to(device)
scale = high_quality_clip[0].shape[2]/low_quality_clip[0].shape[2] # ========= TRAINING =========
for high_quality_clip, low_quality_clip in tqdm.tqdm(train_data_loader, desc=f"Training epoch {generator_epoch+1}/{generator_epochs}, Current epoch {epoch+1}"):
# for high_quality_clip, low_quality_clip in train_data_loader:
high_quality_sample = (high_quality_clip[0], high_quality_clip[1])
low_quality_sample = (low_quality_clip[0], low_quality_clip[1])
# ========= LABELS ========= # ========= LABELS =========
batch_size = high_quality_clip[0].size(0) batch_size = high_quality_clip[0].size(0)
@ -150,34 +154,39 @@ def start_training():
# ========= DISCRIMINATOR ========= # ========= DISCRIMINATOR =========
discriminator.train() discriminator.train()
for _ in range(3): d_loss = discriminator_train(high_quality_sample, low_quality_sample, real_labels, fake_labels)
discriminator_train(high_quality_sample, low_quality_sample, scale, real_labels, fake_labels)
# ========= GENERATOR ========= # ========= GENERATOR =========
generator.train() generator.train()
generator_output = generator_train(low_quality_sample, scale, real_labels) #generator_output, combined_loss, adversarial_loss, mfcc_l = generator_train(low_quality_sample, high_quality_sample, real_labels)
generator_output, adversarial_loss = generator_train(low_quality_sample, high_quality_sample, real_labels)
if debug:
print(d_loss, adversarial_loss)
scheduler_d.step(d_loss)
scheduler_g.step(adversarial_loss)
# ========= SAVE LATEST AUDIO ========= # ========= SAVE LATEST AUDIO =========
high_quality_audio = high_quality_clip high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0])
low_quality_audio = low_quality_clip low_quality_audio = (low_quality_clip[0][0], low_quality_clip[1][0])
ai_enhanced_audio = (generator_output, high_quality_clip[1]) ai_enhanced_audio = (generator_output[0], high_quality_clip[1][0])
metric = snr(high_quality_audio[0].to(device), ai_enhanced_audio[0]) new_epoch = generator_epoch+epoch
print(f"Generator metric {metric}!")
scheduler_g.step(metric)
if generator_epoch % 10 == 0: if generator_epoch % 10 == 0:
print(f"Saved epoch {generator_epoch}!") print(f"Saved epoch {new_epoch}!")
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-crap.wav", low_quality_audio[0][0].cpu(), low_quality_audio[1]) torchaudio.save(f"./output/epoch-{new_epoch}-audio-crap.wav", low_quality_audio[0].cpu(), high_quality_audio[1]) # <-- Because audio clip was resampled in data.py from original to crap and to original again.
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-ai.wav", ai_enhanced_audio[0][0].cpu(), ai_enhanced_audio[1]) torchaudio.save(f"./output/epoch-{new_epoch}-audio-ai.wav", ai_enhanced_audio[0].cpu(), ai_enhanced_audio[1])
torchaudio.save(f"./output/epoch-{generator_epoch}-audio-orig.wav", high_quality_audio[0][0].cpu(), high_quality_audio[1]) torchaudio.save(f"./output/epoch-{new_epoch}-audio-orig.wav", high_quality_audio[0].cpu(), high_quality_audio[1])
if generator_epoch % 50 == 0: if debug:
torch.save(discriminator.state_dict(), f"models/epoch-{generator_epoch}-discriminator.pt") print(generator.state_dict().keys())
torch.save(generator.state_dict(), f"models/epoch-{generator_epoch}-generator.pt") print(discriminator.state_dict().keys())
torch.save(discriminator.state_dict(), f"{models_dir}/discriminator_epoch_{new_epoch}.pt")
torch.save(generator.state_dict(), f"{models_dir}/generator_epoch_{new_epoch}.pt")
torch.save(discriminator.state_dict(), "models/epoch-500-discriminator.pt") torch.save(discriminator, "models/epoch-5000-discriminator.pt")
torch.save(generator.state_dict(), "models/epoch-500-generator.pt") torch.save(generator, "models/epoch-5000-generator.pt")
print("Training complete!") print("Training complete!")
start_training() start_training()