diff --git a/go/ai/generate.go b/go/ai/generate.go index 0f2136f5c..424a36bae 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -25,6 +25,7 @@ import ( "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/logger" + "github.com/firebase/genkit/go/internal" "github.com/firebase/genkit/go/internal/atype" ) @@ -153,7 +154,7 @@ func validCandidate(c *Candidate, output *GenerateRequestOutput) (*Candidate, er if err != nil { return nil, fmt.Errorf("expected schema is not valid: %w", err) } - if err = core.ValidateRaw([]byte(text), schemaBytes); err != nil { + if err = internal.ValidateRaw([]byte(text), schemaBytes); err != nil { return nil, err } // TODO: Verify that it okay to replace all content with JSON. diff --git a/go/core/action.go b/go/core/action.go index dd764f852..1740032e1 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -18,13 +18,15 @@ import ( "context" "encoding/json" "fmt" - "maps" "time" "github.com/firebase/genkit/go/core/logger" "github.com/firebase/genkit/go/core/tracing" "github.com/firebase/genkit/go/internal" "github.com/firebase/genkit/go/internal/atype" + "github.com/firebase/genkit/go/internal/common" + "github.com/firebase/genkit/go/internal/metrics" + "github.com/firebase/genkit/go/internal/registry" "github.com/invopop/jsonschema" ) @@ -39,13 +41,6 @@ type Func[In, Out, Stream any] func(context.Context, In, func(context.Context, S // TODO(jba): use a generic type alias for the above when they become available? -// NoStream indicates that the action or flow does not support streaming. -// A Func[I, O, NoStream] will ignore its streaming callback. -// Such a function corresponds to a Flow[I, O, struct{}]. -type NoStream = func(context.Context, struct{}) error - -type streamingCallback[Stream any] func(context.Context, Stream) error - // An Action is a named, observable operation. // It consists of a function that takes an input of type I and returns an output // of type O, optionally streaming values of type S incrementally by invoking a callback. @@ -69,22 +64,22 @@ type Action[In, Out, Stream any] struct { // DefineAction creates a new Action and registers it. func DefineAction[In, Out any](provider, name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] { - return defineAction(globalRegistry, provider, name, atype, metadata, fn) + return defineAction(registry.Global, provider, name, atype, metadata, fn) } -func defineAction[In, Out any](r *registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] { - a := newAction(provider+"/"+name, atype, metadata, fn) - r.registerAction(a) +func defineAction[In, Out any](r *registry.Registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] { + a := NewAction(provider+"/"+name, atype, metadata, fn) + r.RegisterAction(atype, a) return a } func DefineStreamingAction[In, Out, Stream any](provider, name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] { - return defineStreamingAction(globalRegistry, provider, name, atype, metadata, fn) + return defineStreamingAction(registry.Global, provider, name, atype, metadata, fn) } -func defineStreamingAction[In, Out, Stream any](r *registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] { - a := newStreamingAction(provider+"/"+name, atype, metadata, fn) - r.registerAction(a) +func defineStreamingAction[In, Out, Stream any](r *registry.Registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] { + a := NewStreamingAction(provider+"/"+name, atype, metadata, fn) + r.RegisterAction(atype, a) return a } @@ -103,11 +98,11 @@ func DefineActionWithInputSchema[Out any]( inputSchema *jsonschema.Schema, fn func(context.Context, any) (Out, error), ) *Action[any, Out, struct{}] { - return defineActionWithInputSchema(globalRegistry, provider, name, atype, metadata, inputSchema, fn) + return defineActionWithInputSchema(registry.Global, provider, name, atype, metadata, inputSchema, fn) } func defineActionWithInputSchema[Out any]( - r *registry, + r *registry.Registry, provider, name string, atype atype.ActionType, metadata map[string]any, @@ -115,19 +110,21 @@ func defineActionWithInputSchema[Out any]( fn func(context.Context, any) (Out, error), ) *Action[any, Out, struct{}] { a := newActionWithInputSchema(provider+"/"+name, atype, metadata, fn, inputSchema) - r.registerAction(a) + r.RegisterAction(atype, a) return a } -// newAction creates a new Action with the given name and non-streaming function. -func newAction[In, Out any](name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] { - return newStreamingAction(name, atype, metadata, func(ctx context.Context, in In, cb NoStream) (Out, error) { +type noStream = func(context.Context, struct{}) error + +// NewAction creates a new Action with the given name and non-streaming function. +func NewAction[In, Out any](name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] { + return NewStreamingAction(name, atype, metadata, func(ctx context.Context, in In, cb noStream) (Out, error) { return fn(ctx, in) }) } -// newStreamingAction creates a new Action with the given name and streaming function. -func newStreamingAction[In, Out, Stream any](name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] { +// NewStreamingAction creates a new Action with the given name and streaming function. +func NewStreamingAction[In, Out, Stream any](name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] { var i In var o Out return &Action[In, Out, Stream]{ @@ -137,8 +134,8 @@ func newStreamingAction[In, Out, Stream any](name string, atype atype.ActionType tracing.SetCustomMetadataAttr(ctx, "subtype", string(atype)) return fn(ctx, input, sc) }, - inputSchema: inferJSONSchema(i), - outputSchema: inferJSONSchema(o), + inputSchema: internal.InferJSONSchema(i), + outputSchema: internal.InferJSONSchema(o), metadata: metadata, } } @@ -153,7 +150,7 @@ func newActionWithInputSchema[Out any](name string, atype atype.ActionType, meta return fn(ctx, input) }, inputSchema: inputSchema, - outputSchema: inferJSONSchema(o), + outputSchema: internal.InferJSONSchema(o), metadata: metadata, } } @@ -161,10 +158,8 @@ func newActionWithInputSchema[Out any](name string, atype atype.ActionType, meta // Name returns the Action's Name. func (a *Action[In, Out, Stream]) Name() string { return a.name } -func (a *Action[In, Out, Stream]) actionType() atype.ActionType { return a.atype } - // setTracingState sets the action's tracing.State. -func (a *Action[In, Out, Stream]) setTracingState(tstate *tracing.State) { a.tstate = tstate } +func (a *Action[In, Out, Stream]) SetTracingState(tstate *tracing.State) { a.tstate = tstate } // Run executes the Action's function in a new trace span. func (a *Action[In, Out, Stream]) Run(ctx context.Context, input In, cb func(context.Context, Stream) error) (output Out, err error) { @@ -180,37 +175,38 @@ func (a *Action[In, Out, Stream]) Run(ctx context.Context, input In, cb func(con tstate := a.tstate if tstate == nil { // This action has probably not been registered. - tstate = globalRegistry.tstate + tstate = registry.Global.TracingState() } return tracing.RunInNewSpan(ctx, tstate, a.name, "action", false, input, func(ctx context.Context, input In) (Out, error) { start := time.Now() var err error - if err = validateValue(input, a.inputSchema); err != nil { + if err = internal.ValidateValue(input, a.inputSchema); err != nil { err = fmt.Errorf("invalid input: %w", err) } var output Out if err == nil { output, err = a.fn(ctx, input, cb) if err == nil { - if err = validateValue(output, a.outputSchema); err != nil { + if err = internal.ValidateValue(output, a.outputSchema); err != nil { err = fmt.Errorf("invalid output: %w", err) } } } latency := time.Since(start) if err != nil { - writeActionFailure(ctx, a.name, latency, err) + metrics.WriteActionFailure(ctx, a.name, latency, err) return internal.Zero[Out](), err } - writeActionSuccess(ctx, a.name, latency) + metrics.WriteActionSuccess(ctx, a.name, latency) return output, nil }) } -func (a *Action[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) { +// RunJSON runs the action with a JSON input, and returns a JSON result. +func (a *Action[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) { // Validate input before unmarshaling it because invalid or unknown fields will be discarded in the process. - if err := validateJSON(input, a.inputSchema); err != nil { + if err := internal.ValidateJSON(input, a.inputSchema); err != nil { return nil, err } var in In @@ -238,44 +234,9 @@ func (a *Action[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMes return json.RawMessage(bytes), nil } -// action is the type that all Action[I, O, S] have in common. -type action interface { - Name() string - actionType() atype.ActionType - - // runJSON uses encoding/json to unmarshal the input, - // calls Action.Run, then returns the marshaled result. - runJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) - - // desc returns a description of the action. - // It should set all fields of actionDesc except Key, which - // the registry will set. - desc() actionDesc - - // setTracingState set's the action's tracing.State. - setTracingState(*tracing.State) -} - -// An actionDesc is a description of an Action. -// It is used to provide a list of registered actions. -type actionDesc struct { - Key string `json:"key"` // full key from the registry - Name string `json:"name"` - Description string `json:"description"` - Metadata map[string]any `json:"metadata"` - InputSchema *jsonschema.Schema `json:"inputSchema"` - OutputSchema *jsonschema.Schema `json:"outputSchema"` -} - -func (d1 actionDesc) equal(d2 actionDesc) bool { - return d1.Key == d2.Key && - d1.Name == d2.Name && - d1.Description == d2.Description && - maps.Equal(d1.Metadata, d2.Metadata) -} - -func (a *Action[I, O, S]) desc() actionDesc { - ad := actionDesc{ +// Desc returns a description of the action. +func (a *Action[I, O, S]) Desc() common.ActionDesc { + ad := common.ActionDesc{ Name: a.name, Description: a.description, Metadata: a.metadata, @@ -292,12 +253,14 @@ func (a *Action[I, O, S]) desc() actionDesc { return ad } -func inferJSONSchema(x any) (s *jsonschema.Schema) { - r := jsonschema.Reflector{ - DoNotReference: true, +// LookupActionFor returns the action for the given key in the global registry, +// or nil if there is none. +// It panics if the action is of the wrong type. +func LookupActionFor[In, Out, Stream any](typ atype.ActionType, provider, name string) *Action[In, Out, Stream] { + key := fmt.Sprintf("/%s/%s/%s", typ, provider, name) + a := registry.Global.LookupAction(key) + if a == nil { + return nil } - s = r.Reflect(x) - // TODO: Unwind this change once Monaco Editor supports newer than JSON schema draft-07. - s.Version = "" - return s + return a.(*Action[In, Out, Stream]) } diff --git a/go/core/action_test.go b/go/core/action_test.go index a73b0fd49..c1dad2722 100644 --- a/go/core/action_test.go +++ b/go/core/action_test.go @@ -21,6 +21,8 @@ import ( "testing" "github.com/firebase/genkit/go/internal/atype" + "github.com/firebase/genkit/go/internal/common" + "github.com/firebase/genkit/go/internal/registry" ) func inc(_ context.Context, x int) (int, error) { @@ -28,7 +30,7 @@ func inc(_ context.Context, x int) (int, error) { } func TestActionRun(t *testing.T) { - a := newAction("inc", atype.Custom, nil, inc) + a := NewAction("inc", atype.Custom, nil, inc) got, err := a.Run(context.Background(), 3, nil) if err != nil { t.Fatal(err) @@ -39,10 +41,10 @@ func TestActionRun(t *testing.T) { } func TestActionRunJSON(t *testing.T) { - a := newAction("inc", atype.Custom, nil, inc) + a := NewAction("inc", atype.Custom, nil, inc) input := []byte("3") want := []byte("4") - got, err := a.runJSON(context.Background(), input, nil) + got, err := a.RunJSON(context.Background(), input, nil) if err != nil { t.Fatal(err) } @@ -53,7 +55,7 @@ func TestActionRunJSON(t *testing.T) { func TestNewAction(t *testing.T) { // Verify that struct{} can occur in the function signature. - _ = newAction("f", atype.Custom, nil, func(context.Context, int) (struct{}, error) { return struct{}{}, nil }) + _ = NewAction("f", atype.Custom, nil, func(context.Context, int) (struct{}, error) { return struct{}{}, nil }) } // count streams the numbers from 0 to n-1, then returns n. @@ -70,7 +72,7 @@ func count(ctx context.Context, n int, cb func(context.Context, int) error) (int func TestActionStreaming(t *testing.T) { ctx := context.Background() - a := newStreamingAction("count", atype.Custom, nil, count) + a := NewStreamingAction("count", atype.Custom, nil, count) const n = 3 // Non-streaming. @@ -103,12 +105,12 @@ func TestActionStreaming(t *testing.T) { func TestActionTracing(t *testing.T) { ctx := context.Background() const actionName = "TestTracing-inc" - a := newAction(actionName, atype.Custom, nil, inc) + a := NewAction(actionName, atype.Custom, nil, inc) if _, err := a.Run(context.Background(), 3, nil); err != nil { t.Fatal(err) } // The dev TraceStore is registered by Init, called from TestMain. - ts := globalRegistry.lookupTraceStore(EnvironmentDev) + ts := registry.Global.LookupTraceStore(common.EnvironmentDev) tds, _, err := ts.List(ctx, nil) if err != nil { t.Fatal(err) diff --git a/go/core/core.go b/go/core/core.go index 53ae31af9..9f398e362 100644 --- a/go/core/core.go +++ b/go/core/core.go @@ -19,7 +19,31 @@ // Run the Go code generator on the file just created. //go:generate go run ../internal/cmd/jsonschemagen -outdir .. -config schemas.config ../../genkit-tools/genkit-schema.json core -// Package core implements Genkit actions, flows and other essential machinery. +// Package core implements Genkit actions and other essential machinery. // This package is primarily intended for Genkit internals and for plugins. // Genkit applications should use the genkit package. package core + +import ( + "context" + + "github.com/firebase/genkit/go/core/tracing" + "github.com/firebase/genkit/go/internal/common" + "github.com/firebase/genkit/go/internal/registry" + sdktrace "go.opentelemetry.io/otel/sdk/trace" +) + +// RegisterTraceStore uses the given trace.Store to record traces in the prod environment. +// (A trace.Store that writes to the local filesystem is always installed in the dev environment.) +// The returned function should be called before the program ends to ensure that +// all pending data is stored. +// RegisterTraceStore panics if called more than once. +func RegisterTraceStore(ts tracing.Store) (shutdown func(context.Context) error) { + registry.Global.RegisterTraceStore(common.EnvironmentProd, ts) + return registry.Global.TracingState().AddTraceStoreBatch(ts) +} + +// RegisterSpanProcessor registers an OpenTelemetry SpanProcessor for tracing. +func RegisterSpanProcessor(sp sdktrace.SpanProcessor) { + registry.Global.RegisterSpanProcessor(sp) +} diff --git a/go/core/file_flow_state_store.go b/go/core/file_flow_state_store.go index a4b3a7b17..1caafcb84 100644 --- a/go/core/file_flow_state_store.go +++ b/go/core/file_flow_state_store.go @@ -20,6 +20,7 @@ import ( "path/filepath" "github.com/firebase/genkit/go/internal" + "github.com/firebase/genkit/go/internal/common" ) // A FileFlowStateStore is a FlowStateStore that writes flowStates to files. @@ -36,9 +37,9 @@ func NewFileFlowStateStore(dir string) (*FileFlowStateStore, error) { return &FileFlowStateStore{dir: dir}, nil } -func (s *FileFlowStateStore) Save(ctx context.Context, id string, fs flowStater) error { - fs.lock() - defer fs.unlock() +func (s *FileFlowStateStore) Save(ctx context.Context, id string, fs common.FlowStater) error { + fs.Lock() + defer fs.Unlock() return internal.WriteJSONFile(filepath.Join(s.dir, internal.Clean(id)), fs) } diff --git a/go/core/flow_state_store.go b/go/core/flow_state_store.go index 497a8d2ff..27617e2a4 100644 --- a/go/core/flow_state_store.go +++ b/go/core/flow_state_store.go @@ -14,14 +14,18 @@ package core -import "context" +import ( + "context" + + "github.com/firebase/genkit/go/internal/common" +) // A FlowStateStore stores flow states. // Every flow state has a unique string identifier. // A durable FlowStateStore is necessary for durable flows. type FlowStateStore interface { // Save saves the FlowState to the store, overwriting an existing one. - Save(ctx context.Context, id string, fs flowStater) error + Save(ctx context.Context, id string, fs common.FlowStater) error // Load reads the FlowState with the given ID from the store. // It returns an error that is fs.ErrNotExist if there isn't one. // pfs must be a pointer to a flowState[I, O] of the correct type. @@ -31,5 +35,5 @@ type FlowStateStore interface { // nopFlowStateStore is a FlowStateStore that does nothing. type nopFlowStateStore struct{} -func (nopFlowStateStore) Save(ctx context.Context, id string, fs flowStater) error { return nil } -func (nopFlowStateStore) Load(ctx context.Context, id string, pfs any) error { return nil } +func (nopFlowStateStore) Save(ctx context.Context, id string, fs common.FlowStater) error { return nil } +func (nopFlowStateStore) Load(ctx context.Context, id string, pfs any) error { return nil } diff --git a/go/core/conformance_test.go b/go/genkit/conformance_test.go similarity index 94% rename from go/core/conformance_test.go rename to go/genkit/conformance_test.go index 9236682ab..3310b4179 100644 --- a/go/core/conformance_test.go +++ b/go/genkit/conformance_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package core +package genkit import ( "cmp" @@ -27,7 +27,10 @@ import ( "testing" "time" + "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/internal" + "github.com/firebase/genkit/go/internal/common" + "github.com/firebase/genkit/go/internal/registry" "golang.org/x/exp/maps" ) @@ -69,7 +72,7 @@ func (c *command) run(ctx context.Context, input string) (string, error) { case c.Append != nil: return input + *c.Append, nil case c.Run != nil: - return InternalRun(ctx, c.Run.Name, func() (string, error) { + return Run(ctx, c.Run.Name, func() (string, error) { return c.Run.Command.run(ctx, input) }) default: @@ -92,7 +95,7 @@ func TestFlowConformance(t *testing.T) { t.Fatal(err) } // Each test uses its own registry to avoid interference. - r, err := newRegistry() + r, err := registry.New() if err != nil { t.Fatal(err) } @@ -113,7 +116,7 @@ func TestFlowConformance(t *testing.T) { if test.Trace == nil { return } - ts := r.lookupTraceStore(EnvironmentDev) + ts := r.LookupTraceStore(common.EnvironmentDev) var gotTrace any if err := ts.LoadAny(resp.Telemetry.TraceID, &gotTrace); err != nil { t.Fatal(err) @@ -128,8 +131,8 @@ func TestFlowConformance(t *testing.T) { } // flowFunction returns a function that runs the list of commands. -func flowFunction(commands []command) Func[string, string, struct{}] { - return func(ctx context.Context, input string, cb NoStream) (string, error) { +func flowFunction(commands []command) core.Func[string, string, struct{}] { + return func(ctx context.Context, input string, cb noStream) (string, error) { result := input var err error for i, cmd := range commands { diff --git a/go/core/flow.go b/go/genkit/flow.go similarity index 80% rename from go/core/flow.go rename to go/genkit/flow.go index cb407203a..2f4767ec7 100644 --- a/go/core/flow.go +++ b/go/genkit/flow.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package core +package genkit import ( "context" @@ -24,10 +24,14 @@ import ( "sync" "time" + "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/logger" "github.com/firebase/genkit/go/core/tracing" "github.com/firebase/genkit/go/internal" "github.com/firebase/genkit/go/internal/atype" + "github.com/firebase/genkit/go/internal/common" + "github.com/firebase/genkit/go/internal/metrics" + "github.com/firebase/genkit/go/internal/registry" "github.com/google/uuid" "github.com/invopop/jsonschema" otrace "go.opentelemetry.io/otel/trace" @@ -81,8 +85,6 @@ import ( // the flow invokes a callback repeatedly. When streaming is complete, the flow // returns a final result in the usual way. // -// A flow that doesn't support streaming can use [NoStream] as its third type parameter. -// // Streaming is only supported for the "start" flow instruction. Currently there is // no way to schedule or resume a flow with streaming. @@ -90,39 +92,63 @@ import ( // A Flow[In, Out, Stream] represents a function from In to Out. The Stream parameter is for // flows that support streaming: providing their results incrementally. type Flow[In, Out, Stream any] struct { - name string // The last component of the flow's key in the registry. - fn Func[In, Out, Stream] // The function to run. - stateStore FlowStateStore // Where FlowStates are stored, to support resumption. - tstate *tracing.State // set from the action when the flow is defined - inputSchema *jsonschema.Schema // Schema of the input to the flow - outputSchema *jsonschema.Schema // Schema of the output out of the flow + name string // The last component of the flow's key in the registry. + fn core.Func[In, Out, Stream] // The function to run. + stateStore core.FlowStateStore // Where FlowStates are stored, to support resumption. + tstate *tracing.State // set from the action when the flow is defined + inputSchema *jsonschema.Schema // Schema of the input to the flow + outputSchema *jsonschema.Schema // Schema of the output out of the flow // TODO(jba): scheduler // TODO(jba): experimentalDurable // TODO(jba): authPolicy // TODO(jba): middleware } -// InternalDefineFlow is for use by genkit.DefineFlow exclusively. -// It is not subject to any backwards compatibility guarantees. -func InternalDefineFlow[In, Out, Stream any](name string, fn Func[In, Out, Stream]) *Flow[In, Out, Stream] { - return defineFlow(globalRegistry, name, fn) -} +type noStream = func(context.Context, struct{}) error -func defineFlow[In, Out, Stream any](r *registry, name string, fn Func[In, Out, Stream]) *Flow[In, Out, Stream] { +// DefineFlow creates a Flow that runs fn, and registers it as an action. +// +// fn takes an input of type In and returns an output of type Out. +func DefineFlow[In, Out any]( + name string, + fn func(ctx context.Context, input In) (Out, error), +) *Flow[In, Out, struct{}] { + return defineFlow(registry.Global, name, core.Func[In, Out, struct{}]( + func(ctx context.Context, input In, cb func(ctx context.Context, _ struct{}) error) (Out, error) { + return fn(ctx, input) + })) +} + +// DefineStreamingFlow creates a streaming Flow that runs fn, and registers it as an action. +// +// fn takes an input of type In and returns an output of type Out, optionally +// streaming values of type Stream incrementally by invoking a callback. +// +// If the function supports streaming and the callback is non-nil, it should +// stream the results by invoking the callback periodically, ultimately returning +// with a final return value that includes all the streamed data. +// Otherwise, it should ignore the callback and just return a result. +func DefineStreamingFlow[In, Out, Stream any]( + name string, + fn func(ctx context.Context, input In, callback func(context.Context, Stream) error) (Out, error), +) *Flow[In, Out, Stream] { + return defineFlow(registry.Global, name, core.Func[In, Out, Stream](fn)) +} + +func defineFlow[In, Out, Stream any](r *registry.Registry, name string, fn core.Func[In, Out, Stream]) *Flow[In, Out, Stream] { var i In var o Out f := &Flow[In, Out, Stream]{ name: name, fn: fn, - inputSchema: inferJSONSchema(i), - outputSchema: inferJSONSchema(o), + inputSchema: internal.InferJSONSchema(i), + outputSchema: internal.InferJSONSchema(o), // TODO(jba): set stateStore? } a := f.action() - r.registerAction(a) - // TODO(jba): this is a roundabout way to transmit the tracing state. Is there a cleaner way? - f.tstate = a.tstate - r.registerFlow(f) + r.RegisterAction(atype.Flow, a) + f.tstate = r.TracingState() + r.RegisterFlow(f) return f } @@ -205,19 +231,11 @@ func newFlowState[In, Out any](id, name string, input In) *flowState[In, Out] { } } -// flowStater is the common type of all flowState[I, O] types. -type flowStater interface { - isFlowState() - lock() - unlock() - cache() map[string]json.RawMessage -} - -// isFlowState implements flowStater. -func (fs *flowState[In, Out]) isFlowState() {} -func (fs *flowState[In, Out]) lock() { fs.mu.Lock() } -func (fs *flowState[In, Out]) unlock() { fs.mu.Unlock() } -func (fs *flowState[In, Out]) cache() map[string]json.RawMessage { return fs.Cache } +// flowState implements common.FlowStater. +func (fs *flowState[In, Out]) IsFlowState() {} +func (fs *flowState[In, Out]) Lock() { fs.mu.Lock() } +func (fs *flowState[In, Out]) Unlock() { fs.mu.Unlock() } +func (fs *flowState[In, Out]) GetCache() map[string]json.RawMessage { return fs.Cache } // An operation describes the state of a Flow that may still be in progress. type operation[Out any] struct { @@ -251,21 +269,21 @@ type FlowResult[Out any] struct { // FlowResult is called FlowResponse in the javascript. // action creates an action for the flow. See the comment at the top of this file for more information. -func (f *Flow[In, Out, Stream]) action() *Action[*flowInstruction[In], *flowState[In, Out], Stream] { +func (f *Flow[In, Out, Stream]) action() *core.Action[*flowInstruction[In], *flowState[In, Out], Stream] { metadata := map[string]any{ "inputSchema": f.inputSchema, "outputSchema": f.outputSchema, } cback := func(ctx context.Context, inst *flowInstruction[In], cb func(context.Context, Stream) error) (*flowState[In, Out], error) { tracing.SetCustomMetadataAttr(ctx, "flow:wrapperAction", "true") - return f.runInstruction(ctx, inst, streamingCallback[Stream](cb)) + return f.runInstruction(ctx, inst, common.StreamingCallback[Stream](cb)) } - return newStreamingAction(f.name, atype.Flow, metadata, cback) + return core.NewStreamingAction(f.name, atype.Flow, metadata, cback) } // runInstruction performs one of several actions on a flow, as determined by msg. // (Called runEnvelope in the js.) -func (f *Flow[In, Out, Stream]) runInstruction(ctx context.Context, inst *flowInstruction[In], cb streamingCallback[Stream]) (*flowState[In, Out], error) { +func (f *Flow[In, Out, Stream]) runInstruction(ctx context.Context, inst *flowInstruction[In], cb common.StreamingCallback[Stream]) (*flowState[In, Out], error) { switch { case inst.Start != nil: // TODO(jba): pass msg.Start.Labels. @@ -285,28 +303,22 @@ func (f *Flow[In, Out, Stream]) runInstruction(ctx context.Context, inst *flowIn } } -// flow is the type that all Flow[In, Out, Stream] have in common. -type flow interface { - Name() string - - // runJSON uses encoding/json to unmarshal the input, - // calls Flow.start, then returns the marshaled result. - runJSON(ctx context.Context, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error) -} +// The following methods make Flow[I, O, S] implement the flow interface, define in servers.go. +// Name returns the name that the flow was defined with. func (f *Flow[In, Out, Stream]) Name() string { return f.name } -func (f *Flow[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error) { +func (f *Flow[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessage, cb common.StreamingCallback[json.RawMessage]) (json.RawMessage, error) { // Validate input before unmarshaling it because invalid or unknown fields will be discarded in the process. - if err := validateJSON(input, f.inputSchema); err != nil { - return nil, &httpError{http.StatusBadRequest, err} + if err := internal.ValidateJSON(input, f.inputSchema); err != nil { + return nil, &internal.HTTPError{Code: http.StatusBadRequest, Err: err} } var in In if err := json.Unmarshal(input, &in); err != nil { - return nil, &httpError{http.StatusBadRequest, err} + return nil, &internal.HTTPError{Code: http.StatusBadRequest, Err: err} } // If there is a callback, wrap it to turn an S into a json.RawMessage. - var callback streamingCallback[Stream] + var callback common.StreamingCallback[Stream] if cb != nil { callback = func(ctx context.Context, s Stream) error { bytes, err := json.Marshal(s) @@ -334,7 +346,7 @@ func (f *Flow[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessa } // start starts executing the flow with the given input. -func (f *Flow[In, Out, Stream]) start(ctx context.Context, input In, cb streamingCallback[Stream]) (_ *flowState[In, Out], err error) { +func (f *Flow[In, Out, Stream]) start(ctx context.Context, input In, cb common.StreamingCallback[Stream]) (_ *flowState[In, Out], err error) { flowID, err := generateFlowID() if err != nil { return nil, err @@ -351,7 +363,7 @@ func (f *Flow[In, Out, Stream]) start(ctx context.Context, input In, cb streamin // // This function corresponds to Flow.executeSteps in the js, but does more: // it creates the flowContext and saves the state. -func (f *Flow[In, Out, Stream]) execute(ctx context.Context, state *flowState[In, Out], dispatchType string, cb streamingCallback[Stream]) { +func (f *Flow[In, Out, Stream]) execute(ctx context.Context, state *flowState[In, Out], dispatchType string, cb common.StreamingCallback[Stream]) { fctx := newFlowContext(state, f.stateStore, f.tstate) defer func() { if err := fctx.finish(ctx); err != nil { @@ -381,14 +393,14 @@ func (f *Flow[In, Out, Stream]) execute(ctx context.Context, state *flowState[In // TODO(jba): If input is missing, get it from state.input and overwrite metadata.input. start := time.Now() var err error - if err = validateValue(input, f.inputSchema); err != nil { + if err = internal.ValidateValue(input, f.inputSchema); err != nil { err = fmt.Errorf("invalid input: %w", err) } var output Out if err == nil { output, err = f.fn(ctx, input, cb) if err == nil { - if err = validateValue(output, f.outputSchema); err != nil { + if err = internal.ValidateValue(output, f.outputSchema); err != nil { err = fmt.Errorf("invalid output: %w", err) } } @@ -400,11 +412,11 @@ func (f *Flow[In, Out, Stream]) execute(ctx context.Context, state *flowState[In "path", tracing.SpanPath(ctx), "err", err.Error(), ) - writeFlowFailure(ctx, f.name, latency, err) + metrics.WriteFlowFailure(ctx, f.name, latency, err) tracing.SetCustomMetadataAttr(ctx, "flow:state", "error") } else { logger.FromContext(ctx).Info("flow succeeded", "path", tracing.SpanPath(ctx)) - writeFlowSuccess(ctx, f.name, latency) + metrics.WriteFlowSuccess(ctx, f.name, latency) tracing.SetCustomMetadataAttr(ctx, "flow:state", "done") } @@ -441,7 +453,7 @@ func generateFlowID() (string, error) { // in a context.Context so it can be accessed from within the currrently active flow. type flowContext[I, O any] struct { state *flowState[I, O] - stateStore FlowStateStore + stateStore core.FlowStateStore tstate *tracing.State mu sync.Mutex seenSteps map[string]int // number of times each name appears, to avoid duplicate names @@ -451,11 +463,11 @@ type flowContext[I, O any] struct { // flowContexter is the type of all flowContext[I, O]. type flowContexter interface { uniqueStepName(string) string - stater() flowStater + stater() common.FlowStater tracingState() *tracing.State } -func newFlowContext[I, O any](state *flowState[I, O], store FlowStateStore, tstate *tracing.State) *flowContext[I, O] { +func newFlowContext[I, O any](state *flowState[I, O], store core.FlowStateStore, tstate *tracing.State) *flowContext[I, O] { return &flowContext[I, O]{ state: state, stateStore: store, @@ -463,7 +475,7 @@ func newFlowContext[I, O any](state *flowState[I, O], store FlowStateStore, tsta seenSteps: map[string]int{}, } } -func (fc *flowContext[I, O]) stater() flowStater { return fc.state } +func (fc *flowContext[I, O]) stater() common.FlowStater { return fc.state } func (fc *flowContext[I, O]) tracingState() *tracing.State { return fc.tstate } // finish is called at the end of a flow execution. @@ -489,9 +501,14 @@ func (fc *flowContext[I, O]) uniqueStepName(name string) string { var flowContextKey = internal.NewContextKey[flowContexter]() -// InternalRun is for use by genkit.Run exclusively. -// It is not subject to any backwards compatibility guarantees. -func InternalRun[Out any](ctx context.Context, name string, f func() (Out, error)) (Out, error) { +// Run runs the function f in the context of the current flow +// and returns what f returns. +// It returns an error if no flow is active. +// +// Each call to Run results in a new step in the flow. +// A step has its own span in the trace, and its result is cached so that if the flow +// is restarted, f will not be called a second time. +func Run[Out any](ctx context.Context, name string, f func() (Out, error)) (Out, error) { // from js/flow/src/steps.ts fc := flowContextKey.FromContext(ctx) if fc == nil { @@ -511,9 +528,9 @@ func InternalRun[Out any](ctx context.Context, name string, f func() (Out, error // happen because every step has a unique cache key. // TODO(jba): don't memoize a nested flow (see context.ts) fs := fc.stater() - fs.lock() - j, ok := fs.cache()[uName] - fs.unlock() + fs.Lock() + j, ok := fs.GetCache()[uName] + fs.Unlock() if ok { var t Out if err := json.Unmarshal(j, &t); err != nil { @@ -530,9 +547,9 @@ func InternalRun[Out any](ctx context.Context, name string, f func() (Out, error if err != nil { return internal.Zero[Out](), err } - fs.lock() - fs.cache()[uName] = json.RawMessage(bytes) - fs.unlock() + fs.Lock() + fs.GetCache()[uName] = json.RawMessage(bytes) + fs.Unlock() tracing.SetCustomMetadataAttr(ctx, "flow:state", "run") return t, nil }) diff --git a/go/core/flow_test.go b/go/genkit/flow_test.go similarity index 88% rename from go/core/flow_test.go rename to go/genkit/flow_test.go index 843da267a..02bf0df0c 100644 --- a/go/core/flow_test.go +++ b/go/genkit/flow_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package core +package genkit import ( "context" @@ -21,17 +21,19 @@ import ( "slices" "testing" + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/internal/registry" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" ) -func incFlow(_ context.Context, i int, _ NoStream) (int, error) { +func incFlow(_ context.Context, i int, _ noStream) (int, error) { return i + 1, nil } func TestFlowStart(t *testing.T) { - f := InternalDefineFlow("inc", incFlow) - ss, err := NewFileFlowStateStore(t.TempDir()) + f := DefineStreamingFlow("inc", incFlow) + ss, err := core.NewFileFlowStateStore(t.TempDir()) if err != nil { t.Fatal(err) } @@ -62,12 +64,12 @@ func TestFlowRun(t *testing.T) { return n, nil } - flow := InternalDefineFlow("run", func(ctx context.Context, s string, _ NoStream) ([]int, error) { - g1, err := InternalRun(ctx, "s1", stepf) + flow := DefineFlow("run", func(ctx context.Context, s string) ([]int, error) { + g1, err := Run(ctx, "s1", stepf) if err != nil { return nil, err } - g2, err := InternalRun(ctx, "s2", stepf) + g2, err := Run(ctx, "s2", stepf) if err != nil { return nil, err } @@ -89,7 +91,7 @@ func TestFlowRun(t *testing.T) { } func TestRunFlow(t *testing.T) { - reg, err := newRegistry() + reg, err := registry.New() if err != nil { t.Fatal(err) } diff --git a/go/core/gen.go b/go/genkit/gen.go similarity index 98% rename from go/core/gen.go rename to go/genkit/gen.go index 2e9864499..325b101ee 100644 --- a/go/core/gen.go +++ b/go/genkit/gen.go @@ -14,7 +14,7 @@ // This file was generated by jsonschemagen. DO NOT EDIT. -package core +package genkit import "github.com/firebase/genkit/go/core/tracing" diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 29ce9e096..0eb933d8a 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -17,9 +17,16 @@ package genkit import ( "context" + "fmt" + "log/slog" "net/http" + "os" + "os/signal" + "sync" + "syscall" - "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/internal/common" + "github.com/firebase/genkit/go/internal/registry" ) // Options are options to [Init]. @@ -50,60 +57,67 @@ type Options struct { // Thus Init(nil) will start a dev server in the "dev" environment, will always start // a flow server, and will pause execution until the flow server terminates. func Init(ctx context.Context, opts *Options) error { - return core.InternalInit(ctx, (*core.Options)(opts)) -} + if opts == nil { + opts = &Options{} + } + registry.Global.Freeze() -// DefineFlow creates a Flow that runs fn, and registers it as an action. -// -// fn takes an input of type In and returns an output of type Out. -func DefineFlow[In, Out any]( - name string, - fn func(ctx context.Context, input In) (Out, error), -) *core.Flow[In, Out, struct{}] { - return core.InternalDefineFlow(name, core.Func[In, Out, struct{}](func(ctx context.Context, input In, cb func(ctx context.Context, _ struct{}) error) (Out, error) { - return fn(ctx, input) - })) -} + var mu sync.Mutex + var servers []*http.Server + var wg sync.WaitGroup + errCh := make(chan error, 2) -// DefineStreamingFlow creates a streaming Flow that runs fn, and registers it as an action. -// -// fn takes an input of type In and returns an output of type Out, optionally -// streaming values of type Stream incrementally by invoking a callback. -// Pass [NoStream] for functions that do not support streaming. -// -// If the function supports streaming and the callback is non-nil, it should -// stream the results by invoking the callback periodically, ultimately returning -// with a final return value. Otherwise, it should ignore the callback and -// just return a result. -func DefineStreamingFlow[In, Out, Stream any]( - name string, - fn func(ctx context.Context, input In, callback func(context.Context, Stream) error) (Out, error), -) *core.Flow[In, Out, Stream] { - return core.InternalDefineFlow(name, core.Func[In, Out, Stream](fn)) -} + if common.CurrentEnvironment() == common.EnvironmentDev { + wg.Add(1) + go func() { + defer wg.Done() + s := startReflectionServer(errCh) + mu.Lock() + servers = append(servers, s) + mu.Unlock() + }() + } -// Run runs the function f in the context of the current flow -// and returns what f returns. -// It returns an error if no flow is active. -// -// Each call to Run results in a new step in the flow. -// A step has its own span in the trace, and its result is cached so that if the flow -// is restarted, f will not be called a second time. -func Run[Out any](ctx context.Context, name string, f func() (Out, error)) (Out, error) { - return core.InternalRun(ctx, name, f) -} + if opts.FlowAddr != "-" { + wg.Add(1) + go func() { + defer wg.Done() + s := startFlowServer(opts.FlowAddr, opts.Flows, errCh) + mu.Lock() + servers = append(servers, s) + mu.Unlock() + }() + } -// NewFlowServeMux constructs a [net/http.ServeMux]. -// If flows is non-empty, the each of the named flows is registered as a route. -// Otherwise, all defined flows are registered. -// All routes take a single query parameter, "stream", which if true will stream the -// flow's results back to the client. (Not all flows support streaming, however.) -// -// To use the returned ServeMux as part of a server with other routes, either add routes -// to it, or install it as part of another ServeMux, like so: -// -// mainMux := http.NewServeMux() -// mainMux.Handle("POST /flow/", http.StripPrefix("/flow/", NewFlowServeMux())) -func NewFlowServeMux(flows []string) *http.ServeMux { - return core.NewFlowServeMux(flows) + serverStartCh := make(chan struct{}) + go func() { + wg.Wait() + close(serverStartCh) + }() + + // It will block here until either all servers start up or there is an error in starting one. + select { + case <-serverStartCh: + slog.Info("all servers started successfully") + case err := <-errCh: + return fmt.Errorf("failed to start servers: %w", err) + case <-ctx.Done(): + return ctx.Err() + } + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + // It will block here (i.e. servers will run) until we get an interrupt signal. + select { + case sig := <-sigCh: + slog.Info("received signal, initiating shutdown", "signal", sig) + case err := <-errCh: + slog.Error("server error", "err", err) + return err + case <-ctx.Done(): + slog.Info("context cancelled, initiating shutdown") + } + + return shutdownServers(servers) } diff --git a/go/genkit/genkit_test.go b/go/genkit/genkit_test.go index 6cedc1409..adc818cdf 100644 --- a/go/genkit/genkit_test.go +++ b/go/genkit/genkit_test.go @@ -17,15 +17,13 @@ package genkit import ( "context" "testing" - - "github.com/firebase/genkit/go/core" ) func TestStreamFlow(t *testing.T) { f := DefineStreamingFlow("count", count) iter := f.Stream(context.Background(), 2) want := 0 - iter(func(val *core.StreamFlowValue[int, int], err error) bool { + iter(func(val *StreamFlowValue[int, int], err error) bool { if err != nil { t.Fatal(err) } diff --git a/go/core/servers.go b/go/genkit/servers.go similarity index 73% rename from go/core/servers.go rename to go/genkit/servers.go index 90a5ac299..2827eb0e4 100644 --- a/go/core/servers.go +++ b/go/genkit/servers.go @@ -20,7 +20,7 @@ // The production server has a route for each flow. It // is intended for production deployments. -package core +package genkit import ( "context" @@ -41,84 +41,19 @@ import ( "github.com/firebase/genkit/go/core/logger" "github.com/firebase/genkit/go/core/tracing" + "github.com/firebase/genkit/go/internal" + "github.com/firebase/genkit/go/internal/common" + "github.com/firebase/genkit/go/internal/registry" "go.opentelemetry.io/otel/trace" ) -// InternalInit is for use by the genkit package only. -// It is not subject to compatibility guarantees. -func InternalInit(ctx context.Context, opts *Options) error { - if opts == nil { - opts = &Options{} - } - globalRegistry.freeze() - - var mu sync.Mutex - var servers []*http.Server - var wg sync.WaitGroup - errCh := make(chan error, 2) - - if currentEnvironment() == EnvironmentDev { - wg.Add(1) - go func() { - defer wg.Done() - s := startReflectionServer(errCh) - mu.Lock() - servers = append(servers, s) - mu.Unlock() - }() - } - - if opts.FlowAddr != "-" { - wg.Add(1) - go func() { - defer wg.Done() - s := startFlowServer(opts.FlowAddr, opts.Flows, errCh) - mu.Lock() - servers = append(servers, s) - mu.Unlock() - }() - } - - serverStartCh := make(chan struct{}) - go func() { - wg.Wait() - close(serverStartCh) - }() - - // It will block here until either all servers start up or there is an error in starting one. - select { - case <-serverStartCh: - slog.Info("all servers started successfully") - case err := <-errCh: - return fmt.Errorf("failed to start servers: %w", err) - case <-ctx.Done(): - return ctx.Err() - } - - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) - - // It will block here (i.e. servers will run) until we get an interrupt signal. - select { - case sig := <-sigCh: - slog.Info("received signal, initiating shutdown", "signal", sig) - case err := <-errCh: - slog.Error("server error", "err", err) - return err - case <-ctx.Done(): - slog.Info("context cancelled, initiating shutdown") - } - - return shutdownServers(servers) -} - // startReflectionServer starts the Reflection API server listening at the // value of the environment variable GENKIT_REFLECTION_PORT for the port, // or ":3100" if it is empty. func startReflectionServer(errCh chan<- error) *http.Server { slog.Info("starting reflection server") addr := serverAddress("", "GENKIT_REFLECTION_PORT", "127.0.0.1:3100") - mux := newDevServeMux(globalRegistry) + mux := newDevServeMux(registry.Global) return startServer(addr, mux, errCh) } @@ -135,6 +70,15 @@ func startFlowServer(addr string, flows []string, errCh chan<- error) *http.Serv return startServer(addr, mux, errCh) } +// flow is the type that all Flow[In, Out, Stream] have in common. +type flow interface { + Name() string + + // runJSON uses encoding/json to unmarshal the input, + // calls Flow.start, then returns the marshaled result. + runJSON(ctx context.Context, input json.RawMessage, cb common.StreamingCallback[json.RawMessage]) (json.RawMessage, error) +} + // startServer starts an HTTP server listening on the address. // It returns the server an func startServer(addr string, handler http.Handler, errCh chan<- error) *http.Server { @@ -188,22 +132,11 @@ func shutdownServers(servers []*http.Server) error { return nil } -// Options are options to [InternalInit]. -type Options struct { - // If "-", do not start a FlowServer. - // Otherwise, start a FlowServer on the given address, or the - // default if empty. - FlowAddr string - // The names of flows to serve. - // If empty, all registered flows are served. - Flows []string -} - type devServer struct { - reg *registry + reg *registry.Registry } -func newDevServeMux(r *registry) *http.ServeMux { +func newDevServeMux(r *registry.Registry) *http.ServeMux { mux := http.NewServeMux() s := &devServer{r} handle(mux, "GET /api/__health", func(w http.ResponseWriter, _ *http.Request) error { @@ -228,7 +161,7 @@ func (s *devServer) handleRunAction(w http.ResponseWriter, r *http.Request) erro } defer r.Body.Close() if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - return &httpError{http.StatusBadRequest, err} + return &internal.HTTPError{Code: http.StatusBadRequest, Err: err} } stream := false if s := r.FormValue("stream"); s != "" { @@ -241,7 +174,7 @@ func (s *devServer) handleRunAction(w http.ResponseWriter, r *http.Request) erro logger.FromContext(ctx).Debug("running action", "key", body.Key, "stream", stream) - var callback streamingCallback[json.RawMessage] + var callback common.StreamingCallback[json.RawMessage] if stream { // Stream results are newline-separated JSON. callback = func(ctx context.Context, msg json.RawMessage) error { @@ -271,16 +204,16 @@ type telemetry struct { TraceID string `json:"traceId"` } -func runAction(ctx context.Context, reg *registry, key string, input json.RawMessage, cb streamingCallback[json.RawMessage]) (*runActionResponse, error) { - action := reg.lookupAction(key) +func runAction(ctx context.Context, reg *registry.Registry, key string, input json.RawMessage, cb common.StreamingCallback[json.RawMessage]) (*runActionResponse, error) { + action := reg.LookupAction(key) if action == nil { - return nil, &httpError{http.StatusNotFound, fmt.Errorf("no action with key %q", key)} + return nil, &internal.HTTPError{Code: http.StatusNotFound, Err: fmt.Errorf("no action with key %q", key)} } var traceID string - output, err := tracing.RunInNewSpan(ctx, reg.tstate, "dev-run-action-wrapper", "", true, input, func(ctx context.Context, input json.RawMessage) (json.RawMessage, error) { + output, err := tracing.RunInNewSpan(ctx, reg.TracingState(), "dev-run-action-wrapper", "", true, input, func(ctx context.Context, input json.RawMessage) (json.RawMessage, error) { tracing.SetCustomMetadataAttr(ctx, "genkit-dev-internal", "true") traceID = trace.SpanContextFromContext(ctx).TraceID().String() - return action.runJSON(ctx, input, cb) + return action.RunJSON(ctx, input, cb) }) if err != nil { return nil, err @@ -293,8 +226,8 @@ func runAction(ctx context.Context, reg *registry, key string, input json.RawMes // handleListActions lists all the registered actions. func (s *devServer) handleListActions(w http.ResponseWriter, r *http.Request) error { - descs := s.reg.listActions() - descMap := map[string]actionDesc{} + descs := s.reg.ListActions() + descMap := map[string]common.ActionDesc{} for _, d := range descs { descMap[d.Key] = d } @@ -304,14 +237,14 @@ func (s *devServer) handleListActions(w http.ResponseWriter, r *http.Request) er // handleGetTrace returns a single trace from a TraceStore. func (s *devServer) handleGetTrace(w http.ResponseWriter, r *http.Request) error { env := r.PathValue("env") - ts := s.reg.lookupTraceStore(Environment(env)) + ts := s.reg.LookupTraceStore(common.Environment(env)) if ts == nil { - return &httpError{http.StatusNotFound, fmt.Errorf("no TraceStore for environment %q", env)} + return &internal.HTTPError{Code: http.StatusNotFound, Err: fmt.Errorf("no TraceStore for environment %q", env)} } tid := r.PathValue("traceID") td, err := ts.Load(r.Context(), tid) if errors.Is(err, fs.ErrNotExist) { - return &httpError{http.StatusNotFound, fmt.Errorf("no %s trace with ID %q", env, tid)} + return &internal.HTTPError{Code: http.StatusNotFound, Err: fmt.Errorf("no %s trace with ID %q", env, tid)} } if err != nil { return err @@ -322,22 +255,22 @@ func (s *devServer) handleGetTrace(w http.ResponseWriter, r *http.Request) error // handleListTraces returns a list of traces from a TraceStore. func (s *devServer) handleListTraces(w http.ResponseWriter, r *http.Request) error { env := r.PathValue("env") - ts := s.reg.lookupTraceStore(Environment(env)) + ts := s.reg.LookupTraceStore(common.Environment(env)) if ts == nil { - return &httpError{http.StatusNotFound, fmt.Errorf("no TraceStore for environment %q", env)} + return &internal.HTTPError{Code: http.StatusNotFound, Err: fmt.Errorf("no TraceStore for environment %q", env)} } limit := 0 if lim := r.FormValue("limit"); lim != "" { var err error limit, err = strconv.Atoi(lim) if err != nil { - return &httpError{http.StatusBadRequest, err} + return &internal.HTTPError{Code: http.StatusBadRequest, Err: err} } } ctoken := r.FormValue("continuationToken") tds, ctoken, err := ts.List(r.Context(), &tracing.Query{Limit: limit, ContinuationToken: ctoken}) if errors.Is(err, tracing.ErrBadQuery) { - return &httpError{http.StatusBadRequest, err} + return &internal.HTTPError{Code: http.StatusBadRequest, Err: err} } if err != nil { return err @@ -354,13 +287,12 @@ type listTracesResult struct { } func (s *devServer) handleListFlowStates(w http.ResponseWriter, r *http.Request) error { - // TODO(jba): implement. - return writeJSON(r.Context(), w, listFlowStatesResult{[]flowStater{}, ""}) + return writeJSON(r.Context(), w, listFlowStatesResult{[]common.FlowStater{}, ""}) } type listFlowStatesResult struct { - FlowStates []flowStater `json:"flowStates"` - ContinuationToken string `json:"continuationToken"` + FlowStates []common.FlowStater `json:"flowStates"` + ContinuationToken string `json:"continuationToken"` } // NewFlowServeMux constructs a [net/http.ServeMux]. @@ -376,16 +308,17 @@ type listFlowStatesResult struct { // mainMux := http.NewServeMux() // mainMux.Handle("POST /flow/", http.StripPrefix("/flow/", NewFlowServeMux())) func NewFlowServeMux(flows []string) *http.ServeMux { - return newFlowServeMux(globalRegistry, flows) + return newFlowServeMux(registry.Global, flows) } -func newFlowServeMux(r *registry, flows []string) *http.ServeMux { +func newFlowServeMux(r *registry.Registry, flows []string) *http.ServeMux { mux := http.NewServeMux() m := map[string]bool{} for _, f := range flows { m[f] = true } - for _, f := range r.listFlows() { + for _, f := range r.ListFlows() { + f := f.(flow) if len(flows) == 0 || m[f.Name()] { handle(mux, "POST /"+f.Name(), nonDurableFlowHandler(f)) } @@ -406,7 +339,7 @@ func nonDurableFlowHandler(f flow) func(http.ResponseWriter, *http.Request) erro } if stream { // TODO(jba): implement streaming. - return &httpError{http.StatusNotImplemented, errors.New("streaming")} + return &internal.HTTPError{Code: http.StatusNotImplemented, Err: errors.New("streaming")} } else { // TODO(jba): telemetry out, err := f.runJSON(r.Context(), json.RawMessage(input), nil) @@ -481,9 +414,9 @@ func handle(mux *http.ServeMux, pattern string, f func(w http.ResponseWriter, r if err != nil { // If the error is an httpError, serve the status code it contains. // Otherwise, assume this is an unexpected error and serve a 500. - var herr *httpError + var herr *internal.HTTPError if errors.As(err, &herr) { - http.Error(w, herr.Error(), herr.code) + http.Error(w, herr.Error(), herr.Code) } else { http.Error(w, err.Error(), http.StatusInternalServerError) } @@ -491,34 +424,18 @@ func handle(mux *http.ServeMux, pattern string, f func(w http.ResponseWriter, r }) } -type httpError struct { - code int - err error -} - -func (e *httpError) Error() string { - return fmt.Sprintf("%s: %s", http.StatusText(e.code), e.err) -} - func parseBoolQueryParam(r *http.Request, name string) (bool, error) { b := false if s := r.FormValue(name); s != "" { var err error b, err = strconv.ParseBool(s) if err != nil { - return false, &httpError{http.StatusBadRequest, err} + return false, &internal.HTTPError{Code: http.StatusBadRequest, Err: err} } } return b, nil } -func currentEnvironment() Environment { - if v := os.Getenv("GENKIT_ENV"); v != "" { - return Environment(v) - } - return EnvironmentProd -} - func writeJSON(ctx context.Context, w http.ResponseWriter, value any) error { data, err := json.Marshal(value) if err != nil { diff --git a/go/core/servers_test.go b/go/genkit/servers_test.go similarity index 88% rename from go/core/servers_test.go rename to go/genkit/servers_test.go index 130566399..88d8c0b0b 100644 --- a/go/core/servers_test.go +++ b/go/genkit/servers_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package core +package genkit import ( "context" @@ -23,26 +23,33 @@ import ( "strings" "testing" + "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/tracing" "github.com/firebase/genkit/go/internal/atype" + "github.com/firebase/genkit/go/internal/common" + "github.com/firebase/genkit/go/internal/registry" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/invopop/jsonschema" ) +func inc(_ context.Context, x int) (int, error) { + return x + 1, nil +} + func dec(_ context.Context, x int) (int, error) { return x - 1, nil } func TestDevServer(t *testing.T) { - r, err := newRegistry() + r, err := registry.New() if err != nil { t.Fatal(err) } - r.registerAction(newAction("devServer/inc", atype.Custom, map[string]any{ + r.RegisterAction(atype.Custom, core.NewAction("devServer/inc", atype.Custom, map[string]any{ "foo": "bar", }, inc)) - r.registerAction(newAction("devServer/dec", atype.Custom, map[string]any{ + r.RegisterAction(atype.Custom, core.NewAction("devServer/dec", atype.Custom, map[string]any{ "bar": "baz", }, dec)) srv := httptest.NewServer(newDevServeMux(r)) @@ -80,11 +87,11 @@ func TestDevServer(t *testing.T) { if res.StatusCode != 200 { t.Fatalf("got status %d, wanted 200", res.StatusCode) } - got, err := readJSON[map[string]actionDesc](res.Body) + got, err := readJSON[map[string]common.ActionDesc](res.Body) if err != nil { t.Fatal(err) } - want := map[string]actionDesc{ + want := map[string]common.ActionDesc{ "/custom/devServer/inc": { Key: "/custom/devServer/inc", Name: "devServer/inc", @@ -122,11 +129,11 @@ func TestDevServer(t *testing.T) { } func TestProdServer(t *testing.T) { - r, err := newRegistry() + r, err := registry.New() if err != nil { t.Fatal(err) } - defineFlow(r, "inc", func(_ context.Context, i int, _ NoStream) (int, error) { + defineFlow(r, "inc", func(_ context.Context, i int, _ noStream) (int, error) { return i + 1, nil }) srv := httptest.NewServer(newFlowServeMux(r, nil)) @@ -160,8 +167,8 @@ func TestProdServer(t *testing.T) { t.Run("bad", func(t *testing.T) { check(t, "true", 400, 0) }) } -func checkActionTrace(t *testing.T, reg *registry, tid, name string) { - ts := reg.lookupTraceStore(EnvironmentDev) +func checkActionTrace(t *testing.T, reg *registry.Registry, tid, name string) { + ts := reg.LookupTraceStore(common.EnvironmentDev) td, err := ts.Load(context.Background(), tid) if err != nil { t.Fatal(err) diff --git a/go/core/testdata/conformance/basic.json b/go/genkit/testdata/conformance/basic.json similarity index 100% rename from go/core/testdata/conformance/basic.json rename to go/genkit/testdata/conformance/basic.json diff --git a/go/core/testdata/conformance/run-1.json b/go/genkit/testdata/conformance/run-1.json similarity index 100% rename from go/core/testdata/conformance/run-1.json rename to go/genkit/testdata/conformance/run-1.json diff --git a/go/internal/common/common.go b/go/internal/common/common.go new file mode 100644 index 000000000..6c2626da0 --- /dev/null +++ b/go/internal/common/common.go @@ -0,0 +1,79 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "context" + "encoding/json" + "os" + + "github.com/firebase/genkit/go/core/tracing" + "github.com/invopop/jsonschema" +) + +// Action is the type that all Action[I, O, S] have in common. +type Action interface { + Name() string + + // RunJSON uses encoding/json to unmarshal the input, + // calls Action.Run, then returns the marshaled result. + RunJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) + + // Desc returns a description of the action. + // It should set all fields of actionDesc except Key, which + // the registry will set. + Desc() ActionDesc + + // SetTracingState set's the action's tracing.State. + SetTracingState(*tracing.State) +} + +// An ActionDesc is a description of an Action. +// It is used to provide a list of registered actions. +type ActionDesc struct { + Key string `json:"key"` // full key from the registry + Name string `json:"name"` + Description string `json:"description"` + Metadata map[string]any `json:"metadata"` + InputSchema *jsonschema.Schema `json:"inputSchema"` + OutputSchema *jsonschema.Schema `json:"outputSchema"` +} + +// An Environment is the execution context in which the program is running. +type Environment string + +const ( + EnvironmentDev Environment = "dev" // development: testing, debugging, etc. + EnvironmentProd Environment = "prod" // production: user data, SLOs, etc. +) + +// CurentEnvironment returns the currently active environment. +func CurrentEnvironment() Environment { + if v := os.Getenv("GENKIT_ENV"); v != "" { + return Environment(v) + } + return EnvironmentProd +} + +// FlowStater is the common type of all flowState[I, O] types. +type FlowStater interface { + IsFlowState() + Lock() + Unlock() + GetCache() map[string]json.RawMessage +} + +// StreamingCallback is the type of streaming callbacks. +type StreamingCallback[Stream any] func(context.Context, Stream) error diff --git a/go/internal/doc-snippets/flows.go b/go/internal/doc-snippets/flows.go index be4a4e14c..c5a8fc0dc 100644 --- a/go/internal/doc-snippets/flows.go +++ b/go/internal/doc-snippets/flows.go @@ -21,7 +21,6 @@ import ( "net/http" "strings" - "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/genkit" ) @@ -105,7 +104,7 @@ func f3() { menuSuggestionFlow.Stream( context.Background(), "French", - )(func(sfv *core.StreamFlowValue[OutputType, StreamType], err error) bool { + )(func(sfv *genkit.StreamFlowValue[OutputType, StreamType], err error) bool { if !sfv.Done { fmt.Print(sfv.Output) return true diff --git a/go/internal/json.go b/go/internal/json.go index 49f7f1b7c..5619200a8 100644 --- a/go/internal/json.go +++ b/go/internal/json.go @@ -19,6 +19,8 @@ import ( "errors" "fmt" "os" + + "github.com/invopop/jsonschema" ) // JSONString returns json.Marshal(x) as a string. If json.Marshal returns @@ -55,3 +57,13 @@ func ReadJSONFile(filename string, pvalue any) error { defer f.Close() return json.NewDecoder(f).Decode(pvalue) } + +func InferJSONSchema(x any) (s *jsonschema.Schema) { + r := jsonschema.Reflector{ + DoNotReference: true, + } + s = r.Reflect(x) + // TODO: Unwind this change once Monaco Editor supports newer than JSON schema draft-07. + s.Version = "" + return s +} diff --git a/go/core/metrics.go b/go/internal/metrics/metrics.go similarity index 89% rename from go/core/metrics.go rename to go/internal/metrics/metrics.go index 117f53a12..63df0196f 100644 --- a/go/core/metrics.go +++ b/go/internal/metrics/metrics.go @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package core +package metrics import ( "context" + "log/slog" "sync" "time" - "github.com/firebase/genkit/go/core/logger" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" @@ -38,7 +38,7 @@ var fetchInstruments = sync.OnceValue(func() *metricInstruments { insts, err := initInstruments() if err != nil { // Do not stop the program because we can't collect metrics. - logger.FromContext(context.Background()).Error("metric initialization failed; no metrics will be collected", "err", err) + slog.Default().Error("metric initialization failed; no metrics will be collected", "err", err) return nil } return insts @@ -67,13 +67,13 @@ func initInstruments() (*metricInstruments, error) { return insts, nil } -func writeActionSuccess(ctx context.Context, actionName string, latency time.Duration) { +func WriteActionSuccess(ctx context.Context, actionName string, latency time.Duration) { recordAction(ctx, latency, attribute.String("name", actionName), attribute.String("source", "go")) } -func writeActionFailure(ctx context.Context, actionName string, latency time.Duration, err error) { +func WriteActionFailure(ctx context.Context, actionName string, latency time.Duration, err error) { recordAction(ctx, latency, attribute.String("name", actionName), attribute.Int("errorCode", errorCode(err)), // TODO(jba): Mitigate against high-cardinality dimensions that arise from @@ -97,13 +97,13 @@ func recordAction(ctx context.Context, latency time.Duration, attrs ...attribute } } -func writeFlowSuccess(ctx context.Context, flowName string, latency time.Duration) { +func WriteFlowSuccess(ctx context.Context, flowName string, latency time.Duration) { recordFlow(ctx, latency, attribute.String("name", flowName), attribute.String("source", "go")) } -func writeFlowFailure(ctx context.Context, flowName string, latency time.Duration, err error) { +func WriteFlowFailure(ctx context.Context, flowName string, latency time.Duration, err error) { recordAction(ctx, latency, attribute.String("name", flowName), attribute.Int("errorCode", errorCode(err)), // TODO(jba): Mitigate against high-cardinality dimensions that arise from diff --git a/go/internal/misc.go b/go/internal/misc.go index d13cd5985..db64756e9 100644 --- a/go/internal/misc.go +++ b/go/internal/misc.go @@ -14,7 +14,19 @@ package internal -import "net/url" +import ( + "fmt" + "net/http" + "net/url" +) + +// An Environment is the execution context in which the program is running. +type Environment string + +const ( + EnvironmentDev Environment = "dev" // development: testing, debugging, etc. + EnvironmentProd Environment = "prod" // production: user data, SLOs, etc. +) // Zero returns the Zero value for T. func Zero[T any]() T { @@ -26,3 +38,13 @@ func Zero[T any]() T { func Clean(id string) string { return url.PathEscape(id) } + +// HTTPError is an error that includes an HTTP status code. +type HTTPError struct { + Code int + Err error +} + +func (e *HTTPError) Error() string { + return fmt.Sprintf("%s: %s", http.StatusText(e.Code), e.Err) +} diff --git a/go/core/registry.go b/go/internal/registry/registry.go similarity index 54% rename from go/core/registry.go rename to go/internal/registry/registry.go index 30060d98a..e33812404 100644 --- a/go/core/registry.go +++ b/go/internal/registry/registry.go @@ -12,10 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -package core +package registry import ( - "context" "crypto/md5" "fmt" "log" @@ -27,6 +26,7 @@ import ( "github.com/firebase/genkit/go/core/tracing" "github.com/firebase/genkit/go/internal/atype" + "github.com/firebase/genkit/go/internal/common" sdktrace "go.opentelemetry.io/otel/sdk/trace" "golang.org/x/exp/maps" ) @@ -35,44 +35,46 @@ import ( // The global registry, used in non-test code. // A test may create their own registries to avoid conflicting with other tests. -var globalRegistry *registry +var Global *Registry func init() { // Initialize the global registry, along with a dev tracer, at program startup. var err error - globalRegistry, err = newRegistry() + Global, err = New() if err != nil { log.Fatal(err) } } -type registry struct { +type Registry struct { tstate *tracing.State mu sync.Mutex frozen bool // when true, no more additions - actions map[string]action - flows []flow - // TraceStores, at most one for each [Environment]. + actions map[string]common.Action + flows []Flow + // TraceStores, at most one for each [common.Environment]. // Only the prod trace store is actually registered; the dev one is // always created automatically. But it's simpler if we keep them together here. - traceStores map[Environment]tracing.Store + traceStores map[common.Environment]tracing.Store } -func newRegistry() (*registry, error) { - r := ®istry{ - actions: map[string]action{}, - traceStores: map[Environment]tracing.Store{}, +func New() (*Registry, error) { + r := &Registry{ + actions: map[string]common.Action{}, + traceStores: map[common.Environment]tracing.Store{}, } tstore, err := newDevStore() if err != nil { return nil, err } - r.registerTraceStore(EnvironmentDev, tstore) + r.RegisterTraceStore(common.EnvironmentDev, tstore) r.tstate = tracing.NewState() r.tstate.AddTraceStoreImmediate(tstore) return r, nil } +func (r *Registry) TracingState() *tracing.State { return r.tstate } + func newDevStore() (tracing.Store, error) { programName := filepath.Base(os.Args[0]) rootHash := fmt.Sprintf("%02x", md5.Sum([]byte(programName))) @@ -84,19 +86,11 @@ func newDevStore() (tracing.Store, error) { return tracing.NewFileStore(dir) } -// An Environment is the execution context in which the program is running. -type Environment string - -const ( - EnvironmentDev Environment = "dev" // development: testing, debugging, etc. - EnvironmentProd Environment = "prod" // production: user data, SLOs, etc. -) - -// registerAction records the action in the registry. +// RegisterAction records the action in the registry. // It panics if an action with the same type, provider and name is already // registered. -func (r *registry) registerAction(a action) { - key := fmt.Sprintf("/%s/%s", a.actionType(), a.Name()) +func (r *Registry) RegisterAction(typ atype.ActionType, a common.Action) { + key := fmt.Sprintf("/%s/%s", typ, a.Name()) r.mu.Lock() defer r.mu.Unlock() if r.frozen { @@ -105,80 +99,63 @@ func (r *registry) registerAction(a action) { if _, ok := r.actions[key]; ok { panic(fmt.Sprintf("action %q is already registered", key)) } - a.setTracingState(r.tstate) + a.SetTracingState(r.tstate) r.actions[key] = a slog.Info("RegisterAction", - "type", a.actionType(), + "type", typ, "name", a.Name()) } -func (r *registry) freeze() { +func (r *Registry) Freeze() { r.mu.Lock() defer r.mu.Unlock() r.frozen = true } -// lookupAction returns the action for the given key, or nil if there is none. -func (r *registry) lookupAction(key string) action { +// LookupAction returns the action for the given key, or nil if there is none. +func (r *Registry) LookupAction(key string) common.Action { r.mu.Lock() defer r.mu.Unlock() return r.actions[key] } -// LookupActionFor returns the action for the given key in the global registry, -// or nil if there is none. -// It panics if the action is of the wrong type. -func LookupActionFor[In, Out, Stream any](typ atype.ActionType, provider, name string) *Action[In, Out, Stream] { - key := fmt.Sprintf("/%s/%s/%s", typ, provider, name) - a := globalRegistry.lookupAction(key) - if a == nil { - return nil - } - return a.(*Action[In, Out, Stream]) -} - -// listActions returns a list of descriptions of all registered actions. +// ListActions returns a list of descriptions of all registered actions. // The list is sorted by action name. -func (r *registry) listActions() []actionDesc { - var ads []actionDesc +func (r *Registry) ListActions() []common.ActionDesc { + var ads []common.ActionDesc r.mu.Lock() defer r.mu.Unlock() keys := maps.Keys(r.actions) slices.Sort(keys) for _, key := range keys { a := r.actions[key] - ad := a.desc() + ad := a.Desc() ad.Key = key ads = append(ads, ad) } return ads } -// registerFlow stores the flow for use by the production server (see [NewFlowServeMux]). +// Flow is the type for the flows stored in a registry. +// Since a registry just remembers flows and returns them, +// this interface is empty. +type Flow interface{} + +// RegisterFlow stores the flow for use by the production server (see [NewFlowServeMux]). // It doesn't check for duplicates because registerAction will do that. -func (r *registry) registerFlow(f flow) { +func (r *Registry) RegisterFlow(f Flow) { r.mu.Lock() defer r.mu.Unlock() r.flows = append(r.flows, f) } -func (r *registry) listFlows() []flow { +func (r *Registry) ListFlows() []Flow { r.mu.Lock() defer r.mu.Unlock() return r.flows } -// RegisterTraceStore uses the given trace.Store to record traces in the prod environment. -// (A trace.Store that writes to the local filesystem is always installed in the dev environment.) -// The returned function should be called before the program ends to ensure that -// all pending data is stored. -// RegisterTraceStore panics if called more than once. -func RegisterTraceStore(ts tracing.Store) (shutdown func(context.Context) error) { - globalRegistry.registerTraceStore(EnvironmentProd, ts) - return globalRegistry.tstate.AddTraceStoreBatch(ts) -} - -func (r *registry) registerTraceStore(env Environment, ts tracing.Store) { +func (r *Registry) RegisterTraceStore(env common.Environment, ts tracing.Store) { r.mu.Lock() defer r.mu.Unlock() if _, ok := r.traceStores[env]; ok { @@ -187,19 +164,12 @@ func (r *registry) registerTraceStore(env Environment, ts tracing.Store) { r.traceStores[env] = ts } -func (r *registry) lookupTraceStore(env Environment) tracing.Store { +func (r *Registry) LookupTraceStore(env common.Environment) tracing.Store { r.mu.Lock() defer r.mu.Unlock() return r.traceStores[env] } -// RegisterSpanProcessor registers an OpenTelemetry SpanProcessor for tracing. -func RegisterSpanProcessor(sp sdktrace.SpanProcessor) { - globalRegistry.registerSpanProcessor(sp) -} - -func (r *registry) registerSpanProcessor(sp sdktrace.SpanProcessor) { - r.mu.Lock() - defer r.mu.Unlock() +func (r *Registry) RegisterSpanProcessor(sp sdktrace.SpanProcessor) { r.tstate.RegisterSpanProcessor(sp) } diff --git a/go/core/validation.go b/go/internal/validation.go similarity index 87% rename from go/core/validation.go rename to go/internal/validation.go index 966d3f56f..5c69d2129 100644 --- a/go/core/validation.go +++ b/go/internal/validation.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package core +package internal import ( "encoding/json" @@ -23,19 +23,19 @@ import ( "github.com/xeipuuv/gojsonschema" ) -// validateValue will validate any value against the expected schema. +// ValidateValue will validate any value against the expected schema. // It will return an error if it doesn't match the schema, otherwise it will return nil. -func validateValue(data any, schema *jsonschema.Schema) error { +func ValidateValue(data any, schema *jsonschema.Schema) error { dataBytes, err := json.Marshal(data) if err != nil { return fmt.Errorf("data is not a valid JSON type: %w", err) } - return validateJSON(dataBytes, schema) + return ValidateJSON(dataBytes, schema) } -// validateJSON will validate JSON against the expected schema. +// ValidateJSON will validate JSON against the expected schema. // It will return an error if it doesn't match the schema, otherwise it will return nil. -func validateJSON(dataBytes json.RawMessage, schema *jsonschema.Schema) error { +func ValidateJSON(dataBytes json.RawMessage, schema *jsonschema.Schema) error { schemaBytes, err := schema.MarshalJSON() if err != nil { return fmt.Errorf("expected schema is not valid: %w", err)