forked from yangheng95/PyABSA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_sst_tadbert.py
63 lines (55 loc) · 2.18 KB
/
train_sst_tadbert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# -*- coding: utf-8 -*-
# file: train_text_classification_bert.py
# time: 2021/8/5
# author: yangheng <hy345@exeter.ac.uk>
# github: https://github.com/yangheng95
# Copyright (C) 2021. All Rights Reserved.
import os
import random
import warnings
import findfile
# Transfer Experiments and Multitask Experiments
from pyabsa import TCTrainer, TADConfigManager, TCDatasetList, BERTTADModelList, TADTrainer
from pyabsa.functional.dataset.dataset_manager import AdvTCDatasetList, DatasetItem
warnings.filterwarnings('ignore')
seeds = [random.randint(1, 10000) for _ in range(1)]
def get_config():
config = TADConfigManager.get_tad_config_english()
config.model = BERTTADModelList.TADBERT
config.num_epoch = 1
config.pretrained_bert = 'bert-base-uncased'
config.patience = 5
config.evaluate_begin = 0
config.max_seq_len = 80
config.log_step = -1
config.dropout = 0.5
config.learning_rate = 1e-5
config.cache_dataset = False
config.seed = seeds
config.l2reg = 1e-5
config.cross_validate_fold = -1
return config
# dataset = DatasetItem('SST2')
# text_classifier = TADTrainer(config=get_config(),
# dataset=dataset,
# checkpoint_save_mode=1,
# auto_device=True
# ).load_trained_model()
dataset = DatasetItem('SST2BAE')
text_classifier = TADTrainer(config=get_config(),
dataset=dataset,
checkpoint_save_mode=1,
auto_device=True
).load_trained_model()
# dataset = DatasetItem('SST2PWWS')
# text_classifier = TADTrainer(config=get_config(),
# dataset=dataset,
# checkpoint_save_mode=1,
# auto_device=True
# ).load_trained_model()
# dataset = DatasetItem('SST2TextFooler')
# text_classifier = TADTrainer(config=get_config(),
# dataset=dataset,
# checkpoint_save_mode=1,
# auto_device=True
# ).load_trained_model()