Skip to content

Commit

Permalink
httpgrpc: correct handling of non-loggable errors
Browse files Browse the repository at this point in the history
Signed-off-by: Yuri Nikolic <durica.nikolic@grafana.com>
  • Loading branch information
duricanikolic committed Oct 27, 2023
1 parent 7d64494 commit 7057a5d
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 3 deletions.
27 changes: 26 additions & 1 deletion httpgrpc/httpgrpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ package httpgrpc

import (
"context"
"errors"
"fmt"

"github.com/go-kit/log/level"
"google.golang.org/grpc/metadata"
grpcstatus "google.golang.org/grpc/status"

spb "github.com/gogo/googleapis/google/rpc"
"github.com/gogo/protobuf/types"
Expand Down Expand Up @@ -44,7 +46,7 @@ func ErrorFromHTTPResponse(resp *HTTPResponse) error {

// HTTPResponseFromError converts a grpc error into an HTTP response
func HTTPResponseFromError(err error) (*HTTPResponse, bool) {
s, ok := status.FromError(err)
s, ok := statusFromError(err)
if !ok {
return nil, false
}
Expand All @@ -63,6 +65,29 @@ func HTTPResponseFromError(err error) (*HTTPResponse, bool) {
return &resp, true
}

// statusFromError tries to cast the given error into status.Status.
// If the given error, or any error from its tree are a status.Status,
// that status.Status and the outcome true are returned.
// Otherwise, nil and the outcome false are returned.
// This implementation differs from status.FromError() because the
// latter checks only if the given error can be cast to status.Status,
// and doesn't check other errors in the given error's tree.
func statusFromError(err error) (*status.Status, bool) {
if err == nil {
return nil, false
}
type grpcStatus interface{ GRPCStatus() *grpcstatus.Status }
var gs grpcStatus
if errors.As(err, &gs) {
st := gs.GRPCStatus()
if st == nil {
return nil, false
}
return status.FromGRPCStatus(st), true
}
return nil, false
}

const (
MetadataMethod = "httpgrpc-method"
MetadataURL = "httpgrpc-url"
Expand Down
147 changes: 147 additions & 0 deletions httpgrpc/httpgrpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@ package httpgrpc

import (
"context"
"fmt"
"testing"

"github.com/gogo/status"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
grpcstatus "google.golang.org/grpc/status"
)

func TestAppendMessageSizeToOutgoingContext(t *testing.T) {
Expand All @@ -24,3 +28,146 @@ func TestAppendMessageSizeToOutgoingContext(t *testing.T) {
require.Equal(t, []string{"GET"}, md.Get(MetadataMethod))
require.Equal(t, []string{"/test"}, md.Get(MetadataURL))
}

func TestErrorf(t *testing.T) {
code := 400
errMsg := "this is an error"
expectedHTTPResponse := &HTTPResponse{
Code: int32(code),
Body: []byte(errMsg),
}
err := Errorf(code, errMsg)
stat, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, code, int(stat.Code()))
require.Equal(t, errMsg, stat.Message())
checkDetailAsHTTPResponse(t, expectedHTTPResponse, stat)
}

func TestErrorFromHTTPResponse(t *testing.T) {
var code int32 = 400
errMsg := "this is an error"
headers := []*Header{{Key: "X-Header", Values: []string{"a", "b", "c"}}}
resp := &HTTPResponse{
Code: code,
Headers: headers,
Body: []byte(errMsg),
}
err := ErrorFromHTTPResponse(resp)
require.Error(t, err)
stat, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, code, int32(stat.Code()))
require.Equal(t, errMsg, stat.Message())
checkDetailAsHTTPResponse(t, resp, stat)
}

func TestHTTPResponseFromError(t *testing.T) {
msgErr := "this is an error"
testCases := map[string]struct {
err error
isGRPCError bool
isHTTPGRCPError bool
expectedHTTPResponse *HTTPResponse
}{
"no error cannot be parsed to an HTTPResponse": {
err: nil,
},
"a random error cannot be parsed to an HTTPResponse": {
err: fmt.Errorf(msgErr),
},
"a gRPC error built by gogo/status cannot be parsed to an HTTPResponse": {
err: status.Error(codes.Internal, msgErr),
},
"a gRPC error built by grpc/status cannot be parsed to an HTTPResponse": {
err: grpcstatus.Error(codes.Internal, msgErr),
},
"a gRPC error built by httpgrpc can be parsed to an HTTPResponse": {
err: Errorf(400, msgErr),
expectedHTTPResponse: &HTTPResponse{Code: 400, Body: []byte(msgErr)},
},
"a wrapped gRPC error built by httpgrpc can be parsed to an HTTPResponse": {
err: fmt.Errorf("wrapped: %w", Errorf(400, msgErr)),
expectedHTTPResponse: &HTTPResponse{Code: 400, Body: []byte(msgErr)},
},
}
for testName, testData := range testCases {
t.Run(testName, func(t *testing.T) {
resp, ok := HTTPResponseFromError(testData.err)
if testData.expectedHTTPResponse == nil {
require.False(t, ok)
require.Nil(t, resp)
} else {
require.True(t, ok)

}
})
}
}

func TestStatusFromError(t *testing.T) {
msgErr := "this is an error"
testCases := map[string]struct {
err error
expectedStatus *status.Status
}{
"no error cannot be cast to status.Status": {
err: nil,
},
"a random error cannot be cast to status.Status": {
err: fmt.Errorf(msgErr),
},
"a wrapped error of a random error cannot be cast to status.Status": {
err: fmt.Errorf("wrapped: %w", fmt.Errorf(msgErr)),
},
"a gRPC error built by gogo/status can be cast to status.Status": {
err: status.Error(codes.Internal, msgErr),
expectedStatus: status.New(codes.Internal, msgErr),
},
"a wrapped error of a gRPC error built by gogo/status can be cast to status.Status": {
err: fmt.Errorf("wrapped: %w", status.Error(codes.Internal, msgErr)),
expectedStatus: status.New(codes.Internal, msgErr),
},
"a gRPC error built by grpc/status can be cast to status.Status": {
err: grpcstatus.Error(codes.Internal, msgErr),
expectedStatus: status.New(codes.Internal, msgErr),
},
"a wrapped error of a gRPC error built by grpc/status can be cast to status.Status": {
err: fmt.Errorf("wrapped: %w", grpcstatus.Error(codes.Internal, msgErr)),
expectedStatus: status.New(codes.Internal, msgErr),
},
"a gRPC error built by httpgrpc can be cast to status.Status": {
err: Errorf(400, msgErr),
expectedStatus: status.New(400, msgErr),
},
"a wrapped gRPC error built by httpgrpc can be cast to status.Status": {
err: fmt.Errorf("wrapped: %w", Errorf(400, msgErr)),
expectedStatus: status.New(400, msgErr),
},
}
for testName, testData := range testCases {
t.Run(testName, func(t *testing.T) {
stat, ok := statusFromError(testData.err)
if testData.expectedStatus == nil {
require.False(t, ok)
require.Nil(t, stat)
} else {
require.True(t, ok)
require.NotNil(t, stat)
require.Equal(t, testData.expectedStatus.Code(), stat.Code())
require.Equal(t, testData.expectedStatus.Message(), stat.Message())
}
})
}
}

func checkDetailAsHTTPResponse(t *testing.T, httpResponse *HTTPResponse, stat *status.Status) {
details := stat.Details()
require.Len(t, details, 1)
respDetails, ok := details[0].(*HTTPResponse)
require.True(t, ok)
require.NotNil(t, respDetails)
require.Equal(t, httpResponse.Code, respDetails.Code)
require.Equal(t, httpResponse.Headers, respDetails.Headers)
require.Equal(t, httpResponse.Body, respDetails.Body)
}
14 changes: 12 additions & 2 deletions httpgrpc/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,18 @@ func (s Server) Handle(ctx context.Context, r *httpgrpc.HTTPRequest) (*httpgrpc.

recorder := httptest.NewRecorder()
s.handler.ServeHTTP(recorder, req)
header := recorder.Header()
resp := &httpgrpc.HTTPResponse{
Code: int32(recorder.Code),
Headers: fromHeader(recorder.Header()),
Headers: fromHeader(header),
Body: recorder.Body.Bytes(),
}
if recorder.Code/100 == 5 {
return nil, httpgrpc.ErrorFromHTTPResponse(resp)
err := httpgrpc.ErrorFromHTTPResponse(resp)
if containsDoNotLogErrorKey(header) {
return nil, middleware.DoNotLogError{Err: err}
}
return nil, err
}
return resp, nil
}
Expand Down Expand Up @@ -234,3 +239,8 @@ func fromHeader(hs http.Header) []*httpgrpc.Header {
}
return result
}

func containsDoNotLogErrorKey(hs http.Header) bool {
_, ok := hs[middleware.DoNotLogErrorHeader]
return ok
}
34 changes: 34 additions & 0 deletions httpgrpc/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,37 @@ func TestTracePropagation(t *testing.T) {
assert.Equal(t, "world", recorder.Body.String())
assert.Equal(t, 200, recorder.Code)
}

func TestContainsDoNotLogErrorKey(t *testing.T) {
testCases := map[string]struct {
header http.Header
expectedOutcome bool
}{
"if headers do not contain X-DoNotLogError, return false": {
header: http.Header{
"X-First": []string{"a", "b", "c"},
"X-Second": []string{"1", "2"},
},
expectedOutcome: false,
},
"if headers contain X-DoNotLogError with a value, return true": {
header: http.Header{
"X-First": []string{"a", "b", "c"},
"X-DoNotLogError": []string{"true"},
},
expectedOutcome: true,
},
"if headers contain X-DoNotLogError without a value, return true": {
header: http.Header{
"X-First": []string{"a", "b", "c"},
"X-DoNotLogError": nil,
},
expectedOutcome: true,
},
}
for testName, testData := range testCases {
t.Run(testName, func(t *testing.T) {
require.Equal(t, testData.expectedOutcome, containsDoNotLogErrorKey(testData.header))
})
}
}
11 changes: 11 additions & 0 deletions middleware/grpc_logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package middleware
import (
"context"
"errors"
"net/http"
"time"

"github.com/go-kit/log"
Expand All @@ -24,11 +25,21 @@ const (
gRPC = "gRPC"
)

var (
DoNotLogErrorHeader = http.CanonicalHeaderKey("X-DoNotLogError")
)

// An error can implement ShouldLog() to control whether GRPCServerLog will log.
type OptionalLogging interface {
ShouldLog(ctx context.Context, duration time.Duration) bool
}

type DoNotLogError struct{ Err error }

func (i DoNotLogError) Error() string { return i.Err.Error() }
func (i DoNotLogError) Unwrap() error { return i.Err }
func (i DoNotLogError) ShouldLog(_ context.Context, _ time.Duration) bool { return false }

// GRPCServerLog logs grpc requests, errors, and latency.
type GRPCServerLog struct {
Log log.Logger
Expand Down

0 comments on commit 7057a5d

Please sign in to comment.