Skip to content

Commit

Permalink
[Model] Add oagbert metainfo version (#187)
Browse files Browse the repository at this point in the history
* add oagbert metainfo version

* fix codes format

* fix model load bugs

* add oagbert test

* remove unnecessary cuda code
  • Loading branch information
Somefive committed Feb 22, 2021
1 parent 28892f7 commit af2344b
Show file tree
Hide file tree
Showing 7 changed files with 589 additions and 3 deletions.
25 changes: 25 additions & 0 deletions cogdl/oag/bert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,11 @@ def __init__(self, config):
self.dense_act = LinearActivation(config.hidden_size, config.hidden_size, act=config.hidden_act)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)

def forward(self, hidden_states):
hidden_states = self.dense_act(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states


class BertLMPredictionHead(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
Expand All @@ -400,13 +405,33 @@ def __init__(self, config, bert_model_embedding_weights):
self.decoder.weight = bert_model_embedding_weights
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))

def forward(self, hidden_states, masked_token_indexes):
hidden_states = self.transform(hidden_states)

if masked_token_indexes is not None:
hidden_states = torch.index_select(
hidden_states.view(-1, hidden_states.shape[-1]), 0,
masked_token_indexes)

hidden_states = self.decoder(hidden_states) + self.bias
return hidden_states


class BertPreTrainingHeads(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
super(BertPreTrainingHeads, self).__init__()
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
self.seq_relationship = nn.Linear(config.hidden_size, 2)

def forward(self,
sequence_output,
pooled_output,
masked_token_indexes=None):
prediction_scores = self.predictions(sequence_output,
masked_token_indexes)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score


class BertPreTrainedModel(nn.Module):
"""An abstract class to handle weights initialization and
Expand Down
127 changes: 127 additions & 0 deletions cogdl/oag/dual_position_bert_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
import logging
from .bert_model import BertPreTrainedModel, BertPreTrainingHeads, BertModel, BertEncoder, BertPooler, BertLayerNorm


logger = logging.getLogger(__name__)


class DualPositionBertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""

def __init__(self, config):
super(DualPositionBertEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.position_embeddings_second = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, input_ids, token_type_ids, position_ids, position_ids_second):
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)

words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
position_embeddings_second = self.position_embeddings(position_ids_second)
token_type_embeddings = self.token_type_embeddings(token_type_ids)

embeddings = words_embeddings + position_embeddings + position_embeddings_second + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings


class DualPositionBertModel(BertModel):
def __init__(self, config):
super(DualPositionBertModel, self).__init__(config)
self.embeddings = DualPositionBertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
self.apply(self.init_bert_weights)
logger.info("Init BERT pretrain model")

def forward(self,
input_ids,
token_type_ids=None,
attention_mask=None,
output_all_encoded_layers=True,
checkpoint_activations=False,
position_ids=None,
position_ids_second=None):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(
dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
embedding_output = self.embeddings(input_ids, token_type_ids, position_ids, position_ids_second)
encoded_layers = self.encoder(
embedding_output,
extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers,
checkpoint_activations=checkpoint_activations)
sequence_output = encoded_layers[-1]
pooled_output = self.pooler(sequence_output)

if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
return encoded_layers, pooled_output


class DualPositionBertForPreTrainingPreLN(BertPreTrainedModel):
"""BERT model with pre-training heads and dual position
Params:
config: a BertConfig class instance with the configuration to build a new model.
"""

def __init__(self, config):
super(DualPositionBertForPreTrainingPreLN, self).__init__(config)
self.bert = DualPositionBertModel(config)
self.cls = BertPreTrainingHeads(
config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_bert_weights)

def forward(self,
input_ids,
token_type_ids=None,
attention_mask=None,
masked_lm_labels=None,
position_ids=None,
position_ids_second=None,
log=True):
sequence_output, pooled_output = self.bert(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
output_all_encoded_layers=False,
checkpoint_activations=False,
position_ids=position_ids,
position_ids_second=position_ids_second)

if masked_lm_labels is not None:
# filter out all masked labels.
masked_token_indexes = torch.nonzero(
(masked_lm_labels + 1).view(-1)).view(-1)
prediction_scores, _ = self.cls(
sequence_output, pooled_output, masked_token_indexes)
target = torch.index_select(masked_lm_labels.view(-1), 0,
masked_token_indexes)

loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size), target)
return masked_lm_loss
else:
prediction_scores, _ = self.cls(
sequence_output, pooled_output)
return prediction_scores
19 changes: 16 additions & 3 deletions cogdl/oag/oagbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
from transformers import BertTokenizer

from .bert_model import BertConfig, BertForPreTrainingPreLN
from .oagbert_metainfo import OAGMetaInfoBertModel

PRETRAINED_MODEL_ARCHIVE_MAP = {
"oagbert-v1": "https://cloud.tsinghua.edu.cn/f/051c9f87d8544698826e/?dl=1",
"oagbert-test": "https://cloud.tsinghua.edu.cn/f/68a8d42802564d43984e/?dl=1",
"oagbert-v2-test": "https://cloud.tsinghua.edu.cn/f/baff5abe84c4483bb690/?dl=1",
"oagbert-v2": "https://cloud.tsinghua.edu.cn/f/f06448fa3c234317bd16/?dl=1"
}


Expand Down Expand Up @@ -44,12 +47,22 @@ def _load(model_name_or_path: str, load_weights: bool = False):
model_name_or_path = f"saved/{model_name_or_path}"
else:
raise KeyError("Cannot find the pretrained model {}".format(model_name_or_path))

try:
version = open(os.path.join(model_name_or_path, "version")).readline().strip()
except Exception:
version = None

bert_config = BertConfig.from_dict(json.load(open(os.path.join(model_name_or_path, "bert_config.json"))))
tokenizer = BertTokenizer.from_pretrained(model_name_or_path)
bert_model = OAGBertPretrainingModel(bert_config)
if version == "2":
bert_model = OAGMetaInfoBertModel(bert_config, tokenizer)
else:
bert_model = OAGBertPretrainingModel(bert_config)

if load_weights:
bert_model.load_state_dict(torch.load(os.path.join(model_name_or_path, "pytorch_model.bin")))
model_weight_path = os.path.join(model_name_or_path, "pytorch_model.bin")
if load_weights and os.path.exists(model_weight_path):
bert_model.load_state_dict(torch.load(model_weight_path))

return bert_config, tokenizer, bert_model

Expand Down
Loading

0 comments on commit af2344b

Please sign in to comment.