Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve schema typing (3) #120521

Merged
merged 1 commit into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Improve schema typing (3)
  • Loading branch information
cdce8p committed Jun 26, 2024
commit c0592d72fd0c3d92f6b042efc0462c661232df30
4 changes: 2 additions & 2 deletions homeassistant/components/input_button/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ class InputButtonStorageCollection(collection.DictStorageCollection):

CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS)

async def _process_create_data(self, data: dict) -> vol.Schema:
async def _process_create_data(self, data: dict) -> dict[str, str]:
"""Validate the config is valid."""
return self.CREATE_UPDATE_SCHEMA(data)
return self.CREATE_UPDATE_SCHEMA(data) # type: ignore[no-any-return]

@callback
def _get_suggested_id(self, info: dict) -> str:
Expand Down
4 changes: 2 additions & 2 deletions homeassistant/components/input_text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ class InputTextStorageCollection(collection.DictStorageCollection):

CREATE_UPDATE_SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, _cv_input_text))

async def _process_create_data(self, data: dict[str, Any]) -> vol.Schema:
async def _process_create_data(self, data: dict[str, Any]) -> dict[str, Any]:
"""Validate the config is valid."""
return self.CREATE_UPDATE_SCHEMA(data)
return self.CREATE_UPDATE_SCHEMA(data) # type: ignore[no-any-return]

@callback
def _get_suggested_id(self, info: dict[str, Any]) -> str:
Expand Down
6 changes: 4 additions & 2 deletions homeassistant/components/light/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def is_on(hass: HomeAssistant, entity_id: str) -> bool:


