Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Support format_result and fix prefix param in cityscape metric, and rename CitysMetric to CityscapesMetric #2660

Merged
merged 3 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[WIP] Fix prefix parameter in citys metric
  • Loading branch information
MeowZheng committed Feb 28, 2023
commit 3661c56c058fb54065f89e6d5be68e922cdfa469
7 changes: 4 additions & 3 deletions mmseg/evaluation/metrics/citys_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self,
assert self.metrics[0] == 'cityscapes'
self.to_label_id = to_label_id
self.suffix = suffix
self.prefix = prefix

def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
"""Process one batch of data and data_samples.
Expand All @@ -59,15 +60,15 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
data_batch (dict): A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
mkdir_or_exist(self.suffix)
mkdir_or_exist(self.prefix)

for data_sample in data_samples:
pred_label = data_sample['pred_sem_seg']['data'][0].cpu().numpy()
# results2img
if self.to_label_id:
pred_label = self._convert_to_label_id(pred_label)
basename = osp.splitext(osp.basename(data_sample['img_path']))[0]
png_filename = osp.join(self.suffix, f'{basename}.png')
png_filename = osp.join(self.prefix, f'{basename}.png')
output = Image.fromarray(pred_label.astype(np.uint8)).convert('P')
import cityscapesscripts.helpers.labels as CSLabels
palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8)
Expand Down Expand Up @@ -101,7 +102,7 @@ def compute_metrics(self, results: list) -> Dict[str, float]:
msg = '\n' + msg
print_log(msg, logger=logger)

result_dir = self.suffix
result_dir = self.prefix

eval_results = dict()
print_log(f'Evaluating results under {result_dir} ...', logger=logger)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_evaluation/test_metrics/test_citys_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,15 @@ def test_evaluate(self):
dict(**data, **result)
for data, result in zip(data_batch, predictions)
]
iou_metric = CitysMetric(citys_metrics=['cityscapes'])
iou_metric = CitysMetric(citys_metrics=['cityscapes'], prefix='tmp')
iou_metric.process(data_batch, data_samples)
res = iou_metric.evaluate(6)
self.assertIsInstance(res, dict)
# test to_label_id = True
iou_metric = CitysMetric(
citys_metrics=['cityscapes'], to_label_id=True)
citys_metrics=['cityscapes'], to_label_id=True, prefix='tmp')
iou_metric.process(data_batch, data_samples)
res = iou_metric.evaluate(6)
self.assertIsInstance(res, dict)
import shutil
shutil.rmtree('.format_cityscapes')
shutil.rmtree('tmp')