Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model allowlist and blocklists #446

Merged
merged 4 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
implement model allow/blocklists in config manager
  • Loading branch information
dlqqq committed Nov 8, 2023
commit 3bae8cfd4bca67a45e09e34df9fbcda28fd4fdd8
110 changes: 90 additions & 20 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import shutil
import time
from typing import Optional, Union
from typing import List, Optional, Union

from deepmerge import always_merger as Merger
from jsonschema import Draft202012Validator as Validator
Expand All @@ -12,10 +12,8 @@
AnyProvider,
EmProvidersDict,
LmProvidersDict,
ProviderRestrictions,
get_em_provider,
get_lm_provider,
is_provider_allowed,
)
from jupyter_core.paths import jupyter_data_dir
from traitlets import Integer, Unicode
Expand Down Expand Up @@ -57,6 +55,10 @@ class KeyEmptyError(Exception):
pass


class BlockedModelError(Exception):
pass


def _validate_provider_authn(config: GlobalConfig, provider: AnyProvider):
# TODO: handle non-env auth strategies
if not provider.auth_strategy or provider.auth_strategy.type != "env":
Expand Down Expand Up @@ -99,27 +101,34 @@ def __init__(
log: Logger,
lm_providers: LmProvidersDict,
em_providers: EmProvidersDict,
restrictions: ProviderRestrictions,
allowed_providers: Optional[List[str]],
blocked_providers: Optional[List[str]],
allowed_models: Optional[List[str]],
blocked_models: Optional[List[str]],
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.log = log
"""List of LM providers."""

self._lm_providers = lm_providers
"""List of EM providers."""
"""List of LM providers."""
self._em_providers = em_providers
"""Provider restrictions."""
self._restrictions = restrictions
"""List of EM providers."""

self._allowed_providers = allowed_providers
self._blocked_providers = blocked_providers
self._allowed_models = allowed_models
self._blocked_models = blocked_models

self._last_read: Optional[int] = None
"""When the server last read the config file. If the file was not
modified after this time, then we can return the cached
`self._config`."""
self._last_read: Optional[int] = None

self._config: Optional[GlobalConfig] = None
"""In-memory cache of the `GlobalConfig` object parsed from the config
file."""
self._config: Optional[GlobalConfig] = None

self._init_config_schema()
self._init_validator()
Expand All @@ -140,6 +149,26 @@ def _init_config(self):
if os.path.exists(self.config_path):
with open(self.config_path, encoding="utf-8") as f:
config = GlobalConfig(**json.loads(f.read()))
lm_id = config.model_provider_id
em_id = config.embeddings_provider_id

# if the currently selected language or embedding model are
# forbidden, set them to `None` and log a warning.
if lm_id is not None and not self._validate_model(
lm_id, raise_exc=False
):
self.log.warning(
f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None."
)
config.model_provider_id = None
if em_id is not None and not self._validate_model(
em_id, raise_exc=False
):
self.log.warning(
f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None."
)
config.embeddings_provider_id = None

# re-write to the file to validate the config and apply any
# updates to the config file immediately
self._write_config(config)
Expand Down Expand Up @@ -181,33 +210,74 @@ def _validate_config(self, config: GlobalConfig):
_, lm_provider = get_lm_provider(
config.model_provider_id, self._lm_providers
)
# do not check config for blocked providers
if not is_provider_allowed(config.model_provider_id, self._restrictions):
assert not lm_provider
return

# verify model is declared by some provider
if not lm_provider:
raise ValueError(
f"No language model is associated with '{config.model_provider_id}'."
)

# verify model is not blocked
self._validate_model(config.model_provider_id)

# verify model is authenticated
_validate_provider_authn(config, lm_provider)

# validate embedding model config
if config.embeddings_provider_id:
_, em_provider = get_em_provider(
config.embeddings_provider_id, self._em_providers
)
# do not check config for blocked providers
if not is_provider_allowed(
config.embeddings_provider_id, self._restrictions
):
assert not em_provider
return

# verify model is declared by some provider
if not em_provider:
raise ValueError(
f"No embedding model is associated with '{config.embeddings_provider_id}'."
)

# verify model is not blocked
self._validate_model(config.embeddings_provider_id)

# verify model is authenticated
_validate_provider_authn(config, em_provider)

def _validate_model(self, model_id: str, raise_exc=True):
"""
Validates a model against the set of allow/blocklists specified by the
traitlets configuration, returning `True` if the model is allowed, and
raising a `BlockedModelError` otherwise. If `raise_exc=False`, this
function returns `False` if the model is not allowed.
"""

assert model_id is not None
components = model_id.split(":", 1)
assert len(components) == 2
provider_id, _ = components

try:
if self._allowed_providers and provider_id not in self._allowed_providers:
raise BlockedModelError(
"Model provider not included in the provider allowlist."
)

if self._blocked_providers and provider_id in self._blocked_providers:
raise BlockedModelError(
"Model provider included in the provider blocklist."
)

if self._allowed_models and model_id not in self._allowed_models:
raise BlockedModelError("Model not included in the model allowlist.")

if self._blocked_models and model_id in self._blocked_models:
raise BlockedModelError("Model included in the model blocklist.")
except BlockedModelError as e:
if raise_exc:
raise e
else:
return False

return True

def _write_config(self, new_config: GlobalConfig):
"""Updates configuration and persists it to disk. This accepts a
complete `GlobalConfig` object, and should not be called publicly."""
Expand Down
21 changes: 18 additions & 3 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,27 @@ class AiExtension(ExtensionApp):
allowed_models = List(
Unicode(),
default_value=None,
help="Language models to allow, as a list of global model IDs in the format `<provider>:<local-model-id>`. If `None`, all are allowed. Defaults to `None`.",
help="""
Language models to allow, as a list of global model IDs in the format
`<provider>:<local-model-id>`. If `None`, all are allowed. Defaults to
`None`.

Note: Currently, if `allowed_providers` is also set, then this field is
ignored. This is subject to change in a future non-major release. Using
both traits is considered to be undefined behavior at this time.
""",
allow_none=True,
config=True,
)

blocked_models = List(
Unicode(),
default_value=None,
help="Language models to block, as a list of global model IDs in the format `<provider>:<local-model-id>`. If `None`, none are blocked. Defaults to `None`.",
help="""
Language models to block, as a list of global model IDs in the format
`<provider>:<local-model-id>`. If `None`, none are blocked. Defaults to
`None`.
""",
allow_none=True,
config=True,
)
Expand Down Expand Up @@ -98,7 +110,10 @@ def initialize_settings(self):
log=self.log,
lm_providers=self.settings["lm_providers"],
em_providers=self.settings["em_providers"],
restrictions=restrictions,
allowed_providers=self.allowed_providers,
blocked_providers=self.blocked_providers,
allowed_models=self.allowed_models,
blocked_models=self.blocked_models,
)

