Skip to content

Commit

Permalink
On behalf of akshay@browserstack.com: using keep-alive for remote con…
Browse files Browse the repository at this point in the history
…nection

(minor tweaks to patch made by lukeis)
Fixes Issue SeleniumHQ#6452
  • Loading branch information
lukeis committed Oct 28, 2013
1 parent 56af8b3 commit 93dc128
Showing 1 changed file with 38 additions and 44 deletions.
82 changes: 38 additions & 44 deletions py/selenium/webdriver/remote/remote_connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright 2008-2009 WebDriver committers
# Copyright 2008-2009 Google Inc.
# Copyright 2013 BrowserStack
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,14 +17,13 @@
import logging
import socket
import string
import base64
import httplib
import datetime
import urllib2 as url_request

try:
from urllib import request as url_request
except ImportError:
import urllib2 as url_request

try:
from urllib import parse
from urllib2 import parse
except ImportError:
import urlparse as parse

Expand All @@ -33,7 +33,6 @@

LOGGER = logging.getLogger(__name__)


class Request(url_request.Request):
"""
Extends the url_request.Request to support all HTTP request types.
Expand Down Expand Up @@ -133,13 +132,14 @@ class RemoteConnection(object):
Communicates with the server using the WebDriver wire protocol:
http://code.google.com/p/selenium/wiki/JsonWireProtocol
"""

def __init__(self, remote_server_addr):
# Attempt to resolve the hostname and get an IP address.
parsed_url = parse.urlparse(remote_server_addr)
addr = ""
if parsed_url.hostname:
try:
netloc = socket.gethostbyname(parsed_url.hostname)
addr = netloc
if parsed_url.port:
netloc += ':%d' % parsed_url.port
if parsed_url.username:
Expand All @@ -155,6 +155,7 @@ def __init__(self, remote_server_addr):
parsed_url.hostname)

self._url = remote_server_addr
self._conn = httplib.HTTPConnection(str(addr),str(parsed_url.port))
self._commands = {
Command.STATUS: ('GET', '/status'),
Command.NEW_SESSION: ('POST', '/session'),
Expand Down Expand Up @@ -335,6 +336,8 @@ def __init__(self, remote_server_addr):
('GET','/session/$sessionId/log/types'),
}



def execute(self, command, params):
"""
Send a command to the remote server.
Expand Down Expand Up @@ -369,47 +372,37 @@ def _request(self, url, data=None, method=None):
LOGGER.debug('%s %s %s' % (method, url, data))

parsed_url = parse.urlparse(url)
password_manager = None
headers = {}
headers["Connection"] = "Keep-Alive"
headers[method] = parsed_url.path
headers["User-Agent"] = "Python http auth"
headers["Content-type"] = "text/html;charset=\"UTF-8\""
headers["Connection"] = "keep-alive"

# for basic auth
if parsed_url.username:
netloc = parsed_url.hostname
if parsed_url.port:
netloc += ":%s" % parsed_url.port
cleaned_url = parse.urlunparse((parsed_url.scheme,
netloc,
parsed_url.path,
parsed_url.params,
parsed_url.query,
parsed_url.fragment))
password_manager = url_request.HTTPPasswordMgrWithDefaultRealm()
password_manager.add_password(None,
"%s://%s" % (parsed_url.scheme, netloc),
parsed_url.username,
parsed_url.password)
request = Request(cleaned_url, data=data.encode('utf-8'), method=method)
else:
request = Request(url, data=data.encode('utf-8'), method=method)

request.add_header('Accept', 'application/json')
request.add_header('Content-Type', 'application/json;charset=UTF-8')

if password_manager:
opener = url_request.build_opener(url_request.HTTPRedirectHandler(),
HttpErrorHandler(),
url_request.HTTPBasicAuthHandler(password_manager))
else:
opener = url_request.build_opener(url_request.HTTPRedirectHandler(),
HttpErrorHandler())
response = opener.open(request)
auth = base64.standard_b64encode('%s:%s' % (parsed_url.username, parsed_url.password)).replace('\n','')
# Authorization header
headers["Authorization"] = "Basic %s" % auth

self._conn.request(method, parsed_url.path, data, headers)
resp = self._conn.getresponse()
statuscode = resp.status
statusmessage = resp.msg
LOGGER.debug('%s %s' %(statuscode, statusmessage))
data = resp.read()
try:
if response.code > 399 and response.code < 500:
return {'status': response.code, 'value': response.read()}
body = response.read().decode('utf-8').replace('\x00', '').strip()
content_type = [value for name, value in response.info().items() if name.lower() == "content-type"]
if statuscode > 399 and statuscode < 500:
return {'status': statuscode, 'value': data}
body = data.decode('utf-8').replace('\x00', '').strip()
content_type = []
if resp.getheader('Content-Type') is not None:
content_type = resp.getheader('Content-Type').split(';')
if not any([x.startswith('image/png') for x in content_type]):
try:
data = utils.load_json(body.strip())
except ValueError:
if response.code > 199 and response.code < 300:
if statuscode > 199 and statuscode < 300:
status = ErrorCode.SUCCESS
else:
status = ErrorCode.UNKNOWN_ERROR
Expand All @@ -428,4 +421,5 @@ def _request(self, url, data=None, method=None):
data = {'status': 0, 'value': body.strip()}
return data
finally:
response.close()
LOGGER.debug("Finished Request")
resp.close()

0 comments on commit 93dc128

Please sign in to comment.