Skip to content

Commit

Permalink
Backport PR #446: Model allowlist and blocklists (#452)
Browse files Browse the repository at this point in the history
Co-authored-by: david qiu <david@qiu.dev>
Co-authored-by: Jason Weill <93281816+JasonWeill@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 8, 2023
1 parent 419acd5 commit 0502887
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 49 deletions.
110 changes: 90 additions & 20 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import shutil
import time
from typing import Optional, Union
from typing import List, Optional, Union

from deepmerge import always_merger as Merger
from jsonschema import Draft202012Validator as Validator
Expand All @@ -12,10 +12,8 @@
AnyProvider,
EmProvidersDict,
LmProvidersDict,
ProviderRestrictions,
get_em_provider,
get_lm_provider,
is_provider_allowed,
)
from jupyter_core.paths import jupyter_data_dir
from traitlets import Integer, Unicode
Expand Down Expand Up @@ -57,6 +55,10 @@ class KeyEmptyError(Exception):
pass


class BlockedModelError(Exception):
pass


def _validate_provider_authn(config: GlobalConfig, provider: AnyProvider):
# TODO: handle non-env auth strategies
if not provider.auth_strategy or provider.auth_strategy.type != "env":
Expand Down Expand Up @@ -99,27 +101,34 @@ def __init__(
log: Logger,
lm_providers: LmProvidersDict,
em_providers: EmProvidersDict,
restrictions: ProviderRestrictions,
allowed_providers: Optional[List[str]],
blocked_providers: Optional[List[str]],
allowed_models: Optional[List[str]],
blocked_models: Optional[List[str]],
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.log = log
"""List of LM providers."""

self._lm_providers = lm_providers
"""List of EM providers."""
"""List of LM providers."""
self._em_providers = em_providers
"""Provider restrictions."""
self._restrictions = restrictions
"""List of EM providers."""

self._allowed_providers = allowed_providers
self._blocked_providers = blocked_providers
self._allowed_models = allowed_models
self._blocked_models = blocked_models

self._last_read: Optional[int] = None
"""When the server last read the config file. If the file was not
modified after this time, then we can return the cached
`self._config`."""
self._last_read: Optional[int] = None

self._config: Optional[GlobalConfig] = None
"""In-memory cache of the `GlobalConfig` object parsed from the config
file."""
self._config: Optional[GlobalConfig] = None

self._init_config_schema()
self._init_validator()
Expand All @@ -140,6 +149,26 @@ def _init_config(self):
if os.path.exists(self.config_path):
with open(self.config_path, encoding="utf-8") as f:
config = GlobalConfig(**json.loads(f.read()))
lm_id = config.model_provider_id
em_id = config.embeddings_provider_id

# if the currently selected language or embedding model are
# forbidden, set them to `None` and log a warning.
if lm_id is not None and not self._validate_model(
lm_id, raise_exc=False
):
self.log.warning(
f"Language model {lm_id} is forbidden by current allow/blocklists. Setting to None."
)
config.model_provider_id = None
if em_id is not None and not self._validate_model(
em_id, raise_exc=False
):
self.log.warning(
f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None."
)
config.embeddings_provider_id = None

# re-write to the file to validate the config and apply any
# updates to the config file immediately
self._write_config(config)
Expand Down Expand Up @@ -181,33 +210,74 @@ def _validate_config(self, config: GlobalConfig):
_, lm_provider = get_lm_provider(
config.model_provider_id, self._lm_providers
)
# do not check config for blocked providers
if not is_provider_allowed(config.model_provider_id, self._restrictions):
assert not lm_provider
return

# verify model is declared by some provider
if not lm_provider:
raise ValueError(
f"No language model is associated with '{config.model_provider_id}'."
)

# verify model is not blocked
self._validate_model(config.model_provider_id)

# verify model is authenticated
_validate_provider_authn(config, lm_provider)

# validate embedding model config
if config.embeddings_provider_id:
_, em_provider = get_em_provider(
config.embeddings_provider_id, self._em_providers
)
# do not check config for blocked providers
if not is_provider_allowed(
config.embeddings_provider_id, self._restrictions
):
assert not em_provider
return

# verify model is declared by some provider
if not em_provider:
raise ValueError(
f"No embedding model is associated with '{config.embeddings_provider_id}'."
)

# verify model is not blocked
self._validate_model(config.embeddings_provider_id)

# verify model is authenticated
_validate_provider_authn(config, em_provider)

def _validate_model(self, model_id: str, raise_exc=True):
"""
Validates a model against the set of allow/blocklists specified by the
traitlets configuration, returning `True` if the model is allowed, and
raising a `BlockedModelError` otherwise. If `raise_exc=False`, this
function returns `False` if the model is not allowed.
"""

assert model_id is not None
components = model_id.split(":", 1)
assert len(components) == 2
provider_id, _ = components

try:
if self._allowed_providers and provider_id not in self._allowed_providers:
raise BlockedModelError(
"Model provider not included in the provider allowlist."
)

if self._blocked_providers and provider_id in self._blocked_providers:
raise BlockedModelError(
"Model provider included in the provider blocklist."
)

if self._allowed_models and model_id not in self._allowed_models:
raise BlockedModelError("Model not included in the model allowlist.")

if self._blocked_models and model_id in self._blocked_models:
raise BlockedModelError("Model included in the model blocklist.")
except BlockedModelError as e:
if raise_exc:
raise e
else:
return False

return True

def _write_config(self, new_config: GlobalConfig):
"""Updates configuration and persists it to disk. This accepts a
complete `GlobalConfig` object, and should not be called publicly."""
Expand Down
13 changes: 12 additions & 1 deletion packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,18 @@ class AiExtension(ExtensionApp):

def initialize_settings(self):
start = time.time()

# Read from allowlist and blocklist
restrictions = {
"allowed_providers": self.allowed_providers,
"blocked_providers": self.blocked_providers,
}
self.settings["allowed_models"] = self.allowed_models
self.settings["blocked_models"] = self.blocked_models
self.log.info(f"Configured provider allowlist: {self.allowed_providers}")
self.log.info(f"Configured provider blocklist: {self.blocked_providers}")
self.log.info(f"Configured model allowlist: {self.allowed_models}")
self.log.info(f"Configured model blocklist: {self.blocked_models}")

self.settings["model_parameters"] = self.model_parameters
self.log.info(f"Configured model parameters: {self.model_parameters}")
Expand All @@ -116,7 +124,10 @@ def initialize_settings(self):
log=self.log,
lm_providers=self.settings["lm_providers"],
em_providers=self.settings["em_providers"],
restrictions=restrictions,
allowed_providers=self.allowed_providers,
blocked_providers=self.blocked_providers,
allowed_models=self.allowed_models,
blocked_models=self.blocked_models,
)

