Skip to content

Commit

Permalink
Fix CallbackProcessor.Get() for removed or replaced same name callback (
Browse files Browse the repository at this point in the history
  • Loading branch information
zaneli authored and jinzhu committed Sep 12, 2019
1 parent b954854 commit d5cafb5
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 6 deletions.
10 changes: 7 additions & 3 deletions callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,15 @@ func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *S
// db.Callback().Create().Get("gorm:create")
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
for _, p := range cp.parent.processors {
if p.name == callbackName && p.kind == cp.kind && !cp.remove {
return *p.processor
if p.name == callbackName && p.kind == cp.kind {
if p.remove {
callback = nil
} else {
callback = *p.processor
}
}
}
return nil
return
}

// getRIndex get right index from string slice
Expand Down
48 changes: 45 additions & 3 deletions callbacks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ package gorm_test

import (
"errors"

"github.com/jinzhu/gorm"

"reflect"
"testing"

"github.com/jinzhu/gorm"
)

func (s *Product) BeforeCreate() (err error) {
Expand Down Expand Up @@ -175,3 +174,46 @@ func TestCallbacksWithErrors(t *testing.T) {
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
}
}

func TestGetCallback(t *testing.T) {
scope := DB.NewScope(nil)

if DB.Callback().Create().Get("gorm:test_callback") != nil {
t.Errorf("`gorm:test_callback` should be nil")
}

DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) })
callback := DB.Callback().Create().Get("gorm:test_callback")
if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil")
}
callback(scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 {
t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok)
}

DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) })
callback = DB.Callback().Create().Get("gorm:test_callback")
if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil")
}
callback(scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 {
t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok)
}

DB.Callback().Create().Remove("gorm:test_callback")
if DB.Callback().Create().Get("gorm:test_callback") != nil {
t.Errorf("`gorm:test_callback` should be nil")
}

DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) })
callback = DB.Callback().Create().Get("gorm:test_callback")
if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil")
}
callback(scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 {
t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
}
}

0 comments on commit d5cafb5

Please sign in to comment.