Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Dawn-LX committed Apr 29, 2022
1 parent d4fd492 commit 80afefa
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 39 deletions.
37 changes: 19 additions & 18 deletions experiments/grounding_weights/config_.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,34 +32,35 @@

train_dataset_config = dict(
split = "train",
ann_dir = "datasets/vidor-dataset/annotation",
video_dir = "datasets/vidor-dataset/train_videos",
# NOTE classeme_dir & proposal_dir 中的东西还在kgl103上,因为之前直接在103上tocache了
video_feature_dir = "/home/gkf/I3D_Pytorch/I3D_clip_features/clip16_overlap0.5",
classeme_dir = "proposals/miss60_minscore0p3/VidORtrain_freq1_classeme",
ann_dir = "/home/gkf/project/VidVRD_VidOR/vidor-dataset/annotation",
video_dir = "/home/gkf/project/VidVRD_VidOR/vidor-dataset/train_videos",
classeme_dir = "/home/gkf/project/deepSORT/tracking_results/miss60_minscore0p3/VidORtrain_freq1_classeme",
proposal_dir = {
0:"proposals/miss60_minscore0p3/VidORtrain_freq1_part01",
1:"proposals/miss60_minscore0p3/VidORtrain_freq1_part02",
2:"proposals/miss60_minscore0p3/VidORtrain_freq1_part03",
3:"proposals/miss60_minscore0p3/VidORtrain_freq1_part04",
4:"proposals/miss60_minscore0p3/VidORtrain_freq1_part05",
5:"proposals/miss60_minscore0p3/VidORtrain_freq1_part06",
6:"proposals/miss60_minscore0p3/VidORtrain_freq1_part07",
7:"proposals/miss60_minscore0p3/VidORtrain_freq1_part08",
8:"proposals/miss60_minscore0p3/VidORtrain_freq1_part09",
9:"proposals/miss60_minscore0p3/VidORtrain_freq1_part10",
10:"proposals/miss60_minscore0p3/VidORtrain_freq1_part11",
11:"proposals/miss60_minscore0p3/VidORtrain_freq1_part12",
12:"proposals/miss60_minscore0p3/VidORtrain_freq1_part13",
13:"proposals/miss60_minscore0p3/VidORtrain_freq1_part14",
# 1:"proposals/miss60_minscore0p3/VidORtrain_freq1_part02",
# 2:"proposals/miss60_minscore0p3/VidORtrain_freq1_part03",
# 3:"proposals/miss60_minscore0p3/VidORtrain_freq1_part04",
# 4:"proposals/miss60_minscore0p3/VidORtrain_freq1_part05",
# 5:"proposals/miss60_minscore0p3/VidORtrain_freq1_part06",
# 6:"proposals/miss60_minscore0p3/VidORtrain_freq1_part07",
# 7:"proposals/miss60_minscore0p3/VidORtrain_freq1_part08",
# 8:"proposals/miss60_minscore0p3/VidORtrain_freq1_part09",
# 9:"proposals/miss60_minscore0p3/VidORtrain_freq1_part10",
# 10:"proposals/miss60_minscore0p3/VidORtrain_freq1_part11",
# 11:"proposals/miss60_minscore0p3/VidORtrain_freq1_part12",
# 12:"proposals/miss60_minscore0p3/VidORtrain_freq1_part13",
# 13:"proposals/miss60_minscore0p3/VidORtrain_freq1_part14",
},
cache_dir = "datasets/cache",
cache_tag = "MEGAv7",
dim_boxfeature = 1024,
min_frames_th = 15,
max_proposal = 180,
max_preds = 200,
score_th = 0.4
)


