Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optionally render entity requires populator function for advanced @requires use cases #2884

Merged
merged 14 commits into from
Feb 23, 2024
Merged
9 changes: 5 additions & 4 deletions codegen/config/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ import (
)

type PackageConfig struct {
Filename string `yaml:"filename,omitempty"`
Package string `yaml:"package,omitempty"`
Version int `yaml:"version,omitempty"`
ModelTemplate string `yaml:"model_template,omitempty"`
Filename string `yaml:"filename,omitempty"`
Package string `yaml:"package,omitempty"`
Version int `yaml:"version,omitempty"`
ModelTemplate string `yaml:"model_template,omitempty"`
Options map[string]bool `yaml:"options,omitempty"`
}

func (c *PackageConfig) ImportPath() string {
Expand Down
8 changes: 8 additions & 0 deletions plugin/federation/entity.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package federation

import (
"go/types"
"strings"

"github.com/vektah/gqlparser/v2/ast"

Expand All @@ -18,6 +19,7 @@ type Entity struct {
Resolvers []*EntityResolver
Requires []*Requires
Multi bool
Type types.Type
}

type EntityResolver struct {
Expand Down Expand Up @@ -116,3 +118,9 @@ func (e *Entity) keyFields() []string {
}
return keyFields
}

// GetTypeInfo - get the imported package & type name combo. package.TypeName
func (e Entity) GetTypeInfo() string {
typeParts := strings.Split(e.Type.String(), "/")
return typeParts[len(typeParts)-1]
}
105 changes: 103 additions & 2 deletions plugin/federation/federation.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package federation
import (
_ "embed"
"fmt"
"os"
"path/filepath"
"runtime"
"sort"
"strings"

Expand All @@ -11,6 +14,7 @@ import (
"github.com/99designs/gqlgen/codegen"
"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/codegen/templates"
"github.com/99designs/gqlgen/internal/rewrite"
"github.com/99designs/gqlgen/plugin"
"github.com/99designs/gqlgen/plugin/federation/fieldset"
)
Expand All @@ -19,8 +23,9 @@ import (
var federationTemplate string

type federation struct {
Entities []*Entity
Version int
Entities []*Entity
Version int
PackageOptions map[string]bool
}

// New returns a federation plugin that injects
Expand Down Expand Up @@ -252,6 +257,16 @@ type Entity {
}

func (f *federation) GenerateCode(data *codegen.Data) error {
// requires imports
requiresImports := make(map[string]bool, 0)
requiresImports["context"] = true
requiresImports["fmt"] = true

requiresEntities := make(map[string]*Entity, 0)

// Save package options on f for template use
f.PackageOptions = data.Config.Federation.Options

if len(f.Entities) > 0 {
if data.Objects.ByName("Entity") != nil {
data.Objects.ByName("Entity").Root = true
Expand Down Expand Up @@ -291,9 +306,19 @@ func (f *federation) GenerateCode(data *codegen.Data) error {
fmt.Println("skipping @requires field " + reqField.Name + " in " + e.Def.Name)
continue
}
// keep track of which entities have requires
requiresEntities[e.Def.Name] = e
// make a proper import path
typeString := strings.Split(obj.Type.String(), ".")
requiresImports[strings.Join(typeString[:len(typeString)-1], ".")] = true

cgField := reqField.Field.TypeReference(obj, data.Objects)
reqField.Type = cgField.TypeReference
}

// add type info to entity
e.Type = obj.Type

}
}

Expand All @@ -314,6 +339,82 @@ func (f *federation) GenerateCode(data *codegen.Data) error {
}
}

if len(requiresEntities) > 0 {
// check for existing requires functions
type Populator struct {
FuncName string
Exists bool
Comment string
Implementation string
Entity *Entity
}
populators := make([]Populator, 0)

rewriter, err := rewrite.New(data.Config.Resolver.Dir())
if err != nil {
return err
}

for name, entity := range requiresEntities {
populator := Populator{
FuncName: fmt.Sprintf("Populate%sRequires", name),
Entity: entity,
}

populator.Comment = strings.TrimSpace(strings.TrimLeft(rewriter.GetMethodComment("executionContext", populator.FuncName), `\`))
populator.Implementation = strings.TrimSpace(rewriter.GetMethodBody("executionContext", populator.FuncName))

if populator.Implementation == "" {
populator.Exists = false
populator.Implementation = fmt.Sprintf("panic(fmt.Errorf(\"not implemented: %v\"))", populator.FuncName)
}
populators = append(populators, populator)
}

if data.Config.Federation.Options["explicit_requires"] {

// find and read requires template
_, callerFile, _, _ := runtime.Caller(0)
currentDir := filepath.Dir(callerFile)
requiresTemplate, err := os.ReadFile(currentDir + "/requires.gotpl")

if err != nil {
return err
}

requiresFile := data.Config.Federation.Dir() + "/federation.requires.go"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ldebruijn I've reverted this from your proposed change as it breaks the implementation and the test code generation.

existingImports := rewriter.ExistingImports(requiresFile)
for _, imp := range existingImports {
if imp.Alias == "" {
// import exists in both places, remove
delete(requiresImports, imp.ImportPath)
}
}

for k := range requiresImports {
existingImports = append(existingImports, rewrite.Import{ImportPath: k})
}

// render requires populators
err = templates.Render(templates.Options{
PackageName: data.Config.Federation.Package,
Filename: requiresFile,
Data: struct {
federation
ExistingImports []rewrite.Import
Populators []Populator
OriginalSource string
}{*f, existingImports, populators, ""},
GeneratedHeader: false,
Packages: data.Config.Packages,
Template: string(requiresTemplate),
})
if err != nil {
return err
}
}
}

return templates.Render(templates.Options{
PackageName: data.Config.Federation.Package,
Filename: data.Config.Federation.Filename,
Expand Down
14 changes: 11 additions & 3 deletions plugin/federation/federation.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
{{ reserveImport "sync" }}

{{ reserveImport "github.com/99designs/gqlgen/plugin/federation/fedruntime" }}
{{ $options := .PackageOptions }}

var (
ErrUnknownType = errors.New("unknown type")
Expand Down Expand Up @@ -103,11 +104,18 @@ func (ec *executionContext) __resolve_entities(ctx context.Context, representati
if err != nil {
return fmt.Errorf(`resolving Entity "{{$entity.Def.Name}}": %w`, err)
}
{{ range $entity.Requires }}
entity.{{.Field.JoinGo `.`}}, err = ec.{{.Type.UnmarshalFunc}}(ctx, rep["{{.Field.Join `"].(map[string]interface{})["`}}"])
{{ if and (index $options "explicit_requires") $entity.Requires }}
err = ec.Populate{{$entity.Def.Name}}Requires(ctx, entity, rep)
if err != nil {
return err
return fmt.Errorf(`populating requires for Entity "{{$entity.Def.Name}}": %w`, err)
}
{{- else }}
{{ range $entity.Requires }}
entity.{{.Field.JoinGo `.`}}, err = ec.{{.Type.UnmarshalFunc}}(ctx, rep["{{.Field.Join `"].(map[string]interface{})["`}}"])
if err != nil {
return err
}
{{- end }}
{{- end }}
list[idx[i]] = entity
return nil
Expand Down
Loading