Skip to content

Commit

Permalink
adding frequency splits network
Browse files Browse the repository at this point in the history
  • Loading branch information
JNaranjo-Alcazar committed Dec 4, 2020
1 parent 9d6fe66 commit 5355205
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 12 deletions.
12 changes: 11 additions & 1 deletion code/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# data paths
path = '/repos/DCASE2021-Task1/data/gammatone_64/'
path = '/repos/DCASE2021-Task1/data/gammatone_256/'

# audio representation hyperparameters
freq_bands = 64
Expand All @@ -16,6 +16,10 @@
dense_layer = None
dropouts_rate_cl = None

split_freqs = True
n_split_freqs = 3
f_split_freqs = [64, 128]

# callbacks
# Early Stopping
early_stopping = True
Expand All @@ -34,6 +38,12 @@
patience_lr_on_plateau = None
min_lr_on_plateau = None

# Save models and csvs
save_outputs = True
best_model_path = '/repos/DCASE2021-Task1/outputs/best.h5'
last_model_path = '/repos/DCASE2021-Task1/outputs/last.h5'
log_path = '/repos/DCASE2021-Task1/outputs/log.csv'

# training hyperparamteres
quick_test = True
loss_type = 'focal_loss'
Expand Down
20 changes: 14 additions & 6 deletions code/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
from data_augmentation import MixupGenerator
from focal_loss import categorical_focal_loss
from load_data import load_h5s
from models import res_conv_standard_post_csse
from models import res_conv_standard_post_csse, res_conv_standard_post_csse_split_freqs
from tests import (check_reshape_variable, check_model_depth, check_alpha_list, check_loss_type, check_data_generator,
check_training_verbose, is_boolean, check_callbacks)


# check config options
check_reshape_variable(config.reshape_method)
check_model_depth(config.n_filters, config.pools_size, config.dropouts_rate)
Expand All @@ -24,10 +23,19 @@
print('Validation shape: {}'.format(val_x.shape))

# creating model
model = res_conv_standard_post_csse(X.shape[1], X.shape[2], X.shape[3], Y.shape[1],
config.n_filters, config.pools_size, config.dropouts_rate, config.ratio,
config.reshape_method, config.dense_layer,
verbose=True)

if config.split_freqs is not True:
model = res_conv_standard_post_csse(X.shape[1], X.shape[2], X.shape[3], Y.shape[1],
config.n_filters, config.pools_size, config.dropouts_rate, config.ratio,
config.reshape_method, config.dense_layer,
verbose=config.verbose)

else:
model = res_conv_standard_post_csse_split_freqs(X.shape[1], X.shape[2], X.shape[3], Y.shape[1],
config.n_filters, config.pools_size, config.dropouts_rate,
config.ratio,
config.reshape_method, config.dense_layer,
config.n_split_freqs, config.f_split_freqs, verbose=config.verbose)

# checking focal loss if necessary
if config.loss_type == 'focal_loss':
Expand Down
78 changes: 76 additions & 2 deletions code/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import keras.layers
from modules import network_module
from modules import network_module, freq_split

from keras.models import Model

