Skip to content

Commit

Permalink
[py] Get devtools API's working properly
Browse files Browse the repository at this point in the history
This does dynamic importing of the generated APIs
  • Loading branch information
AutomatedTester committed Oct 2, 2020
1 parent 8185e9c commit e2987e2
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 87 deletions.
10 changes: 5 additions & 5 deletions py/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
#
# CDP domain: {{}}{{}}
from __future__ import annotations
from ..devtools.util import event_class, T_JSON_DICT
from ..util import event_class, T_JSON_DICT
from dataclasses import dataclass
import enum
import typing
Expand Down Expand Up @@ -245,7 +245,7 @@ def generate_decl(self):
code += ' = None'
return code

def generate_to_json(self, dict_, use_self = True):
def generate_to_json(self, dict_, use_self=True):
''' Generate the code that exports this property to the specified JSON
dict. '''
self_ref = 'self.' if use_self else ''
Expand Down Expand Up @@ -830,7 +830,7 @@ def generate_imports(self):
continue
if domain != self.domain:
dependencies.add(snake_case(domain))
code = '\n'.join(f'from ..devtools import {d}' for d in sorted(dependencies))
code = '\n'.join(f'from .. import {d}' for d in sorted(dependencies))

return code

Expand Down Expand Up @@ -926,9 +926,9 @@ def generate_init(init_path, domains):
'''
with open(init_path, "w") as init_file:
init_file.write(INIT_HEADER)
init_file.write('from ..devtools import util\n\n')
for domain in domains:
init_file.write('from ..devtools import {}\n'.format(domain.module))
init_file.write('from . import {}\n'.format(domain.module))
init_file.write('from .. import util\n\n')


def generate_docs(docs_path, domains):
Expand Down
22 changes: 20 additions & 2 deletions py/selenium/webdriver/remote/webdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,15 +1425,33 @@ def get_log(self, log_type):
async def get_devtools_connection(self):
assert sys.version_info >= (3, 6)

from selenium.webdriver.support import cdp
ws_url = None
if self.capabilities.get("se:options"):
ws_url = self.capabilities.get("se:options").get("cdp")
else:
ws_url = self.capabilities.get(self.vendor_prefix["debuggerAddress"])
version, ws_url = self._get_cdp_details()

if ws_url is None:
raise WebDriverException("Unable to find url to connect to from capabilities")

from selenium.webdriver.support import cdp
cdp.import_devtools(version)
async with cdp.open_cdp(ws_url) as conn:
yield conn

def _get_cdp_details(self):
import json
import urllib3

http = urllib3.PoolManager()
debugger_address = self.capabilities.get("{0}:chromeOptions".format(self.vendor_prefix)).get("debuggerAddress")
res = http.request('GET', "http://{0}/json/version".format(debugger_address))
data = json.loads(res.data)

browser_version = data.get("Browser")
websocket_url = data.get("webSocketDebuggerUrl")

import re
version = re.search(r".*/(\d+)\.", browser_version).group(1)

return version, websocket_url
183 changes: 103 additions & 80 deletions py/selenium/webdriver/support/cdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,34 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from trio_websocket import (
ConnectionClosed as WsConnectionClosed,
connect_websocket_url,
open_websocket_url
)
import trio
from collections import defaultdict
from contextlib import (contextmanager, asynccontextmanager)
from dataclasses import dataclass
import contextvars
import importlib
import itertools
import json
import logging
import sys
import typing

import trio
from trio_websocket import (
ConnectionClosed as WsConnectionClosed,
connect_websocket_url,
open_websocket_url
)

logger = logging.getLogger('trio_cdp')
T = typing.TypeVar('T')
MAX_WS_MESSAGE_SIZE = 2**24

global devtools


def import_devtools(version):
devtools = importlib.import_module("selenium.webdriver.common.devtools.v{}".format(version))


_connection_context: contextvars.ContextVar = contextvars.ContextVar('connection_context')
_session_context: contextvars.ContextVar = contextvars.ContextVar('session_context')
Expand Down Expand Up @@ -98,7 +112,6 @@ def set_global_session(session):
_session_context = contextvars.ContextVar('_session_context', default=session)



class BrowserError(Exception):
''' This exception is raised when the browser's response to a command
indicates that an error occurred. '''
Expand Down Expand Up @@ -134,7 +147,15 @@ class InternalError(Exception):
the integration with PyCDP. '''


