diff --git a/builtin/builtin.go b/builtin/builtin.go index 4995e75c..38a6f2f3 100644 --- a/builtin/builtin.go +++ b/builtin/builtin.go @@ -472,9 +472,27 @@ var Builtins = []*Function{ { Name: "now", Func: func(args ...any) (any, error) { - return time.Now(), nil + if len(args) == 0 { + return time.Now(), nil + } + if len(args) == 1 { + if tz, ok := args[0].(*time.Location); ok { + return time.Now().In(tz), nil + } + } + return nil, fmt.Errorf("invalid number of arguments (expected 0, got %d)", len(args)) + }, + Validate: func(args []reflect.Type) (reflect.Type, error) { + if len(args) == 0 { + return timeType, nil + } + if len(args) == 1 { + if args[0].AssignableTo(locationType) { + return timeType, nil + } + } + return anyType, fmt.Errorf("invalid number of arguments (expected 0, got %d)", len(args)) }, - Types: types(new(func() time.Time)), }, { Name: "duration", @@ -486,9 +504,17 @@ var Builtins = []*Function{ { Name: "date", Func: func(args ...any) (any, error) { + tz, ok := args[0].(*time.Location) + if ok { + args = args[1:] + } + date := args[0].(string) if len(args) == 2 { layout := args[1].(string) + if tz != nil { + return time.ParseInLocation(layout, date, tz) + } return time.Parse(layout, date) } if len(args) == 3 { @@ -516,18 +542,32 @@ var Builtins = []*Function{ time.RFC1123, } for _, layout := range layouts { - t, err := time.Parse(layout, date) - if err == nil { - return t, nil + if tz == nil { + t, err := time.Parse(layout, date) + if err == nil { + return t, nil + } + } else { + t, err := time.ParseInLocation(layout, date, tz) + if err == nil { + return t, nil + } } } return nil, fmt.Errorf("invalid date %s", date) }, - Types: types( - new(func(string) time.Time), - new(func(string, string) time.Time), - new(func(string, string, string) time.Time), - ), + Validate: func(args []reflect.Type) (reflect.Type, error) { + if len(args) < 1 { + return anyType, fmt.Errorf("invalid number of arguments (expected at least 1, got %d)", len(args)) + } + if args[0].AssignableTo(locationType) { + args = args[1:] + } + if len(args) > 3 { + return anyType, fmt.Errorf("invalid number of arguments (expected at most 3, got %d)", len(args)) + } + return timeType, nil + }, }, { Name: "timezone", diff --git a/builtin/builtin_test.go b/builtin/builtin_test.go index fc2395f7..09a13b9f 100644 --- a/builtin/builtin_test.go +++ b/builtin/builtin_test.go @@ -170,6 +170,7 @@ func TestBuiltin_works_with_any(t *testing.T) { config := map[string]struct { arity int }{ + "now": {0}, "get": {2}, "take": {2}, "sortBy": {2}, diff --git a/builtin/utils.go b/builtin/utils.go index 7d3b6ee8..29a95731 100644 --- a/builtin/utils.go +++ b/builtin/utils.go @@ -3,14 +3,17 @@ package builtin import ( "fmt" "reflect" + "time" ) var ( - anyType = reflect.TypeOf(new(any)).Elem() - integerType = reflect.TypeOf(0) - floatType = reflect.TypeOf(float64(0)) - arrayType = reflect.TypeOf([]any{}) - mapType = reflect.TypeOf(map[any]any{}) + anyType = reflect.TypeOf(new(any)).Elem() + integerType = reflect.TypeOf(0) + floatType = reflect.TypeOf(float64(0)) + arrayType = reflect.TypeOf([]any{}) + mapType = reflect.TypeOf(map[any]any{}) + timeType = reflect.TypeOf(new(time.Time)).Elem() + locationType = reflect.TypeOf(new(time.Location)) ) func kind(t reflect.Type) reflect.Kind { diff --git a/expr.go b/expr.go index ba786c01..83e0a167 100644 --- a/expr.go +++ b/expr.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "reflect" + "time" "github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/builtin" @@ -183,6 +184,17 @@ func WithContext(name string) Option { }) } +// Timezone sets default timezone for date() and now() builtin functions. +func Timezone(name string) Option { + tz, err := time.LoadLocation(name) + if err != nil { + panic(err) + } + return Patch(patcher.WithTimezone{ + Location: tz, + }) +} + // Compile parses and compiles given input expression to bytecode program. func Compile(input string, ops ...Option) (*vm.Program, error) { config := conf.CreateNew() diff --git a/expr_test.go b/expr_test.go index 4b4e2edf..889d82b3 100644 --- a/expr_test.go +++ b/expr_test.go @@ -584,6 +584,23 @@ func ExampleWithContext() { // Output: 42 } +func ExampleWithTimezone() { + program, err := expr.Compile(`now().Location().String()`, expr.Timezone("Asia/Kamchatka")) + if err != nil { + fmt.Printf("%v", err) + return + } + + output, err := expr.Run(program, nil) + if err != nil { + fmt.Printf("%v", err) + return + } + + fmt.Printf("%v", output) + // Output: Asia/Kamchatka +} + func TestExpr_readme_example(t *testing.T) { env := map[string]any{ "greet": "Hello, %v!", diff --git a/patcher/with_timezone.go b/patcher/with_timezone.go new file mode 100644 index 00000000..83eb28e9 --- /dev/null +++ b/patcher/with_timezone.go @@ -0,0 +1,25 @@ +package patcher + +import ( + "time" + + "github.com/expr-lang/expr/ast" +) + +// WithTimezone passes Location to date() and now() functions. +type WithTimezone struct { + Location *time.Location +} + +func (t WithTimezone) Visit(node *ast.Node) { + if btin, ok := (*node).(*ast.BuiltinNode); ok { + switch btin.Name { + case "date", "now": + loc := &ast.ConstantNode{Value: t.Location} + ast.Patch(node, &ast.BuiltinNode{ + Name: btin.Name, + Arguments: append([]ast.Node{loc}, btin.Arguments...), + }) + } + } +} diff --git a/patcher/with_timezone_test.go b/patcher/with_timezone_test.go new file mode 100644 index 00000000..2cd099f2 --- /dev/null +++ b/patcher/with_timezone_test.go @@ -0,0 +1,28 @@ +package patcher_test + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/expr-lang/expr" +) + +func TestWithTimezone_date(t *testing.T) { + program, err := expr.Compile(`date("2024-05-07 23:00:00")`, expr.Timezone("Europe/Zurich")) + require.NoError(t, err) + + out, err := expr.Run(program, nil) + require.NoError(t, err) + require.Equal(t, "2024-05-07T23:00:00+02:00", out.(time.Time).Format(time.RFC3339)) +} + +func TestWithTimezone_now(t *testing.T) { + program, err := expr.Compile(`now()`, expr.Timezone("Asia/Kamchatka")) + require.NoError(t, err) + + out, err := expr.Run(program, nil) + require.NoError(t, err) + require.Equal(t, "Asia/Kamchatka", out.(time.Time).Location().String()) +} diff --git a/testdata/examples.txt b/testdata/examples.txt index b02094a7..9c18442e 100644 --- a/testdata/examples.txt +++ b/testdata/examples.txt @@ -7231,7 +7231,6 @@ get(false ? f64 : 1, ok) get(false ? f64 : score, add) get(false ? false : f32, i) get(false ? i32 : list, i64) -get(false ? i32 : ok, now(div, array)) get(false ? i64 : foo, Bar) get(false ? i64 : true, f64) get(false ? score : ok, trimSuffix("bar", "bar"))