Skip to content

Commit

Permalink
add RawNet wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
ljuvela committed Jul 25, 2023
1 parent 1df873b commit 4efcaed
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 22 deletions.
26 changes: 16 additions & 10 deletions train_watermark.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def train(rank, a, h):

watermark = WatermarkModelEnsemble(
model_type=h.watermark_model,
sample_rate=h.sampling_rate
sample_rate=h.sampling_rate,
config=h
).to(device)

if rank == 0:
Expand Down Expand Up @@ -81,6 +82,17 @@ def train(rank, a, h):
msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)
watermark = DistributedDataParallel(watermark, device_ids=[rank]).to(device)

if a.pretrained_watermark_path is not None:
state_dict = torch.load(a.pretrained_watermark_path)
watermark.load_pretrained_state_dict(state_dict)

if a.freeze_watermark_weights:
for param in watermark.parameters():
param.requires_grad_(False)
# unfreeze output layer
# watermark.output_layer_requires_grad_(True)
# watermark.eval()

optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
h.learning_rate, betas=[h.adam_b1, h.adam_b2])
Expand Down Expand Up @@ -204,18 +216,9 @@ def train(rank, a, h):
wm_losses_f.append(wm_loss_r)
wm_losses_r.append(wm_loss_f)

# Collaborator (watermark), Generator is aligned with Discriminator
# y_df_hat_wm_r, y_df_hat_wm_g, _, _ = mpd_watermark(y, y_g_hat)
# loss_disc_f_wm, losses_disc_f_wm_r, losses_disc_f_wm_g = discriminator_loss(
# disc_real_outputs=y_df_hat_wm_r, disc_generated_outputs=y_df_hat_wm_g)
# y_ds_hat_wm_r, y_ds_hat_wm_g, _, _ = msd_watermark(y, y_g_hat)
# loss_disc_s_wm, losses_disc_s_wm_r, losses_disc_s_wm_g = discriminator_loss(
# disc_real_outputs=y_ds_hat_wm_r, disc_generated_outputs=y_ds_hat_wm_g)

# Adversarial (S, F), Feature matching (S, F), Mel, Collaborative
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel + loss_wm_total


loss_gen_all.backward()
optim_g.step()
optim_wm.step()
Expand Down Expand Up @@ -274,6 +277,8 @@ def train(rank, a, h):
for label, losses in zip(watermark.get_labels(), wm_losses_f):
sw.add_scalar(f"training_watermark/{label}_fake", sum(losses), steps)

# TODO: log minibatch EER

# Validation
if steps % a.validation_interval == 0: # and steps != 0:
generator.eval()
Expand Down Expand Up @@ -373,6 +378,7 @@ def main():
parser.add_argument('--input_validation_file', default='LJSpeech-1.1/validation.txt')
parser.add_argument('--checkpoint_path', default='cp_hifigan')
parser.add_argument('--pretrained_watermark_path', default=None)
parser.add_argument('--freeze_watermark_weights', default=False, type=bool)
parser.add_argument('--config', default='')
parser.add_argument('--training_epochs', default=3100, type=int)
parser.add_argument('--stdout_interval', default=5, type=int)
Expand Down
121 changes: 113 additions & 8 deletions wrappers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,17 @@ def import_module_from_file(name, filepath):
name='lfcc',
filepath=os.path.realpath(f"{__file__}/../../third_party/asvspoof-2021/LA/Baseline-LFCC-LCNN/project/baseline_LA/model.py"))

rawnet = import_module_from_file(
name='rawnet',
filepath=os.path.realpath(f"{__file__}/../../third_party/asvspoof-2021/LA/Baseline-RawNet2/model.py")
)

class LFCC_LCNN(lfcc.Model):

def __init__(self, in_dim, out_dim, sample_rate):
def __init__(self, in_dim, out_dim,
sample_rate,
sigmoid_output=True,
dropout_prob=0.7):
"""
Args:
in_dim: input dimension, default 1 for single channel wav
Expand All @@ -36,14 +43,17 @@ def __init__(self, in_dim, out_dim, sample_rate):
super().__init__(
in_dim, out_dim,
args=args, prj_conf=prj_conf,
mean_std=mean_std)
mean_std=mean_std,
dropout_prob=dropout_prob)

self.sample_rate = sample_rate
if self.sample_rate != 16000:
self.resampler = Resample(orig_freq=self.sample_rate, new_freq=16000)
else:
self.resampler = None

self.sigmoid_out = sigmoid_output

def forward(self, x):
"""
Args:
Expand All @@ -64,26 +74,121 @@ def forward(self, x):

feature_vec = self._compute_embedding(x[:, 0, :], datalength=None)
# return feature_vec
scores = self._compute_score(feature_vec)
scores = self._compute_score(feature_vec, inference=(not self.sigmoid_out))
scores = scores.reshape(-1, 1)
return scores


if __name__ == "__main__":
class RawNet(rawnet.RawNet):

