Skip to content

Commit

Permalink
[py] Add low-level sync API to use DevTools
Browse files Browse the repository at this point in the history
  • Loading branch information
p0deje committed May 19, 2024
1 parent 0795c03 commit 690e94f
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 0 deletions.
1 change: 1 addition & 0 deletions py/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def event_class(method):
''' A decorator that registers a class as an event class. '''
def decorate(cls):
_event_parsers[method] = cls
cls.event_class = method
return cls
return decorate
Expand Down
1 change: 1 addition & 0 deletions py/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@ trio-websocket==0.9.2
twine==4.0.2
typing_extensions==4.9.0
urllib3[socks]==2.0.7
websocket-client==1.8.0
wsproto==1.2.0
zipp==3.17.0
29 changes: 29 additions & 0 deletions py/selenium/webdriver/remote/webdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@
from .shadowroot import ShadowRoot
from .switch_to import SwitchTo
from .webelement import WebElement
from .websocket_connection import WebSocketConnection

cdp = None
devtools = None


def import_cdp():
Expand Down Expand Up @@ -206,6 +208,7 @@ def __init__(
self._authenticator_id = None
self.start_client()
self.start_session(capabilities)
self._websocket_connection = None

def __repr__(self):
return f'<{type(self).__module__}.{type(self).__name__} (session="{self.session_id}")>'
Expand Down Expand Up @@ -1017,6 +1020,32 @@ def get_log(self, log_type):
"""
return self.execute(Command.GET_LOG, {"type": log_type})["value"]

def start_devtools(self):
global devtools
if self._websocket_connection:
return devtools, self._websocket_connection
else:
global cdp
import_cdp()

if not devtools:
if self.caps.get("se:cdp"):
ws_url = self.caps.get("se:cdp")
version = self.caps.get("se:cdpVersion").split(".")[0]
else:
version, ws_url = self._get_cdp_details()

if not ws_url:
raise WebDriverException("Unable to find url to connect to from capabilities")

devtools = cdp.import_devtools(version)
self._websocket_connection = WebSocketConnection(ws_url)
targets = self._websocket_connection.execute(devtools.target.get_targets())
target_id = targets[0].target_id
session = self._websocket_connection.execute(devtools.target.attach_to_target(target_id, True))
self._websocket_connection.session_id = session
return devtools, self._websocket_connection

@asynccontextmanager
async def bidi_connection(self):
global cdp
Expand Down
124 changes: 124 additions & 0 deletions py/selenium/webdriver/remote/websocket_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import logging
from ssl import CERT_NONE
from threading import Thread
from time import sleep

from websocket import WebSocketApp

logger = logging.getLogger("websocket")

class WebSocketConnection:
_response_wait_timeout = 30
_response_wait_interval = 0.1

_max_log_message_size = 9999

def __init__(self, url):
self.session_id = None
self.url = url

self._id = 0
self._callbacks = {}
self._messages = {}
self._started = False

self._start_ws()
self._wait_until(lambda: self._started)

def close(self):
self._ws_thread.join(timeout=_response_wait_timeout)
self._ws.close()
self._started = False
self._ws = None

def execute(self, command):
self._id += 1
payload = self._serialize_command(command)
payload["id"] = self._id
if self.session_id:
payload["sessionId"] = self.session_id

data = json.dumps(payload)
logger.debug(f"WebSocket -> {data}"[:self._max_log_message_size])
self._ws.send(data)

self._wait_until(lambda: self._id in self._messages)
result = self._messages.pop(self._id)["result"]
return self._deserialize_result(result, command)

def on(self, event, callback):
if event not in self._callbacks:
self._callbacks[event.event_class] = []
self._callbacks[event.event_class].append(lambda params: callback(event.from_json(params)))

def _serialize_command(self, command):
return next(command)

def _deserialize_result(self, result, command):
try:
_ = command.send(result)
raise InternalError("The command's generator function did not exit when expected!")
except StopIteration as exit:
return exit.value

def _start_ws(self):
def on_open(ws):
self._started = True

def on_message(ws, message):
self._process_message(message)

def on_error(ws, error):
logger.debug(f"WebSocket error: {error}")
ws.close()

def run_socket():
if self.url.startswith("wss://"):
self._ws.run_forever(sslopt={"cert_reqs": CERT_NONE}, suppress_origin=True)
else:
self._ws.run_forever(suppress_origin=True)

self._ws = WebSocketApp(self.url, on_open=on_open, on_message=on_message, on_error=on_error)
self._ws_thread = Thread(target=run_socket)
self._ws_thread.start()

def _process_message(self, message):
message = json.loads(message)
logger.debug(f"WebSocket <- {message}"[:self._max_log_message_size])

if 'id' in message:
self._messages[message['id']] = message

if 'method' in message:
params = message['params']
for callback in self._callbacks.get(message['method'], []):
callback(params)

def _wait_until(self, condition):
timeout = self._response_wait_timeout
interval = self._response_wait_interval

while timeout > 0:
result = condition()
if result:
return result
else:
timeout -= interval
sleep(interval)
39 changes: 39 additions & 0 deletions py/test/selenium/webdriver/common/devtools_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest

from selenium.webdriver.common.by import By
from selenium.webdriver.common.log import Log
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.ui import WebDriverWait


@pytest.mark.xfail_safari
def test_check_console_messages(driver, pages):
devtools, connection = driver.start_devtools()
console_api_calls = []

connection.execute(devtools.runtime.enable())
connection.on(devtools.runtime.ConsoleAPICalled, console_api_calls.append)
driver.execute_script("console.log('I love cheese')")
driver.execute_script("console.error('I love bread')")
WebDriverWait(driver, 5).until(lambda _: len(console_api_calls) == 2)

assert console_api_calls[0].type_ == "log"
assert console_api_calls[0].args[0].value == "I love cheese"
assert console_api_calls[1].type_ == "error"
assert console_api_calls[1].args[0].value == "I love bread"

0 comments on commit 690e94f

Please sign in to comment.