Skip to content

Commit

Permalink
Run proxy finder function and add result to request context (#87)
Browse files Browse the repository at this point in the history
This lets us run the proxy resolution code only once, even if we have to
perform multiple requests (e.g. for authentication).
  • Loading branch information
samuong committed Apr 27, 2022
1 parent efd9189 commit 2d9008b
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 28 deletions.
4 changes: 2 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2019, 2021 The Alpaca Authors
// Copyright 2019, 2021, 2022 The Alpaca Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -87,7 +87,7 @@ func main() {

pacWrapper := NewPACWrapper(PACData{Port: *port})
proxyFinder := NewProxyFinder(pacURL, pacWrapper)
proxyHandler := NewProxyHandler(proxyFinder.findProxyForRequest, a, proxyFinder.blockProxy)
proxyHandler := NewProxyHandler(a, getProxyFromContext, proxyFinder.blockProxy)
mux := http.NewServeMux()
pacWrapper.SetupHandlers(mux)

Expand Down
35 changes: 21 additions & 14 deletions proxy.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2019, 2021 The Alpaca Authors
// Copyright 2019, 2021, 2022 The Alpaca Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -37,7 +37,7 @@ type ProxyHandler struct {

type proxyFunc func(*http.Request) (*url.URL, error)

func NewProxyHandler(proxy proxyFunc, auth *authenticator, block func(string)) ProxyHandler {
func NewProxyHandler(auth *authenticator, proxy proxyFunc, block func(string)) ProxyHandler {
return ProxyHandler{&http.Transport{Proxy: proxy}, auth, block}
}

Expand Down Expand Up @@ -69,22 +69,20 @@ func (ph ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {

func (ph ProxyHandler) handleConnect(w http.ResponseWriter, req *http.Request) {
// Establish a connection to the server, or an upstream proxy.
u, err := ph.transport.Proxy(req)
id := req.Context().Value(contextKeyID)
proxy, err := ph.transport.Proxy(req)
if err != nil {
log.Printf("[%d] Error finding proxy for %v: %v", id, req.Host, err)
w.WriteHeader(http.StatusInternalServerError)
return
log.Printf("[%d] Error finding proxy for request: %v", id, err)
}
var server net.Conn
if u == nil {
server, err = net.Dial("tcp", req.Host)
if proxy == nil {
server, err = connectDirect(req)
} else {
server, err = connectViaProxy(req, u.Host, ph.auth)
server, err = connectViaProxy(req, proxy.Host, ph.auth)
var dialErr *dialError
if errors.As(err, &dialErr) {
log.Printf("[%d] Temporarily blocking unreachable proxy: %q", id, u.Host)
ph.block(u.Host)
log.Printf("[%d] Temporarily blocking proxy: %q", id, proxy.Host)
ph.block(proxy.Host)
}
}
if err != nil {
Expand Down Expand Up @@ -130,12 +128,21 @@ func (ph ProxyHandler) handleConnect(w http.ResponseWriter, req *http.Request) {
go func() { _, _ = io.Copy(client, server); client.Close() }()
}

func connectDirect(req *http.Request) (net.Conn, error) {
server, err := net.Dial("tcp", req.Host)
if err != nil {
id := req.Context().Value(contextKeyID)
log.Printf("[%d] Error dialling host %s: %v", id, req.Host, err)
}
return server, err
}

func connectViaProxy(req *http.Request, proxy string, auth *authenticator) (net.Conn, error) {
id := req.Context().Value(contextKeyID)
var tr transport
defer tr.Close()
if err := tr.dial("tcp", proxy); err != nil {
log.Printf("[%d] Error dialling %s: %v", id, proxy, err)
log.Printf("[%d] Error dialling proxy %s: %v", id, proxy, err)
return nil, err
}
resp, err := tr.RoundTrip(req)
Expand Down Expand Up @@ -176,7 +183,7 @@ func (ph ProxyHandler) proxyRequest(w http.ResponseWriter, req *http.Request, au
resp, err := ph.transport.RoundTrip(req)
if err != nil {
log.Printf("[%d] Error forwarding request: %v", id, err)
w.WriteHeader(http.StatusInternalServerError)
w.WriteHeader(http.StatusBadGateway)
var dialErr *dialError
if errors.As(err, &dialErr) && dialErr.address != req.Host {
log.Printf("[%d] Temporarily blocking unreachable proxy: %q",
Expand All @@ -196,7 +203,7 @@ func (ph ProxyHandler) proxyRequest(w http.ResponseWriter, req *http.Request, au
resp, err = auth.do(req, ph.transport)
if err != nil {
log.Printf("[%d] Error forwarding request (with auth): %v", id, err)
w.WriteHeader(http.StatusInternalServerError)
w.WriteHeader(http.StatusBadGateway)
return
}
defer resp.Body.Close()
Expand Down
29 changes: 19 additions & 10 deletions proxy_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2019 The Alpaca Authors
// Copyright 2019, 2021, 2022 The Alpaca Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,7 @@ package main

import (
"bufio"
"context"
"crypto/tls"
"crypto/x509"
"fmt"
Expand Down Expand Up @@ -54,17 +55,17 @@ func (tp testProxy) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}

func newDirectProxy() ProxyHandler {
return NewProxyHandler(
func(r *http.Request) (*url.URL, error) { return nil, nil },
nil,
func(string) {},
)
return NewProxyHandler(nil, http.ProxyURL(nil), func(string) {})
}

func newChildProxy(parent *httptest.Server) ProxyHandler {
return NewProxyHandler(func(r *http.Request) (*url.URL, error) {
return &url.URL{Host: parent.Listener.Addr().String()}, nil
}, nil, func(string) {})
func newChildProxy(parent *httptest.Server) http.Handler {
parentURL := &url.URL{Host: parent.Listener.Addr().String()}
childProxy := NewProxyHandler(nil, getProxyFromContext, func(string) {})
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx := context.WithValue(req.Context(), contextKeyProxy, parentURL)
reqWithProxy := req.WithContext(ctx)
childProxy.ServeHTTP(w, reqWithProxy)
})
}

func proxyServer(t *testing.T, proxy *httptest.Server) proxyFunc {
Expand Down Expand Up @@ -336,3 +337,11 @@ func TestConnectResponseHasCorrectNewlines(t *testing.T) {
assert.NotContains(t, noCRLFs, "\r", "response contains unmatched CR")
assert.NotContains(t, noCRLFs, "\n", "response contains unmatched LF")
}

func TestConnectToNonExistentHost(t *testing.T) {
proxy := httptest.NewServer(newDirectProxy())
defer proxy.Close()
client := http.Client{Transport: &http.Transport{Proxy: proxyServer(t, proxy)}}
_, err := client.Get("https://nonexistent.test")
require.Error(t, err)
}
22 changes: 20 additions & 2 deletions proxyfinder.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2019, 2021 The Alpaca Authors
// Copyright 2019, 2021, 2022 The Alpaca Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -15,6 +15,7 @@
package main

import (
"context"
"errors"
"log"
"net"
Expand All @@ -24,6 +25,16 @@ import (
"sync"
)

const contextKeyProxy = contextKey("proxy")

func getProxyFromContext(req *http.Request) (*url.URL, error) {
if value := req.Context().Value(contextKeyProxy); value != nil {
proxy := value.(*url.URL)
return proxy, nil
}
return nil, nil
}

type ProxyFinder struct {
runner *PACRunner
fetcher *pacFetcher
Expand Down Expand Up @@ -53,7 +64,14 @@ func (pf *ProxyFinder) WrapHandler(next http.Handler) http.Handler {
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
pf.checkForUpdates()
next.ServeHTTP(w, req)
proxy, err := pf.findProxyForRequest(req)
if err != nil {
log.Printf("[%d] %v", req.Context().Value(contextKeyID), err)
w.WriteHeader(http.StatusInternalServerError)
return
}
ctx := context.WithValue(req.Context(), contextKeyProxy, proxy)
next.ServeHTTP(w, req.WithContext(ctx))
})
}

Expand Down

0 comments on commit 2d9008b

Please sign in to comment.