diff --git a/go/ai/gen.go b/go/ai/gen.go index c150622c7..927b279fc 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -23,7 +23,7 @@ type Candidate struct { Custom any `json:"custom,omitempty"` FinishMessage string `json:"finishMessage,omitempty"` FinishReason FinishReason `json:"finishReason,omitempty"` - Index int `json:"index,omitempty"` + Index int `json:"index"` Message *Message `json:"message,omitempty"` Usage *GenerationUsage `json:"usage,omitempty"` } diff --git a/go/ai/request_helpers.go b/go/ai/request_helpers.go index 6252a455e..1f722e003 100644 --- a/go/ai/request_helpers.go +++ b/go/ai/request_helpers.go @@ -15,7 +15,7 @@ package ai // NewGenerateRequest create a new GenerateRequest with provided config and -// messages. Use NewUserTextGenerateRequest if you have a simple text-only user message. +// messages. func NewGenerateRequest(config any, messages ...*Message) *GenerateRequest { return &GenerateRequest{ Config: config, diff --git a/go/tests/api_test.go b/go/tests/api_test.go new file mode 100644 index 000000000..6b8c51171 --- /dev/null +++ b/go/tests/api_test.go @@ -0,0 +1,262 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api_test + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + "testing" + "time" + + "gopkg.in/yaml.v2" +) + +type testFile struct { + App string + Tests []test +} + +type test struct { + Path string + Post any + Body any +} + +const hostPort = "http://localhost:3100" + +func TestReflectionAPI(t *testing.T) { + filenames, err := filepath.Glob(filepath.FromSlash("../../tests/*.yaml")) + if err != nil { + t.Fatal(err) + } + for _, fn := range filenames { + if filepath.Base(fn) == "pnpm-lock.yaml" { + continue + } + data, err := os.ReadFile(fn) + if err != nil { + t.Fatal(err) + } + var file testFile + if err := yaml.Unmarshal(data, &file); err != nil { + t.Fatal(err) + } + t.Run(file.App, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + wait, err := startGenkitApp(ctx, file.App) + if err != nil { + t.Fatal(err) + } + t.Log("started app") + defer func() { + cancel() + err := wait() + t.Logf("wait returned %v", err) + }() + for { + time.Sleep(50 * time.Millisecond) + t.Log("checking...") + if _, err := http.Get(hostPort + "/api/__health"); err == nil { + break + } + } + t.Log("app ready") + for _, test := range file.Tests { + runTest(t, test) + } + }) + } +} + +func runTest(t *testing.T, test test) { + t.Run(test.Path[1:], func(t *testing.T) { + if t.Name() == "TestReflectionAPI/test_app/api/actions" { + t.Skip("FIXME: skipping because Go and JS schemas are not aligned") + } + url := hostPort + test.Path + var ( + res *http.Response + err error + ) + if test.Post != nil { + body, err := json.Marshal(yamlToJSON(test.Post)) + if err != nil { + t.Fatal(err) + } + res, err = http.Post(url, "application/json", bytes.NewReader(body)) + } else { + res, err = http.Get(url) + } + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + t.Fatalf("got status %s", res.Status) + } + var got any + dec := json.NewDecoder(res.Body) + dec.UseNumber() + if err := dec.Decode(&got); err != nil { + t.Fatal(err) + } + want := yamlToJSON(test.Body) + msgs := compare(got, want) + if len(msgs) > 0 { + t.Logf("%s", prettyJSON(got)) + t.Fatal(strings.Join(msgs, "\n")) + } + }) +} + +func startGenkitApp(ctx context.Context, dir string) (func() error, error) { + // If we invoke `go run`, the actual server is a child of the go command, and not + // easy to kill. So instead we compile the app and run it directly. + tmp := os.TempDir() + cmd := exec.Command("go", "build", "-o", tmp, ".") + cmd.Dir = dir + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return nil, fmt.Errorf("build: %w", err) + } + + // `go build .` will use the working directory as the name of the executable. + cmd = exec.CommandContext(ctx, "./"+dir) + cmd.Dir = tmp + cmd.Env = append(os.Environ(), "GENKIT_ENV=dev") + cmd.WaitDelay = time.Second + if err := cmd.Start(); err != nil { + return nil, err + } + return cmd.Wait, nil +} + +func compare(got, want any) []string { + var msgs []string + compare1(nil, got, want, func(path []string, format string, args ...any) { + msg := fmt.Sprintf(format, args...) + prefix := strings.Join(path, ".") + if prefix == "" { + prefix = "top level" + } + msgs = append(msgs, fmt.Sprintf("%s: %s", prefix, msg)) + }) + return msgs +} + +func compare1(path []string, got, want any, add func([]string, string, ...any)) { + check := func() { + if got != want { + add(path, "\ngot %v (%[1]T)\nwant %v (%[2]T)", got, want) + } + } + + switch w := want.(type) { + case map[string]any: + g, ok := got.(map[string]any) + if !ok { + add(path, "want map, got %T", got) + return + } + for k, wv := range w { + gv, ok := g[k] + if !ok { + add(path, "missing key %q", k) + } else { + compare1(append(path, k), gv, wv, add) + } + } + case int: + if n, ok := got.(json.Number); ok { + g, err := n.Int64() + if err != nil { + add(path, "got number %s, want %d (%[2]T)", n, w) + } + if g != int64(w) { + add(path, "got %d, want %d", g, w) + } + } else { + check() + } + + case float64: + if n, ok := got.(json.Number); ok { + g, err := n.Float64() + if err != nil { + add(path, "got number %s, want %f (%[2]T)", n, w) + } + if g != w { + add(path, "got %f, want %f", g, w) + } + } else { + check() + } + + case []any: + g, ok := got.([]any) + if !ok { + add(path, "want slice, got %T", got) + return + } + if len(g) != len(w) { + add(path, "got slice length %d, want %d", len(g), len(w)) + return + } + for i, ew := range w { + compare1(append(path, strconv.Itoa(i)), g[i], ew, add) + } + default: + check() + } +} + +// The yaml package unmarshals using different types than +// the json package. Convert the yaml form to the JSON form. +func yamlToJSON(y any) any { + switch y := y.(type) { + case map[any]any: + j := map[string]any{} + for k, v := range y { + j[fmt.Sprint(k)] = yamlToJSON(v) + } + return j + case []any: + j := make([]any, len(y)) + for i, e := range y { + j[i] = yamlToJSON(e) + } + return j + default: + return y + } +} + +func prettyJSON(x any) string { + var sb strings.Builder + enc := json.NewEncoder(&sb) + enc.SetIndent("", " ") + if err := enc.Encode(x); err != nil { + panic(err) + } + return sb.String() +} diff --git a/go/tests/test_app/main.go b/go/tests/test_app/main.go new file mode 100644 index 000000000..3ae29e23e --- /dev/null +++ b/go/tests/test_app/main.go @@ -0,0 +1,58 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This program doesn't do anything interesting. +// It is used by go/tests/api_test.go. +package main + +import ( + "context" + "encoding/json" + "log" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" +) + +func main() { + model := ai.DefineModel("", "customReflector", nil, echo) + genkit.DefineFlow("testFlow", func(ctx context.Context, in string) (string, error) { + res, err := model.Generate(ctx, ai.NewGenerateRequest(nil, ai.NewUserTextMessage(in)), nil) + if err != nil { + return "", err + } + _ = res + return "TBD", nil + }) + if err := genkit.Init(context.Background(), nil); err != nil { + log.Fatal(err) + } +} + +func echo(ctx context.Context, req *ai.GenerateRequest, cb func(context.Context, *ai.GenerateResponseChunk) error) (*ai.GenerateResponse, error) { + jsonBytes, err := json.Marshal(req) + if err != nil { + return nil, err + } + return &ai.GenerateResponse{ + Candidates: []*ai.Candidate{{ + Index: 0, + FinishReason: "stop", + Message: &ai.Message{ + Role: "model", + Content: []*ai.Part{ai.NewTextPart(string(jsonBytes))}, + }, + }}, + }, nil +} diff --git a/tests/test_js_app.yaml b/tests/test_js_app.yaml index 53511edf3..6ef2893e2 100644 --- a/tests/test_js_app.yaml +++ b/tests/test_js_app.yaml @@ -4,7 +4,7 @@ # TODO: add more about the flow in the /api/actions test. # TODO: add more test cases. -app: test_js_app +app: test_app tests: - path: /api/runAction post: @@ -19,7 +19,7 @@ tests: message: role: model content: - - text: '{"messages":[{"role":"user","content":[{"text":"hello"}]}]}' + - text: '{"messages":[{"content":[{"text":"hello"}],"role":"user"}]}' - path: /api/actions body: diff --git a/tests/test_js_app/src/index.ts b/tests/test_js_app/src/index.ts index 0a550caff..2766866da 100644 --- a/tests/test_js_app/src/index.ts +++ b/tests/test_js_app/src/index.ts @@ -26,6 +26,11 @@ defineModel( name: 'customReflector', }, async (input) => { + // In Go, JSON object properties are output in sorted order. + // JSON.stringify uses the order they appear in the program. + // So swap the order here to match Go. + const m = input.messages[0]; + input.messages[0] = { content: m.content, role: m.role }; return { candidates: [ { @@ -61,7 +66,7 @@ export const testFlow = defineFlow( prompt: subject, }); - const want = `{"messages":[{"role":"user","content":[{"text":"${subject}"}]}],"tools":[],"output":{"format":"text"}}`; + const want = `{"messages":[{"content":[{"text":"${subject}"}],"role":"user"}],"tools":[],"output":{"format":"text"}}`; if (response.text() !== want) { throw new Error(`Expected ${want} but got ${response.text()}`); }