def __init__(
self,
sample_rate=16000,
first_conv=1024,
in_channels=1,
filts=[20, [20, 20], [20, 128], [128, 128]],
nb_fc_node= 1024,
gru_node=1024,
nb_gru_layer=3,
nb_classes=2,
device=torch.device('cpu')):
"""
Args:
nb_samp: ?
first_conv: no. of filter coefficients
in_channels: ?
filts: no. of filters channel in residual blocks
nb_fc_node: ?
gru_node: ?
nb_gru_layer: ?
nb_classes: ?
"""
d_args = {
# 'nb_samp': nb_samp,
'first_conv': first_conv,
'in_channels': in_channels,
'filts': filts,
# 'blocks': blocks,
'nb_fc_node': nb_fc_node,
'gru_node': gru_node,
'nb_gru_layer': nb_gru_layer,
'nb_classes': nb_classes
}

super().__init__(d_args=d_args, device=device)

self.sample_rate = sample_rate
if self.sample_rate != 16000:
self.resampler = Resample(orig_freq=self.sample_rate, new_freq=16000)
else:
self.resampler = None


model = LFCC_LCNN(in_dim=1, out_dim=1)
model = model.eval()
def forward(self, x):
"""
Args:
x: (batch, channels=1, length)
Returns:
scores: (batch, length=1)
"""

if x.ndim != 3:
raise ValueError(f"Expected input of shape (batch, channels=1, timestesps), got {x.shape}")
if x.size(1) != 1:
raise ValueError(f"Expected single channel input, got {x.shape}")

if self.resampler is not None:
x = self.resampler(x)

log_out = super().forward(x[:, 0, :])

# slice from (batch, num_classes) -> (batch, 1)
log_out = log_out[:, 0:1]

return torch.exp(log_out)


def test_lfcc_lcnn():

model = LFCC_LCNN(in_dim=1, out_dim=1, sample_rate=16000)

batch = 2
timesteps = 16000
channels = 1
x = 0.1 * torch.randn(batch, timesteps, requires_grad=True)
x = 0.1 * torch.randn(batch, 1, timesteps, requires_grad=True)
x = torch.nn.Parameter(x)

scores = model.forward(x)

# check that gradients pass
scores.sum().backward()

print(f"{x.grad}")

assert x.grad is not None





if __name__ == "__main__":

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

batch = 3
timesteps = 16000
channels = 1
x = 0.1 * torch.randn(batch, 1, timesteps, requires_grad=True)
x = torch.nn.Parameter(x)
x_dev = x.to(device)

model = RawNet(sample_rate=22050)
model = model.to(device)

scores = model.forward(x_dev)
scores.pow(2).sum().backward()

print(f"{x.grad}")
assert x.grad is not None
39 changes: 36 additions & 3 deletions wrappers/watermark.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,27 @@

from models import MultiPeriodDiscriminator, MultiScaleDiscriminator

from .models import LFCC_LCNN
from .models import LFCC_LCNN, RawNet

class WatermarkModelEnsemble(torch.nn.Module):

def __init__(self, model_type:str, sample_rate:int):
def __init__(self, model_type:str, sample_rate:int, config):
super().__init__()

self.models = torch.nn.ModuleDict()
self.model_type = model_type

if model_type == "hifi_gan":
self.models['mpd'] = MultiPeriodDiscriminator()
self.models['msd'] = MultiScaleDiscriminator()
elif model_type == "lfcc_lcnn":
self.models['lfcc_lcnn'] = LFCC_LCNN(
in_dim=1, out_dim=1,
sample_rate=sample_rate)
sample_rate=sample_rate,
sigmoid_output=config.get('lfcc_lcnn_sigmoid_out', True),
dropout_prob=config.get('lfcc_lcnn_dropout_prob', 0.7))
elif model_type == 'raw_net':
self.models['raw_net'] = RawNet(sample_rate=sample_rate)
elif model_type is None:
pass
else:
Expand Down Expand Up @@ -50,3 +55,31 @@ def get_labels(self):
def get_num_models(self):
return len(self.models.keys())


def load_pretrained_state_dict(self, state_dict):

if self.model_type == 'lfcc_lcnn':

state_dict_old = self.models['lfcc_lcnn'].state_dict()

optional_keys = ['m_frontend.0.window', 'resampler.kernel']
for ok in optional_keys:
val = state_dict.get(ok, state_dict_old[ok])
state_dict[ok] = val

self.models['lfcc_lcnn'].load_state_dict(state_dict)
elif self.model_type == 'raw_net':

self.models['raw_net'].load_state_dict(state_dict)
else:
raise NotImplementedError()

def output_layer_requires_grad_(self, requires_grad: bool = True):

if self.model_type == 'lfcc_lcnn':
self.models['lfcc_lcnn'].m_output_act.requires_grad_(requires_grad)
else:
raise NotImplementedError()



0 comments on commit 4efcaed

Please sign in to comment.