forked from microsoft/torchgeo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_batch.py
162 lines (134 loc) · 5.33 KB
/
test_batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import math
from collections.abc import Iterator
from itertools import product
import pytest
from _pytest.fixtures import SubRequest
from rasterio.crs import CRS
from torch.utils.data import DataLoader
from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples
from torchgeo.samplers import BatchGeoSampler, RandomBatchGeoSampler, Units
class CustomBatchGeoSampler(BatchGeoSampler):
def __init__(self) -> None:
pass
def __iter__(self) -> Iterator[list[BoundingBox]]:
for i in range(len(self)):
yield [BoundingBox(j, j, j, j, j, j) for j in range(len(self))]
def __len__(self) -> int:
return 2
class CustomGeoDataset(GeoDataset):
def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None:
super().__init__()
self._crs = crs
self.res = res
def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]:
return {"index": query}
class TestBatchGeoSampler:
@pytest.fixture(scope="class")
def dataset(self) -> CustomGeoDataset:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 100, 200, 300, 400, 500))
return ds
@pytest.fixture(scope="function")
def sampler(self) -> CustomBatchGeoSampler:
return CustomBatchGeoSampler()
def test_iter(self, sampler: CustomBatchGeoSampler) -> None:
expected = [BoundingBox(0, 0, 0, 0, 0, 0), BoundingBox(1, 1, 1, 1, 1, 1)]
assert next(iter(sampler)) == expected
def test_len(self, sampler: CustomBatchGeoSampler) -> None:
assert len(sampler) == 2
@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader(
self,
dataset: CustomGeoDataset,
sampler: CustomBatchGeoSampler,
num_workers: int,
) -> None:
dl = DataLoader(
dataset,
batch_sampler=sampler,
num_workers=num_workers,
collate_fn=stack_samples,
)
for _ in dl:
continue
def test_abstract(self, dataset: CustomGeoDataset) -> None:
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
BatchGeoSampler(dataset) # type: ignore[abstract]
class TestRandomBatchGeoSampler:
@pytest.fixture(scope="class")
def dataset(self) -> CustomGeoDataset:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 100, 200, 300, 400, 500))
ds.index.insert(1, (0, 100, 200, 300, 400, 500))
return ds
@pytest.fixture(
scope="function",
params=product([3, 4.5, (2, 2), (3, 4.5), (4.5, 3)], [Units.PIXELS, Units.CRS]),
)
def sampler(
self, dataset: CustomGeoDataset, request: SubRequest
) -> RandomBatchGeoSampler:
size, units = request.param
return RandomBatchGeoSampler(
dataset, size, batch_size=2, length=10, units=units
)
def test_iter(self, sampler: RandomBatchGeoSampler) -> None:
for batch in sampler:
for query in batch:
assert sampler.roi.minx <= query.minx <= query.maxx <= sampler.roi.maxx
assert sampler.roi.miny <= query.miny <= query.miny <= sampler.roi.maxy
assert sampler.roi.mint <= query.mint <= query.maxt <= sampler.roi.maxt
assert math.isclose(query.maxx - query.minx, sampler.size[1])
assert math.isclose(query.maxy - query.miny, sampler.size[0])
assert math.isclose(
query.maxt - query.mint, sampler.roi.maxt - sampler.roi.mint
)
def test_len(self, sampler: RandomBatchGeoSampler) -> None:
assert len(sampler) == sampler.length // sampler.batch_size
def test_roi(self, dataset: CustomGeoDataset) -> None:
roi = BoundingBox(0, 50, 200, 250, 400, 450)
sampler = RandomBatchGeoSampler(dataset, 2, 2, 10, roi=roi)
for batch in sampler:
for query in batch:
assert query in roi
def test_small_area(self) -> None:
ds = CustomGeoDataset(res=1)
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (20, 21, 20, 21, 20, 21))
sampler = RandomBatchGeoSampler(ds, 2, 2, 10)
for _ in sampler:
continue
def test_point_data(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 0, 0, 0, 0, 0))
ds.index.insert(1, (1, 1, 1, 1, 1, 1))
sampler = RandomBatchGeoSampler(ds, 0, 2, 10)
for _ in sampler:
continue
def test_weighted_sampling(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 0, 0, 0, 0, 0))
ds.index.insert(1, (0, 10, 0, 10, 0, 10))
sampler = RandomBatchGeoSampler(ds, 1, 2, 10)
for batch in sampler:
for bbox in batch:
assert bbox == BoundingBox(0, 10, 0, 10, 0, 10)
@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader(
self,
dataset: CustomGeoDataset,
sampler: RandomBatchGeoSampler,
num_workers: int,
) -> None:
dl = DataLoader(
dataset,
batch_sampler=sampler,
num_workers=num_workers,
collate_fn=stack_samples,
)
for _ in dl:
continue