-
Notifications
You must be signed in to change notification settings - Fork 0
/
symmetry_transforms.py
164 lines (136 loc) · 6.98 KB
/
symmetry_transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from cenotaph.basics.neighbourhood import Neighbourhood
from cenotaph.basics.neighbourhood import SquareNeighbourhood
from cenotaph.basics.generic_functions import combine_patterns
from cenotaph.basics.matrix_displaced_copies import matrix_displaced_copies
from cenotaph.basics.base_classes import *
class SymmetryTransform(SingleChannelImageDescriptor):
"""LOCuST: LOCal Symmetry Transform. Abstract base class"""
def __init__(self, imgfile, resolution=3, geometric_transform='r90', \
statistical_test='distance (L1)', *args):
"""
Default constructor
Parameters
----------
imgfile : str
The path to the input image.
resolution : odd int
The dimension (side length) of the square neighbourhood.
geometric_transform : str
The geometric transform to check symmetry for. Can be:
'r90' -> counter-clockwise rotation by 90°
'r180' -> counter-clockwise rotation by 180°
'r270' -> counter-clockwise rotation by 270°
'hMirror' -> mirror about the horizontal mid-line
'vMirror' -> mirror about the vertical mid-line
'd1Mirror' -> mirror about the bottom right/top left diagonal
'd2Mirror' -> mirror about the bottom left/top right diagonal
statistical_test : str
The statistical test used for checking symmetry. Can be:
'L2-distance-std' -> Euclidean distance with prelimiary zero mean
and unit variance standardisation.
args : int
Entity of the displacement (in pixels). Required if geometri_transform
is a translation
"""
if not ((resolution % 2) == 1):
Exception('Resolution needs to be an odd number')
super().__init__(imgfile)
self._resolution = resolution
self._geometric_transform = geometric_transform
self._statistical_test = statistical_test
self._updated = False
self._map = np.empty(self._img.size)
#Generate the original and the transformed neighbourhood (wich are the same
#at the beginning)
self._original_neighbourhood = SquareNeighbourhood(resolution)
self._transformed_neighbourhood = SquareNeighbourhood(resolution)
def _compute_features(self):
"""Dummy implementation"""
self._features = np.empty([0])
def _compute_symmetry_map(self):
"""Compute the symmetry map"""
#Generate the tranformed patterns
self._apply_transform()
self._transformed_patterns = \
matrix_displaced_copies(self._img, \
self._transformed_neighbourhood.get_integer_points())
#Carry out the comparison between the original and the transformed patterns
self._perform_test()
self._updated = True
def get_symmetry_map(self):
"""The resulting symmetry map
Returns
-------
_map : ndarray
A matrix the same size of the input image representing the symmetry
map
"""
if not self._updated:
self._compute_symmetry_map()
self._updated = True
return self._map
def _apply_transform(self):
"""Apply the geometric transformation to the neighbourhood of points"""
if self._geometric_transform == 'r90':
self._transformed_neighbourhood.rotate(90)
elif self._geometric_transform == 'r180':
self._transformed_neighbourhood.rotate(180)
elif self._geometric_transform == 'r270':
self._transformed_neighbourhood.rotate(270)
elif self._geometric_transform == 'hMirror':
self._transformed_neighbourhood.reflect(1.0, 0.0, 0.0)
elif self._geometric_transform == 'vMirror':
self._transformed_neighbourhood.reflect(0.0, 1.0, 0.0)
elif self._geometric_transform == 'd1Mirror':
self._transformed_neighbourhood.reflect(1.0, 1.0, 0.0)
elif self._geometric_transform == 'd2Mirror':
self._transformed_neighbourhood.reflect(-1.0, 1.0, 0.0)
else:
raise Exception('Geometric transform not supported')
def _perform_test(self, neglect_fixed_points=True):
"""Perform the statistical test between the original and the transformed
patterns
Parameters
----------
neglect_fixed_points : bool
If True the fixed points are not considered in the test
"""
original_patterns = np.empty([])
transformed_patterns = np.empty([])
if neglect_fixed_points:
#Compute the fixed points
fixed_points = Neighbourhood.compare(self._original_neighbourhood,
self._transformed_neighbourhood)
#Compute the original and transformed neighbourhoods without fixed
#points
original_neighbourhood_no_fixed_points = \
Neighbourhood.from_points(self._original_neighbourhood.get_points())
original_neighbourhood_no_fixed_points.delete_points(fixed_points)
transformed_neighbourhood_no_fixed_points = \
Neighbourhood.from_points(self._transformed_neighbourhood.get_points())
transformed_neighbourhood_no_fixed_points.delete_points(fixed_points)
#Generate the original patterns
original_patterns = \
matrix_displaced_copies(self._img, \
original_neighbourhood_no_fixed_points.get_integer_points())
#Generate the transformed patterns
transformed_patterns = \
matrix_displaced_copies(self._img, \
transformed_neighbourhood_no_fixed_points.get_integer_points())
else:
#Generate the original patterns
original_patterns = \
matrix_displaced_copies(self._img, \
self._original_neighbourhood.get_integer_points())
#Generate the transformed patterns
transformed_patterns = \
matrix_displaced_copies(self._img, \
self._transformed_neighbourhood.get_integer_points())
if self._statistical_test == 'L2-distance-std':
self.standardise()
self._map = combine_patterns(original_patterns, \
transformed_patterns, \
'distance', \
n=1)
else:
raise Exception('Statistical test not supported')