diff --git a/contrib/screener-api/client/client.go b/contrib/screener-api/client/client.go index 48e438175f..c5f0e70811 100644 --- a/contrib/screener-api/client/client.go +++ b/contrib/screener-api/client/client.go @@ -6,6 +6,7 @@ import ( "crypto/hmac" "crypto/sha256" "encoding/hex" + "encoding/json" "fmt" "strings" "time" @@ -17,11 +18,6 @@ import ( "github.com/synapsecns/sanguine/core/metrics" ) -var ( - // BlacklistEndpoint is the endpoint for blacklisting an address. - BlacklistEndpoint = "/api/data/sync/" -) - // ScreenerClient is an interface for the Screener API. type ScreenerClient interface { ScreenAddress(ctx context.Context, ruleset, address string) (blocked bool, err error) @@ -69,7 +65,7 @@ func (c clientImpl) ScreenAddress(ctx context.Context, ruleset, address string) // BlackListBody is the json payload that represents a blacklisted address. type BlackListBody struct { - TypeReq string `json:"typereq"` + Type string `json:"type"` ID string `json:"id"` Data string `json:"data"` Address string `json:"address"` @@ -90,19 +86,27 @@ func (c clientImpl) BlacklistAddress(ctx context.Context, appsecret string, appi timestamp := fmt.Sprintf("%d", time.Now().Unix()) queryString := "" - signature := GenerateSignature(appsecret, appid, timestamp, nonce, queryString, body) + bodyBz, err := json.Marshal(body) + if err != nil { + return "", fmt.Errorf("error marshaling body: %w", err) + } + + message := fmt.Sprintf("%s%s%s%s%s%s%s", + appid, timestamp, nonce, "POST", "/api/data/sync/", queryString, string(bodyBz)) + + signature := GenerateSignature(appsecret, message) resp, err := c.rClient.R(). SetContext(ctx). SetHeader("Content-Type", "application/json"). - SetHeader("appid", appid). - SetHeader("timestamp", timestamp). - SetHeader("nonce", nonce). - SetHeader("queryString", queryString). - SetHeader("signature", signature). - SetResult(&blacklistRes). + SetHeader("AppID", appid). + SetHeader("Timestamp", timestamp). + SetHeader("Nonce", nonce). + SetHeader("QueryString", queryString). + SetHeader("Signature", signature). SetBody(body). - Post(BlacklistEndpoint) + SetResult(&blacklistRes). + Post("/api/data/sync/") if err != nil { return resp.Status(), fmt.Errorf("error from server: %s: %w", resp.String(), err) @@ -115,6 +119,17 @@ func (c clientImpl) BlacklistAddress(ctx context.Context, appsecret string, appi return blacklistRes.Status, nil } +// GenerateSignature generates a signature for the request. +func GenerateSignature( + secret, + message string, +) string { + key := []byte(secret) + h := hmac.New(sha256.New, key) + h.Write([]byte(message)) + return hex.EncodeToString(h.Sum(nil)) +} + // NewNoOpClient creates a new no-op client for the Screener API. // it returns false for every address. func NewNoOpClient() (ScreenerClient, error) { @@ -131,32 +146,4 @@ func (n noOpClient) BlacklistAddress(_ context.Context, _ string, _ string, _ Bl return "", nil } -// GenerateSignature generates a signature for the request. -func GenerateSignature(secret string, - appid string, - timestamp string, - nonce string, - queryString string, - body BlackListBody, -) string { - key := []byte(secret) - - // Concatenate the body. - message := fmt.Sprintf( - "%s%s%s%s%s%s%s", - appid, - timestamp, - nonce, - "POST", - BlacklistEndpoint, - queryString, - body, - ) - - h := hmac.New(sha256.New, key) - h.Write([]byte(message)) - - return strings.ToLower(hex.EncodeToString(h.Sum(nil))) -} - var _ ScreenerClient = noOpClient{} diff --git a/contrib/screener-api/db/db_test.go b/contrib/screener-api/db/db_test.go index 05eaabcaa3..bf93dfef85 100644 --- a/contrib/screener-api/db/db_test.go +++ b/contrib/screener-api/db/db_test.go @@ -72,7 +72,7 @@ func (d *DBSuite) TestBlacklist() { testAddress := gofakeit.BitcoinAddress() blacklistBody := db.BlacklistedAddress{ - TypeReq: "create", + Type: "create", ID: "testId", Address: testAddress, Network: "bitcoin", @@ -88,7 +88,7 @@ func (d *DBSuite) TestBlacklist() { d.Require().NotNil(blacklistedAddress) // update the address - blacklistBody.TypeReq = "update" + blacklistBody.Type = "update" blacklistBody.Remark = "testRemarkUpdated" err = testDB.UpdateBlacklistedAddress(d.GetTestContext(), blacklistBody.ID, blacklistBody) d.Require().NoError(err) diff --git a/contrib/screener-api/db/models.go b/contrib/screener-api/db/models.go index bcc89e3f4c..383bce26fd 100644 --- a/contrib/screener-api/db/models.go +++ b/contrib/screener-api/db/models.go @@ -20,7 +20,7 @@ type BlacklistedAddress struct { CreatedAt time.Time `json:"createdAt"` UpdatedAt time.Time `json:"updatedAt"` - TypeReq string `gorm:"column:typereq" json:"typereq"` + Type string `gorm:"column:type" json:"type"` ID string `gorm:"column:id;primary_key" json:"id"` Data string `gorm:"column:data" json:"data"` Address string `gorm:"column:address" json:"address"` diff --git a/contrib/screener-api/db/sql/base/base.go b/contrib/screener-api/db/sql/base/base.go index 91e9fca664..747e9b92bc 100644 --- a/contrib/screener-api/db/sql/base/base.go +++ b/contrib/screener-api/db/sql/base/base.go @@ -60,7 +60,7 @@ func (s *Store) PutBlacklistedAddress(ctx context.Context, body db.BlacklistedAd {Name: idName}, }, DoUpdates: clause.AssignmentColumns([]string{ - idName, typeReqName, dataName, addressName, networkName, tagName, remarkName}, + idName, typeName, dataName, addressName, networkName, tagName, remarkName}, ), }).Create(&body) if dbTx.Error != nil { diff --git a/contrib/screener-api/db/sql/base/namer.go b/contrib/screener-api/db/sql/base/namer.go index 7d305d7052..302873c27d 100644 --- a/contrib/screener-api/db/sql/base/namer.go +++ b/contrib/screener-api/db/sql/base/namer.go @@ -10,7 +10,7 @@ func init() { addressName = namer.GetConsistentName("Address") indicatorName = namer.GetConsistentName("Indicators") - typeReqName = namer.GetConsistentName("TypeReq") + typeName = namer.GetConsistentName("Type") idName = namer.GetConsistentName("ID") dataName = namer.GetConsistentName("Data") networkName = namer.GetConsistentName("Network") @@ -22,7 +22,7 @@ var ( addressName string indicatorName string - typeReqName string + typeName string idName string dataName string networkName string diff --git a/contrib/screener-api/docs/docs.go b/contrib/screener-api/docs/docs.go index ec00e7b428..752cecc4c9 100644 --- a/contrib/screener-api/docs/docs.go +++ b/contrib/screener-api/docs/docs.go @@ -160,7 +160,7 @@ const docTemplate = `{ "tag": { "type": "string" }, - "typereq": { + "type": { "type": "string" }, "updatedAt": { diff --git a/contrib/screener-api/docs/swagger.json b/contrib/screener-api/docs/swagger.json index ef9bf8b0e5..5132742779 100644 --- a/contrib/screener-api/docs/swagger.json +++ b/contrib/screener-api/docs/swagger.json @@ -149,7 +149,7 @@ "tag": { "type": "string" }, - "typereq": { + "type": { "type": "string" }, "updatedAt": { diff --git a/contrib/screener-api/docs/swagger.yaml b/contrib/screener-api/docs/swagger.yaml index aa75918ffc..9d108f9ce2 100644 --- a/contrib/screener-api/docs/swagger.yaml +++ b/contrib/screener-api/docs/swagger.yaml @@ -15,7 +15,7 @@ definitions: type: string tag: type: string - typereq: + type: type: string updatedAt: type: string diff --git a/contrib/screener-api/screener/screener.go b/contrib/screener-api/screener/screener.go index 33b2bba172..e262066b96 100644 --- a/contrib/screener-api/screener/screener.go +++ b/contrib/screener-api/screener/screener.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net/http" "strings" "sync" @@ -94,7 +95,6 @@ func NewScreener(ctx context.Context, cfg config.Config, metricHandler metrics.H screener.router.Handle(http.MethodGet, "/:ruleset/address/:address", screener.screenAddress) screener.router.Handle(http.MethodPost, "/api/data/sync", screener.authMiddleware(cfg), screener.blacklistAddress) - screener.router.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerfiles.Handler)) return &screener, nil @@ -159,7 +159,7 @@ func (s *screenerImpl) blacklistAddress(c *gin.Context) { return } - span.SetAttributes(attribute.String("type", blacklistBody.TypeReq)) + span.SetAttributes(attribute.String("type", blacklistBody.Type)) span.SetAttributes(attribute.String("id", blacklistBody.ID)) span.SetAttributes(attribute.String("data", blacklistBody.Data)) span.SetAttributes(attribute.String("network", blacklistBody.Network)) @@ -168,7 +168,7 @@ func (s *screenerImpl) blacklistAddress(c *gin.Context) { span.SetAttributes(attribute.String("address", blacklistBody.Address)) blacklistedAddress := db.BlacklistedAddress{ - TypeReq: blacklistBody.TypeReq, + Type: blacklistBody.Type, ID: blacklistBody.ID, Data: blacklistBody.Data, Network: blacklistBody.Network, @@ -177,7 +177,7 @@ func (s *screenerImpl) blacklistAddress(c *gin.Context) { Address: strings.ToLower(blacklistBody.Address), } - switch blacklistBody.TypeReq { + switch blacklistBody.Type { case "create": if err := s.db.PutBlacklistedAddress(ctx, blacklistedAddress); err != nil { span.AddEvent("error", trace.WithAttributes(attribute.String("error", err.Error()))) @@ -185,28 +185,34 @@ func (s *screenerImpl) blacklistAddress(c *gin.Context) { return } + span.AddEvent("blacklistedAddress", trace.WithAttributes(attribute.String("address", blacklistBody.Address))) c.JSON(http.StatusOK, gin.H{"status": "success"}) return case "update": - if err := s.db.UpdateBlacklistedAddress(c, blacklistedAddress.ID, blacklistedAddress); err != nil { + if err := s.db.UpdateBlacklistedAddress(ctx, blacklistedAddress.ID, blacklistedAddress); err != nil { + span.AddEvent("error", trace.WithAttributes(attribute.String("error", err.Error()))) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } + span.AddEvent("blacklistedAddress", trace.WithAttributes(attribute.String("address", blacklistBody.Address))) c.JSON(http.StatusOK, gin.H{"status": "success"}) return case "delete": - if err := s.db.DeleteBlacklistedAddress(c, blacklistedAddress.Address); err != nil { + if err := s.db.DeleteBlacklistedAddress(ctx, blacklistedAddress.Address); err != nil { + span.AddEvent("error", trace.WithAttributes(attribute.String("error", err.Error()))) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } + span.AddEvent("blacklistedAddress", trace.WithAttributes(attribute.String("address", blacklistBody.Address))) c.JSON(http.StatusOK, gin.H{"status": "success"}) return default: + span.AddEvent("error", trace.WithAttributes(attribute.String("error", err.Error()))) c.JSON(http.StatusBadRequest, gin.H{"error": "invalid type"}) return } @@ -216,34 +222,45 @@ func (s *screenerImpl) blacklistAddress(c *gin.Context) { // compare it with the signature provided. If they match, the request is allowed to pass through. func (s *screenerImpl) authMiddleware(cfg config.Config) gin.HandlerFunc { return func(c *gin.Context) { - var blacklistBody client.BlackListBody + _, span := s.metrics.Tracer().Start(c.Request.Context(), "authMiddleware") + defer span.End() - if err := c.ShouldBindBodyWith(&blacklistBody, binding.JSON); err != nil { - c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) - c.Abort() - return - } + appID := c.Request.Header.Get("AppID") + timestamp := c.Request.Header.Get("Timestamp") + nonce := c.Request.Header.Get("Nonce") + signature := c.Request.Header.Get("Signature") + queryString := c.Request.Header.Get("QueryString") + bodyBytes, _ := io.ReadAll(c.Request.Body) + bodyStr := string(bodyBytes) - appid := cfg.AppID - appsecret := cfg.AppSecret + c.Request.Body = io.NopCloser(strings.NewReader(bodyStr)) - nonce := c.GetHeader("nonce") - timestamp := c.GetHeader("timestamp") - queryString := c.GetHeader("queryString") - if nonce == "" || timestamp == "" || appid == "" { - c.JSON(http.StatusConflict, gin.H{"error": "missing headers"}) - c.Abort() - return - } + span.SetAttributes( + attribute.String("appId", appID), + attribute.String("timestamp", timestamp), + attribute.String("nonce", nonce), + attribute.String("signature", signature), + attribute.String("queryString", queryString), + attribute.String("bodyString", bodyStr), + ) + + message := fmt.Sprintf("%s%s%s%s%s%s%s", + appID, timestamp, nonce, "POST", "/api/data/sync/", queryString, bodyStr) - // reconstruct signature - expected := client.GenerateSignature(appsecret, appid, timestamp, nonce, queryString, blacklistBody) + span.AddEvent("message", trace.WithAttributes(attribute.String("message", message))) - if c.GetHeader("Signature") != expected { - c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"}) + expectedSignature := client.GenerateSignature(cfg.AppSecret, message) + + span.AddEvent("generated_signature", trace.WithAttributes(attribute.String("expectedSignature", expectedSignature))) + + if expectedSignature != signature { + span.AddEvent("error", trace.WithAttributes(attribute.String("error", "Invalid signature"))) + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid signature"}) c.Abort() return } + + span.AddEvent("signature_validated") c.Next() } } diff --git a/contrib/screener-api/screener/suite_test.go b/contrib/screener-api/screener/suite_test.go index aa759d400e..7be0ee3bc1 100644 --- a/contrib/screener-api/screener/suite_test.go +++ b/contrib/screener-api/screener/suite_test.go @@ -80,7 +80,7 @@ func (s *ScreenerSuite) TestScreener() { s.T().Setenv("TRM_URL", "") cfg := config.Config{ - AppSecret: "appsecret", + AppSecret: "secret", AppID: "appid", TRMKey: "", Rulesets: map[string]config.RulesetConfig{ @@ -155,7 +155,7 @@ func (s *ScreenerSuite) TestScreener() { // now test crud screener blacklistBody := client.BlackListBody{ - TypeReq: "create", + Type: "create", ID: "1", Data: "{\"test\":\"data\"}", Address: "0x123", @@ -166,27 +166,31 @@ func (s *ScreenerSuite) TestScreener() { // post to the blacklist status, err := apiClient.BlacklistAddress(s.GetTestContext(), cfg.AppSecret, cfg.AppID, blacklistBody) + fmt.Println(status) Equal(s.T(), "success", status) Nil(s.T(), err) // update an address on the blacklist - blacklistBody.TypeReq = "update" + blacklistBody.Type = "update" blacklistBody.Remark = "new remark" status, err = apiClient.BlacklistAddress(s.GetTestContext(), cfg.AppSecret, cfg.AppID, blacklistBody) + fmt.Println(status) Equal(s.T(), "success", status) Nil(s.T(), err) // delete the address on the blacklist - blacklistBody.TypeReq = "delete" + blacklistBody.Type = "delete" blacklistBody.ID = "1" status, err = apiClient.BlacklistAddress(s.GetTestContext(), cfg.AppSecret, cfg.AppID, blacklistBody) + fmt.Println(status) Equal(s.T(), "success", status) Nil(s.T(), err) // unauthorized status, err = apiClient.BlacklistAddress(s.GetTestContext(), "bad", cfg.AppID, blacklistBody) + fmt.Println(status) NotEqual(s.T(), "success", status) NotNil(s.T(), err) }