Skip to content

Commit

Permalink
Extract the AttrDict class to vissl.config (facebookresearch#245)
Browse files Browse the repository at this point in the history
Summary:
This code move allows to remove the implicit circular dependency between `vissl.utils.hydra_config` and `vissl.config`:
- `vissl.utils.hydra_config` directly imports `vissl.config`
- `vissl.config` used to implicitly depend on hydra_config by using the AttrDict class

This small refactoring is a first step toward implementing the feature described in issue facebookresearch#241 and also allows type-hinting the config function.

All files have been migrated to the new location of AttrDict, but in case some other files are not under the scope of this PR (internal files), they will continue to work, since hydra_config imports AttrDict.

Pull Request resolved: facebookresearch#245

Reviewed By: prigoyal

Differential Revision: D27188806

Pulled By: QuentinDuval

fbshipit-source-id: 7abcf1028c9f3e028c1e326627d28fcf1550ec1d
  • Loading branch information
QuentinDuval authored and facebook-github-bot committed Mar 25, 2021
1 parent 5e62994 commit 2b03a74
Show file tree
Hide file tree
Showing 44 changed files with 131 additions and 123 deletions.
4 changes: 2 additions & 2 deletions tests/test_mlp.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import unittest

import torch
from vissl.models.heads import LinearEvalMLP, MLP
from vissl.utils.hydra_config import AttrDict
from vissl.config import AttrDict
from vissl.models.heads import MLP, LinearEvalMLP


class TestMLP(unittest.TestCase):
Expand Down
5 changes: 3 additions & 2 deletions tools/cluster_features_and_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@

import numpy as np
from hydra.experimental import compose, initialize_config_module
from vissl.config import AttrDict
from vissl.data import build_dataset
from vissl.hooks import default_hook_generator
from vissl.utils.checkpoint import get_checkpoint_folder
from vissl.utils.distributed_launcher import launch_distributed
from vissl.utils.env import set_env_vars
from vissl.utils.hydra_config import AttrDict, convert_to_attrdict, is_hydra_available
from vissl.utils.hydra_config import convert_to_attrdict, is_hydra_available
from vissl.utils.io import save_file
from vissl.utils.logger import setup_logging, shutdown_logging
from vissl.utils.misc import merge_features, set_seeds, is_faiss_available
from vissl.utils.misc import is_faiss_available, merge_features, set_seeds


def get_data_features_and_images(cfg: AttrDict):
Expand Down
3 changes: 2 additions & 1 deletion tools/perf_measurement/benchmark_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import tqdm
from fvcore.common.timer import Timer
from hydra.experimental import compose, initialize_config_module
from vissl.config import AttrDict
from vissl.data import build_dataset, get_loader
from vissl.utils.hydra_config import AttrDict, convert_to_attrdict, is_hydra_available
from vissl.utils.hydra_config import convert_to_attrdict, is_hydra_available
from vissl.utils.logger import setup_logging


Expand Down
4 changes: 3 additions & 1 deletion vissl/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import logging

from vissl.config.attr_dict import AttrDict


# When to do version bump:
# - version bump is NOT required if the new keys are being added to the defaults.yaml
Expand All @@ -17,7 +19,7 @@
LATEST_CFG_VERSION = 1


def check_cfg_version(cfg):
def check_cfg_version(cfg: AttrDict):
"""
Check the config version
Expand Down
78 changes: 78 additions & 0 deletions vissl/config/attr_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


class AttrDict(dict):
"""
Dictionary subclass whose entries can be accessed like attributes (as well as normally).
Credits: https://aiida.readthedocs.io/projects/aiida-core/en/latest/_modules/aiida/common/extendeddicts.html#AttributeDict # noqa
"""

def __init__(self, dictionary):
"""
Recursively turn the `dict` and all its nested dictionaries into `AttrDict` instance.
"""
super().__init__()

for key, value in dictionary.items():
if isinstance(value, dict):
self[key] = AttrDict(value)
else:
self[key] = value

def __getattr__(self, key):
"""
Read a key as an attribute.
:raises AttributeError: if the attribute does not correspond to an existing key.
"""
if key in self:
return self[key]
else:
raise AttributeError(
f"{self.__class__.__name__} object has no attribute {key}."
)

def __setattr__(self, key, value):
"""
Set a key as an attribute.
"""
self[key] = value

def __delattr__(self, key):
"""
Delete a key as an attribute.
:raises AttributeError: if the attribute does not correspond to an existing key.
"""
if key in self:
del self[key]
else:
raise AttributeError(
f"{self.__class__.__name__} object has no attribute {key}."
)

def __getstate__(self):
"""
Needed for pickling this class.
"""
return self.__dict__.copy()

def __setstate__(self, dictionary):
"""
Needed for pickling this class.
"""
self.__dict__.update(dictionary)

def __deepcopy__(self, memo=None):
"""
Deep copy.
"""
from copy import deepcopy

if memo is None:
memo = {}
retval = deepcopy(dict(self))
return self.__class__(retval)

def __dir__(self):
return self.keys()
2 changes: 1 addition & 1 deletion vissl/data/ssl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from classy_vision.generic.distributed_util import get_world_size
from fvcore.common.file_io import PathManager
from torch.utils.data import Dataset
from vissl.config import AttrDict
from vissl.data import dataset_catalog
from vissl.data.data_helper import balanced_sub_sampling, unbalanced_sub_sampling
from vissl.data.ssl_transforms import get_transform
from vissl.utils.env import get_machine_local_and_dist_rank
from vissl.utils.hydra_config import AttrDict


def _convert_lbl_to_long(lbl):
Expand Down
2 changes: 1 addition & 1 deletion vissl/data/torchvision_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from PIL import Image
from torch.utils.data import Dataset
from torchvision.datasets import CIFAR10, CIFAR100, MNIST, STL10, SVHN
from vissl.utils.hydra_config import AttrDict
from vissl.config import AttrDict


class TorchvisionDatasetName:
Expand Down
3 changes: 2 additions & 1 deletion vissl/engines/extract_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import logging

from vissl.config import AttrDict
from vissl.trainer import SelfSupervisionTrainer
from vissl.utils.checkpoint import get_checkpoint_folder
from vissl.utils.collect_env import collect_env_info
from vissl.utils.env import get_machine_local_and_dist_rank, set_env_vars
from vissl.utils.hydra_config import AttrDict, print_cfg
from vissl.utils.hydra_config import print_cfg
from vissl.utils.io import save_file
from vissl.utils.logger import setup_logging, shutdown_logging
from vissl.utils.misc import set_seeds, setup_multiprocessing_method
Expand Down
3 changes: 2 additions & 1 deletion vissl/engines/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
from classy_vision.hooks.classy_hook import ClassyHook
from vissl.config import AttrDict
from vissl.hooks import default_hook_generator
from vissl.trainer import SelfSupervisionTrainer
from vissl.utils.collect_env import collect_env_info
Expand All @@ -14,7 +15,7 @@
print_system_env_info,
set_env_vars,
)
from vissl.utils.hydra_config import AttrDict, print_cfg
from vissl.utils.hydra_config import print_cfg
from vissl.utils.logger import setup_logging, shutdown_logging
from vissl.utils.misc import set_seeds, setup_multiprocessing_method

Expand Down
2 changes: 1 addition & 1 deletion vissl/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List

from classy_vision.hooks.classy_hook import ClassyHook
from vissl.config import AttrDict
from vissl.hooks.deepclusterv2_hooks import ClusterMemoryHook, InitMemoryHook # noqa
from vissl.hooks.grad_clip_hooks import GradClipHook # noqa
from vissl.hooks.log_hooks import ( # noqa
Expand Down Expand Up @@ -31,7 +32,6 @@
SwAVMomentumNormalizePrototypesHook,
)
from vissl.hooks.tensorboard_hook import SSLTensorboardHook # noqa
from vissl.utils.hydra_config import AttrDict
from vissl.utils.tensorboard import get_tensorboard_hook, is_tensorboard_available


Expand Down
2 changes: 1 addition & 1 deletion vissl/losses/bce_logits_multiple_output_single_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from classy_vision.generic.util import is_on_gpu
from classy_vision.losses import ClassyLoss, register_loss
from torch import nn
from vissl.utils.hydra_config import AttrDict
from vissl.config import AttrDict


@register_loss("bce_logits_multiple_output_single_target")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from classy_vision.generic.util import is_on_gpu
from classy_vision.losses import ClassyLoss, register_loss
from torch import Tensor, nn
from vissl.utils.hydra_config import AttrDict
from vissl.config import AttrDict


class SmoothCrossEntropy(torch.nn.modules.CrossEntropyLoss):
Expand Down
2 changes: 1 addition & 1 deletion vissl/losses/deepclusterv2_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from classy_vision.losses import ClassyLoss, register_loss
from torch import nn
from vissl.utils.hydra_config import AttrDict
from vissl.config import AttrDict
from vissl.utils.misc import get_indices_sparse


Expand Down
2 changes: 1 addition & 1 deletion vissl/losses/multicrop_simclr_info_nce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import numpy as np
import torch
from classy_vision.losses import register_loss
from vissl.config import AttrDict
from vissl.losses.simclr_info_nce_loss import SimclrInfoNCECriterion, SimclrInfoNCELoss
from vissl.utils.hydra_config import AttrDict


@register_loss("multicrop_simclr_info_nce_loss")
Expand Down
2 changes: 1 addition & 1 deletion vissl/losses/nce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from classy_vision.generic.util import is_pos_int
from classy_vision.losses import ClassyLoss, register_loss
from torch import nn
from vissl.utils.hydra_config import AttrDict
from vissl.config import AttrDict


@register_loss("nce_loss_with_memory")
Expand Down
2 changes: 1 addition & 1 deletion vissl/losses/simclr_info_nce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from classy_vision.generic.distributed_util import get_cuda_device_index, get_rank
from classy_vision.losses import ClassyLoss, register_loss
from torch import nn
from vissl.config import AttrDict
from vissl.utils.distributed_gradients import gather_from_all
from vissl.utils.hydra_config import AttrDict


@register_loss("simclr_info_nce_loss")
Expand Down
2 changes: 1 addition & 1 deletion vissl/losses/swav_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from classy_vision.losses import ClassyLoss, register_loss
from fvcore.common.file_io import PathManager
from torch import nn
from vissl.utils.hydra_config import AttrDict
from vissl.config import AttrDict


@register_loss("swav_loss")
Expand Down
2 changes: 1 addition & 1 deletion vissl/losses/swav_momentum_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from classy_vision.losses import ClassyLoss, register_loss
from torch import nn
from vissl.utils.hydra_config import AttrDict
from vissl.config import AttrDict


@register_loss("swav_momentum_loss")
Expand Down
2 changes: 1 addition & 1 deletion vissl/meters/accuracy_list_meter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from classy_vision.generic.util import is_pos_int
from classy_vision.meters import AccuracyMeter, ClassyMeter, register_meter
from vissl.utils.hydra_config import AttrDict
from vissl.config import AttrDict


@register_meter("accuracy_list_meter")
Expand Down
2 changes: 1 addition & 1 deletion vissl/meters/mean_ap_list_meter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import torch
from classy_vision.generic.util import is_pos_int
from classy_vision.meters import ClassyMeter, register_meter
from vissl.config import AttrDict
from vissl.meters.mean_ap_meter import MeanAPMeter
from vissl.utils.hydra_config import AttrDict


@register_meter("mean_ap_list_meter")
Expand Down
2 changes: 1 addition & 1 deletion vissl/meters/mean_ap_meter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import torch
from classy_vision.generic.distributed_util import all_reduce_sum, gather_from_all
from classy_vision.meters import ClassyMeter, register_meter
from vissl.config import AttrDict
from vissl.utils.env import get_machine_local_and_dist_rank
from vissl.utils.hydra_config import AttrDict
from vissl.utils.svm_utils.evaluate import get_precision_recall


Expand Down
2 changes: 1 addition & 1 deletion vissl/models/heads/linear_eval_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import torch
import torch.nn as nn
from vissl.config import AttrDict
from vissl.models.heads import register_model_head
from vissl.models.heads.mlp import MLP
from vissl.utils.hydra_config import AttrDict


@register_model_head("eval_mlp")
Expand Down
2 changes: 1 addition & 1 deletion vissl/models/heads/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import torch
import torch.nn as nn
from vissl.config import AttrDict
from vissl.models.heads import register_model_head
from vissl.utils.hydra_config import AttrDict


@register_model_head("mlp")
Expand Down
2 changes: 1 addition & 1 deletion vissl/models/heads/siamese_concat_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import torch
import torch.nn as nn
from vissl.config import AttrDict
from vissl.models.heads import register_model_head
from vissl.utils.hydra_config import AttrDict


@register_model_head("siamese_concat_view")
Expand Down
2 changes: 1 addition & 1 deletion vissl/models/heads/swav_prototypes_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import torch
import torch.nn as nn
from vissl.config import AttrDict
from vissl.models.heads import register_model_head
from vissl.utils.hydra_config import AttrDict


@register_model_head("swav_head")
Expand Down
2 changes: 1 addition & 1 deletion vissl/models/heads/vision_transformer_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from collections import OrderedDict

import torch.nn as nn
from vissl.config import AttrDict
from vissl.models.heads import register_model_head
from vissl.models.model_helpers import lecun_normal_init, trunc_normal_
from vissl.utils.hydra_config import AttrDict


@register_model_head("vision_transformer_head")
Expand Down
2 changes: 1 addition & 1 deletion vissl/models/trunks/alexnet_bvlc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import torch.nn as nn
from vissl.config import AttrDict
from vissl.models.model_helpers import Flatten, get_trunk_forward_outputs_module_list
from vissl.models.trunks import register_model_trunk
from vissl.utils.hydra_config import AttrDict


@register_model_trunk("alexnet_bvlc")
Expand Down
2 changes: 1 addition & 1 deletion vissl/models/trunks/alexnet_colorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import torch
import torch.nn as nn
from vissl.config import AttrDict
from vissl.models.model_helpers import Flatten, get_trunk_forward_outputs_module_list
from vissl.models.trunks import register_model_trunk
from vissl.utils.hydra_config import AttrDict


@register_model_trunk("alexnet_colorization")
Expand Down
2 changes: 1 addition & 1 deletion vissl/models/trunks/alexnet_deepcluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import torch
import torch.nn as nn
from vissl.config import AttrDict
from vissl.models.model_helpers import Flatten, get_trunk_forward_outputs_module_list
from vissl.models.trunks import register_model_trunk
from vissl.utils.hydra_config import AttrDict


@register_model_trunk("alexnet_deepcluster")
Expand Down
Loading

0 comments on commit 2b03a74

Please sign in to comment.