diff --git a/homeassistant/components/zha/__init__.py b/homeassistant/components/zha/__init__.py index 5c8d9381a2efb..53b56012e5c5a 100644 --- a/homeassistant/components/zha/__init__.py +++ b/homeassistant/components/zha/__init__.py @@ -90,8 +90,8 @@ async def async_setup_entry(hass, config_entry): # pylint: disable=W0611, W0612 import zhaquirks # noqa - zha_gateway = ZHAGateway(hass, config) - await zha_gateway.async_initialize(config_entry) + zha_gateway = ZHAGateway(hass, config, config_entry) + await zha_gateway.async_initialize() device_registry = await \ hass.helpers.device_registry.async_get_registry() diff --git a/homeassistant/components/zha/core/gateway.py b/homeassistant/components/zha/core/gateway.py index 4a38bc647e63c..351ad1c5a67d1 100644 --- a/homeassistant/components/zha/core/gateway.py +++ b/homeassistant/components/zha/core/gateway.py @@ -15,7 +15,7 @@ from homeassistant.components.system_log import LogEntry, _figure_out_source from homeassistant.core import callback from homeassistant.helpers.device_registry import ( - async_get_registry as get_dev_reg) + CONNECTION_ZIGBEE, async_get_registry as get_dev_reg) from homeassistant.helpers.dispatcher import async_dispatcher_send from ..api import async_get_device_info @@ -46,13 +46,14 @@ class ZHAGateway: """Gateway that handles events that happen on the ZHA Zigbee network.""" - def __init__(self, hass, config): + def __init__(self, hass, config, config_entry): """Initialize the gateway.""" self._hass = hass self._config = config self._devices = {} self._device_registry = collections.defaultdict(list) self.zha_storage = None + self.ha_device_registry = None self.application_controller = None self.radio_description = None hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] = self @@ -62,14 +63,16 @@ def __init__(self, hass, config): } self.debug_enabled = False self._log_relay_handler = LogRelayHandler(hass, self) + self._config_entry = config_entry - async def async_initialize(self, config_entry): + async def async_initialize(self): """Initialize controller and connect radio.""" self.zha_storage = await async_get_registry(self._hass) + self.ha_device_registry = await get_dev_reg(self._hass) - usb_path = config_entry.data.get(CONF_USB_PATH) + usb_path = self._config_entry.data.get(CONF_USB_PATH) baudrate = self._config.get(CONF_BAUDRATE, DEFAULT_BAUDRATE) - radio_type = config_entry.data.get(CONF_RADIO_TYPE) + radio_type = self._config_entry.data.get(CONF_RADIO_TYPE) radio_details = RADIO_TYPES[radio_type][RADIO]() radio = radio_details[RADIO] @@ -147,11 +150,10 @@ async def _async_remove_device(self, device, entity_refs): for entity_ref in entity_refs: remove_tasks.append(entity_ref.remove_future) await asyncio.wait(remove_tasks) - ha_device_registry = await get_dev_reg(self._hass) - reg_device = ha_device_registry.async_get_device( + reg_device = self.ha_device_registry.async_get_device( {(DOMAIN, str(device.ieee))}, set()) if reg_device is not None: - ha_device_registry.async_remove_device(reg_device.id) + self.ha_device_registry.async_remove_device(reg_device.id) def device_removed(self, device): """Handle device being removed from the network.""" @@ -241,6 +243,14 @@ def _async_get_or_create_device(self, zigpy_device, is_new_join): if zha_device is None: zha_device = ZHADevice(self._hass, zigpy_device, self) self._devices[zigpy_device.ieee] = zha_device + self.ha_device_registry.async_get_or_create( + config_entry_id=self._config_entry.entry_id, + connections={(CONNECTION_ZIGBEE, str(zha_device.ieee))}, + identifiers={(DOMAIN, str(zha_device.ieee))}, + name=zha_device.name, + manufacturer=zha_device.manufacturer, + model=zha_device.model + ) if not is_new_join: entry = self.zha_storage.async_get_or_create(zha_device) zha_device.async_update_last_seen(entry.last_seen) @@ -322,7 +332,11 @@ async def async_device_initialized(self, device, is_new_join): ) if is_new_join: - device_info = async_get_device_info(self._hass, zha_device) + device_info = async_get_device_info( + self._hass, + zha_device, + self.ha_device_registry + ) async_dispatcher_send( self._hass, ZHA_GW_MSG, diff --git a/tests/components/zha/conftest.py b/tests/components/zha/conftest.py index cd0f615973d6c..763c59cd25501 100644 --- a/tests/components/zha/conftest.py +++ b/tests/components/zha/conftest.py @@ -5,6 +5,8 @@ from homeassistant.components.zha.core.const import ( DOMAIN, DATA_ZHA, COMPONENTS ) +from homeassistant.helpers.device_registry import ( + async_get_registry as get_dev_reg) from homeassistant.components.zha.core.gateway import ZHAGateway from homeassistant.components.zha.core.registries import \ establish_device_mappings @@ -24,7 +26,7 @@ def config_entry_fixture(hass): @pytest.fixture(name='zha_gateway') -async def zha_gateway_fixture(hass): +async def zha_gateway_fixture(hass, config_entry): """Fixture representing a zha gateway. Create a ZHAGateway object that can be used to interact with as if we @@ -37,8 +39,10 @@ async def zha_gateway_fixture(hass): hass.data[DATA_ZHA].get(component, {}) ) zha_storage = await async_get_registry(hass) - gateway = ZHAGateway(hass, {}) + dev_reg = await get_dev_reg(hass) + gateway = ZHAGateway(hass, {}, config_entry) gateway.zha_storage = zha_storage + gateway.ha_device_registry = dev_reg return gateway