diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml new file mode 100644 index 000000000..3a09b77cb --- /dev/null +++ b/.github/workflows/codeql-analysis.yml @@ -0,0 +1,24 @@ +name: "CodeQL" + +on: [push, pull_request] + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + steps: + - name: Checkout repo + uses: actions/checkout@v3 + + - name: Initialize CodeQL + uses: github/codeql-action/init@v1 + with: + languages: 'go' + + - name: CodeQL Analysis + uses: github/codeql-action/analyze@v1 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 000000000..e1b413c9e --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,214 @@ +name: Test + +on: [push, pull_request] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + postgres: + - '15' + - '14' + - '13' + - '12' + - '11' + - '10' + - '9.6' + go: + - '1.20' + - '1.19' + - '1.18' + - '1.17' + - '1.16' + - '1.15' + - '1.14' + steps: + - name: setup postgres pre-reqs + run: | + mkdir init + cat < init/root.crt + -----BEGIN CERTIFICATE----- + MIIEBjCCAu6gAwIBAgIJAPizR+OD14YnMA0GCSqGSIb3DQEBCwUAMF4xCzAJBgNV + BAYTAlVTMQ8wDQYDVQQIDAZOZXZhZGExEjAQBgNVBAcMCUxhcyBWZWdhczEaMBgG + A1UECgwRZ2l0aHViLmNvbS9saWIvcHExDjAMBgNVBAMMBXBxIENBMB4XDTIxMDkw + MjAxNTUwMloXDTMxMDkwMzAxNTUwMlowXjELMAkGA1UEBhMCVVMxDzANBgNVBAgM + Bk5ldmFkYTESMBAGA1UEBwwJTGFzIFZlZ2FzMRowGAYDVQQKDBFnaXRodWIuY29t + L2xpYi9wcTEOMAwGA1UEAwwFcHEgQ0EwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw + ggEKAoIBAQDb9d6sjdU6GdibGrXRMOHREH3MRUS8T4TFqGgPEGVDP/V5bAZlBSGP + AN0o9DTyVLcbQpBt8zMTw9KeIzIIe5NIVkSmA16lw/YckGhOM+kZIkiDuE6qt5Ia + OQCRMdXkZ8ejG/JUu+rHU8FJZL8DE+jyYherzdjkeVAQ7JfzxAwW2Dl7T/47g337 + Pwmf17AEb8ibSqmXyUN7R5NhJQs+hvaYdNagzdx91E1H+qlyBvmiNeasUQljLvZ+ + Y8wAuU79neA+d09O4PBiYwV17rSP6SZCeGE3oLZviL/0KM9Xig88oB+2FmvQ6Zxa + L7SoBlqS+5pBZwpH7eee/wCIKAnJtMAJAgMBAAGjgcYwgcMwDwYDVR0TAQH/BAUw + AwEB/zAdBgNVHQ4EFgQUfIXEczahbcM2cFrwclJF7GbdajkwgZAGA1UdIwSBiDCB + hYAUfIXEczahbcM2cFrwclJF7GbdajmhYqRgMF4xCzAJBgNVBAYTAlVTMQ8wDQYD + VQQIDAZOZXZhZGExEjAQBgNVBAcMCUxhcyBWZWdhczEaMBgGA1UECgwRZ2l0aHVi + LmNvbS9saWIvcHExDjAMBgNVBAMMBXBxIENBggkA+LNH44PXhicwDQYJKoZIhvcN + AQELBQADggEBABFyGgSz2mHVJqYgX1Y+7P+MfKt83cV2uYDGYvXrLG2OGiCilVul + oTBG+8omIMSHOsQZvWMpA5H0tnnlQHrKpKpUyKkSL+Wv5GL0UtBmHX7mVRiaK2l4 + q2BjRaQUitp/FH4NSdXtVrMME5T1JBBZHsQkNL3cNRzRKwY/Vj5UGEDxDS7lILUC + e01L4oaK0iKQn4beALU+TvKoAHdPvoxpPpnhkF5ss9HmdcvRktJrKZemDJZswZ7/ + +omx8ZPIYYUH5VJJYYE88S7guAt+ZaKIUlel/t6xPbo2ZySFSg9u1uB99n+jTo3L + 1rAxFnN3FCX2jBqgP29xMVmisaN5k04UmyI= + -----END CERTIFICATE----- + CONF + cat < init/server.crt + -----BEGIN CERTIFICATE----- + MIIDqzCCApOgAwIBAgIJAPiewLrOyYipMA0GCSqGSIb3DQEBCwUAMF4xCzAJBgNV + BAYTAlVTMQ8wDQYDVQQIDAZOZXZhZGExEjAQBgNVBAcMCUxhcyBWZWdhczEaMBgG + A1UECgwRZ2l0aHViLmNvbS9saWIvcHExDjAMBgNVBAMMBXBxIENBMB4XDTIxMDkw + MjAxNTUwMloXDTMxMDkwMzAxNTUwMlowTjELMAkGA1UEBhMCVVMxDzANBgNVBAgM + Bk5ldmFkYTESMBAGA1UEBwwJTGFzIFZlZ2FzMRowGAYDVQQKDBFnaXRodWIuY29t + L2xpYi9wcTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKf6H4UzmANN + QiQJe92Mf3ETMYmpZKNNO9DPEHyNLIkag+XwMrBTdcCK0mLvsNCYpXuBN6703KCd + WAFOeMmj7gOsWtvjt5Xm6bRHLgegekXzcG/jDwq/wyzeDzr/YkITuIlG44Lf9lhY + FLwiHlHOWHnwrZaEh6aU//02aQkzyX5INeXl/3TZm2G2eIH6AOxOKOU27MUsyVSQ + 5DE+SDKGcRP4bElueeQWvxAXNMZYb7sVSDdfHI3zr32K4k/tC8x0fZJ5XN/dvl4t + 4N4MrYlmDO5XOrb/gQH1H4iu6+5EMDfZYab4fkThnNFdfFqu4/8Scv7KZ8mWqpKM + fGAjEPctQi0CAwEAAaN8MHowHQYDVR0OBBYEFENExPbmDyFB2AJUdbMvVyhlNPD5 + MAkGA1UdEwQCMAAwCwYDVR0PBAQDAgWgMBMGA1UdEQQMMAqCCHBvc3RncmVzMCwG + CWCGSAGG+EIBDQQfFh1PcGVuU1NMIEdlbmVyYXRlZCBDZXJ0aWZpY2F0ZTANBgkq + hkiG9w0BAQsFAAOCAQEAMRVbV8RiEsmp9HAtnVCZmRXMIbgPGrqjeSwk586s4K8v + BSqNCqxv6s5GfCRmDYiqSqeuCVDtUJS1HsTmbxVV7Ke71WMo+xHR1ICGKOa8WGCb + TGsuicG5QZXWaxeMOg4s0qpKmKko0d1aErdVsanU5dkrVS7D6729Ffnzu4lwApk6 + invAB67p8u7sojwqRq5ce0vRaG+YFylTrWomF9kauEb8gKbQ9Xc7QfX+h+UH/mq9 + Nvdj8LOHp6/82bZdnsYUOtV4lS1IA/qzeXpqBphxqfWabD1yLtkyJyImZKq8uIPp + 0CG4jhObPdWcCkXD6bg3QK3mhwlC79OtFgxWmldCRQ== + -----END CERTIFICATE----- + CONF + cat < init/server.key + -----BEGIN PRIVATE KEY----- + MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCn+h+FM5gDTUIk + CXvdjH9xEzGJqWSjTTvQzxB8jSyJGoPl8DKwU3XAitJi77DQmKV7gTeu9NygnVgB + TnjJo+4DrFrb47eV5um0Ry4HoHpF83Bv4w8Kv8Ms3g86/2JCE7iJRuOC3/ZYWBS8 + Ih5Rzlh58K2WhIemlP/9NmkJM8l+SDXl5f902ZthtniB+gDsTijlNuzFLMlUkOQx + PkgyhnET+GxJbnnkFr8QFzTGWG+7FUg3XxyN8699iuJP7QvMdH2SeVzf3b5eLeDe + DK2JZgzuVzq2/4EB9R+IruvuRDA32WGm+H5E4ZzRXXxaruP/EnL+ymfJlqqSjHxg + IxD3LUItAgMBAAECggEAOE2naQ9tIZYw2EFxikZApVcooJrtx6ropMnzHbx4NBB2 + K4mChAXFj184u77ZxmGT/jzGvFcI6LE0wWNbK0NOUV7hKZk/fPhkV3AQZrAMrAu4 + IVi7PwAd3JkmA8F8XuebUDA5rDGDsgL8GD9baFJA58abeLs9eMGyuF4XgOUh4bip + hgHa76O2rcDWNY5HZqqRslw75FzlYkB0PCts/UJxSswj70kTTihyOhDlrm2TnyxI + ne54UbGRrpfs9wiheSGLjDG81qZToBHQDwoAnjjZhu1VCaBISuGbgZrxyyRyqdnn + xPW+KczMv04XyvF7v6Pz+bUEppalLXGiXnH5UtWvZQKBgQDTPCdMpNE/hwlq4nAw + Kf42zIBWfbnMLVWYoeDiAOhtl9XAUAXn76xe6Rvo0qeAo67yejdbJfRq3HvGyw+q + 4PS8r9gXYmLYIPQxSoLL5+rFoBCN3qFippfjLB1j32mp7+15KjRj8FF2r6xIN8fu + XatSRsaqmvCWYLDRv/rbHnxwkwKBgQDLkyfFLF7BtwtPWKdqrwOM7ip1UKh+oDBS + vkCQ08aEFRBU7T3jChsx5GbaW6zmsSBwBwcrHclpSkz7n3aq19DDWObJR2p80Fma + rsXeIcvtEpkvT3pVX268P5d+XGs1kxgFunqTysG9yChW+xzcs5MdKBzuMPPn7rL8 + MKAzdar6PwKBgEypkzW8x3h/4Moa3k6MnwdyVs2NGaZheaRIc95yJ+jGZzxBjrMr + h+p2PbvU4BfO0AqOkpKRBtDVrlJqlggVVp04UHvEKE16QEW3Xhr0037f5cInX3j3 + Lz6yXwRFLAsR2aTUzWjL6jTh8uvO2s/GzQuyRh3a16Ar/WBShY+K0+zjAoGATnLT + xZjWnyHRmu8X/PWakamJ9RFzDPDgDlLAgM8LVgTj+UY/LgnL9wsEU6s2UuP5ExKy + QXxGDGwUhHar/SQTj+Pnc7Mwpw6HKSOmnnY5po8fNusSwml3O9XppEkrC0c236Y/ + 7EobJO5IFVTJh4cv7vFxTJzSsRL8KFD4uzvh+nMCgYEAqY8NBYtIgNJA2B6C6hHF + +bG7v46434ZHFfGTmMQwzE4taVg7YRnzYESAlvK4bAP5ZXR90n7GRGFhrXzoMZ38 + r0bw/q9rV+ReGda7/Bjf7ciCKiq0RODcHtf4IaskjPXCoQRGJtgCPLhWPfld6g9v + /HTvO96xv9e3eG/PKSPog94= + -----END PRIVATE KEY----- + CONF + cat < init/hba.sh + cat < /var/lib/postgresql/data/pg_hba.conf + local all all trust + host all postgres all trust + hostnossl all pqgossltest all reject + hostnossl all pqgosslcert all reject + hostssl all pqgossltest all trust + hostssl all pqgosslcert all cert + host all all all trust + EOF + CONF + sudo chown 999:999 ./init/* + sudo chmod 600 ./init/* + + - name: start postgres + run: | + docker run -d \ + --name pg \ + -p 5432:5432 \ + -v $(pwd)/init:/init \ + -e POSTGRES_PASSWORD=unused \ + -e POSTGRES_USER=postgres \ + postgres:${{ matrix.postgres }} \ + -c ssl=on \ + -c ssl_ca_file=/init/root.crt \ + -c ssl_cert_file=/init/server.crt \ + -c ssl_key_file=/init/server.key + + - name: configure postgres + run: | + n=0 + until [ "$n" -ge 10 ] + do + docker exec pg pg_isready -h localhost && break + n=$((n+1)) + echo waiting for postgres to be ready... + sleep 1 + done + docker exec pg bash /init/hba.sh + n=0 + until [ "$n" -ge 10 ] + do + docker exec pg su postgres -c '/usr/lib/postgresql/${{ matrix.postgres }}/bin/pg_ctl reload' && break + n=$((n+1)) + echo waiting for postgres to reload... + sleep 1 + done + + - name: setup hosts + run: echo '127.0.0.1 postgres' | sudo tee -a /etc/hosts + + - name: create db/roles + run: | + n=0 + until [ "$n" -ge 10 ] + do + docker exec pg pg_isready -h localhost && break + n=$((n+1)) + echo waiting for postgres to be ready... + sleep 1 + done + docker exec pg createdb -h localhost -U postgres pqgotest + docker exec pg createuser -h localhost -U postgres -DRS pqgossltest + docker exec pg createuser -h localhost -U postgres -DRS pqgosslcert + + - name: check out code into the Go module directory + uses: actions/checkout@v3 + + - name: set up go + uses: actions/setup-go@v4 + with: + go-version: ${{ matrix.go }} + id: go + + - name: set key perms + run: sudo chmod 600 certs/postgresql.key + + - name: run tests + env: + PGUSER: postgres + PGHOST: localhost + PGPORT: 5432 + PQGOSSLTESTS: 1 + PQSSLCERTTEST_PATH: certs + run: | + PQTEST_BINARY_PARAMETERS=no go test -race -v ./... + PQTEST_BINARY_PARAMETERS=yes go test -race -v ./... + + - name: install goimports + run: go get golang.org/x/tools/cmd/goimports + + - name: install staticcheck + run: | + wget https://github.com/dominikh/go-tools/releases/latest/download/staticcheck_linux_amd64.tar.gz -O - | tar -xz staticcheck + + - name: run goimports + run: | + goimports -d -e . | awk '{ print } END { exit NR == 0 ? 0 : 1 }' + + - name: run staticcheck + run: ./staticcheck/staticcheck -go 1.13 ./... + + - name: build + run: go build -v . diff --git a/.gitignore b/.gitignore index 0f1d00e11..3243952a4 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ *.test *~ *.swp +.idea +.vscode \ No newline at end of file diff --git a/.travis.sh b/.travis.sh deleted file mode 100755 index ebf447030..000000000 --- a/.travis.sh +++ /dev/null @@ -1,73 +0,0 @@ -#!/bin/bash - -set -eu - -client_configure() { - sudo chmod 600 $PQSSLCERTTEST_PATH/postgresql.key -} - -pgdg_repository() { - local sourcelist='sources.list.d/postgresql.list' - - curl -sS 'https://www.postgresql.org/media/keys/ACCC4CF8.asc' | sudo apt-key add - - echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION | sudo tee "/etc/apt/$sourcelist" - sudo apt-get -o Dir::Etc::sourcelist="$sourcelist" -o Dir::Etc::sourceparts='-' -o APT::Get::List-Cleanup='0' update -} - -postgresql_configure() { - sudo tee /etc/postgresql/$PGVERSION/main/pg_hba.conf > /dev/null <<-config - local all all trust - hostnossl all pqgossltest 127.0.0.1/32 reject - hostnossl all pqgosslcert 127.0.0.1/32 reject - hostssl all pqgossltest 127.0.0.1/32 trust - hostssl all pqgosslcert 127.0.0.1/32 cert - host all all 127.0.0.1/32 trust - hostnossl all pqgossltest ::1/128 reject - hostnossl all pqgosslcert ::1/128 reject - hostssl all pqgossltest ::1/128 trust - hostssl all pqgosslcert ::1/128 cert - host all all ::1/128 trust - config - - xargs sudo install -o postgres -g postgres -m 600 -t /var/lib/postgresql/$PGVERSION/main/ <<-certificates - certs/root.crt - certs/server.crt - certs/server.key - certificates - - sort -VCu <<-versions || - $PGVERSION - 9.2 - versions - sudo tee -a /etc/postgresql/$PGVERSION/main/postgresql.conf > /dev/null <<-config - ssl_ca_file = 'root.crt' - ssl_cert_file = 'server.crt' - ssl_key_file = 'server.key' - config - - echo 127.0.0.1 postgres | sudo tee -a /etc/hosts > /dev/null - - sudo service postgresql restart -} - -postgresql_install() { - xargs sudo apt-get -y -o Dpkg::Options::='--force-confdef' -o Dpkg::Options::='--force-confnew' install <<-packages - postgresql-$PGVERSION - postgresql-server-dev-$PGVERSION - postgresql-contrib-$PGVERSION - packages -} - -postgresql_uninstall() { - sudo service postgresql stop - xargs sudo apt-get -y --purge remove <<-packages - libpq-dev - libpq5 - postgresql - postgresql-client-common - postgresql-common - packages - sudo rm -rf /var/lib/postgresql -} - -$1 diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 3498c53dc..000000000 --- a/.travis.yml +++ /dev/null @@ -1,44 +0,0 @@ -language: go - -go: - - 1.13.x - - 1.14.x - - master - -sudo: true - -env: - global: - - PGUSER=postgres - - PQGOSSLTESTS=1 - - PQSSLCERTTEST_PATH=$PWD/certs - - PGHOST=127.0.0.1 - matrix: - - PGVERSION=10 - - PGVERSION=9.6 - - PGVERSION=9.5 - - PGVERSION=9.4 - -before_install: - - ./.travis.sh postgresql_uninstall - - ./.travis.sh pgdg_repository - - ./.travis.sh postgresql_install - - ./.travis.sh postgresql_configure - - ./.travis.sh client_configure - - go get golang.org/x/tools/cmd/goimports - - go get golang.org/x/lint/golint - - GO111MODULE=on go get honnef.co/go/tools/cmd/staticcheck@2020.1.3 - -before_script: - - createdb pqgotest - - createuser -DRS pqgossltest - - createuser -DRS pqgosslcert - -script: - - > - goimports -d -e $(find -name '*.go') | awk '{ print } END { exit NR == 0 ? 0 : 1 }' - - go vet ./... - - staticcheck -go 1.13 ./... - - golint ./... - - PQTEST_BINARY_PARAMETERS=no go test -race -v ./... - - PQTEST_BINARY_PARAMETERS=yes go test -race -v ./... diff --git a/README.md b/README.md index ecd01939b..126ee5d35 100644 --- a/README.md +++ b/README.md @@ -19,10 +19,7 @@ * Unix socket support * Notifications: `LISTEN`/`NOTIFY` * pgpass support - -## Optional Features - -* GSS (Kerberos) auth (to use, see GoDoc) +* GSS (Kerberos) auth ## Tests @@ -30,4 +27,10 @@ ## Status -This package is effectively in maintenance mode and is not actively developed. Small patches and features are only rarely reviewed and merged. We recommend using [pgx](https://github.com/jackc/pgx) which is actively maintained. +This package is currently in maintenance mode, which means: +1. It generally does not accept new features. +2. It does accept bug fixes and version compatability changes provided by the community. +3. Maintainers usually do not resolve reported issues. +4. Community members are encouraged to help each other with reported issues. + +For users that require new features or reliable resolution of reported bugs, we recommend using [pgx](https://github.com/jackc/pgx) which is under active development. diff --git a/array.go b/array.go index e4933e227..39c8f7e2e 100644 --- a/array.go +++ b/array.go @@ -22,7 +22,7 @@ var typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() // db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401})) // // var x []sql.NullInt64 -// db.QueryRow('SELECT ARRAY[235, 401]').Scan(pq.Array(&x)) +// db.QueryRow(`SELECT ARRAY[235, 401]`).Scan(pq.Array(&x)) // // Scanning multi-dimensional arrays is not supported. Arrays where the lower // bound is not one (such as `[0:0]={1}') are not supported. @@ -35,19 +35,31 @@ func Array(a interface{}) interface { return (*BoolArray)(&a) case []float64: return (*Float64Array)(&a) + case []float32: + return (*Float32Array)(&a) case []int64: return (*Int64Array)(&a) + case []int32: + return (*Int32Array)(&a) case []string: return (*StringArray)(&a) + case [][]byte: + return (*ByteaArray)(&a) case *[]bool: return (*BoolArray)(a) case *[]float64: return (*Float64Array)(a) + case *[]float32: + return (*Float32Array)(a) case *[]int64: return (*Int64Array)(a) + case *[]int32: + return (*Int32Array)(a) case *[]string: return (*StringArray)(a) + case *[][]byte: + return (*ByteaArray)(a) } return GenericArray{a} @@ -267,6 +279,70 @@ func (a Float64Array) Value() (driver.Value, error) { return "{}", nil } +// Float32Array represents a one-dimensional array of the PostgreSQL double +// precision type. +type Float32Array []float32 + +// Scan implements the sql.Scanner interface. +func (a *Float32Array) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to Float32Array", src) +} + +func (a *Float32Array) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Float32Array") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Float32Array, len(elems)) + for i, v := range elems { + var x float64 + if x, err = strconv.ParseFloat(string(v), 32); err != nil { + return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + } + b[i] = float32(x) + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Float32Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendFloat(b, float64(a[0]), 'f', -1, 32) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendFloat(b, float64(a[i]), 'f', -1, 32) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + // GenericArray implements the driver.Valuer and sql.Scanner interfaces for // an array or slice of any dimension. type GenericArray struct{ A interface{} } @@ -483,6 +559,69 @@ func (a Int64Array) Value() (driver.Value, error) { return "{}", nil } +// Int32Array represents a one-dimensional array of the PostgreSQL integer types. +type Int32Array []int32 + +// Scan implements the sql.Scanner interface. +func (a *Int32Array) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to Int32Array", src) +} + +func (a *Int32Array) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Int32Array") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Int32Array, len(elems)) + for i, v := range elems { + x, err := strconv.ParseInt(string(v), 10, 32) + if err != nil { + return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + } + b[i] = int32(x) + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Int32Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendInt(b, int64(a[0]), 10) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendInt(b, int64(a[i]), 10) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + // StringArray represents a one-dimensional array of the PostgreSQL character types. type StringArray []string diff --git a/array_test.go b/array_test.go index f724bcd88..5ca9f7a55 100644 --- a/array_test.go +++ b/array_test.go @@ -104,16 +104,33 @@ func TestArrayScanner(t *testing.T) { t.Errorf("Expected *Int64Array, got %T", s) } + s = Array(&[]float32{}) + if _, ok := s.(*Float32Array); !ok { + t.Errorf("Expected *Float32Array, got %T", s) + } + + s = Array(&[]int32{}) + if _, ok := s.(*Int32Array); !ok { + t.Errorf("Expected *Int32Array, got %T", s) + } + s = Array(&[]string{}) if _, ok := s.(*StringArray); !ok { t.Errorf("Expected *StringArray, got %T", s) } + s = Array(&[][]byte{}) + if _, ok := s.(*ByteaArray); !ok { + t.Errorf("Expected *ByteaArray, got %T", s) + } + for _, tt := range []interface{}{ &[]sql.Scanner{}, &[][]bool{}, &[][]float64{}, &[][]int64{}, + &[][]float32{}, + &[][]int32{}, &[][]string{}, } { s = Array(tt) @@ -139,17 +156,34 @@ func TestArrayValuer(t *testing.T) { t.Errorf("Expected *Int64Array, got %T", v) } + v = Array([]float32{}) + if _, ok := v.(*Float32Array); !ok { + t.Errorf("Expected *Float32Array, got %T", v) + } + + v = Array([]int32{}) + if _, ok := v.(*Int32Array); !ok { + t.Errorf("Expected *Int32Array, got %T", v) + } + v = Array([]string{}) if _, ok := v.(*StringArray); !ok { t.Errorf("Expected *StringArray, got %T", v) } + v = Array([][]byte{}) + if _, ok := v.(*ByteaArray); !ok { + t.Errorf("Expected *ByteaArray, got %T", v) + } + for _, tt := range []interface{}{ nil, []driver.Value{}, [][]bool{}, [][]float64{}, [][]int64{}, + [][]float32{}, + [][]int32{}, [][]string{}, } { v = Array(tt) @@ -773,6 +807,313 @@ func BenchmarkInt64ArrayValue(b *testing.B) { } } +func TestFloat32ArrayScanUnsupported(t *testing.T) { + var arr Float32Array + err := arr.Scan(true) + + if err == nil { + t.Fatal("Expected error when scanning from bool") + } + if !strings.Contains(err.Error(), "bool to Float32Array") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +func TestFloat32ArrayScanEmpty(t *testing.T) { + var arr Float32Array + err := arr.Scan(`{}`) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr == nil || len(arr) != 0 { + t.Errorf("Expected empty, got %#v", arr) + } +} + +func TestFloat32ArrayScanNil(t *testing.T) { + arr := Float32Array{5, 5, 5} + err := arr.Scan(nil) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr != nil { + t.Errorf("Expected nil, got %+v", arr) + } +} + +var Float32ArrayStringTests = []struct { + str string + arr Float32Array +}{ + {`{}`, Float32Array{}}, + {`{1.2}`, Float32Array{1.2}}, + {`{3.456,7.89}`, Float32Array{3.456, 7.89}}, + {`{3,1,2}`, Float32Array{3, 1, 2}}, +} + +func TestFloat32ArrayScanBytes(t *testing.T) { + for _, tt := range Float32ArrayStringTests { + bytes := []byte(tt.str) + arr := Float32Array{5, 5, 5} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkFloat32ArrayScanBytes(b *testing.B) { + var a Float32Array + var x interface{} = []byte(`{1.2,3.4,5.6,7.8,9.01,2.34,5.67,8.90,1.234,5.678}`) + + for i := 0; i < b.N; i++ { + a = Float32Array{} + a.Scan(x) + } +} + +func TestFloat32ArrayScanString(t *testing.T) { + for _, tt := range Float32ArrayStringTests { + arr := Float32Array{5, 5, 5} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestFloat32ArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{5.6},{7.8}}`, "cannot convert ARRAY[2][1] to Float32Array"}, + {`{NULL}`, "parsing array element index 0:"}, + {`{a}`, "parsing array element index 0:"}, + {`{5.6,a}`, "parsing array element index 1:"}, + {`{5.6,7.8,a}`, "parsing array element index 2:"}, + } { + arr := Float32Array{5, 5, 5} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, Float32Array{5, 5, 5}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestFloat32ArrayValue(t *testing.T) { + result, err := Float32Array(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = Float32Array([]float32{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = Float32Array([]float32{1.2, 3.4, 5.6}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{1.2,3.4,5.6}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkFloat32ArrayValue(b *testing.B) { + rand.Seed(1) + x := make([]float32, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.Float32() + } + a := Float32Array(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func TestInt32ArrayScanUnsupported(t *testing.T) { + var arr Int32Array + err := arr.Scan(true) + + if err == nil { + t.Fatal("Expected error when scanning from bool") + } + if !strings.Contains(err.Error(), "bool to Int32Array") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +func TestInt32ArrayScanEmpty(t *testing.T) { + var arr Int32Array + err := arr.Scan(`{}`) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr == nil || len(arr) != 0 { + t.Errorf("Expected empty, got %#v", arr) + } +} + +func TestInt32ArrayScanNil(t *testing.T) { + arr := Int32Array{5, 5, 5} + err := arr.Scan(nil) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr != nil { + t.Errorf("Expected nil, got %+v", arr) + } +} + +var Int32ArrayStringTests = []struct { + str string + arr Int32Array +}{ + {`{}`, Int32Array{}}, + {`{12}`, Int32Array{12}}, + {`{345,678}`, Int32Array{345, 678}}, +} + +func TestInt32ArrayScanBytes(t *testing.T) { + for _, tt := range Int32ArrayStringTests { + bytes := []byte(tt.str) + arr := Int32Array{5, 5, 5} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkInt32ArrayScanBytes(b *testing.B) { + var a Int32Array + var x interface{} = []byte(`{1,2,3,4,5,6,7,8,9,0}`) + + for i := 0; i < b.N; i++ { + a = Int32Array{} + a.Scan(x) + } +} + +func TestInt32ArrayScanString(t *testing.T) { + for _, tt := range Int32ArrayStringTests { + arr := Int32Array{5, 5, 5} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestInt32ArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{5},{6}}`, "cannot convert ARRAY[2][1] to Int32Array"}, + {`{NULL}`, "parsing array element index 0:"}, + {`{a}`, "parsing array element index 0:"}, + {`{5,a}`, "parsing array element index 1:"}, + {`{5,6,a}`, "parsing array element index 2:"}, + } { + arr := Int32Array{5, 5, 5} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, Int32Array{5, 5, 5}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestInt32ArrayValue(t *testing.T) { + result, err := Int32Array(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = Int32Array([]int32{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = Int32Array([]int32{1, 2, 3}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{1,2,3}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkInt32ArrayValue(b *testing.B) { + rand.Seed(1) + x := make([]int32, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.Int31() + } + a := Int32Array(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + func TestStringArrayScanUnsupported(t *testing.T) { var arr StringArray err := arr.Scan(true) diff --git a/auth/kerberos/krb_unix.go b/auth/kerberos/krb_unix.go index 7d5ec76a5..b6f27ce57 100644 --- a/auth/kerberos/krb_unix.go +++ b/auth/kerberos/krb_unix.go @@ -1,3 +1,4 @@ +//go:build !windows // +build !windows package kerberos diff --git a/auth/kerberos/krb_windows.go b/auth/kerberos/krb_windows.go index 973be8fc6..ca26d02c6 100644 --- a/auth/kerberos/krb_windows.go +++ b/auth/kerberos/krb_windows.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package kerberos @@ -8,7 +9,7 @@ import ( ) // GSS implements the pq.GSS interface. -type Gss struct { +type GSS struct { creds *sspi.Credentials ctx *negotiate.ClientContext } diff --git a/certs/Makefile b/certs/Makefile new file mode 100644 index 000000000..a84e31e9c --- /dev/null +++ b/certs/Makefile @@ -0,0 +1,37 @@ +.PHONY: all root-ssl server-ssl client-ssl + +# Rebuilds self-signed root/server/client certs/keys in a consistent way +all: root-ssl server-ssl client-ssl + rm -f .srl + +root-ssl: + openssl req -new -sha256 -nodes -newkey rsa:2048 \ + -config ./certs/root.cnf \ + -keyout /tmp/root.key \ + -out /tmp/root.csr + openssl x509 -req -days 3653 -sha256 \ + -in /tmp/root.csr \ + -extfile /etc/ssl/openssl.cnf -extensions v3_ca \ + -signkey /tmp/root.key \ + -out ./certs/root.crt + +server-ssl: + openssl req -new -sha256 -nodes -newkey rsa:2048 \ + -config ./certs/server.cnf \ + -keyout ./certs/server.key \ + -out /tmp/server.csr + openssl x509 -req -days 3653 -sha256 \ + -extfile ./certs/server.cnf -extensions req_ext \ + -CA ./certs/root.crt -CAkey /tmp/root.key -CAcreateserial \ + -in /tmp/server.csr \ + -out ./certs/server.crt + +client-ssl: + openssl req -new -sha256 -nodes -newkey rsa:2048 \ + -config ./certs/postgresql.cnf \ + -keyout ./certs/postgresql.key \ + -out /tmp/postgresql.csr + openssl x509 -req -days 3653 -sha256 \ + -CA ./certs/root.crt -CAkey /tmp/root.key -CAcreateserial \ + -in /tmp/postgresql.csr \ + -out ./certs/postgresql.crt diff --git a/certs/postgresql.cnf b/certs/postgresql.cnf new file mode 100644 index 000000000..fa8ffc489 --- /dev/null +++ b/certs/postgresql.cnf @@ -0,0 +1,10 @@ +[req] +distinguished_name = req_distinguished_name +prompt = no + +[req_distinguished_name] +C = US +ST = Nevada +L = Las Vegas +O = github.com/lib/pq +CN = pqgosslcert diff --git a/certs/postgresql.crt b/certs/postgresql.crt index 6e6b4284a..c1815f865 100644 --- a/certs/postgresql.crt +++ b/certs/postgresql.crt @@ -1,69 +1,20 @@ -Certificate: - Data: - Version: 3 (0x2) - Serial Number: 2 (0x2) - Signature Algorithm: sha256WithRSAEncryption - Issuer: C=US, ST=Nevada, L=Las Vegas, O=github.com/lib/pq, CN=pq CA - Validity - Not Before: Oct 11 15:10:11 2014 GMT - Not After : Oct 8 15:10:11 2024 GMT - Subject: C=US, ST=Nevada, L=Las Vegas, O=github.com/lib/pq, CN=pqgosslcert - Subject Public Key Info: - Public Key Algorithm: rsaEncryption - RSA Public Key: (1024 bit) - Modulus (1024 bit): - 00:e3:8c:06:9a:70:54:51:d1:34:34:83:39:cd:a2: - 59:0f:05:ed:8d:d8:0e:34:d0:92:f4:09:4d:ee:8c: - 78:55:49:24:f8:3c:e0:34:58:02:b2:e7:94:58:c1: - e8:e5:bb:d1:af:f6:54:c1:40:b1:90:70:79:0d:35: - 54:9c:8f:16:e9:c2:f0:92:e6:64:49:38:c1:76:f8: - 47:66:c4:5b:4a:b6:a9:43:ce:c8:be:6c:4d:2b:94: - 97:3c:55:bc:d1:d0:6e:b7:53:ae:89:5c:4b:6b:86: - 40:be:c1:ae:1e:64:ce:9c:ae:87:0a:69:e5:c8:21: - 12:be:ae:1d:f6:45:df:16:a7 - Exponent: 65537 (0x10001) - X509v3 extensions: - X509v3 Subject Key Identifier: - 9B:25:31:63:A2:D8:06:FF:CB:E3:E9:96:FF:0D:BA:DC:12:7D:04:CF - X509v3 Authority Key Identifier: - keyid:52:93:ED:1E:76:0A:9F:65:4F:DE:19:66:C1:D5:22:40:35:CB:A0:72 - - X509v3 Basic Constraints: - CA:FALSE - X509v3 Key Usage: - Digital Signature, Non Repudiation, Key Encipherment - Signature Algorithm: sha256WithRSAEncryption - 3e:f5:f8:0b:4e:11:bd:00:86:1f:ce:dc:97:02:98:91:11:f5: - 65:f6:f2:8a:b2:3e:47:92:05:69:28:c9:e9:b4:f7:cf:93:d1: - 2d:81:5d:00:3c:23:be:da:70:ea:59:e1:2c:d3:25:49:ae:a6: - 95:54:c1:10:df:23:e3:fe:d6:e4:76:c7:6b:73:ad:1b:34:7c: - e2:56:cc:c0:37:ae:c5:7a:11:20:6c:3d:05:0e:99:cd:22:6c: - cf:59:a1:da:28:d4:65:ba:7d:2f:2b:3d:69:6d:a6:c1:ae:57: - bf:56:64:13:79:f8:48:46:65:eb:81:67:28:0b:7b:de:47:10: - b3:80:3c:31:d1:58:94:01:51:4a:c7:c8:1a:01:a8:af:c4:cd: - bb:84:a5:d9:8b:b4:b9:a1:64:3e:95:d9:90:1d:d5:3f:67:cc: - 3b:ba:f5:b4:d1:33:77:ee:c2:d2:3e:7e:c5:66:6e:b7:35:4c: - 60:57:b0:b8:be:36:c8:f3:d3:95:8c:28:4a:c9:f7:27:a4:0d: - e5:96:99:eb:f5:c8:bd:f3:84:6d:ef:02:f9:8a:36:7d:6b:5f: - 36:68:37:41:d9:74:ae:c6:78:2e:44:86:a1:ad:43:ca:fb:b5: - 3e:ba:10:23:09:02:ac:62:d1:d0:83:c8:95:b9:e3:5e:30:ff: - 5b:2b:38:fa -----BEGIN CERTIFICATE----- -MIIDEzCCAfugAwIBAgIBAjANBgkqhkiG9w0BAQsFADBeMQswCQYDVQQGEwJVUzEP -MA0GA1UECBMGTmV2YWRhMRIwEAYDVQQHEwlMYXMgVmVnYXMxGjAYBgNVBAoTEWdp -dGh1Yi5jb20vbGliL3BxMQ4wDAYDVQQDEwVwcSBDQTAeFw0xNDEwMTExNTEwMTFa -Fw0yNDEwMDgxNTEwMTFaMGQxCzAJBgNVBAYTAlVTMQ8wDQYDVQQIEwZOZXZhZGEx -EjAQBgNVBAcTCUxhcyBWZWdhczEaMBgGA1UEChMRZ2l0aHViLmNvbS9saWIvcHEx -FDASBgNVBAMTC3BxZ29zc2xjZXJ0MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB -gQDjjAaacFRR0TQ0gznNolkPBe2N2A400JL0CU3ujHhVSST4POA0WAKy55RYwejl -u9Gv9lTBQLGQcHkNNVScjxbpwvCS5mRJOMF2+EdmxFtKtqlDzsi+bE0rlJc8VbzR -0G63U66JXEtrhkC+wa4eZM6crocKaeXIIRK+rh32Rd8WpwIDAQABo1owWDAdBgNV -HQ4EFgQUmyUxY6LYBv/L4+mW/w263BJ9BM8wHwYDVR0jBBgwFoAUUpPtHnYKn2VP -3hlmwdUiQDXLoHIwCQYDVR0TBAIwADALBgNVHQ8EBAMCBeAwDQYJKoZIhvcNAQEL -BQADggEBAD71+AtOEb0Ahh/O3JcCmJER9WX28oqyPkeSBWkoyem098+T0S2BXQA8 -I77acOpZ4SzTJUmuppVUwRDfI+P+1uR2x2tzrRs0fOJWzMA3rsV6ESBsPQUOmc0i -bM9Zodoo1GW6fS8rPWltpsGuV79WZBN5+EhGZeuBZygLe95HELOAPDHRWJQBUUrH -yBoBqK/EzbuEpdmLtLmhZD6V2ZAd1T9nzDu69bTRM3fuwtI+fsVmbrc1TGBXsLi+ -Nsjz05WMKErJ9yekDeWWmev1yL3zhG3vAvmKNn1rXzZoN0HZdK7GeC5EhqGtQ8r7 -tT66ECMJAqxi0dCDyJW5414w/1srOPo= +MIIDPjCCAiYCCQD4nsC6zsmIqjANBgkqhkiG9w0BAQsFADBeMQswCQYDVQQGEwJV +UzEPMA0GA1UECAwGTmV2YWRhMRIwEAYDVQQHDAlMYXMgVmVnYXMxGjAYBgNVBAoM +EWdpdGh1Yi5jb20vbGliL3BxMQ4wDAYDVQQDDAVwcSBDQTAeFw0yMTA5MDIwMTU1 +MDJaFw0zMTA5MDMwMTU1MDJaMGQxCzAJBgNVBAYTAlVTMQ8wDQYDVQQIDAZOZXZh +ZGExEjAQBgNVBAcMCUxhcyBWZWdhczEaMBgGA1UECgwRZ2l0aHViLmNvbS9saWIv +cHExFDASBgNVBAMMC3BxZ29zc2xjZXJ0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEAx0ucPVUNCrVmbyithwWrmmZ1dGudBwhSyDB6af4z5Cr+S6dx2SRU +UGUw3Lv+z+tUqQ7hJj0oNddIQeYKl/Tt6JPpZsQfERP/cUGedtyt7HnCKobBL+0B +NvHnDIUiIL4LgfiZK4DWJkGmm7nTHo/7qKAw60vCMLUW98DC0Xhlk9MHYG+e9Zai +3G0vY2X6DUYcSmzBI3JakFEgMZTQg3ofUQMz8TYeK3/DYadLXkl08d18LL3Dnefx +0xRuBPNTa2tLfVnFkfFi6Z9xVB/WhG6+X4OLnO85v5xUOGTV+g154iR7FOkrrl5F +lEUBj+yaIoTRi+MyZ/oYqWwQUDYS3+Te9wIDAQABMA0GCSqGSIb3DQEBCwUAA4IB +AQCCJpwUWCx7xfXv3vH3LQcffZycyRHYPgTCbiQw3x9aBb77jUAh5O6lEj/W0nx2 +SCTEsCsRSAiFwfUb+g/AFCW84dELRWmf38eoqACebLymqnvxyZA+O87yu07XyFZR +TnmbDMzZgsyWWGwS3JoGFk+ibWY4AImYQnSJO8Pi0kZ37ngbAyJ3RtDhhEQJWw/Q +D04p3uky/ea7Gyz0QTx5o40n4gq7nEzF1OS6IHozM840J5aZrxRiXEa56fsmJHmI +IGyI07SGlWJ15r1wc8lB+8ilnAqH1QQlYzTIW0Q4NZE7n3uQg1EVuueGiGO2ex2/ +he9lDiJfOQuPuLbOxzctP9v9 -----END CERTIFICATE----- diff --git a/certs/postgresql.key b/certs/postgresql.key index eb8b20be9..8380da4bc 100644 --- a/certs/postgresql.key +++ b/certs/postgresql.key @@ -1,15 +1,28 @@ ------BEGIN RSA PRIVATE KEY----- -MIICWwIBAAKBgQDjjAaacFRR0TQ0gznNolkPBe2N2A400JL0CU3ujHhVSST4POA0 -WAKy55RYwejlu9Gv9lTBQLGQcHkNNVScjxbpwvCS5mRJOMF2+EdmxFtKtqlDzsi+ -bE0rlJc8VbzR0G63U66JXEtrhkC+wa4eZM6crocKaeXIIRK+rh32Rd8WpwIDAQAB -AoGAM5dM6/kp9P700i8qjOgRPym96Zoh5nGfz/rIE5z/r36NBkdvIg8OVZfR96nH -b0b9TOMR5lsPp0sI9yivTWvX6qyvLJRWy2vvx17hXK9NxXUNTAm0PYZUTvCtcPeX -RnJpzQKNZQPkFzF0uXBc4CtPK2Vz0+FGvAelrhYAxnw1dIkCQQD+9qaW5QhXjsjb -Nl85CmXgxPmGROcgLQCO+omfrjf9UXrituU9Dz6auym5lDGEdMFnkzfr+wpasEy9 -mf5ZZOhDAkEA5HjXfVGaCtpydOt6hDon/uZsyssCK2lQ7NSuE3vP+sUsYMzIpEoy -t3VWXqKbo+g9KNDTP4WEliqp1aiSIylzzQJANPeqzihQnlgEdD4MdD4rwhFJwVIp -Le8Lcais1KaN7StzOwxB/XhgSibd2TbnPpw+3bSg5n5lvUdo+e62/31OHwJAU1jS -I+F09KikQIr28u3UUWT2IzTT4cpVv1AHAQyV3sG3YsjSGT0IK20eyP9BEBZU2WL0 -7aNjrvR5aHxKc5FXsQJABsFtyGpgI5X4xufkJZVZ+Mklz2n7iXa+XPatMAHFxAtb -EEMt60rngwMjXAzBSC6OYuYogRRAY3UCacNC5VhLYQ== ------END RSA PRIVATE KEY----- +-----BEGIN PRIVATE KEY----- +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDHS5w9VQ0KtWZv +KK2HBauaZnV0a50HCFLIMHpp/jPkKv5Lp3HZJFRQZTDcu/7P61SpDuEmPSg110hB +5gqX9O3ok+lmxB8RE/9xQZ523K3secIqhsEv7QE28ecMhSIgvguB+JkrgNYmQaab +udMej/uooDDrS8IwtRb3wMLReGWT0wdgb571lqLcbS9jZfoNRhxKbMEjclqQUSAx +lNCDeh9RAzPxNh4rf8Nhp0teSXTx3XwsvcOd5/HTFG4E81Nra0t9WcWR8WLpn3FU +H9aEbr5fg4uc7zm/nFQ4ZNX6DXniJHsU6SuuXkWURQGP7JoihNGL4zJn+hipbBBQ +NhLf5N73AgMBAAECggEAHLNY1sRO0oH5NHzpMI6yfdPPimqM/JxIP6grmOQQ2QUQ +BhkhHiJLOiC4frFcKtk7IfWQmw8noUlVkJfuYp/VOy9B55jK2IzGtqq6hWeWbH3E +Zpdtbtd021LO8VCi75Au3BLPDCLLtEq0Ea0bKEWX+lrHcLtCRf1uR1OtOrlZ94Wl +DUhm7YJC4cS1bi6Kdf03R+fw2oFi7/QdywcT4ow032jGWOly/Jl7bSHZK7xLtM/i +9HfMwmusD/iuz7mtLU7VCpnlKZm6MfS5D427ybW8MruuiZEtQJ6QtRIrHBHk93aK +Op0tjJ6tMav1UsJzgVz9+uWILE9l0AjAa4AvbfNzEQKBgQD8mma9SLQPtBb6cXuT +CQgjE4vyph8mRnm/pTz3QLIpMiLy2+aKJD/u4cduzLw1vjuH1tlb7NQ9c891jAJh +JhwDwqKAXfFicfRs/PYWngx/XtGhbbpgm1yA6XuYL1D06gzmjzXgHvZMOFcts+GF +y0JEuV7v6eYrpQJRQYCwY6xTgwKBgQDJ+bHAlgOaC94DZEXZMiUznCCjBjAstiXG +BEN7Cnfn6vgvPm/b6BkKn4VrsCmbZQKT7QJDSOhYwXCC2ZlrKiF8GEUHX4mi8347 +8B+DsuokTLNmN61QAZbb1c3XQVnr15xH8ijm7yYs4tCBmVLKBmpw1T4IZXXlVE5k +gmee+AwIfQKBgGr+P0wnclVAc4cq8CusZKzux5VEtebxbPo21CbqWUxHtzPk3rZe +elIFggK1Z3bgF7kG0NQ18QQCfLoOTqe1i6IwG8KBiA+pst1DHD0iPqroj6RvpMTs +qXbU7ovcZs8GH+a8fBZtJufL6WkrSvfvyybu2X6HNP4Bi4S9WPPdlA1fAoGAE5m/ +vkjQoKp2KS4Z+TH8mj2UjT2Uf0JN+CGByvcBG+iZnTwZ7uVfSMCiWgkGgKYU0fY2 +OgFhSvu6x3gGg3fbOAfC6yxCVyX6IibzZ/x87HjlEA5nK1R8J2lgSHt3FoQeDn1Z +qs+ajNCWG32doy1sNvb6xiXSgybjVK2zEKJRyKECgYBJTk2IABebjvInNb6tagcI +nD4d2LgBmZJZsTruHXrpO0s3XCQcFKks4JKH1CVjd34f7LkxzEOGbE7wKBBd652s +ob6gFKnbqTniTo3NRUycB6ymo4LSaBvKgeY5hYbVxrYheRLPGY+gPVYb3VMKu9N9 +76rcaFqJOz7OeywRG5bHUg== +-----END PRIVATE KEY----- diff --git a/certs/root.cnf b/certs/root.cnf new file mode 100644 index 000000000..ea6def267 --- /dev/null +++ b/certs/root.cnf @@ -0,0 +1,10 @@ +[req] +distinguished_name = req_distinguished_name +prompt = no + +[req_distinguished_name] +C = US +ST = Nevada +L = Las Vegas +O = github.com/lib/pq +CN = pq CA diff --git a/certs/root.crt b/certs/root.crt index aecf8f621..390a907c3 100644 --- a/certs/root.crt +++ b/certs/root.crt @@ -1,24 +1,24 @@ -----BEGIN CERTIFICATE----- -MIIEAzCCAuugAwIBAgIJANmheROCdW1NMA0GCSqGSIb3DQEBBQUAMF4xCzAJBgNV -BAYTAlVTMQ8wDQYDVQQIEwZOZXZhZGExEjAQBgNVBAcTCUxhcyBWZWdhczEaMBgG -A1UEChMRZ2l0aHViLmNvbS9saWIvcHExDjAMBgNVBAMTBXBxIENBMB4XDTE0MTAx -MTE1MDQyOVoXDTI0MTAwODE1MDQyOVowXjELMAkGA1UEBhMCVVMxDzANBgNVBAgT -Bk5ldmFkYTESMBAGA1UEBxMJTGFzIFZlZ2FzMRowGAYDVQQKExFnaXRodWIuY29t -L2xpYi9wcTEOMAwGA1UEAxMFcHEgQ0EwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw -ggEKAoIBAQCV4PxP7ShzWBzUCThcKk3qZtOLtHmszQVtbqhvgTpm1kTRtKBdVMu0 -pLAHQ3JgJCnAYgH0iZxVGoMP16T3irdgsdC48+nNTFM2T0cCdkfDURGIhSFN47cb -Pgy306BcDUD2q7ucW33+dlFSRuGVewocoh4BWM/vMtMvvWzdi4Ag/L/jhb+5wZxZ -sWymsadOVSDePEMKOvlCa3EdVwVFV40TVyDb+iWBUivDAYsS2a3KajuJrO6MbZiE -Sp2RCIkZS2zFmzWxVRi9ZhzIZhh7EVF9JAaNC3T52jhGUdlRq3YpBTMnd89iOh74 -6jWXG7wSuPj3haFzyNhmJ0ZUh+2Ynoh1AgMBAAGjgcMwgcAwHQYDVR0OBBYEFFKT -7R52Cp9lT94ZZsHVIkA1y6ByMIGQBgNVHSMEgYgwgYWAFFKT7R52Cp9lT94ZZsHV -IkA1y6ByoWKkYDBeMQswCQYDVQQGEwJVUzEPMA0GA1UECBMGTmV2YWRhMRIwEAYD -VQQHEwlMYXMgVmVnYXMxGjAYBgNVBAoTEWdpdGh1Yi5jb20vbGliL3BxMQ4wDAYD -VQQDEwVwcSBDQYIJANmheROCdW1NMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEF -BQADggEBAAEhCLWkqJNMI8b4gkbmj5fqQ/4+oO83bZ3w2Oqf6eZ8I8BC4f2NOyE6 -tRUlq5+aU7eqC1cOAvGjO+YHN/bF/DFpwLlzvUSXt+JP/pYcUjL7v+pIvwqec9hD -ndvM4iIbkD/H/OYQ3L+N3W+G1x7AcFIX+bGCb3PzYVQAjxreV6//wgKBosMGFbZo -HPxT9RPMun61SViF04H5TNs0derVn1+5eiiYENeAhJzQNyZoOOUuX1X/Inx9bEPh -C5vFBtSMgIytPgieRJVWAiMLYsfpIAStrHztRAbBs2DU01LmMgRvHdxgFEKinC/d -UHZZQDP+6pT+zADrGhQGXe4eThaO6f0= +MIIEBjCCAu6gAwIBAgIJAPizR+OD14YnMA0GCSqGSIb3DQEBCwUAMF4xCzAJBgNV +BAYTAlVTMQ8wDQYDVQQIDAZOZXZhZGExEjAQBgNVBAcMCUxhcyBWZWdhczEaMBgG +A1UECgwRZ2l0aHViLmNvbS9saWIvcHExDjAMBgNVBAMMBXBxIENBMB4XDTIxMDkw +MjAxNTUwMloXDTMxMDkwMzAxNTUwMlowXjELMAkGA1UEBhMCVVMxDzANBgNVBAgM +Bk5ldmFkYTESMBAGA1UEBwwJTGFzIFZlZ2FzMRowGAYDVQQKDBFnaXRodWIuY29t +L2xpYi9wcTEOMAwGA1UEAwwFcHEgQ0EwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw +ggEKAoIBAQDb9d6sjdU6GdibGrXRMOHREH3MRUS8T4TFqGgPEGVDP/V5bAZlBSGP +AN0o9DTyVLcbQpBt8zMTw9KeIzIIe5NIVkSmA16lw/YckGhOM+kZIkiDuE6qt5Ia +OQCRMdXkZ8ejG/JUu+rHU8FJZL8DE+jyYherzdjkeVAQ7JfzxAwW2Dl7T/47g337 +Pwmf17AEb8ibSqmXyUN7R5NhJQs+hvaYdNagzdx91E1H+qlyBvmiNeasUQljLvZ+ +Y8wAuU79neA+d09O4PBiYwV17rSP6SZCeGE3oLZviL/0KM9Xig88oB+2FmvQ6Zxa +L7SoBlqS+5pBZwpH7eee/wCIKAnJtMAJAgMBAAGjgcYwgcMwDwYDVR0TAQH/BAUw +AwEB/zAdBgNVHQ4EFgQUfIXEczahbcM2cFrwclJF7GbdajkwgZAGA1UdIwSBiDCB +hYAUfIXEczahbcM2cFrwclJF7GbdajmhYqRgMF4xCzAJBgNVBAYTAlVTMQ8wDQYD +VQQIDAZOZXZhZGExEjAQBgNVBAcMCUxhcyBWZWdhczEaMBgGA1UECgwRZ2l0aHVi +LmNvbS9saWIvcHExDjAMBgNVBAMMBXBxIENBggkA+LNH44PXhicwDQYJKoZIhvcN +AQELBQADggEBABFyGgSz2mHVJqYgX1Y+7P+MfKt83cV2uYDGYvXrLG2OGiCilVul +oTBG+8omIMSHOsQZvWMpA5H0tnnlQHrKpKpUyKkSL+Wv5GL0UtBmHX7mVRiaK2l4 +q2BjRaQUitp/FH4NSdXtVrMME5T1JBBZHsQkNL3cNRzRKwY/Vj5UGEDxDS7lILUC +e01L4oaK0iKQn4beALU+TvKoAHdPvoxpPpnhkF5ss9HmdcvRktJrKZemDJZswZ7/ ++omx8ZPIYYUH5VJJYYE88S7guAt+ZaKIUlel/t6xPbo2ZySFSg9u1uB99n+jTo3L +1rAxFnN3FCX2jBqgP29xMVmisaN5k04UmyI= -----END CERTIFICATE----- diff --git a/certs/server.cnf b/certs/server.cnf new file mode 100644 index 000000000..ba4b55703 --- /dev/null +++ b/certs/server.cnf @@ -0,0 +1,29 @@ +[ req ] +default_bits = 2048 +distinguished_name = subject +req_extensions = req_ext +x509_extensions = x509_ext +string_mask = utf8only +prompt = no + +[ subject ] +C = US +ST = Nevada +L = Las Vegas +O = github.com/lib/pq + +[ x509_ext ] +subjectKeyIdentifier = hash +authorityKeyIdentifier = keyid,issuer + +basicConstraints = CA:FALSE +keyUsage = digitalSignature, keyEncipherment +subjectAltName = DNS:postgres +nsComment = "OpenSSL Generated Certificate" + +[ req_ext ] +subjectKeyIdentifier = hash +basicConstraints = CA:FALSE +keyUsage = digitalSignature, keyEncipherment +subjectAltName = DNS:postgres +nsComment = "OpenSSL Generated Certificate" diff --git a/certs/server.crt b/certs/server.crt index ddc995a6d..1a0ac0d43 100644 --- a/certs/server.crt +++ b/certs/server.crt @@ -1,81 +1,22 @@ -Certificate: - Data: - Version: 3 (0x2) - Serial Number: 1 (0x1) - Signature Algorithm: sha256WithRSAEncryption - Issuer: C=US, ST=Nevada, L=Las Vegas, O=github.com/lib/pq, CN=pq CA - Validity - Not Before: Oct 11 15:05:15 2014 GMT - Not After : Oct 8 15:05:15 2024 GMT - Subject: C=US, ST=Nevada, L=Las Vegas, O=github.com/lib/pq, CN=postgres - Subject Public Key Info: - Public Key Algorithm: rsaEncryption - RSA Public Key: (2048 bit) - Modulus (2048 bit): - 00:d7:8a:4c:85:fb:17:a5:3c:8f:e0:72:11:29:ce: - 3f:b0:1f:3f:7d:c6:ee:7f:a7:fc:02:2b:35:47:08: - a6:3d:90:df:5c:56:14:94:00:c7:6d:d1:d2:e2:61: - 95:77:b8:e3:a6:66:31:f9:1f:21:7d:62:e1:27:da: - 94:37:61:4a:ea:63:53:a0:61:b8:9c:bb:a5:e2:e7: - b7:a6:d8:0f:05:04:c7:29:e2:ea:49:2b:7f:de:15: - 00:a6:18:70:50:c7:0c:de:9a:f9:5a:96:b0:e1:94: - 06:c6:6d:4a:21:3b:b4:0f:a5:6d:92:86:34:b2:4e: - d7:0e:a7:19:c0:77:0b:7b:87:c8:92:de:42:ff:86: - d2:b7:9a:a4:d4:15:23:ca:ad:a5:69:21:b8:ce:7e: - 66:cb:85:5d:b9:ed:8b:2d:09:8d:94:e4:04:1e:72: - ec:ef:d0:76:90:15:5a:a4:f7:91:4b:e9:ce:4e:9d: - 5d:9a:70:17:9c:d8:e9:73:83:ea:3d:61:99:a6:cd: - ac:91:40:5a:88:77:e5:4e:2a:8e:3d:13:f3:f9:38: - 6f:81:6b:8a:95:ca:0e:07:ab:6f:da:b4:8c:d9:ff: - aa:78:03:aa:c7:c2:cf:6f:64:92:d3:d8:83:d5:af: - f1:23:18:a7:2e:7b:17:0b:e7:7d:f1:fa:a8:41:a3: - 04:57 - Exponent: 65537 (0x10001) - X509v3 extensions: - X509v3 Subject Key Identifier: - EE:F0:B3:46:DC:C7:09:EB:0E:B6:2F:E5:FE:62:60:45:44:9F:59:CC - X509v3 Authority Key Identifier: - keyid:52:93:ED:1E:76:0A:9F:65:4F:DE:19:66:C1:D5:22:40:35:CB:A0:72 - - X509v3 Basic Constraints: - CA:FALSE - X509v3 Key Usage: - Digital Signature, Non Repudiation, Key Encipherment - Signature Algorithm: sha256WithRSAEncryption - 7e:5a:6e:be:bf:d2:6c:c1:d6:fa:b6:fb:3f:06:53:36:08:87: - 9d:95:b1:39:af:9e:f6:47:38:17:39:da:25:7c:f2:ad:0c:e3: - ab:74:19:ca:fb:8c:a0:50:c0:1d:19:8a:9c:21:ed:0f:3a:d1: - 96:54:2e:10:09:4f:b8:70:f7:2b:99:43:d2:c6:15:bc:3f:24: - 7d:28:39:32:3f:8d:a4:4f:40:75:7f:3e:0d:1c:d1:69:f2:4e: - 98:83:47:97:d2:25:ac:c9:36:86:2f:04:a6:c4:86:c7:c4:00: - 5f:7f:b9:ad:fc:bf:e9:f5:78:d7:82:1a:51:0d:fc:ab:9e:92: - 1d:5f:0c:18:d1:82:e0:14:c9:ce:91:89:71:ff:49:49:ff:35: - bf:7b:44:78:42:c1:d0:66:65:bb:28:2e:60:ca:9b:20:12:a9: - 90:61:b1:96:ec:15:46:c9:37:f7:07:90:8a:89:45:2a:3f:37: - ec:dc:e3:e5:8f:c3:3a:57:80:a5:54:60:0c:e1:b2:26:99:2b: - 40:7e:36:d1:9a:70:02:ec:63:f4:3b:72:ae:81:fb:30:20:6d: - cb:48:46:c6:b5:8f:39:b1:84:05:25:55:8d:f5:62:f6:1b:46: - 2e:da:a3:4c:26:12:44:d7:56:b6:b8:a9:ca:d3:ab:71:45:7c: - 9f:48:6d:1e -----BEGIN CERTIFICATE----- -MIIDlDCCAnygAwIBAgIBATANBgkqhkiG9w0BAQsFADBeMQswCQYDVQQGEwJVUzEP -MA0GA1UECBMGTmV2YWRhMRIwEAYDVQQHEwlMYXMgVmVnYXMxGjAYBgNVBAoTEWdp -dGh1Yi5jb20vbGliL3BxMQ4wDAYDVQQDEwVwcSBDQTAeFw0xNDEwMTExNTA1MTVa -Fw0yNDEwMDgxNTA1MTVaMGExCzAJBgNVBAYTAlVTMQ8wDQYDVQQIEwZOZXZhZGEx -EjAQBgNVBAcTCUxhcyBWZWdhczEaMBgGA1UEChMRZ2l0aHViLmNvbS9saWIvcHEx -ETAPBgNVBAMTCHBvc3RncmVzMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC -AQEA14pMhfsXpTyP4HIRKc4/sB8/fcbuf6f8Ais1RwimPZDfXFYUlADHbdHS4mGV -d7jjpmYx+R8hfWLhJ9qUN2FK6mNToGG4nLul4ue3ptgPBQTHKeLqSSt/3hUAphhw -UMcM3pr5Wpaw4ZQGxm1KITu0D6VtkoY0sk7XDqcZwHcLe4fIkt5C/4bSt5qk1BUj -yq2laSG4zn5my4Vdue2LLQmNlOQEHnLs79B2kBVapPeRS+nOTp1dmnAXnNjpc4Pq -PWGZps2skUBaiHflTiqOPRPz+ThvgWuKlcoOB6tv2rSM2f+qeAOqx8LPb2SS09iD -1a/xIxinLnsXC+d98fqoQaMEVwIDAQABo1owWDAdBgNVHQ4EFgQU7vCzRtzHCesO -ti/l/mJgRUSfWcwwHwYDVR0jBBgwFoAUUpPtHnYKn2VP3hlmwdUiQDXLoHIwCQYD -VR0TBAIwADALBgNVHQ8EBAMCBeAwDQYJKoZIhvcNAQELBQADggEBAH5abr6/0mzB -1vq2+z8GUzYIh52VsTmvnvZHOBc52iV88q0M46t0Gcr7jKBQwB0Zipwh7Q860ZZU -LhAJT7hw9yuZQ9LGFbw/JH0oOTI/jaRPQHV/Pg0c0WnyTpiDR5fSJazJNoYvBKbE -hsfEAF9/ua38v+n1eNeCGlEN/Kuekh1fDBjRguAUyc6RiXH/SUn/Nb97RHhCwdBm -ZbsoLmDKmyASqZBhsZbsFUbJN/cHkIqJRSo/N+zc4+WPwzpXgKVUYAzhsiaZK0B+ -NtGacALsY/Q7cq6B+zAgbctIRsa1jzmxhAUlVY31YvYbRi7ao0wmEkTXVra4qcrT -q3FFfJ9IbR4= +MIIDqzCCApOgAwIBAgIJAPiewLrOyYipMA0GCSqGSIb3DQEBCwUAMF4xCzAJBgNV +BAYTAlVTMQ8wDQYDVQQIDAZOZXZhZGExEjAQBgNVBAcMCUxhcyBWZWdhczEaMBgG +A1UECgwRZ2l0aHViLmNvbS9saWIvcHExDjAMBgNVBAMMBXBxIENBMB4XDTIxMDkw +MjAxNTUwMloXDTMxMDkwMzAxNTUwMlowTjELMAkGA1UEBhMCVVMxDzANBgNVBAgM +Bk5ldmFkYTESMBAGA1UEBwwJTGFzIFZlZ2FzMRowGAYDVQQKDBFnaXRodWIuY29t +L2xpYi9wcTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKf6H4UzmANN +QiQJe92Mf3ETMYmpZKNNO9DPEHyNLIkag+XwMrBTdcCK0mLvsNCYpXuBN6703KCd +WAFOeMmj7gOsWtvjt5Xm6bRHLgegekXzcG/jDwq/wyzeDzr/YkITuIlG44Lf9lhY +FLwiHlHOWHnwrZaEh6aU//02aQkzyX5INeXl/3TZm2G2eIH6AOxOKOU27MUsyVSQ +5DE+SDKGcRP4bElueeQWvxAXNMZYb7sVSDdfHI3zr32K4k/tC8x0fZJ5XN/dvl4t +4N4MrYlmDO5XOrb/gQH1H4iu6+5EMDfZYab4fkThnNFdfFqu4/8Scv7KZ8mWqpKM +fGAjEPctQi0CAwEAAaN8MHowHQYDVR0OBBYEFENExPbmDyFB2AJUdbMvVyhlNPD5 +MAkGA1UdEwQCMAAwCwYDVR0PBAQDAgWgMBMGA1UdEQQMMAqCCHBvc3RncmVzMCwG +CWCGSAGG+EIBDQQfFh1PcGVuU1NMIEdlbmVyYXRlZCBDZXJ0aWZpY2F0ZTANBgkq +hkiG9w0BAQsFAAOCAQEAMRVbV8RiEsmp9HAtnVCZmRXMIbgPGrqjeSwk586s4K8v +BSqNCqxv6s5GfCRmDYiqSqeuCVDtUJS1HsTmbxVV7Ke71WMo+xHR1ICGKOa8WGCb +TGsuicG5QZXWaxeMOg4s0qpKmKko0d1aErdVsanU5dkrVS7D6729Ffnzu4lwApk6 +invAB67p8u7sojwqRq5ce0vRaG+YFylTrWomF9kauEb8gKbQ9Xc7QfX+h+UH/mq9 +Nvdj8LOHp6/82bZdnsYUOtV4lS1IA/qzeXpqBphxqfWabD1yLtkyJyImZKq8uIPp +0CG4jhObPdWcCkXD6bg3QK3mhwlC79OtFgxWmldCRQ== -----END CERTIFICATE----- diff --git a/certs/server.key b/certs/server.key index bd7b019b6..152e0f23e 100644 --- a/certs/server.key +++ b/certs/server.key @@ -1,27 +1,28 @@ ------BEGIN RSA PRIVATE KEY----- -MIIEogIBAAKCAQEA14pMhfsXpTyP4HIRKc4/sB8/fcbuf6f8Ais1RwimPZDfXFYU -lADHbdHS4mGVd7jjpmYx+R8hfWLhJ9qUN2FK6mNToGG4nLul4ue3ptgPBQTHKeLq -SSt/3hUAphhwUMcM3pr5Wpaw4ZQGxm1KITu0D6VtkoY0sk7XDqcZwHcLe4fIkt5C -/4bSt5qk1BUjyq2laSG4zn5my4Vdue2LLQmNlOQEHnLs79B2kBVapPeRS+nOTp1d -mnAXnNjpc4PqPWGZps2skUBaiHflTiqOPRPz+ThvgWuKlcoOB6tv2rSM2f+qeAOq -x8LPb2SS09iD1a/xIxinLnsXC+d98fqoQaMEVwIDAQABAoIBAF3ZoihUhJ82F4+r -Gz4QyDpv4L1reT2sb1aiabhcU8ZK5nbWJG+tRyjSS/i2dNaEcttpdCj9HR/zhgZM -bm0OuAgG58rVwgS80CZUruq++Qs+YVojq8/gWPTiQD4SNhV2Fmx3HkwLgUk3oxuT -SsvdqzGE3okGVrutCIcgy126eA147VPMoej1Bb3fO6npqK0pFPhZfAc0YoqJuM+k -obRm5pAnGUipyLCFXjA9HYPKwYZw2RtfdA3CiImHeanSdqS+ctrC9y8BV40Th7gZ -haXdKUNdjmIxV695QQ1mkGqpKLZFqhzKioGQ2/Ly2d1iaKN9fZltTusu8unepWJ2 -tlT9qMECgYEA9uHaF1t2CqE+AJvWTihHhPIIuLxoOQXYea1qvxfcH/UMtaLKzCNm -lQ5pqCGsPvp+10f36yttO1ZehIvlVNXuJsjt0zJmPtIolNuJY76yeussfQ9jHheB -5uPEzCFlHzxYbBUyqgWaF6W74okRGzEGJXjYSP0yHPPdU4ep2q3bGiUCgYEA34Af -wBSuQSK7uLxArWHvQhyuvi43ZGXls6oRGl+Ysj54s8BP6XGkq9hEJ6G4yxgyV+BR -DUOs5X8/TLT8POuIMYvKTQthQyCk0eLv2FLdESDuuKx0kBVY3s8lK3/z5HhrdOiN -VMNZU+xDKgKc3hN9ypkk8vcZe6EtH7Y14e0rVcsCgYBTgxi8F/M5K0wG9rAqphNz -VFBA9XKn/2M33cKjO5X5tXIEKzpAjaUQvNxexG04rJGljzG8+mar0M6ONahw5yD1 -O7i/XWgazgpuOEkkVYiYbd8RutfDgR4vFVMn3hAP3eDnRtBplRWH9Ec3HTiNIys6 -F8PKBOQjyRZQQC7jyzW3hQKBgACe5HeuFwXLSOYsb6mLmhR+6+VPT4wR1F95W27N -USk9jyxAnngxfpmTkiziABdgS9N+pfr5cyN4BP77ia/Jn6kzkC5Cl9SN5KdIkA3z -vPVtN/x/ThuQU5zaymmig1ThGLtMYggYOslG4LDfLPxY5YKIhle+Y+259twdr2yf -Mf2dAoGAaGv3tWMgnIdGRk6EQL/yb9PKHo7ShN+tKNlGaK7WwzBdKs+Fe8jkgcr7 -pz4Ne887CmxejdISzOCcdT+Zm9Bx6I/uZwWOtDvWpIgIxVX9a9URj/+D1MxTE/y4 -d6H+c89yDY62I2+drMpdjCd3EtCaTlxpTbRS+s1eAHMH7aEkcCE= ------END RSA PRIVATE KEY----- +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCn+h+FM5gDTUIk +CXvdjH9xEzGJqWSjTTvQzxB8jSyJGoPl8DKwU3XAitJi77DQmKV7gTeu9NygnVgB +TnjJo+4DrFrb47eV5um0Ry4HoHpF83Bv4w8Kv8Ms3g86/2JCE7iJRuOC3/ZYWBS8 +Ih5Rzlh58K2WhIemlP/9NmkJM8l+SDXl5f902ZthtniB+gDsTijlNuzFLMlUkOQx +PkgyhnET+GxJbnnkFr8QFzTGWG+7FUg3XxyN8699iuJP7QvMdH2SeVzf3b5eLeDe +DK2JZgzuVzq2/4EB9R+IruvuRDA32WGm+H5E4ZzRXXxaruP/EnL+ymfJlqqSjHxg +IxD3LUItAgMBAAECggEAOE2naQ9tIZYw2EFxikZApVcooJrtx6ropMnzHbx4NBB2 +K4mChAXFj184u77ZxmGT/jzGvFcI6LE0wWNbK0NOUV7hKZk/fPhkV3AQZrAMrAu4 +IVi7PwAd3JkmA8F8XuebUDA5rDGDsgL8GD9baFJA58abeLs9eMGyuF4XgOUh4bip +hgHa76O2rcDWNY5HZqqRslw75FzlYkB0PCts/UJxSswj70kTTihyOhDlrm2TnyxI +ne54UbGRrpfs9wiheSGLjDG81qZToBHQDwoAnjjZhu1VCaBISuGbgZrxyyRyqdnn +xPW+KczMv04XyvF7v6Pz+bUEppalLXGiXnH5UtWvZQKBgQDTPCdMpNE/hwlq4nAw +Kf42zIBWfbnMLVWYoeDiAOhtl9XAUAXn76xe6Rvo0qeAo67yejdbJfRq3HvGyw+q +4PS8r9gXYmLYIPQxSoLL5+rFoBCN3qFippfjLB1j32mp7+15KjRj8FF2r6xIN8fu +XatSRsaqmvCWYLDRv/rbHnxwkwKBgQDLkyfFLF7BtwtPWKdqrwOM7ip1UKh+oDBS +vkCQ08aEFRBU7T3jChsx5GbaW6zmsSBwBwcrHclpSkz7n3aq19DDWObJR2p80Fma +rsXeIcvtEpkvT3pVX268P5d+XGs1kxgFunqTysG9yChW+xzcs5MdKBzuMPPn7rL8 +MKAzdar6PwKBgEypkzW8x3h/4Moa3k6MnwdyVs2NGaZheaRIc95yJ+jGZzxBjrMr +h+p2PbvU4BfO0AqOkpKRBtDVrlJqlggVVp04UHvEKE16QEW3Xhr0037f5cInX3j3 +Lz6yXwRFLAsR2aTUzWjL6jTh8uvO2s/GzQuyRh3a16Ar/WBShY+K0+zjAoGATnLT +xZjWnyHRmu8X/PWakamJ9RFzDPDgDlLAgM8LVgTj+UY/LgnL9wsEU6s2UuP5ExKy +QXxGDGwUhHar/SQTj+Pnc7Mwpw6HKSOmnnY5po8fNusSwml3O9XppEkrC0c236Y/ +7EobJO5IFVTJh4cv7vFxTJzSsRL8KFD4uzvh+nMCgYEAqY8NBYtIgNJA2B6C6hHF ++bG7v46434ZHFfGTmMQwzE4taVg7YRnzYESAlvK4bAP5ZXR90n7GRGFhrXzoMZ38 +r0bw/q9rV+ReGda7/Bjf7ciCKiq0RODcHtf4IaskjPXCoQRGJtgCPLhWPfld6g9v +/HTvO96xv9e3eG/PKSPog94= +-----END PRIVATE KEY----- diff --git a/conn.go b/conn.go index b3ab14d3c..bc0983608 100644 --- a/conn.go +++ b/conn.go @@ -2,6 +2,7 @@ package pq import ( "bufio" + "bytes" "context" "crypto/md5" "crypto/sha256" @@ -18,6 +19,7 @@ import ( "path/filepath" "strconv" "strings" + "sync" "time" "unicode" @@ -30,21 +32,28 @@ var ( ErrNotSupported = errors.New("pq: Unsupported command") ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction") ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server") - ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less") - ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly") + ErrSSLKeyUnknownOwnership = errors.New("pq: Could not get owner information for private key, may not be properly protected") + ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key has world access. Permissions should be u=rw,g=r (0640) if owned by root, or u=rw (0600), or less") + + ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly") errUnexpectedReady = errors.New("unexpected ReadyForQuery") errNoRowsAffected = errors.New("no RowsAffected available after the empty statement") errNoLastInsertID = errors.New("no LastInsertId available after the empty statement") ) +// Compile time validation that our types implement the expected interfaces +var ( + _ driver.Driver = Driver{} +) + // Driver is the Postgres database driver. type Driver struct{} // Open opens a new connection to the database. name is a connection string. // Most users should only use it through database/sql package from the standard // library. -func (d *Driver) Open(name string) (driver.Conn, error) { +func (d Driver) Open(name string) (driver.Conn, error) { return Open(name) } @@ -104,7 +113,9 @@ type defaultDialer struct { func (d defaultDialer) Dial(network, address string) (net.Conn, error) { return d.d.Dial(network, address) } -func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { +func (d defaultDialer) DialTimeout( + network, address string, timeout time.Duration, +) (net.Conn, error) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() return d.DialContext(ctx, network, address) @@ -134,9 +145,10 @@ type conn struct { saveMessageType byte saveMessageBuffer []byte - // If true, this connection is bad and all public-facing functions should - // return ErrBadConn. - bad bool + // If an error is set, this connection is bad and all public-facing + // functions should return the appropriate error by calling get() + // (ErrBadConn) or getForNext(). + err syncErr // If set, this connection should never use the binary format when // receiving query results from prepared statements. Only provided for @@ -160,6 +172,40 @@ type conn struct { gss GSS } +type syncErr struct { + err error + sync.Mutex +} + +// Return ErrBadConn if connection is bad. +func (e *syncErr) get() error { + e.Lock() + defer e.Unlock() + if e.err != nil { + return driver.ErrBadConn + } + return nil +} + +// Return the error set on the connection. Currently only used by rows.Next. +func (e *syncErr) getForNext() error { + e.Lock() + defer e.Unlock() + return e.err +} + +// Set error, only if it isn't set yet. +func (e *syncErr) set(err error) { + if err == nil { + panic("attempt to set nil err") + } + e.Lock() + defer e.Unlock() + if e.err == nil { + e.err = err + } +} + // Handle driver-side settings in parsed connection string. func (cn *conn) handleDriverSettings(o values) (err error) { boolSetting := func(key string, val *bool) error { @@ -187,7 +233,8 @@ func (cn *conn) handlePgpass(o values) { if _, ok := o["password"]; ok { return } - filename := os.Getenv("PGPASSFILE") + // Get passfile from the options + filename := o["passfile"] if filename == "" { // XXX this code doesn't work on Windows where the default filename is // XXX %APPDATA%\postgresql\pgpass.conf @@ -217,47 +264,56 @@ func (cn *conn) handlePgpass(o values) { } defer file.Close() scanner := bufio.NewScanner(io.Reader(file)) + // From: https://github.com/tg/pgpass/blob/master/reader.go + for scanner.Scan() { + if scanText(scanner.Text(), o) { + break + } + } +} + +// GetFields is a helper function for scanText. +func getFields(s string) []string { + fs := make([]string, 0, 5) + f := make([]rune, 0, len(s)) + + var esc bool + for _, c := range s { + switch { + case esc: + f = append(f, c) + esc = false + case c == '\\': + esc = true + case c == ':': + fs = append(fs, string(f)) + f = f[:0] + default: + f = append(f, c) + } + } + return append(fs, string(f)) +} + +// ScanText assists HandlePgpass in it's objective. +func scanText(line string, o values) bool { hostname := o["host"] ntw, _ := network(o) port := o["port"] db := o["dbname"] username := o["user"] - // From: https://github.com/tg/pgpass/blob/master/reader.go - getFields := func(s string) []string { - fs := make([]string, 0, 5) - f := make([]rune, 0, len(s)) - - var esc bool - for _, c := range s { - switch { - case esc: - f = append(f, c) - esc = false - case c == '\\': - esc = true - case c == ':': - fs = append(fs, string(f)) - f = f[:0] - default: - f = append(f, c) - } - } - return append(fs, string(f)) + if len(line) == 0 || line[0] == '#' { + return false } - for scanner.Scan() { - line := scanner.Text() - if len(line) == 0 || line[0] == '#' { - continue - } - split := getFields(line) - if len(split) != 5 { - continue - } - if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) { - o["password"] = split[4] - return - } + split := getFields(line) + if len(split) != 5 { + return false } + if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) { + o["password"] = split[4] + return true + } + return false } func (cn *conn) writeBuf(b byte) *writeBuf { @@ -281,7 +337,7 @@ func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) { if err != nil { return nil, err } - c.dialer = d + c.Dialer(d) return c.open(context.Background()) } @@ -292,7 +348,13 @@ func (c *Connector) open(ctx context.Context) (cn *conn, err error) { // the user. defer errRecoverNoErrBadConn(&err) - o := c.opts + // Create a new values map (copy). This makes it so maps in different + // connections do not reference the same underlying data structure, so it + // is safe for multiple connections to concurrently write to their opts. + o := make(values) + for k, v := range c.opts { + o[k] = v + } cn = &conn{ opts: o, @@ -503,7 +565,7 @@ func (cn *conn) isInTransaction() bool { func (cn *conn) checkIsInTransaction(intxn bool) { if cn.isInTransaction() != intxn { - cn.bad = true + cn.err.set(driver.ErrBadConn) errorf("unexpected transaction status %v", cn.txnStatus) } } @@ -513,8 +575,8 @@ func (cn *conn) Begin() (_ driver.Tx, err error) { } func (cn *conn) begin(mode string) (_ driver.Tx, err error) { - if cn.bad { - return nil, driver.ErrBadConn + if err := cn.err.get(); err != nil { + return nil, err } defer cn.errRecover(&err) @@ -524,11 +586,11 @@ func (cn *conn) begin(mode string) (_ driver.Tx, err error) { return nil, err } if commandTag != "BEGIN" { - cn.bad = true + cn.err.set(driver.ErrBadConn) return nil, fmt.Errorf("unexpected command tag %s", commandTag) } if cn.txnStatus != txnStatusIdleInTransaction { - cn.bad = true + cn.err.set(driver.ErrBadConn) return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus) } return cn, nil @@ -542,8 +604,8 @@ func (cn *conn) closeTxn() { func (cn *conn) Commit() (err error) { defer cn.closeTxn() - if cn.bad { - return driver.ErrBadConn + if err := cn.err.get(); err != nil { + return err } defer cn.errRecover(&err) @@ -564,12 +626,12 @@ func (cn *conn) Commit() (err error) { _, commandTag, err := cn.simpleExec("COMMIT") if err != nil { if cn.isInTransaction() { - cn.bad = true + cn.err.set(driver.ErrBadConn) } return err } if commandTag != "COMMIT" { - cn.bad = true + cn.err.set(driver.ErrBadConn) return fmt.Errorf("unexpected command tag %s", commandTag) } cn.checkIsInTransaction(false) @@ -578,8 +640,8 @@ func (cn *conn) Commit() (err error) { func (cn *conn) Rollback() (err error) { defer cn.closeTxn() - if cn.bad { - return driver.ErrBadConn + if err := cn.err.get(); err != nil { + return err } defer cn.errRecover(&err) return cn.rollback() @@ -590,7 +652,7 @@ func (cn *conn) rollback() (err error) { _, commandTag, err := cn.simpleExec("ROLLBACK") if err != nil { if cn.isInTransaction() { - cn.bad = true + cn.err.set(driver.ErrBadConn) } return err } @@ -630,7 +692,7 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err case 'T', 'D': // ignore any results default: - cn.bad = true + cn.err.set(driver.ErrBadConn) errorf("unknown response for simple query: %q", t) } } @@ -652,7 +714,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { // the user can close, though, to avoid connections from being // leaked. A "rows" with done=true works fine for that purpose. if err != nil { - cn.bad = true + cn.err.set(driver.ErrBadConn) errorf("unexpected message %q in simple query execution", t) } if res == nil { @@ -663,8 +725,11 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { // Set the result and tag to the last command complete if there wasn't a // query already run. Although queries usually return from here and cede // control to Next, a query with zero results does not. - if t == 'C' && res.colNames == nil { + if t == 'C' { res.result, res.tag = cn.parseComplete(r.string()) + if res.colNames != nil { + return + } } res.done = true case 'Z': @@ -676,7 +741,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { err = parseError(r) case 'D': if res == nil { - cn.bad = true + cn.err.set(driver.ErrBadConn) errorf("unexpected DataRow in simple query execution") } // the query didn't fail; kick off to Next @@ -691,7 +756,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { // To work around a bug in QueryRow in Go 1.2 and earlier, wait // until the first DataRow has been received. default: - cn.bad = true + cn.err.set(driver.ErrBadConn) errorf("unknown response for simple query: %q", t) } } @@ -713,7 +778,9 @@ func (noRows) RowsAffected() (int64, error) { // Decides which column formats to use for a prepared statement. The input is // an array of type oids, one element per result column. -func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte) { +func decideColumnFormats( + colTyps []fieldDesc, forceText bool, +) (colFmts []format, colFmtData []byte) { if len(colTyps) == 0 { return nil, colFmtDataAllText } @@ -784,8 +851,8 @@ func (cn *conn) prepareTo(q, stmtName string) *stmt { } func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) { - if cn.bad { - return nil, driver.ErrBadConn + if err := cn.err.get(); err != nil { + return nil, err } defer cn.errRecover(&err) @@ -823,8 +890,8 @@ func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { } func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { - if cn.bad { - return nil, driver.ErrBadConn + if err := cn.err.get(); err != nil { + return nil, err } if cn.inCopy { return nil, errCopyInProgress @@ -857,8 +924,8 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { // Implement the optional "Execer" interface for one-shot queries func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) { - if cn.bad { - return nil, driver.ErrBadConn + if err := cn.err.get(); err != nil { + return nil, err } defer cn.errRecover(&err) @@ -891,9 +958,20 @@ func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err return r, err } +type safeRetryError struct { + Err error +} + +func (se *safeRetryError) Error() string { + return se.Err.Error() +} + func (cn *conn) send(m *writeBuf) { - _, err := cn.c.Write(m.wrap()) + n, err := cn.c.Write(m.wrap()) if err != nil { + if n == 0 { + err = &safeRetryError{Err: err} + } panic(err) } } @@ -918,7 +996,7 @@ func (cn *conn) sendSimpleMessage(typ byte) (err error) { // the message yourself. func (cn *conn) saveMessage(typ byte, buf *readBuf) { if cn.saveMessageType != 0 { - cn.bad = true + cn.err.set(driver.ErrBadConn) errorf("unexpected saveMessageType %d", cn.saveMessageType) } cn.saveMessageType = typ @@ -1064,7 +1142,7 @@ func isDriverSetting(key string) bool { return true case "password": return true - case "sslmode", "sslcert", "sslkey", "sslrootcert": + case "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline", "sslsni": return true case "fallback_application_name": return true @@ -1074,9 +1152,9 @@ func isDriverSetting(key string) bool { return true case "binary_parameters": return true - case "service": + case "krbsrvname": return true - case "spn": + case "krbspn": return true default: return false @@ -1168,13 +1246,13 @@ func (cn *conn) auth(r *readBuf, o values) { var token []byte - if spn, ok := o["spn"]; ok { + if spn, ok := o["krbspn"]; ok { // Use the supplied SPN if provided.. token, err = cli.GetInitTokenFromSpn(spn) } else { // Allow the kerberos service name to be overridden service := "postgres" - if val, ok := o["service"]; ok { + if val, ok := o["krbsrvname"]; ok { service = val } @@ -1288,8 +1366,8 @@ func (st *stmt) Close() (err error) { if st.closed { return nil } - if st.cn.bad { - return driver.ErrBadConn + if err := st.cn.err.get(); err != nil { + return err } defer st.cn.errRecover(&err) @@ -1302,14 +1380,14 @@ func (st *stmt) Close() (err error) { t, _ := st.cn.recv1() if t != '3' { - st.cn.bad = true + st.cn.err.set(driver.ErrBadConn) errorf("unexpected close response: %q", t) } st.closed = true t, r := st.cn.recv1() if t != 'Z' { - st.cn.bad = true + st.cn.err.set(driver.ErrBadConn) errorf("expected ready for query, but got: %q", t) } st.cn.processReadyForQuery(r) @@ -1318,8 +1396,12 @@ func (st *stmt) Close() (err error) { } func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { - if st.cn.bad { - return nil, driver.ErrBadConn + return st.query(v) +} + +func (st *stmt) query(v []driver.Value) (r *rows, err error) { + if err := st.cn.err.get(); err != nil { + return nil, err } defer st.cn.errRecover(&err) @@ -1331,8 +1413,8 @@ func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { } func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) { - if st.cn.bad { - return nil, driver.ErrBadConn + if err := st.cn.err.get(); err != nil { + return nil, err } defer st.cn.errRecover(&err) @@ -1418,7 +1500,7 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") { parts := strings.Split(commandTag, " ") if len(parts) != 3 { - cn.bad = true + cn.err.set(driver.ErrBadConn) errorf("unexpected INSERT command tag %s", commandTag) } affectedRows = &parts[len(parts)-1] @@ -1430,7 +1512,7 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { } n, err := strconv.ParseInt(*affectedRows, 10, 64) if err != nil { - cn.bad = true + cn.err.set(driver.ErrBadConn) errorf("could not parse commandTag: %s", err) } return driver.RowsAffected(n), commandTag @@ -1497,8 +1579,8 @@ func (rs *rows) Next(dest []driver.Value) (err error) { } conn := rs.cn - if conn.bad { - return driver.ErrBadConn + if err := conn.err.getForNext(); err != nil { + return err } defer conn.errRecover(&err) @@ -1522,7 +1604,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) { case 'D': n := rs.rb.int16() if err != nil { - conn.bad = true + conn.err.set(driver.ErrBadConn) errorf("unexpected DataRow after error %s", err) } if n < len(dest) { @@ -1564,10 +1646,10 @@ func (rs *rows) NextResultSet() error { // QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be // used as part of an SQL statement. For example: // -// tblname := "my_table" -// data := "my_data" -// quoted := pq.QuoteIdentifier(tblname) -// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data) +// tblname := "my_table" +// data := "my_data" +// quoted := pq.QuoteIdentifier(tblname) +// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data) // // Any double quotes in name will be escaped. The quoted identifier will be // case sensitive when used in a query. If the input string contains a zero @@ -1580,12 +1662,24 @@ func QuoteIdentifier(name string) string { return `"` + strings.Replace(name, `"`, `""`, -1) + `"` } +// BufferQuoteIdentifier satisfies the same purpose as QuoteIdentifier, but backed by a +// byte buffer. +func BufferQuoteIdentifier(name string, buffer *bytes.Buffer) { + end := strings.IndexRune(name, 0) + if end > -1 { + name = name[:end] + } + buffer.WriteRune('"') + buffer.WriteString(strings.Replace(name, `"`, `""`, -1)) + buffer.WriteRune('"') +} + // QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal // to DDL and other statements that do not accept parameters) to be used as part // of an SQL statement. For example: // -// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z") -// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date)) +// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z") +// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date)) // // Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be // replaced by two backslashes (i.e. "\\") and the C-style escape identifier @@ -1689,10 +1783,9 @@ func (cn *conn) processParameterStatus(r *readBuf) { case "server_version": var major1 int var major2 int - var minor int - _, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor) + _, err = fmt.Sscanf(r.string(), "%d.%d", &major1, &major2) if err == nil { - cn.parameterStatus.serverVersion = major1*10000 + major2*100 + minor + cn.parameterStatus.serverVersion = major1*10000 + major2*100 } case "TimeZone": @@ -1717,7 +1810,7 @@ func (cn *conn) readReadyForQuery() { cn.processReadyForQuery(r) return default: - cn.bad = true + cn.err.set(driver.ErrBadConn) errorf("unexpected message %q; expected ReadyForQuery", t) } } @@ -1737,12 +1830,16 @@ func (cn *conn) readParseResponse() { cn.readReadyForQuery() panic(err) default: - cn.bad = true + cn.err.set(driver.ErrBadConn) errorf("unexpected Parse response %q", t) } } -func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc) { +func (cn *conn) readStatementDescribeResponse() ( + paramTyps []oid.Oid, + colNames []string, + colTyps []fieldDesc, +) { for { t, r := cn.recv1() switch t { @@ -1762,7 +1859,7 @@ func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames [ cn.readReadyForQuery() panic(err) default: - cn.bad = true + cn.err.set(driver.ErrBadConn) errorf("unexpected Describe statement response %q", t) } } @@ -1780,7 +1877,7 @@ func (cn *conn) readPortalDescribeResponse() rowsHeader { cn.readReadyForQuery() panic(err) default: - cn.bad = true + cn.err.set(driver.ErrBadConn) errorf("unexpected Describe response %q", t) } panic("not reached") @@ -1796,7 +1893,7 @@ func (cn *conn) readBindResponse() { cn.readReadyForQuery() panic(err) default: - cn.bad = true + cn.err.set(driver.ErrBadConn) errorf("unexpected Bind response %q", t) } } @@ -1823,20 +1920,22 @@ func (cn *conn) postExecuteWorkaround() { cn.saveMessage(t, r) return default: - cn.bad = true + cn.err.set(driver.ErrBadConn) errorf("unexpected message during extended query execution: %q", t) } } } // Only for Exec(), since we ignore the returned data -func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) { +func (cn *conn) readExecuteResponse( + protocolState string, +) (res driver.Result, commandTag string, err error) { for { t, r := cn.recv1() switch t { case 'C': if err != nil { - cn.bad = true + cn.err.set(driver.ErrBadConn) errorf("unexpected CommandComplete after error %s", err) } res, commandTag = cn.parseComplete(r.string()) @@ -1850,7 +1949,7 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co err = parseError(r) case 'T', 'D', 'I': if err != nil { - cn.bad = true + cn.err.set(driver.ErrBadConn) errorf("unexpected %q after error %s", t, err) } if t == 'I' { @@ -1858,7 +1957,7 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co } // ignore any results default: - cn.bad = true + cn.err.set(driver.ErrBadConn) errorf("unknown %s response: %q", protocolState, t) } } @@ -1940,6 +2039,8 @@ func parseEnviron(env []string) (out map[string]string) { accrue("user") case "PGPASSWORD": accrue("password") + case "PGPASSFILE": + accrue("passfile") case "PGSERVICE", "PGSERVICEFILE", "PGREALM": unsupported() case "PGOPTIONS": @@ -1954,6 +2055,8 @@ func parseEnviron(env []string) (out map[string]string) { accrue("sslkey") case "PGSSLROOTCERT": accrue("sslrootcert") + case "PGSSLSNI": + accrue("sslsni") case "PGREQUIRESSL", "PGSSLCRL": unsupported() case "PGREQUIREPEER": @@ -1994,3 +2097,19 @@ func alnumLowerASCII(ch rune) rune { } return -1 // discard } + +// The database/sql/driver package says: +// All Conn implementations should implement the following interfaces: Pinger, SessionResetter, and Validator. +var _ driver.Pinger = &conn{} +var _ driver.SessionResetter = &conn{} + +func (cn *conn) ResetSession(ctx context.Context) error { + // Ensure bad connections are reported: From database/sql/driver: + // If a connection is never returned to the connection pool but immediately reused, then + // ResetSession is called prior to reuse but IsValid is not called. + return cn.err.get() +} + +func (cn *conn) IsValid() bool { + return cn.err.get() == nil +} diff --git a/conn_go115.go b/conn_go115.go new file mode 100644 index 000000000..f4ef030f9 --- /dev/null +++ b/conn_go115.go @@ -0,0 +1,8 @@ +//go:build go1.15 +// +build go1.15 + +package pq + +import "database/sql/driver" + +var _ driver.Validator = &conn{} diff --git a/conn_go18.go b/conn_go18.go index 09e2ea464..63d4ca6aa 100644 --- a/conn_go18.go +++ b/conn_go18.go @@ -10,6 +10,10 @@ import ( "time" ) +const ( + watchCancelDialContextTimeout = time.Second * 10 +) + // Implement the "QueryerContext" interface func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { list := make([]driver.Value, len(args)) @@ -42,6 +46,14 @@ func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.Nam return cn.Exec(query, list) } +// Implement the "ConnPrepareContext" interface +func (cn *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + if finish := cn.watchCancel(ctx); finish != nil { + defer finish() + } + return cn.Prepare(query) +} + // Implement the "ConnBeginTx" interface func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { var mode string @@ -89,25 +101,37 @@ func (cn *conn) Ping(ctx context.Context) error { func (cn *conn) watchCancel(ctx context.Context) func() { if done := ctx.Done(); done != nil { - finished := make(chan struct{}) + finished := make(chan struct{}, 1) go func() { select { case <-done: + select { + case finished <- struct{}{}: + default: + // We raced with the finish func, let the next query handle this with the + // context. + return + } + + // Set the connection state to bad so it does not get reused. + cn.err.set(ctx.Err()) + // At this point the function level context is canceled, // so it must not be used for the additional network // request to cancel the query. // Create a new context to pass into the dial. - ctxCancel, cancel := context.WithTimeout(context.Background(), time.Second*10) + ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout) defer cancel() _ = cn.cancel(ctxCancel) - finished <- struct{}{} case <-finished: } }() return func() { select { case <-finished: + cn.err.set(ctx.Err()) + cn.Close() case finished <- struct{}{}: } } @@ -116,7 +140,16 @@ func (cn *conn) watchCancel(ctx context.Context) func() { } func (cn *conn) cancel(ctx context.Context) error { - c, err := dial(ctx, cn.dialer, cn.opts) + // Create a new values map (copy). This makes sure the connection created + // in this method cannot write to the same underlying data, which could + // cause a concurrent map write panic. This is necessary because cancel + // is called from a goroutine in watchCancel. + o := make(values) + for k, v := range cn.opts { + o[k] = v + } + + c, err := dial(ctx, cn.dialer, o) if err != nil { return err } @@ -126,7 +159,7 @@ func (cn *conn) cancel(ctx context.Context) error { can := conn{ c: c, } - err = can.ssl(cn.opts) + err = can.ssl(o) if err != nil { return err } @@ -147,3 +180,68 @@ func (cn *conn) cancel(ctx context.Context) error { return err } } + +// Implement the "StmtQueryContext" interface +func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + list := make([]driver.Value, len(args)) + for i, nv := range args { + list[i] = nv.Value + } + finish := st.watchCancel(ctx) + r, err := st.query(list) + if err != nil { + if finish != nil { + finish() + } + return nil, err + } + r.finish = finish + return r, nil +} + +// Implement the "StmtExecContext" interface +func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + list := make([]driver.Value, len(args)) + for i, nv := range args { + list[i] = nv.Value + } + + if finish := st.watchCancel(ctx); finish != nil { + defer finish() + } + + return st.Exec(list) +} + +// watchCancel is implemented on stmt in order to not mark the parent conn as bad +func (st *stmt) watchCancel(ctx context.Context) func() { + if done := ctx.Done(); done != nil { + finished := make(chan struct{}) + go func() { + select { + case <-done: + // At this point the function level context is canceled, + // so it must not be used for the additional network + // request to cancel the query. + // Create a new context to pass into the dial. + ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout) + defer cancel() + + _ = st.cancel(ctxCancel) + finished <- struct{}{} + case <-finished: + } + }() + return func() { + select { + case <-finished: + case finished <- struct{}{}: + } + } + } + return nil +} + +func (st *stmt) cancel(ctx context.Context) error { + return st.cn.cancel(ctx) +} diff --git a/conn_go19.go b/conn_go19.go new file mode 100644 index 000000000..003601565 --- /dev/null +++ b/conn_go19.go @@ -0,0 +1,35 @@ +//go:build go1.9 +// +build go1.9 + +package pq + +import ( + "database/sql/driver" + "reflect" +) + +var _ driver.NamedValueChecker = (*conn)(nil) + +func (c *conn) CheckNamedValue(nv *driver.NamedValue) error { + if _, ok := nv.Value.(driver.Valuer); ok { + // Ignore Valuer, for backward compatibility with pq.Array(). + return driver.ErrSkip + } + + // Ignoring []byte / []uint8. + if _, ok := nv.Value.([]uint8); ok { + return driver.ErrSkip + } + + v := reflect.ValueOf(nv.Value) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() == reflect.Slice { + var err error + nv.Value, err = Array(v.Interface()).Value() + return err + } + + return driver.ErrSkip +} diff --git a/conn_go19_test.go b/conn_go19_test.go new file mode 100644 index 000000000..43c982631 --- /dev/null +++ b/conn_go19_test.go @@ -0,0 +1,83 @@ +//go:build go1.9 +// +build go1.9 + +package pq + +import ( + "fmt" + "reflect" + "testing" +) + +func TestArrayArg(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + for _, tc := range []struct { + pgType string + in, out interface{} + }{ + { + pgType: "int[]", + in: []int{245, 231}, + out: []int64{245, 231}, + }, + { + pgType: "int[]", + in: &[]int{245, 231}, + out: []int64{245, 231}, + }, + { + pgType: "int[]", + in: []int64{245, 231}, + }, + { + pgType: "int[]", + in: &[]int64{245, 231}, + out: []int64{245, 231}, + }, + { + pgType: "varchar[]", + in: []string{"hello", "world"}, + }, + { + pgType: "varchar[]", + in: &[]string{"hello", "world"}, + out: []string{"hello", "world"}, + }, + } { + if tc.out == nil { + tc.out = tc.in + } + t.Run(fmt.Sprintf("%#v", tc.in), func(t *testing.T) { + r, err := db.Query(fmt.Sprintf("SELECT $1::%s", tc.pgType), tc.in) + if err != nil { + t.Fatal(err) + } + defer r.Close() + + if !r.Next() { + if r.Err() != nil { + t.Fatal(r.Err()) + } + t.Fatal("expected row") + } + + defer func() { + if r.Next() { + t.Fatal("unexpected row") + } + }() + + got := reflect.New(reflect.TypeOf(tc.out)) + if err := r.Scan(Array(got.Interface())); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(tc.out, got.Elem().Interface()) { + t.Errorf("got %v, want %v", got, tc.out) + } + }) + } + +} diff --git a/conn_test.go b/conn_test.go index 0d25c9554..96f70dddf 100644 --- a/conn_test.go +++ b/conn_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "database/sql/driver" + "errors" "fmt" "io" "net" @@ -143,10 +144,6 @@ func TestOpenURL(t *testing.T) { const pgpassFile = "/tmp/pqgotest_pgpass" func TestPgpass(t *testing.T) { - if os.Getenv("TRAVIS") != "true" { - t.Skip("not running under Travis, skipping pgpass tests") - } - testAssert := func(conninfo string, expected string, reason string) { conn, err := openTestConnConninfo(conninfo) if err != nil { @@ -177,7 +174,10 @@ func TestPgpass(t *testing.T) { } testAssert("", "ok", "missing .pgpass, unexpected error %#v") os.Setenv("PGPASSFILE", pgpassFile) + defer os.Unsetenv("PGPASSFILE") testAssert("host=/tmp", "fail", ", unexpected error %#v") + os.Unsetenv("PGPASSFILE") + os.Remove(pgpassFile) pgpass, err := os.OpenFile(pgpassFile, os.O_RDWR|os.O_CREATE, 0644) if err != nil { @@ -192,6 +192,7 @@ localhost:*:*:*:pass_C if err != nil { t.Fatalf("Unexpected error writing pgpass file %#v", err) } + defer os.Remove(pgpassFile) pgpass.Close() assertPassword := func(extra values, expected string) { @@ -214,19 +215,22 @@ localhost:*:*:*:pass_C t.Fatalf("For %v expected %s got %s", extra, expected, pw) } } + // missing passfile means empty psasword + assertPassword(values{"host": "server", "dbname": "some_db", "user": "some_user"}, "") // wrong permissions for the pgpass file means it should be ignored - assertPassword(values{"host": "example.com", "user": "foo"}, "") + assertPassword(values{"host": "example.com", "passfile": pgpassFile, "user": "foo"}, "") // fix the permissions and check if it has taken effect os.Chmod(pgpassFile, 0600) - assertPassword(values{"host": "server", "dbname": "some_db", "user": "some_user"}, "pass_A") - assertPassword(values{"host": "example.com", "user": "foo"}, "pass_fallback") - assertPassword(values{"host": "example.com", "dbname": "some_db", "user": "some_user"}, "pass_B") + + assertPassword(values{"host": "server", "passfile": pgpassFile, "dbname": "some_db", "user": "some_user"}, "pass_A") + assertPassword(values{"host": "example.com", "passfile": pgpassFile, "user": "foo"}, "pass_fallback") + assertPassword(values{"host": "example.com", "passfile": pgpassFile, "dbname": "some_db", "user": "some_user"}, "pass_B") // localhost also matches the default "" and UNIX sockets - assertPassword(values{"host": "", "user": "some_user"}, "pass_C") - assertPassword(values{"host": "/tmp", "user": "some_user"}, "pass_C") - // cleanup - os.Remove(pgpassFile) - os.Setenv("PGPASSFILE", "") + assertPassword(values{"host": "", "passfile": pgpassFile, "user": "some_user"}, "pass_C") + assertPassword(values{"host": "/tmp", "passfile": pgpassFile, "user": "some_user"}, "pass_C") + // passfile connection parameter takes precedence + os.Setenv("PGPASSFILE", "/tmp") + assertPassword(values{"host": "server", "passfile": pgpassFile, "dbname": "some_db", "user": "some_user"}, "pass_A") } func TestExec(t *testing.T) { @@ -321,7 +325,7 @@ func TestStatment(t *testing.T) { if !r1.Next() { if r.Err() != nil { - t.Fatal(r1.Err()) + t.Fatal(r.Err()) } t.Fatal("expected row") } @@ -703,8 +707,8 @@ func TestBadConn(t *testing.T) { if err != driver.ErrBadConn { t.Fatalf("expected driver.ErrBadConn, got: %#v", err) } - if !cn.bad { - t.Fatalf("expected cn.bad") + if err := cn.err.get(); err != driver.ErrBadConn { + t.Fatalf("expected driver.ErrBadConn, got %#v", err) } cn = conn{} @@ -716,8 +720,8 @@ func TestBadConn(t *testing.T) { if err != driver.ErrBadConn { t.Fatalf("expected driver.ErrBadConn, got: %#v", err) } - if !cn.bad { - t.Fatalf("expected cn.bad") + if err := cn.err.get(); err != driver.ErrBadConn { + t.Fatalf("expected driver.ErrBadConn, got %#v", err) } } @@ -1293,7 +1297,7 @@ func TestNullAfterNonNull(t *testing.T) { if !r.Next() { if r.Err() != nil { - t.Fatal(err) + t.Fatal(r.Err()) } t.Fatal("expected row") } @@ -1308,7 +1312,7 @@ func TestNullAfterNonNull(t *testing.T) { if !r.Next() { if r.Err() != nil { - t.Fatal(err) + t.Fatal(r.Err()) } t.Fatal("expected row") } @@ -1504,7 +1508,7 @@ func TestRuntimeParameters(t *testing.T) { } value, success := tryGetParameterValue() - if success != test.success && !test.success { + if success != test.success && !success { t.Fatalf("%v: unexpected error: %v", test.conninfo, err) } if success != test.success { @@ -1640,10 +1644,6 @@ func TestRowsResultTag(t *testing.T) { { query: "CREATE TEMP TABLE t (a int); DROP TABLE t; SELECT 1", }, - // Verify that an no-results query doesn't set the tag. - { - query: "CREATE TEMP TABLE t (a int); SELECT 1 WHERE FALSE; DROP TABLE t;", - }, } // If this is the only test run, this will correct the connection string. @@ -1748,6 +1748,36 @@ func TestMultipleResult(t *testing.T) { } } +func TestMultipleEmptyResult(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + rows, err := db.Query("select 1 where false; select 2") + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + for rows.Next() { + t.Fatal("unexpected row") + } + if !rows.NextResultSet() { + t.Fatal("expected more result sets", rows.Err()) + } + for rows.Next() { + var i int + if err := rows.Scan(&i); err != nil { + t.Fatal(err) + } + if i != 2 { + t.Fatalf("expected 2, got %d", i) + } + } + if rows.NextResultSet() { + t.Fatal("unexpected result set") + } +} + func TestCopyInStmtAffectedRows(t *testing.T) { db := openTestConn(t) defer db.Close() @@ -1775,3 +1805,169 @@ func TestCopyInStmtAffectedRows(t *testing.T) { res.RowsAffected() res.LastInsertId() } + +func TestConnPrepareContext(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + tests := []struct { + name string + ctx func() (context.Context, context.CancelFunc) + sql string + err error + }{ + { + name: "context.Background", + ctx: func() (context.Context, context.CancelFunc) { + return context.Background(), nil + }, + sql: "SELECT 1", + err: nil, + }, + { + name: "context.WithTimeout exceeded", + ctx: func() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), -time.Minute) + }, + sql: "SELECT 1", + err: context.DeadlineExceeded, + }, + { + name: "context.WithTimeout", + ctx: func() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), time.Minute) + }, + sql: "SELECT 1", + err: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := tt.ctx() + if cancel != nil { + defer cancel() + } + _, err := db.PrepareContext(ctx, tt.sql) + switch { + case (err != nil) != (tt.err != nil): + t.Fatalf("conn.PrepareContext() unexpected nil err got = %v, expected = %v", err, tt.err) + case (err != nil && tt.err != nil) && (err.Error() != tt.err.Error()): + t.Errorf("conn.PrepareContext() got = %v, expected = %v", err.Error(), tt.err.Error()) + } + }) + } +} + +func TestStmtQueryContext(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + tests := []struct { + name string + ctx func() (context.Context, context.CancelFunc) + sql string + cancelExpected bool + }{ + { + name: "context.Background", + ctx: func() (context.Context, context.CancelFunc) { + return context.Background(), nil + }, + sql: "SELECT pg_sleep(1);", + cancelExpected: false, + }, + { + name: "context.WithTimeout exceeded", + ctx: func() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), 1*time.Second) + }, + sql: "SELECT pg_sleep(10);", + cancelExpected: true, + }, + { + name: "context.WithTimeout", + ctx: func() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), time.Minute) + }, + sql: "SELECT pg_sleep(1);", + cancelExpected: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := tt.ctx() + if cancel != nil { + defer cancel() + } + stmt, err := db.PrepareContext(ctx, tt.sql) + if err != nil { + t.Fatal(err) + } + _, err = stmt.QueryContext(ctx) + pgErr := (*Error)(nil) + switch { + case (err != nil) != tt.cancelExpected: + t.Fatalf("stmt.QueryContext() unexpected nil err got = %v, cancelExpected = %v", err, tt.cancelExpected) + case (err != nil && tt.cancelExpected) && !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode): + t.Errorf("stmt.QueryContext() got = %v, cancelExpected = %v", err.Error(), tt.cancelExpected) + } + }) + } +} + +func TestStmtExecContext(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + tests := []struct { + name string + ctx func() (context.Context, context.CancelFunc) + sql string + cancelExpected bool + }{ + { + name: "context.Background", + ctx: func() (context.Context, context.CancelFunc) { + return context.Background(), nil + }, + sql: "SELECT pg_sleep(1);", + cancelExpected: false, + }, + { + name: "context.WithTimeout exceeded", + ctx: func() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), 1*time.Second) + }, + sql: "SELECT pg_sleep(10);", + cancelExpected: true, + }, + { + name: "context.WithTimeout", + ctx: func() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), time.Minute) + }, + sql: "SELECT pg_sleep(1);", + cancelExpected: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := tt.ctx() + if cancel != nil { + defer cancel() + } + stmt, err := db.PrepareContext(ctx, tt.sql) + if err != nil { + t.Fatal(err) + } + _, err = stmt.ExecContext(ctx) + pgErr := (*Error)(nil) + switch { + case (err != nil) != tt.cancelExpected: + t.Fatalf("stmt.QueryContext() unexpected nil err got = %v, cancelExpected = %v", err, tt.cancelExpected) + case (err != nil && tt.cancelExpected) && !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode): + t.Errorf("stmt.QueryContext() got = %v, cancelExpected = %v", err.Error(), tt.cancelExpected) + } + }) + } +} diff --git a/connector.go b/connector.go index 6a0ee7fc1..1145e1225 100644 --- a/connector.go +++ b/connector.go @@ -27,7 +27,12 @@ func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) { return c.open(ctx) } -// Driver returnst the underlying driver of this Connector. +// Dialer allows change the dialer used to open connections. +func (c *Connector) Dialer(dialer Dialer) { + c.dialer = dialer +} + +// Driver returns the underlying driver of this Connector. func (c *Connector) Driver() driver.Driver { return &Driver{} } diff --git a/connector_example_test.go b/connector_example_test.go index 9401fa0d4..dbae4db52 100644 --- a/connector_example_test.go +++ b/connector_example_test.go @@ -1,3 +1,4 @@ +//go:build go1.10 // +build go1.10 package pq_test diff --git a/connector_test.go b/connector_test.go index 3d2c67b06..ab34fc50b 100644 --- a/connector_test.go +++ b/connector_test.go @@ -1,3 +1,4 @@ +//go:build go1.10 // +build go1.10 package pq @@ -6,6 +7,7 @@ import ( "context" "database/sql" "database/sql/driver" + "os" "testing" ) @@ -65,3 +67,21 @@ func TestNewConnector_Driver(t *testing.T) { } txn.Rollback() } + +func TestNewConnector_Environ(t *testing.T) { + name := "" + os.Setenv("PGPASSFILE", "/tmp/.pgpass") + defer os.Unsetenv("PGPASSFILE") + c, err := NewConnector(name) + if err != nil { + t.Fatal(err) + } + for key, expected := range map[string]string{ + "passfile": "/tmp/.pgpass", + } { + if got := c.opts[key]; got != expected { + t.Fatalf("Getting values from environment variables, for %v expected %s got %s", key, expected, got) + } + } + +} diff --git a/copy.go b/copy.go index d3bc1edd8..a8f16b2b2 100644 --- a/copy.go +++ b/copy.go @@ -1,6 +1,8 @@ package pq import ( + "bytes" + "context" "database/sql/driver" "encoding/binary" "errors" @@ -19,29 +21,35 @@ var ( // CopyIn creates a COPY FROM statement which can be prepared with // Tx.Prepare(). The target table should be visible in search_path. func CopyIn(table string, columns ...string) string { - stmt := "COPY " + QuoteIdentifier(table) + " (" + buffer := bytes.NewBufferString("COPY ") + BufferQuoteIdentifier(table, buffer) + buffer.WriteString(" (") + makeStmt(buffer, columns...) + return buffer.String() +} + +// MakeStmt makes the stmt string for CopyIn and CopyInSchema. +func makeStmt(buffer *bytes.Buffer, columns ...string) { + //s := bytes.NewBufferString() for i, col := range columns { if i != 0 { - stmt += ", " + buffer.WriteString(", ") } - stmt += QuoteIdentifier(col) + BufferQuoteIdentifier(col, buffer) } - stmt += ") FROM STDIN" - return stmt + buffer.WriteString(") FROM STDIN") } // CopyInSchema creates a COPY FROM statement which can be prepared with // Tx.Prepare(). func CopyInSchema(schema, table string, columns ...string) string { - stmt := "COPY " + QuoteIdentifier(schema) + "." + QuoteIdentifier(table) + " (" - for i, col := range columns { - if i != 0 { - stmt += ", " - } - stmt += QuoteIdentifier(col) - } - stmt += ") FROM STDIN" - return stmt + buffer := bytes.NewBufferString("COPY ") + BufferQuoteIdentifier(schema, buffer) + buffer.WriteRune('.') + BufferQuoteIdentifier(table, buffer) + buffer.WriteString(" (") + makeStmt(buffer, columns...) + return buffer.String() } type copyin struct { @@ -52,8 +60,11 @@ type copyin struct { closed bool - sync.Mutex // guards err - err error + mu struct { + sync.Mutex + err error + driver.Result + } } const ciBufferSize = 64 * 1024 @@ -97,13 +108,13 @@ awaitCopyInResponse: err = parseError(r) case 'Z': if err == nil { - ci.setBad() + ci.setBad(driver.ErrBadConn) errorf("unexpected ReadyForQuery in response to COPY") } cn.processReadyForQuery(r) return nil, err default: - ci.setBad() + ci.setBad(driver.ErrBadConn) errorf("unknown response for copy query: %q", t) } } @@ -122,7 +133,7 @@ awaitCopyInResponse: cn.processReadyForQuery(r) return nil, err default: - ci.setBad() + ci.setBad(driver.ErrBadConn) errorf("unknown response for CopyFail: %q", t) } } @@ -143,7 +154,7 @@ func (ci *copyin) resploop() { var r readBuf t, err := ci.cn.recvMessage(&r) if err != nil { - ci.setBad() + ci.setBad(driver.ErrBadConn) ci.setError(err) ci.done <- true return @@ -151,6 +162,8 @@ func (ci *copyin) resploop() { switch t { case 'C': // complete + res, _ := ci.cn.parseComplete(r.string()) + ci.setResult(res) case 'N': if n := ci.cn.noticeHandler; n != nil { n(parseError(&r)) @@ -163,7 +176,7 @@ func (ci *copyin) resploop() { err := parseError(&r) ci.setError(err) default: - ci.setBad() + ci.setBad(driver.ErrBadConn) ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t)) ci.done <- true return @@ -171,34 +184,45 @@ func (ci *copyin) resploop() { } } -func (ci *copyin) setBad() { - ci.Lock() - ci.cn.bad = true - ci.Unlock() +func (ci *copyin) setBad(err error) { + ci.cn.err.set(err) } -func (ci *copyin) isBad() bool { - ci.Lock() - b := ci.cn.bad - ci.Unlock() - return b +func (ci *copyin) getBad() error { + return ci.cn.err.get() } -func (ci *copyin) isErrorSet() bool { - ci.Lock() - isSet := (ci.err != nil) - ci.Unlock() - return isSet +func (ci *copyin) err() error { + ci.mu.Lock() + err := ci.mu.err + ci.mu.Unlock() + return err } // setError() sets ci.err if one has not been set already. Caller must not be // holding ci.Mutex. func (ci *copyin) setError(err error) { - ci.Lock() - if ci.err == nil { - ci.err = err + ci.mu.Lock() + if ci.mu.err == nil { + ci.mu.err = err } - ci.Unlock() + ci.mu.Unlock() +} + +func (ci *copyin) setResult(result driver.Result) { + ci.mu.Lock() + ci.mu.Result = result + ci.mu.Unlock() +} + +func (ci *copyin) getResult() driver.Result { + ci.mu.Lock() + result := ci.mu.Result + ci.mu.Unlock() + if result == nil { + return driver.RowsAffected(0) + } + return result } func (ci *copyin) NumInput() int { @@ -221,17 +245,21 @@ func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { return nil, errCopyInClosed } - if ci.isBad() { - return nil, driver.ErrBadConn + if err := ci.getBad(); err != nil { + return nil, err } defer ci.cn.errRecover(&err) - if ci.isErrorSet() { - return nil, ci.err + if err := ci.err(); err != nil { + return nil, err } if len(v) == 0 { - return driver.RowsAffected(0), ci.Close() + if err := ci.Close(); err != nil { + return driver.RowsAffected(0), err + } + + return ci.getResult(), nil } numValues := len(v) @@ -253,14 +281,51 @@ func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { return driver.RowsAffected(0), nil } +// CopyData inserts a raw string into the COPY stream. The insert is +// asynchronous and CopyData can return errors from previous CopyData calls to +// the same COPY stmt. +// +// You need to call Exec(nil) to sync the COPY stream and to get any +// errors from pending data, since Stmt.Close() doesn't return errors +// to the user. +func (ci *copyin) CopyData(ctx context.Context, line string) (r driver.Result, err error) { + if ci.closed { + return nil, errCopyInClosed + } + + if finish := ci.cn.watchCancel(ctx); finish != nil { + defer finish() + } + + if err := ci.getBad(); err != nil { + return nil, err + } + defer ci.cn.errRecover(&err) + + if err := ci.err(); err != nil { + return nil, err + } + + ci.buffer = append(ci.buffer, []byte(line)...) + ci.buffer = append(ci.buffer, '\n') + + if len(ci.buffer) > ciBufferFlushSize { + ci.flush(ci.buffer) + // reset buffer, keep bytes for message identifier and length + ci.buffer = ci.buffer[:5] + } + + return driver.RowsAffected(0), nil +} + func (ci *copyin) Close() (err error) { if ci.closed { // Don't do anything, we're already closed return nil } ci.closed = true - if ci.isBad() { - return driver.ErrBadConn + if err := ci.getBad(); err != nil { + return err } defer ci.cn.errRecover(&err) @@ -276,8 +341,7 @@ func (ci *copyin) Close() (err error) { <-ci.done ci.cn.inCopy = false - if ci.isErrorSet() { - err = ci.err + if err := ci.err(); err != nil { return err } return nil diff --git a/copy_test.go b/copy_test.go index a888a8948..5a1cf92af 100644 --- a/copy_test.go +++ b/copy_test.go @@ -4,9 +4,11 @@ import ( "bytes" "database/sql" "database/sql/driver" + "fmt" "net" "strings" "testing" + "time" ) func TestCopyInStmt(t *testing.T) { @@ -73,11 +75,20 @@ func TestCopyInMultipleValues(t *testing.T) { } } - _, err = stmt.Exec() + result, err := stmt.Exec() + if err != nil { + t.Fatal(err) + } + + rowsAffected, err := result.RowsAffected() if err != nil { t.Fatal(err) } + if rowsAffected != 500 { + t.Fatalf("expected 500 rows affected, not %d", rowsAffected) + } + err = stmt.Close() if err != nil { t.Fatal(err) @@ -121,7 +132,7 @@ func TestCopyInRaiseStmtTrigger(t *testing.T) { _, err = txn.Exec(` CREATE OR REPLACE FUNCTION pg_temp.temptest() - RETURNS trigger AS + RETURNS trigger AS $BODY$ begin raise notice 'Hello world'; return new; @@ -134,7 +145,7 @@ func TestCopyInRaiseStmtTrigger(t *testing.T) { _, err = txn.Exec(` CREATE TRIGGER temptest_trigger BEFORE INSERT - ON temp + ON temp FOR EACH ROW EXECUTE PROCEDURE pg_temp.temptest()`) if err != nil { @@ -397,10 +408,13 @@ func TestCopyRespLoopConnectionError(t *testing.T) { t.Fatal(err) } } - _, err = stmt.Exec() - if err == nil { - t.Fatalf("expected error") - } + retry(t, time.Second*5, func() error { + _, err = stmt.Exec() + if err == nil { + return fmt.Errorf("expected error") + } + return nil + }) switch pge := err.(type) { case *Error: if pge.Code.Name() != "admin_shutdown" { @@ -411,6 +425,8 @@ func TestCopyRespLoopConnectionError(t *testing.T) { default: if err == driver.ErrBadConn { // likely an EPIPE + } else if err == errCopyInClosed { + // ignore } else { t.Fatalf("unexpected error, got %+#v", err) } @@ -419,6 +435,24 @@ func TestCopyRespLoopConnectionError(t *testing.T) { _ = stmt.Close() } +// retry executes f in a backoff loop until it doesn't return an error. If this +// doesn't happen within duration, t.Fatal is called with the latest error. +func retry(t *testing.T, duration time.Duration, f func() error) { + start := time.Now() + next := time.Millisecond * 100 + for { + err := f() + if err == nil { + return + } + if time.Since(start) > duration { + t.Fatal(err) + } + time.Sleep(next) + next *= 2 + } +} + func BenchmarkCopyIn(b *testing.B) { db := openTestConn(b) defer db.Close() @@ -466,3 +500,11 @@ func BenchmarkCopyIn(b *testing.B) { b.Fatalf("expected %d items, not %d", b.N, num) } } + +var bigTableColumns = []string{"ABIOGENETICALLY", "ABORIGINALITIES", "ABSORBABILITIES", "ABSORBEFACIENTS", "ABSORPTIOMETERS", "ABSTRACTIONISMS", "ABSTRACTIONISTS", "ACANTHOCEPHALAN", "ACCEPTABILITIES", "ACCEPTINGNESSES", "ACCESSARINESSES", "ACCESSIBILITIES", "ACCESSORINESSES", "ACCIDENTALITIES", "ACCIDENTOLOGIES", "ACCLIMATISATION", "ACCLIMATIZATION", "ACCOMMODATINGLY", "ACCOMMODATIONAL", "ACCOMPLISHMENTS", "ACCOUNTABLENESS", "ACCOUNTANTSHIPS", "ACCULTURATIONAL", "ACETOPHENETIDIN", "ACETYLSALICYLIC", "ACHONDROPLASIAS", "ACHONDROPLASTIC", "ACHROMATICITIES", "ACHROMATISATION", "ACHROMATIZATION", "ACIDIMETRICALLY", "ACKNOWLEDGEABLE", "ACKNOWLEDGEABLY", "ACKNOWLEDGEMENT", "ACKNOWLEDGMENTS", "ACQUIRABILITIES", "ACQUISITIVENESS", "ACRIMONIOUSNESS", "ACROPARESTHESIA", "ACTINOBIOLOGIES", "ACTINOCHEMISTRY", "ACTINOTHERAPIES", "ADAPTABLENESSES", "ADDITIONALITIES", "ADENOCARCINOMAS", "ADENOHYPOPHYSES", "ADENOHYPOPHYSIS", "ADENOIDECTOMIES", "ADIATHERMANCIES", "ADJUSTABILITIES", "ADMINISTRATIONS", "ADMIRABLENESSES", "ADMISSIBILITIES", "ADRENALECTOMIES", "ADSORBABILITIES", "ADVENTUROUSNESS", "ADVERSARINESSES", "ADVISABLENESSES", "AERODYNAMICALLY", "AERODYNAMICISTS", "AEROELASTICIANS", "AEROHYDROPLANES", "AEROLITHOLOGIES", "AEROSOLISATIONS", "AEROSOLIZATIONS", "AFFECTABILITIES", "AFFECTIVENESSES", "AFFORDABILITIES", "AFFRANCHISEMENT", "AFTERSENSATIONS", "AGGLUTINABILITY", "AGGRANDISEMENTS", "AGGRANDIZEMENTS", "AGGREGATENESSES", "AGRANULOCYTOSES", "AGRANULOCYTOSIS", "AGREEABLENESSES", "AGRIBUSINESSMAN", "AGRIBUSINESSMEN", "AGRICULTURALIST", "AIRWORTHINESSES", "ALCOHOLISATIONS", "ALCOHOLIZATIONS", "ALCOHOLOMETRIES", "ALEXIPHARMAKONS", "ALGORITHMICALLY", "ALKALINISATIONS", "ALKALINIZATIONS", "ALLEGORICALNESS", "ALLEGORISATIONS", "ALLEGORIZATIONS", "ALLELOMORPHISMS", "ALLERGENICITIES", "ALLOTETRAPLOIDS", "ALLOTETRAPLOIDY", "ALLOTRIOMORPHIC", "ALLOWABLENESSES", "ALPHABETISATION", "ALPHABETIZATION", "ALTERNATIVENESS", "ALTITUDINARIANS", "ALUMINOSILICATE", "ALUMINOTHERMIES", "AMARYLLIDACEOUS", "AMBASSADORSHIPS", "AMBIDEXTERITIES", "AMBIGUOUSNESSES", "AMBISEXUALITIES", "AMBITIOUSNESSES", "AMINOPEPTIDASES", "AMINOPHENAZONES", "AMMONIFICATIONS", "AMORPHOUSNESSES", "AMPHIDIPLOIDIES", "AMPHITHEATRICAL", "ANACOLUTHICALLY", "ANACREONTICALLY", "ANAESTHESIOLOGY", "ANAESTHETICALLY", "ANAGRAMMATISING", "ANAGRAMMATIZING", "ANALOGOUSNESSES", "ANALYZABILITIES", "ANAMORPHOSCOPES", "ANCYLOSTOMIASES", "ANCYLOSTOMIASIS", "ANDROGYNOPHORES", "ANDROMEDOTOXINS", "ANDROMONOECIOUS", "ANDROMONOECISMS", "ANESTHETIZATION", "ANFRACTUOSITIES", "ANGUSTIROSTRATE", "ANIMATRONICALLY", "ANISOTROPICALLY", "ANKYLOSTOMIASES", "ANKYLOSTOMIASIS", "ANNIHILATIONISM", "ANOMALISTICALLY", "ANOMALOUSNESSES", "ANONYMOUSNESSES", "ANSWERABILITIES", "ANTAGONISATIONS", "ANTAGONIZATIONS", "ANTAPHRODISIACS", "ANTEPENULTIMATE", "ANTHROPOBIOLOGY", "ANTHROPOCENTRIC", "ANTHROPOGENESES", "ANTHROPOGENESIS", "ANTHROPOGENETIC", "ANTHROPOLATRIES", "ANTHROPOLOGICAL", "ANTHROPOLOGISTS", "ANTHROPOMETRIES", "ANTHROPOMETRIST", "ANTHROPOMORPHIC", "ANTHROPOPATHIES", "ANTHROPOPATHISM", "ANTHROPOPHAGIES", "ANTHROPOPHAGITE", "ANTHROPOPHAGOUS", "ANTHROPOPHOBIAS", "ANTHROPOPHOBICS", "ANTHROPOPHUISMS", "ANTHROPOPSYCHIC", "ANTHROPOSOPHIES", "ANTHROPOSOPHIST", "ANTIABORTIONIST", "ANTIALCOHOLISMS", "ANTIAPHRODISIAC", "ANTIARRHYTHMICS", "ANTICAPITALISMS", "ANTICAPITALISTS", "ANTICARCINOGENS", "ANTICHOLESTEROL", "ANTICHOLINERGIC", "ANTICHRISTIANLY", "ANTICLERICALISM", "ANTICLIMACTICAL", "ANTICOINCIDENCE", "ANTICOLONIALISM", "ANTICOLONIALIST", "ANTICOMPETITIVE", "ANTICONVULSANTS", "ANTICONVULSIVES", "ANTIDEPRESSANTS", "ANTIDERIVATIVES", "ANTIDEVELOPMENT", "ANTIEDUCATIONAL", "ANTIEGALITARIAN", "ANTIFASHIONABLE", "ANTIFEDERALISTS", "ANTIFERROMAGNET", "ANTIFORECLOSURE", "ANTIHELMINTHICS", "ANTIHISTAMINICS", "ANTILIBERALISMS", "ANTILIBERTARIAN", "ANTILOGARITHMIC", "ANTIMATERIALISM", "ANTIMATERIALIST", "ANTIMETABOLITES", "ANTIMILITARISMS", "ANTIMILITARISTS", "ANTIMONARCHICAL", "ANTIMONARCHISTS", "ANTIMONOPOLISTS", "ANTINATIONALIST", "ANTINUCLEARISTS", "ANTIODONTALGICS", "ANTIPERISTALSES", "ANTIPERISTALSIS", "ANTIPERISTALTIC", "ANTIPERSPIRANTS", "ANTIPHLOGISTICS", "ANTIPORNOGRAPHY", "ANTIPROGRESSIVE", "ANTIQUARIANISMS", "ANTIRADICALISMS", "ANTIRATIONALISM", "ANTIRATIONALIST", "ANTIRATIONALITY", "ANTIREPUBLICANS", "ANTIROMANTICISM", "ANTISEGREGATION", "ANTISENTIMENTAL", "ANTISEPARATISTS", "ANTISEPTICISING", "ANTISEPTICIZING", "ANTISEXUALITIES", "ANTISHOPLIFTING", "ANTISOCIALITIES", "ANTISPECULATION", "ANTISPECULATIVE", "ANTISYPHILITICS", "ANTITHEORETICAL", "ANTITHROMBOTICS", "ANTITRADITIONAL", "ANTITRANSPIRANT", "ANTITRINITARIAN", "ANTITUBERCULOUS", "ANTIVIVISECTION", "APHELIOTROPISMS", "APOCALYPTICALLY", "APOCALYPTICISMS", "APOLIPOPROTEINS", "APOLITICALITIES", "APOPHTHEGMATISE", "APOPHTHEGMATIST", "APOPHTHEGMATIZE", "APOTHEGMATISING", "APOTHEGMATIZING", "APPEALABILITIES", "APPEALINGNESSES", "APPENDICULARIAN", "APPLICABILITIES", "APPRENTICEHOODS", "APPRENTICEMENTS", "APPRENTICESHIPS", "APPROACHABILITY", "APPROPINQUATING", "APPROPINQUATION", "APPROPINQUITIES", "APPROPRIATENESS", "ARACHNOIDITISES", "ARBITRARINESSES", "ARBORICULTURIST", "ARCHAEBACTERIUM", "ARCHAEOBOTANIES", "ARCHAEOBOTANIST", "ARCHAEOMETRISTS", "ARCHAEOPTERYXES", "ARCHAEZOOLOGIES", "ARCHEOASTRONOMY", "ARCHEOBOTANISTS", "ARCHEOLOGICALLY", "ARCHEOMAGNETISM", "ARCHEOZOOLOGIES", "ARCHEOZOOLOGIST", "ARCHGENETHLIACS", "ARCHIDIACONATES", "ARCHIEPISCOPACY", "ARCHIEPISCOPATE", "ARCHITECTURALLY", "ARCHPRIESTHOODS", "ARCHPRIESTSHIPS", "ARGUMENTATIVELY", "ARIBOFLAVINOSES", "ARIBOFLAVINOSIS", "AROMATHERAPISTS", "ARRONDISSEMENTS", "ARTERIALISATION", "ARTERIALIZATION", "ARTERIOGRAPHIES", "ARTIFICIALISING", "ARTIFICIALITIES", "ARTIFICIALIZING", "ASCLEPIADACEOUS", "ASSENTIVENESSES"} + +func BenchmarkCopy(b *testing.B) { + for i := 0; i < b.N; i++ { + CopyIn("temp", bigTableColumns...) + } +} diff --git a/doc.go b/doc.go index 78c670b1d..b57184801 100644 --- a/doc.go +++ b/doc.go @@ -57,8 +57,6 @@ supported: * sslkey - Key file location. The file must contain PEM encoded data. * sslrootcert - The location of the root certificate file. The file must contain PEM encoded data. - * spn - Configures GSS (Kerberos) SPN. - * service - GSS (Kerberos) service name to use when constructing the SPN (default is `postgres`). Valid values for sslmode are: @@ -259,5 +257,12 @@ package: This package is in a separate module so that users who don't need Kerberos don't have to download unnecessary dependencies. +When imported, additional connection string parameters are supported: + + * krbsrvname - GSS (Kerberos) service name when constructing the + SPN (default is `postgres`). This will be combined with the host + to form the full SPN: `krbsrvname/host`. + * krbspn - GSS (Kerberos) SPN. This takes priority over + `krbsrvname` if present. */ package pq diff --git a/encode.go b/encode.go index c4dafe270..bffe6096a 100644 --- a/encode.go +++ b/encode.go @@ -200,11 +200,17 @@ func appendEscapedText(buf []byte, text string) []byte { func mustParse(f string, typ oid.Oid, s []byte) time.Time { str := string(s) - // check for a 30-minute-offset timezone - if (typ == oid.T_timestamptz || typ == oid.T_timetz) && - str[len(str)-3] == ':' { - f += ":00" + // Check for a minute and second offset in the timezone. + if typ == oid.T_timestamptz || typ == oid.T_timetz { + for i := 3; i <= 6; i += 3 { + if str[len(str)-i] == ':' { + f += ":00" + continue + } + break + } } + // Special case for 24:00 time. // Unfortunately, golang does not parse 24:00 as a proper time. // In this case, we want to try "round to the next day", to differentiate. @@ -416,7 +422,7 @@ func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, erro if remainderIdx < len(str) && str[remainderIdx] == '.' { fracStart := remainderIdx + 1 - fracOff := strings.IndexAny(str[fracStart:], "-+ ") + fracOff := strings.IndexAny(str[fracStart:], "-+Z ") if fracOff < 0 { fracOff = len(str) - fracStart } @@ -426,7 +432,7 @@ func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, erro remainderIdx += fracOff + 1 } if tzStart := remainderIdx; tzStart < len(str) && (str[tzStart] == '-' || str[tzStart] == '+') { - // time zone separator is always '-' or '+' (UTC is +00) + // time zone separator is always '-' or '+' or 'Z' (UTC is +00) var tzSign int switch c := str[tzStart]; c { case '-': @@ -448,7 +454,11 @@ func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, erro remainderIdx += 3 } tzOff = tzSign * ((tzHours * 60 * 60) + (tzMin * 60) + tzSec) + } else if tzStart < len(str) && str[tzStart] == 'Z' { + // time zone Z separator indicates UTC is +00 + remainderIdx += 1 } + var isoYear int if isBC { @@ -553,7 +563,7 @@ func parseBytea(s []byte) (result []byte, err error) { if len(s) < 4 { return nil, fmt.Errorf("invalid bytea sequence %v", s) } - r, err := strconv.ParseInt(string(s[1:4]), 8, 9) + r, err := strconv.ParseUint(string(s[1:4]), 8, 8) if err != nil { return nil, fmt.Errorf("could not parse bytea value: %s", err.Error()) } diff --git a/encode_test.go b/encode_test.go index 3406c315a..69f9ebb10 100644 --- a/encode_test.go +++ b/encode_test.go @@ -59,6 +59,8 @@ var timeTests = []struct { time.FixedZone("", -(7*60*60+42*60)))}, {"2001-02-03 04:05:06-07:30:09", time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9)))}, + {"2001-02-03 04:05:06+07:30:09", time.Date(2001, time.February, 3, 4, 5, 6, 0, + time.FixedZone("", +(7*60*60+30*60+9)))}, {"2001-02-03 04:05:06+07", time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", 7*60*60))}, {"0011-02-03 04:05:06 BC", time.Date(-10, time.February, 3, 4, 5, 6, 0, time.FixedZone("", 0))}, @@ -251,6 +253,8 @@ func TestTimeWithTimezone(t *testing.T) { }{ {"11:59:59+00:00", time.Date(0, 1, 1, 11, 59, 59, 0, time.UTC)}, {"11:59:59+04:00", time.Date(0, 1, 1, 11, 59, 59, 0, time.FixedZone("+04", 4*60*60))}, + {"11:59:59+04:01:02", time.Date(0, 1, 1, 11, 59, 59, 0, time.FixedZone("+04:01:02", 4*60*60+1*60+2))}, + {"11:59:59-04:01:02", time.Date(0, 1, 1, 11, 59, 59, 0, time.FixedZone("-04:01:02", -(4*60*60+1*60+2)))}, {"24:00+00", time.Date(0, 1, 2, 0, 0, 0, 0, time.UTC)}, {"24:00Z", time.Date(0, 1, 2, 0, 0, 0, 0, time.UTC)}, {"24:00-04:00", time.Date(0, 1, 2, 0, 0, 0, 0, time.FixedZone("-04", -4*60*60))}, @@ -824,6 +828,43 @@ func TestAppendEscapedTextExistingBuffer(t *testing.T) { } } +var formatAndParseTimestamp = []struct { + time time.Time + expected string +}{ + {time.Time{}, "0001-01-01 00:00:00Z"}, + {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "2001-02-03 04:05:06.123456789Z"}, + {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "2001-02-03 04:05:06.123456789+02:00"}, + {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "2001-02-03 04:05:06.123456789-06:00"}, + {time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "2001-02-03 04:05:06-07:30:09"}, + + {time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "0001-02-03 04:05:06.123456789Z"}, + {time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "0001-02-03 04:05:06.123456789+02:00"}, + {time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "0001-02-03 04:05:06.123456789-06:00"}, + + {time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "0001-02-03 04:05:06.123456789Z BC"}, + {time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "0001-02-03 04:05:06.123456789+02:00 BC"}, + {time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "0001-02-03 04:05:06.123456789-06:00 BC"}, + + {time.Date(1, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "0001-02-03 04:05:06-07:30:09"}, + {time.Date(0, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "0001-02-03 04:05:06-07:30:09 BC"}, +} + +func TestFormatAndParseTimestamp(t *testing.T) { + for _, val := range formatAndParseTimestamp { + formattedTime := FormatTimestamp(val.time) + parsedTime, err := ParseTimestamp(nil, string(formattedTime)) + + if err != nil { + t.Errorf("invalid parsing, err: %v", err.Error()) + } + + if val.time.UTC() != parsedTime.UTC() { + t.Errorf("invalid parsing from formatted timestamp, got %v; expected %v", parsedTime.String(), val.time.String()) + } + } +} + func BenchmarkAppendEscapedText(b *testing.B) { longString := "" for i := 0; i < 100; i++ { diff --git a/error.go b/error.go index 3d66ba7c5..f67c5a5fa 100644 --- a/error.go +++ b/error.go @@ -402,6 +402,11 @@ func (err *Error) Fatal() bool { return err.Severity == Efatal } +// SQLState returns the SQLState of the error. +func (err *Error) SQLState() string { + return string(err.Code) +} + // Get implements the legacy PGError interface. New code should use the fields // of the Error struct directly. func (err *Error) Get(k byte) (v string) { @@ -444,7 +449,7 @@ func (err *Error) Get(k byte) (v string) { return "" } -func (err Error) Error() string { +func (err *Error) Error() string { return "pq: " + err.Message } @@ -484,7 +489,7 @@ func (cn *conn) errRecover(err *error) { case nil: // Do nothing case runtime.Error: - cn.bad = true + cn.err.set(driver.ErrBadConn) panic(v) case *Error: if v.Fatal() { @@ -493,23 +498,26 @@ func (cn *conn) errRecover(err *error) { *err = v } case *net.OpError: - cn.bad = true + cn.err.set(driver.ErrBadConn) *err = v + case *safeRetryError: + cn.err.set(driver.ErrBadConn) + *err = driver.ErrBadConn case error: - if v == io.EOF || v.(error).Error() == "remote error: handshake failure" { + if v == io.EOF || v.Error() == "remote error: handshake failure" { *err = driver.ErrBadConn } else { *err = v } default: - cn.bad = true + cn.err.set(driver.ErrBadConn) panic(fmt.Sprintf("unknown error: %#v", e)) } // Any time we return ErrBadConn, we need to remember it since *Tx doesn't // mark the connection bad in database/sql. if *err == driver.ErrBadConn { - cn.bad = true + cn.err.set(driver.ErrBadConn) } } diff --git a/go18_test.go b/go18_test.go index 72cd71fe9..6166db275 100644 --- a/go18_test.go +++ b/go18_test.go @@ -3,6 +3,8 @@ package pq import ( "context" "database/sql" + "database/sql/driver" + "errors" "runtime" "strings" "testing" @@ -74,6 +76,8 @@ func TestMultipleSimpleQuery(t *testing.T) { const contextRaceIterations = 100 +const cancelErrorCode ErrorCode = "57014" + func TestContextCancelExec(t *testing.T) { db := openTestConn(t) defer db.Close() @@ -86,7 +90,7 @@ func TestContextCancelExec(t *testing.T) { // Not canceled until after the exec has started. if _, err := db.ExecContext(ctx, "select pg_sleep(1)"); err == nil { t.Fatal("expected error") - } else if err.Error() != "pq: canceling statement due to user request" { + } else if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) { t.Fatalf("unexpected error: %s", err) } @@ -124,7 +128,7 @@ func TestContextCancelQuery(t *testing.T) { // Not canceled until after the exec has started. if _, err := db.QueryContext(ctx, "select pg_sleep(1)"); err == nil { t.Fatal("expected error") - } else if err.Error() != "pq: canceling statement due to user request" { + } else if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) { t.Fatalf("unexpected error: %s", err) } @@ -142,7 +146,7 @@ func TestContextCancelQuery(t *testing.T) { cancel() if err != nil { t.Fatal(err) - } else if err := rows.Close(); err != nil { + } else if err := rows.Close(); err != nil && err != driver.ErrBadConn && err != context.Canceled { t.Fatal(err) } }() @@ -176,13 +180,26 @@ func TestIssue617(t *testing.T) { } }() } - numGoroutineFinish := runtime.NumGoroutine() - // We use N/2 and not N because the GC and other actors may increase or - // decrease the number of goroutines. - if numGoroutineFinish-numGoroutineStart >= N/2 { - t.Errorf("goroutine leak detected, was %d, now %d", numGoroutineStart, numGoroutineFinish) + // Give time for goroutines to terminate + delayTime := time.Millisecond * 50 + waitTime := time.Second + iterations := int(waitTime / delayTime) + + var numGoroutineFinish int + for i := 0; i < iterations; i++ { + time.Sleep(delayTime) + + numGoroutineFinish = runtime.NumGoroutine() + + // We use N/2 and not N because the GC and other actors may increase or + // decrease the number of goroutines. + if numGoroutineFinish-numGoroutineStart < N/2 { + return + } } + + t.Errorf("goroutine leak detected, was %d, now %d", numGoroutineStart, numGoroutineFinish) } func TestContextCancelBegin(t *testing.T) { @@ -201,7 +218,7 @@ func TestContextCancelBegin(t *testing.T) { // Not canceled until after the exec has started. if _, err := tx.Exec("select pg_sleep(1)"); err == nil { t.Fatal("expected error") - } else if err.Error() != "pq: canceling statement due to user request" { + } else if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) { t.Fatalf("unexpected error: %s", err) } @@ -226,9 +243,9 @@ func TestContextCancelBegin(t *testing.T) { cancel() if err != nil { t.Fatal(err) - } else if err := tx.Rollback(); err != nil && - err.Error() != "pq: canceling statement due to user request" && - err != sql.ErrTxDone { + } else if err, pgErr := tx.Rollback(), (*Error)(nil); err != nil && + !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) && + err != sql.ErrTxDone && err != driver.ErrBadConn && err != context.Canceled { t.Fatal(err) } }() @@ -317,3 +334,19 @@ func TestTxOptions(t *testing.T) { t.Errorf("Expected error to mention isolation level, got %q", err) } } + +func TestErrorSQLState(t *testing.T) { + r := readBuf([]byte{67, 52, 48, 48, 48, 49, 0, 0}) // 40001 + err := parseError(&r) + var sqlErr errWithSQLState + if !errors.As(err, &sqlErr) { + t.Fatal("SQLState interface not satisfied") + } + if state := err.SQLState(); state != "40001" { + t.Fatalf("unexpected SQL state %v", state) + } +} + +type errWithSQLState interface { + SQLState() string +} diff --git a/go19_test.go b/go19_test.go index 25566d6f8..f810e4890 100644 --- a/go19_test.go +++ b/go19_test.go @@ -1,3 +1,4 @@ +//go:build go1.9 // +build go1.9 package pq diff --git a/issues_test.go b/issues_test.go index 3a330a0a9..e2c93925a 100644 --- a/issues_test.go +++ b/issues_test.go @@ -1,6 +1,12 @@ package pq -import "testing" +import ( + "context" + "database/sql" + "errors" + "testing" + "time" +) func TestIssue494(t *testing.T) { db := openTestConn(t) @@ -24,3 +30,129 @@ func TestIssue494(t *testing.T) { t.Fatal("expected error") } } + +func TestIssue1046(t *testing.T) { + ctxTimeout := time.Second * 2 + + db := openTestConn(t) + defer db.Close() + + ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout) + defer cancel() + + stmt, err := db.PrepareContext(ctx, `SELECT pg_sleep(10) AS id`) + if err != nil { + t.Fatal(err) + } + + var d []uint8 + err = stmt.QueryRowContext(ctx).Scan(&d) + dl, _ := ctx.Deadline() + since := time.Since(dl) + if since > ctxTimeout { + t.Logf("FAIL %s: query returned after context deadline: %v\n", t.Name(), since) + t.Fail() + } + if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) { + t.Logf("ctx.Err(): [%T]%+v\n", ctx.Err(), ctx.Err()) + t.Logf("got err: [%T] %+v expected errCode: %v", err, err, cancelErrorCode) + t.Fail() + } +} + +func TestIssue1062(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + // Ensure that cancelling a QueryRowContext does not result in an ErrBadConn. + + for i := 0; i < 100; i++ { + ctx, cancel := context.WithCancel(context.Background()) + go cancel() + row := db.QueryRowContext(ctx, "select 1") + + var v int + err := row.Scan(&v) + if pgErr := (*Error)(nil); err != nil && + err != context.Canceled && + !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) { + t.Fatalf("Scan resulted in unexpected error %v for canceled QueryRowContext at attempt %d", err, i+1) + } + } +} + +func connIsValid(t *testing.T, db *sql.DB) { + t.Helper() + + ctx := context.Background() + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + // the connection must be valid + err = conn.PingContext(ctx) + if err != nil { + t.Errorf("PingContext err=%#v", err) + } + // close must not return an error + err = conn.Close() + if err != nil { + t.Errorf("Close err=%#v", err) + } +} + +func TestQueryCancelRace(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + // cancel a query while executing on Postgres: must return the cancelled error code + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(10 * time.Millisecond) + cancel() + }() + row := db.QueryRowContext(ctx, "select pg_sleep(0.5)") + var pgSleepVoid string + err := row.Scan(&pgSleepVoid) + if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) { + t.Fatalf("expected cancelled error; err=%#v", err) + } + + // get a connection: it must be a valid + connIsValid(t, db) +} + +// Test cancelling a scan after it is started. This broke with 1.10.4. +func TestQueryCancelledReused(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + ctx, cancel := context.WithCancel(context.Background()) + // run a query that returns a lot of data + rows, err := db.QueryContext(ctx, "select generate_series(1, 10000)") + if err != nil { + t.Fatal(err) + } + + // scan the first value + if !rows.Next() { + t.Error("expected rows.Next() to return true") + } + var i int + err = rows.Scan(&i) + if err != nil { + t.Fatal(err) + } + if i != 1 { + t.Error(i) + } + + // cancel the context and close rows, ignoring errors + cancel() + rows.Close() + + // get a connection: it must be valid + connIsValid(t, db) +} diff --git a/notice.go b/notice.go index 01dd8c723..70ad122a7 100644 --- a/notice.go +++ b/notice.go @@ -1,3 +1,4 @@ +//go:build go1.10 // +build go1.10 package pq diff --git a/notice_example_test.go b/notice_example_test.go index 9392ad0bb..04bcb1d9d 100644 --- a/notice_example_test.go +++ b/notice_example_test.go @@ -1,3 +1,4 @@ +//go:build go1.10 // +build go1.10 package pq_test diff --git a/notice_test.go b/notice_test.go index 92212e778..e9da9af3a 100644 --- a/notice_test.go +++ b/notice_test.go @@ -1,3 +1,4 @@ +//go:build go1.10 // +build go1.10 package pq diff --git a/oid/gen.go b/oid/gen.go index 7c634cdc5..29a8de8ba 100644 --- a/oid/gen.go +++ b/oid/gen.go @@ -1,3 +1,4 @@ +//go:build ignore // +build ignore // Generate the table of OID values diff --git a/ssl.go b/ssl.go index d90208455..36b61ba45 100644 --- a/ssl.go +++ b/ssl.go @@ -8,6 +8,7 @@ import ( "os" "os/user" "path/filepath" + "strings" ) // ssl generates a function to upgrade a net.Conn based on the "sslmode" and @@ -50,6 +51,16 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) { return nil, fmterrorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode) } + // Set Server Name Indication (SNI), if enabled by connection parameters. + // By default SNI is on, any value which is not starting with "1" disables + // SNI -- that is the same check vanilla libpq uses. + if sslsni := o["sslsni"]; sslsni == "" || strings.HasPrefix(sslsni, "1") { + // RFC 6066 asks to not set SNI if the host is a literal IP address (IPv4 + // or IPv6). This check is coded already crypto.tls.hostnameInSNI, so + // just always set ServerName here and let crypto/tls do the filtering. + tlsConf.ServerName = o["host"] + } + err := sslClientCertificates(&tlsConf, o) if err != nil { return nil, err @@ -83,6 +94,16 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) { // in the user's home directory. The configured files must exist and have // the correct permissions. func sslClientCertificates(tlsConf *tls.Config, o values) error { + sslinline := o["sslinline"] + if sslinline == "true" { + cert, err := tls.X509KeyPair([]byte(o["sslcert"]), []byte(o["sslkey"])) + if err != nil { + return err + } + tlsConf.Certificates = []tls.Certificate{cert} + return nil + } + // user.Current() might fail when cross-compiling. We have to ignore the // error and continue without home directory defaults, since we wouldn't // know from where to load them. @@ -137,9 +158,17 @@ func sslCertificateAuthority(tlsConf *tls.Config, o values) error { if sslrootcert := o["sslrootcert"]; len(sslrootcert) > 0 { tlsConf.RootCAs = x509.NewCertPool() - cert, err := ioutil.ReadFile(sslrootcert) - if err != nil { - return err + sslinline := o["sslinline"] + + var cert []byte + if sslinline == "true" { + cert = []byte(sslrootcert) + } else { + var err error + cert, err = ioutil.ReadFile(sslrootcert) + if err != nil { + return err + } } if !tlsConf.RootCAs.AppendCertsFromPEM(cert) { diff --git a/ssl_permissions.go b/ssl_permissions.go index 3b7c3a2a3..d587f102e 100644 --- a/ssl_permissions.go +++ b/ssl_permissions.go @@ -1,8 +1,30 @@ +//go:build !windows // +build !windows package pq -import "os" +import ( + "errors" + "os" + "syscall" +) + +const ( + rootUserID = uint32(0) + + // The maximum permissions that a private key file owned by a regular user + // is allowed to have. This translates to u=rw. + maxUserOwnedKeyPermissions os.FileMode = 0600 + + // The maximum permissions that a private key file owned by root is allowed + // to have. This translates to u=rw,g=r. + maxRootOwnedKeyPermissions os.FileMode = 0640 +) + +var ( + errSSLKeyHasUnacceptableUserPermissions = errors.New("permissions for files not owned by root should be u=rw (0600) or less") + errSSLKeyHasUnacceptableRootPermissions = errors.New("permissions for root owned files should be u=rw,g=r (0640) or less") +) // sslKeyPermissions checks the permissions on user-supplied ssl key files. // The key file should have very little access. @@ -13,8 +35,59 @@ func sslKeyPermissions(sslkey string) error { if err != nil { return err } - if info.Mode().Perm()&0077 != 0 { - return ErrSSLKeyHasWorldPermissions + + err = hasCorrectPermissions(info) + + // return ErrSSLKeyHasWorldPermissions for backwards compatability with + // existing code. + if err == errSSLKeyHasUnacceptableUserPermissions || err == errSSLKeyHasUnacceptableRootPermissions { + err = ErrSSLKeyHasWorldPermissions } - return nil + return err +} + +// hasCorrectPermissions checks the file info (and the unix-specific stat_t +// output) to verify that the permissions on the file are correct. +// +// If the file is owned by the same user the process is running as, +// the file should only have 0600 (u=rw). If the file is owned by root, +// and the group matches the group that the process is running in, the +// permissions cannot be more than 0640 (u=rw,g=r). The file should +// never have world permissions. +// +// Returns an error when the permission check fails. +func hasCorrectPermissions(info os.FileInfo) error { + // if file's permission matches 0600, allow access. + userPermissionMask := (os.FileMode(0777) ^ maxUserOwnedKeyPermissions) + + // regardless of if we're running as root or not, 0600 is acceptable, + // so we return if we match the regular user permission mask. + if info.Mode().Perm()&userPermissionMask == 0 { + return nil + } + + // We need to pull the Unix file information to get the file's owner. + // If we can't access it, there's some sort of operating system level error + // and we should fail rather than attempting to use faulty information. + sysInfo := info.Sys() + if sysInfo == nil { + return ErrSSLKeyUnknownOwnership + } + + unixStat, ok := sysInfo.(*syscall.Stat_t) + if !ok { + return ErrSSLKeyUnknownOwnership + } + + // if the file is owned by root, we allow 0640 (u=rw,g=r) to match what + // Postgres does. + if unixStat.Uid == rootUserID { + rootPermissionMask := (os.FileMode(0777) ^ maxRootOwnedKeyPermissions) + if info.Mode().Perm()&rootPermissionMask != 0 { + return errSSLKeyHasUnacceptableRootPermissions + } + return nil + } + + return errSSLKeyHasUnacceptableUserPermissions } diff --git a/ssl_permissions_test.go b/ssl_permissions_test.go new file mode 100644 index 000000000..b0bdca107 --- /dev/null +++ b/ssl_permissions_test.go @@ -0,0 +1,109 @@ +//go:build !windows +// +build !windows + +package pq + +import ( + "os" + "syscall" + "testing" + "time" +) + +type stat_t_wrapper struct { + stat syscall.Stat_t +} + +func (stat_t *stat_t_wrapper) Name() string { + return "pem.key" +} + +func (stat_t *stat_t_wrapper) Size() int64 { + return int64(100) +} + +func (stat_t *stat_t_wrapper) Mode() os.FileMode { + return os.FileMode(stat_t.stat.Mode) +} + +func (stat_t *stat_t_wrapper) ModTime() time.Time { + return time.Now() +} + +func (stat_t *stat_t_wrapper) IsDir() bool { + return true +} + +func (stat_t *stat_t_wrapper) Sys() interface{} { + return &stat_t.stat +} + +func TestHasCorrectRootGroupPermissions(t *testing.T) { + currentUID := uint32(os.Getuid()) + currentGID := uint32(os.Getgid()) + + testData := []struct { + expectedError error + stat syscall.Stat_t + }{ + { + expectedError: nil, + stat: syscall.Stat_t{ + Mode: 0600, + Uid: currentUID, + Gid: currentGID, + }, + }, + { + expectedError: nil, + stat: syscall.Stat_t{ + Mode: 0640, + Uid: 0, + Gid: currentGID, + }, + }, + { + expectedError: errSSLKeyHasUnacceptableUserPermissions, + stat: syscall.Stat_t{ + Mode: 0666, + Uid: currentUID, + Gid: currentGID, + }, + }, + { + expectedError: errSSLKeyHasUnacceptableRootPermissions, + stat: syscall.Stat_t{ + Mode: 0666, + Uid: 0, + Gid: currentGID, + }, + }, + } + + for _, test := range testData { + wrapper := &stat_t_wrapper{ + stat: test.stat, + } + + if test.expectedError != hasCorrectPermissions(wrapper) { + if test.expectedError == nil { + t.Errorf( + "file owned by %d:%d with %s should not have failed check with error \"%s\"", + test.stat.Uid, + test.stat.Gid, + wrapper.Mode(), + hasCorrectPermissions(wrapper), + ) + continue + } + t.Errorf( + "file owned by %d:%d with %s, expected \"%s\", got \"%s\"", + test.stat.Uid, + test.stat.Gid, + wrapper.Mode(), + test.expectedError, + hasCorrectPermissions(wrapper), + ) + } + } +} diff --git a/ssl_test.go b/ssl_test.go index 3eafbfd20..4c631b81b 100644 --- a/ssl_test.go +++ b/ssl_test.go @@ -3,12 +3,20 @@ package pq // This file contains SSL tests import ( + "bytes" _ "crypto/sha256" + "crypto/tls" "crypto/x509" "database/sql" + "errors" + "fmt" + "io" + "net" "os" "path/filepath" + "strings" "testing" + "time" ) func maybeSkipSSLTests(t *testing.T) { @@ -79,9 +87,14 @@ func TestSSLVerifyFull(t *testing.T) { if err == nil { t.Fatal("expected error") } - _, ok := err.(x509.UnknownAuthorityError) - if !ok { - t.Fatalf("expected x509.UnknownAuthorityError, got %#+v", err) + { + var x509err x509.UnknownAuthorityError + if !errors.As(err, &x509err) { + var x509err x509.HostnameError + if !errors.As(err, &x509err) { + t.Fatalf("expected x509.UnknownAuthorityError or x509.HostnameError, got %#+v", err) + } + } } rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt") @@ -91,9 +104,11 @@ func TestSSLVerifyFull(t *testing.T) { if err == nil { t.Fatal("expected error") } - _, ok = err.(x509.HostnameError) - if !ok { - t.Fatalf("expected x509.HostnameError, got %#+v", err) + { + var x509err x509.HostnameError + if !errors.As(err, &x509err) { + t.Fatalf("expected x509.HostnameError, got %#+v", err) + } } // OK _, err = openSSLConn(t, rootCert+"host=postgres sslmode=verify-full user=pqgossltest") @@ -116,9 +131,11 @@ func TestSSLRequireWithRootCert(t *testing.T) { if err == nil { t.Fatal("expected error") } - _, ok := err.(x509.UnknownAuthorityError) - if !ok { - t.Fatalf("expected x509.UnknownAuthorityError, got %s, %#+v", err, err) + { + var x509err x509.UnknownAuthorityError + if !errors.As(err, &x509err) { + t.Fatalf("expected x509.UnknownAuthorityError, got %s, %#+v", err, err) + } } nonExistentCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "non_existent.crt") @@ -154,7 +171,8 @@ func TestSSLVerifyCA(t *testing.T) { // Not OK according to the system CA { _, err := openSSLConn(t, "host=postgres sslmode=verify-ca user=pqgossltest") - if _, ok := err.(x509.UnknownAuthorityError); !ok { + var x509err x509.UnknownAuthorityError + if !errors.As(err, &x509err) { t.Fatalf("expected %T, got %#+v", x509.UnknownAuthorityError{}, err) } } @@ -162,7 +180,8 @@ func TestSSLVerifyCA(t *testing.T) { // Still not OK according to the system CA; empty sslrootcert is treated as unspecified. { _, err := openSSLConn(t, "host=postgres sslmode=verify-ca user=pqgossltest sslrootcert=''") - if _, ok := err.(x509.UnknownAuthorityError); !ok { + var x509err x509.UnknownAuthorityError + if !errors.As(err, &x509err) { t.Fatalf("expected %T, got %#+v", x509.UnknownAuthorityError{}, err) } } @@ -233,7 +252,8 @@ func TestSSLClientCertificates(t *testing.T) { // Cert present, key not specified, should fail { _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert) - if _, ok := err.(*os.PathError); !ok { + var pathErr *os.PathError + if !errors.As(err, &pathErr) { t.Fatalf("expected %T, got %#+v", (*os.PathError)(nil), err) } } @@ -241,7 +261,8 @@ func TestSSLClientCertificates(t *testing.T) { // Cert present, empty key specified, should fail { _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey=''") - if _, ok := err.(*os.PathError); !ok { + var pathErr *os.PathError + if !errors.As(err, &pathErr) { t.Fatalf("expected %T, got %#+v", (*os.PathError)(nil), err) } } @@ -249,7 +270,8 @@ func TestSSLClientCertificates(t *testing.T) { // Cert present, non-existent key, should fail { _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey=/tmp/filedoesnotexist") - if _, ok := err.(*os.PathError); !ok { + var pathErr *os.PathError + if !errors.As(err, &pathErr) { t.Fatalf("expected %T, got %#+v", (*os.PathError)(nil), err) } } @@ -277,3 +299,135 @@ func TestSSLClientCertificates(t *testing.T) { } } } + +// Check that clint sends SNI data when `sslsni` is not disabled +func TestSNISupport(t *testing.T) { + t.Parallel() + tests := []struct { + name string + conn_param string + hostname string + expected_sni string + }{ + { + name: "SNI is set by default", + conn_param: "", + hostname: "localhost", + expected_sni: "localhost", + }, + { + name: "SNI is passed when asked for", + conn_param: "sslsni=1", + hostname: "localhost", + expected_sni: "localhost", + }, + { + name: "SNI is not passed when disabled", + conn_param: "sslsni=0", + hostname: "localhost", + expected_sni: "", + }, + { + name: "SNI is not set for IPv4", + conn_param: "", + hostname: "127.0.0.1", + expected_sni: "", + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Start mock postgres server on OS-provided port + listener, err := net.Listen("tcp", "127.0.0.1:") + if err != nil { + t.Fatal(err) + } + serverErrChan := make(chan error, 1) + serverSNINameChan := make(chan string, 1) + go mockPostgresSSL(listener, serverErrChan, serverSNINameChan) + + defer listener.Close() + defer close(serverErrChan) + defer close(serverSNINameChan) + + // Try to establish a connection with the mock server. Connection will error out after TLS + // clientHello, but it is enough to catch SNI data on the server side + port := strings.Split(listener.Addr().String(), ":")[1] + connStr := fmt.Sprintf("sslmode=require host=%s port=%s %s", tt.hostname, port, tt.conn_param) + + // We are okay to skip this error as we are polling serverErrChan and we'll get an error + // or timeout from the server side in case of problems here. + db, _ := sql.Open("postgres", connStr) + _, _ = db.Exec("SELECT 1") + + // Check SNI data + select { + case sniHost := <-serverSNINameChan: + if sniHost != tt.expected_sni { + t.Fatalf("Expected SNI to be 'localhost', got '%+v' instead", sniHost) + } + case err = <-serverErrChan: + t.Fatalf("mock server failed with error: %+v", err) + case <-time.After(time.Second): + t.Fatal("exceeded connection timeout without erroring out") + } + }) + } +} + +// Make a postgres mock server to test TLS SNI +// +// Accepts postgres StartupMessage and handles TLS clientHello, then closes a connection. +// While reading clientHello catch passed SNI data and report it to nameChan. +func mockPostgresSSL(listener net.Listener, errChan chan error, nameChan chan string) { + var sniHost string + + conn, err := listener.Accept() + if err != nil { + errChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(time.Second)) + if err != nil { + errChan <- err + return + } + + // Receive StartupMessage with SSL Request + startupMessage := make([]byte, 8) + if _, err := io.ReadFull(conn, startupMessage); err != nil { + errChan <- err + return + } + // StartupMessage: first four bytes -- total len = 8, last four bytes SslRequestNumber + if !bytes.Equal(startupMessage, []byte{0, 0, 0, 0x8, 0x4, 0xd2, 0x16, 0x2f}) { + errChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage) + return + } + + // Respond with SSLOk + _, err = conn.Write([]byte("S")) + if err != nil { + errChan <- err + return + } + + // Set up TLS context to catch clientHello. It will always error out during handshake + // as no certificate is set. + srv := tls.Server(conn, &tls.Config{ + GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) { + sniHost = argHello.ServerName + return nil, nil + }, + }) + defer srv.Close() + + // Do the TLS handshake ignoring errors + _ = srv.Handshake() + + nameChan <- sniHost +} diff --git a/ssl_windows.go b/ssl_windows.go index 5d2c763ce..73663c8f1 100644 --- a/ssl_windows.go +++ b/ssl_windows.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package pq diff --git a/url.go b/url.go index f4d8a7c20..aec6e95be 100644 --- a/url.go +++ b/url.go @@ -40,10 +40,10 @@ func ParseURL(url string) (string, error) { } var kvs []string - escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`) + escaper := strings.NewReplacer(`'`, `\'`, `\`, `\\`) accrue := func(k, v string) { if v != "" { - kvs = append(kvs, k+"="+escaper.Replace(v)) + kvs = append(kvs, k+"='"+escaper.Replace(v)+"'") } } diff --git a/url_test.go b/url_test.go index 4ff0ce034..9b6345595 100644 --- a/url_test.go +++ b/url_test.go @@ -5,7 +5,7 @@ import ( ) func TestSimpleParseURL(t *testing.T) { - expected := "host=hostname.remote" + expected := "host='hostname.remote'" str, err := ParseURL("postgres://hostname.remote") if err != nil { t.Fatal(err) @@ -17,7 +17,7 @@ func TestSimpleParseURL(t *testing.T) { } func TestIPv6LoopbackParseURL(t *testing.T) { - expected := "host=::1 port=1234" + expected := "host='::1' port='1234'" str, err := ParseURL("postgres://[::1]:1234") if err != nil { t.Fatal(err) @@ -29,7 +29,7 @@ func TestIPv6LoopbackParseURL(t *testing.T) { } func TestFullParseURL(t *testing.T) { - expected := `dbname=database host=hostname.remote password=top\ secret port=1234 user=username` + expected := `dbname='database' host='hostname.remote' password='top secret' port='1234' user='username'` str, err := ParseURL("postgres://username:top%20secret@hostname.remote:1234/database") if err != nil { t.Fatal(err) diff --git a/user_other.go b/user_other.go new file mode 100644 index 000000000..3dae8f557 --- /dev/null +++ b/user_other.go @@ -0,0 +1,10 @@ +// Package pq is a pure Go Postgres driver for the database/sql package. + +//go:build js || android || hurd || zos +// +build js android hurd zos + +package pq + +func userCurrent() (string, error) { + return "", ErrCouldNotDetectUsername +} diff --git a/user_posix.go b/user_posix.go index a51019205..5f2d439bc 100644 --- a/user_posix.go +++ b/user_posix.go @@ -1,6 +1,7 @@ // Package pq is a pure Go Postgres driver for the database/sql package. -// +build aix darwin dragonfly freebsd linux nacl netbsd openbsd plan9 solaris rumprun +//go:build aix || darwin || dragonfly || freebsd || (linux && !android) || nacl || netbsd || openbsd || plan9 || solaris || rumprun || illumos +// +build aix darwin dragonfly freebsd linux,!android nacl netbsd openbsd plan9 solaris rumprun illumos package pq