Skip to content

Commit

Permalink
Add more tests for token client
Browse files Browse the repository at this point in the history
  • Loading branch information
sideshow committed Nov 3, 2016
1 parent 59dc349 commit 40fc7ef
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 6 deletions.
14 changes: 8 additions & 6 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ var (
HTTPClientTimeout = 30 * time.Second
)

// DialTLS is the default dial function for creating TLS connections for
// non-proxied HTTPS requests.
var DialTLS = func(network, addr string, cfg *tls.Config) (net.Conn, error) {
return tls.DialWithDialer(&net.Dialer{Timeout: TLSDialTimeout}, network, addr, cfg)
}

// Client represents a connection with the APNs
type Client struct {
Host string
Expand Down Expand Up @@ -64,9 +70,7 @@ func NewClient(certificate tls.Certificate) *Client {
}
transport := &http2.Transport{
TLSClientConfig: tlsConfig,
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
return tls.DialWithDialer(&net.Dialer{Timeout: TLSDialTimeout}, network, addr, cfg)
},
DialTLS: DialTLS,
}
return &Client{
HTTPClient: &http.Client{
Expand All @@ -88,9 +92,7 @@ func NewClient(certificate tls.Certificate) *Client {
// connection and disconnection as a denial-of-service attack.
func NewTokenClient(token *token.Token) *Client {
transport := &http2.Transport{
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
return tls.DialWithDialer(&net.Dialer{Timeout: TLSDialTimeout}, network, addr, cfg)
},
DialTLS: DialTLS,
}
return &Client{
Token: token,
Expand Down
40 changes: 40 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package apns2_test

import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"fmt"
"io/ioutil"
Expand All @@ -15,6 +18,7 @@ import (

apns "github.com/sideshow/apns2"
"github.com/sideshow/apns2/certificate"
"github.com/sideshow/apns2/token"
"github.com/stretchr/testify/assert"
)

Expand All @@ -27,6 +31,12 @@ func mockNotification() *apns.Notification {
return n
}

func mockToken() *token.Token {
pubkeyCurve := elliptic.P256()
authKey, _ := ecdsa.GenerateKey(pubkeyCurve, rand.Reader)
return &token.Token{AuthKey: authKey}
}

func mockCert() tls.Certificate {
return tls.Certificate{}
}
Expand All @@ -42,16 +52,31 @@ func TestClientDefaultHost(t *testing.T) {
assert.Equal(t, "https://api.development.push.apple.com", client.Host)
}

func TestTokenDefaultHost(t *testing.T) {
client := apns.NewTokenClient(mockToken()).Development()
assert.Equal(t, "https://api.development.push.apple.com", client.Host)
}

func TestClientDevelopmentHost(t *testing.T) {
client := apns.NewClient(mockCert()).Development()
assert.Equal(t, "https://api.development.push.apple.com", client.Host)
}

func TestTokenClientDevelopmentHost(t *testing.T) {
client := apns.NewTokenClient(mockToken()).Development()
assert.Equal(t, "https://api.development.push.apple.com", client.Host)
}

func TestClientProductionHost(t *testing.T) {
client := apns.NewClient(mockCert()).Production()
assert.Equal(t, "https://api.push.apple.com", client.Host)
}

func TestTokenClientProductionHost(t *testing.T) {
client := apns.NewTokenClient(mockToken()).Production()
assert.Equal(t, "https://api.push.apple.com", client.Host)
}

func TestClientBadUrlError(t *testing.T) {
n := mockNotification()
res, err := mockClient("badurl://badurl.com").Push(n)
Expand Down Expand Up @@ -150,6 +175,21 @@ func TestHeaders(t *testing.T) {
assert.NoError(t, err)
}

func TestAuthorizationHeader(t *testing.T) {
n := mockNotification()
token := mockToken()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "application/json; charset=utf-8", r.Header.Get("Content-Type"))
assert.Equal(t, fmt.Sprintf("bearer %v", token.Bearer), r.Header.Get("authorization"))
}))
defer server.Close()

client := mockClient(server.URL)
client.Token = token
_, err := client.Push(n)
assert.NoError(t, err)
}

func TestPayload(t *testing.T) {
n := mockNotification()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down

0 comments on commit 40fc7ef

Please sign in to comment.