self.log.info("Registered providers.")
Expand Down
73 changes: 58 additions & 15 deletions packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def common_cm_kwargs(config_path, schema_path):
"em_providers": em_providers,
"config_path": config_path,
"schema_path": schema_path,
"restrictions": {"allowed_providers": None, "blocked_providers": None},
"allowed_providers": None,
"blocked_providers": None,
"allowed_models": None,
"blocked_models": None,
}


Expand All @@ -46,6 +49,26 @@ def cm(common_cm_kwargs):
return ConfigManager(**common_cm_kwargs)


@pytest.fixture
def cm_with_blocklists(common_cm_kwargs):
kwargs = {
**common_cm_kwargs,
"blocked_providers": ["ai21"],
"blocked_models": ["cohere:medium"],
}
return ConfigManager(**kwargs)


@pytest.fixture
def cm_with_allowlists(common_cm_kwargs):
kwargs = {
**common_cm_kwargs,
"allowed_providers": ["ai21"],
"allowed_models": ["cohere:medium"],
}
return ConfigManager(**kwargs)


@pytest.fixture(autouse=True)
def reset(config_path, schema_path):
"""Fixture that deletes the config and config schema after each test."""
Expand Down Expand Up @@ -98,23 +121,43 @@ def test_snapshot_default_config(cm: ConfigManager, snapshot):
assert config_from_cm == snapshot(exclude=lambda prop, path: prop == "last_read")


def test_init_with_existing_config(
cm: ConfigManager, config_path: str, schema_path: str
):
def test_init_with_existing_config(cm: ConfigManager, common_cm_kwargs):
configure_to_cohere(cm)
del cm

log = logging.getLogger()
lm_providers = get_lm_providers()
em_providers = get_em_providers()
ConfigManager(
log=log,
lm_providers=lm_providers,
em_providers=em_providers,
config_path=config_path,
schema_path=schema_path,
restrictions={"allowed_providers": None, "blocked_providers": None},
)
ConfigManager(**common_cm_kwargs)


def test_init_with_blocklists(cm: ConfigManager, common_cm_kwargs):
configure_to_openai(cm)
del cm

blocked_providers = ["openai"] # blocks EM
blocked_models = ["openai-chat-new:gpt-3.5-turbo"] # blocks LM
kwargs = {
**common_cm_kwargs,
"blocked_providers": blocked_providers,
"blocked_models": blocked_models,
}
test_cm = ConfigManager(**kwargs)
assert test_cm._blocked_providers == blocked_providers
assert test_cm._blocked_models == blocked_models
assert test_cm.lm_gid == None
assert test_cm.em_gid == None


def test_init_with_allowlists(cm: ConfigManager, common_cm_kwargs):
configure_to_cohere(cm)
del cm

allowed_providers = ["openai"] # blocks both LM & EM

kwargs = {**common_cm_kwargs, "allowed_providers": allowed_providers}
test_cm = ConfigManager(**kwargs)
assert test_cm._allowed_providers == allowed_providers
assert test_cm._allowed_models == None
assert test_cm.lm_gid == None
assert test_cm.em_gid == None


def test_property_access_on_default_config(cm: ConfigManager):
Expand Down
Loading