Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Constituency js #916

Merged
merged 7 commits into from
Feb 24, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
tidy up, remove sentence lengths from return dict
  • Loading branch information
Mark Neumann committed Feb 23, 2018
commit 46cbd8c70dd6198f19cb9ec544005a2ebcf9ae71
13 changes: 1 addition & 12 deletions allennlp/models/constituency_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,6 @@ def forward(self, # type: ignore
The original spans tensor.
tokens : ``List[List[str]]``, required.
A list of tokens in the sentence for each element in the batch.
sentence_lengths : ``torch.LongTensor``, required.
A tensor of shape (batch_size), representing the lengths of the non-padded
elements of ``sentences``.
num_spans : ``torch.LongTensor``, required.
A tensor of shape (batch_size), representing the lengths of non-padded spans
in ``enumerated_spans``.
Expand All @@ -173,7 +170,6 @@ def forward(self, # type: ignore
# a length 1 sentence in PTB, which do exist. -.-
span_mask = span_mask.unsqueeze(-1)

sentence_lengths = get_lengths_from_binary_sequence_mask(mask)
num_spans = get_lengths_from_binary_sequence_mask(span_mask)

encoded_text = self.encoder(embedded_text_input, mask)
Expand All @@ -187,7 +183,6 @@ def forward(self, # type: ignore
"class_probabilities": class_probabilities,
"spans": spans,
"tokens": [meta["tokens"] for meta in metadata],
"sentence_lengths": sentence_lengths,
"num_spans": num_spans
}
if span_labels is not None:
Expand All @@ -203,7 +198,6 @@ def forward(self, # type: ignore
predicted_trees = self.construct_trees(class_probabilities.cpu().data,
spans.cpu().data,
output_dict["tokens"],
sentence_lengths.cpu().data,
num_spans.data)
self._evalb_score(predicted_trees, batch_gold_trees)

Expand All @@ -224,9 +218,8 @@ def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor
all_spans = output_dict["spans"].cpu().data

all_sentences = output_dict["tokens"]
sentence_lengths = output_dict["sentence_lengths"].data
num_spans = output_dict["num_spans"].data
trees = self.construct_trees(all_predictions, all_spans, all_sentences, sentence_lengths, num_spans)
trees = self.construct_trees(all_predictions, all_spans, all_sentences, num_spans)

batch_size = all_predictions.size(0)
output_dict["spans"] = [all_spans[i, :num_spans[i]] for i in range(batch_size)]
Expand All @@ -239,7 +232,6 @@ def construct_trees(self,
predictions: torch.FloatTensor,
all_spans: torch.LongTensor,
sentences: List[List[str]],
sentence_lengths: torch.LongTensor,
num_spans: torch.LongTensor) -> List[Tree]:
"""
Construct ``nltk.Tree``'s for each batch element by greedily nesting spans.
Expand All @@ -256,9 +248,6 @@ def construct_trees(self,
indices we scored.
sentences : ``List[List[str]]``, required.
A list of tokens in the sentence for each element in the batch.
sentence_lengths : ``torch.LongTensor``, required.
A tensor of shape (batch_size), representing the lengths of the non-padded
elements of ``sentences``.
num_spans : ``torch.LongTensor``, required.
A tensor of shape (batch_size), representing the lengths of non-padded spans
in ``enumerated_spans``.
Expand Down
48 changes: 35 additions & 13 deletions allennlp/service/predictors/constituency_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,31 +53,53 @@ def predict_batch_json(self, inputs: List[JsonDict], cuda_device: int = -1) -> L


def _build_hierplane_tree(self, tree: Tree, index: int, is_root: bool) -> JsonDict:
"""
Recursively builds a JSON dictionary from an NLTK ``Tree`` suitable for
rendering trees using the `Hierplane library<https://allenai.github.io/hierplane/>`.

Parameters
----------
tree : ``Tree``, required.
The tree to convert into Hierplane JSON.
index : int, required.
The character index into the tree, used for creating spans.
is_root : bool
An indicator which allows us to add the outer Hierplane JSON which
is required for rendering.

Returns
-------
A JSON dictionary render-able by Hierplane for the given tree.
"""
children = []
for child in tree:
if isinstance(child, Tree):
# If the child is a tree, it has children,
# as NLTK leaves are just strings.
children.append(self._build_hierplane_tree(child, index, is_root=False))
else:
print(tree)
# We're at a leaf, so add the length of
# the word to the character index.
index += len(child)

word = " ".join(tree.leaves())

span = " ".join(tree.leaves())
label = tree.label()
hierplane_node = {
"word": word,
"nodeType": tree.label(),
"attributes": [tree.label()],
"link": tree.label(),
#"spans": [{"start": index, "end": index + len(word) + 1,}],
}
"word": span,
"nodeType": label,
"attributes": [label],
"link": label,
#"spans": [{"start": index, "end": index + len(word) + 1,}],
}
if children:
hierplane_node["children"] = children
#else:
# hierplane_node["spans"] = [{"start": index, "end": index + len(word) + 1,}]
else:
# Only add spans to leaves. TODO: Ask Sam about this.
hierplane_node["spans"] = [{"start": index, "end": index + len(span) + 1,}]

if is_root:
hierplane_node = {
"text": " ".join(tree.leaves()),
"root": hierplane_node
"text": span,
"root": hierplane_node
}
return hierplane_node
2 changes: 1 addition & 1 deletion tests/models/constituency_parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_decode_runs(self):
output_dict = self.model(**training_tensors)
decode_output_dict = self.model.decode(output_dict)
assert set(decode_output_dict.keys()) == {'spans', 'class_probabilities', 'trees',
'tokens', 'sentence_lengths', 'num_spans', 'loss'}
'tokens', 'num_spans', 'loss'}
metrics = self.model.get_metrics(reset=True)
metric_keys = set(metrics.keys())
assert "evalb_precision" in metric_keys
Expand Down
7 changes: 3 additions & 4 deletions tests/service/predictors/constituency_parser_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# pylint: disable=no-self-use,invalid-name
# pylint: disable=no-self-use,invalid-name,protected-access
from unittest import TestCase

from nltk import Tree
Expand Down Expand Up @@ -57,9 +57,8 @@ def test_batch_prediction(self):
def test_build_hierplane_tree(self):
tree = Tree.fromstring("(S (NP (D the) (N dog)) (VP (V chased) (NP (D the) (N cat))))")
archive = load_archive('tests/fixtures/constituency_parser/serialization/model.tar.gz')
predictor = Predictor.from_archive(archive, 'constituency-parser')
predictor = Predictor.from_archive(archive, 'constituency-parser')

hierplane_tree = predictor._build_hierplane_tree(tree, 0, is_root=True)

text = " ".join(tree.leaves())
print(hierplane_tree)
print(hierplane_tree)