class CDPBase:
@dataclass
class CmEventProxy:
''' A proxy object returned by :meth:`CdpBase.wait_for()``. After the
context manager executes, this proxy object will have a value set that
contains the returned event. '''
value: typing.Any = None


class CdpBase:

def __init__(self, ws, session_id, target_id):
self.ws = ws
Expand Down Expand Up @@ -169,7 +190,6 @@ async def execute(self, cmd: typing.Generator[dict, T, typing.Any]) -> T:
raise response
return response


def listen(self, *event_types, buffer_size=10):
''' Return an async iterator that iterates over events matching the
indicated types. '''
Expand Down Expand Up @@ -216,7 +236,7 @@ def _handle_cmd_response(self, data):
cmd, event = self.inflight_cmd.pop(cmd_id)
except KeyError:
logger.warning('Got a message with a command ID that does'
' not exist: {}'.format(data))
' not exist: {}'.format(data))
return
if 'error' in data:
# If the server reported an error, convert it to an exception and do
Expand All @@ -228,7 +248,7 @@ def _handle_cmd_response(self, data):
try:
response = cmd.send(data['result'])
raise InternalError("The command's generator function "
"did not exit when expected!")
"did not exit when expected!")
except StopIteration as exit:
return_ = exit.value
self.inflight_result[cmd_id] = return_
Expand All @@ -239,21 +259,83 @@ def _handle_event(self, data):
Handle an event.
:param dict data: event as a JSON dictionary
'''
event = cdp.util.parse_json_event(data)
event = devtools.util.parse_json_event(data)
logger.debug('Received event: %s', event)
to_remove = set()
for sender in self.channels[type(event)]:
try:
sender.send_nowait(event)
except trio.WouldBlock:
logger.error('Unable to send event "%r" due to full channel %s',
event, sender)
event, sender)
except trio.BrokenResourceError:
to_remove.add(sender)
if to_remove:
self.channels[type(event)] -= to_remove


class CdpSession(CdpBase):
'''
Contains the state for a CDP session.
Generally you should not instantiate this object yourself; you should call
:meth:`CdpConnection.open_session`.
'''

def __init__(self, ws, session_id, target_id):
'''
Constructor.
:param trio_websocket.WebSocketConnection ws:
:param devtools.target.SessionID session_id:
:param devtools.target.TargetID target_id:
'''
super().__init__(ws, session_id, target_id)

self._dom_enable_count = 0
self._dom_enable_lock = trio.Lock()
self._page_enable_count = 0
self._page_enable_lock = trio.Lock()

@asynccontextmanager
async def dom_enable(self):
'''
A context manager that executes ``dom.enable()`` when it enters and then
calls ``dom.disable()``.
This keeps track of concurrent callers and only disables DOM events when
all callers have exited.
'''
async with self._dom_enable_lock:
self._dom_enable_count += 1
if self._dom_enable_count == 1:
await self.execute(devtools.dom.enable())

yield

async with self._dom_enable_lock:
self._dom_enable_count -= 1
if self._dom_enable_count == 0:
await self.execute(devtools.dom.disable())

@asynccontextmanager
async def page_enable(self):
'''
A context manager that executes ``page.enable()`` when it enters and
then calls ``page.disable()`` when it exits.
This keeps track of concurrent callers and only disables page events
when all callers have exited.
'''
async with self._page_enable_lock:
self._page_enable_count += 1
if self._page_enable_count == 1:
await self.execute(devtools.page.enable())

yield

async with self._page_enable_lock:
self._page_enable_count -= 1
if self._page_enable_count == 0:
await self.execute(devtools.page.disable())


class CdpConnection(CdpBase, trio.abc.AsyncResource):
'''
Contains the connection state for a Chrome DevTools Protocol server.
Expand All @@ -265,6 +347,7 @@ class CdpConnection(CdpBase, trio.abc.AsyncResource):
You should generally call the :func:`open_cdp()` instead of
instantiating this class directly.
'''

def __init__(self, ws):
'''
Constructor
Expand All @@ -285,7 +368,7 @@ async def aclose(self):
await self.ws.aclose()

