Skip to content

Commit

Permalink
update code for OIV6
Browse files Browse the repository at this point in the history
  • Loading branch information
yrcong committed May 1, 2023
1 parent 03d80d6 commit 4d676ee
Show file tree
Hide file tree
Showing 27 changed files with 610 additions and 33 deletions.
2 changes: 1 addition & 1 deletion datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ def get_coco_api_from_dataset(dataset):


def build_dataset(image_set, args):
if args.dataset == 'vg' or args.dataset_file == 'oi':
if args.dataset == 'vg' or args.dataset == 'oi':
return build_coco(image_set, args)
raise ValueError(f'dataset {args.dataset} not supported')
Binary file modified datasets/__pycache__/__init__.cpython-36.pyc
Binary file not shown.
Binary file modified datasets/__pycache__/coco.cpython-36.pyc
Binary file not shown.
Binary file modified datasets/__pycache__/coco_eval.cpython-36.pyc
Binary file not shown.
Binary file modified datasets/__pycache__/transforms.cpython-36.pyc
Binary file not shown.
1 change: 0 additions & 1 deletion datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def build(image_set, args):
ann_path = args.ann_path
img_folder = args.img_folder

#TODO: adapt vg as coco
if image_set == 'train':
ann_file = ann_path + 'train.json'
elif image_set == 'val':
Expand Down
72 changes: 60 additions & 12 deletions engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import util.misc as utils
from util.box_ops import rescale_bboxes
from lib.evaluation.sg_eval import BasicSceneGraphEvaluator, calculate_mR_from_evaluator_list
from lib.openimages_evaluation import task_evaluation_sg

def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
Expand Down Expand Up @@ -87,15 +88,19 @@ def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, arg
header = 'Test:'

# initilize evaluator
evaluator = BasicSceneGraphEvaluator.all_modes(multiple_preds=False)
if args.eval:
evaluator_list = []
for index, name in enumerate(data_loader.dataset.rel_categories):
if index == 0:
continue
evaluator_list.append((index, name, BasicSceneGraphEvaluator.all_modes()))
# TODO merge evaluation programs
if args.dataset == 'vg':
evaluator = BasicSceneGraphEvaluator.all_modes(multiple_preds=False)
if args.eval:
evaluator_list = []
for index, name in enumerate(data_loader.dataset.rel_categories):
if index == 0:
continue
evaluator_list.append((index, name, BasicSceneGraphEvaluator.all_modes()))
else:
evaluator_list = None
else:
evaluator_list = None
all_results = []

iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys())
coco_evaluator = CocoEvaluator(base_ds, iou_types)
Expand All @@ -108,7 +113,6 @@ def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, arg
outputs = model(samples)
loss_dict = criterion(outputs, targets)
weight_dict = criterion.weight_dict
evaluate_rel_batch(outputs, targets, evaluator, evaluator_list)

# reduce losses over all GPUs for logging purposes
loss_dict_reduced = utils.reduce_dict(loss_dict)
Expand All @@ -124,15 +128,24 @@ def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, arg
metric_logger.update(obj_error=loss_dict_reduced['obj_error'])
metric_logger.update(rel_error=loss_dict_reduced['rel_error'])

if args.dataset == 'vg':
evaluate_rel_batch(outputs, targets, evaluator, evaluator_list)
else:
evaluate_rel_batch_oi(outputs, targets, all_results)

orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
results = postprocessors['bbox'](outputs, orig_target_sizes)

res = {target['image_id'].item(): output for target, output in zip(targets, results)}
if coco_evaluator is not None:
coco_evaluator.update(res)

evaluator['sgdet'].print_stats()
if args.eval:
if args.dataset == 'vg':
evaluator['sgdet'].print_stats()
else:
task_evaluation_sg.eval_rel_results(all_results, 100, do_val=True, do_vis=False)

if args.eval and args.dataset == 'vg':
calculate_mR_from_evaluator_list(evaluator_list, 'sgdet')

