From c0592d72fd0c3d92f6b042efc0462c661232df30 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Wed, 26 Jun 2024 10:58:15 +0200 Subject: [PATCH] Improve schema typing (3) --- .../components/input_button/__init__.py | 4 ++-- homeassistant/components/input_text/__init__.py | 4 ++-- homeassistant/components/light/__init__.py | 6 ++++-- .../components/motioneye/config_flow.py | 8 +++++--- homeassistant/components/zha/device_action.py | 5 +++-- .../components/zwave_js/triggers/event.py | 6 ++---- homeassistant/data_entry_flow.py | 9 +++++---- homeassistant/helpers/config_validation.py | 12 ++++++------ homeassistant/helpers/intent.py | 17 ++++++++++------- .../helpers/schema_config_entry_flow.py | 8 +++++--- 10 files changed, 44 insertions(+), 35 deletions(-) diff --git a/homeassistant/components/input_button/__init__.py b/homeassistant/components/input_button/__init__.py index e70bbacd9338f8..47ec36969c6dd7 100644 --- a/homeassistant/components/input_button/__init__.py +++ b/homeassistant/components/input_button/__init__.py @@ -58,9 +58,9 @@ class InputButtonStorageCollection(collection.DictStorageCollection): CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS) - async def _process_create_data(self, data: dict) -> vol.Schema: + async def _process_create_data(self, data: dict) -> dict[str, str]: """Validate the config is valid.""" - return self.CREATE_UPDATE_SCHEMA(data) + return self.CREATE_UPDATE_SCHEMA(data) # type: ignore[no-any-return] @callback def _get_suggested_id(self, info: dict) -> str: diff --git a/homeassistant/components/input_text/__init__.py b/homeassistant/components/input_text/__init__.py index 3d75ff9f5c2b97..7d8f66636733b6 100644 --- a/homeassistant/components/input_text/__init__.py +++ b/homeassistant/components/input_text/__init__.py @@ -163,9 +163,9 @@ class InputTextStorageCollection(collection.DictStorageCollection): CREATE_UPDATE_SCHEMA = vol.Schema(vol.All(STORAGE_FIELDS, _cv_input_text)) - async def _process_create_data(self, data: dict[str, Any]) -> vol.Schema: + async def _process_create_data(self, data: dict[str, Any]) -> dict[str, Any]: """Validate the config is valid.""" - return self.CREATE_UPDATE_SCHEMA(data) + return self.CREATE_UPDATE_SCHEMA(data) # type: ignore[no-any-return] @callback def _get_suggested_id(self, info: dict[str, Any]) -> str: diff --git a/homeassistant/components/light/__init__.py b/homeassistant/components/light/__init__.py index 16367c35ec5422..b61625edaf2604 100644 --- a/homeassistant/components/light/__init__.py +++ b/homeassistant/components/light/__init__.py @@ -302,7 +302,7 @@ def is_on(hass: HomeAssistant, entity_id: str) -> bool: def preprocess_turn_on_alternatives( - hass: HomeAssistant, params: dict[str, Any] + hass: HomeAssistant, params: dict[str, Any] | dict[str | vol.Optional, Any] ) -> None: """Process extra data for turn light on request. @@ -406,7 +406,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: # noqa: # of the light base platform. hass.async_create_task(profiles.async_initialize(), eager_start=True) - def preprocess_data(data: dict[str, Any]) -> dict[str | vol.Optional, Any]: + def preprocess_data( + data: dict[str | vol.Optional, Any], + ) -> dict[str | vol.Optional, Any]: """Preprocess the service data.""" base: dict[str | vol.Optional, Any] = { entity_field: data.pop(entity_field) diff --git a/homeassistant/components/motioneye/config_flow.py b/homeassistant/components/motioneye/config_flow.py index bbbd2bc7fba1dd..49059b528db571 100644 --- a/homeassistant/components/motioneye/config_flow.py +++ b/homeassistant/components/motioneye/config_flow.py @@ -226,14 +226,16 @@ async def async_step_init( if self.show_advanced_options: # The input URL is not validated as being a URL, to allow for the possibility # the template input won't be a valid URL until after it's rendered - stream_kwargs = {} + description: dict[str, str] | None = None if CONF_STREAM_URL_TEMPLATE in self._config_entry.options: - stream_kwargs["description"] = { + description = { "suggested_value": self._config_entry.options[ CONF_STREAM_URL_TEMPLATE ] } - schema[vol.Optional(CONF_STREAM_URL_TEMPLATE, **stream_kwargs)] = str + schema[vol.Optional(CONF_STREAM_URL_TEMPLATE, description=description)] = ( + str + ) return self.async_show_form(step_id="init", data_schema=vol.Schema(schema)) diff --git a/homeassistant/components/zha/device_action.py b/homeassistant/components/zha/device_action.py index 8f5a03a7fe51f8..a0f16d61f41d3e 100644 --- a/homeassistant/components/zha/device_action.py +++ b/homeassistant/components/zha/device_action.py @@ -167,8 +167,9 @@ async def async_get_action_capabilities( hass: HomeAssistant, config: ConfigType ) -> dict[str, vol.Schema]: """List action capabilities.""" - - return {"extra_fields": DEVICE_ACTION_SCHEMAS.get(config[CONF_TYPE], {})} + if (fields := DEVICE_ACTION_SCHEMAS.get(config[CONF_TYPE])) is None: + return {} + return {"extra_fields": fields} async def _execute_service_based_action( diff --git a/homeassistant/components/zwave_js/triggers/event.py b/homeassistant/components/zwave_js/triggers/event.py index 921cae19b3a5c3..9938d08408ceb7 100644 --- a/homeassistant/components/zwave_js/triggers/event.py +++ b/homeassistant/components/zwave_js/triggers/event.py @@ -80,10 +80,8 @@ def validate_event_data(obj: dict) -> dict: except ValidationError as exc: # Filter out required field errors if keys can be missing, and if there are # still errors, raise an exception - if errors := [ - error for error in exc.errors() if error["type"] != "value_error.missing" - ]: - raise vol.MultipleInvalid(errors) from exc + if [error for error in exc.errors() if error["type"] != "value_error.missing"]: + raise vol.MultipleInvalid from exc return obj diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index 155e64d259e4ce..f632e3e4ddecf7 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -5,7 +5,7 @@ import abc import asyncio from collections import defaultdict -from collections.abc import Callable, Container, Iterable, Mapping +from collections.abc import Callable, Container, Hashable, Iterable, Mapping from contextlib import suppress import copy from dataclasses import dataclass @@ -13,7 +13,7 @@ from functools import partial import logging from types import MappingProxyType -from typing import Any, Generic, Required, TypedDict +from typing import Any, Generic, Required, TypedDict, cast from typing_extensions import TypeVar import voluptuous as vol @@ -120,7 +120,7 @@ class InvalidData(vol.Invalid): # type: ignore[misc] def __init__( self, message: str, - path: list[str | vol.Marker] | None, + path: list[Hashable] | None, error_message: str | None, schema_errors: dict[str, Any], **kwargs: Any, @@ -384,6 +384,7 @@ async def _async_configure( if ( data_schema := cur_step.get("data_schema") ) is not None and user_input is not None: + data_schema = cast(vol.Schema, data_schema) try: user_input = data_schema(user_input) # type: ignore[operator] except vol.Invalid as ex: @@ -694,7 +695,7 @@ def add_suggested_values_to_schema( ): # Copy the marker to not modify the flow schema new_key = copy.copy(key) - new_key.description = {"suggested_value": suggested_values[key]} + new_key.description = {"suggested_value": suggested_values[key.schema]} schema[new_key] = val return vol.Schema(schema) diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index 558baaeb7799ad..58c76a40c8e1d5 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -981,7 +981,7 @@ def removed( def key_value_schemas( key: str, - value_schemas: dict[Hashable, VolSchemaType], + value_schemas: dict[Hashable, VolSchemaType | Callable[[Any], dict[str, Any]]], default_schema: VolSchemaType | None = None, default_description: str | None = None, ) -> Callable[[Any], dict[Hashable, Any]]: @@ -1016,12 +1016,12 @@ def key_value_validator(value: Any) -> dict[Hashable, Any]: # Validator helpers -def key_dependency( +def key_dependency[_KT: Hashable, _VT]( key: Hashable, dependency: Hashable -) -> Callable[[dict[Hashable, Any]], dict[Hashable, Any]]: +) -> Callable[[dict[_KT, _VT]], dict[_KT, _VT]]: """Validate that all dependencies exist for key.""" - def validator(value: dict[Hashable, Any]) -> dict[Hashable, Any]: + def validator(value: dict[_KT, _VT]) -> dict[_KT, _VT]: """Test dependencies.""" if not isinstance(value, dict): raise vol.Invalid("key dependencies require a dict") @@ -1405,13 +1405,13 @@ def script_action(value: Any) -> dict: ) -def STATE_CONDITION_SCHEMA(value: Any) -> dict: +def STATE_CONDITION_SCHEMA(value: Any) -> dict[str, Any]: """Validate a state condition.""" if not isinstance(value, dict): raise vol.Invalid("Expected a dictionary") if CONF_ATTRIBUTE in value: - validated: dict = STATE_CONDITION_ATTRIBUTE_SCHEMA(value) + validated: dict[str, Any] = STATE_CONDITION_ATTRIBUTE_SCHEMA(value) else: validated = STATE_CONDITION_STATE_SCHEMA(value) diff --git a/homeassistant/helpers/intent.py b/homeassistant/helpers/intent.py index e191bddf10234e..1bf78ae3a29ce3 100644 --- a/homeassistant/helpers/intent.py +++ b/homeassistant/helpers/intent.py @@ -4,7 +4,7 @@ from abc import abstractmethod import asyncio -from collections.abc import Collection, Coroutine, Iterable +from collections.abc import Callable, Collection, Coroutine, Iterable import dataclasses from dataclasses import dataclass, field from enum import Enum, auto @@ -37,6 +37,9 @@ _LOGGER = logging.getLogger(__name__) type _SlotsType = dict[str, Any] +type _IntentSlotsType = dict[ + str | tuple[str, str], VolSchemaType | Callable[[Any], Any] +] INTENT_TURN_OFF = "HassTurnOff" INTENT_TURN_ON = "HassTurnOn" @@ -808,8 +811,8 @@ def __init__( self, intent_type: str, speech: str | None = None, - required_slots: dict[str | tuple[str, str], VolSchemaType] | None = None, - optional_slots: dict[str | tuple[str, str], VolSchemaType] | None = None, + required_slots: _IntentSlotsType | None = None, + optional_slots: _IntentSlotsType | None = None, required_domains: set[str] | None = None, required_features: int | None = None, required_states: set[str] | None = None, @@ -825,7 +828,7 @@ def __init__( self.description = description self.platforms = platforms - self.required_slots: dict[tuple[str, str], VolSchemaType] = {} + self.required_slots: _IntentSlotsType = {} if required_slots: for key, value_schema in required_slots.items(): if isinstance(key, str): @@ -834,7 +837,7 @@ def __init__( self.required_slots[key] = value_schema - self.optional_slots: dict[tuple[str, str], VolSchemaType] = {} + self.optional_slots: _IntentSlotsType = {} if optional_slots: for key, value_schema in optional_slots.items(): if isinstance(key, str): @@ -1108,8 +1111,8 @@ def __init__( domain: str, service: str, speech: str | None = None, - required_slots: dict[str | tuple[str, str], VolSchemaType] | None = None, - optional_slots: dict[str | tuple[str, str], VolSchemaType] | None = None, + required_slots: _IntentSlotsType | None = None, + optional_slots: _IntentSlotsType | None = None, required_domains: set[str] | None = None, required_features: int | None = None, required_states: set[str] | None = None, diff --git a/homeassistant/helpers/schema_config_entry_flow.py b/homeassistant/helpers/schema_config_entry_flow.py index 05e4a852ad9863..7463c9945b24b6 100644 --- a/homeassistant/helpers/schema_config_entry_flow.py +++ b/homeassistant/helpers/schema_config_entry_flow.py @@ -175,7 +175,9 @@ async def _async_form_step( and key.default is not vol.UNDEFINED and key not in self._options ): - user_input[str(key.schema)] = key.default() + user_input[str(key.schema)] = cast( + Callable[[], Any], key.default + )() if user_input is not None and form_step.validate_user_input is not None: # Do extra validation of user input @@ -215,7 +217,7 @@ def _update_and_remove_omitted_optional_keys( ) ): # Key not present, delete keys old value (if present) too - values.pop(key, None) + values.pop(key.schema, None) async def _show_next_step_or_create_entry( self, form_step: SchemaFlowFormStep @@ -491,7 +493,7 @@ def wrapped_entity_config_entry_title( def entity_selector_without_own_entities( handler: SchemaOptionsFlowHandler, entity_selector_config: selector.EntitySelectorConfig, -) -> vol.Schema: +) -> selector.EntitySelector: """Return an entity selector which excludes own entities.""" entity_registry = er.async_get(handler.hass) entities = er.async_entries_for_config_entry(