self.log.info("Registered providers.")
Expand Down
70 changes: 57 additions & 13 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import uuid
from asyncio import AbstractEventLoop
from dataclasses import asdict
from typing import TYPE_CHECKING, Dict, List
from typing import TYPE_CHECKING, Dict, List, Optional

import tornado
from jupyter_ai.chat_handlers import BaseChatHandler
Expand Down Expand Up @@ -237,14 +237,58 @@ def on_close(self):
self.log.debug("Chat clients: %s", self.root_chat_handlers.keys())


class ModelProviderHandler(BaseAPIHandler):
class ProviderHandler(BaseAPIHandler):
"""
Helper base class used for HTTP handlers hosting endpoints relating to
providers. Wrapper around BaseAPIHandler.
"""

@property
def lm_providers(self) -> Dict[str, "BaseProvider"]:
return self.settings["lm_providers"]

@property
def em_providers(self) -> Dict[str, "BaseEmbeddingsProvider"]:
return self.settings["em_providers"]

@property
def allowed_models(self) -> Optional[List[str]]:
return self.settings["allowed_models"]

@property
def blocked_models(self) -> Optional[List[str]]:
return self.settings["blocked_models"]

def _filter_blocked_models(self, providers: List[ListProvidersEntry]):
"""
Satisfy the model-level allow/blocklist by filtering models accordingly.
The provider-level allow/blocklist is already handled in
`AiExtension.initialize_settings()`.
"""
if self.blocked_models is None and self.allowed_models is None:
return providers

def filter_predicate(local_model_id: str):
model_id = provider.id + ":" + local_model_id
if self.blocked_models:
return model_id not in self.blocked_models
else:
return model_id in self.allowed_models

# filter out every model w/ model ID according to allow/blocklist
for provider in providers:
provider.models = list(filter(filter_predicate, provider.models))

# filter out every provider with no models which satisfy the allow/blocklist, then return
return filter((lambda p: len(p.models) > 0), providers)


class ModelProviderHandler(ProviderHandler):
@web.authenticated
def get(self):
providers = []

# Step 1: gather providers
for provider in self.lm_providers.values():
# skip old legacy OpenAI chat provider used only in magics
if provider.id == "openai-chat":
Expand All @@ -267,17 +311,16 @@ def get(self):
)
)

response = ListProvidersResponse(
providers=sorted(providers, key=lambda p: p.name)
)
self.finish(response.json())
# Step 2: sort & filter providers
providers = self._filter_blocked_models(providers)
providers = sorted(providers, key=lambda p: p.name)

# Finally, yield response.
response = ListProvidersResponse(providers=providers)
self.finish(response.json())

class EmbeddingsModelProviderHandler(BaseAPIHandler):
@property
def em_providers(self) -> Dict[str, "BaseEmbeddingsProvider"]:
return self.settings["em_providers"]

class EmbeddingsModelProviderHandler(ProviderHandler):
@web.authenticated
def get(self):
providers = []
Expand All @@ -293,9 +336,10 @@ def get(self):
)
)

response = ListProvidersResponse(
providers=sorted(providers, key=lambda p: p.name)
)
providers = self._filter_blocked_models(providers)
providers = sorted(providers, key=lambda p: p.name)

response = ListProvidersResponse(providers=providers)
self.finish(response.json())


Expand Down
Loading

0 comments on commit 0502887

Please sign in to comment.