⚗️ | Added some stupid ways for training + some makeup
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
@@ -52,7 +53,7 @@ class ResidualInResidualBlock(nn.Module):
|
||||
|
||||
|
||||
class SISUGenerator(nn.Module):
|
||||
def __init__(self, channels=16, num_rirb=4, alpha=1.0):
|
||||
def __init__(self, channels=16, num_rirb=4, alpha=1):
|
||||
super(SISUGenerator, self).__init__()
|
||||
self.alpha = alpha
|
||||
|
||||
@@ -66,7 +67,9 @@ class SISUGenerator(nn.Module):
|
||||
*[ResidualInResidualBlock(channels) for _ in range(num_rirb)]
|
||||
)
|
||||
|
||||
self.final_layer = nn.Conv1d(channels, 1, kernel_size=3, padding=1)
|
||||
self.final_layer = nn.Sequential(
|
||||
nn.Conv1d(channels, 1, kernel_size=3, padding=1), nn.Tanh()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
residual_input = x
|
||||
@@ -75,4 +78,4 @@ class SISUGenerator(nn.Module):
|
||||
learned_residual = self.final_layer(x_rirb_out)
|
||||
output = residual_input + self.alpha * learned_residual
|
||||
|
||||
return output
|
||||
return torch.tanh(output)
|
||||
|
Reference in New Issue
Block a user