Skip to content

Commit

Permalink
[Go] improve API (#523)
Browse files Browse the repository at this point in the history
This PR consists solely of refactoring. There are no behavior changes.

Address the following warts on the API:

- Although flows are a key concept in Genkit, and the genkit package
  includes a DefineFlow function, the Flow type itself lived in the core
  package.

- The core package contained several exported methods that were intended
  only to be called by the genkit package, a clear sign that something
  is off in the API design.

The actual visible changes to achieve this are small, but a lot of
code was moved. Most of the changes involve moving unexported symbols
into internal packages so they could be used by other packages.

The main API changes are:

- Flow is now in the genkit package.

- There are no more InternalXXX symbols in core.

There are some other minor changes, like the removal of the NoStream
type.
  • Loading branch information
jba authored Jul 2, 2024
1 parent c14f4e1 commit 1afa05d
Show file tree
Hide file tree
Showing 23 changed files with 495 additions and 460 deletions.
3 changes: 2 additions & 1 deletion go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/core/logger"
"github.com/firebase/genkit/go/internal"
"github.com/firebase/genkit/go/internal/atype"
)

Expand Down Expand Up @@ -153,7 +154,7 @@ func validCandidate(c *Candidate, output *GenerateRequestOutput) (*Candidate, er
if err != nil {
return nil, fmt.Errorf("expected schema is not valid: %w", err)
}
if err = core.ValidateRaw([]byte(text), schemaBytes); err != nil {
if err = internal.ValidateRaw([]byte(text), schemaBytes); err != nil {
return nil, err
}
// TODO: Verify that it okay to replace all content with JSON.
Expand Down
127 changes: 45 additions & 82 deletions go/core/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ import (
"context"
"encoding/json"
"fmt"
"maps"
"time"

"github.com/firebase/genkit/go/core/logger"
"github.com/firebase/genkit/go/core/tracing"
"github.com/firebase/genkit/go/internal"
"github.com/firebase/genkit/go/internal/atype"
"github.com/firebase/genkit/go/internal/common"
"github.com/firebase/genkit/go/internal/metrics"
"github.com/firebase/genkit/go/internal/registry"
"github.com/invopop/jsonschema"
)

Expand All @@ -39,13 +41,6 @@ type Func[In, Out, Stream any] func(context.Context, In, func(context.Context, S

// TODO(jba): use a generic type alias for the above when they become available?

// NoStream indicates that the action or flow does not support streaming.
// A Func[I, O, NoStream] will ignore its streaming callback.
// Such a function corresponds to a Flow[I, O, struct{}].
type NoStream = func(context.Context, struct{}) error

type streamingCallback[Stream any] func(context.Context, Stream) error

// An Action is a named, observable operation.
// It consists of a function that takes an input of type I and returns an output
// of type O, optionally streaming values of type S incrementally by invoking a callback.
Expand All @@ -69,22 +64,22 @@ type Action[In, Out, Stream any] struct {

// DefineAction creates a new Action and registers it.
func DefineAction[In, Out any](provider, name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
return defineAction(globalRegistry, provider, name, atype, metadata, fn)
return defineAction(registry.Global, provider, name, atype, metadata, fn)
}

func defineAction[In, Out any](r *registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
a := newAction(provider+"/"+name, atype, metadata, fn)
r.registerAction(a)
func defineAction[In, Out any](r *registry.Registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
a := NewAction(provider+"/"+name, atype, metadata, fn)
r.RegisterAction(atype, a)
return a
}

func DefineStreamingAction[In, Out, Stream any](provider, name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
return defineStreamingAction(globalRegistry, provider, name, atype, metadata, fn)
return defineStreamingAction(registry.Global, provider, name, atype, metadata, fn)
}

func defineStreamingAction[In, Out, Stream any](r *registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
a := newStreamingAction(provider+"/"+name, atype, metadata, fn)
r.registerAction(a)
func defineStreamingAction[In, Out, Stream any](r *registry.Registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
a := NewStreamingAction(provider+"/"+name, atype, metadata, fn)
r.RegisterAction(atype, a)
return a
}

Expand All @@ -103,31 +98,33 @@ func DefineActionWithInputSchema[Out any](
inputSchema *jsonschema.Schema,
fn func(context.Context, any) (Out, error),
) *Action[any, Out, struct{}] {
return defineActionWithInputSchema(globalRegistry, provider, name, atype, metadata, inputSchema, fn)
return defineActionWithInputSchema(registry.Global, provider, name, atype, metadata, inputSchema, fn)
}

func defineActionWithInputSchema[Out any](
r *registry,
r *registry.Registry,
provider, name string,
atype atype.ActionType,
metadata map[string]any,
inputSchema *jsonschema.Schema,
fn func(context.Context, any) (Out, error),
) *Action[any, Out, struct{}] {
a := newActionWithInputSchema(provider+"/"+name, atype, metadata, fn, inputSchema)
r.registerAction(a)
r.RegisterAction(atype, a)
return a
}

// newAction creates a new Action with the given name and non-streaming function.
func newAction[In, Out any](name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
return newStreamingAction(name, atype, metadata, func(ctx context.Context, in In, cb NoStream) (Out, error) {
type noStream = func(context.Context, struct{}) error

// NewAction creates a new Action with the given name and non-streaming function.
func NewAction[In, Out any](name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
return NewStreamingAction(name, atype, metadata, func(ctx context.Context, in In, cb noStream) (Out, error) {
return fn(ctx, in)
})
}

// newStreamingAction creates a new Action with the given name and streaming function.
func newStreamingAction[In, Out, Stream any](name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
// NewStreamingAction creates a new Action with the given name and streaming function.
func NewStreamingAction[In, Out, Stream any](name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
var i In
var o Out
return &Action[In, Out, Stream]{
Expand All @@ -137,8 +134,8 @@ func newStreamingAction[In, Out, Stream any](name string, atype atype.ActionType
tracing.SetCustomMetadataAttr(ctx, "subtype", string(atype))
return fn(ctx, input, sc)
},
inputSchema: inferJSONSchema(i),
outputSchema: inferJSONSchema(o),
inputSchema: internal.InferJSONSchema(i),
outputSchema: internal.InferJSONSchema(o),
metadata: metadata,
}
}
Expand All @@ -153,18 +150,16 @@ func newActionWithInputSchema[Out any](name string, atype atype.ActionType, meta
return fn(ctx, input)
},
inputSchema: inputSchema,
outputSchema: inferJSONSchema(o),
outputSchema: internal.InferJSONSchema(o),
metadata: metadata,
}
}

// Name returns the Action's Name.
func (a *Action[In, Out, Stream]) Name() string { return a.name }

func (a *Action[In, Out, Stream]) actionType() atype.ActionType { return a.atype }

// setTracingState sets the action's tracing.State.
func (a *Action[In, Out, Stream]) setTracingState(tstate *tracing.State) { a.tstate = tstate }
func (a *Action[In, Out, Stream]) SetTracingState(tstate *tracing.State) { a.tstate = tstate }

// Run executes the Action's function in a new trace span.
func (a *Action[In, Out, Stream]) Run(ctx context.Context, input In, cb func(context.Context, Stream) error) (output Out, err error) {
Expand All @@ -180,37 +175,38 @@ func (a *Action[In, Out, Stream]) Run(ctx context.Context, input In, cb func(con
tstate := a.tstate
if tstate == nil {
// This action has probably not been registered.
tstate = globalRegistry.tstate
tstate = registry.Global.TracingState()
}
return tracing.RunInNewSpan(ctx, tstate, a.name, "action", false, input,
func(ctx context.Context, input In) (Out, error) {
start := time.Now()
var err error
if err = validateValue(input, a.inputSchema); err != nil {
if err = internal.ValidateValue(input, a.inputSchema); err != nil {
err = fmt.Errorf("invalid input: %w", err)
}
var output Out
if err == nil {
output, err = a.fn(ctx, input, cb)
if err == nil {
if err = validateValue(output, a.outputSchema); err != nil {
if err = internal.ValidateValue(output, a.outputSchema); err != nil {
err = fmt.Errorf("invalid output: %w", err)
}
}
}
latency := time.Since(start)
if err != nil {
writeActionFailure(ctx, a.name, latency, err)
metrics.WriteActionFailure(ctx, a.name, latency, err)
return internal.Zero[Out](), err
}
writeActionSuccess(ctx, a.name, latency)
metrics.WriteActionSuccess(ctx, a.name, latency)
return output, nil
})
}

func (a *Action[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) {
// RunJSON runs the action with a JSON input, and returns a JSON result.
func (a *Action[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) {
// Validate input before unmarshaling it because invalid or unknown fields will be discarded in the process.
if err := validateJSON(input, a.inputSchema); err != nil {
if err := internal.ValidateJSON(input, a.inputSchema); err != nil {
return nil, err
}
var in In
Expand Down Expand Up @@ -238,44 +234,9 @@ func (a *Action[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMes
return json.RawMessage(bytes), nil
}

// action is the type that all Action[I, O, S] have in common.
type action interface {
Name() string
actionType() atype.ActionType

// runJSON uses encoding/json to unmarshal the input,
// calls Action.Run, then returns the marshaled result.
runJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error)

// desc returns a description of the action.
// It should set all fields of actionDesc except Key, which
// the registry will set.
desc() actionDesc

// setTracingState set's the action's tracing.State.
setTracingState(*tracing.State)
}

// An actionDesc is a description of an Action.
// It is used to provide a list of registered actions.
type actionDesc struct {
Key string `json:"key"` // full key from the registry
Name string `json:"name"`
Description string `json:"description"`
Metadata map[string]any `json:"metadata"`
InputSchema *jsonschema.Schema `json:"inputSchema"`
OutputSchema *jsonschema.Schema `json:"outputSchema"`
}

func (d1 actionDesc) equal(d2 actionDesc) bool {
return d1.Key == d2.Key &&
d1.Name == d2.Name &&
d1.Description == d2.Description &&
maps.Equal(d1.Metadata, d2.Metadata)
}

func (a *Action[I, O, S]) desc() actionDesc {
ad := actionDesc{
// Desc returns a description of the action.
func (a *Action[I, O, S]) Desc() common.ActionDesc {
ad := common.ActionDesc{
Name: a.name,
Description: a.description,
Metadata: a.metadata,
Expand All @@ -292,12 +253,14 @@ func (a *Action[I, O, S]) desc() actionDesc {
return ad
}

func inferJSONSchema(x any) (s *jsonschema.Schema) {
r := jsonschema.Reflector{
DoNotReference: true,
// LookupActionFor returns the action for the given key in the global registry,
// or nil if there is none.
// It panics if the action is of the wrong type.
func LookupActionFor[In, Out, Stream any](typ atype.ActionType, provider, name string) *Action[In, Out, Stream] {
key := fmt.Sprintf("/%s/%s/%s", typ, provider, name)
a := registry.Global.LookupAction(key)
if a == nil {
return nil
}
s = r.Reflect(x)
// TODO: Unwind this change once Monaco Editor supports newer than JSON schema draft-07.
s.Version = ""
return s
return a.(*Action[In, Out, Stream])
}
16 changes: 9 additions & 7 deletions go/core/action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@ import (
"testing"

"github.com/firebase/genkit/go/internal/atype"
"github.com/firebase/genkit/go/internal/common"
"github.com/firebase/genkit/go/internal/registry"
)

func inc(_ context.Context, x int) (int, error) {
return x + 1, nil
}

func TestActionRun(t *testing.T) {
a := newAction("inc", atype.Custom, nil, inc)
a := NewAction("inc", atype.Custom, nil, inc)
got, err := a.Run(context.Background(), 3, nil)
if err != nil {
t.Fatal(err)
Expand All @@ -39,10 +41,10 @@ func TestActionRun(t *testing.T) {
}

func TestActionRunJSON(t *testing.T) {
a := newAction("inc", atype.Custom, nil, inc)
a := NewAction("inc", atype.Custom, nil, inc)
input := []byte("3")
want := []byte("4")
got, err := a.runJSON(context.Background(), input, nil)
got, err := a.RunJSON(context.Background(), input, nil)
if err != nil {
t.Fatal(err)
}
Expand All @@ -53,7 +55,7 @@ func TestActionRunJSON(t *testing.T) {

func TestNewAction(t *testing.T) {
// Verify that struct{} can occur in the function signature.
_ = newAction("f", atype.Custom, nil, func(context.Context, int) (struct{}, error) { return struct{}{}, nil })
_ = NewAction("f", atype.Custom, nil, func(context.Context, int) (struct{}, error) { return struct{}{}, nil })
}

// count streams the numbers from 0 to n-1, then returns n.
Expand All @@ -70,7 +72,7 @@ func count(ctx context.Context, n int, cb func(context.Context, int) error) (int

func TestActionStreaming(t *testing.T) {
ctx := context.Background()
a := newStreamingAction("count", atype.Custom, nil, count)
a := NewStreamingAction("count", atype.Custom, nil, count)
const n = 3

// Non-streaming.
Expand Down Expand Up @@ -103,12 +105,12 @@ func TestActionStreaming(t *testing.T) {
func TestActionTracing(t *testing.T) {
ctx := context.Background()
const actionName = "TestTracing-inc"
a := newAction(actionName, atype.Custom, nil, inc)
a := NewAction(actionName, atype.Custom, nil, inc)
if _, err := a.Run(context.Background(), 3, nil); err != nil {
t.Fatal(err)
}
// The dev TraceStore is registered by Init, called from TestMain.
ts := globalRegistry.lookupTraceStore(EnvironmentDev)
ts := registry.Global.LookupTraceStore(common.EnvironmentDev)
tds, _, err := ts.List(ctx, nil)
if err != nil {
t.Fatal(err)
Expand Down
26 changes: 25 additions & 1 deletion go/core/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,31 @@
// Run the Go code generator on the file just created.
//go:generate go run ../internal/cmd/jsonschemagen -outdir .. -config schemas.config ../../genkit-tools/genkit-schema.json core

// Package core implements Genkit actions, flows and other essential machinery.
// Package core implements Genkit actions and other essential machinery.
// This package is primarily intended for Genkit internals and for plugins.
// Genkit applications should use the genkit package.
package core

import (
"context"

"github.com/firebase/genkit/go/core/tracing"
"github.com/firebase/genkit/go/internal/common"
"github.com/firebase/genkit/go/internal/registry"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
)

// RegisterTraceStore uses the given trace.Store to record traces in the prod environment.
// (A trace.Store that writes to the local filesystem is always installed in the dev environment.)
// The returned function should be called before the program ends to ensure that
// all pending data is stored.
// RegisterTraceStore panics if called more than once.
func RegisterTraceStore(ts tracing.Store) (shutdown func(context.Context) error) {
registry.Global.RegisterTraceStore(common.EnvironmentProd, ts)
return registry.Global.TracingState().AddTraceStoreBatch(ts)
}

// RegisterSpanProcessor registers an OpenTelemetry SpanProcessor for tracing.
func RegisterSpanProcessor(sp sdktrace.SpanProcessor) {
registry.Global.RegisterSpanProcessor(sp)
}
7 changes: 4 additions & 3 deletions go/core/file_flow_state_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"path/filepath"

"github.com/firebase/genkit/go/internal"
"github.com/firebase/genkit/go/internal/common"
)

// A FileFlowStateStore is a FlowStateStore that writes flowStates to files.
Expand All @@ -36,9 +37,9 @@ func NewFileFlowStateStore(dir string) (*FileFlowStateStore, error) {
return &FileFlowStateStore{dir: dir}, nil
}

func (s *FileFlowStateStore) Save(ctx context.Context, id string, fs flowStater) error {
fs.lock()
defer fs.unlock()
func (s *FileFlowStateStore) Save(ctx context.Context, id string, fs common.FlowStater) error {
fs.Lock()
defer fs.Unlock()
return internal.WriteJSONFile(filepath.Join(s.dir, internal.Clean(id)), fs)
}

Expand Down
Loading

0 comments on commit 1afa05d

Please sign in to comment.