def preprocess_turn_on_alternatives(
hass: HomeAssistant, params: dict[str, Any]
hass: HomeAssistant, params: dict[str, Any] | dict[str | vol.Optional, Any]
) -> None:
"""Process extra data for turn light on request.

Expand Down Expand Up @@ -406,7 +406,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: # noqa:
# of the light base platform.
hass.async_create_task(profiles.async_initialize(), eager_start=True)

def preprocess_data(data: dict[str, Any]) -> dict[str | vol.Optional, Any]:
def preprocess_data(
data: dict[str | vol.Optional, Any],
) -> dict[str | vol.Optional, Any]:
"""Preprocess the service data."""
base: dict[str | vol.Optional, Any] = {
entity_field: data.pop(entity_field)
Expand Down
8 changes: 5 additions & 3 deletions homeassistant/components/motioneye/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,14 +226,16 @@ async def async_step_init(
if self.show_advanced_options:
# The input URL is not validated as being a URL, to allow for the possibility
# the template input won't be a valid URL until after it's rendered
stream_kwargs = {}
description: dict[str, str] | None = None
if CONF_STREAM_URL_TEMPLATE in self._config_entry.options:
stream_kwargs["description"] = {
description = {
"suggested_value": self._config_entry.options[
CONF_STREAM_URL_TEMPLATE
]
}

schema[vol.Optional(CONF_STREAM_URL_TEMPLATE, **stream_kwargs)] = str
schema[vol.Optional(CONF_STREAM_URL_TEMPLATE, description=description)] = (
str
)

return self.async_show_form(step_id="init", data_schema=vol.Schema(schema))
5 changes: 3 additions & 2 deletions homeassistant/components/zha/device_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,9 @@ async def async_get_action_capabilities(
hass: HomeAssistant, config: ConfigType
) -> dict[str, vol.Schema]:
"""List action capabilities."""

return {"extra_fields": DEVICE_ACTION_SCHEMAS.get(config[CONF_TYPE], {})}
if (fields := DEVICE_ACTION_SCHEMAS.get(config[CONF_TYPE])) is None:
return {}
return {"extra_fields": fields}


async def _execute_service_based_action(
Expand Down
6 changes: 2 additions & 4 deletions homeassistant/components/zwave_js/triggers/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,8 @@ def validate_event_data(obj: dict) -> dict:
except ValidationError as exc:
# Filter out required field errors if keys can be missing, and if there are
# still errors, raise an exception
if errors := [
error for error in exc.errors() if error["type"] != "value_error.missing"
]:
raise vol.MultipleInvalid(errors) from exc
if [error for error in exc.errors() if error["type"] != "value_error.missing"]:
raise vol.MultipleInvalid from exc
return obj


Expand Down
9 changes: 5 additions & 4 deletions homeassistant/data_entry_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
import abc
import asyncio
from collections import defaultdict
from collections.abc import Callable, Container, Iterable, Mapping
from collections.abc import Callable, Container, Hashable, Iterable, Mapping
from contextlib import suppress
import copy
from dataclasses import dataclass
from enum import StrEnum
from functools import partial
import logging
from types import MappingProxyType
from typing import Any, Generic, Required, TypedDict
from typing import Any, Generic, Required, TypedDict, cast

from typing_extensions import TypeVar
import voluptuous as vol
Expand Down Expand Up @@ -120,7 +120,7 @@ class InvalidData(vol.Invalid): # type: ignore[misc]
def __init__(
self,
message: str,
path: list[str | vol.Marker] | None,
path: list[Hashable] | None,
error_message: str | None,
schema_errors: dict[str, Any],
**kwargs: Any,
Expand Down Expand Up @@ -384,6 +384,7 @@ async def _async_configure(
if (
data_schema := cur_step.get("data_schema")
) is not None and user_input is not None:
data_schema = cast(vol.Schema, data_schema)
try:
user_input = data_schema(user_input) # type: ignore[operator]
except vol.Invalid as ex:
Expand Down Expand Up @@ -694,7 +695,7 @@ def add_suggested_values_to_schema(
):
# Copy the marker to not modify the flow schema
new_key = copy.copy(key)
new_key.description = {"suggested_value": suggested_values[key]}
new_key.description = {"suggested_value": suggested_values[key.schema]}
bdraco marked this conversation as resolved.
Show resolved Hide resolved
schema[new_key] = val
return vol.Schema(schema)

Expand Down
12 changes: 6 additions & 6 deletions homeassistant/helpers/config_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,7 @@ def removed(

def key_value_schemas(
key: str,
value_schemas: dict[Hashable, VolSchemaType],
value_schemas: dict[Hashable, VolSchemaType | Callable[[Any], dict[str, Any]]],
default_schema: VolSchemaType | None = None,
default_description: str | None = None,
) -> Callable[[Any], dict[Hashable, Any]]:
Expand Down Expand Up @@ -1016,12 +1016,12 @@ def key_value_validator(value: Any) -> dict[Hashable, Any]:
# Validator helpers


def key_dependency(
def key_dependency[_KT: Hashable, _VT](
key: Hashable, dependency: Hashable
) -> Callable[[dict[Hashable, Any]], dict[Hashable, Any]]:
) -> Callable[[dict[_KT, _VT]], dict[_KT, _VT]]:
"""Validate that all dependencies exist for key."""

def validator(value: dict[Hashable, Any]) -> dict[Hashable, Any]:
def validator(value: dict[_KT, _VT]) -> dict[_KT, _VT]:
"""Test dependencies."""
if not isinstance(value, dict):
raise vol.Invalid("key dependencies require a dict")
Expand Down Expand Up @@ -1405,13 +1405,13 @@ def script_action(value: Any) -> dict:
)


def STATE_CONDITION_SCHEMA(value: Any) -> dict:
def STATE_CONDITION_SCHEMA(value: Any) -> dict[str, Any]:
"""Validate a state condition."""
if not isinstance(value, dict):
raise vol.Invalid("Expected a dictionary")

if CONF_ATTRIBUTE in value:
validated: dict = STATE_CONDITION_ATTRIBUTE_SCHEMA(value)
validated: dict[str, Any] = STATE_CONDITION_ATTRIBUTE_SCHEMA(value)
else:
validated = STATE_CONDITION_STATE_SCHEMA(value)

Expand Down
17 changes: 10 additions & 7 deletions homeassistant/helpers/intent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from abc import abstractmethod
import asyncio
from collections.abc import Collection, Coroutine, Iterable
from collections.abc import Callable, Collection, Coroutine, Iterable
import dataclasses
from dataclasses import dataclass, field
from enum import Enum, auto
Expand Down Expand Up @@ -37,6 +37,9 @@

_LOGGER = logging.getLogger(__name__)
type _SlotsType = dict[str, Any]
type _IntentSlotsType = dict[
str | tuple[str, str], VolSchemaType | Callable[[Any], Any]
]

INTENT_TURN_OFF = "HassTurnOff"
INTENT_TURN_ON = "HassTurnOn"
Expand Down Expand Up @@ -808,8 +811,8 @@ def __init__(
self,
intent_type: str,
speech: str | None = None,
required_slots: dict[str | tuple[str, str], VolSchemaType] | None = None,
optional_slots: dict[str | tuple[str, str], VolSchemaType] | None = None,
required_slots: _IntentSlotsType | None = None,
optional_slots: _IntentSlotsType | None = None,
required_domains: set[str] | None = None,
required_features: int | None = None,
required_states: set[str] | None = None,
Expand All @@ -825,7 +828,7 @@ def __init__(
self.description = description
self.platforms = platforms

self.required_slots: dict[tuple[str, str], VolSchemaType] = {}
self.required_slots: _IntentSlotsType = {}
if required_slots:
for key, value_schema in required_slots.items():
if isinstance(key, str):
Expand All @@ -834,7 +837,7 @@ def __init__(

self.required_slots[key] = value_schema

self.optional_slots: dict[tuple[str, str], VolSchemaType] = {}
self.optional_slots: _IntentSlotsType = {}
if optional_slots:
for key, value_schema in optional_slots.items():
if isinstance(key, str):
Expand Down Expand Up @@ -1108,8 +1111,8 @@ def __init__(
domain: str,
service: str,
speech: str | None = None,
required_slots: dict[str | tuple[str, str], VolSchemaType] | None = None,
optional_slots: dict[str | tuple[str, str], VolSchemaType] | None = None,
required_slots: _IntentSlotsType | None = None,
optional_slots: _IntentSlotsType | None = None,
required_domains: set[str] | None = None,
required_features: int | None = None,
required_states: set[str] | None = None,
Expand Down
8 changes: 5 additions & 3 deletions homeassistant/helpers/schema_config_entry_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ async def _async_form_step(
and key.default is not vol.UNDEFINED
and key not in self._options
):
user_input[str(key.schema)] = key.default()
user_input[str(key.schema)] = cast(
Callable[[], Any], key.default
)()

if user_input is not None and form_step.validate_user_input is not None:
# Do extra validation of user input
Expand Down Expand Up @@ -215,7 +217,7 @@ def _update_and_remove_omitted_optional_keys(
)
):
# Key not present, delete keys old value (if present) too
values.pop(key, None)
values.pop(key.schema, None)

async def _show_next_step_or_create_entry(
self, form_step: SchemaFlowFormStep
Expand Down Expand Up @@ -491,7 +493,7 @@ def wrapped_entity_config_entry_title(
def entity_selector_without_own_entities(
handler: SchemaOptionsFlowHandler,
entity_selector_config: selector.EntitySelectorConfig,
) -> vol.Schema:
) -> selector.EntitySelector:
"""Return an entity selector which excludes own entities."""
entity_registry = er.async_get(handler.hass)
entities = er.async_entries_for_config_entry(
Expand Down