:albemic: | More tests.
This commit is contained in:
parent
54338e55a9
commit
f928d8c2cf
@ -28,7 +28,7 @@ class AttentionBlock(nn.Module):
|
|||||||
return x * attention_weights
|
return x * attention_weights
|
||||||
|
|
||||||
class SISUDiscriminator(nn.Module):
|
class SISUDiscriminator(nn.Module):
|
||||||
def __init__(self, layers=64): #Increased base layer count
|
def __init__(self, layers=4): #Increased base layer count
|
||||||
super(SISUDiscriminator, self).__init__()
|
super(SISUDiscriminator, self).__init__()
|
||||||
self.model = nn.Sequential(
|
self.model = nn.Sequential(
|
||||||
discriminator_block(1, layers, kernel_size=3, stride=1), #Aggressive downsampling
|
discriminator_block(1, layers, kernel_size=3, stride=1), #Aggressive downsampling
|
||||||
|
@ -34,7 +34,7 @@ class ResidualInResidualBlock(nn.Module):
|
|||||||
return x + residual
|
return x + residual
|
||||||
|
|
||||||
class SISUGenerator(nn.Module):
|
class SISUGenerator(nn.Module):
|
||||||
def __init__(self, layer=64, num_rirb=4): #increased base layer and rirb amounts
|
def __init__(self, layer=4, num_rirb=4): #increased base layer and rirb amounts
|
||||||
super(SISUGenerator, self).__init__()
|
super(SISUGenerator, self).__init__()
|
||||||
self.conv1 = nn.Sequential(
|
self.conv1 = nn.Sequential(
|
||||||
nn.Conv1d(1, layer, kernel_size=7, padding=3),
|
nn.Conv1d(1, layer, kernel_size=7, padding=3),
|
||||||
|
22
training.py
22
training.py
@ -30,7 +30,7 @@ parser.add_argument("--discriminator", type=str, default=None,
|
|||||||
help="Path to the discriminator model file")
|
help="Path to the discriminator model file")
|
||||||
parser.add_argument("--device", type=str, default="cpu", help="Select device")
|
parser.add_argument("--device", type=str, default="cpu", help="Select device")
|
||||||
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
|
parser.add_argument("--epoch", type=int, default=0, help="Current epoch for model versioning")
|
||||||
parser.add_argument("--verbose", action="store_true", help="Increase output verbosity")
|
parser.add_argument("--debug", action="store_true", help="Print debug logs")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -80,19 +80,20 @@ def generator_train(low_quality, high_quality, real_labels):
|
|||||||
# Forward pass for fake samples (from generator output)
|
# Forward pass for fake samples (from generator output)
|
||||||
generator_output = generator(low_quality[0])
|
generator_output = generator(low_quality[0])
|
||||||
|
|
||||||
mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output)
|
#mfcc_l = gpu_mfcc_loss(high_quality[0], generator_output)
|
||||||
|
|
||||||
discriminator_decision = discriminator(generator_output)
|
discriminator_decision = discriminator(generator_output)
|
||||||
adversarial_loss = criterion_g(discriminator_decision, real_labels)
|
adversarial_loss = criterion_g(discriminator_decision, real_labels)
|
||||||
|
|
||||||
combined_loss = adversarial_loss + 0.5 * mfcc_l
|
#combined_loss = adversarial_loss + 0.5 * mfcc_l
|
||||||
|
|
||||||
combined_loss.backward()
|
adversarial_loss.backward()
|
||||||
optimizer_g.step()
|
optimizer_g.step()
|
||||||
|
|
||||||
return (generator_output, combined_loss, adversarial_loss, mfcc_l)
|
#return (generator_output, combined_loss, adversarial_loss, mfcc_l)
|
||||||
|
return (generator_output, adversarial_loss)
|
||||||
|
|
||||||
debug = args.verbose
|
debug = args.debug
|
||||||
|
|
||||||
# Initialize dataset and dataloader
|
# Initialize dataset and dataloader
|
||||||
dataset_dir = './dataset/good'
|
dataset_dir = './dataset/good'
|
||||||
@ -100,7 +101,7 @@ dataset = AudioDataset(dataset_dir, device)
|
|||||||
|
|
||||||
# ========= SINGLE =========
|
# ========= SINGLE =========
|
||||||
|
|
||||||
train_data_loader = DataLoader(dataset, batch_size=16, shuffle=True)
|
train_data_loader = DataLoader(dataset, batch_size=256, shuffle=True)
|
||||||
|
|
||||||
# Initialize models and move them to device
|
# Initialize models and move them to device
|
||||||
generator = SISUGenerator()
|
generator = SISUGenerator()
|
||||||
@ -157,12 +158,13 @@ def start_training():
|
|||||||
|
|
||||||
# ========= GENERATOR =========
|
# ========= GENERATOR =========
|
||||||
generator.train()
|
generator.train()
|
||||||
generator_output, combined_loss, adversarial_loss, mfcc_l = generator_train(low_quality_sample, high_quality_sample, real_labels)
|
#generator_output, combined_loss, adversarial_loss, mfcc_l = generator_train(low_quality_sample, high_quality_sample, real_labels)
|
||||||
|
generator_output, adversarial_loss = generator_train(low_quality_sample, high_quality_sample, real_labels)
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
print(d_loss, combined_loss, adversarial_loss, mfcc_l)
|
print(d_loss, adversarial_loss)
|
||||||
scheduler_d.step(d_loss)
|
scheduler_d.step(d_loss)
|
||||||
#scheduler_g.step(combined_loss)
|
scheduler_g.step(adversarial_loss)
|
||||||
|
|
||||||
# ========= SAVE LATEST AUDIO =========
|
# ========= SAVE LATEST AUDIO =========
|
||||||
high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0])
|
high_quality_audio = (high_quality_clip[0][0], high_quality_clip[1][0])
|
||||||
|
Loading…
Reference in New Issue
Block a user