Skip to content

Commit

Permalink
Retain custom options as known fields, even with custom `descriptor.p…
Browse files Browse the repository at this point in the history
…roto` (#109)

An issue was introduced in #97 -- if a custom `descriptor.proto` is
used, then custom options may come across in the final compilation
result as unrecognized fields. The first commit in this PR is a repro case.

This happens when the options in question come from _public_
imports and when an _override_ set of options are used. We have
to use a dynamic message to represent the override definition of
the options. But the result ultimately needs to be the generated
type (not a dynamic message). So we end serializing the dynamic
message to bytes and then de-serializing that into the generated
type. During de-serialization, we were supplying an extension
resolver that was failing to find the custom option, because it
was visible via a public import.

To fix the issue, I found several code paths that could be unified
and consolidated, which ultimately resulted in removing a few
no-longer-used methods from interfaces defined in the linker
sub-package. This also changes the way we handle an override
descriptor.proto by making it an interpreter option instead of
providing it during linking (since it is really only used by the
option interpreter).

These API/interface changes are not technically backwards
compatible. But this repo is still pre-v1.0, and these changes
are in interfaces that we don't actually expect any users to
implement.
  • Loading branch information
jhump authored Mar 7, 2023
1 parent d4ce4f4 commit 49c8099
Show file tree
Hide file tree
Showing 9 changed files with 325 additions and 256 deletions.
39 changes: 19 additions & 20 deletions compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ func (t *task) asFile(ctx context.Context, name string, r SearchResult) (linker.
}
}

var descriptorProtoRes *result
if len(imports) > 0 {
t.r.setBlockedOn(imports)

Expand All @@ -455,12 +456,7 @@ func (t *task) asFile(ctx context.Context, name string, r SearchResult) (linker.
}
results[i] = res
}
capacity := len(results)
if wantsDescriptorProto {
capacity++
}
deps = make([]linker.File, len(results), capacity)
var descriptorProtoRes *result
deps = make([]linker.File, len(results))
if wantsDescriptorProto {
descriptorProtoRes = t.e.compile(ctx, descriptorProtoPath)
}
Expand Down Expand Up @@ -488,17 +484,6 @@ func (t *task) asFile(ctx context.Context, name string, r SearchResult) (linker.
return nil, ctx.Err()
}
}
if descriptorProtoRes != nil {
select {
case <-descriptorProtoRes.ready:
// descriptor.proto wasn't explicitly imported, so we can ignore a failure
if descriptorProtoRes.err == nil {
deps = append(deps, descriptorProtoRes.res)
}
case <-ctx.Done():
return nil, ctx.Err()
}
}

// all deps resolved
t.r.setBlockedOn(nil)
Expand All @@ -509,7 +494,7 @@ func (t *task) asFile(ctx context.Context, name string, r SearchResult) (linker.
t.released = false
}

return t.link(parseRes, deps)
return t.link(ctx, parseRes, deps, descriptorProtoRes)
}

