forked from VowpalWabbit/vowpal_wabbit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
covington.py
117 lines (102 loc) · 5.08 KB
/
covington.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import pyvw
# the label for each word is its parent, or -1 for root
my_dataset = [ [("the", 1), # 0
("monster", 2), # 1
("ate", -1), # 2
("a", 5), # 3
("big", 5), # 4
("sandwich", 2)] # 5
,
[("the", 1), # 0
("sandwich", 2), # 1
("is", -1), # 2
("tasty", 2)] # 3
,
[("a", 1), # 0
("sandwich", 2), # 1
("ate", -1), # 2
("itself", 2)] # 3
]
class CovingtonDepParser(pyvw.SearchTask):
def __init__(self, vw, sch, num_actions):
pyvw.SearchTask.__init__(self, vw, sch, num_actions)
sch.set_options( sch.AUTO_HAMMING_LOSS | sch.AUTO_CONDITION_FEATURES )
def _run(self, sentence):
N = len(sentence)
# initialize our output so everything is a root
output = [-1 for i in range(N)]
for n in range(N):
wordN,parN = sentence[n]
for m in range(-1,N):
if m == n: continue
wordM = sentence[m][0] if m > 0 else "*root*"
# ask the question: is m the parent of n?
isParent = 2 if m == parN else 1
# construct an example
dir = 'l' if m < n else 'r'
ex = lambda: self.vw.example({'a': [wordN, dir + '_' + wordN], 'b': [wordM, dir + '_' + wordN], 'p': [wordN + '_' + wordM, dir + '_' + wordN + '_' + wordM],
'd': [ str(m-n <= d) + '<=' + str(d) for d in [-8, -4, -2, -1, 1, 2, 4, 8] ] +
[ str(m-n >= d) + '>=' + str(d) for d in [-8, -4, -2, -1, 1, 2, 4, 8] ] })
pred = self.sch.predict(examples = ex,
my_tag = (m+1)*N + n + 1,
oracle = isParent,
condition = [ (max(0, (m )*N + n + 1), 'p'),
(max(0, (m+1)*N + n ), 'q') ])
if pred == 2:
output[n] = m
break
return output
class CovingtonDepParserLDF(pyvw.SearchTask):
def __init__(self, vw, sch, num_actions):
pyvw.SearchTask.__init__(self, vw, sch, num_actions)
sch.set_options( sch.AUTO_HAMMING_LOSS | sch.IS_LDF | sch.AUTO_CONDITION_FEATURES )
def makeExample(self, sentence, n, m):
wordN = sentence[n][0]
wordM = sentence[m][0] if m >= 0 else '*ROOT*'
dir = 'l' if m < n else 'r'
ex = self.vw.example( { 'a': [wordN, dir + '_' + wordN],
'b': [wordM, dir + '_' + wordM],
'p': [wordN + '_' + wordM, dir + '_' + wordN + '_' + wordM],
'd': [ str(m-n <= d) + '<=' + str(d) for d in [-8, -4, -2, -1, 1, 2, 4, 8] ] +
[ str(m-n >= d) + '>=' + str(d) for d in [-8, -4, -2, -1, 1, 2, 4, 8] ] },
labelType=self.vw.lCostSensitive)
# the label string is (m+2):0. The :0 means cost zero (this is
# irrelevant and could be any number). +2 ensures >= 1
ex.set_label_string(str(100 + n - m) + ":0")
return ex
def _run(self, sentence):
N = len(sentence)
# initialize our output so everything is a root
output = [-1 for i in range(N)]
for n in range(N):
# make LDF examples
examples = [ lambda: self.makeExample(sentence,n,m) for m in range(-1,N) if n != m ]
# truth
parN = sentence[n][1]
oracle = parN+1 if parN < n else parN # have to -1 because we excluded n==m from list
# make a prediction
pred = self.sch.predict(examples = examples,
my_tag = n+1,
oracle = oracle,
condition = [ (n, 'p'), (n-1, 'q') ] )
output[n] = pred-1 if pred < n else pred # have to +1 because n==m excluded
return output
# TODO: if they make sure search=0 <==> ldf <==> csoaa_ldf
# demo the non-ldf version:
print 'training non-LDF'
vw = pyvw.vw("--search 2 --search_task hook --ring_size 1024 --quiet")
task = vw.init_search_task(CovingtonDepParser)
for p in range(2): # do two passes over the training data
task.learn(my_dataset)
print 'testing non-LDF'
print task.predict( [(w,-1) for w in "the monster ate a sandwich".split()] )
print 'should have printed [ 1 2 -1 4 2 ]'
# demo the ldf version:
print 'training LDF'
vw = pyvw.vw("--search 0 --csoaa_ldf m --search_task hook --ring_size 1024 --quiet")
task = vw.init_search_task(CovingtonDepParserLDF)
for p in range(100): # do two passes over the training data
task.learn(my_dataset)
print 'testing LDF'
print task.predict( [(w,-1) for w in "the monster ate a sandwich".split()] )
print 'should have printed [ 1 2 -1 4 2 ]'