Skip to content

Commit

Permalink
remove the module classes scan, only add limit number of classes to u…
Browse files Browse the repository at this point in the history
…se name search.
  • Loading branch information
yhwen committed Aug 9, 2024
1 parent 06242de commit 7bd14ec
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 7 deletions.
16 changes: 9 additions & 7 deletions nvflare/fuel/utils/class_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Dict, List, Optional

from nvflare.security.logging import secure_format_exception
from nvflare.utils.components_utils import create_classes_table_static

DEPRECATED_PACKAGES = ["nvflare.app_common.pt", "nvflare.app_common.homomorphic_encryption"]

Expand Down Expand Up @@ -85,10 +86,10 @@ def __init__(self, base_pkgs: List[str], module_names: List[str], exclude_libs=T
self.exclude_libs = exclude_libs

self._logger = logging.getLogger(self.__class__.__name__)
self._class_table: Dict[str, str] = {}
self._create_classes_table()
self._class_table = create_classes_table_static()

def _create_classes_table(self):
class_table: Dict[str, str] = {}
scan_result_table = {}
for base in self.base_pkgs:
package = importlib.import_module(base)
Expand All @@ -111,20 +112,21 @@ def _create_classes_table(self):
# same class name exists in multiple modules
if name in scan_result_table:
scan_result = scan_result_table[name]
if name in self._class_table:
self._class_table.pop(name)
self._class_table[f"{scan_result.module_name}.{name}"] = module_name
self._class_table[f"{module_name}.{name}"] = module_name
if name in class_table:
class_table.pop(name)
class_table[f"{scan_result.module_name}.{name}"] = module_name
class_table[f"{module_name}.{name}"] = module_name
else:
scan_result = _ModuleScanResult(class_name=name, module_name=module_name)
scan_result_table[name] = scan_result
self._class_table[name] = module_name
class_table[name] = module_name
except (ModuleNotFoundError, RuntimeError) as e:
self._logger.debug(
f"Try to import module {module_name}, but failed: {secure_format_exception(e)}. "
f"Can't use name in config to refer to classes in module: {module_name}."
)
pass
return class_table

def get_module_name(self, class_name) -> Optional[str]:
"""Gets the name of the module that contains this class.
Expand Down
64 changes: 64 additions & 0 deletions nvflare/utils/components_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def create_classes_table_static():
from nvflare.app_common.abstract.aggregator import Aggregator
from nvflare.app_common.abstract.shareable_generator import ShareableGenerator
from nvflare.app_common.aggregators import InTimeAccumulateWeightedAggregator
from nvflare.app_common.aggregators.collect_and_assemble_aggregator import CollectAndAssembleAggregator
from nvflare.app_common.aggregators.dxo_aggregator import DXOAggregator
from nvflare.app_common.ccwf import (
CrossSiteEvalClientController,
CrossSiteEvalServerController,
CyclicClientController,
SwarmClientController,
SwarmServerController,
)
from nvflare.app_common.ccwf.swarm_client_ctl import Gatherer
from nvflare.app_common.response_processors.global_weights_initializer import GlobalWeightsInitializer
from nvflare.app_common.shareablegenerators import FullModelShareableGenerator
from nvflare.app_common.workflows.cross_site_eval import CrossSiteEval
from nvflare.app_common.workflows.cross_site_model_eval import CrossSiteModelEval
from nvflare.app_common.workflows.cyclic_ctl import CyclicController
from nvflare.app_common.workflows.global_model_eval import GlobalModelEval
from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather
from nvflare.app_common.workflows.scatter_and_gather_scaffold import ScatterAndGatherScaffold

objects = {
ScatterAndGather,
ScatterAndGatherScaffold,
Aggregator,
CollectAndAssembleAggregator,
CrossSiteEval,
CrossSiteEvalClientController,
CrossSiteEvalServerController,
CrossSiteModelEval,
CyclicClientController,
CyclicController,
DXOAggregator,
GlobalModelEval,
GlobalWeightsInitializer,
Gatherer,
ShareableGenerator,
SwarmClientController,
SwarmServerController,
FullModelShareableGenerator,
InTimeAccumulateWeightedAggregator,
}

class_table = {}
for obj in objects:
class_table[obj.__name__] = obj.__module__
return class_table

0 comments on commit 7bd14ec

Please sign in to comment.