func (e *executor) checkForDependencyCycle(res *result, sequence []string, pos ast.SourcePos, checked map[string]struct{}) error {
Expand Down Expand Up @@ -568,12 +553,26 @@ func findImportPos(res parser.Result, dep string) ast.SourcePos {
return ast.UnknownPos(res.FileNode().Name())
}

func (t *task) link(parseRes parser.Result, deps linker.Files) (linker.File, error) {
func (t *task) link(ctx context.Context, parseRes parser.Result, deps linker.Files, descriptorProtoRes *result) (linker.File, error) {
file, err := linker.Link(parseRes, deps, t.e.sym, t.h)
if err != nil {
return nil, err
}
optsIndex, err := options.InterpretOptions(file, t.h)

var interpretOpts []options.InterpreterOption
if descriptorProtoRes != nil {
select {
case <-descriptorProtoRes.ready:
// descriptor.proto wasn't explicitly imported, so we can ignore a failure
if descriptorProtoRes.err == nil {
interpretOpts = []options.InterpreterOption{options.WithOverrideDescriptorProto(descriptorProtoRes.res)}
}
case <-ctx.Done():
return nil, ctx.Err()
}
}

optsIndex, err := options.InterpretOptions(file, t.h, interpretOpts...)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion internal/benchmarks/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ go 1.19
require (
github.com/bufbuild/protocompile v0.0.0-20221004230924-06a336f5b6be
github.com/igrmk/treemap/v2 v2.0.1
github.com/jhump/protoreflect v1.13.0
github.com/jhump/protoreflect v1.14.1
github.com/stretchr/testify v1.8.0
google.golang.org/protobuf v1.28.2-0.20220831092852-f930b1dc76e8
)
Expand Down
4 changes: 2 additions & 2 deletions internal/benchmarks/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ github.com/jhump/gopoet v0.0.0-20190322174617-17282ff210b3/go.mod h1:me9yfT6IJSl
github.com/jhump/gopoet v0.1.0/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI=
github.com/jhump/goprotoc v0.5.0/go.mod h1:VrbvcYrQOrTi3i0Vf+m+oqQWk9l72mjkJCYo7UvLHRQ=
github.com/jhump/protoreflect v1.11.0/go.mod h1:U7aMIjN0NWq9swDP7xDdoMfRHb35uiuTd3Z9nFXJf5E=
github.com/jhump/protoreflect v1.13.0 h1:zrrZqa7JAc2YGgPSzZZkmUXJ5G6NRPdxOg/9t7ISImA=
github.com/jhump/protoreflect v1.13.0/go.mod h1:JytZfP5d0r8pVNLZvai7U/MCuTWITgrI4tTg7puQFKI=
github.com/jhump/protoreflect v1.14.1 h1:N88q7JkxTHWFEqReuTsYH1dPIwXxA0ITNQp7avLY10s=
github.com/jhump/protoreflect v1.14.1/go.mod h1:JytZfP5d0r8pVNLZvai7U/MCuTWITgrI4tTg7puQFKI=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
Expand Down
6 changes: 1 addition & 5 deletions linker/descriptors.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ func asSourceLocations(srcInfoProtos []*descriptorpb.SourceCodeInfo_Location) []
func pathStr(p protoreflect.SourcePath) string {
var buf bytes.Buffer
for _, v := range p {
fmt.Fprintf(&buf, "%x:", v)
_, _ = fmt.Fprintf(&buf, "%x:", v)
}
return buf.String()
}
Expand Down Expand Up @@ -1869,10 +1869,6 @@ func (r *result) FindDescriptorByName(name protoreflect.FullName) protoreflect.D
return r.descriptors[fqn]
}

func (r *result) importsAsFiles() Files {
return r.deps
}

func (r *result) hasSource() bool {
n := r.FileNode()
_, ok := n.(*ast.FileNode)
Expand Down
108 changes: 47 additions & 61 deletions linker/files.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (

// File is like a super-powered protoreflect.FileDescriptor. It includes helpful
// methods for looking up elements in the descriptor and can be used to create a
// resolver for all of the file's transitive closure of dependencies. (See
// resolver for all the file's transitive closure of dependencies. (See
// ResolverFromFile.)
type File interface {
protoreflect.FileDescriptor
Expand All @@ -42,10 +42,6 @@ type File interface {
// that extends the given message name. If no such extension is defined in this
// file, nil is returned.
FindExtensionByNumber(message protoreflect.FullName, tag protoreflect.FieldNumber) protoreflect.ExtensionTypeDescriptor
// Imports returns this file's imports. These are only the files directly
// imported by the file. Indirect transitive dependencies will not be in
// the returned slice.
importsAsFiles() Files
}

// NewFile converts a protoreflect.FileDescriptor to a File. The given deps must
Expand Down Expand Up @@ -147,10 +143,6 @@ func (f *file) FindExtensionByNumber(msg protoreflect.FullName, tag protoreflect
return findExtension(f, msg, tag)
}

func (f *file) importsAsFiles() Files {
return f.deps
}

var _ File = (*file)(nil)

// Files represents a set of protobuf files. It is a slice of File values, but
Expand Down Expand Up @@ -187,58 +179,53 @@ type Resolver interface {
protoregistry.ExtensionTypeResolver
}

// ResolverFromFile returns a Resolver that uses the given file plus all of its
// imports as the source of descriptors. If a given query cannot be answered with
// these files, the query will fail with a protoregistry.NotFound error. This
// does not recursively search the entire transitive closure; it only searches
// the given file and its immediate dependencies. This is useful for resolving
// elements visible to the file.
//
// If the given file is the result of a call to Link, then all dependencies
// provided in the call to Link are searched (which could actually include more
// than just the file's direct imports).
// ResolverFromFile returns a Resolver that can resolve any element that is
// visible to the given file. It will search the given file, its imports, and
// any transitive public imports.
//
// Note that this function does not compute any additional indexes for efficient
// search, so queries generally take linear time, O(n) where n is the number of
// files in the transitive closure of the given file. Queries for an extension
// files whose elements are visible to the given file. Queries for an extension
// by number are linear with the number of messages and extensions defined across
// all the files.
// those files.
func ResolverFromFile(f File) Resolver {
return fileResolver{
f: f,
deps: f.importsAsFiles().AsResolver(),
}
return fileResolver{f: f}
}

type fileResolver struct {
f File
deps Resolver
f File
}

func (r fileResolver) FindFileByPath(path string) (protoreflect.FileDescriptor, error) {
if r.f.Path() == path {
return r.f, nil
}
return r.deps.FindFileByPath(path)
return resolveInFile(r.f, false, nil, func(f File) (protoreflect.FileDescriptor, error) {
if f.Path() == path {
return f, nil
}
return nil, protoregistry.NotFound
})
}

func (r fileResolver) FindDescriptorByName(name protoreflect.FullName) (protoreflect.Descriptor, error) {
d := r.f.FindDescriptorByName(name)
if d != nil {
return d, nil
}
return r.deps.FindDescriptorByName(name)
return resolveInFile(r.f, false, nil, func(f File) (protoreflect.Descriptor, error) {
if d := f.FindDescriptorByName(name); d != nil {
return d, nil
}
return nil, protoregistry.NotFound
})
}

func (r fileResolver) FindMessageByName(message protoreflect.FullName) (protoreflect.MessageType, error) {
d := r.f.FindDescriptorByName(message)
if d != nil {
if md, ok := d.(protoreflect.MessageDescriptor); ok {
return resolveInFile(r.f, false, nil, func(f File) (protoreflect.MessageType, error) {
d := f.FindDescriptorByName(message)
if d != nil {
md, ok := d.(protoreflect.MessageDescriptor)
if !ok {
return nil, fmt.Errorf("%q is %s, not a message", message, descriptorTypeWithArticle(d))
}
return dynamicpb.NewMessageType(md), nil
}
return nil, protoregistry.NotFound
}
return r.deps.FindMessageByName(message)
})
}

func (r fileResolver) FindMessageByURL(url string) (protoreflect.MessageType, error) {
Expand All @@ -248,35 +235,34 @@ func (r fileResolver) FindMessageByURL(url string) (protoreflect.MessageType, er

func messageNameFromURL(url string) string {
lastSlash := strings.LastIndexByte(url, '/')
var fullName string
if lastSlash >= 0 {
fullName = url[lastSlash+1:]
} else {
fullName = url
}
return fullName
return url[lastSlash+1:]
}

func (r fileResolver) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
d := r.f.FindDescriptorByName(field)
if d != nil {
if extd, ok := d.(protoreflect.ExtensionTypeDescriptor); ok {
return extd.Type(), nil
}
if fld, ok := d.(protoreflect.FieldDescriptor); ok && fld.IsExtension() {
return resolveInFile(r.f, false, nil, func(f File) (protoreflect.ExtensionType, error) {
d := f.FindDescriptorByName(field)
if d != nil {
fld, ok := d.(protoreflect.FieldDescriptor)
if !ok || !fld.IsExtension() {
return nil, fmt.Errorf("%q is %s, not an extension", field, descriptorTypeWithArticle(d))
}
if extd, ok := fld.(protoreflect.ExtensionTypeDescriptor); ok {
return extd.Type(), nil
}
return dynamicpb.NewExtensionType(fld), nil
}
return nil, protoregistry.NotFound
}
return r.deps.FindExtensionByName(field)
})
}

func (r fileResolver) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
ext := findExtension(r.f, message, field)
if ext != nil {
return ext.Type(), nil
}
return r.deps.FindExtensionByNumber(message, field)
return resolveInFile(r.f, false, nil, func(f File) (protoreflect.ExtensionType, error) {
ext := findExtension(f, message, field)
if ext != nil {
return ext.Type(), nil
}
return nil, protoregistry.NotFound
})
}

type filesResolver []File
Expand Down
19 changes: 1 addition & 18 deletions linker/linker.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,24 +106,7 @@ func Link(parsed parser.Result, dependencies Files, symbols *Symbols, handler *r
type Result interface {
File
parser.Result
// ResolveEnumType returns an enum descriptor for the given named enum that
// is available in this file. If no such element is available or if the
// named element is not an enum, nil is returned.
ResolveEnumType(protoreflect.FullName) protoreflect.EnumDescriptor
// ResolveMessageType returns a message descriptor for the given named
// message that is available in this file. If no such element is available
// or if the named element is not a message, nil is returned.
ResolveMessageType(protoreflect.FullName) protoreflect.MessageDescriptor
// ResolveOptionsType returns a message descriptor for the given options
// type. This is like ResolveMessageType but searches the result's entire
// set of transitive dependencies without regard for visibility. If no
// such element is available or if the named element is not a message, nil
// is returned.
ResolveOptionsType(protoreflect.FullName) protoreflect.MessageDescriptor
// ResolveExtension returns an extension descriptor for the given named
// extension that is available in this file. If no such element is available
// or if the named element is not an extension, nil is returned.
ResolveExtension(protoreflect.FullName) protoreflect.ExtensionTypeDescriptor

// ResolveMessageLiteralExtensionName returns the fully qualified name for
// an identifier for extension field names in message literals.
ResolveMessageLiteralExtensionName(ast.IdentValueNode) string
Expand Down
Loading

0 comments on commit 49c8099

Please sign in to comment.