diff --git a/main.go b/main.go index 3c6ea74..a0599a2 100644 --- a/main.go +++ b/main.go @@ -1,4 +1,4 @@ -// Copyright 2019, 2021 The Alpaca Authors +// Copyright 2019, 2021, 2022 The Alpaca Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -87,7 +87,7 @@ func main() { pacWrapper := NewPACWrapper(PACData{Port: *port}) proxyFinder := NewProxyFinder(pacURL, pacWrapper) - proxyHandler := NewProxyHandler(proxyFinder.findProxyForRequest, a, proxyFinder.blockProxy) + proxyHandler := NewProxyHandler(a, getProxyFromContext, proxyFinder.blockProxy) mux := http.NewServeMux() pacWrapper.SetupHandlers(mux) diff --git a/proxy.go b/proxy.go index e71990f..76819d5 100644 --- a/proxy.go +++ b/proxy.go @@ -1,4 +1,4 @@ -// Copyright 2019, 2021 The Alpaca Authors +// Copyright 2019, 2021, 2022 The Alpaca Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ type ProxyHandler struct { type proxyFunc func(*http.Request) (*url.URL, error) -func NewProxyHandler(proxy proxyFunc, auth *authenticator, block func(string)) ProxyHandler { +func NewProxyHandler(auth *authenticator, proxy proxyFunc, block func(string)) ProxyHandler { return ProxyHandler{&http.Transport{Proxy: proxy}, auth, block} } @@ -69,22 +69,20 @@ func (ph ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (ph ProxyHandler) handleConnect(w http.ResponseWriter, req *http.Request) { // Establish a connection to the server, or an upstream proxy. - u, err := ph.transport.Proxy(req) id := req.Context().Value(contextKeyID) + proxy, err := ph.transport.Proxy(req) if err != nil { - log.Printf("[%d] Error finding proxy for %v: %v", id, req.Host, err) - w.WriteHeader(http.StatusInternalServerError) - return + log.Printf("[%d] Error finding proxy for request: %v", id, err) } var server net.Conn - if u == nil { - server, err = net.Dial("tcp", req.Host) + if proxy == nil { + server, err = connectDirect(req) } else { - server, err = connectViaProxy(req, u.Host, ph.auth) + server, err = connectViaProxy(req, proxy.Host, ph.auth) var dialErr *dialError if errors.As(err, &dialErr) { - log.Printf("[%d] Temporarily blocking unreachable proxy: %q", id, u.Host) - ph.block(u.Host) + log.Printf("[%d] Temporarily blocking proxy: %q", id, proxy.Host) + ph.block(proxy.Host) } } if err != nil { @@ -130,12 +128,21 @@ func (ph ProxyHandler) handleConnect(w http.ResponseWriter, req *http.Request) { go func() { _, _ = io.Copy(client, server); client.Close() }() } +func connectDirect(req *http.Request) (net.Conn, error) { + server, err := net.Dial("tcp", req.Host) + if err != nil { + id := req.Context().Value(contextKeyID) + log.Printf("[%d] Error dialling host %s: %v", id, req.Host, err) + } + return server, err +} + func connectViaProxy(req *http.Request, proxy string, auth *authenticator) (net.Conn, error) { id := req.Context().Value(contextKeyID) var tr transport defer tr.Close() if err := tr.dial("tcp", proxy); err != nil { - log.Printf("[%d] Error dialling %s: %v", id, proxy, err) + log.Printf("[%d] Error dialling proxy %s: %v", id, proxy, err) return nil, err } resp, err := tr.RoundTrip(req) @@ -176,7 +183,7 @@ func (ph ProxyHandler) proxyRequest(w http.ResponseWriter, req *http.Request, au resp, err := ph.transport.RoundTrip(req) if err != nil { log.Printf("[%d] Error forwarding request: %v", id, err) - w.WriteHeader(http.StatusInternalServerError) + w.WriteHeader(http.StatusBadGateway) var dialErr *dialError if errors.As(err, &dialErr) && dialErr.address != req.Host { log.Printf("[%d] Temporarily blocking unreachable proxy: %q", @@ -196,7 +203,7 @@ func (ph ProxyHandler) proxyRequest(w http.ResponseWriter, req *http.Request, au resp, err = auth.do(req, ph.transport) if err != nil { log.Printf("[%d] Error forwarding request (with auth): %v", id, err) - w.WriteHeader(http.StatusInternalServerError) + w.WriteHeader(http.StatusBadGateway) return } defer resp.Body.Close() diff --git a/proxy_test.go b/proxy_test.go index fd7d1d0..c9d860c 100644 --- a/proxy_test.go +++ b/proxy_test.go @@ -1,4 +1,4 @@ -// Copyright 2019 The Alpaca Authors +// Copyright 2019, 2021, 2022 The Alpaca Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package main import ( "bufio" + "context" "crypto/tls" "crypto/x509" "fmt" @@ -54,17 +55,17 @@ func (tp testProxy) ServeHTTP(w http.ResponseWriter, req *http.Request) { } func newDirectProxy() ProxyHandler { - return NewProxyHandler( - func(r *http.Request) (*url.URL, error) { return nil, nil }, - nil, - func(string) {}, - ) + return NewProxyHandler(nil, http.ProxyURL(nil), func(string) {}) } -func newChildProxy(parent *httptest.Server) ProxyHandler { - return NewProxyHandler(func(r *http.Request) (*url.URL, error) { - return &url.URL{Host: parent.Listener.Addr().String()}, nil - }, nil, func(string) {}) +func newChildProxy(parent *httptest.Server) http.Handler { + parentURL := &url.URL{Host: parent.Listener.Addr().String()} + childProxy := NewProxyHandler(nil, getProxyFromContext, func(string) {}) + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + ctx := context.WithValue(req.Context(), contextKeyProxy, parentURL) + reqWithProxy := req.WithContext(ctx) + childProxy.ServeHTTP(w, reqWithProxy) + }) } func proxyServer(t *testing.T, proxy *httptest.Server) proxyFunc { @@ -336,3 +337,11 @@ func TestConnectResponseHasCorrectNewlines(t *testing.T) { assert.NotContains(t, noCRLFs, "\r", "response contains unmatched CR") assert.NotContains(t, noCRLFs, "\n", "response contains unmatched LF") } + +func TestConnectToNonExistentHost(t *testing.T) { + proxy := httptest.NewServer(newDirectProxy()) + defer proxy.Close() + client := http.Client{Transport: &http.Transport{Proxy: proxyServer(t, proxy)}} + _, err := client.Get("https://nonexistent.test") + require.Error(t, err) +} diff --git a/proxyfinder.go b/proxyfinder.go index 2f45901..2d99185 100644 --- a/proxyfinder.go +++ b/proxyfinder.go @@ -1,4 +1,4 @@ -// Copyright 2019, 2021 The Alpaca Authors +// Copyright 2019, 2021, 2022 The Alpaca Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ package main import ( + "context" "errors" "log" "net" @@ -24,6 +25,16 @@ import ( "sync" ) +const contextKeyProxy = contextKey("proxy") + +func getProxyFromContext(req *http.Request) (*url.URL, error) { + if value := req.Context().Value(contextKeyProxy); value != nil { + proxy := value.(*url.URL) + return proxy, nil + } + return nil, nil +} + type ProxyFinder struct { runner *PACRunner fetcher *pacFetcher @@ -53,7 +64,14 @@ func (pf *ProxyFinder) WrapHandler(next http.Handler) http.Handler { } return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { pf.checkForUpdates() - next.ServeHTTP(w, req) + proxy, err := pf.findProxyForRequest(req) + if err != nil { + log.Printf("[%d] %v", req.Context().Value(contextKeyID), err) + w.WriteHeader(http.StatusInternalServerError) + return + } + ctx := context.WithValue(req.Context(), contextKeyProxy, proxy) + next.ServeHTTP(w, req.WithContext(ctx)) }) }