Skip to content

Commit

Permalink
Merge branch 'jinzhu-v1.9.14'
Browse files Browse the repository at this point in the history
  • Loading branch information
lun committed Jul 6, 2020
2 parents 77551de + 65e997f commit 27d8d83
Show file tree
Hide file tree
Showing 29 changed files with 846 additions and 272 deletions.
40 changes: 1 addition & 39 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,41 +1,3 @@
# GORM

The fantastic ORM library for Golang, aims to be developer friendly.

[![go report card](https://goreportcard.com/badge/github.com/jinzhu/gorm "go report card")](https://goreportcard.com/report/github.com/jinzhu/gorm)
[![wercker status](https://app.wercker.com/status/8596cace912c9947dd9c8542ecc8cb8b/s/master "wercker status")](https://app.wercker.com/project/byKey/8596cace912c9947dd9c8542ecc8cb8b)
[![codecov](https://codecov.io/gh/jinzhu/gorm/branch/master/graph/badge.svg)](https://codecov.io/gh/jinzhu/gorm)
[![Join the chat at https://gitter.im/jinzhu/gorm](https://img.shields.io/gitter/room/jinzhu/gorm.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
[![Open Collective Backer](https://opencollective.com/gorm/tiers/backer/badge.svg?label=backer&color=brightgreen "Open Collective Backer")](https://opencollective.com/gorm)
[![Open Collective Sponsor](https://opencollective.com/gorm/tiers/sponsor/badge.svg?label=sponsor&color=brightgreen "Open Collective Sponsor")](https://opencollective.com/gorm)
[![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT)
[![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm)

## Overview

* Full-Featured ORM (almost)
* Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism)
* Hooks (Before/After Create/Save/Update/Delete/Find)
* Preloading (eager loading)
* Transactions
* Composite Primary Key
* SQL Builder
* Auto Migrations
* Logger
* Extendable, write Plugins based on GORM callbacks
* Every feature comes with tests
* Developer Friendly

## Getting Started

* GORM Guides [https://gorm.io](https://gorm.io)

## Contributing

[You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html)

## License

© Jinzhu, 2013~time.Now

Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License)
Moved to https://github.com/go-gorm/gorm
25 changes: 15 additions & 10 deletions callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package gorm
import "fmt"

// DefaultCallback default callbacks defined by gorm
var DefaultCallback = &Callback{}
var DefaultCallback = &Callback{logger: nopLogger{}}

// Callback is a struct that contains all CRUD callbacks
// Field `creates` contains callbacks will be call when creating object
Expand Down Expand Up @@ -96,11 +96,12 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
if cp.kind == "row_query" {
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
cp.logger.Print(fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...\n", callbackName))
cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName))
cp.before = "gorm:row_query"
}
}

cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum()))
cp.name = callbackName
cp.processor = &callback
cp.parent.processors = append(cp.parent.processors, cp)
Expand All @@ -110,7 +111,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *
// Remove a registered callback
// db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
func (cp *CallbackProcessor) Remove(callbackName string) {
cp.logger.Print(fmt.Sprintf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum()))
cp.logger.Print("info", fmt.Sprintf("[info] removing callback `%v` from %v", callbackName, fileWithLineNum()))
cp.name = callbackName
cp.remove = true
cp.parent.processors = append(cp.parent.processors, cp)
Expand All @@ -119,11 +120,11 @@ func (cp *CallbackProcessor) Remove(callbackName string) {

// Replace a registered callback with new callback
// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
// scope.SetColumn("Created", now)
// scope.SetColumn("Updated", now)
// scope.SetColumn("CreatedAt", now)
// scope.SetColumn("UpdatedAt", now)
// })
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
cp.logger.Print(fmt.Sprintf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()))
cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum()))
cp.name = callbackName
cp.processor = &callback
cp.replace = true
Expand All @@ -135,11 +136,15 @@ func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *S
// db.Callback().Create().Get("gorm:create")
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
for _, p := range cp.parent.processors {
if p.name == callbackName && p.kind == cp.kind && !cp.remove {
return *p.processor
if p.name == callbackName && p.kind == cp.kind {
if p.remove {
callback = nil
} else {
callback = *p.processor
}
}
}
return nil
return
}

// getRIndex get right index from string slice
Expand All @@ -162,7 +167,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
for _, cp := range cps {
// show warning message the callback name already exists
if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
cp.logger.Print(fmt.Sprintf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()))
cp.logger.Print("warning", fmt.Sprintf("[warning] duplicated callback `%v` from %v", cp.name, fileWithLineNum()))
}
allNames = append(allNames, cp.name)
}
Expand Down
50 changes: 37 additions & 13 deletions callback_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func beforeCreateCallback(scope *Scope) {
// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
func updateTimeStampForCreateCallback(scope *Scope) {
if !scope.HasError() {
now := NowFunc()
now := scope.db.nowFunc()

if createdAtField, ok := scope.FieldByName("CreatedAt"); ok {
if createdAtField.IsBlank {
Expand Down Expand Up @@ -100,11 +100,15 @@ func createCallback(scope *Scope) {
returningColumn = scope.Quote(primaryField.DBName)
}

lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(quotedTableName, returningColumn, columns)
var lastInsertIDReturningSuffix string
if lastInsertIDOutputInterstitial == "" {
lastInsertIDReturningSuffix = scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
}

if len(columns) == 0 {
scope.Raw(fmt.Sprintf(
"INSERT %v INTO %v %v%v%v",
"INSERT%v INTO %v %v%v%v",
addExtraSpaceIfExist(insertModifier),
quotedTableName,
scope.Dialect().DefaultValueStr(),
Expand All @@ -113,18 +117,19 @@ func createCallback(scope *Scope) {
))
} else {
scope.Raw(fmt.Sprintf(
"INSERT %v INTO %v (%v) VALUES (%v)%v%v",
"INSERT%v INTO %v (%v)%v VALUES (%v)%v%v",
addExtraSpaceIfExist(insertModifier),
scope.QuotedTableName(),
strings.Join(columns, ","),
addExtraSpaceIfExist(lastInsertIDOutputInterstitial),
strings.Join(placeholders, ","),
addExtraSpaceIfExist(extraOption),
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
))
}

// execute create sql
if lastInsertIDReturningSuffix == "" || primaryField == nil {
// execute create sql: no primaryField
if primaryField == nil {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
// set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected()
Expand All @@ -136,16 +141,35 @@ func createCallback(scope *Scope) {
}
}
}
} else {
if primaryField.Field.CanAddr() {
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
primaryField.IsBlank = false
scope.db.RowsAffected = 1
return
}

// execute create sql: lastInsertID implemention for majority of dialects
if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
// set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected()

// set primary value to primary field
if primaryField != nil && primaryField.IsBlank {
if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
scope.Err(primaryField.Set(primaryValue))
}
}
} else {
scope.Err(ErrUnaddressable)
}
return
}

// execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql)
if primaryField.Field.CanAddr() {
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
primaryField.IsBlank = false
scope.db.RowsAffected = 1
}
} else {
scope.Err(ErrUnaddressable)
}
return
}
}

Expand Down
4 changes: 2 additions & 2 deletions callback_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func init() {
// beforeDeleteCallback will invoke `BeforeDelete` method before deleting
func beforeDeleteCallback(scope *Scope) {
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
scope.Err(errors.New("Missing WHERE clause while deleting"))
scope.Err(errors.New("missing WHERE clause while deleting"))
return
}
if !scope.HasError() {
Expand All @@ -40,7 +40,7 @@ func deleteCallback(scope *Scope) {
"UPDATE %v SET %v=%v%v%v",
scope.QuotedTableName(),
scope.Quote(deletedAtField.DBName),
scope.AddToVars(NowFunc()),
scope.AddToVars(scope.db.nowFunc()),
addExtraSpaceIfExist(scope.CombinedConditionSql()),
addExtraSpaceIfExist(extraOption),
)).Exec()
Expand Down
5 changes: 5 additions & 0 deletions callback_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ func queryCallback(scope *Scope) {

if !scope.HasError() {
scope.db.RowsAffected = 0

if str, ok := scope.Get("gorm:query_hint"); ok {
scope.SQL = fmt.Sprint(str) + scope.SQL
}

if str, ok := scope.Get("gorm:query_option"); ok {
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
}
Expand Down
5 changes: 5 additions & 0 deletions callback_row_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ type RowsQueryResult struct {
func rowQueryCallback(scope *Scope) {
if result, ok := scope.InstanceGet("row_query_result"); ok {
scope.prepareQuerySQL()

if str, ok := scope.Get("gorm:query_hint"); ok {
scope.SQL = fmt.Sprint(str) + scope.SQL
}

if str, ok := scope.Get("gorm:query_option"); ok {
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
}
Expand Down
4 changes: 2 additions & 2 deletions callback_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func assignUpdatingAttributesCallback(scope *Scope) {
// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
func beforeUpdateCallback(scope *Scope) {
if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
scope.Err(errors.New("Missing WHERE clause while updating"))
scope.Err(errors.New("missing WHERE clause while updating"))
return
}
if _, ok := scope.Get("gorm:update_column"); !ok {
Expand All @@ -50,7 +50,7 @@ func beforeUpdateCallback(scope *Scope) {
// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
func updateTimeStampForUpdateCallback(scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok {
scope.SetColumn("UpdatedAt", NowFunc())
scope.SetColumn("UpdatedAt", scope.db.nowFunc())
}
}

Expand Down
78 changes: 75 additions & 3 deletions callbacks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ package gorm_test

import (
"errors"

"github.com/lun-zhang/gorm"

"reflect"
"testing"

"github.com/lun-zhang/gorm"
)

func (s *Product) BeforeCreate() (err error) {
Expand Down Expand Up @@ -175,3 +174,76 @@ func TestCallbacksWithErrors(t *testing.T) {
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
}
}

func TestGetCallback(t *testing.T) {
scope := DB.NewScope(nil)

if DB.Callback().Create().Get("gorm:test_callback") != nil {
t.Errorf("`gorm:test_callback` should be nil")
}

DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 1) })
callback := DB.Callback().Create().Get("gorm:test_callback")
if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil")
}
callback(scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 1 {
t.Errorf("`gorm:test_callback_value` should be `1, true` but `%v, %v`", v, ok)
}

DB.Callback().Create().Replace("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 2) })
callback = DB.Callback().Create().Get("gorm:test_callback")
if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil")
}
callback(scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 2 {
t.Errorf("`gorm:test_callback_value` should be `2, true` but `%v, %v`", v, ok)
}

DB.Callback().Create().Remove("gorm:test_callback")
if DB.Callback().Create().Get("gorm:test_callback") != nil {
t.Errorf("`gorm:test_callback` should be nil")
}

DB.Callback().Create().Register("gorm:test_callback", func(scope *gorm.Scope) { scope.Set("gorm:test_callback_value", 3) })
callback = DB.Callback().Create().Get("gorm:test_callback")
if callback == nil {
t.Errorf("`gorm:test_callback` should be non-nil")
}
callback(scope)
if v, ok := scope.Get("gorm:test_callback_value"); !ok || v != 3 {
t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
}
}

func TestUseDefaultCallback(t *testing.T) {
createCallbackName := "gorm:test_use_default_callback_for_create"
gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) {
// nop
})
if gorm.DefaultCallback.Create().Get(createCallbackName) == nil {
t.Errorf("`%s` expected non-nil, but got nil", createCallbackName)
}
gorm.DefaultCallback.Create().Remove(createCallbackName)
if gorm.DefaultCallback.Create().Get(createCallbackName) != nil {
t.Errorf("`%s` expected nil, but got non-nil", createCallbackName)
}

updateCallbackName := "gorm:test_use_default_callback_for_update"
scopeValueName := "gorm:test_use_default_callback_for_update_value"
gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) {
scope.Set(scopeValueName, 1)
})
gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) {
scope.Set(scopeValueName, 2)
})

scope := DB.NewScope(nil)
callback := gorm.DefaultCallback.Update().Get(updateCallbackName)
callback(scope)
if v, ok := scope.Get(scopeValueName); !ok || v != 2 {
t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok)
}
}
Loading

0 comments on commit 27d8d83

Please sign in to comment.