Skip to content

Commit

Permalink
Add an Option for using a custom http.Client in WebDrivers/Pages
Browse files Browse the repository at this point in the history
  • Loading branch information
sclevine committed Sep 10, 2015
1 parent 547867f commit 26c5f50
Show file tree
Hide file tree
Showing 14 changed files with 153 additions and 33 deletions.
7 changes: 5 additions & 2 deletions agouti.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,12 @@ func Selendroid(jarFile string, options ...Option) *WebDriver {
}

// SauceLabs opens a Sauce Labs session and returns a *Page. Does not support Sauce Connect.
func SauceLabs(name, platform, browser, version, username, accessKey string) (*Page, error) {
//
// This method takes the same Options as *WebDriver.NewPage. Passing the Desired Option will
// completely override the provided name, platform, browser, and version.
func SauceLabs(name, platform, browser, version, username, accessKey string, options ...Option) (*Page, error) {
url := fmt.Sprintf("http://%s:%s@ondemand.saucelabs.com/wd/hub", username, accessKey)
capabilities := NewCapabilities().Browser(name).Platform(platform).Version(version)
capabilities["name"] = name
return NewPage(url, Desired(capabilities))
return NewPage(url, append([]Option{Desired(capabilities)}, options...)...)
}
7 changes: 7 additions & 0 deletions api/internal/bus/bus_suite_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package bus_test

import (
"net/http"
"testing"

. "github.com/onsi/ginkgo"
Expand All @@ -11,3 +12,9 @@ func TestBus(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Bus Suite")
}

type roundTripperFunc func(*http.Request) (*http.Response, error)

func (r roundTripperFunc) RoundTrip(request *http.Request) (*http.Response, error) {
return r(request)
}
7 changes: 4 additions & 3 deletions api/internal/bus/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

type Client struct {
SessionURL string
HTTPClient *http.Client
}

func (c *Client) Send(method, endpoint string, body interface{}, result interface{}) error {
Expand All @@ -20,7 +21,7 @@ func (c *Client) Send(method, endpoint string, body interface{}, result interfac
}

requestURL := strings.TrimSuffix(c.SessionURL+"/"+endpoint, "/")
responseBody, err := makeRequest(requestURL, method, requestBody)
responseBody, err := c.makeRequest(requestURL, method, requestBody)
if err != nil {
return err
}
Expand All @@ -46,7 +47,7 @@ func bodyToJSON(body interface{}) ([]byte, error) {
return bodyJSON, nil
}

func makeRequest(url, method string, body []byte) ([]byte, error) {
func (c *Client) makeRequest(url, method string, body []byte) ([]byte, error) {
request, err := http.NewRequest(method, url, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("invalid request: %s", err)
Expand All @@ -56,7 +57,7 @@ func makeRequest(url, method string, body []byte) ([]byte, error) {
request.Header.Add("Content-Type", "application/json")
}

response, err := http.DefaultClient.Do(request)
response, err := c.HTTPClient.Do(request)
if err != nil {
return nil, fmt.Errorf("request failed: %s", err)
}
Expand Down
19 changes: 17 additions & 2 deletions api/internal/bus/client_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package bus_test

import (
"errors"
"io/ioutil"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -43,7 +44,10 @@ var _ = Describe("Session", func() {
var client *Client

BeforeEach(func() {
client = &Client{SessionURL: server.URL + "/session/some-id"}
client = &Client{
SessionURL: server.URL + "/session/some-id",
HTTPClient: http.DefaultClient,
}
})

It("should make a request with the method and full session endpoint", func() {
Expand All @@ -52,8 +56,19 @@ var _ = Describe("Session", func() {
Expect(requestMethod).To(Equal("GET"))
})

It("should use the provided HTTP client", func() {
var path string
client.HTTPClient = &http.Client{Transport: roundTripperFunc(func(request *http.Request) (*http.Response, error) {
path = request.URL.Path
return nil, errors.New("some error")
})}
err := client.Send("GET", "some/endpoint", nil, nil)
Expect(err).To(MatchError(ContainSubstring("some error")))
Expect(path).To(Equal("/session/some-id/some/endpoint"))
})

Context("with a valid request body", func() {
It("should make a application/json request with the provided body", func() {
It("should make a request with the provided body and application/json content type", func() {
body := struct{ SomeValue string }{"some request value"}
Expect(client.Send("POST", "some/endpoint", body, nil)).To(Succeed())
Expect(requestBody).To(Equal(`{"SomeValue":"some request value"}`))
Expand Down
14 changes: 9 additions & 5 deletions api/internal/bus/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,23 @@ import (
"net/http"
)

func Connect(url string, capabilities map[string]interface{}) (*Client, error) {
func Connect(url string, capabilities map[string]interface{}, httpClient *http.Client) (*Client, error) {
requestBody, err := capabilitiesToJSON(capabilities)
if err != nil {
return nil, err
}

sessionID, err := openSession(url, requestBody)
if httpClient == nil {
httpClient = http.DefaultClient
}

sessionID, err := openSession(url, requestBody, httpClient)
if err != nil {
return nil, err
}

sessionURL := fmt.Sprintf("%s/session/%s", url, sessionID)
return &Client{sessionURL}, nil
return &Client{sessionURL, httpClient}, nil
}

func capabilitiesToJSON(capabilities map[string]interface{}) (io.Reader, error) {
Expand All @@ -40,15 +44,15 @@ func capabilitiesToJSON(capabilities map[string]interface{}) (io.Reader, error)
return bytes.NewReader(capabiltiesJSON), err
}

func openSession(url string, body io.Reader) (sessionID string, err error) {
func openSession(url string, body io.Reader, httpClient *http.Client) (sessionID string, err error) {
request, err := http.NewRequest("POST", fmt.Sprintf("%s/session", url), body)
if err != nil {
return "", err
}

request.Header.Add("Content-Type", "application/json")

response, err := http.DefaultClient.Do(request)
response, err := httpClient.Do(request)
if err != nil {
return "", err
}
Expand Down
60 changes: 49 additions & 11 deletions api/internal/bus/connect_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package bus_test

import (
"errors"
"io/ioutil"
"net/http"
"net/http/httptest"
Expand All @@ -24,7 +25,7 @@ var _ = Describe(".Connect", func() {
responseBody = `{"sessionId": "some-id"}`
requestPath, requestMethod, requestBody, requestContentType = "", "", "", ""
server = httptest.NewServer(http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
requestPath = request.URL.Path // TODO: use
requestPath = request.URL.Path
requestMethod = request.Method
requestBodyBytes, _ := ioutil.ReadAll(request.Body)
requestBody = string(requestBodyBytes)
Expand All @@ -37,67 +38,104 @@ var _ = Describe(".Connect", func() {
server.Close()
})

It("should successfully make an application/json POST request to the session endpoint", func() {
_, err := Connect(server.URL, nil)
It("should successfully make an POST request with content type application/json to the session endpoint", func() {
_, err := Connect(server.URL, nil, nil)
Expect(err).NotTo(HaveOccurred())
Expect(requestMethod).To(Equal("POST"))
Expect(requestPath).To(Equal("/session"))
Expect(requestContentType).To(Equal("application/json"))
})

It("should make the request using the provided HTTP client", func() {
var path string
client := &http.Client{Transport: roundTripperFunc(func(request *http.Request) (*http.Response, error) {
path = request.URL.Path
return nil, errors.New("some error")
})}
_, err := Connect(server.URL, nil, client)
Expect(err).To(MatchError(ContainSubstring("some error")))
Expect(path).To(Equal("/session"))
})

It("should return a client with a session URL", func() {
client, err := Connect(server.URL, nil)
client, err := Connect(server.URL, nil, nil)
Expect(err).NotTo(HaveOccurred())
Expect(client.SessionURL).To(ContainSubstring("/session/some-id"))
})

It("should make the request with the provided desired capabilities", func() {
_, err := Connect(server.URL, map[string]interface{}{"some": "json"})
_, err := Connect(server.URL, map[string]interface{}{"some": "json"}, nil)
Expect(err).NotTo(HaveOccurred())
Expect(requestBody).To(MatchJSON(`{"desiredCapabilities": {"some": "json"}}`))
})

Context("when the capabilities are nil", func() {
It("should make the request with empty capabilities", func() {
_, err := Connect(server.URL, nil)
_, err := Connect(server.URL, nil, nil)
Expect(err).NotTo(HaveOccurred())
Expect(requestBody).To(MatchJSON(`{"desiredCapabilities": {}}`))
})
})

Context("when the capabilities are invalid", func() {
It("should return an error", func() {
_, err := Connect(server.URL, map[string]interface{}{"some": func() {}})
_, err := Connect(server.URL, map[string]interface{}{"some": func() {}}, nil)
Expect(err).To(MatchError("json: unsupported type: func()"))
})
})

Context("when the provided HTTP client is nil", func() {
var (
defaultClient *http.Client
path string
)

BeforeEach(func() {
defaultClient = http.DefaultClient
http.DefaultClient = &http.Client{Transport: roundTripperFunc(func(request *http.Request) (*http.Response, error) {
path = request.URL.Path
return nil, errors.New("some error")
})}

})

AfterEach(func() {
http.DefaultClient = defaultClient
})

It("should use the default HTTP client", func() {
_, err := Connect(server.URL, nil, nil)
Expect(err).To(MatchError(ContainSubstring("some error")))
Expect(path).To(Equal("/session"))
})
})

Context("when the request is invalid", func() {
It("should return an error", func() {
_, err := Connect("%@#$%", nil)
_, err := Connect("%@#$%", nil, nil)
Expect(err.Error()).To(ContainSubstring(`parse %@: invalid URL escape "%@"`))
})
})

Context("when the request fails", func() {
It("should return an error", func() {
_, err := Connect("http://#", nil)
_, err := Connect("http://#", nil, nil)
Expect(err.Error()).To(ContainSubstring("Post http://#/session"))
})
})

Context("when the response contains invalid JSON", func() {
It("should return an error", func() {
responseBody = "$$$"
_, err := Connect(server.URL, nil)
_, err := Connect(server.URL, nil, nil)
Expect(err).To(MatchError("invalid character '$' looking for beginning of value"))
})
})

Context("when the response does not contain a session ID", func() {
It("should return an error", func() {
responseBody = "{}"
_, err := Connect(server.URL, nil)
_, err := Connect(server.URL, nil, nil)
Expect(err).To(MatchError("failed to retrieve a session ID"))
})
})
Expand Down
7 changes: 6 additions & 1 deletion api/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package api
import (
"encoding/base64"
"errors"
"net/http"

"github.com/sclevine/agouti/api/internal/bus"
)
Expand All @@ -16,7 +17,11 @@ type Bus interface {
}

func Open(url string, capabilities map[string]interface{}) (*Session, error) {
busClient, err := bus.Connect(url, capabilities)
return OpenWithClient(url, capabilities, nil)
}

func OpenWithClient(url string, capabilities map[string]interface{}, client *http.Client) (*Session, error) {
busClient, err := bus.Connect(url, capabilities, client)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion api/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
. "github.com/sclevine/agouti/internal/matchers"
)

var _ = Describe("Bus", func() {
var _ = Describe("Session", func() {
var (
bus *mocks.Bus
session *Session
Expand Down
12 changes: 7 additions & 5 deletions api/webdriver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@ package api

import (
"fmt"
"net/http"
"time"

"github.com/sclevine/agouti/api/internal/service"
)

type WebDriver struct {
Timeout time.Duration
Debug bool
service driverService
sessions []*Session
Timeout time.Duration
Debug bool
HTTPClient *http.Client
service driverService
sessions []*Session
}

type driverService interface {
Expand Down Expand Up @@ -43,7 +45,7 @@ func (w *WebDriver) Open(desiredCapabilites map[string]interface{}) (*Session, e
return nil, fmt.Errorf("service not started")
}

session, err := Open(url, desiredCapabilites)
session, err := OpenWithClient(url, desiredCapabilites, w.HTTPClient)
if err != nil {
return nil, err
}
Expand Down
19 changes: 19 additions & 0 deletions api/webdriver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ import (
"github.com/sclevine/agouti/api/internal/mocks"
)

type roundTripperFunc func(*http.Request) (*http.Response, error)

func (r roundTripperFunc) RoundTrip(request *http.Request) (*http.Response, error) {
return r(request)
}

var _ = Describe("WebDriver", func() {
var (
webDriver *WebDriver
Expand Down Expand Up @@ -82,6 +88,19 @@ var _ = Describe("WebDriver", func() {
Expect(err).To(MatchError("failed to retrieve a session ID"))
})
})

Context("when a custom HTTP client is set", func() {
It("should open the session using that client", func() {
var path string
webDriver.HTTPClient = &http.Client{Transport: roundTripperFunc(func(request *http.Request) (*http.Response, error) {
path = request.URL.Path
return nil, errors.New("some error")
})}
_, err := webDriver.Open(nil)
Expect(err).To(MatchError(ContainSubstring("some error")))
Expect(path).To(Equal("/session"))
})
})
})

Describe("#Start", func() {
Expand Down
Loading

0 comments on commit 26c5f50

Please sign in to comment.