Skip to content

Commit

Permalink
1.6.14
Browse files Browse the repository at this point in the history
  • Loading branch information
Yang Heng committed Dec 25, 2021
1 parent 3c6df74 commit 4809801
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 75 deletions.
45 changes: 0 additions & 45 deletions demos/aspect_polarity_classification/augment.py

This file was deleted.

6 changes: 3 additions & 3 deletions demos/aspect_polarity_classification/run_fast_lsa_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
apc_config_english.hidden_dim = 768
apc_config_english.embed_dim = 768
apc_config_english.num_epoch = 25
apc_config_english.log_step = 5
apc_config_english.log_step = 10
apc_config_english.learning_rate = 1e-5
apc_config_english.batch_size = 16
apc_config_english.evaluate_begin = 2
Expand All @@ -51,7 +51,7 @@
Laptop14 = ABSADatasetList.Laptop14
Trainer(config=apc_config_english,
dataset=Laptop14, # train set and test set will be automatically detected
checkpoint_save_mode=1, # =None to avoid save model
checkpoint_save_mode=0, # =None to avoid save model
auto_device=True # automatic choose CUDA or CPU
)

Expand Down Expand Up @@ -104,7 +104,7 @@
apc_config_english.seed = seeds

apc_config_english.cross_validate_fold = -1 # disable cross_validate
#

