Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Go] improve API #523

Merged
merged 2 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/core/logger"
"github.com/firebase/genkit/go/internal"
jba marked this conversation as resolved.
Show resolved Hide resolved
"github.com/firebase/genkit/go/internal/atype"
)

Expand Down Expand Up @@ -153,7 +154,7 @@ func validCandidate(c *Candidate, output *GenerateRequestOutput) (*Candidate, er
if err != nil {
return nil, fmt.Errorf("expected schema is not valid: %w", err)
}
if err = core.ValidateRaw([]byte(text), schemaBytes); err != nil {
if err = internal.ValidateRaw([]byte(text), schemaBytes); err != nil {
return nil, err
}
// TODO: Verify that it okay to replace all content with JSON.
Expand Down
127 changes: 45 additions & 82 deletions go/core/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ import (
"context"
"encoding/json"
"fmt"
"maps"
"time"

"github.com/firebase/genkit/go/core/logger"
"github.com/firebase/genkit/go/core/tracing"
"github.com/firebase/genkit/go/internal"
"github.com/firebase/genkit/go/internal/atype"
"github.com/firebase/genkit/go/internal/common"
"github.com/firebase/genkit/go/internal/metrics"
"github.com/firebase/genkit/go/internal/registry"
"github.com/invopop/jsonschema"
)

Expand All @@ -39,13 +41,6 @@ type Func[In, Out, Stream any] func(context.Context, In, func(context.Context, S

// TODO(jba): use a generic type alias for the above when they become available?

// NoStream indicates that the action or flow does not support streaming.
// A Func[I, O, NoStream] will ignore its streaming callback.
// Such a function corresponds to a Flow[I, O, struct{}].
type NoStream = func(context.Context, struct{}) error

type streamingCallback[Stream any] func(context.Context, Stream) error

// An Action is a named, observable operation.
// It consists of a function that takes an input of type I and returns an output
// of type O, optionally streaming values of type S incrementally by invoking a callback.
Expand All @@ -69,22 +64,22 @@ type Action[In, Out, Stream any] struct {

// DefineAction creates a new Action and registers it.
func DefineAction[In, Out any](provider, name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
return defineAction(globalRegistry, provider, name, atype, metadata, fn)
return defineAction(registry.Global, provider, name, atype, metadata, fn)
}

func defineAction[In, Out any](r *registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
a := newAction(provider+"/"+name, atype, metadata, fn)
r.registerAction(a)
func defineAction[In, Out any](r *registry.Registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
a := NewAction(provider+"/"+name, atype, metadata, fn)
r.RegisterAction(atype, a)
return a
}

func DefineStreamingAction[In, Out, Stream any](provider, name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
return defineStreamingAction(globalRegistry, provider, name, atype, metadata, fn)
return defineStreamingAction(registry.Global, provider, name, atype, metadata, fn)
}

func defineStreamingAction[In, Out, Stream any](r *registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
a := newStreamingAction(provider+"/"+name, atype, metadata, fn)
r.registerAction(a)
func defineStreamingAction[In, Out, Stream any](r *registry.Registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
a := NewStreamingAction(provider+"/"+name, atype, metadata, fn)
r.RegisterAction(atype, a)
return a
}

Expand All @@ -103,31 +98,33 @@ func DefineActionWithInputSchema[Out any](
inputSchema *jsonschema.Schema,
fn func(context.Context, any) (Out, error),
) *Action[any, Out, struct{}] {
return defineActionWithInputSchema(globalRegistry, provider, name, atype, metadata, inputSchema, fn)
return defineActionWithInputSchema(registry.Global, provider, name, atype, metadata, inputSchema, fn)
}

func defineActionWithInputSchema[Out any](
r *registry,
r *registry.Registry,
provider, name string,
atype atype.ActionType,
metadata map[string]any,
inputSchema *jsonschema.Schema,
fn func(context.Context, any) (Out, error),
) *Action[any, Out, struct{}] {
a := newActionWithInputSchema(provider+"/"+name, atype, metadata, fn, inputSchema)
r.registerAction(a)
r.RegisterAction(atype, a)
return a
}

// newAction creates a new Action with the given name and non-streaming function.
func newAction[In, Out any](name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
return newStreamingAction(name, atype, metadata, func(ctx context.Context, in In, cb NoStream) (Out, error) {
type noStream = func(context.Context, struct{}) error

// NewAction creates a new Action with the given name and non-streaming function.
func NewAction[In, Out any](name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
return NewStreamingAction(name, atype, metadata, func(ctx context.Context, in In, cb noStream) (Out, error) {
return fn(ctx, in)
})
}

// newStreamingAction creates a new Action with the given name and streaming function.
func newStreamingAction[In, Out, Stream any](name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
// NewStreamingAction creates a new Action with the given name and streaming function.
func NewStreamingAction[In, Out, Stream any](name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
var i In
var o Out
return &Action[In, Out, Stream]{
Expand All @@ -137,8 +134,8 @@ func newStreamingAction[In, Out, Stream any](name string, atype atype.ActionType
tracing.SetCustomMetadataAttr(ctx, "subtype", string(atype))
return fn(ctx, input, sc)
},
inputSchema: inferJSONSchema(i),
outputSchema: inferJSONSchema(o),
inputSchema: internal.InferJSONSchema(i),
outputSchema: internal.InferJSONSchema(o),
metadata: metadata,
}
}
Expand All @@ -153,18 +150,16 @@ func newActionWithInputSchema[Out any](name string, atype atype.ActionType, meta
return fn(ctx, input)
},
inputSchema: inputSchema,
outputSchema: inferJSONSchema(o),
outputSchema: internal.InferJSONSchema(o),
metadata: metadata,
}
}

// Name returns the Action's Name.
func (a *Action[In, Out, Stream]) Name() string { return a.name }

func (a *Action[In, Out, Stream]) actionType() atype.ActionType { return a.atype }

// setTracingState sets the action's tracing.State.
func (a *Action[In, Out, Stream]) setTracingState(tstate *tracing.State) { a.tstate = tstate }
func (a *Action[In, Out, Stream]) SetTracingState(tstate *tracing.State) { a.tstate = tstate }

// Run executes the Action's function in a new trace span.
func (a *Action[In, Out, Stream]) Run(ctx context.Context, input In, cb func(context.Context, Stream) error) (output Out, err error) {
Expand All @@ -180,37 +175,38 @@ func (a *Action[In, Out, Stream]) Run(ctx context.Context, input In, cb func(con
tstate := a.tstate
if tstate == nil {
// This action has probably not been registered.
tstate = globalRegistry.tstate
tstate = registry.Global.TracingState()
}
return tracing.RunInNewSpan(ctx, tstate, a.name, "action", false, input,
func(ctx context.Context, input In) (Out, error) {
start := time.Now()
var err error
if err = validateValue(input, a.inputSchema); err != nil {
if err = internal.ValidateValue(input, a.inputSchema); err != nil {
err = fmt.Errorf("invalid input: %w", err)
}
var output Out
if err == nil {
output, err = a.fn(ctx, input, cb)
if err == nil {
if err = validateValue(output, a.outputSchema); err != nil {
if err = internal.ValidateValue(output, a.outputSchema); err != nil {
err = fmt.Errorf("invalid output: %w", err)
}
}
}
latency := time.Since(start)
if err != nil {
writeActionFailure(ctx, a.name, latency, err)
metrics.WriteActionFailure(ctx, a.name, latency, err)
return internal.Zero[Out](), err
}
writeActionSuccess(ctx, a.name, latency)
metrics.WriteActionSuccess(ctx, a.name, latency)
return output, nil
})
}

func (a *Action[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) {
// RunJSON runs the action with a JSON input, and returns a JSON result.
func (a *Action[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) {
// Validate input before unmarshaling it because invalid or unknown fields will be discarded in the process.
if err := validateJSON(input, a.inputSchema); err != nil {
if err := internal.ValidateJSON(input, a.inputSchema); err != nil {
return nil, err
}
var in In
Expand Down Expand Up @@ -238,44 +234,9 @@ func (a *Action[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMes
return json.RawMessage(bytes), nil
}

// action is the type that all Action[I, O, S] have in common.
type action interface {
Name() string
actionType() atype.ActionType

// runJSON uses encoding/json to unmarshal the input,
// calls Action.Run, then returns the marshaled result.
runJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error)

// desc returns a description of the action.
// It should set all fields of actionDesc except Key, which
// the registry will set.
desc() actionDesc

// setTracingState set's the action's tracing.State.
setTracingState(*tracing.State)
}

// An actionDesc is a description of an Action.
// It is used to provide a list of registered actions.
type actionDesc struct {
Key string `json:"key"` // full key from the registry
Name string `json:"name"`
Description string `json:"description"`
Metadata map[string]any `json:"metadata"`
InputSchema *jsonschema.Schema `json:"inputSchema"`
OutputSchema *jsonschema.Schema `json:"outputSchema"`
}

func (d1 actionDesc) equal(d2 actionDesc) bool {
return d1.Key == d2.Key &&
d1.Name == d2.Name &&
d1.Description == d2.Description &&
maps.Equal(d1.Metadata, d2.Metadata)
}

func (a *Action[I, O, S]) desc() actionDesc {
ad := actionDesc{
// Desc returns a description of the action.
func (a *Action[I, O, S]) Desc() common.ActionDesc {
ad := common.ActionDesc{
Name: a.name,
Description: a.description,
Metadata: a.metadata,
Expand All @@ -292,12 +253,14 @@ func (a *Action[I, O, S]) desc() actionDesc {
return ad
}

func inferJSONSchema(x any) (s *jsonschema.Schema) {
r := jsonschema.Reflector{
DoNotReference: true,
// LookupActionFor returns the action for the given key in the global registry,
// or nil if there is none.
// It panics if the action is of the wrong type.
func LookupActionFor[In, Out, Stream any](typ atype.ActionType, provider, name string) *Action[In, Out, Stream] {
key := fmt.Sprintf("/%s/%s/%s", typ, provider, name)
a := registry.Global.LookupAction(key)
if a == nil {
return nil
}
s = r.Reflect(x)
// TODO: Unwind this change once Monaco Editor supports newer than JSON schema draft-07.
s.Version = ""
return s
return a.(*Action[In, Out, Stream])
}
16 changes: 9 additions & 7 deletions go/core/action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@ import (
"testing"

"github.com/firebase/genkit/go/internal/atype"
"github.com/firebase/genkit/go/internal/common"
"github.com/firebase/genkit/go/internal/registry"
)

func inc(_ context.Context, x int) (int, error) {
return x + 1, nil
}

func TestActionRun(t *testing.T) {
a := newAction("inc", atype.Custom, nil, inc)
a := NewAction("inc", atype.Custom, nil, inc)
got, err := a.Run(context.Background(), 3, nil)
if err != nil {
t.Fatal(err)
Expand All @@ -39,10 +41,10 @@ func TestActionRun(t *testing.T) {
}

func TestActionRunJSON(t *testing.T) {
a := newAction("inc", atype.Custom, nil, inc)
a := NewAction("inc", atype.Custom, nil, inc)
input := []byte("3")
want := []byte("4")
got, err := a.runJSON(context.Background(), input, nil)
got, err := a.RunJSON(context.Background(), input, nil)
if err != nil {
t.Fatal(err)
}
Expand All @@ -53,7 +55,7 @@ func TestActionRunJSON(t *testing.T) {

func TestNewAction(t *testing.T) {
// Verify that struct{} can occur in the function signature.
_ = newAction("f", atype.Custom, nil, func(context.Context, int) (struct{}, error) { return struct{}{}, nil })
_ = NewAction("f", atype.Custom, nil, func(context.Context, int) (struct{}, error) { return struct{}{}, nil })
}

// count streams the numbers from 0 to n-1, then returns n.
Expand All @@ -70,7 +72,7 @@ func count(ctx context.Context, n int, cb func(context.Context, int) error) (int

func TestActionStreaming(t *testing.T) {
ctx := context.Background()
a := newStreamingAction("count", atype.Custom, nil, count)
a := NewStreamingAction("count", atype.Custom, nil, count)
const n = 3

// Non-streaming.
Expand Down Expand Up @@ -103,12 +105,12 @@ func TestActionStreaming(t *testing.T) {
func TestActionTracing(t *testing.T) {
ctx := context.Background()
const actionName = "TestTracing-inc"
a := newAction(actionName, atype.Custom, nil, inc)
a := NewAction(actionName, atype.Custom, nil, inc)
if _, err := a.Run(context.Background(), 3, nil); err != nil {
t.Fatal(err)
}
// The dev TraceStore is registered by Init, called from TestMain.
ts := globalRegistry.lookupTraceStore(EnvironmentDev)
ts := registry.Global.LookupTraceStore(common.EnvironmentDev)
tds, _, err := ts.List(ctx, nil)
if err != nil {
t.Fatal(err)
Expand Down
26 changes: 25 additions & 1 deletion go/core/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,31 @@
// Run the Go code generator on the file just created.
//go:generate go run ../internal/cmd/jsonschemagen -outdir .. -config schemas.config ../../genkit-tools/genkit-schema.json core

// Package core implements Genkit actions, flows and other essential machinery.
// Package core implements Genkit actions and other essential machinery.
// This package is primarily intended for Genkit internals and for plugins.
// Genkit applications should use the genkit package.
package core

import (
"context"

"github.com/firebase/genkit/go/core/tracing"
"github.com/firebase/genkit/go/internal/common"
"github.com/firebase/genkit/go/internal/registry"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
)

// RegisterTraceStore uses the given trace.Store to record traces in the prod environment.
// (A trace.Store that writes to the local filesystem is always installed in the dev environment.)
// The returned function should be called before the program ends to ensure that
// all pending data is stored.
// RegisterTraceStore panics if called more than once.
func RegisterTraceStore(ts tracing.Store) (shutdown func(context.Context) error) {
registry.Global.RegisterTraceStore(common.EnvironmentProd, ts)
return registry.Global.TracingState().AddTraceStoreBatch(ts)
}

// RegisterSpanProcessor registers an OpenTelemetry SpanProcessor for tracing.
func RegisterSpanProcessor(sp sdktrace.SpanProcessor) {
registry.Global.RegisterSpanProcessor(sp)
}
7 changes: 4 additions & 3 deletions go/core/file_flow_state_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"path/filepath"

"github.com/firebase/genkit/go/internal"
"github.com/firebase/genkit/go/internal/common"
)

// A FileFlowStateStore is a FlowStateStore that writes flowStates to files.
Expand All @@ -36,9 +37,9 @@ func NewFileFlowStateStore(dir string) (*FileFlowStateStore, error) {
return &FileFlowStateStore{dir: dir}, nil
}

func (s *FileFlowStateStore) Save(ctx context.Context, id string, fs flowStater) error {
fs.lock()
defer fs.unlock()
func (s *FileFlowStateStore) Save(ctx context.Context, id string, fs common.FlowStater) error {
fs.Lock()
defer fs.Unlock()
return internal.WriteJSONFile(filepath.Join(s.dir, internal.Clean(id)), fs)
}

Expand Down
Loading
Loading