Expand Down Expand Up @@ -43,4 +43,78 @@ def res_conv_standard_post_csse(h, w, n_channels, n_classes,

return model

#TODO trident model

def res_conv_standard_post_csse_split_freqs(h, w, n_channels, n_classes,
nfilters, pools_size, dropouts_rate, ratio, reshape_type, dense_layer,
n_split_freqs, f_split_freqs, verbose=False):
ip = keras.layers.Input(shape=(h, w, n_channels))

if n_split_freqs == 2:

splits = keras.layers.Lambda(freq_split)(ip, n_split_freqs, f_split_freqs)

x1 = splits[0]
x2 = splits[1]

for i in range(0, len(nfilters)):
x1 = network_module(x1, nfilters[i], ratio, pools_size[i], dropouts_rate[i])
x2 = network_module(x2, nfilters[i], ratio, pools_size[i], dropouts_rate[i])

x = keras.layers.concatenate([x1, x2], axis=1)

# Reshape
if reshape_type == 'global_avg':
x = keras.layers.GlobalAveragePooling2D()(x)

elif reshape_type == 'flatten':
x = keras.layers.Flatten()(x)

elif reshape_type == 'global_max':
x = keras.layers.GlobalMaxPooling2D()(x)

if dense_layer is None:
x = keras.layers.Dense(n_classes)(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Activation('softmax')(x)

model = Model(ip, x)

if verbose:
print(model.summary())

return model

elif n_split_freqs == 3:

splits = keras.layers.Lambda(freq_split, arguments={'n_split_freqs': n_split_freqs, 'f_split_freqs': f_split_freqs})(ip)

x1 = splits[0]
x2 = splits[1]
x3 = splits[2]

for i in range(0, len(nfilters)):
x1 = network_module(x1, nfilters[i], ratio, pools_size[i], dropouts_rate[i])
x2 = network_module(x2, nfilters[i], ratio, pools_size[i], dropouts_rate[i])
x3 = network_module(x3, nfilters[i], ratio, pools_size[i], dropouts_rate[i])

x = keras.layers.concatenate([x1, x2, x3], axis=1)

# Reshape
if reshape_type == 'global_avg':
x = keras.layers.GlobalAveragePooling2D()(x)
elif reshape_type == 'flatten':
x = keras.layers.Flatten()(x)
elif reshape_type == 'global_max':
x = keras.layers.GlobalMaxPooling2D()(x)

if dense_layer is None:
x = keras.layers.Dense(n_classes)(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Activation('softmax')(x)

model = Model(ip, x)

if verbose:
print(model.summary())

return model
15 changes: 15 additions & 0 deletions code/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,18 @@ def network_module(inp, nfilters, ratio, pool_size, dropout_rate):
x = Dropout(dropout_rate)(x)

return x


def freq_split(inp, n_split_freqs, f_split_freqs):
if n_split_freqs == 2:
x1 = inp[:, 0:f_split_freqs[0], :, :]
x2 = inp[:, f_split_freqs[0]:, :, :]

return [x1, x2]

if n_split_freqs == 3:
x1 = inp[:, 0:f_split_freqs[0], :, :]
x2 = inp[:, f_split_freqs[0]:f_split_freqs[1], :, :]
x3 = inp[:, f_split_freqs[1]:, :, :]

return [x1, x2, x3]
26 changes: 23 additions & 3 deletions code/tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import config
from callbacks import lr_on_plateau, early_stopping, GetLRAfterEpoch

import os
import keras

def check_reshape_variable(reshape_method):
possible_options = ['global_avg', 'global_max', 'flatten']
Expand Down Expand Up @@ -60,7 +61,8 @@ def is_boolean(inp):

def check_callbacks():

if config.early_stopping is not True and config.get_lr_after_epoch is not True and config.factor_lr_on_plateau is not True:
if (config.early_stopping is not True and config.get_lr_after_epoch is not True
and config.factor_lr_on_plateau and config.save_outputs is not True):
return None
else:
if config.early_stopping is True:
Expand Down Expand Up @@ -90,7 +92,25 @@ def check_callbacks():
else:
lr_onplt = []

callbacks = [es, get_lr, lr_onplt]
if config.save_outputs is True:
home_path = os.getenv('HOME')
save_best = keras.callbacks.ModelCheckpoint(home_path + config.best_model_path, save_best_only=True,
monitor='val_categorical_accuracy')
save = keras.callbacks.ModelCheckpoint(home_path + config.last_model_path)
csv_log = keras.callbacks.CSVLogger(home_path + config.log_path)

else:
save_best = []
save = []
csv_log = []

callbacks = [es, get_lr, lr_onplt, save_best, save, csv_log]
callbacks = list(filter(None, callbacks))

return callbacks


def check_split_freqs(split_freqs, n_split_freqs, f_split_freqs):
if split_freqs is True:
if n_split_freqs - len(f_split_freqs) != 1:
raise Exception('Number of split frequencies and frequencies cutoff do not match.')

0 comments on commit 5355205

Please sign in to comment.