-
Notifications
You must be signed in to change notification settings - Fork 352
/
test_select.py
81 lines (59 loc) · 2.26 KB
/
test_select.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from copy import deepcopy
import theano
from numpy.testing import assert_raises
from blocks.bricks.base import Brick
from blocks.select import Path, Selector
class MockBrickTop(Brick):
def __init__(self, children, **kwargs):
super(MockBrickTop, self).__init__(**kwargs)
self.children = children
self.parameters = []
class MockBrickBottom(Brick):
def __init__(self, **kwargs):
super(MockBrickBottom, self).__init__(**kwargs)
self.parameters = [theano.shared(0, "V"), theano.shared(0, "W")]
def test_path():
path1 = Path.parse("/brick")
assert path1.nodes == (Path.BrickName("brick"),)
path2 = Path.parse("/brick.W")
assert path2.nodes == (Path.BrickName("brick"), Path.ParameterName("W"))
path3 = Path.parse("/brick1/brick2")
assert path3.nodes == (Path.BrickName("brick1"), Path.BrickName("brick2"))
path4 = deepcopy(path3)
assert path4 == path3
assert path4 != path2
assert hash(path4) == hash(path3)
assert hash(path4) != hash(path2)
def test_selector_get_parameters_uniqueness():
top = MockBrickTop(
[MockBrickBottom(name="bottom"), MockBrickBottom(name="bottom")],
name="top")
selector = Selector([top])
assert_raises(ValueError, selector.get_parameters)
def test_selector():
b1 = MockBrickBottom(name="b1")
b2 = MockBrickBottom(name="b2")
b3 = MockBrickBottom(name="b3")
t1 = MockBrickTop([b1, b2], name="t1")
t2 = MockBrickTop([b2, b3], name="t2")
s1 = Selector([t1])
s11 = s1.select("/t1/b1")
assert s11.bricks[0] == b1
assert len(s11.bricks) == 1
s12 = s1.select("/t1")
assert s12.bricks[0] == t1
assert len(s12.bricks) == 1
s2 = Selector([t1, t2])
s21 = s2.select("/t2/b2")
assert s21.bricks[0] == b2
assert len(s21.bricks) == 1
assert s2.select("/t2/b2.V")[0] == b2.parameters[0]
parameters = list(s1.get_parameters().items())
assert parameters[0][0] == "/t1/b1.V"
assert parameters[0][1] == b1.parameters[0]
assert parameters[1][0] == "/t1/b1.W"
assert parameters[1][1] == b1.parameters[1]
assert parameters[2][0] == "/t1/b2.V"
assert parameters[2][1] == b2.parameters[0]
assert parameters[3][0] == "/t1/b2.W"
assert parameters[3][1] == b2.parameters[1]