Skip to content

Commit

Permalink
test: gtid and context ut (#759)
Browse files Browse the repository at this point in the history
  • Loading branch information
baerwang committed Sep 16, 2023
1 parent b40e2af commit df70525
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 27 deletions.
38 changes: 13 additions & 25 deletions pkg/runtime/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,7 @@ func WithHints(ctx context.Context, hints []*hint.Hint) context.Context {

// Tenant extracts the tenant.
func Tenant(ctx context.Context) string {
tenant, ok := ctx.Value(proto.ContextKeyTenant{}).(string)
if !ok {
return ""
}
return tenant
return isString(ctx, proto.ContextKeyTenant{})
}

// IsRead returns true if this is a read operation
Expand All @@ -99,40 +95,25 @@ func IsDirect(ctx context.Context) bool {

// SQL returns the original sql string.
func SQL(ctx context.Context) string {
if sql, ok := ctx.Value(proto.ContextKeySQL{}).(string); ok {
return sql
}
return ""
return isString(ctx, proto.ContextKeySQL{})
}

func Schema(ctx context.Context) string {
if schema, ok := ctx.Value(proto.ContextKeySchema{}).(string); ok {
return schema
}
return ""
return isString(ctx, proto.ContextKeySchema{})
}

func Version(ctx context.Context) string {
if schema, ok := ctx.Value(proto.ContextKeyServerVersion{}).(string); ok {
return schema
}
return ""
return isString(ctx, proto.ContextKeyServerVersion{})
}

// NodeLabel returns the label of node.
func NodeLabel(ctx context.Context) string {
if label, ok := ctx.Value(keyNodeLabel{}).(string); ok {
return label
}
return ""
return isString(ctx, keyNodeLabel{})
}

// TransactionID returns the transactions id
func TransactionID(ctx context.Context) string {
if label, ok := ctx.Value(keyTransactionID{}).(string); ok {
return label
}
return ""
return isString(ctx, keyTransactionID{})
}

// Hints extracts the hints.
Expand Down Expand Up @@ -162,3 +143,10 @@ func getFlag(ctx context.Context) cFlag {
}
return f
}

func isString(ctx context.Context, v any) string {
if data, ok := ctx.Value(v).(string); ok {
return data
}
return ""
}
64 changes: 64 additions & 0 deletions pkg/runtime/context/context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package context

import (
"context"
"testing"
)

import (
"github.com/stretchr/testify/assert"
)

import (
"github.com/arana-db/arana/pkg/proto"
"github.com/arana-db/arana/pkg/proto/hint"
)

func TestContext(t *testing.T) {
ctx := context.Background()

id := "1024"
assert.Equal(t, TransactionID(WithTransactionID(ctx, id)), id)

label := "arana"
assert.Equal(t, NodeLabel(WithNodeLabel(ctx, label)), label)

res, err := hint.Parse("master")
assert.NoError(t, err)
hints := []*hint.Hint{res}
assert.Empty(t, Hints(ctx))
assert.Equal(t, Hints(WithHints(ctx, hints)), hints)

assert.Empty(t, TransientVariables(ctx))
value, err := proto.NewValue("arana")
assert.NoError(t, err)
variables := map[string]proto.Value{"arana": value}
variablesCtx := context.WithValue(ctx, proto.ContextKeyTransientVariables{}, variables)
assert.Equal(t, TransientVariables(variablesCtx), variables)

assert.True(t, IsDirect(WithDirect(ctx)))
assert.True(t, IsWrite(WithWrite(ctx)))
assert.True(t, IsRead(WithRead(ctx)))

assert.Empty(t, Tenant(ctx))
assert.Empty(t, SQL(ctx))
assert.Empty(t, Schema(ctx))
assert.Empty(t, Version(ctx))
}
4 changes: 2 additions & 2 deletions pkg/runtime/gtid/gtid.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package gtid

import (
"fmt"
"strconv"
"sync"
)

Expand Down Expand Up @@ -58,5 +58,5 @@ func NewID() ID {

// String ID to string
func (i ID) String() string {
return fmt.Sprintf("%s-%d", i.NodeID, i.Seq)
return i.NodeID + "-" + strconv.FormatInt(i.Seq, 10)
}
33 changes: 33 additions & 0 deletions pkg/runtime/gtid/gtid_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package gtid

import (
"testing"
)

import (
"github.com/stretchr/testify/assert"
)

func TestGtID(t *testing.T) {
id := NewID()
assert.NotEmpty(t, id.NodeID)
assert.NotEmpty(t, id.Seq)
assert.NotEmpty(t, id.String())
}

0 comments on commit df70525

Please sign in to comment.