From a75919c71fa2e1978acbd3ae71169f94c5eadd36 Mon Sep 17 00:00:00 2001 From: David Christofas Date: Wed, 29 Apr 2020 15:03:10 +0200 Subject: [PATCH] implement ocs to http status code mapping --- internal/http/services/owncloud/ocs/ocs.go | 15 +- .../owncloud/ocs/response/response.go | 129 +++++++++++++----- 2 files changed, 105 insertions(+), 39 deletions(-) diff --git a/internal/http/services/owncloud/ocs/ocs.go b/internal/http/services/owncloud/ocs/ocs.go index 85c8d6a7fe..0481da58ec 100644 --- a/internal/http/services/owncloud/ocs/ocs.go +++ b/internal/http/services/owncloud/ocs/ocs.go @@ -30,6 +30,11 @@ import ( "github.com/mitchellh/mapstructure" ) +const ( + apiV1 = "v1.php" + apiV2 = "v2.php" +) + func init() { global.Register("ocs", New) } @@ -86,13 +91,11 @@ func (s *svc) Handler() http.Handler { log.Debug().Str("head", head).Str("tail", r.URL.Path).Msg("ocs routing") - // TODO v2 uses a status code mapper - // see https://github.com/owncloud/core/commit/bacf1603ffd53b7a5f73854d1d0ceb4ae545ce9f#diff-262cbf0df26b45bad0cf00d947345d9c - if head == "v1.php" || head == "v2.php" { - s.V1Handler.Handler().ServeHTTP(w, r) + if !(head == apiV1 || head == apiV2) { + response.WriteOCSError(w, r, response.MetaNotFound.StatusCode, "Not found", nil) return } - - response.WriteOCSError(w, r, response.MetaNotFound.StatusCode, "Not found", nil) + ctx := response.WithAPIVersion(r.Context(), head) + s.V1Handler.Handler().ServeHTTP(w, r.WithContext(ctx)) }) } diff --git a/internal/http/services/owncloud/ocs/response/response.go b/internal/http/services/owncloud/ocs/response/response.go index d89f916cf1..013dc9c940 100644 --- a/internal/http/services/owncloud/ocs/response/response.go +++ b/internal/http/services/owncloud/ocs/response/response.go @@ -19,6 +19,8 @@ package response import ( + "bytes" + "context" "encoding/json" "encoding/xml" "net/http" @@ -28,6 +30,16 @@ import ( "github.com/cs3org/reva/pkg/appctx" ) +type key int + +const ( + apiVersionKey key = 1 +) + +var ( + defaultStatusCodeMapper = OcsV2StatusCodes +) + // Response is the top level response structure type Response struct { OCS *Payload `json:"ocs"` @@ -137,54 +149,28 @@ func WriteOCSData(w http.ResponseWriter, r *http.Request, m *Meta, d interface{} // WriteOCSResponse handles writing ocs responses in json and xml func WriteOCSResponse(w http.ResponseWriter, r *http.Request, res *Response, err error) { - var encoded []byte - if err != nil { appctx.GetLogger(r.Context()).Error().Err(err).Msg(res.OCS.Meta.Message) } + var encoder func(*Payload) ([]byte, error) if r.URL.Query().Get("format") == "json" { w.Header().Set("Content-Type", "application/json") - encoded, err = json.Marshal(res) + encoder = encodeJSON } else { w.Header().Set("Content-Type", "application/xml") - _, err = w.Write([]byte(xml.Header)) - if err != nil { - appctx.GetLogger(r.Context()).Error().Err(err).Msg("error writing xml header") - w.WriteHeader(http.StatusInternalServerError) - return - } - encoded, err = xml.Marshal(res.OCS) + encoder = encodeXML } + encoded, err := encoder(res.OCS) if err != nil { appctx.GetLogger(r.Context()).Error().Err(err).Msg("error encoding ocs response") w.WriteHeader(http.StatusInternalServerError) return } - // TODO map error for v2 only? - // see https://github.com/owncloud/core/commit/bacf1603ffd53b7a5f73854d1d0ceb4ae545ce9f#diff-262cbf0df26b45bad0cf00d947345d9c - switch res.OCS.Meta.StatusCode { - case MetaNotFound.StatusCode: - w.WriteHeader(http.StatusNotFound) - case MetaServerError.StatusCode: - w.WriteHeader(http.StatusInternalServerError) - case MetaUnknownError.StatusCode: - w.WriteHeader(http.StatusInternalServerError) - case MetaUnauthorized.StatusCode: - w.WriteHeader(http.StatusUnauthorized) - case 100: - w.WriteHeader(http.StatusOK) - case 104: - w.WriteHeader(http.StatusForbidden) - default: - // any 2xx, 4xx and 5xx will be used as is - if res.OCS.Meta.StatusCode >= 200 && res.OCS.Meta.StatusCode < 600 { - w.WriteHeader(res.OCS.Meta.StatusCode) - } else { - w.WriteHeader(http.StatusBadRequest) - } - } + m := statusCodeMapper(r.Context()) + statusCode := m(res.OCS.Meta) + w.WriteHeader(statusCode) _, err = w.Write(encoded) if err != nil { @@ -203,3 +189,80 @@ func UserIDToString(userID *user.UserId) string { } return userID.OpaqueId + "@" + userID.Idp } + +func encodeXML(payload *Payload) ([]byte, error) { + marshalled, err := xml.Marshal(payload) + if err != nil { + return nil, err + } + b := new(bytes.Buffer) + b.Write([]byte(xml.Header)) + b.Write(marshalled) + return b.Bytes(), nil +} + +func encodeJSON(payload *Payload) ([]byte, error) { + return json.Marshal(payload) +} + +// OcsV1StatusCodes returns the http status codes for the OCS API v1. +func OcsV1StatusCodes(meta *Meta) int { + return http.StatusOK +} + +// OcsV2StatusCodes maps the OCS codes to http status codes for the ocs API v2. +func OcsV2StatusCodes(meta *Meta) int { + sc := meta.StatusCode + switch sc { + case MetaNotFound.StatusCode: + return http.StatusNotFound + case MetaUnknownError.StatusCode: + fallthrough + case MetaServerError.StatusCode: + return http.StatusInternalServerError + case MetaUnauthorized.StatusCode: + return http.StatusUnauthorized + case 100: + return http.StatusOK + } + // any 2xx, 4xx and 5xx will be used as is + if sc >= 200 && sc < 600 { + return sc + } + + // any error codes > 100 are treated as client errors + if sc > 100 && sc < 200 { + return http.StatusBadRequest + } + + // TODO change this status code? + return http.StatusOK +} + +// WithAPIVersion puts the api version in the context. +func WithAPIVersion(parent context.Context, version string) context.Context { + return context.WithValue(parent, apiVersionKey, version) +} + +// APIVersion retrieves the api version from the context. +func APIVersion(ctx context.Context) string { + value := ctx.Value(apiVersionKey) + if value != nil { + return value.(string) + } + return "" +} + +func statusCodeMapper(ctx context.Context) func(*Meta) int { + version := APIVersion(ctx) + var mapper func(*Meta) int + switch version { + case "v1.php": + mapper = OcsV1StatusCodes + case "v2.php": + mapper = OcsV2StatusCodes + default: + mapper = defaultStatusCodeMapper + } + return mapper +}