Skip to content
This repository has been archived by the owner on Aug 10, 2023. It is now read-only.

Commit

Permalink
fix modules.base; support fast_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
liuqiuhui2015 committed Feb 6, 2021
1 parent 3ce7e4e commit c6b98da
Show file tree
Hide file tree
Showing 24 changed files with 228 additions and 153 deletions.
2 changes: 1 addition & 1 deletion adv/predict/doc/para/predict_doc_para.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

def load_fixing(module):

if "fix_load" in dir(module):
if hasattr(module, "fix_load"):
module.fix_load()

td = h5py.File(cnfg.test_data, "r")
Expand Down
2 changes: 1 addition & 1 deletion adv/predict/predict_ape.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

def load_fixing(module):

if "fix_load" in dir(module):
if hasattr(module, "fix_load"):
module.fix_load()

td = h5py.File(cnfg.test_data, "r")
Expand Down
2 changes: 1 addition & 1 deletion adv/rank/doc/para/rank_loss_para.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

def load_fixing(module):

if "fix_load" in dir(module):
if hasattr(module, "fix_load"):
module.fix_load()

td = h5py.File(sys.argv[2], "r")
Expand Down
2 changes: 1 addition & 1 deletion adv/rank/doc/rank_loss_sent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

def load_fixing(module):

if "fix_load" in dir(module):
if hasattr(module, "fix_load"):
module.fix_load()

td = h5py.File(sys.argv[2], "r")
Expand Down
6 changes: 3 additions & 3 deletions adv/train/doc/para/train_doc_para.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from utils.fmt.base import tostr, save_states, load_states, pad_id
from utils.fmt.base4torch import parse_cuda, load_emb

from lrsch import GoogleLR
from lrsch import GoogleLR as LRScheduler
from loss.base import LabelSmoothingLoss

from random import shuffle
Expand Down Expand Up @@ -176,7 +176,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False):

def init_fixing(module):

if "fix_init" in dir(module):
if hasattr(module, "fix_init"):
module.fix_init()

rid = cnfg.run_id
Expand Down Expand Up @@ -280,7 +280,7 @@ def init_fixing(module):
logger.info("Load optimizer state from: " + fine_tune_state)
optimizer.load_state_dict(h5load(fine_tune_state))

lrsch = GoogleLR(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale)
lrsch = LRScheduler(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale)

num_checkpoint = cnfg.num_checkpoint
cur_checkid = 0
Expand Down
6 changes: 3 additions & 3 deletions adv/train/train_ape.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from utils.fmt.base import tostr, save_states, load_states, pad_id
from utils.fmt.base4torch import parse_cuda, load_emb

from lrsch import GoogleLR
from lrsch import GoogleLR as LRScheduler
from loss.base import LabelSmoothingLoss

from random import shuffle
Expand Down Expand Up @@ -174,7 +174,7 @@ def hook_lr_update(optm, flags=None):

def init_fixing(module):

if "fix_init" in dir(module):
if hasattr(module, "fix_init"):
module.fix_init()

rid = cnfg.run_id
Expand Down Expand Up @@ -270,7 +270,7 @@ def init_fixing(module):
logger.info("Load optimizer state from: " + fine_tune_state)
optimizer.load_state_dict(h5load(fine_tune_state))

lrsch = GoogleLR(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale)
lrsch = LRScheduler(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale)

num_checkpoint = cnfg.num_checkpoint
cur_checkid = 0
Expand Down
6 changes: 3 additions & 3 deletions adv/train/train_dynb.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from utils.fmt.base4torch import parse_cuda, load_emb

from lrsch import GoogleLR
from lrsch import GoogleLR as LRScheduler
from loss.base import LabelSmoothingLoss

from random import shuffle
Expand Down Expand Up @@ -195,7 +195,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False):

def init_fixing(module):

if "fix_init" in dir(module):
if hasattr(module, "fix_init"):
module.fix_init()

rid = cnfg.run_id
Expand Down Expand Up @@ -291,7 +291,7 @@ def init_fixing(module):
logger.info("Load optimizer state from: " + fine_tune_state)
optimizer.load_state_dict(h5load(fine_tune_state))

lrsch = GoogleLR(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale)
lrsch = LRScheduler(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale)

num_checkpoint = cnfg.num_checkpoint
cur_checkid = 0
Expand Down
3 changes: 3 additions & 0 deletions cnfg/hyp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
use_k_relative_position = 0
disable_std_pemb = False

# using fast implementation of label smoothing loss, but it cannot exclude the negative impact of special tokens, like <pad>, on training. `forbidden_indexes` in `cnfg/base.py` shall be set to None to enable.
use_fast_loss = False

# configure maximum batch size w.r.t GPU memory
max_sentences_gpu = 768
max_tokens_gpu = 4608
Expand Down
Loading

0 comments on commit c6b98da

Please sign in to comment.