Skip to content

Commit

Permalink
[Feature] Add GDAL backend and Support LEVIR-CD Dataset (open-mmlab#2903
Browse files Browse the repository at this point in the history
)

## Motivation

For support with reading multiple remote sensing image formats, please
refer to https://gdal.org/drivers/raster/index.html.

Byte, UInt16, Int16, UInt32, Int32, Float32, Float64, CInt16, CInt32,
CFloat32 and CFloat64 are supported for reading and writing.

Support input of two images for change detection tasks, and support the
LEVIR-CD dataset.

## Modification

Add LoadSingleRSImageFromFile in 'mmseg/datasets/transforms/loading.py'.
Load a single remote sensing image for object segmentation tasks.

Add LoadMultipleRSImageFromFile in
'mmseg/datasets/transforms/loading.py'.
Load two remote sensing images for change detection tasks.

Add ConcatCDInput  in 'mmseg/datasets/transforms/transforms.py'.
Combine images that have been separately augmented for data enhancement.

Add BaseCDDataset in 'mmseg/datasets/basesegdataset.py'
Base class for datasets used in change detection tasks.

---------

Co-authored-by: xiexinch <xiexinch@outlook.com>
  • Loading branch information
Zoulinx and xiexinch committed May 8, 2023
1 parent 77836e6 commit 77591b9
Show file tree
Hide file tree
Showing 14 changed files with 809 additions and 25 deletions.
59 changes: 59 additions & 0 deletions configs/_base_/datasets/levir_256x256.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# dataset settings
dataset_type = 'LEVIRCDDataset'
data_root = r'data/LEVIRCD'

albu_train_transforms = [
dict(type='RandomBrightnessContrast', p=0.2),
dict(type='HorizontalFlip', p=0.5),
dict(type='VerticalFlip', p=0.5)
]

train_pipeline = [
dict(type='LoadMultipleRSImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Albu', transforms=albu_train_transforms),
dict(type='ConcatCDInput'),
dict(type='PackSegInputs')
]
test_pipeline = [
dict(type='LoadMultipleRSImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='ConcatCDInput'),
dict(type='PackSegInputs')
]

tta_pipeline = [
dict(type='LoadMultipleRSImageFromFile'),
dict(
type='TestTimeAug',
transforms=[[dict(type='LoadAnnotations')],
[dict(type='ConcatCDInput')],
[dict(type='PackSegInputs')]])
]
train_dataloader = dict(
batch_size=4,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='train/A',
img_path2='train/B',
seg_map_path='train/label'),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='test/A', img_path2='test/B', seg_map_path='test/label'),
pipeline=test_pipeline))
test_dataloader = val_dataloader
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
test_evaluator = val_evaluator
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
_base_ = [
'../_base_/models/upernet_swin.py', '../_base_/datasets/levir_256x256.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_20k.py'
]
crop_size = (256, 256)
norm_cfg = dict(type='BN', requires_grad=True)
data_preprocessor = dict(
size=crop_size,
type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53, 123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375, 58.395, 57.12, 57.375])

model = dict(
data_preprocessor=data_preprocessor,
backbone=dict(
in_channels=6,
embed_dims=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
use_abs_pos_embed=False,
drop_path_rate=0.3,
patch_norm=True),
decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=2),
auxiliary_head=dict(in_channels=384, num_classes=2))

# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=dict(
type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
paramwise_cfg=dict(
custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))

param_scheduler = [
dict(
type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500),
dict(
type='PolyLR',
eta_min=0.0,
power=1.0,
begin=1500,
end=20000,
by_epoch=False,
)
]

train_dataloader = dict(batch_size=4)
val_dataloader = dict(batch_size=1)
test_dataloader = val_dataloader
10 changes: 10 additions & 0 deletions docs/en/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,16 @@ Run it with
docker run --gpus all --shm-size=8g -it -v {DATA_DIR}:/mmsegmentation/data mmsegmentation
```

### Optional Dependencies

#### Install GDAL

[GDAL](https://gdal.org/) is a translator library for raster and vector geospatial data formats. Install GDAL to read complex formats and extremely large remote sensing images.

```shell
conda install GDAL
```

## Trouble shooting

If you have some issues during the installation, please first view the [FAQ](notes/faq.md) page.
Expand Down
30 changes: 30 additions & 0 deletions docs/en/user_guides/2_dataset_prepare.md
Original file line number Diff line number Diff line change
Expand Up @@ -620,3 +620,33 @@ It includes 400 images for training, 400 images for validation and 400 images fo

- You could set Datasets version with `MapillaryDataset_v1` and `MapillaryDataset_v2` in your configs.
View the Mapillary Vistas Datasets config file here [V1.2](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/_base_/datasets/mapillary_v1.py) and [V2.0](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/_base_/datasets/mapillary_v2.py)

## LEVIR-CD

[LEVIR-CD](https://justchenhao.github.io/LEVIR/) Large-scale Remote Sensing Change Detection Dataset for Building.

Download the dataset from [here](https://justchenhao.github.io/LEVIR/).

The supplement version of the dataset can be requested on the [homepage](https://github.com/S2Looking/Dataset)

Please download the supplement version of the dataset, then unzip `LEVIR-CD+.zip`, the contents of original datasets include:

```none
│ ├── LEVIR-CD+
│ │ ├── train
│ │ │ ├── A
│ │ │ ├── B
│ │ │ ├── label
│ │ ├── test
│ │ │ ├── A
│ │ │ ├── B
│ │ │ ├── label
```

For LEVIR-CD dataset, please run the following command to crop images without overlap:

```shell
python tools/dataset_converters/levircd.py --dataset-path /path/to/LEVIR-CD+ --out_dir /path/to/LEVIR-CD
```

The size of cropped image is 256x256, which is consistent with the original paper.
10 changes: 10 additions & 0 deletions docs/zh_cn/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,16 @@ docker build -t mmsegmentation docker/
docker run --gpus all --shm-size=8g -it -v {DATA_DIR}:/mmsegmentation/data mmsegmentation
```