# gather the stats from all processes
Expand Down Expand Up @@ -167,7 +180,7 @@ def evaluate_rel_batch(outputs, targets, evaluator, evaluator_list):
pred_sub_scores, pred_sub_classes = torch.max(outputs['sub_logits'][batch].softmax(-1)[:, :-1], dim=1)
pred_obj_scores, pred_obj_classes = torch.max(outputs['obj_logits'][batch].softmax(-1)[:, :-1], dim=1)
rel_scores = outputs['rel_logits'][batch][:,1:-1].softmax(-1)
#

pred_entry = {'sub_boxes': sub_bboxes_scaled,
'sub_classes': pred_sub_classes.cpu().clone().numpy(),
'sub_scores': pred_sub_scores.cpu().clone().numpy(),
Expand All @@ -187,3 +200,38 @@ def evaluate_rel_batch(outputs, targets, evaluator, evaluator_list):
continue
evaluator_rel['sgdet'].evaluate_scene_graph_entry(gt_entry_rel, pred_entry)


def evaluate_rel_batch_oi(outputs, targets, all_results):

for batch, target in enumerate(targets):
target_bboxes_scaled = rescale_bboxes(target['boxes'].cpu(), torch.flip(target['orig_size'],dims=[0]).cpu()).clone().numpy() # recovered boxes with original size

sub_bboxes_scaled = rescale_bboxes(outputs['sub_boxes'][batch].cpu(), torch.flip(target['orig_size'],dims=[0]).cpu()).clone().numpy()
obj_bboxes_scaled = rescale_bboxes(outputs['obj_boxes'][batch].cpu(), torch.flip(target['orig_size'],dims=[0]).cpu()).clone().numpy()

pred_sub_scores, pred_sub_classes = torch.max(outputs['sub_logits'][batch].softmax(-1)[:, :-1], dim=1)
pred_obj_scores, pred_obj_classes = torch.max(outputs['obj_logits'][batch].softmax(-1)[:, :-1], dim=1)

rel_scores = outputs['rel_logits'][batch][:, :-1].softmax(-1)

relation_idx = target['rel_annotations'].cpu().numpy()
gt_sub_boxes = target_bboxes_scaled[relation_idx[:, 0]]
gt_sub_labels = target['labels'][relation_idx[:, 0]].cpu().clone().numpy()
gt_obj_boxes = target_bboxes_scaled[relation_idx[:, 1]]
gt_obj_labels = target['labels'][relation_idx[:, 1]].cpu().clone().numpy()

img_result_dict = {'sbj_boxes': sub_bboxes_scaled,
'sbj_labels': pred_sub_classes.cpu().clone().numpy(),
'sbj_scores': pred_sub_scores.cpu().clone().numpy(),
'obj_boxes': obj_bboxes_scaled,
'obj_labels': pred_obj_classes.cpu().clone().numpy(),
'obj_scores': pred_obj_scores.cpu().clone().numpy(),
'prd_scores': rel_scores.cpu().clone().numpy(),
'image': str(target['image_id'].item())+'.jpg',
'gt_sbj_boxes': gt_sub_boxes,
'gt_sbj_labels': gt_sub_labels,
'gt_obj_boxes': gt_obj_boxes,
'gt_obj_labels': gt_obj_labels,
'gt_prd_labels': relation_idx[:, 2]
}
all_results.append(img_result_dict)
2 changes: 1 addition & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_args_parser():

parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--resume', default='ckpt/checkpoint0149.pth', help='resume from checkpoint')
parser.add_argument('--resume', default='ckpt/checkpoint0149_oi.pth', help='resume from checkpoint')
parser.add_argument('--set_cost_class', default=1, type=float,
help="Class coefficient in the matching cost")
parser.add_argument('--set_cost_bbox', default=5, type=float,
Expand Down
Binary file modified lib/__pycache__/pytorch_misc.cpython-36.pyc
Binary file not shown.
Binary file modified lib/evaluation/__pycache__/__init__.cpython-36.pyc
Binary file not shown.
Binary file modified lib/evaluation/__pycache__/sg_eval.cpython-36.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
216 changes: 216 additions & 0 deletions lib/openimages_evaluation/ap_eval_rel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Adapted from Detectron.pytorch/lib/datasets/voc_eval.py for
# this project by Ji Zhang, 2019
#-----------------------------------------------------------------------------
# Copyright (c) 2017-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
#
# Based on:
# --------------------------------------------------------
# Fast/er R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Bharath Hariharan
# --------------------------------------------------------

"""relationship AP evaluation code."""

from six.moves import cPickle as pickle
import logging
import numpy as np
import os
from tqdm import tqdm
from lib.fpn.box_intersections_cpu.bbox import bbox_overlaps

logger = logging.getLogger(__name__)

def boxes_union(boxes1, boxes2):
assert boxes1.shape == boxes2.shape
xmin = np.minimum(boxes1[:, 0], boxes2[:, 0])
ymin = np.minimum(boxes1[:, 1], boxes2[:, 1])
xmax = np.maximum(boxes1[:, 2], boxes2[:, 2])
ymax = np.maximum(boxes1[:, 3], boxes2[:, 3])
return np.vstack((xmin, ymin, xmax, ymax)).transpose()

def prepare_mAP_dets(topk_dets, cls_num):
cls_image_ids = [[] for _ in range(cls_num)]
cls_dets = [{'confidence': np.empty(0),
'BB_s': np.empty((0, 4)),
'BB_o': np.empty((0, 4)),
'BB_r': np.empty((0, 4)),
'LBL_s': np.empty(0),
'LBL_o': np.empty(0)} for _ in range(cls_num)]
cls_gts = [{} for _ in range(cls_num)]
npos = [0 for _ in range(cls_num)]
for dets in tqdm(topk_dets):
image_id = dets['image'].split('/')[-1].split('.')[0]
sbj_boxes = dets['det_boxes_s_top']
obj_boxes = dets['det_boxes_o_top']
rel_boxes = boxes_union(sbj_boxes, obj_boxes)
sbj_labels = dets['det_labels_s_top']
obj_labels = dets['det_labels_o_top']
prd_labels = dets['det_labels_p_top']
det_scores = dets['det_scores_top']
gt_boxes_sbj = dets['gt_boxes_sbj']
gt_boxes_obj = dets['gt_boxes_obj']
gt_boxes_rel = boxes_union(gt_boxes_sbj, gt_boxes_obj)
gt_labels_sbj = dets['gt_labels_sbj']
gt_labels_prd = dets['gt_labels_prd']
gt_labels_obj = dets['gt_labels_obj']
for c in range(cls_num):
cls_inds = np.where(prd_labels == c)[0]
# logger.info(cls_inds)
if len(cls_inds):
cls_sbj_boxes = sbj_boxes[cls_inds]
cls_obj_boxes = obj_boxes[cls_inds]
cls_rel_boxes = rel_boxes[cls_inds]
cls_sbj_labels = sbj_labels[cls_inds]
cls_obj_labels = obj_labels[cls_inds]
cls_det_scores = det_scores[cls_inds]
cls_dets[c]['confidence'] = np.concatenate((cls_dets[c]['confidence'], cls_det_scores))
cls_dets[c]['BB_s'] = np.concatenate((cls_dets[c]['BB_s'], cls_sbj_boxes), 0)
cls_dets[c]['BB_o'] = np.concatenate((cls_dets[c]['BB_o'], cls_obj_boxes), 0)
cls_dets[c]['BB_r'] = np.concatenate((cls_dets[c]['BB_r'], cls_rel_boxes), 0)
cls_dets[c]['LBL_s'] = np.concatenate((cls_dets[c]['LBL_s'], cls_sbj_labels))
cls_dets[c]['LBL_o'] = np.concatenate((cls_dets[c]['LBL_o'], cls_obj_labels))
cls_image_ids[c] += [image_id] * len(cls_inds)
cls_gt_inds = np.where(gt_labels_prd == c)[0]
cls_gt_boxes_sbj = gt_boxes_sbj[cls_gt_inds]
cls_gt_boxes_obj = gt_boxes_obj[cls_gt_inds]
cls_gt_boxes_rel = gt_boxes_rel[cls_gt_inds]
cls_gt_labels_sbj = gt_labels_sbj[cls_gt_inds]
cls_gt_labels_obj = gt_labels_obj[cls_gt_inds]
cls_gt_num = len(cls_gt_inds)
det = [False] * cls_gt_num
npos[c] = npos[c] + cls_gt_num
cls_gts[c][image_id] = {'gt_boxes_sbj': cls_gt_boxes_sbj,
'gt_boxes_obj': cls_gt_boxes_obj,
'gt_boxes_rel': cls_gt_boxes_rel,
'gt_labels_sbj': cls_gt_labels_sbj,
'gt_labels_obj': cls_gt_labels_obj,
'gt_num': cls_gt_num,
'det': det}
return cls_image_ids, cls_dets, cls_gts, npos


def get_ap(rec, prec):
"""Compute AP given precision and recall.
"""
# correct AP calculation
# first append sentinel values at the end
mrec = np.concatenate(([0.], rec, [1.]))
mpre = np.concatenate(([0.], prec, [0.]))

# compute the precision envelope
for i in range(mpre.size - 1, 0, -1):
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

# to calculate area under PR curve, look for points
# where X axis (recall) changes value
i = np.where(mrec[1:] != mrec[:-1])[0]

# and sum (\Delta recall) * prec
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
return ap


def ap_eval(image_ids,
dets,
gts,
npos,
rel_or_phr=True,
ovthresh=0.5):
"""
Top level function that does the relationship AP evaluation.
detpath: Path to detections
detpath.format(classname) should produce the detection results file.
classname: Category name (duh)
[ovthresh]: Overlap threshold (default = 0.5)
"""

confidence = dets['confidence']
BB_s = dets['BB_s']
BB_o = dets['BB_o']
BB_r = dets['BB_r']
LBL_s = dets['LBL_s']
LBL_o = dets['LBL_o']

# sort by confidence
sorted_ind = np.argsort(-confidence)
BB_s = BB_s[sorted_ind, :]
BB_o = BB_o[sorted_ind, :]
BB_r = BB_r[sorted_ind, :]
LBL_s = LBL_s[sorted_ind]
LBL_o = LBL_o[sorted_ind]
image_ids = [image_ids[x] for x in sorted_ind]

# go down dets and mark TPs and FPs
nd = len(image_ids)
tp = np.zeros(nd)
fp = np.zeros(nd)
gts_visited = {k: [False] * v['gt_num'] for k, v in gts.items()}
for d in range(nd):
R = gts[image_ids[d]]
visited = gts_visited[image_ids[d]]
bb_s = BB_s[d, :].astype(float)
bb_o = BB_o[d, :].astype(float)
bb_r = BB_r[d, :].astype(float)
lbl_s = LBL_s[d]
lbl_o = LBL_o[d]
ovmax = -np.inf
BBGT_s = R['gt_boxes_sbj'].astype(float)
BBGT_o = R['gt_boxes_obj'].astype(float)
BBGT_r = R['gt_boxes_rel'].astype(float)
LBLGT_s = R['gt_labels_sbj']
LBLGT_o = R['gt_labels_obj']
if BBGT_s.size > 0:
valid_mask = np.logical_and(LBLGT_s == lbl_s, LBLGT_o == lbl_o)
if valid_mask.any():
if rel_or_phr: # means it is evaluating relationships
overlaps_s = bbox_overlaps(
bb_s[None, :].astype(dtype=np.float32, copy=False),
BBGT_s.astype(dtype=np.float32, copy=False))[0]
overlaps_o = bbox_overlaps(
bb_o[None, :].astype(dtype=np.float32, copy=False),
BBGT_o.astype(dtype=np.float32, copy=False))[0]
overlaps = np.minimum(overlaps_s, overlaps_o)
else:
overlaps = bbox_overlaps(
bb_r[None, :].astype(dtype=np.float32, copy=False),
BBGT_r.astype(dtype=np.float32, copy=False))[0]
overlaps *= valid_mask
ovmax = np.max(overlaps)
jmax = np.argmax(overlaps)
else:
ovmax = 0.
jmax = -1

if ovmax > ovthresh:
if not visited[jmax]:
tp[d] = 1.
visited[jmax] = 1
else:
fp[d] = 1.
else:
fp[d] = 1.

# compute precision recall
fp = np.cumsum(fp)
tp = np.cumsum(tp)
rec = tp / (float(npos) + 1e-12)
# ground truth
prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
ap = get_ap(rec, prec)

return rec, prec, ap
Loading

0 comments on commit 4d676ee

Please sign in to comment.