diff --git a/griptape/config/__init__.py b/griptape/config/__init__.py index b242d80a7..4a87c2d01 100644 --- a/griptape/config/__init__.py +++ b/griptape/config/__init__.py @@ -9,6 +9,7 @@ from .anthropic_driver_config import AnthropicDriverConfig from .google_driver_config import GoogleDriverConfig from .cohere_driver_config import CohereDriverConfig +from .logging_config import LoggingConfig from .config import config @@ -22,5 +23,6 @@ "AnthropicDriverConfig", "GoogleDriverConfig", "CohereDriverConfig", + "LoggingConfig", "config", ] diff --git a/griptape/config/base_config.py b/griptape/config/base_config.py index 9209aa4a4..ef62a4e9b 100644 --- a/griptape/config/base_config.py +++ b/griptape/config/base_config.py @@ -1,14 +1,22 @@ +from __future__ import annotations + from abc import ABC +from typing import TYPE_CHECKING, Optional -from attrs import define +from attrs import define, field from griptape.mixins.serializable_mixin import SerializableMixin -from .base_driver_config import BaseDriverConfig -from .logging_config import LoggingConfig +if TYPE_CHECKING: + from .base_driver_config import BaseDriverConfig + from .logging_config import LoggingConfig @define(kw_only=True) class BaseConfig(SerializableMixin, ABC): - drivers: BaseDriverConfig - logging: LoggingConfig + _logging: Optional[LoggingConfig] = field(alias="logging") + _drivers: Optional[BaseDriverConfig] = field(alias="drivers") + + def reset(self) -> None: + self._logging = None + self._drivers = None diff --git a/griptape/config/config.py b/griptape/config/config.py index 97d501abb..7b70df409 100644 --- a/griptape/config/config.py +++ b/griptape/config/config.py @@ -1,15 +1,42 @@ -from attrs import Factory, define, field +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +from attrs import define, field from .base_config import BaseConfig -from .base_driver_config import BaseDriverConfig from .logging_config import LoggingConfig from .openai_driver_config import OpenAiDriverConfig +if TYPE_CHECKING: + from .base_driver_config import BaseDriverConfig -@define + +@define(kw_only=True) class _Config(BaseConfig): - drivers: BaseDriverConfig = field(default=Factory(lambda: OpenAiDriverConfig()), kw_only=True) - logging: LoggingConfig = field(default=Factory(lambda: LoggingConfig()), kw_only=True) + _logging: Optional[LoggingConfig] = field(default=None, alias="logging") + _drivers: Optional[BaseDriverConfig] = field(default=None, alias="drivers") + + @property + def drivers(self) -> BaseDriverConfig: + """Lazily instantiates the drivers configuration to avoid client errors like missing API key.""" + if self._drivers is None: + self._drivers = OpenAiDriverConfig() + return self._drivers + + @drivers.setter + def drivers(self, drivers: BaseDriverConfig) -> None: + self._drivers = drivers + + @property + def logging(self) -> LoggingConfig: + if self._logging is None: + self._logging = LoggingConfig() + return self._logging + + @logging.setter + def logging(self, logging: LoggingConfig) -> None: + self._logging = logging config = _Config() diff --git a/tests/unit/config/test_config.py b/tests/unit/config/test_config.py new file mode 100644 index 000000000..ecd02e1a1 --- /dev/null +++ b/tests/unit/config/test_config.py @@ -0,0 +1,25 @@ +import pytest + +from griptape.config.openai_driver_config import OpenAiDriverConfig + + +class TestConfig: + @pytest.mark.skip_mock_config() + def test_init(self): + from griptape.config import LoggingConfig, config + + assert isinstance(config.drivers, OpenAiDriverConfig) + assert isinstance(config.logging, LoggingConfig) + + @pytest.mark.skip_mock_config() + def test_lazy_init(self): + from griptape.config import config + + assert config._drivers is None + assert config._logging is None + + assert config.drivers is not None + assert config.logging is not None + + assert config._drivers is not None + assert config._logging is not None diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 8a37f6d28..db881bc20 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,12 +1,12 @@ import pytest -from griptape.config import config -from griptape.events import event_bus from tests.mocks.mock_driver_config import MockDriverConfig @pytest.fixture(autouse=True) def mock_event_bus(): + from griptape.events import event_bus + event_bus.clear_event_listeners() yield event_bus @@ -15,7 +15,19 @@ def mock_event_bus(): @pytest.fixture(autouse=True) -def mock_config(): +def mock_config(request): + from griptape.config import config + + config.reset() + + # Some tests we don't want to use the autouse fixture's MockDriverConfig + if "skip_mock_config" in request.keywords: + yield + + return + config.drivers = MockDriverConfig() - return config + yield config + + config.reset()