### 可选依赖

#### 安装 GDAL

[GDAL](https://gdal.org/) 是一个用于栅格和矢量地理空间数据格式的转换库。安装 GDAL 可以读取复杂格式和极大的遥感图像。

```shell
conda install GDAL
```

## 问题解答

如果您在安装过程中遇到了其他问题,请第一时间查阅 [FAQ](notes/faq.md) 文件。如果没有找到答案,您也可以在 GitHub 上提出 [issue](https://github.com/open-mmlab/mmsegmentation/issues/new/choose)
30 changes: 30 additions & 0 deletions docs/zh_cn/user_guides/2_dataset_prepare.md
Original file line number Diff line number Diff line change
Expand Up @@ -616,3 +616,33 @@ python tools/convert_datasets/refuge.py --raw_data_root=/path/to/refuge/REFUGE2/

- 您可以在配置中使用 `MapillaryDataset_v1``Mapillary Dataset_v2` 设置数据集版本。
在此处 [V1.2](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/_base_/datasets/mapillary_v1.py) 和 [V2.0](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/_base_/datasets/mapillary_v2.py) 查看 Mapillary Vistas 数据集配置文件

## LEVIR-CD

[LEVIR-CD](https://justchenhao.github.io/LEVIR/) 大规模遥感建筑变化检测数据集。

数据集可以在[主页](https://justchenhao.github.io/LEVIR/)上请求获得。

数据集的补充版本可以在[主页](https://github.com/S2Looking/Dataset)上请求获得。

请下载数据集的补充版本,然后解压 `LEVIR-CD+.zip`,数据集的内容包括:

```none
│ ├── LEVIR-CD+
│ │ ├── train
│ │ │ ├── A
│ │ │ ├── B
│ │ │ ├── label
│ │ ├── test
│ │ │ ├── A
│ │ │ ├── B
│ │ │ ├── label
```

对于 LEVIR-CD 数据集,请运行以下命令无重叠裁剪影像:

```shell
python tools/dataset_converters/levircd.py --dataset-path /path/to/LEVIR-CD+ --out_dir /path/to/LEVIR-CD
```

裁剪后的影像大小为256x256,与原论文保持一致。
23 changes: 14 additions & 9 deletions mmseg/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
# yapf: disable
from .ade import ADE20KDataset
from .basesegdataset import BaseSegDataset
from .basesegdataset import BaseCDDataset, BaseSegDataset
from .chase_db1 import ChaseDB1Dataset
from .cityscapes import CityscapesDataset
from .coco_stuff import COCOStuffDataset
Expand All @@ -12,6 +12,7 @@
from .hrf import HRFDataset
from .isaid import iSAIDDataset
from .isprs import ISPRSDataset
from .levir import LEVIRCDDataset
from .lip import LIPDataset
from .loveda import LoveDADataset
from .mapillary import MapillaryDataset_v1, MapillaryDataset_v2
Expand All @@ -25,13 +26,15 @@
from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad,
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,
LoadBiomedicalAnnotation, LoadBiomedicalData,
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
PackSegInputs, PhotoMetricDistortion, RandomCrop,
RandomCutOut, RandomMosaic, RandomRotate,
RandomRotFlip, Rerange, ResizeShortestEdge,
ResizeToMultiple, RGB2Gray, SegRescale)
BioMedicalRandomGamma, ConcatCDInput, GenerateEdge,
LoadAnnotations, LoadBiomedicalAnnotation,
LoadBiomedicalData, LoadBiomedicalImageFromFile,
LoadImageFromNDArray, LoadMultipleRSImageFromFile,
LoadSingleRSImageFromFile, PackSegInputs,
PhotoMetricDistortion, RandomCrop, RandomCutOut,
RandomMosaic, RandomRotate, RandomRotFlip, Rerange,
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
SegRescale)
from .voc import PascalVOCDataset

# yapf: enable
Expand All @@ -51,5 +54,7 @@
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip',
'SynapseDataset', 'REFUGEDataset', 'MapillaryDataset_v1',
'MapillaryDataset_v2', 'Albu'
'MapillaryDataset_v2', 'Albu', 'LEVIRCDDataset',
'LoadMultipleRSImageFromFile', 'LoadSingleRSImageFromFile',
'ConcatCDInput', 'BaseCDDataset'
]
Loading

0 comments on commit 77591b9

Please sign in to comment.