Skip to content

Commit

Permalink
COCO
Browse files Browse the repository at this point in the history
  • Loading branch information
yuyang committed Jun 3, 2020
1 parent 06940a8 commit e60a2e5
Show file tree
Hide file tree
Showing 6 changed files with 123,393 additions and 4 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ This repository have done:
- [x] YOLOv3 Head
- [x] Keras Callbacks for Online Evaluation
- [x] Load Official Weight File
- [x] Data Format Converter(COCO and Pascal VOC)
- [x] K-Means for Anchors
- [x] Fight with 'NaN'
- [x] Train (Strategy and Model Config)
Expand Down Expand Up @@ -63,7 +64,7 @@ path/to/image2 x1,y1,x2,y2,label
...
```

Convert your data format firstly. We present a script for Pascal VOC in https://github.com/yuto3o/yolox/blob/master/data/pascal_voc/voc_convert.py
Convert your data format firstly. We present [a script for Pascal VOC](./data/pascal_voc/voc_convert.py) and [a script for COCO](./data/coco/coco_convert.py).

More details and a simple dataset could be got from https://github.com/YunYang1994/yymnist.

Expand Down
5 changes: 3 additions & 2 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ YOLOv4-tiny(YOLOv4未提出,非官方)

## 被训练支配的恐惧 !

- 因为目前可用的卡是一张游戏卡RzaiTX 2070S(8 G),因此在训练时使用了较小的batch size
- 因为目前可用的卡是一张游戏卡RzaiTX 2070S(8 G),因此在训练时使用了较小的batch,实际中尽量大batch可以省很多事
- 本项目的数据增强均使用在线形式,高级的数据增强方式会大大拖慢训练速度。
- 训练过程中,Tiny版问题不大,而完整版模型容易NaN或者收敛慢,还在调参中。

Expand All @@ -26,6 +26,7 @@ This repository have done:
- [x] YOLOv3 Head
- [x] Keras Callbacks for Online Evaluation
- [x] Load Official Weight File
- [x] Data Format Converter(COCO and Pascal VOC)
- [x] K-Means for Anchors
- [x] Fight with 'NaN'
- [x] Train (Strategy and Model Config)
Expand Down Expand Up @@ -63,7 +64,7 @@ path/to/image2 x1,y1,x2,y2,label
...
```

当然本项目也提供了一个简单的VOC格式[转换脚本](./data/pascal_voc/voc_convert.py)
当然本项目也提供了一个的[VOC格式转换脚本](./data/pascal_voc/voc_convert.py)和一个[COCO格式转换脚本](./data/coco/coco_convert.py)

也可以从其他大佬的项目中看到这种格式的运用,甚至可以得到一个简单的入门级目标检测数据集: https://github.com/YunYang1994/yymnist.

Expand Down
99 changes: 99 additions & 0 deletions data/coco/coco_convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# -*- coding: utf-8 -*-
import os
import json
import pprint

from absl import logging, app, flags
from collections import defaultdict

flags.DEFINE_string('coco_path', None, 'path to coco dataset')
flags.DEFINE_string('name_path', None, 'path to coco name file')
flags.DEFINE_string('txt_output_path', None, 'path to output txt file')
flags.DEFINE_boolean('use_crowd', True, 'use crowd annotation')

FLAGS = flags.FLAGS


def convert(coco_path, coco_name_path, txt_output_path, use_crowd=True):
def _read_txt_line(path):
with open(path, 'r') as f:
txt = f.readlines()

return [line.strip() for line in txt]

ann_path_train2017 = os.path.join(coco_path, 'annotations', 'instances_train2017.json')
img_path_train2017 = os.path.join(coco_path, 'images', 'train2017')

ann_path_val2017 = os.path.join(coco_path, 'annotations', 'instances_val2017.json')
img_path_val2017 = os.path.join(coco_path, 'images', 'val2017')

coco_name = _read_txt_line(coco_name_path)

def _check_bbox(x1, y1, x2, y2, w, h):

if x1 < 0 or x2 < 0 or x1 > w or x2 > w or y1 < 0 or y2 < 0 or y1 > h or y2 > h:
logging.warning('cross boundary (' + str(w) + ',' + str(h) + '),(' + ','.join(
[str(x1), str(y1), str(x2), str(y2)]) + ')')

return str(min(max(x1, 0.), w)), str(min(max(y1, 0.), h)), str(min(max(x2, 0.), w)), str(
min(max(y2, 0.), h))

return x1, y1, x2, y2

def _write_to_text(ann_path, img_path, txt_path):
dataset = json.load(open(ann_path, 'r'))
print('creating index...')
anns, cats, imgs = {}, {}, {}
imgToAnns = defaultdict(list)
if 'annotations' in dataset:
for ann in dataset['annotations']:
imgToAnns[ann['image_id']].append(ann)
anns[ann['id']] = ann

if 'images' in dataset:
for img in dataset['images']:
imgs[img['id']] = img

if 'categories' in dataset:
for cat in dataset['categories']:
cats[cat['id']] = cat

print('Categories')
pprint.pprint(cats)
print('index created!')

with open(txt_path, 'w') as f:
for img_id, img in imgs.items():
anns = imgToAnns[img_id]
iw, ih = img['width'], img['height']
file_name = img['file_name']

line = os.path.join(img_path, file_name)

for ann in anns:

label = cats[ann['category_id']]['name']
if label not in coco_name:
continue
if not use_crowd and ann['iscrowd'] == 1:
continue
idx = coco_name.index(label)
x, y, w, h = ann['bbox']
x1, y1, x2, y2 = x, y, x + w, y + h
x1, y1, x2, y2 = _check_bbox(x1, y1, x2, y2, iw, ih)

line += ' ' + ','.join([str(x1), str(y1), str(x2), str(y2), str(idx)])

logging.info(line)
f.write(line + '\n')

_write_to_text(ann_path_train2017, img_path_train2017, os.path.join(txt_output_path, 'train2017.txt'))
_write_to_text(ann_path_val2017, img_path_val2017, os.path.join(txt_output_path, 'val2017.txt'))


def main(_argv):
convert(FLAGS.coco_path, FLAGS.name_path, FLAGS.txt_output_path, FLAGS.use_crowd)


if __name__ == '__main__':
app.run(main)
Loading

0 comments on commit e60a2e5

Please sign in to comment.