Skip to content

Commit

Permalink
Adahybrid integrated in query_strategies (#79)
Browse files Browse the repository at this point in the history
* Seamlessly Integrate AdaHybrid

* adahybird integrated - no additional advancement yet
  • Loading branch information
Seondong committed Feb 4, 2021
1 parent 365d0a0 commit 95b73ff
Show file tree
Hide file tree
Showing 20 changed files with 302 additions and 182 deletions.
216 changes: 85 additions & 131 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from collections import defaultdict
from datetime import timedelta
import datetime
from query_strategies import badge, bATE, upDATE, gATE, random, DATE, diversity, uncertainty, hybrid, xgb, xgb_lr, ssl_ae, tabnet, deepSAD, multideepSAD

import numpy as np
import torch
import torch.utils.data as Data
Expand All @@ -24,66 +22,50 @@
from model.AttTreeEmbedding import Attention, DATEModel
from ranger import Ranger
from utils import torch_threshold, metrics, metrics_active
from query_strategies import random, xgb, xgb_lr, badge, DATE, diversity, bATE, upDATE, gATE, ssl_ae, tabnet, uncertainty, deepSAD, multideepSAD, VIME, adahybrid, hybrid
warnings.filterwarnings("ignore")

class ExpWeights(object):
""" Expenential weight helper, adapted form RP1 paper
sample: sample the weights based on its underlying distribution l
update_dists(feedback): update the underlying distribution with new feedback (should be loss)
"""
def __init__(self,
arms=[0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0],
lr = 2,
window = 20, # we don't use this yet..
epsilon = 0,
decay = 1):

self.arms = arms
self.l = {i:0 for i in range(len(self.arms))}
self.arm = 0
self.value = self.arms[self.arm]
self.error_buffer = []
self.window = window
self.lr = lr
self.epsilon = epsilon
self.decay = decay

self.choices = [self.arm]
self.data = []

def sample(self):

if np.random.uniform() > self.epsilon:
self.p = [np.exp(x) for x in self.l.values()]
self.p /= np.sum(self.p) # normalize to make it a distribution
print(self.p)
self.arm = np.random.choice(range(0,len(self.p)), p=self.p)
else:
self.arm = int(np.random.uniform() * len(self.arms))

self.value = self.arms[self.arm]
self.choices.append(self.arm)

return(self.value)

def update_dists(self, feedback, norm=1):

# Need to normalize score.
# Since this is non-stationary, subtract mean of previous 5.
if not math.isfinite(feedback):
return
self.error_buffer.append(feedback)
self.error_buffer = self.error_buffer[-5:]

feedback -= np.mean(self.error_buffer)
feedback /= norm

print(feedback)
# Selection stragies
def initialize_sampler(samp, args):
"""Initialize selection strategies"""
if samp == 'random':
sampler = random.RandomSampling(args)
elif samp == 'xgb':
sampler = xgb.XGBSampling(args)
elif samp == 'xgb_lr':
sampler = xgb_lr.XGBLRSampling(args)
elif samp == 'badge':
sampler = badge.BadgeSampling(args)
elif samp in ['DATE', 'noupDATE', 'randomupDATE']:
sampler = DATE.DATESampling(args)
elif samp == 'diversity':
sampler = diversity.DiversitySampling(args)
elif samp == 'bATE':
sampler = bATE.bATESampling(args)
elif samp == 'upDATE':
sampler = upDATE.upDATESampling(args)
elif samp == 'gATE':
sampler = gATE.gATESampling(args)
elif samp == 'ssl_ae':
sampler = ssl_ae.SSLAutoencoderSampling(args)
elif samp == 'tabnet':
sampler = tabnet.TabnetSampling(args)
elif samp == 'deepSAD': # check
sampler = deepSAD.deepSADSampling(args)
elif samp == 'multideepSAD':
sampler = multideepSAD.multideepSADSampling(args)
elif samp == 'VIME':
sampler = VIME.VIMESampling(args)
elif samp == 'hybrid':
sampler = hybrid.HybridSampling(args)
elif samp == 'adahybrid':
sampler = adahybrid.AdaHybridSampling(args)
else:
print('Make sure the sampling strategy is listed in the argument --sampling')
sampler.set_name(samp)
return sampler

self.l[self.arm] *= self.decay
self.l[self.arm] -= self.lr * feedback/max(self.p[self.arm], 1e-16)

self.data.append(feedback)

def make_logger(curr_time, name=None):
""" Initialize loggers, log files are saved under the ./intermediary/logs directory
Expand Down Expand Up @@ -132,8 +114,11 @@ def inspection_plan(rate_init, rate_final, numWeeks, option):
first_half = np.linspace(rate_init, rate_final, 10)
second_half = np.linspace(rate_final, rate_final, numWeeks - len(first_half))
return np.concatenate((first_half, second_half))





if __name__ == '__main__':

curr_time = str(round(time.time(),3))
Expand Down Expand Up @@ -252,7 +237,7 @@ def inspection_plan(rate_init, rate_final, numWeeks, option):

output_file = "./results/performances/results-" + args.output + '-' + samp + '-' + subsamps + '-' + str(final_inspection_rate) + ".csv"
with open(output_file, 'a') as ff:
output_metric_name = ['runID', 'data', 'num_train','num_valid','num_test','num_select','num_inspected','num_uninspected','num_test_illicit','test_illicit_rate', 'upper_bound_precision', 'upper_bound_recall','upper_bound_rev', 'sampling', 'initial_inspection_rate', 'current_inspection_rate', 'final_inspection_rate', 'inspection_rate_option', 'mode', 'subsamplings', 'weights','unc_mode', 'train_start', 'valid_start', 'test_start', 'test_end', 'numWeek', 'precision', 'recall', 'revenue', 'norm-precision', 'norm-recall', 'norm-revenue', 'save']
output_metric_name = ['runID', 'data', 'num_train','num_valid','num_test','num_select','num_inspected','num_uninspected','num_test_illicit','test_illicit_rate', 'upper_bound_precision', 'upper_bound_recall','upper_bound_rev', 'sampling', 'initial_inspection_rate', 'current_inspection_rate', 'final_inspection_rate', 'inspection_rate_option', 'mode', 'subsamplings', 'initial_weights', 'current_weights', 'unc_mode', 'train_start', 'valid_start', 'test_start', 'test_end', 'numWeek', 'precision', 'recall', 'revenue', 'norm-precision', 'norm-recall', 'norm-revenue', 'save']
print(",".join(output_metric_name),file=ff)

path = None
Expand All @@ -269,7 +254,20 @@ def inspection_plan(rate_init, rate_final, numWeeks, option):
confirmed_inspection_plan = inspection_plan(initial_inspection_rate, final_inspection_rate, numWeeks, inspection_rate_option)
logger.info('Inspection rate for testing periods: %s', confirmed_inspection_plan)



if samp in ['hybrid', 'adahybrid']:
subsamplings = args.subsamplings
initial_weights = [float(weight) for weight in args.weights.split("/")]
final_weights = initial_weights
else:
subsamplings = '-'
initial_weights = '-'
final_weights = '-'

# Initialize a sampler (We put it outside the week loop since we do not change sampler every week)
# NOTE: If you put this inside the week loop, new sampler is initialized every week, which means that parameters in sampler are also initialized)
sampler = initialize_sampler(samp, args)

# Customs selection simulation for long term (if test_length = 7 days, simulate for numWeeks)
for i in range(numWeeks):

Expand All @@ -286,72 +284,38 @@ def inspection_plan(rate_init, rate_final, numWeeks, option):
else:
data.offset = data.test.index[0]
current_inspection_rate = confirmed_inspection_plan[i] # ToDo: Add multiple decaying strategy
print(i, current_inspection_rate)
print(f'Test episode: #{i}, Current inspection rate: {current_inspection_rate}')
logger.info('%s, %s', data.train_lab.shape, data.test.shape)

# Initialize unceratinty module for some cases
# Initialize uncertainty module for some cases
if unc_mode == 'self-supervised':
if samp in ['bATE', 'diversity', 'hybrid', 'upDATE', 'gATE', 'adahybrid']:

if uncertainty_module is None :
uncertainty_module = uncertainty.Uncertainty(data.train_lab, './uncertainty_models/')
uncertainty_module.train()
uncertainty_module.test_data = data.test

num_samples = int(len(data.test)*current_inspection_rate/100)

# Selection stragies
def initialize_sampler(samp):
"""Initialize selection strategies"""
if samp == 'random':
sampler = random.RandomSampling(data, args)
elif samp == 'xgb':
sampler = xgb.XGBSampling(data, args)
elif samp == 'xgb_lr':
sampler = xgb_lr.XGBLRSampling(data, args)
elif samp == 'badge':
sampler = badge.BadgeSampling(data, args)
elif samp in ['DATE', 'noupDATE', 'randomupDATE']:
sampler = DATE.DATESampling(data, args)
elif samp == 'diversity':
sampler = diversity.DiversitySampling(data, args, uncertainty_module)
elif samp == 'bATE':
sampler = bATE.bATESampling(data, args, uncertainty_module)
elif samp == 'upDATE':
sampler = upDATE.upDATESampling(data, args, uncertainty_module)
elif samp == 'gATE':
sampler = gATE.gATESampling(data, args, uncertainty_module)
elif samp == 'ssl_ae':
sampler = ssl_ae.SSLAutoencoderSampling(data, args)
elif samp == 'tabnet':
sampler = tabnet.TabnetSampling(data, args)
elif samp == 'deepSAD': # check
sampler = deepSAD.deepSADSampling(data, args)
elif samp == 'multideepSAD':
sampler = multideepSAD.multideepSADSampling(data, args)
else:
print('Make sure the sampling strategy is listed in the argument --sampling')
return sampler

if samp not in ['hybrid', 'adahybrid']:
sampler = initialize_sampler(samp)
# Retrieve subsampler weights from the previous week, for hybrid models
if samp in ['hybrid', 'adahybrid']:
try:
final_weights = sampler.get_weights()
except NameError:
pass # use the previously defined final_weights (= initial_weights)

# If we need to update sampler every week, you can initialize the sampler here.

# set uncertainty module
sampler.set_uncertainty_module(uncertainty_module)

# set previous weeks' weights, for hybrid models
if samp in ['hybrid', 'adahybrid']:
sampler.set_weights(final_weights)

# set data to sampler
sampler.set_data(data)

if samp == 'hybrid':
subsamplers = [initialize_sampler(samp) for samp in args.subsamplings.split("/")]
weights = [float(weight) for weight in args.weights.split("/")]
sampler = hybrid.HybridSampling(data, args, subsamplers, weights)

if samp == 'adahybrid':
subsamplers = [initialize_sampler(samp) for samp in args.subsamplings.split("/")]
# TODO: Ideally, it should support multiple strategies.
assert(len(subsamplers) == 2)
if i == 0:
weight_sampler = ExpWeights(lr = ada_lr)
weight = weight_sampler.sample()
print(weight)
weights = [weight, 1 - weight]
sampler = hybrid.HybridSampling(data, args, subsamplers, weights)

# If it fails to query, try one more time. If it fails again, do random sampling.
try:
chosen = sampler.query(num_samples)
Expand All @@ -361,7 +325,7 @@ def initialize_sampler(samp):
traceback.print_exc()


logger.info("%s, %s, %s", len(set(chosen)), len(chosen), num_samples)
logger.info("# of unique queried item: %s, # of queried item: %s, # of samples to be queried: %s", len(set(chosen)), len(chosen), num_samples)
assert len(set(chosen)) == num_samples


Expand Down Expand Up @@ -392,25 +356,18 @@ def initialize_sampler(samp):
logger.info(f'Metrics Active DATE:\n Pr@{current_inspection_rate}:{round(active_precisions, 4)}, Re@{current_inspection_rate}:{round(active_recalls, 4)} Rev@{current_inspection_rate}:{round(active_revenues, 4)}')

with open(output_file, 'a') as ff:
if samp == 'hybrid':
subsamplings = args.subsamplings
weights = args.weights
elif samp == 'adahybrid':
subsamplings = args.subsamplings
weights = '/'.join([str(weight) for weight in weights])
else:
subsamplings = '-'
weights = '-'

upper_bound_precision = min(100*np.mean(data.test_cls_label)/current_inspection_rate, 1)
upper_bound_recall = min(current_inspection_rate/np.mean(data.test_cls_label)/100, 1)
upper_bound_revenue = sum(sorted(data.test_reg_label, reverse=True)[:len(chosen)]) / sum(data.test_reg_label)
norm_precision = active_precisions/upper_bound_precision
norm_recall = active_recalls/upper_bound_recall
norm_revenue = active_revenues/upper_bound_revenue

if samp in ['hybrid', 'adahybrid']:
initial_weights_str = '/'.join([str(weight) for weight in initial_weights])
final_weights_str = '/'.join([str(weight) for weight in final_weights])

output_metric = [curr_time, chosen_data, len(data.train_lab), len(data.valid_lab), len(data.test), len(chosen), len(inspected_imports), len(uninspected_imports), np.sum(data.test_cls_label), np.mean(data.test_cls_label), upper_bound_precision, upper_bound_recall, upper_bound_revenue, samp, initial_inspection_rate, current_inspection_rate, final_inspection_rate, inspection_rate_option, mode, subsamplings, weights, unc_mode, train_start_day.strftime('%y-%m-%d'), valid_start_day.strftime('%y-%m-%d'), test_start_day.strftime('%y-%m-%d'), test_end_day.strftime('%y-%m-%d'), i+1, round(active_precisions,4), round(active_recalls,4), round(active_revenues,4), round(norm_precision,4), round(norm_recall,4), round(norm_revenue,4), save]
output_metric = [curr_time, chosen_data, len(data.train_lab), len(data.valid_lab), len(data.test), len(chosen), len(inspected_imports), len(uninspected_imports), np.sum(data.test_cls_label), np.mean(data.test_cls_label), upper_bound_precision, upper_bound_recall, upper_bound_revenue, samp, initial_inspection_rate, current_inspection_rate, final_inspection_rate, inspection_rate_option, mode, subsamplings, initial_weights_str, final_weights_str, unc_mode, train_start_day.strftime('%y-%m-%d'), valid_start_day.strftime('%y-%m-%d'), test_start_day.strftime('%y-%m-%d'), test_end_day.strftime('%y-%m-%d'), i+1, round(active_precisions,4), round(active_recalls,4), round(active_revenues,4), round(norm_precision,4), round(norm_recall,4), round(norm_revenue,4), save]

output_metric = list(map(str,output_metric))
logger.debug(output_metric)
Expand Down Expand Up @@ -441,11 +398,8 @@ def initialize_sampler(samp):

# Review needed: Check if the weights are updated as desired.
if samp == 'adahybrid':
print(weight_sampler.p)
weight_sampler.update_dists(1-norm_precision)
logger.info(f'Ada distribution: {weight_sampler.p}')
logger.info(f'Ada arm: {weight_sampler.value}')
logger.info(f'Feedbacks: {weight_sampler.data}')
sampler.update(norm_precision)
# pdb.set_trace()

# Renew valid & test period & dataset
if i == numWeeks - 1:
Expand Down
6 changes: 3 additions & 3 deletions query_strategies/DATE.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .strategy import Strategy
from xgboost import XGBClassifier
from torch.utils.data import DataLoader
from utils import find_best_threshold,process_leaf_idx, torch_threshold, metrics
from utils import find_best_threshold, process_leaf_idx, torch_threshold, metrics
from model.AttTreeEmbedding import Attention, DATEModel
from model.utils import FocalLoss

Expand All @@ -29,8 +29,8 @@ class DATESampling(Strategy):
""" DATE strategy: Using DATE classification probability to measure fraudness of imports
Reference: DATE: Dual Attentive Tree-aware Embedding for Customs Fraud Detection; KDD 2020 """

def __init__(self, data, args):
super(DATESampling,self).__init__(data, args)
def __init__(self, args):
super(DATESampling,self).__init__(args)
self.model_name = "DATE"
self.model_path = "./intermediary/saved_models/%s-%s.pkl" % (self.model_name,self.args.identifier)
self.batch_size = args.batch_size
Expand Down
1 change: 1 addition & 0 deletions query_strategies/VIME
Submodule VIME added at 996c58
Empty file added query_strategies/VIME.py
Empty file.
Loading

0 comments on commit 95b73ff

Please sign in to comment.