Skip to content

Commit

Permalink
command-line configuration explorer (allenai#1309)
Browse files Browse the repository at this point in the history
* config explorer

* wip

* wip

* progress

* work

* config tool

* configuration

* configuration command line tool

* add test for configure command

* fix test collisions

* fix docs

* scope suppress FutureWarning to only h5py imports

* fix import ordering

* fix docs

* fix docs (again)
  • Loading branch information
joelgrus authored Jun 1, 2018
1 parent b0d0d94 commit 03d6fad
Show file tree
Hide file tree
Showing 17 changed files with 471 additions and 15 deletions.
17 changes: 11 additions & 6 deletions allennlp/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging

from allennlp import __version__
from allennlp.commands.configure import Configure
from allennlp.commands.elmo import Elmo
from allennlp.commands.evaluate import Evaluate
from allennlp.commands.fine_tune import FineTune
Expand Down Expand Up @@ -33,6 +34,7 @@ def main(prog: str = None,

subcommands = {
# Default commands
"configure": Configure(),
"train": Train(),
"evaluate": Evaluate(),
"predict": Predict(),
Expand All @@ -49,11 +51,14 @@ def main(prog: str = None,

for name, subcommand in subcommands.items():
subparser = subcommand.add_subparser(name, subparsers)
subparser.add_argument('--include-package',
type=str,
action='append',
default=[],
help='additional packages to include')
# configure doesn't need include-package because it imports
# whatever classes it needs.
if name != "configure":
subparser.add_argument('--include-package',
type=str,
action='append',
default=[],
help='additional packages to include')

args = parser.parse_args()

Expand All @@ -62,7 +67,7 @@ def main(prog: str = None,
# so give the user some help.
if 'func' in dir(args):
# Import any additional modules needed (to register custom classes).
for package_name in args.include_package:
for package_name in getattr(args, 'include_package', ()):
import_submodules(package_name)
args.func(args)
else:
Expand Down
61 changes: 61 additions & 0 deletions allennlp/commands/configure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
The ``configure`` subcommand generates a stub configuration for
the specified class (or for the top level configuration if no class specified).
.. code-block:: bash
$ allennlp configure --help
usage: allennlp configure [-h] [class]
Generate a configuration stub for a specific class (or for config as a whole if [class] is omitted).
positional arguments:
class
optional arguments:
-h, --help show this help message and exit
"""

import argparse

from allennlp.commands.subcommand import Subcommand
from allennlp.common.configuration import configure, Config, render_config

class Configure(Subcommand):
def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argparse.ArgumentParser:
# pylint: disable=protected-access
description = '''Generate a configuration stub for a specific class (or for config as a whole)'''
subparser = parser.add_parser(
name, description=description, help='Generate configuration stubs.')

subparser.add_argument('cla55', nargs='?', default='', metavar='class')
subparser.set_defaults(func=_configure)

return subparser

def _configure(args: argparse.Namespace) -> None:
cla55 = args.cla55
parts = cla55.split(".")
module = ".".join(parts[:-1])
class_name = parts[-1]

print()

try:
config = configure(cla55)
if isinstance(config, Config):
if cla55:
print(f"configuration stub for {cla55}:\n")
else:
print(f"configuration stub for AllenNLP:\n")
print(render_config(config))
else:
print(f"{class_name} is an abstract base class, choose one of the following subclasses:\n")
for subclass in config:
print("\t", subclass)
except ModuleNotFoundError:
print(f"unable to load module {module}")
except AttributeError:
print(f"class {class_name} does not exist in module {module}")

print()
7 changes: 6 additions & 1 deletion allennlp/commands/elmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,14 @@

import logging
from typing import IO, List, Iterable, Tuple
import warnings

import argparse
import h5py

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
import h5py

import numpy
import torch

Expand Down
262 changes: 262 additions & 0 deletions allennlp/common/configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
"""
Tools for programmatically generating config files for AllenNLP models.
"""
# pylint: disable=protected-access

from typing import NamedTuple, Optional, Any, List, TypeVar, Generic, Type, Dict, Union
import inspect
import importlib

import torch

from allennlp.common import Registrable, JsonDict
from allennlp.data.dataset_readers import DatasetReader
from allennlp.data.iterators import DataIterator
from allennlp.data.vocabulary import Vocabulary
from allennlp.models.model import Model
from allennlp.modules.seq2seq_encoders import _Seq2SeqWrapper
from allennlp.modules.seq2vec_encoders import _Seq2VecWrapper
from allennlp.training.optimizers import Optimizer as AllenNLPOptimizer
from allennlp.training.trainer import Trainer


def full_name(cla55: type) -> str:
"""
Return the full name (including module) of the given class.
"""
return f"{cla55.__module__}.{cla55.__name__}"


class ConfigItem(NamedTuple):
"""
Each ``ConfigItem`` represents a single entry in a configuration JsonDict.
"""
name: str
annotation: type
default_value: Optional[Any] = None
comment: str = ''

def to_json(self) -> JsonDict:
return {
"annotation": full_name(self.annotation),
"default_value": str(self.default_value),
"comment": self.comment
}


T = TypeVar("T")


class Config(Generic[T]):
"""
A ``Config`` represents an entire subdict in a configuration file.
If it corresponds to a named subclass of a registrable class,
it will also contain a ``type`` item in addition to whatever
items are required by the subclass ``from_params`` method.
"""
def __init__(self, items: List[ConfigItem], typ3: str = None) -> None:
self.items = items
self.typ3 = typ3

def __repr__(self) -> str:
return f"Config({self.items})"

def to_json(self) -> JsonDict:
item_dict: JsonDict = {
item.name: item.to_json()
for item in self.items
}

if self.typ3:
item_dict["type"] = self.typ3

return item_dict


# ``None`` is sometimes the default value for a function parameter,
# so we use a special sentinel to indicate that a parameter has no
# default value.
_NO_DEFAULT = object()

def _get_config_type(cla55: type) -> Optional[str]:
"""
Find the name (if any) that a subclass was registered under.
We do this simply by iterating through the registry until we
find it.
"""
for subclass_dict in Registrable._registry.values():
for name, subclass in subclass_dict.items():
if subclass == cla55:
return name
return None


def _auto_config(cla55: Type[T]) -> Config[T]:
"""
Create the ``Config`` for a class by reflecting on its ``__init__``
method and applying a few hacks.
"""
argspec = inspect.getfullargspec(cla55.__init__)

items: List[ConfigItem] = []

num_args = len(argspec.args)
defaults = list(argspec.defaults or [])
num_default_args = len(defaults)
num_non_default_args = num_args - num_default_args

# Required args all come first, default args at the end.
defaults = [_NO_DEFAULT for _ in range(num_non_default_args)] + defaults

for name, default in zip(argspec.args, defaults):
# Don't include self
if name == "self":
continue
annotation = argspec.annotations.get(name)

# Don't include Model, the only place you'd specify that is top-level.
if annotation == Model:
continue

# Don't include params for an Optimizer
if torch.optim.Optimizer in cla55.__bases__ and name == "params":
continue

# Don't include datasets in the trainer
if cla55 == Trainer and name.endswith("_dataset"):
continue

# Hack in our Optimizer class to the trainer
if cla55 == Trainer and annotation == torch.optim.Optimizer:
annotation = AllenNLPOptimizer

items.append(ConfigItem(name, annotation, default))

return Config(items, typ3=_get_config_type(cla55))


def render_config(config: Config, indent: str = "") -> str:
"""
Pretty-print a config in sort-of-JSON+comments.
"""
# Add four spaces to the indent.
new_indent = indent + " "

return "".join([
# opening brace + newline
"{\n",
# "type": "...", (if present)
f'{new_indent}"type": "{config.typ3}",\n' if config.typ3 else '',
# render each item
"".join(_render(item, new_indent) for item in config.items),
# indent and close the brace
indent,
"}\n"
])

def _render(item: ConfigItem, indent: str = "") -> str:
"""
Render a single config item, with the provided indent
"""
optional = item.default_value != _NO_DEFAULT

# Anything with a from_params method is itself configurable
if hasattr(item.annotation, 'from_params'):
rendered_annotation = f"{item.annotation} (configurable)"
else:
rendered_annotation = str(item.annotation)

rendered_item = "".join([
# rendered_comment,
indent,
"// " if optional else "",
f'"{item.name}": ',
rendered_annotation,
f" (default: {item.default_value} )" if optional else "",
f" // {item.comment}" if item.comment else "",
"\n"
])

return rendered_item

BASE_CONFIG: Config = Config([
ConfigItem(name="dataset_reader",
annotation=DatasetReader,
default_value=_NO_DEFAULT,
comment="specify your dataset reader here"),
ConfigItem(name="validation_dataset_reader",
annotation=DatasetReader,
default_value=None,
comment="same as dataset_reader by default"),
ConfigItem(name="train_data_path",
annotation=str,
default_value=_NO_DEFAULT,
comment="path to the training data"),
ConfigItem(name="validation_data_path",
annotation=str,
default_value=None,
comment="path to the validation data"),
ConfigItem(name="test_data_path",
annotation=str,
default_value=None,
comment="path to the test data (you probably don't want to use this!)"),
ConfigItem(name="evaluate_on_test",
annotation=bool,
default_value=False,
comment="whether to evaluate on the test dataset at the end of training (don't do it!"),
ConfigItem(name="model",
annotation=Model,
default_value=_NO_DEFAULT,
comment="specify your model here"),
ConfigItem(name="iterator",
annotation=DataIterator,
default_value=_NO_DEFAULT,
comment="specify your data iterator here"),
ConfigItem(name="trainer",
annotation=Trainer,
default_value=_NO_DEFAULT,
comment="specify the trainer parameters here"),
ConfigItem(name="datasets_for_vocab_creation",
annotation=List[str],
default_value=None,
comment="if not specified, use all datasets"),
ConfigItem(name="vocabulary",
annotation=Vocabulary,
default_value=None,
comment="vocabulary options"),

])

def _valid_choices(cla55: type) -> Dict[str, str]:
"""
Return a mapping {registered_name -> subclass_name}
for the registered subclasses of `cla55`.
"""
choices: Dict[str, str] = {}

if cla55 not in Registrable._registry:
raise ValueError(f"{cla55} is not a known Registrable class")

for name, subclass in Registrable._registry[cla55].items():
# These wrapper classes need special treatment
if isinstance(subclass, (_Seq2SeqWrapper, _Seq2VecWrapper)):
subclass = subclass._module_class

choices[name] = full_name(subclass)

return choices

def configure(full_path: str = '') -> Union[Config, List[str]]:
if not full_path:
return BASE_CONFIG

parts = full_path.split(".")
class_name = parts[-1]
module_name = ".".join(parts[:-1])
module = importlib.import_module(module_name)
cla55 = getattr(module, class_name)

if Registrable in cla55.__bases__:
return list(_valid_choices(cla55).values())
else:
return _auto_config(cla55)
Loading

0 comments on commit 03d6fad

Please sign in to comment.