train_config = dict(
batch_size = 8,
total_epoch = 80,
Expand Down
11 changes: 6 additions & 5 deletions experiments/grounding_weights/config_bin1.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,9 @@

train_dataset_config = dict(
split = "train",
ann_dir = "datasets/vidor-dataset/annotation",
video_dir = "datasets/vidor-dataset/train_videos",
# NOTE classeme_dir & proposal_dir 中的东西还在kgl103上,因为之前直接在103上tocache了
video_feature_dir = "/home/gkf/I3D_Pytorch/I3D_clip_features/clip16_overlap0.5",
classeme_dir = "proposals/miss60_minscore0p3/VidORtrain_freq1_classeme",
ann_dir = "/home/gkf/project/VidVRD_VidOR/vidor-dataset/annotation",
video_dir = "/home/gkf/project/VidVRD_VidOR/vidor-dataset/train_videos",
classeme_dir = "/home/gkf/project/deepSORT/tracking_results/miss60_minscore0p3/VidORtrain_freq1_classeme",
proposal_dir = {
0:"proposals/miss60_minscore0p3/VidORtrain_freq1_part01",
1:"proposals/miss60_minscore0p3/VidORtrain_freq1_part02",
Expand All @@ -53,13 +51,16 @@
12:"proposals/miss60_minscore0p3/VidORtrain_freq1_part13",
13:"proposals/miss60_minscore0p3/VidORtrain_freq1_part14",
},
cache_dir = "datasets/cache",
cache_tag = "MEGAv7",
dim_boxfeature = 1024,
min_frames_th = 15,
max_proposal = 180,
max_preds = 200,
score_th = 0.4
)


train_config = dict(
batch_size = 8,
total_epoch = 80,
Expand Down
11 changes: 5 additions & 6 deletions experiments/grounding_weights/config_bin5.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,11 @@
cache_tag = "MEGAv9_m60s0.3_freq1"
)


train_dataset_config = dict(
split = "train",
ann_dir = "datasets/vidor-dataset/annotation",
video_dir = "datasets/vidor-dataset/train_videos",
# NOTE classeme_dir & proposal_dir 中的东西还在kgl103上,因为之前直接在103上tocache了
video_feature_dir = "/home/gkf/I3D_Pytorch/I3D_clip_features/clip16_overlap0.5",
classeme_dir = "proposals/miss60_minscore0p3/VidORtrain_freq1_classeme",
ann_dir = "/home/gkf/project/VidVRD_VidOR/vidor-dataset/annotation",
video_dir = "/home/gkf/project/VidVRD_VidOR/vidor-dataset/train_videos",
classeme_dir = "/home/gkf/project/deepSORT/tracking_results/miss60_minscore0p3/VidORtrain_freq1_classeme",
proposal_dir = {
0:"proposals/miss60_minscore0p3/VidORtrain_freq1_part01",
1:"proposals/miss60_minscore0p3/VidORtrain_freq1_part02",
Expand All @@ -53,6 +50,8 @@
12:"proposals/miss60_minscore0p3/VidORtrain_freq1_part13",
13:"proposals/miss60_minscore0p3/VidORtrain_freq1_part14",
},
cache_dir = "datasets/cache",
cache_tag = "MEGAv7",
dim_boxfeature = 1024,
min_frames_th = 15,
max_proposal = 180,
Expand Down
24 changes: 14 additions & 10 deletions tools/train_vidor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from collections import defaultdict

from dataloaders.dataloader_vidor import Dataset
from models import BIG_C_vidor,Base_C
from models import BIG_C_vidor,Base_C,DEBUG

from utils.DataParallel import VidSGG_DataParallel
from utils.utils_func import create_logger,parse_config_py,dura_intersection_ts,vIoU_ts
Expand Down Expand Up @@ -502,10 +502,6 @@ def train_grounding_stage(
from_checkpoint = False,
ckpt_path = None
):
## import model class
temp = model_class_path.split('.')[0].split('/')
model_class_path = ".".join(temp)
DEBUG = import_module(model_class_path).DEBUG

## create dirs and logger
if experiment_dir == None:
Expand Down Expand Up @@ -536,8 +532,9 @@ def train_grounding_stage(
trainable_num = sum([p.numel() for p in model.parameters() if p.requires_grad])
logger.info("number of model.parameters: total:{},trainable:{}".format(total_num,trainable_num))

model = VORG_DataParallel(model,device_ids=device_ids)
model = model.cuda("cuda:{}".format(device_ids[0]))
device_ids = list(range(torch.cuda.device_count()))
model = VidSGG_DataParallel(model,device_ids=device_ids)
model = model.cuda()

# training configs

Expand Down Expand Up @@ -566,13 +563,11 @@ def train_grounding_stage(
dataset_len,batch_size,dataloader_len,batch_size,dataloader_len,batch_size*dataloader_len
)
)


milestones = [int(m*dataset_len/batch_size) for m in epoch_lr_milestones]
optimizer = torch.optim.Adam(model.parameters(), lr = initial_lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones,gamma=lr_decay)


if from_checkpoint:
model,optimizer,scheduler,crt_epoch,batch_size_ = load_checkpoint(model,optimizer,scheduler,ckpt_path)
# assert batch_size == batch_size_ , "batch_size from checkpoint not match : {} != {}"
Expand Down Expand Up @@ -637,6 +632,8 @@ def train_grounding_stage(
save_path = os.path.join(experiment_dir,'model_epoch_{}_{}.pth'.format(total_epoch,save_tag))
save_checkpoint(batch_size,epoch,model,optimizer,scheduler,save_path)
logger.info("checkpoint is saved: {}".format(save_path))
logger.info(f"log saved at {log_path}")
logger.handlers.clear()



Expand Down Expand Up @@ -694,9 +691,16 @@ def train_grounding_stage(
--cfg_path experiments/exp5/config_.py \
--save_tag retrain
## for exp6
## for exp6 (80 epochs, around 6.5 hours for 1 RTX 2080Ti with batch_size=4)
CUDA_VISIBLE_DEVICES=1 python tools/train_vidor.py \
--train_baseline \
--cfg_path experiments/exp6/config_.py \
--save_tag retrain
## for train grounding stage
CUDA_VISIBLE_DEVICES=2,3 python tools/train_vidor.py \
--train_grounding \
--cfg_path experiments/grounding_weights/config_.py \
--save_tag retrain
'''

0 comments on commit 80afefa

Please sign in to comment.