-
Notifications
You must be signed in to change notification settings - Fork 3.2k
/
run_tf_squad.py
675 lines (588 loc) · 33.6 KB
/
run_tf_squad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
# Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import subprocess
import time
import argparse
import json
import logging
import tensorflow as tf
import horovod.tensorflow as hvd
from horovod.tensorflow.compression import Compression
from gpu_affinity import set_affinity
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
from tqdm import tqdm
import dllogger
from utils import is_main_process, format_step, get_rank, get_world_size, log
from configuration import ElectraConfig
from modeling import TFElectraForQuestionAnswering
from tokenization import ElectraTokenizer
from optimization import create_optimizer
from squad_utils import SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features, \
SquadResult, RawResult, get_answers
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = [
"google/electra-small-generator",
"google/electra-base-generator",
"google/electra-large-generator",
"google/electra-small-discriminator",
"google/electra-base-discriminator",
"google/electra-large-discriminator",
# See all ELECTRA models at https://huggingface.co/models?filter=electra
]
def parse_args():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--electra_model", default=None, type=str, required=True,
help="Model selected in the list: " + ", ".join(TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST))
parser.add_argument("--data_dir", default=None, type=str, required=True,
help="Path to dataset.")
parser.add_argument("--output_dir", default=".", type=str, required=True,
help="The output directory where the model checkpoints and predictions will be written.")
parser.add_argument("--init_checkpoint",
default=None,
type=str,
help="The checkpoint file from pretraining")
# Other parameters
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
parser.add_argument("--do_predict", action='store_true', help="Whether to run eval on the dev set.")
parser.add_argument("--do_eval",
action='store_true',
help="Whether to use evaluate accuracy of predictions")
parser.add_argument("--train_file", default=None, type=str, help="SQuAD json for training. E.g., train-v1.1.json")
parser.add_argument("--predict_file", default=None, type=str,
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.")
parser.add_argument("--predict_batch_size", default=8, type=int, help="Total batch size for predictions.")
parser.add_argument("--learning_rate", default=1e-4, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay_rate", default=0.01, type=float, help="Weight decay if we apply some.")
parser.add_argument("--layerwise_lr_decay", default=0.8, type=float,
help="The layerwise learning rate decay. Shallower layers have lower learning rates.")
parser.add_argument("--num_train_epochs", default=3, type=int,
help="Total number of training epochs to perform.")
parser.add_argument("--max_steps", default=-1.0, type=float,
help="Total number of training steps to perform.")
parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
"of training.")
parser.add_argument("--max_seq_length", default=384, type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
"longer than this will be truncated, and sequences shorter than this will be padded.")
parser.add_argument("--doc_stride", default=128, type=int,
help="When splitting up a long document into chunks, how much stride to take between chunks.")
parser.add_argument("--max_query_length", default=64, type=int,
help="The maximum number of tokens for the question. Questions longer than this will "
"be truncated to this length.")
parser.add_argument("--vocab_file", default=None, type=str,
help="Path to vocabulary file use for tokenization")
parser.add_argument("--ci", action="store_true", help="true if running on CI")
parser.add_argument(
"--joint_head",
default=True,
type=bool,
help="Jointly predict the start and end positions",
)
parser.add_argument(
"--beam_size",
default=4,
type=int,
help="Beam size when doing joint predictions",
)
parser.add_argument("--n_best_size", default=20, type=int,
help="The total number of n-best predictions to generate in the nbest_predictions.json "
"output file.")
parser.add_argument("--max_answer_length", default=30, type=int,
help="The maximum length of an answer that can be generated. This is needed because the start "
"and end predictions are not conditioned on one another.")
parser.add_argument("--verbose_logging", action='store_true',
help="If true, all of the warnings related to data processing will be printed. "
"A number of warnings are expected for a normal SQuAD evaluation.")
parser.add_argument("--no_cuda",
action='store_true',
help="Whether not to use CUDA when available")
parser.add_argument('--seed',
type=int,
default=42,
help="random seed for initialization")
parser.add_argument(
"--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step."
)
parser.add_argument('--gradient_accumulation_steps',
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument("--do_lower_case",
action='store_true',
help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument("--local_rank",
type=int,
default=os.getenv('LOCAL_RANK', -1),
help="local_rank for distributed training on gpus")
parser.add_argument('--amp',
action='store_true',
help="Automatic mixed precision training")
parser.add_argument('--fp16_all_reduce',
action='store_true',
help="Whether to use 16-bit all reduce")
parser.add_argument('--xla',
action='store_true',
help="Whether to use XLA")
parser.add_argument('--version_2_with_negative',
action='store_true',
help='If true, the SQuAD examples contain some that do not have an answer.')
parser.add_argument('--null_score_diff_threshold',
type=float, default=0.0,
help="If null_score - best_non_null is greater than the threshold predict null.")
parser.add_argument('--log_freq',
type=int, default=50,
help='frequency of logging loss.')
parser.add_argument('--json-summary', type=str, default="results/dllogger.json",
help='If provided, the json summary will be written to the specified file.')
parser.add_argument("--eval_script",
help="Script to evaluate squad predictions",
default="evaluate.py",
type=str)
parser.add_argument("--use_env",
action='store_true',
help="Whether to read local rank from ENVVAR")
parser.add_argument('--skip_checkpoint',
default=False,
action='store_true',
help="Whether to save checkpoints")
parser.add_argument('--disable-progress-bar',
default=False,
action='store_true',
help='Disable tqdm progress bar')
parser.add_argument("--skip_cache",
default=False,
action='store_true',
help="Whether to cache train features")
parser.add_argument("--cache_dir",
default=None,
type=str,
help="Location to cache train feaures. Will default to the dataset direct")
args = parser.parse_args()
if not args.do_train and (not args.init_checkpoint or args.init_checkpoint == 'None'):
raise ValueError("Checkpoint is required if do_train is not set")
return args
def get_dataset_from_features(features, batch_size, drop_remainder=True, ngpu=8, mode="train", v2=False):
"""Input function for training"""
all_input_ids = tf.convert_to_tensor([f.input_ids for f in features], dtype=tf.int64)
all_input_mask = tf.convert_to_tensor([f.attention_mask for f in features], dtype=tf.int64)
all_segment_ids = tf.convert_to_tensor([f.token_type_ids for f in features], dtype=tf.int64)
all_start_pos = tf.convert_to_tensor([f.start_position for f in features], dtype=tf.int64)
all_end_pos = tf.convert_to_tensor([f.end_position for f in features], dtype=tf.int64)
# if v2 else None:
all_cls_index = tf.convert_to_tensor([f.cls_index for f in features], dtype=tf.int64)
all_p_mask = tf.convert_to_tensor([f.p_mask for f in features], dtype=tf.float32)
all_is_impossible = tf.convert_to_tensor([f.is_impossible for f in features], dtype=tf.float32)
dataset = tf.data.Dataset.from_tensor_slices(
(all_input_ids, all_input_mask, all_segment_ids, all_start_pos, all_end_pos)
+ (all_cls_index, all_p_mask, all_is_impossible))
if ngpu > 1:
dataset = dataset.shard(get_world_size(), get_rank())
if mode == "train":
dataset = dataset.shuffle(batch_size * 3)
# dataset = dataset.map(self._preproc_samples,
# num_parallel_calls=multiprocessing.cpu_count()//self._num_gpus)
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
dataset = dataset.prefetch(batch_size)
return dataset
@tf.function
def train_step(model, inputs, loss, amp, opt, init, v2=False, loss_class=None, fp16=False, clip_norm=1.0):
with tf.GradientTape() as tape:
[input_ids, input_mask, segment_ids, start_positions, end_positions, cls_index, p_mask, is_impossible] = inputs
if not v2:
is_impossible = None
start_logits, end_logits, cls_logits = model(input_ids,
attention_mask=input_mask,
token_type_ids=segment_ids,
start_positions=start_positions,
end_positions=end_positions,
cls_index=cls_index,
p_mask=p_mask,
is_impossible=is_impossible,
position_ids=None,
head_mask=None,
inputs_embeds=None,
training=True,
)[0:3]
# If we are on multi-GPU, split add a dimension
if len(start_positions.shape) > 1:
start_positions = tf.squeeze(start_positions, axis=-1, name="squeeze_start_positions")
if len(end_positions.shape) > 1:
end_positions = tf.squeeze(end_positions, axis=-1, name="squeeze_end_positions")
if is_impossible is not None and len(is_impossible.shape) > 1 and v2 and cls_logits is not None:
is_impossible = tf.squeeze(is_impossible, axis=-1, name="squeeze_is_impossible")
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.shape[1]
start_positions = tf.clip_by_value(start_positions, 0, ignored_index, name="clip_start_positions")
end_positions = tf.clip_by_value(end_positions, 0, ignored_index, name="clip_end_positions")
start_loss = loss(y_true=start_positions, y_pred=tf.cast(start_logits, tf.float32))
end_loss = loss(y_true=end_positions, y_pred=tf.cast(end_logits, tf.float32))
loss_value = (start_loss + end_loss) / 2
if v2:
cls_loss_value = loss_class(y_true=is_impossible, y_pred=tf.cast(cls_logits, tf.float32))
loss_value += cls_loss_value * 0.5
unscaled_loss = tf.stop_gradient(loss_value)
if amp:
loss_value = opt.get_scaled_loss(loss_value)
tape = hvd.DistributedGradientTape(tape, sparse_as_dense=True,
compression=Compression.fp16 if fp16 else Compression.none)
gradients = tape.gradient(loss_value, model.trainable_variables)
if amp:
gradients = opt.get_unscaled_gradients(gradients)
(gradients, _) = tf.clip_by_global_norm(gradients, clip_norm=clip_norm)
opt.apply_gradients(zip(gradients, model.trainable_variables)) # , clip_norm=1.0)
if init:
hvd.broadcast_variables(model.variables, root_rank=0)
hvd.broadcast_variables(opt.variables(), root_rank=0)
return unscaled_loss # , outputs#, tape.gradient(loss_value, model.trainable_variables)
@tf.function
def infer_step(model, input_ids,
attention_mask=None,
token_type_ids=None,
cls_index=None,
p_mask=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
training=False,
):
return model(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
cls_index=cls_index,
p_mask=p_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
training=training,
)
def main():
args = parse_args()
hvd.init()
set_affinity(hvd.local_rank())
if is_main_process():
log("Running total processes: {}".format(get_world_size()))
log("Starting process: {}".format(get_rank()))
if is_main_process():
dllogger.init(backends=[dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE,
filename=args.json_summary),
dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE, step_format=format_step)])
else:
dllogger.init(backends=[])
dllogger.metadata("exact_match", {"unit": None})
dllogger.metadata("F1", {"unit": None})
dllogger.metadata("inference_sequences_per_second", {"unit": "sequences/s"})
dllogger.metadata("training_sequences_per_second", {"unit": "sequences/s"})
tf.random.set_seed(args.seed)
dllogger.log(step="PARAMETER", data={"SEED": args.seed})
# script parameters
BATCH_SIZE = args.train_batch_size
EVAL_BATCH_SIZE = args.predict_batch_size
USE_XLA = args.xla
USE_AMP = args.amp
EPOCHS = args.num_train_epochs
if not args.do_train:
EPOCHS = args.num_train_epochs = 1
log("Since running inference only, setting args.num_train_epochs to 1")
if not os.path.exists(args.output_dir) and is_main_process():
os.makedirs(args.output_dir)
# TensorFlow configuration
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')
tf.config.optimizer.set_jit(USE_XLA)
#tf.config.optimizer.set_experimental_options({"auto_mixed_precision": USE_AMP})
if args.amp:
policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16", loss_scale="dynamic")
tf.keras.mixed_precision.experimental.set_policy(policy)
print('Compute dtype: %s' % policy.compute_dtype) # Compute dtype: float16
print('Variable dtype: %s' % policy.variable_dtype) # Variable dtype: float32
if is_main_process():
log("***** Loading tokenizer and model *****")
# Load tokenizer and model from pretrained model/vocabulary. Specify the number of labels to classify (2+: classification, 1: regression)
electra_model = args.electra_model
config = ElectraConfig.from_pretrained(electra_model, cache_dir=args.cache_dir)
config.update({"amp": args.amp})
if args.vocab_file is None:
tokenizer = ElectraTokenizer.from_pretrained(electra_model, cache_dir=args.cache_dir)
else:
tokenizer = ElectraTokenizer(
vocab_file=args.vocab_file,
do_lower_case=args.do_lower_case)
model = TFElectraForQuestionAnswering.from_pretrained(electra_model, config=config, cache_dir=args.cache_dir, args=args)
if is_main_process():
log("***** Loading dataset *****")
# Load data
processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
train_examples = processor.get_train_examples(args.data_dir) if args.do_train else None
dev_examples = processor.get_dev_examples(args.data_dir) if args.do_predict else None
if is_main_process():
log("***** Loading features *****")
# Load cached features
squad_version = '2.0' if args.version_2_with_negative else '1.1'
if args.cache_dir is None:
args.cache_dir = args.data_dir
cached_train_features_file = args.cache_dir.rstrip('/') + '/' + 'TF2_train-v{4}.json_{1}_{2}_{3}'.format(
electra_model.split("/")[1], str(args.max_seq_length), str(args.doc_stride),
str(args.max_query_length), squad_version)
cached_dev_features_file = args.cache_dir.rstrip('/') + '/' + 'TF2_dev-v{4}.json_{1}_{2}_{3}'.format(
electra_model.split("/")[1], str(args.max_seq_length), str(args.doc_stride),
str(args.max_query_length), squad_version)
try:
with open(cached_train_features_file, "rb") as reader:
train_features = pickle.load(reader) if args.do_train else []
with open(cached_dev_features_file, "rb") as reader:
dev_features = pickle.load(reader) if args.do_predict else []
except:
train_features = ( # TODO: (yy) do on rank 0?
squad_convert_examples_to_features(
examples=train_examples,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=True,
return_dataset="",
)
if args.do_train
else []
)
dev_features = (
squad_convert_examples_to_features(
examples=dev_examples,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=False,
return_dataset="",
)
if args.do_predict
else []
)
# Dump Cached features
if not args.skip_cache and is_main_process():
if args.do_train:
log("***** Building Cache Files: {} *****".format(cached_train_features_file))
with open(cached_train_features_file, "wb") as writer:
pickle.dump(train_features, writer)
if args.do_predict:
log("***** Building Cache Files: {} *****".format(cached_dev_features_file))
with open(cached_dev_features_file, "wb") as writer:
pickle.dump(dev_features, writer)
len_train_features = len(train_features)
total_train_steps = int((len_train_features * EPOCHS / BATCH_SIZE) / get_world_size()) + 1
train_steps_per_epoch = int((len_train_features / BATCH_SIZE) / get_world_size()) + 1
len_dev_features = len(dev_features)
total_dev_steps = int((len_dev_features / EVAL_BATCH_SIZE)) + 1
train_dataset = get_dataset_from_features(train_features, BATCH_SIZE,
v2=args.version_2_with_negative) if args.do_train else []
dev_dataset = get_dataset_from_features(dev_features, EVAL_BATCH_SIZE, drop_remainder=False, ngpu=1, mode="dev",
v2=args.version_2_with_negative) if args.do_predict else []
opt = create_optimizer(init_lr=args.learning_rate, num_train_steps=total_train_steps,
num_warmup_steps=int(args.warmup_proportion * total_train_steps),
weight_decay_rate=args.weight_decay_rate,
layerwise_lr_decay=args.layerwise_lr_decay,
n_transformer_layers=model.num_hidden_layers)
if USE_AMP:
# loss scaling is currently required when using mixed precision
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, "dynamic")
# Define loss function
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
loss_class = tf.keras.losses.BinaryCrossentropy(
from_logits=True,
name='binary_crossentropy'
)
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
model.compile(optimizer=opt, loss=loss, metrics=[metric])
train_loss_results = []
if args.do_train and is_main_process():
log("***** Running training *****")
log(" Num examples = ", len_train_features)
log(" Num Epochs = ", args.num_train_epochs)
log(" Instantaneous batch size per GPU = ", args.train_batch_size)
log(
" Total train batch size (w. parallel, distributed & accumulation) = ",
args.train_batch_size
* get_world_size(),
)
log(" Total optimization steps =", total_train_steps)
total_train_time = 0
latency = []
for epoch in range(EPOCHS):
if args.do_train:
epoch_loss_avg = tf.keras.metrics.Mean()
epoch_perf_avg = tf.keras.metrics.Mean()
epoch_start = time.time()
epoch_iterator = tqdm(train_dataset, total=train_steps_per_epoch, desc="Iteration", mininterval=5,
disable=not is_main_process())
for iter, inputs in enumerate(epoch_iterator):
# breaking criterion if max_steps if > 1
if args.max_steps > 0 and (epoch * train_steps_per_epoch + iter) > args.max_steps:
break
iter_start = time.time()
# Optimize the model
loss_value = train_step(model, inputs, loss, USE_AMP, opt, (iter == 0 and epoch == 0),
v2=args.version_2_with_negative, loss_class=loss_class, fp16=USE_AMP)
#introduce CPU-GPU sync for training perf computation
loss_numpy = loss_value.numpy()
epoch_perf_avg.update_state(1. * BATCH_SIZE / (time.time() - iter_start))
if iter % args.log_freq == 0:
if is_main_process():
log("\nEpoch: {:03d}, Step:{:6d}, Loss:{:12.8f}, Perf:{:5.0f}, loss_scale:{}, opt_step:{}".format(epoch, iter, loss_value,
epoch_perf_avg.result() * get_world_size(), opt.loss_scale if config.amp else 1,
int(opt.iterations)))
dllogger.log(step=(epoch, iter,), data={"step_loss": float(loss_value.numpy()),
"train_perf": float( epoch_perf_avg.result().numpy() * get_world_size())})
# Track progress
epoch_loss_avg.update_state(loss_value) # Add current batch loss
# End epoch
train_loss_results.append(epoch_loss_avg.result())
total_train_time += float(time.time() - epoch_start)
# Summarize and save checkpoint at the end of each epoch
if is_main_process():
dllogger.log(step=tuple(), data={"e2e_train_time": total_train_time,
"training_sequences_per_second": float(
epoch_perf_avg.result().numpy() * get_world_size()),
"final_loss": float(epoch_loss_avg.result().numpy())})
if not args.skip_checkpoint:
if args.ci:
checkpoint_name = "{}/electra_base_qa_v2_{}_epoch_{}_ckpt".format(args.output_dir, args.version_2_with_negative, epoch + 1)
else:
checkpoint_name = "checkpoints/electra_base_qa_v2_{}_epoch_{}_ckpt".format(args.version_2_with_negative, epoch + 1)
if is_main_process():
model.save_weights(checkpoint_name)
if args.do_predict and (args.evaluate_during_training or epoch == args.num_train_epochs - 1):
if not args.do_train:
log("***** Loading checkpoint: {} *****".format(args.init_checkpoint))
model.load_weights(args.init_checkpoint).expect_partial()
current_feature_id = 0
all_results = []
if is_main_process():
log("***** Running evaluation *****")
log(" Num Batches = ", total_dev_steps)
log(" Batch size = ", args.predict_batch_size)
raw_infer_start = time.time()
if is_main_process():
infer_perf_avg = tf.keras.metrics.Mean()
dev_iterator = tqdm(dev_dataset, total=total_dev_steps, desc="Iteration", mininterval=5,
disable=not is_main_process())
for input_ids, input_mask, segment_ids, start_positions, end_positions, cls_index, p_mask, is_impossible in dev_iterator:
# training=False is needed only if there are layers with different
# behavior during training versus inference (e.g. Dropout).
iter_start = time.time()
if not args.joint_head:
batch_start_logits, batch_end_logits = infer_step(model, input_ids,
attention_mask=input_mask,
token_type_ids=segment_ids,
)[:2]
#Synchronize with GPU to compute time
_ = batch_start_logits.numpy()
else:
outputs = infer_step(model, input_ids,
attention_mask=input_mask,
token_type_ids=segment_ids,
cls_index=cls_index,
p_mask=p_mask,
)
#Synchronize with GPU to compute time
_ = outputs[0].numpy()
infer_time = (time.time() - iter_start)
infer_perf_avg.update_state(1. * EVAL_BATCH_SIZE / infer_time)
latency.append(infer_time)
for iter_ in range(input_ids.shape[0]):
if not args.joint_head:
start_logits = batch_start_logits[iter_].numpy().tolist()
end_logits = batch_end_logits[iter_].numpy().tolist()
dev_feature = dev_features[current_feature_id]
current_feature_id += 1
unique_id = int(dev_feature.unique_id)
all_results.append(RawResult(unique_id=unique_id,
start_logits=start_logits,
end_logits=end_logits))
else:
dev_feature = dev_features[current_feature_id]
current_feature_id += 1
unique_id = int(dev_feature.unique_id)
output = [output[iter_].numpy().tolist() for output in outputs]
start_logits = output[0]
start_top_index = output[1]
end_logits = output[2]
end_top_index = output[3]
cls_logits = output[4]
result = SquadResult(
unique_id,
start_logits,
end_logits,
start_top_index=start_top_index,
end_top_index=end_top_index,
cls_logits=cls_logits,
)
all_results.append(result)
# Compute and save predictions
answers, nbest_answers = get_answers(dev_examples, dev_features, all_results, args)
output_prediction_file = os.path.join(args.output_dir, "predictions.json")
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
e2e_infer_time = time.time() - raw_infer_start
# if args.version_2_with_negative:
# output_null_log_odds_file = os.path.join(args.output_dir, "null_odds.json")
# else:
# output_null_log_odds_file = None
with open(output_prediction_file, "w") as f:
f.write(json.dumps(answers, indent=4) + "\n")
with open(output_nbest_file, "w") as f:
f.write(json.dumps(nbest_answers, indent=4) + "\n")
if args.do_eval:
if args.version_2_with_negative:
dev_file = "dev-v2.0.json"
else:
dev_file = "dev-v1.1.json"
eval_out = subprocess.check_output([sys.executable, args.eval_script,
args.data_dir + "/" + dev_file, output_prediction_file])
log(eval_out.decode('UTF-8'))
scores = str(eval_out).strip()
exact_match = float(scores.split(":")[1].split(",")[0])
if args.version_2_with_negative:
f1 = float(scores.split(":")[2].split(",")[0])
else:
f1 = float(scores.split(":")[2].split("}")[0])
log("Epoch: {:03d} Results: {}".format(epoch, eval_out.decode('UTF-8')))
log("**EVAL SUMMARY** - Epoch: {:03d}, EM: {:6.3f}, F1: {:6.3f}, Infer_Perf: {:4.0f} seq/s"
.format(epoch, exact_match, f1, infer_perf_avg.result()))
latency_all = sorted(latency)[:-2]
log(
"**LATENCY SUMMARY** - Epoch: {:03d}, Ave: {:6.3f} ms, 90%: {:6.3f} ms, 95%: {:6.3f} ms, 99%: {:6.3f} ms"
.format(epoch, sum(latency_all) / len(latency_all) * 1000,
sum(latency_all[:int(len(latency_all) * 0.9)]) / int(len(latency_all) * 0.9) * 1000,
sum(latency_all[:int(len(latency_all) * 0.95)]) / int(len(latency_all) * 0.95) * 1000,
sum(latency_all[:int(len(latency_all) * 0.99)]) / int(len(latency_all) * 0.99) * 1000,
))
dllogger.log(step=tuple(),
data={"inference_sequences_per_second": float(infer_perf_avg.result().numpy()),
"e2e_inference_time": e2e_infer_time})
if is_main_process() and args.do_train and args.do_eval:
log(
"**RESULTS SUMMARY** - EM: {:6.3f}, F1: {:6.3f}, Train_Time: {:4.0f} s, Train_Perf: {:4.0f} seq/s, Infer_Perf: {:4.0f} seq/s"
.format(exact_match, f1, total_train_time, epoch_perf_avg.result() * get_world_size(),
infer_perf_avg.result()))
dllogger.log(step=tuple(), data={"exact_match": exact_match, "F1": f1})
if __name__ == "__main__":
main()