Skip to content

Commit

Permalink
feat(math): add safe arithmetic (cosmos#18552)
Browse files Browse the repository at this point in the history
  • Loading branch information
ocnc committed Nov 28, 2023
1 parent 7d5c2db commit 7d5e9f1
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 23 deletions.
1 change: 1 addition & 0 deletions math/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ Ref: https://github.com/commitizen/conventional-commit-types/blob/v3.0.0/index.j
## [Unreleased]

### Features
* [#18552](https://github.com/cosmos/cosmos-sdk/pull/18552) Add safe arithmetic operations for `math.Int` that return an error in case of an overflow or any mishap.

* [#18421](https://github.com/cosmos/cosmos-sdk/pull/18421) Add mutative api for `LegacyDec.BigInt()`.

Expand Down
100 changes: 79 additions & 21 deletions math/int.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package math
import (
"encoding"
"encoding/json"
"errors"
"fmt"
"math/big"
"strings"
Expand All @@ -13,6 +14,14 @@ import (
// MaxBitLen defines the maximum bit length supported bit Int and Uint types.
const MaxBitLen = 256

// Integer errors
var (
// ErrIntOverflow is the error returned when an integer overflow occurs
ErrIntOverflow = errors.New("Integer overflow")
// ErrDivideByZero is the error returned when a divide by zero occurs
ErrDivideByZero = errors.New("Divide by zero")
)

func newIntegerFromString(s string) (*big.Int, bool) {
return new(big.Int).SetString(s, 0)
}
Expand Down Expand Up @@ -259,80 +268,129 @@ func (i Int) LTE(i2 Int) bool {

// Add adds Int from another
func (i Int) Add(i2 Int) (res Int) {
res = Int{add(i.i, i2.i)}
// Check overflow
if res.i.BitLen() > MaxBitLen {
panic("Int overflow")
x, err := i.SafeAdd(i2)
if err != nil {
panic(err)
}
return
return x
}

// AddRaw adds int64 to Int
func (i Int) AddRaw(i2 int64) Int {
return i.Add(NewInt(i2))
}

// SafeAdd adds Int from another and returns an error if overflow
func (i Int) SafeAdd(i2 Int) (res Int, err error) {
res = Int{add(i.i, i2.i)}
// Check overflow
if res.i.BitLen() > MaxBitLen {
return Int{}, ErrIntOverflow
}
return res, nil
}

// Sub subtracts Int from another
func (i Int) Sub(i2 Int) (res Int) {
res = Int{sub(i.i, i2.i)}
// Check overflow
if res.i.BitLen() > MaxBitLen {
panic("Int overflow")
x, err := i.SafeSub(i2)
if err != nil {
panic(err)
}
return
return x
}

// SubRaw subtracts int64 from Int
func (i Int) SubRaw(i2 int64) Int {
return i.Sub(NewInt(i2))
}

// SafeSub subtracts Int from another and returns an error if overflow or underflow
func (i Int) SafeSub(i2 Int) (res Int, err error) {
res = Int{sub(i.i, i2.i)}
// Check overflow/underflow
if res.i.BitLen() > MaxBitLen {
return Int{}, ErrIntOverflow
}
return res, nil
}

// Mul multiples two Ints
func (i Int) Mul(i2 Int) (res Int) {
// Check overflow
if i.i.BitLen()+i2.i.BitLen()-1 > MaxBitLen {
panic("Int overflow")
}
res = Int{mul(i.i, i2.i)}
// Check overflow if sign of both are same
if res.i.BitLen() > MaxBitLen {
panic("Int overflow")
x, err := i.SafeMul(i2)
if err != nil {
panic(err)
}
return
return x
}

// MulRaw multipies Int and int64
func (i Int) MulRaw(i2 int64) Int {
return i.Mul(NewInt(i2))
}

// SafeMul multiples Int from another and returns an error if overflow
func (i Int) SafeMul(i2 Int) (res Int, err error) {
// Check overflow
if i.i.BitLen()+i2.i.BitLen()-1 > MaxBitLen {
return Int{}, ErrIntOverflow
}
res = Int{mul(i.i, i2.i)}
// Check overflow if sign of both are same
if res.i.BitLen() > MaxBitLen {
return Int{}, ErrIntOverflow
}
return res, nil
}

// Quo divides Int with Int
func (i Int) Quo(i2 Int) (res Int) {
// Check division-by-zero
if i2.i.Sign() == 0 {
x, err := i.SafeQuo(i2)
if err != nil {
panic("Division by zero")
}
return Int{div(i.i, i2.i)}
return x
}

// QuoRaw divides Int with int64
func (i Int) QuoRaw(i2 int64) Int {
return i.Quo(NewInt(i2))
}

// SafeQuo divides Int with Int and returns an error if division by zero
func (i Int) SafeQuo(i2 Int) (res Int, err error) {
// Check division-by-zero
if i2.i.Sign() == 0 {
return Int{}, ErrDivideByZero
}
return Int{div(i.i, i2.i)}, nil
}

// Mod returns remainder after dividing with Int
func (i Int) Mod(i2 Int) Int {
if i2.Sign() == 0 {
panic("division-by-zero")
x, err := i.SafeMod(i2)
if err != nil {
panic(err)
}
return Int{mod(i.i, i2.i)}
return x
}

// ModRaw returns remainder after dividing with int64
func (i Int) ModRaw(i2 int64) Int {
return i.Mod(NewInt(i2))
}

// SafeMod returns remainder after dividing with Int and returns an error if division by zero
func (i Int) SafeMod(i2 Int) (res Int, err error) {
if i2.Sign() == 0 {
return Int{}, ErrDivideByZero
}
return Int{mod(i.i, i2.i)}, nil
}

// Neg negates Int
func (i Int) Neg() (res Int) {
return Int{neg(i.i)}
Expand Down
44 changes: 42 additions & 2 deletions math/int_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,32 +111,66 @@ func (s *intTestSuite) TestIntPanic() {
s.Require().NotPanics(func() { i1.Add(i1) })
s.Require().NotPanics(func() { i2.Add(i2) })
s.Require().Panics(func() { i3.Add(i3) })
_, err := i1.SafeAdd(i1)
s.Require().Nil(err)
_, err = i2.SafeAdd(i2)
s.Require().Nil(err)
_, err = i3.SafeAdd(i3)
s.Require().Error(err)

s.Require().NotPanics(func() { i1.Sub(i1.Neg()) })
s.Require().NotPanics(func() { i2.Sub(i2.Neg()) })
s.Require().Panics(func() { i3.Sub(i3.Neg()) })
_, err = i1.SafeSub(i1.Neg())
s.Require().Nil(err)
_, err = i2.SafeSub(i2.Neg())
s.Require().Nil(err)
_, err = i3.SafeSub(i3.Neg())
s.Require().Error(err)

s.Require().Panics(func() { i1.Mul(i1) })
s.Require().Panics(func() { i2.Mul(i2) })
s.Require().Panics(func() { i3.Mul(i3) })
_, err = i1.SafeMul(i1)
s.Require().Error(err)
_, err = i2.SafeMul(i2)
s.Require().Error(err)
_, err = i3.SafeMul(i3)
s.Require().Error(err)

s.Require().Panics(func() { i1.Neg().Mul(i1.Neg()) })
s.Require().Panics(func() { i2.Neg().Mul(i2.Neg()) })
s.Require().Panics(func() { i3.Neg().Mul(i3.Neg()) })

// // Underflow check
_, err = i1.Neg().SafeMul(i1.Neg())
s.Require().Error(err)
_, err = i2.Neg().SafeMul(i2.Neg())
s.Require().Error(err)
_, err = i3.Neg().SafeMul(i3.Neg())
s.Require().Error(err)

// Underflow check
i3n := i3.Neg()
s.Require().NotPanics(func() { i3n.Sub(i1) })
s.Require().NotPanics(func() { i3n.Sub(i2) })
s.Require().Panics(func() { i3n.Sub(i3) })
_, err = i3n.SafeSub(i3)
s.Require().Error(err)

s.Require().NotPanics(func() { i3n.Add(i1.Neg()) })
s.Require().NotPanics(func() { i3n.Add(i2.Neg()) })
s.Require().Panics(func() { i3n.Add(i3.Neg()) })
_, err = i3n.SafeAdd(i3.Neg())
s.Require().Error(err)

s.Require().Panics(func() { i1.Mul(i1.Neg()) })
s.Require().Panics(func() { i2.Mul(i2.Neg()) })
s.Require().Panics(func() { i3.Mul(i3.Neg()) })
_, err = i1.SafeMul(i1.Neg())
s.Require().Error(err)
_, err = i2.SafeMul(i2.Neg())
s.Require().Error(err)
_, err = i3.SafeMul(i3.Neg())
s.Require().Error(err)

// Bound check
intmax := math.NewIntFromBigInt(new(big.Int).Sub(new(big.Int).Exp(big.NewInt(2), big.NewInt(256), nil), big.NewInt(1)))
Expand All @@ -145,12 +179,18 @@ func (s *intTestSuite) TestIntPanic() {
s.Require().NotPanics(func() { intmin.Sub(math.ZeroInt()) })
s.Require().Panics(func() { intmax.Add(math.OneInt()) })
s.Require().Panics(func() { intmin.Sub(math.OneInt()) })
_, err = intmax.SafeAdd(math.OneInt())
s.Require().Error(err)
_, err = intmin.SafeSub(math.OneInt())
s.Require().Error(err)

s.Require().NotPanics(func() { math.NewIntFromBigInt(nil) })
s.Require().True(math.NewIntFromBigInt(nil).IsNil())

// Division-by-zero check
s.Require().Panics(func() { i1.Quo(math.NewInt(0)) })
_, err = i1.SafeQuo(math.NewInt(0))
s.Require().Error(err)

s.Require().NotPanics(func() { math.Int{}.BigInt() })
}
Expand Down

0 comments on commit 7d5e9f1

Please sign in to comment.