From c2b7628a67c0b73c75cf4506fe5e418ef4dbd81b Mon Sep 17 00:00:00 2001 From: Munan Ning <276585665@qq.com> Date: Mon, 11 Jul 2022 15:06:09 +0800 Subject: [PATCH] Update cityscapes.py encoding map for cityscapse dataset, transform label in [0, 33] to [0, 18] --- u2pl/dataset/cityscapes.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/u2pl/dataset/cityscapes.py b/u2pl/dataset/cityscapes.py index 7eec683..665592a 100644 --- a/u2pl/dataset/cityscapes.py +++ b/u2pl/dataset/cityscapes.py @@ -38,11 +38,46 @@ def __getitem__(self, index): image = self.img_loader(image_path, "RGB") label = self.img_loader(label_path, "L") image, label = self.transform(image, label) + + # encode cityscapes labels + self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] + self.valid_classes = [ + 7, + 8, + 11, + 12, + 13, + 17, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 31, + 32, + 33, + ] + self.ignore_index = 255 + self.class_map = dict(zip(self.valid_classes, range(19))) # zip: return tuples + label = self.encode_segmap(label) return image[0], label[0, 0].long() def __len__(self): return len(self.list_sample_new) + def encode_segmap(self, mask): + # Put all void classes to zero + for _voidc in self.void_classes: + mask[mask == _voidc] = self.ignore_index + for _validc in self.valid_classes: + mask[mask == _validc] = self.class_map[_validc] + return mask + def build_transfrom(cfg): trs_form = []