Laptop14 = ABSADatasetList.Laptop14
Trainer(config=apc_config_english,
dataset=Laptop14, # train set and test set will be automatically detected
Expand Down
36 changes: 18 additions & 18 deletions demos/aspect_polarity_classification/sentiment_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,33 @@

checkpoint_map = available_checkpoints(from_local=False)

examples = [
'The [ASP]battery-life[ASP], and this [ASP]battery[ASP] is ok',
'The [ASP] battery-life [ASP] is bad',
'The [ASP] battery-life [ASP] is good',
'The [ASP] battery-life [ASP] ',
'Strong build though which really adds to its [ASP]durability[ASP] .', # !sent! Positive
'Strong [ASP]build[ASP] though which really adds to its durability . !sent! Positive',
'The [ASP]battery life[ASP] is excellent - 6-7 hours without charging . !sent! Positive',
'I have had my computer for 2 weeks already and it [ASP]works[ASP] perfectly . !sent! Positive',
'And I may be the only one but I am really liking [ASP]Windows 8[ASP] . !sent! Positive',
]
sent_classifier = APCCheckpointManager.get_sentiment_classifier(checkpoint='fast_lsa_s_acc_84.9_f1_82.11.zip',
# examples = [
# 'The [ASP]battery-life[ASP], and this [ASP]battery[ASP] is ok',
# 'The [ASP] battery-life [ASP] is bad',
# 'The [ASP] battery-life [ASP] is good',
# 'The [ASP] battery-life [ASP] ',
# 'Strong build though which really adds to its [ASP]durability[ASP] .', # !sent! Positive
# 'Strong [ASP]build[ASP] though which really adds to its durability . !sent! Positive',
# 'The [ASP]battery life[ASP] is excellent - 6-7 hours without charging . !sent! Positive',
# 'I have had my computer for 2 weeks already and it [ASP]works[ASP] perfectly . !sent! Positive',
# 'And I may be the only one but I am really liking [ASP]Windows 8[ASP] . !sent! Positive',
# ]
sent_classifier = APCCheckpointManager.get_sentiment_classifier(checkpoint='fast_lsa_s_acc_85',
auto_device=True, # Use CUDA if available
)

# text = 'everything is always cooked to perfection , the [ASP]service[ASP] is excellent , the [ASP]decor[ASP] cool and understated . !sent! 1 1'
# sent_classifier.infer(text, print_result=True)

inference_sets = examples
# inference_sets = examples
#
# for ex in examples:
# result = sent_classifier.infer(ex, print_result=True)

for ex in examples:
result = sent_classifier.infer(ex, print_result=True)

inference_sets = ABSADatasetList.English
inference_sets = ABSADatasetList.Laptop14
results = sent_classifier.batch_infer(target_file=inference_sets,
print_result=True,
save_result=True,
ignore_error=False,
)
print(results)
# print(results)
2 changes: 1 addition & 1 deletion pyabsa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# Copyright (C) 2021. All Rights Reserved.

__version__ = '1.6.13a0'
__version__ = '1.6.14'
__name__ = 'pyabsa'

from termcolor import colored
Expand Down
24 changes: 21 additions & 3 deletions pyabsa/core/apc/prediction/sentiment_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def merge_results(self, results):
if final_res and "".join(final_res[-1]['text'].split()) == "".join(result['text'].split()):
final_res[-1]['aspect'].append(result['aspect'])
final_res[-1]['sentiment'].append(result['sentiment'])
final_res[-1]['confidence'].append(result['confidence'])
final_res[-1]['ref_sentiment'].append(result['ref_sentiment'])
final_res[-1]['ref_check'].append(result['ref_check'])
else:
Expand All @@ -212,6 +213,7 @@ def merge_results(self, results):
'text': result['text'].replace(' ', ' '),
'aspect': [result['aspect']],
'sentiment': [result['sentiment']],
'confidence': [result['confidence']],
'ref_sentiment': [result['ref_sentiment']],
'ref_check': [result['ref_check']]
}
Expand Down Expand Up @@ -249,13 +251,16 @@ def _infer(self, save_path=None, print_result=True):
sent = int(i_probs.argmax(axis=-1))
real_sent = int(sample['polarity'][i])

confidence = max(i_probs)

aspect = sample['aspect'][i]
text_raw = sample['text_raw'][i]

results.append({
'text': text_raw,
'aspect': aspect,
'sentiment': sent,
'confidence': confidence,
'ref_sentiment': real_sent,
'ref_check': correct[sent == real_sent] if real_sent != '-999' else '',
})
Expand All @@ -269,11 +274,24 @@ def _infer(self, save_path=None, print_result=True):
for i in range(len(result['aspect'])):
if result['ref_sentiment'][i] != -999:
if result['sentiment'][i] == result['ref_sentiment'][i]:
aspect_info = colored('{} -> {}(ref:{})'.format(result['aspect'][i], result['sentiment'][i], result['ref_sentiment'][i]), 'green')
aspect_info = colored('{} -> {}(confidence:{}, ref:{})'.format(
result['aspect'][i],
result['sentiment'][i],
round(result['confidence'][i], 3),
result['ref_sentiment'][i]),
'green')
else:
aspect_info = colored('{} -> {}(ref:{})'.format(result['aspect'][i], result['sentiment'][i], result['ref_sentiment'][i]), 'red')
aspect_info = colored('{} -> {}(confidence:{}, ref:{})'.format(
result['aspect'][i],
result['sentiment'][i],
round(result['confidence'][i], 3),
result['ref_sentiment'][i]),
'red')

else:
aspect_info = '{} -> {}'.format(result['aspect'][i], result['sentiment'][i])
aspect_info = '{} -> {}(confidence:{})'.format(result['aspect'][i],
round(result['confidence'][i], 3),
result['sentiment'][i])

text_printing = text_printing.replace(result['aspect'][i], aspect_info)
print(text_printing)
Expand Down
8 changes: 6 additions & 2 deletions pyabsa/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,14 +274,18 @@ def save_model(opt, model, tokenizer, save_path):

torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.bert.config.to_json_file(output_config_file)
tokenizer.tokenizer.save_vocabulary(model_output_dir)
if hasattr(tokenizer, 'tokenizer'):
tokenizer.tokenizer.save_vocabulary(model_output_dir)
else:
tokenizer.save_vocabulary(model_output_dir)

else:
raise ValueError('Invalid save_mode: {}'.format(opt.save_mode))
model.to(opt.device)


def check_update_log():
print(colored('check update log at https://github.com/yangheng95/PyABSA/blob/release/release-note.json', 'red'))
print(colored('check release notes at https://github.com/yangheng95/PyABSA/blob/release/release-note.json', 'red'))


def query_remote_version():
Expand Down
5 changes: 4 additions & 1 deletion pyabsa/utils/pyabsa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ def save_args(config, save_path):
def print_args(config, logger=None, mode=0):
args = [key for key in sorted(config.args.keys())]
for arg in args:
logger.info('{0}:{1}\t-->\tCalling Count:{2}'.format(arg, config.args[arg], config.args_call_count[arg]))
if logger:
logger.info('{0}:{1}\t-->\tCalling Count:{2}'.format(arg, config.args[arg], config.args_call_count[arg]))
else:
print('{0}:{1}\t-->\tCalling Count:{2}'.format(arg, config.args[arg], config.args_call_count[arg]))

# activated_args = []
# default_args = []
Expand Down
2 changes: 1 addition & 1 deletion release-note.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"1.6.13a0": {
"1.6.13": {
"1": "Some minor modifications"
},
"1.6.12": {
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@
# Choose your license
license='MIT',
install_requires=['findfile>=1.5', 'autocuda>=0.8', 'spacy', 'networkx', 'seqeval', 'update_checker', 'typing_extensions',
'tqdm', 'termcolor', 'gitpython', 'googledrivedownloader', 'transformers>4.5', 'torch>=1.0', 'sentencepiece', 'termcolor'],
'tqdm', 'termcolor', 'gitpython', 'googledrivedownloader', 'transformers>4.5', 'torch>=1.0', 'sentencepiece'],
)

0 comments on commit 4809801

Please sign in to comment.