Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

An initial VilBERT model for NLVR2 #4423

Merged
merged 12 commits into from
Jul 16, 2020
1 change: 1 addition & 0 deletions .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
pull_request:
branches:
- master
- vision

jobs:
changelog:
Expand Down
11 changes: 6 additions & 5 deletions allennlp/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import argparse
import logging
import os
from typing import Any, Dict, List, Optional
from os import PathLike
from typing import Any, Dict, List, Optional, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -106,8 +107,8 @@ def train_model_from_args(args: argparse.Namespace):


def train_model_from_file(
parameter_filename: str,
serialization_dir: str,
parameter_filename: Union[str, PathLike],
serialization_dir: Union[str, PathLike],
overrides: str = "",
recover: bool = False,
force: bool = False,
Expand Down Expand Up @@ -161,7 +162,7 @@ def train_model_from_file(

def train_model(
params: Params,
serialization_dir: str,
serialization_dir: Union[str, PathLike],
recover: bool = False,
force: bool = False,
node_rank: int = 0,
Expand Down Expand Up @@ -287,7 +288,7 @@ def train_model(
def _train_worker(
process_rank: int,
params: Params,
serialization_dir: str,
serialization_dir: Union[str, PathLike],
include_package: List[str] = None,
dry_run: bool = False,
node_rank: int = 0,
Expand Down
5 changes: 3 additions & 2 deletions allennlp/common/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import tempfile
import json
from os import PathLike
from urllib.parse import urlparse
from pathlib import Path
from typing import Optional, Tuple, Union, IO, Callable, Set, List, Iterator, Iterable
Expand Down Expand Up @@ -89,7 +90,7 @@ def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[


def cached_path(
url_or_filename: Union[str, Path],
url_or_filename: Union[str, PathLike],
cache_dir: Union[str, Path] = None,
extract_archive: bool = False,
force_extract: bool = False,
Expand Down Expand Up @@ -119,7 +120,7 @@ def cached_path(
if cache_dir is None:
cache_dir = CACHE_DIRECTORY

if isinstance(url_or_filename, Path):
if isinstance(url_or_filename, PathLike):
url_or_filename = str(url_or_filename)

# If we're using the /a/b/foo.zip!c/d/file.txt syntax, handle it here.
Expand Down
7 changes: 6 additions & 1 deletion allennlp/common/logging.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import logging
from logging import Filter
import os
from os import PathLike
from typing import Union

import sys


Expand Down Expand Up @@ -54,7 +57,9 @@ def filter(self, record):
return record.levelno < logging.ERROR


def prepare_global_logging(serialization_dir: str, rank: int = 0, world_size: int = 1,) -> None:
def prepare_global_logging(
serialization_dir: Union[str, PathLike], rank: int = 0, world_size: int = 1,
) -> None:
root_logger = logging.getLogger()

# create handlers
Expand Down
5 changes: 3 additions & 2 deletions allennlp/common/params.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, List
from os import PathLike
from typing import Any, Dict, List, Union
from collections.abc import MutableMapping
from collections import OrderedDict
import copy
Expand Down Expand Up @@ -456,7 +457,7 @@ def _check_is_dict(self, new_history, value):

@classmethod
def from_file(
cls, params_file: str, params_overrides: str = "", ext_vars: dict = None
cls, params_file: Union[str, PathLike], params_overrides: str = "", ext_vars: dict = None
) -> "Params":
"""
Load a `Params` object from a configuration file.
Expand Down
3 changes: 2 additions & 1 deletion allennlp/common/testing/model_test_case.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import json
from os import PathLike
from typing import Any, Dict, Iterable, Set, Union

import torch
Expand Down Expand Up @@ -48,7 +49,7 @@ def set_up_model(self, param_file, dataset_file):

def ensure_model_can_train_save_and_load(
self,
param_file: str,
param_file: Union[PathLike, str],
tolerance: float = 1e-4,
cuda_device: int = -1,
gradients_to_ignore: Set[str] = None,
Expand Down
1 change: 1 addition & 0 deletions allennlp/data/dataset_readers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from allennlp.data.dataset_readers.sharded_dataset_reader import ShardedDatasetReader
from allennlp.data.dataset_readers.babi import BabiReader
from allennlp.data.dataset_readers.text_classification_json import TextClassificationJsonReader
from allennlp.data.dataset_readers.nlvr2_reader import Nlvr2LxmertReader
257 changes: 257 additions & 0 deletions allennlp/data/dataset_readers/nlvr2_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
from typing import List, Dict, Any
import base64
import csv
import json
import os
import sys
import time

from overrides import overrides
import numpy as np
import spacy

from allennlp.data.token_indexers.token_indexer import TokenIndexer
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.fields import (
ArrayField,
LabelField,
TextField,
MetadataField,
)
from allennlp.data.instance import Instance
from allennlp.data.tokenizers import PretrainedTransformerTokenizer
from allennlp.data.token_indexers import PretrainedTransformerIndexer


FIELDNAMES = [
"img_id",
"img_h",
"img_w",
"objects_id",
"objects_conf",
"attrs_id",
"attrs_conf",
"num_boxes",
"boxes",
"features",
]


def load_obj_tsv(fname, topk=None):
"""Load object features from tsv file.

:param fname: The path to the tsv file.
:param topk: Only load features for top K images (lines) in the tsv file.
Will load all the features if topk is either -1 or None.
:return: A list of image object features where each feature is a dict.
See FILENAMES above for the keys in the feature dict.
"""
csv.field_size_limit(sys.maxsize)
data = []
start_time = time.time()
print("Start to load Faster-RCNN detected objects from %s" % fname)
with open(fname) as f:
reader = csv.DictReader(f, FIELDNAMES, delimiter="\t")
for i, item in enumerate(reader):

for key in ["img_h", "img_w", "num_boxes"]:
item[key] = int(item[key])

boxes = item["num_boxes"]
decode_config = [
("objects_id", (boxes,), np.int64),
("objects_conf", (boxes,), np.float32),
("attrs_id", (boxes,), np.int64),
("attrs_conf", (boxes,), np.float32),
("boxes", (boxes, 4), np.float32),
("features", (boxes, -1), np.float32),
]
for key, shape, dtype in decode_config:
item[key] = np.frombuffer(base64.b64decode(item[key]), dtype=dtype)
item[key] = item[key].reshape(shape)
item[key].setflags(write=False)

data.append(item)
if topk is not None and len(data) == topk:
break
elapsed_time = time.time() - start_time
print("Loaded %d images in file %s in %d seconds." % (len(data), fname, elapsed_time))
return data


@DatasetReader.register("nlvr2_lxmert")
class Nlvr2LxmertReader(DatasetReader):
"""
Parameters
----------
text_path_prefix: ``str``
Path to folder containing text files for each dataset split. These files contain
the sentences and metadata for each task instance.
visual_path_prefix: ``str``
Path to folder containing `tsv` files with the extracted objects and visual
features
topk_images: ``int``, optional (default=-1)
Number of images to load from each split's visual features file. If -1, all
images are loaded
mask_prepositions_verbs: ``bool``, optional (default=False)
Whether to mask prepositions and verbs in each sentence
drop_prepositions_verbs: ``bool``, optional (default=False)
Whether to drop (remove without replacement) prepositions and verbs in each sentence
lazy : ``bool``, optional
Whether to load data lazily. Passed to super class.
"""

def __init__(
self,
text_path_prefix: str,
visual_path_prefix: str,
topk_images: int = -1,
mask_prepositions_verbs: bool = False,
drop_prepositions_verbs: bool = False,
lazy: bool = False,
) -> None:
super().__init__(lazy)
self.text_path_prefix = text_path_prefix
self.visual_path_prefix = visual_path_prefix
self._tokenizer = PretrainedTransformerTokenizer("bert-base-uncased")
self._token_indexers: Dict[str, TokenIndexer] = {
"tokens": PretrainedTransformerIndexer("bert-base-uncased")
}
self.topk_images = topk_images
self.mask_prepositions_verbs = mask_prepositions_verbs
self.drop_prepositions_verbs = drop_prepositions_verbs
self.image_data: Dict[str, Dict[str, Any]] = {}
# Loading Spacy to find prepositions and verbs
self.spacy = spacy.load("en_core_web_sm")

def get_all_grouped_instances(self, split: str):
text_file_path = os.path.join(self.text_path_prefix, split + ".json")
visual_file_path = os.path.join(self.visual_path_prefix, split + ".tsv")
visual_features = load_obj_tsv(visual_file_path, self.topk_images)
for img in visual_features:
self.image_data[img["img_id"]] = img
instances = []
with open(text_file_path) as f:
examples = json.load(f)
examples_dict = {}
for example in examples:
if example["img0"] not in self.image_data or example["img1"] not in self.image_data:
continue
identifier = example["identifier"].split("-")
identifier = identifier[0] + "-" + identifier[1] + "-" + identifier[-1]
if identifier not in examples_dict:
examples_dict[identifier] = {
"sent": example["sent"],
"identifier": identifier,
"image_ids": [],
}
examples_dict[identifier]["image_ids"] += [
example["img0"],
example["img1"],
]
for key in examples_dict:
instances.append(
self.text_to_instance(
examples_dict[key]["sent"],
examples_dict[key]["identifier"],
examples_dict[key]["image_ids"],
None,
None,
only_predictions=True,
)
)
return instances

@overrides
def _read(self, split: str):
text_file_path = os.path.join(self.text_path_prefix, split + ".json")
visual_file_path = os.path.join(self.visual_path_prefix, split + ".tsv")
visual_features = load_obj_tsv(visual_file_path, self.topk_images)
for img in visual_features:
self.image_data[img["img_id"]] = img
with open(text_file_path) as f:
examples = json.load(f)
for example in examples:
if example["img0"] not in self.image_data or example["img1"] not in self.image_data:
continue
yield self.text_to_instance(
example["sent"],
example["identifier"],
[example["img0"], example["img1"]],
example["label"],
)

@overrides
def text_to_instance(
self, # type: ignore
question: str,
identifier: str,
image_ids: List[str],
denotation: str = None,
only_predictions: bool = False,
) -> Instance:
if self.mask_prepositions_verbs:
doc = self.spacy(question)
prep_verb_starts = [
(token.idx, len(token))
for token in doc
if token.dep_ == "prep" or token.pos_ == "VERB"
]
new_question = ""
prev_end = 0
for (idx, length) in prep_verb_starts:
new_question += question[prev_end:idx] + self._tokenizer.tokenizer.mask_token
prev_end = idx + length
new_question += question[prev_end:]
question = new_question
elif self.drop_prepositions_verbs:
doc = self.spacy(question)
prep_verb_starts = [
(token.idx, len(token))
for token in doc
if token.dep_ == "prep" or token.pos_ == "VERB"
]
new_question = ""
prev_end = 0
for (idx, length) in prep_verb_starts:
new_question += question[prev_end:idx]
prev_end = idx + length
new_question += question[prev_end:]
question = new_question
tokenized_sentence = self._tokenizer.tokenize(question)
sentence_field = TextField(tokenized_sentence, self._token_indexers)

original_identifier = identifier
all_boxes = []
all_features = []
for key in image_ids:
img_info = self.image_data[key]
boxes = img_info["boxes"].copy()
features = img_info["features"].copy()
assert len(boxes) == len(features)

# Normalize the boxes (to 0 ~ 1)
img_h, img_w = img_info["img_h"], img_info["img_w"]
# Dim=1 indices for `boxes`: 0 and 2 are x_min and x_max, respectively;
# 1 and 3 are y_min and y_max, respectively
boxes[..., (0, 2)] /= img_w
boxes[..., (1, 3)] /= img_h
np.testing.assert_array_less(boxes, 1 + 1e-5)
np.testing.assert_array_less(-boxes, 0 + 1e-5)

all_boxes.append(boxes)
all_features.append(features)
features = np.stack(all_features)
boxes = np.stack(all_boxes)
fields = {
"visual_features": ArrayField(features),
"box_coordinates": ArrayField(boxes),
"sentence": MetadataField(question),
"image_id": MetadataField(image_ids),
"identifier": MetadataField(original_identifier),
"sentence_field": sentence_field,
}

if denotation is not None:
fields["denotation"] = LabelField(int(denotation), skip_indexing=True)
return Instance(fields)
1 change: 1 addition & 0 deletions allennlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from allennlp.models.archival import archive_model, load_archive, Archive
from allennlp.models.simple_tagger import SimpleTagger
from allennlp.models.basic_classifier import BasicClassifier
from allennlp.models.vilbert import Nlvr2Vilbert
Loading