@asynccontextmanager
async def open_session(self, target_id: cdp.target.TargetID) -> \
async def open_session(self, target_id) -> \
typing.AsyncIterator[CdpSession]:
'''
This context manager opens a session and enables the "simple" style of calling
Expand All @@ -297,11 +380,11 @@ async def open_session(self, target_id: cdp.target.TargetID) -> \
with session_context(session):
yield session

async def connect_session(self, target_id: cdp.target.TargetID) -> 'CdpSession':
async def connect_session(self, target_id) -> 'CdpSession':
'''
Returns a new :class:`CdpSession` connected to the specified target.
'''
session_id = await self.execute(cdp.target.attach_to_target(
session_id = await self.execute(devtools.target.attach_to_target(
target_id, True))
session = CdpSession(self.ws, session_id, target_id)
self.sessions[session_id] = session
Expand Down Expand Up @@ -331,78 +414,17 @@ async def _reader_task(self):
})
logger.debug('Received message %r', data)
if 'sessionId' in data:
session_id = cdp.target.SessionID(data['sessionId'])
session_id = devtools.target.SessionID(data['sessionId'])
try:
session = self.sessions[session_id]
except KeyError:
raise BrowserError('Browser sent a message for an invalid '
'session: {!r}'.format(session_id))
'session: {!r}'.format(session_id))
session._handle_data(data)
else:
self._handle_data(data)


class CdpSession(CdpBase):
'''
Contains the state for a CDP session.
Generally you should not instantiate this object yourself; you should call
:meth:`CdpConnection.open_session`.
'''
def __init__(self, ws, session_id, target_id):
'''
Constructor.
:param trio_websocket.WebSocketConnection ws:
:param cdp.target.SessionID session_id:
:param cdp.target.TargetID target_id:
'''
super().__init__(ws, session_id, target_id)

self._dom_enable_count = 0
self._dom_enable_lock = trio.Lock()
self._page_enable_count = 0
self._page_enable_lock = trio.Lock()

@asynccontextmanager
async def dom_enable(self):
'''
A context manager that executes ``dom.enable()`` when it enters and then
calls ``dom.disable()``.
This keeps track of concurrent callers and only disables DOM events when
all callers have exited.
'''
async with self._dom_enable_lock:
self._dom_enable_count += 1
if self._dom_enable_count == 1:
await self.execute(cdp.dom.enable())

yield

async with self._dom_enable_lock:
self._dom_enable_count -= 1
if self._dom_enable_count == 0:
await self.execute(cdp.dom.disable())

@asynccontextmanager
async def page_enable(self):
'''
A context manager that executes ``page.enable()`` when it enters and
then calls ``page.disable()`` when it exits.
This keeps track of concurrent callers and only disables page events
when all callers have exited.
'''
async with self._page_enable_lock:
self._page_enable_count += 1
if self._page_enable_count == 1:
await self.execute(cdp.page.enable())

yield

async with self._page_enable_lock:
self._page_enable_count -= 1
if self._page_enable_count == 0:
await self.execute(cdp.page.disable())


@asynccontextmanager
async def open_cdp(url) -> typing.AsyncIterator[CdpConnection]:
'''
Expand All @@ -414,6 +436,7 @@ async def open_cdp(url) -> typing.AsyncIterator[CdpConnection]:
connection automatically. If you want to use multiple connections concurrently, it
is recommended to open each on in a separate task.
'''

async with trio.open_nursery() as nursery:
conn = await connect_cdp(nursery, url)
try:
Expand All @@ -437,7 +460,7 @@ async def connect_cdp(nursery, url) -> CdpConnection:
such as running inside of a notebook.
'''
ws = await connect_websocket_url(nursery, url,
max_message_size=MAX_WS_MESSAGE_SIZE)
max_message_size=MAX_WS_MESSAGE_SIZE)
cdp_conn = CdpConnection(ws)
nursery.start_soon(cdp_conn._reader_task)
return cdp_conn

0 comments on commit e2987e2

Please sign in to comment.