Skip to content

Commit

Permalink
improve the class_utils to handle the duplicate class name case.
Browse files Browse the repository at this point in the history
  • Loading branch information
yhwen committed Aug 23, 2024
1 parent f251be4 commit c5d4f7b
Show file tree
Hide file tree
Showing 3 changed files with 587 additions and 25 deletions.
37 changes: 14 additions & 23 deletions nvflare/fuel/utils/class_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import pkgutil
from typing import Dict, List, Optional

from nvflare.apis.fl_component import FLComponent
from nvflare.fuel.common.excepts import ConfigError
from nvflare.fuel.utils.components_utils import create_classes_table_static
from nvflare.security.logging import secure_format_exception

Expand Down Expand Up @@ -61,17 +63,6 @@ def instantiate_class(class_path, init_params):
return instance


class _ModuleScanResult:
"""Data class for ModuleScanner."""

def __init__(self, class_name: str, module_name: str):
self.class_name = class_name
self.module_name = module_name

def __str__(self):
return f"{self.class_name}:{self.module_name}"


class ModuleScanner:
def __init__(self, base_pkgs: List[str], module_names: List[str], exclude_libs=True):
"""Loads specified modules from base packages and then constructs a class to module name mapping.
Expand All @@ -90,7 +81,6 @@ def __init__(self, base_pkgs: List[str], module_names: List[str], exclude_libs=T

def create_classes_table(self):
class_table: Dict[str, str] = {}
scan_result_table = {}
for base in self.base_pkgs:
package = importlib.import_module(base)

Expand All @@ -108,18 +98,12 @@ def create_classes_table(self):
not name.startswith("_")
and inspect.isclass(obj)
and obj.__module__ == module_name
and issubclass(obj, FLComponent)
):
# same class name exists in multiple modules
if name in scan_result_table:
scan_result = scan_result_table[name]
if name in class_table:
class_table.pop(name)
class_table[f"{scan_result.module_name}.{name}"] = module_name
class_table[f"{module_name}.{name}"] = module_name
if name in class_table:
class_table[name].append(module_name)
else:
scan_result = _ModuleScanResult(class_name=name, module_name=module_name)
scan_result_table[name] = scan_result
class_table[name] = module_name
class_table[name] = [module_name]
except (ModuleNotFoundError, RuntimeError, AttributeError) as e:
self._logger.debug(
f"Try to import module {module_name}, but failed: {secure_format_exception(e)}. "
Expand All @@ -137,7 +121,14 @@ def get_module_name(self, class_name) -> Optional[str]:
Returns:
The module name if found.
"""
return self._class_table.get(class_name, None)
if class_name not in self._class_table:
return None

modules = self._class_table.get(class_name, None)
if modules and len(modules) > 1:
raise ConfigError(f"There are multiple modules with the class_name:{class_name}, modules are: {modules}")
else:
return modules[0]


def _retrieve_parameters(class__, parameters):
Expand Down
Loading

0 comments on commit c5d4f7b

Please sign in to comment.