diff --git a/conn.go b/conn.go index 4435ed81..d95fb888 100644 --- a/conn.go +++ b/conn.go @@ -155,6 +155,9 @@ type conn struct { // If not nil, notifications will be synchronously sent here notificationHandler func(*Notification) + + // GSSAPI context + gss Gss } // Handle driver-side settings in parsed connection string. @@ -1071,7 +1074,10 @@ func isDriverSetting(key string) bool { return true case "binary_parameters": return true - + case "service": + return true + case "spn": + return true default: return false } @@ -1151,6 +1157,56 @@ func (cn *conn) auth(r *readBuf, o values) { if r.int32() != 0 { errorf("unexpected authentication response: %q", t) } + case 7: // GSSAPI, startup + cli, err := NewGSS() + if err != nil { + errorf("kerberos error: %s", err.Error()) + } + + var token []byte + + if spn, ok := o["spn"]; ok { + // Use the supplied SPN if provided.. + token, err = cli.GetInitTokenFromSpn(spn) + } else { + // Allow the kerberos service name to be overridden + service := "postgres" + if val, ok := o["service"]; ok { + service = val + } + + token, err = cli.GetInitToken(o["host"], service) + } + + if err != nil { + errorf("failed to get Kerberos ticket: %q", err) + } + + w := cn.writeBuf('p') + w.bytes(token) + cn.send(w) + + // Store for GSSAPI continue message + cn.gss = cli + + case 8: // GSSAPI continue + + if cn.gss == nil { + errorf("GSSAPI protocol error") + } + + b := []byte(*r) + + done, tokOut, err := cn.gss.Continue(b) + if err == nil && !done { + w := cn.writeBuf('p') + w.bytes(tokOut) + cn.send(w) + } + + // Errors fall through and read the more detailed message + // from the server.. + case 10: sc := scram.NewClient(sha256.New, o["user"], o["password"]) sc.Step(nil) diff --git a/go.mod b/go.mod index edf0b343..a33cf4c4 100644 --- a/go.mod +++ b/go.mod @@ -1 +1,14 @@ module github.com/lib/pq + +go 1.13 + +require ( + github.com/alexbrainman/sspi v0.0.0-20180613141037-e580b900e9f5 // indirect + github.com/jcmturner/gokrb5/v8 v8.2.0 + golang.org/x/crypto v0.0.0-20200311171314-f7b00557c8c4 // indirect + gopkg.in/jcmturner/aescts.v1 v1.0.1 // indirect + gopkg.in/jcmturner/dnsutils.v1 v1.0.1 // indirect + gopkg.in/jcmturner/goidentity.v3 v3.0.0 // indirect + gopkg.in/jcmturner/gokrb5.v7 v7.5.0 + gopkg.in/jcmturner/rpc.v1 v1.1.0 // indirect +) diff --git a/krb.go b/krb.go new file mode 100644 index 00000000..e98432f0 --- /dev/null +++ b/krb.go @@ -0,0 +1,40 @@ +package pq + +import ( + "net" + "strings" +) + +/* + * Basic GSSAPI interface to abstract Windows (SSPI) from Unix + * APIs within the driver + */ + +type Gss interface { + GetInitToken(host string, service string) ([]byte, error) + GetInitTokenFromSpn(spn string) ([]byte, error) + Continue(inToken []byte) (done bool, outToken []byte, err error) +} + +/* + * Find the A record associated with a hostname + * In general, hostnames supplied to the driver should be + * canonicalized because the KDC usually only has one + * principal and not one per potential alias of a host. + */ +func canonicalizeHostname(host string) (string, error) { + canon := host + + name, err := net.LookupCNAME(host) + if err != nil { + return "", err + } + + name = strings.TrimSuffix(name, ".") + + if name != "" { + canon = name + } + + return canon, nil +} diff --git a/krb_unix.go b/krb_unix.go new file mode 100644 index 00000000..b5483663 --- /dev/null +++ b/krb_unix.go @@ -0,0 +1,123 @@ +// +build !windows + +package pq + +import ( + "fmt" + "os" + "os/user" + "strings" + + "github.com/jcmturner/gokrb5/v8/client" + "github.com/jcmturner/gokrb5/v8/config" + "github.com/jcmturner/gokrb5/v8/credentials" + "github.com/jcmturner/gokrb5/v8/spnego" +) + +/* + * UNIX Kerberos support, using jcmturner's pure-go + * implementation + */ + +// Implements the Gss interface +type gss struct { + cli *client.Client +} + +func NewGSS() (Gss, error) { + g := &gss{} + err := g.init() + + if err != nil { + return nil, err + } + + return g, nil +} + +func (g *gss) init() error { + cfgPath, ok := os.LookupEnv("KRB5_CONFIG") + if !ok { + cfgPath = "/etc/krb5.conf" + } + + cfg, err := config.Load(cfgPath) + if err != nil { + return err + } + + u, err := user.Current() + if err != nil { + return err + } + + ccpath := "/tmp/krb5cc_" + u.Uid + + ccname := os.Getenv("KRB5CCNAME") + if strings.HasPrefix(ccname, "FILE:") { + ccpath = strings.SplitN(ccname, ":", 2)[1] + } + + ccache, err := credentials.LoadCCache(ccpath) + if err != nil { + return err + } + + cl, err := client.NewFromCCache(ccache, cfg, client.DisablePAFXFAST(true)) + if err != nil { + return err + } + + cl.Login() + + g.cli = cl + + return nil +} + +func (g *gss) GetInitToken(host string, service string) ([]byte, error) { + + // Resolve the hostname down to an 'A' record, if required (usually, it is) + if g.cli.Config.LibDefaults.DNSCanonicalizeHostname { + var err error + host, err = canonicalizeHostname(host) + if err != nil { + return nil, err + } + } + + spn := service + "/" + host + + return g.GetInitTokenFromSpn(spn) +} + +func (g *gss) GetInitTokenFromSpn(spn string) ([]byte, error) { + s := spnego.SPNEGOClient(g.cli, spn) + + st, err := s.InitSecContext() + if err != nil { + return nil, fmt.Errorf("kerberos error (InitSecContext): %s", err.Error()) + } + + b, err := st.Marshal() + if err != nil { + return nil, fmt.Errorf("kerberos error (Marshaling token): %s", err.Error()) + } + + return b, nil +} + +func (g *gss) Continue(inToken []byte) (done bool, outToken []byte, err error) { + t := &spnego.SPNEGOToken{} + err = t.Unmarshal(inToken) + if err != nil { + return true, nil, fmt.Errorf("kerberos error (Unmarshaling token): %s", err.Error()) + } + + state := t.NegTokenResp.State() + if state != spnego.NegStateAcceptCompleted { + return true, nil, fmt.Errorf("kerberos: expected state 'Completed' - got %d", state) + } + + return true, nil, nil +} diff --git a/krb_windows.go b/krb_windows.go new file mode 100644 index 00000000..71e328ea --- /dev/null +++ b/krb_windows.go @@ -0,0 +1,61 @@ +// +build windows + +package pq + +import ( + "github.com/alexbrainman/sspi" + "github.com/alexbrainman/sspi/negotiate" +) + +type gss struct { + creds *sspi.Credentials + ctx *negotiate.ClientContext +} + +func NewGSS() (Gss, error) { + g := &gss{} + err := g.init() + + if err != nil { + return nil, err + } + + return g, nil +} + +func (g *gss) init() error { + creds, err := negotiate.AcquireCurrentUserCredentials() + if err != nil { + return err + } + + g.creds = creds + return nil +} + +func (g *gss) GetInitToken(host string, service string) ([]byte, error) { + + host, err := canonicalizeHostname(host) + if err != nil { + return nil, err + } + + spn := service + "/" + host + + return g.GetInitTokenFromSpn(spn) +} + +func (g *gss) GetInitTokenFromSpn(spn string) ([]byte, error) { + ctx, token, err := negotiate.NewClientContext(g.creds, spn) + if err != nil { + return nil, err + } + + g.ctx = ctx + + return token, nil +} + +func (g *gss) Continue(inToken []byte) (done bool, outToken []byte, err error) { + return g.ctx.Update(inToken) +}