Skip to content

Commit

Permalink
Add a passthrough mode for the new webdriver servlet
Browse files Browse the repository at this point in the history
This should allow ends that speaks the same dialect
of the wire protocol to just use the selenium server
as a fairly dumb pipe.
  • Loading branch information
shs96c committed May 21, 2017
1 parent ea3fa47 commit cb3b1e3
Show file tree
Hide file tree
Showing 19 changed files with 927 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ private void streamW3CProtocolParameters(
gson.toJson(firstMatch, out);
}

private Optional<Result> createSession(HttpClient client, InputStream newSessionBlob, long size)
public Optional<Result> createSession(HttpClient client, InputStream newSessionBlob, long size)
throws IOException {
// Create the http request and send it
HttpRequest request = new HttpRequest(HttpMethod.POST, "/session");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public HttpResponse encode(Response response) {
? HTTP_OK
: HTTP_INTERNAL_ERROR;

byte[] data = beanToJsonConverter.convert(response).getBytes(UTF_8);
byte[] data = beanToJsonConverter.convert(getValueToEncode(response)).getBytes(UTF_8);

HttpResponse httpResponse = new HttpResponse();
httpResponse.setStatus(status);
Expand All @@ -71,6 +71,8 @@ public HttpResponse encode(Response response) {
return httpResponse;
}

protected abstract Object getValueToEncode(Response response);

@Override
public Response decode(HttpResponse encodedResponse) {
String contentType = nullToEmpty(encodedResponse.getHeader(CONTENT_TYPE));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,9 @@ protected Response reconstructValue(Response response) {

return response;
}

@Override
protected Object getValueToEncode(Response response) {
return response;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import static java.net.HttpURLConnection.HTTP_OK;

import com.google.common.base.Strings;
import com.google.common.base.Throwables;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
Expand All @@ -35,6 +36,7 @@
import org.openqa.selenium.remote.internal.JsonToWebElementConverter;

import java.lang.reflect.Constructor;
import java.util.HashMap;
import java.util.Optional;
import java.util.function.Function;
import java.util.logging.Logger;
Expand Down Expand Up @@ -145,6 +147,31 @@ public Response decode(HttpResponse encodedResponse) {
return response;
}

@Override
protected Object getValueToEncode(Response response) {
HashMap<Object, Object> toReturn = new HashMap<>();
Object value = response.getValue();
if (value instanceof WebDriverException) {
HashMap<Object, Object> exception = new HashMap<>();
exception.put(
"error",
response.getState() != null ?
response.getState() :
errorCodes.toState(response.getStatus()));
exception.put("message", ((WebDriverException) value).getMessage());
exception.put("stacktrace", Throwables.getStackTraceAsString((WebDriverException) value));
if (value instanceof UnhandledAlertException) {
HashMap<String, Object> data = new HashMap<>();
data.put("text", ((UnhandledAlertException) value).getAlertText());
exception.put("data", data);
}

value = exception;
}
toReturn.put("value", value);
return toReturn;
}

protected Response reconstructValue(Response response) {
response.setValue(elementConverter.apply(response.getValue()));
return response;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@

interface ActiveSession extends CommandHandler {

/**
* Used to provide logging information and thread names.
*/
String getDescription();

SessionId getId();

/**
Expand Down
11 changes: 10 additions & 1 deletion java/server/src/org/openqa/selenium/remote/server/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ java_library(
'CommandHandler.java',
'ExceptionHandler.java',
'InMemorySession.java',
'Passthrough.java',
'ProtocolConverter.java',
'ServicedSession.java',
'SessionCodec.java',
'SessionFactory.java',
'TeeReader.java',
'WebDriverServlet.java',
],
provided_deps = [
Expand All @@ -80,7 +86,10 @@ java_library(
'//java/client/src/org/openqa/selenium/remote:remote',
'//third_party/java/gson:gson',
'//third_party/java/guava:guava',
]
],
visibility = [
'//java/server/test/org/openqa/selenium/remote/server:tests',
],
)

export_file(name = 'client',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,21 @@
import static org.openqa.selenium.remote.DesiredCapabilities.chrome;
import static org.openqa.selenium.remote.DesiredCapabilities.firefox;
import static org.openqa.selenium.remote.DesiredCapabilities.htmlUnit;
import static org.openqa.selenium.remote.Dialect.OSS;
import static org.openqa.selenium.remote.Dialect.W3C;

import com.google.common.cache.Cache;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.stream.JsonReader;
import com.google.gson.stream.JsonToken;

import org.openqa.selenium.SessionNotCreatedException;
import org.openqa.selenium.remote.BeanToJsonConverter;
import org.openqa.selenium.remote.Dialect;
import org.openqa.selenium.remote.SessionId;

import java.io.BufferedInputStream;
Expand Down Expand Up @@ -54,16 +58,16 @@ class BeginSession implements CommandHandler {

private final Cache<SessionId, ActiveSession> allSessions;
private final DriverSessions legacySessions;
private final Map<String, Function<Path, ActiveSession>> factories;
private final Map<String, SessionFactory> factories;
private final Function<Path, ActiveSession> defaultFactory;

public BeginSession(Cache<SessionId, ActiveSession> allSessions, DriverSessions legacySessions) {
this.allSessions = allSessions;
this.legacySessions = legacySessions;

this.factories = ImmutableMap.of(
chrome().getBrowserName(), new InMemorySession.Factory(legacySessions),
firefox().getBrowserName(), new InMemorySession.Factory(legacySessions),
chrome().getBrowserName(), new ServicedSession.Factory("org.openqa.selenium.chrome.ChromeDriverService"),
firefox().getBrowserName(), new ServicedSession.Factory("org.openqa.selenium.firefox.GeckoDriverService"),
htmlUnit().getBrowserName(), new InMemorySession.Factory(legacySessions));

defaultFactory = null;
Expand All @@ -80,15 +84,24 @@ public void execute(HttpServletRequest req, HttpServletResponse resp) throws IOE
List<Map<String, Object>> firstMatch = new LinkedList<>();

readCapabilities(allCaps, req, ossKeys, alwaysMatch, firstMatch);
List<Function<Path, ActiveSession>> browserGenerators = determineBrowser(
List<SessionFactory> browserGenerators = determineBrowser(
ossKeys,
alwaysMatch,
firstMatch);

ImmutableSet.Builder<Dialect> downstreamDialects = ImmutableSet.builder();
// Favour OSS for now
if (!ossKeys.isEmpty()) {
downstreamDialects.add(OSS);
}
if (!alwaysMatch.isEmpty() || !firstMatch.isEmpty()) {
downstreamDialects.add(W3C);
}

ActiveSession session = browserGenerators.stream()
.map(func -> {
try {
return func.apply(allCaps);
return func.apply(allCaps, downstreamDialects.build());
} catch (Exception e) {
LOG.log(Level.INFO, "Unable to start session.", e);
}
Expand Down Expand Up @@ -192,7 +205,7 @@ private void readCapabilities(
}
}

private List<Function<Path, ActiveSession>> determineBrowser(
private List<SessionFactory> determineBrowser(
Map<String, Object> ossKeys,
Map<String, Object> alwaysMatchKeys,
List<Map<String, Object>> firstMatchKeys) {
Expand All @@ -203,7 +216,7 @@ private List<Function<Path, ActiveSession>> determineBrowser(
allCapabilities.add(ossKeys);

// Can we figure out the browser from any of these?
ImmutableList.Builder<Function<Path, ActiveSession>> builder = ImmutableList.builder();
ImmutableList.Builder<SessionFactory> builder = ImmutableList.builder();
for (Map<String, Object> caps : allCapabilities) {
caps.entrySet().stream()
.map(entry -> guessBrowserName(entry.getKey(), entry.getValue()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.openqa.selenium.Capabilities;
import org.openqa.selenium.ImmutableCapabilities;
import org.openqa.selenium.SessionNotCreatedException;
import org.openqa.selenium.remote.Dialect;
import org.openqa.selenium.remote.JsonToBeanConverter;
import org.openqa.selenium.remote.SessionId;

Expand All @@ -17,7 +18,7 @@
import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import java.util.Set;
import java.util.logging.Logger;
import java.util.stream.Collectors;

Expand All @@ -41,14 +42,6 @@ public InMemorySession(SessionId id, Session session, JsonHttpCommandHandler com
this.commandHandler = commandHandler;
}

@Override
public String getDescription() {
return String.format(
"%s: Legacy Session -> %s",
id,
session.getCapabilities().getBrowserName());
}

@Override
public SessionId getId() {
return id;
Expand All @@ -70,7 +63,7 @@ public void execute(HttpServletRequest req, HttpServletResponse resp) throws IOE
commandHandler.handleRequest(req, resp);
}

public static class Factory implements Function<Path, ActiveSession> {
public static class Factory implements SessionFactory {

private final DriverSessions legacySessions;
private final JsonHttpCommandHandler jsonHttpCommandHandler;
Expand All @@ -83,7 +76,7 @@ public Factory(DriverSessions legacySessions) {
}

@Override
public ActiveSession apply(Path path) {
public ActiveSession apply(Path path, Set<Dialect> downstreamDialects) {
try (BufferedReader reader = Files.newBufferedReader(path, UTF_8)) {
Map<?, ?> blob = new JsonToBeanConverter().convert(Map.class, CharStreams.toString(reader));

Expand Down
138 changes: 138 additions & 0 deletions java/server/src/org/openqa/selenium/remote/server/Passthrough.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package org.openqa.selenium.remote.server;

import static java.nio.charset.StandardCharsets.UTF_8;

import com.google.common.base.Strings;
import com.google.common.collect.ImmutableSet;
import com.google.common.io.CharStreams;
import com.google.common.net.MediaType;

import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.Reader;
import java.io.StringWriter;
import java.io.Writer;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.Enumeration;
import java.util.Objects;
import java.util.logging.Logger;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

public class Passthrough implements SessionCodec {

private final static Logger LOG = Logger.getLogger(Passthrough.class.getName());

private final static ImmutableSet<String> IGNORED_REQ_HEADERS = ImmutableSet.<String>builder()
.add("connection")
.add("keep-alive")
.add("proxy-authorization")
.add("proxy-authenticate")
.add("proxy-connection")
.add("te")
.add("trailer")
.add("transfer-encoding")
.add("upgrade")
.build();

private final URL upstream;

public Passthrough(URL upstream) {
this.upstream = upstream;
}

@Override
public void handle(HttpServletRequest req, HttpServletResponse resp) throws IOException {
String suffix = req.getPathInfo();
if (Strings.isNullOrEmpty(suffix)) {
suffix = "/";
}

URL target = new URL(upstream.toExternalForm() + suffix);
HttpURLConnection connection = (HttpURLConnection) target.openConnection();
connection.setInstanceFollowRedirects(true);
connection.setRequestMethod(req.getMethod());
connection.setDoInput(true);
connection.setDoOutput(true);
connection.setUseCaches(false);

Enumeration<String> allHeaders = req.getHeaderNames();
while (allHeaders.hasMoreElements()) {
String name = allHeaders.nextElement();
if (IGNORED_REQ_HEADERS.contains(name.toLowerCase())) {
continue;
}

Enumeration<String> allValues = req.getHeaders(name);
while (allValues.hasMoreElements()) {
String value = allValues.nextElement();
connection.addRequestProperty(name, value);
}
}
// None of this "keep alive" nonsense.
connection.setRequestProperty("Connection", "close");

if ("POST".equalsIgnoreCase(req.getMethod()) || "PUT".equalsIgnoreCase(req.getMethod())) {
// We always transform to UTF-8 on the way up.
String contentType = req.getHeader("Content-Type");
contentType = contentType == null ? MediaType.JAVASCRIPT_UTF_8.toString() : contentType;

MediaType type = MediaType.parse(contentType);
connection.setRequestProperty("Content-Type", type.withCharset(UTF_8).toString());

String charSet = req.getCharacterEncoding() != null ?
req.getCharacterEncoding() :
UTF_8.name();

StringWriter logWriter = new StringWriter();
try (
InputStream is = req.getInputStream();
Reader reader = new InputStreamReader(is, charSet);
Reader in = new TeeReader(reader, logWriter);
OutputStream os = connection.getOutputStream();
Writer out = new OutputStreamWriter(os, UTF_8)) {
CharStreams.copy(in, out);
}
LOG.info("To upstream: " + logWriter.toString());
}

resp.setStatus(connection.getResponseCode());
// clear response defaults.
resp.setHeader("Date",null);
resp.setHeader("Server",null);

connection.getHeaderFields().entrySet().stream()
.filter(entry -> entry.getKey() != null && entry.getValue() != null)
.filter(entry -> !IGNORED_REQ_HEADERS.contains(entry.getKey().toLowerCase()))
.forEach(entry -> {
entry.getValue().stream()
.filter(Objects::nonNull)
.forEach(value -> {
resp.addHeader(entry.getKey(), value);
});
});
InputStream in = connection.getErrorStream();
if (in == null) {
in = connection.getInputStream();
}

String charSet = connection.getContentEncoding() != null ? connection.getContentEncoding() : UTF_8.name();
StringWriter logWriter = new StringWriter();
try (
Reader reader = new InputStreamReader(in, charSet);
Reader tee = new TeeReader(reader, logWriter);
OutputStream os = resp.getOutputStream();
Writer out = new OutputStreamWriter(os, UTF_8)) {
CharStreams.copy(tee, out);
} finally {
in.close();
}

LOG.info("To downstream: " + logWriter.toString());
}
}
Loading

0 comments on commit cb3b1e3

Please sign in to comment.