-
Notifications
You must be signed in to change notification settings - Fork 4
/
gATE.py
54 lines (43 loc) · 2.09 KB
/
gATE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import numpy as np
from scipy import stats
from sklearn.metrics import pairwise_distances
from sklearn.preprocessing import normalize
from .DATE import DATESampling
from .badge import init_centers
from utils import timer_func
class gATESampling(DATESampling):
""" gATE strategy: Our proposed model for better exploration. Switch turn on/off bATE model and random depending on the DATE performance. """
def __init__(self, args):
super(gATESampling,self).__init__(args)
def get_uncertainty(self):
if self.uncertainty_module is None :
# return np.asarray(self.get_output().apply(lambda x : -1.8*abs(x-0.5) + 1))
return np.asarray(-1.8*abs(self.get_output()-0.5) + 1)
uncertainty = self.uncertainty_module.measure(self.uncertainty_module.test_data ,'feature_importance')
return np.asarray(uncertainty)[self.available_indices]
def bATE_sampling(self, k):
gradEmbedding = self.get_grad_embedding()
# normalize
# gradEmbedding = normalize(gradEmbedding, axis = 1, norm = 'l2')
# get uncertainty
uncertainty_score = self.get_uncertainty()
revs = np.asarray(self.get_revenue())
# integrate revenue and uncertainty
assert len(gradEmbedding) == len(uncertainty_score)
for idx in range(len(gradEmbedding)):
gradEmbedding[idx] = [emb*self.rev_score()(revs[idx])*uncertainty_score[idx] for emb in gradEmbedding[idx]]
chosen = init_centers(gradEmbedding, k)
return chosen
@timer_func
def query(self, k, model_available = False):
if not model_available:
self.train_xgb_model()
self.prepare_DATE_input()
self.train_DATE_model()
if self.get_model().module.performance > 0.3:
chosen = self.bATE_sampling(k)
print('bATE is used for exploration')
return self.available_indices[chosen].tolist()
else:
print('random is used for exploration')
return np.random.choice(self.available_indices, k, replace = False).tolist()