Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow any file to make use of custom descriptor.proto #97

Merged
merged 8 commits into from
Feb 23, 2023
74 changes: 68 additions & 6 deletions compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ type executor struct {
cancel context.CancelFunc
sym *linker.Symbols

descriptorProtoCheck sync.Once
descriptorProtoIsCustom bool

mu sync.Mutex
results map[string]*result
}
Expand Down Expand Up @@ -316,6 +319,18 @@ func (e errFailedToResolve) Unwrap() error {
return e.err
}

func (e *executor) hasOverrideDescriptorProto() bool {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do this as a way to cheaply detect if descriptors are overridden. That way, if the file provided is not overridden, we can avoid the cost of wrapping the descriptor in a linker.File (which is not terribly expensive, but also not free since it must create an index of all symbols in the file).

e.descriptorProtoCheck.Do(func() {
defer func() {
// ignore a panic here; just assume no custom descriptor.proto
_ = recover()
}()
res, err := e.c.Resolver.FindFileByPath(descriptorProtoPath)
e.descriptorProtoIsCustom = err == nil && res.Desc != standardImports[descriptorProtoPath]
})
return e.descriptorProtoIsCustom
}

func (e *executor) doCompile(ctx context.Context, file string, r *result) {
t := task{e: e, h: e.h.SubHandler(), r: r}
if err := e.s.Acquire(ctx, 1); err != nil {
Expand All @@ -326,7 +341,7 @@ func (e *executor) doCompile(ctx context.Context, file string, r *result) {

sr, err := e.c.Resolver.FindFileByPath(file)
if err != nil {
r.fail(errFailedToResolve{err, file})
r.fail(errFailedToResolve{err: err, path: file})
return
}

Expand Down Expand Up @@ -371,6 +386,8 @@ func (t *task) release() {
}
}

const descriptorProtoPath = "google/protobuf/descriptor.proto"

func (t *task) asFile(ctx context.Context, name string, r SearchResult) (linker.File, error) {
if r.Desc != nil {
if r.Desc.Path() != name {
Expand All @@ -385,12 +402,38 @@ func (t *task) asFile(ctx context.Context, name string, r SearchResult) (linker.
}

var deps []linker.File
if len(parseRes.FileDescriptorProto().Dependency) > 0 {
t.r.setBlockedOn(parseRes.FileDescriptorProto().Dependency)
fileDescriptorProto := parseRes.FileDescriptorProto()
var wantsDescriptorProto bool
imports := fileDescriptorProto.Dependency

if t.e.hasOverrideDescriptorProto() {
// we only consider implicitly including descriptor.proto if it's overridden
if name != descriptorProtoPath {
var includesDescriptorProto bool
for _, dep := range fileDescriptorProto.Dependency {
if dep == descriptorProtoPath {
includesDescriptorProto = true
break
}
}
if !includesDescriptorProto {
wantsDescriptorProto = true
// make a defensive copy so we don't inadvertently mutate
// slice's backing array when adding this implicit dep
importsCopy := make([]string, len(imports)+1)
copy(importsCopy, imports)
importsCopy[len(imports)] = descriptorProtoPath
imports = importsCopy
}
}
}

if len(imports) > 0 {
t.r.setBlockedOn(imports)

results := make([]*result, len(parseRes.FileDescriptorProto().Dependency))
results := make([]*result, len(fileDescriptorProto.Dependency))
checked := map[string]struct{}{}
for i, dep := range parseRes.FileDescriptorProto().Dependency {
for i, dep := range fileDescriptorProto.Dependency {
pos := findImportPos(parseRes, dep)
if name == dep {
// doh! file imports itself
Expand All @@ -405,7 +448,15 @@ func (t *task) asFile(ctx context.Context, name string, r SearchResult) (linker.
}
results[i] = res
}
deps = make([]linker.File, len(results))
capacity := len(results)
if wantsDescriptorProto {
capacity++
}
deps = make([]linker.File, len(results), capacity)
var descriptorProtoRes *result
if wantsDescriptorProto {
descriptorProtoRes = t.e.compile(ctx, descriptorProtoPath)
}

// release our semaphore so dependencies can be processed w/out risk of deadlock
t.e.s.Release(1)
Expand All @@ -430,6 +481,17 @@ func (t *task) asFile(ctx context.Context, name string, r SearchResult) (linker.
return nil, ctx.Err()
}
}
if descriptorProtoRes != nil {
select {
case <-descriptorProtoRes.ready:
// descriptor.proto wasn't explicitly imported, so we can ignore a failure
if descriptorProtoRes.err == nil {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I walked through this in the debugger and saw one non-nil error reported in the linker tests - search result for "google/protobuf/descriptor.proto" returned descriptor for "foo.proto". Is that expected?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly. What test? If you look at the resolver it's using, it may be hard-coded to return a particular descriptor, regardless of the requested file name.

In any event, since this is not an explicit dependency, I don't think we should fail here if the resolver can't supply something. If the explicitly import this file, then it will fail as expected.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll see if I can gather more details. I agree with the behavior just wanted to make sure there wasn't something we were missing.

deps = append(deps, descriptorProtoRes.res)
}
case <-ctx.Done():
return nil, ctx.Err()
}
}

// all deps resolved
t.r.setBlockedOn(nil)
Expand Down
19 changes: 16 additions & 3 deletions compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,17 @@ func TestParseFilesWithDependencies(t *testing.T) {
// Create a dependency-aware parser that should never be called.
compiler := Compiler{
Resolver: ResolverFunc(func(f string) (SearchResult, error) {
if f == "test.proto" {
switch f {
case "test.proto":
return SearchResult{Source: strings.NewReader(`syntax = "proto3";`)}, nil
case descriptorProtoPath:
// used to see if resolver provides custom descriptor.proto
return SearchResult{}, os.ErrNotExist
default:
// no other name should be passed to resolver
t.Errorf("resolver was called for unexpected filename %q", f)
return SearchResult{}, os.ErrNotExist
}
t.Errorf("resolved was called for unexpected filename %q", f)
return SearchResult{}, os.ErrNotExist
}),
}
_, err := compiler.Compile(ctx, "test.proto")
Expand Down Expand Up @@ -261,3 +267,10 @@ func TestPanicHandling(t *testing.T) {
require.True(t, ok)
t.Logf("%v\n\n%v", panicErr, panicErr.Stack)
}

func TestDescriptorProtoPath(t *testing.T) {
t.Parallel()
// sanity check our constant
path := (*descriptorpb.FileDescriptorProto)(nil).ProtoReflect().Descriptor().ParentFile().Path()
require.Equal(t, descriptorProtoPath, path)
}
18 changes: 13 additions & 5 deletions linker/files.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (

// File is like a super-powered protoreflect.FileDescriptor. It includes helpful
// methods for looking up elements in the descriptor and can be used to create a
// resolver for all in the file's transitive closure of dependencies. (See
// resolver for all of the file's transitive closure of dependencies. (See
// ResolverFromFile.)
type File interface {
protoreflect.FileDescriptor
Expand Down Expand Up @@ -88,6 +88,8 @@ func newFile(f protoreflect.FileDescriptor, deps Files) (File, error) {
// NewFileRecursive recursively converts a protoreflect.FileDescriptor to a File.
// If f has any dependencies/imports, they are converted, too, including any and
// all transitive dependencies.
//
// If f already implements File, it is returned unchanged.
func NewFileRecursive(f protoreflect.FileDescriptor) (File, error) {
if asFile, ok := f.(File); ok {
return asFile, nil
Expand Down Expand Up @@ -185,10 +187,16 @@ type Resolver interface {
protoregistry.ExtensionTypeResolver
}

// ResolverFromFile returns a Resolver that uses the given file plus its full
// set of transitive dependencies as the source of descriptors. If a given query
// cannot be answered with these files, the query will fail with a
// protoregistry.NotFound error.
// ResolverFromFile returns a Resolver that uses the given file plus all of its
// imports as the source of descriptors. If a given query cannot be answered with
// these files, the query will fail with a protoregistry.NotFound error. This
// does not recursively search the entire transitive closure; it only searches
// the given file and its immediate dependencies. This is useful for resolving
// elements visible to the file.
//
// If the given file is the result of a call to Link, then all dependencies
// provided in the call to Link are searched (which could actually include more
// than just the file's direct imports).
//
// Note that this function does not compute any additional indexes for efficient
// search, so queries generally take linear time, O(n) where n is the number of
Expand Down
6 changes: 6 additions & 0 deletions linker/linker.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ type Result interface {
// message that is available in this file. If no such element is available
// or if the named element is not a message, nil is returned.
ResolveMessageType(protoreflect.FullName) protoreflect.MessageDescriptor
// ResolveOptionsType returns a message descriptor for the given options
// type. This is like ResolveMessageType but searches the result's entire
// set of transitive dependencies without regard for visibility. If no
// such element is available or if the named element is not a message, nil
// is returned.
ResolveOptionsType(protoreflect.FullName) protoreflect.MessageDescriptor
// ResolveExtension returns an extension descriptor for the given named
// extension that is available in this file. If no such element is available
// or if the named element is not an extension, nil is returned.
Expand Down
16 changes: 16 additions & 0 deletions linker/linker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1583,6 +1583,22 @@ func TestLinkerValidation(t *testing.T) {
},
expectedErr: `foo.proto:14:23: option (bar): value -2147483649 is out of range for an enum`,
},
"success_custom_field_option": {
input: map[string]string{
"google/protobuf/descriptor.proto": `
syntax = "proto2";
package google.protobuf;
message FieldOptions {
optional string some_new_option = 11;
}`,
"bar.proto": `
syntax = "proto3";
package foo.bar.baz;
message Foo {
string bar = 1 [some_new_option="abc"];
}`,
},
},
}

for name, tc := range testCases {
Expand Down
21 changes: 19 additions & 2 deletions linker/resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@ func (r *result) ResolveMessageType(name protoreflect.FullName) protoreflect.Mes
return nil
}

func (r *result) ResolveOptionsType(name protoreflect.FullName) protoreflect.MessageDescriptor {
d, _ := ResolverFromFile(r).FindDescriptorByName(name)
md, _ := d.(protoreflect.MessageDescriptor)
if md != nil && md.ParentFile() != nil {
r.markUsed(md.ParentFile().Path())
jhump marked this conversation as resolved.
Show resolved Hide resolved
}
return md
}

func (r *result) ResolveEnumType(name protoreflect.FullName) protoreflect.EnumDescriptor {
d := r.resolveElement(name)
if ed, ok := d.(protoreflect.EnumDescriptor); ok {
Expand Down Expand Up @@ -343,7 +352,7 @@ func resolveFieldTypes(f *fldDescriptor, handler *reporter.Handler, s *Symbols,
}
} else {
// make sure tag is not a duplicate
if err := s.AddExtension(dsc.ParentFile().Package(), dsc.FullName(), tag, file.NodeInfo(node.FieldTag()).Start(), handler); err != nil {
if err := s.AddExtension(packageFor(dsc), dsc.FullName(), tag, file.NodeInfo(node.FieldTag()).Start(), handler); err != nil {
return err
}
}
Expand Down Expand Up @@ -402,7 +411,7 @@ func resolveFieldTypes(f *fldDescriptor, handler *reporter.Handler, s *Symbols,
f.msgType = dsc
case protoreflect.EnumDescriptor:
proto3 := r.Syntax() == protoreflect.Proto3
enumIsProto3 := dsc.ParentFile().Syntax() == protoreflect.Proto3
enumIsProto3 := dsc.Syntax() == protoreflect.Proto3
if fld.GetExtendee() == "" && proto3 && !enumIsProto3 {
// fields in a proto3 message cannot refer to proto2 enums
return handler.HandleErrorf(file.NodeInfo(node.FieldType()).Start(), "%s: cannot use proto2 enum %s in a proto3 message", scope, fld.GetTypeName())
Expand All @@ -417,6 +426,14 @@ func resolveFieldTypes(f *fldDescriptor, handler *reporter.Handler, s *Symbols,
return nil
}

func packageFor(dsc protoreflect.Descriptor) protoreflect.FullName {
if dsc.ParentFile() != nil {
return dsc.ParentFile().Package()
}
// Can't access package? Make a best effort guess.
return dsc.FullName().Parent()
}

func isValidMap(mapField protoreflect.FieldDescriptor, mapEntry protoreflect.MessageDescriptor) bool {
return !mapField.IsExtension() &&
mapEntry.Parent() == mapField.ContainingMessage() &&
Expand Down
34 changes: 22 additions & 12 deletions linker/symbols.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import (
"github.com/bufbuild/protocompile/walk"
)

const unknownFilePath = "<unknown file>"

// Symbols is a symbol table that maps names for all program elements to their
// location in source. It also tracks extension tag numbers. This can be used
// to enforce uniqueness for symbol names and tag numbers across many files and
Expand Down Expand Up @@ -121,7 +123,7 @@ func (s *Symbols) importFileWithExtensions(pkg *packageSymbols, fd protoreflect.
}
pos := sourcePositionForNumber(fld)
extendee := fld.ContainingMessage()
if err := s.AddExtension(extendee.ParentFile().Package(), extendee.FullName(), fld.Number(), pos, handler); err != nil {
if err := s.AddExtension(packageFor(extendee), extendee.FullName(), fld.Number(), pos, handler); err != nil {
return err
}
return nil
Expand Down Expand Up @@ -294,9 +296,13 @@ func sourcePositionForPackage(fd protoreflect.FileDescriptor) ast.SourcePos {
}

func sourcePositionFor(d protoreflect.Descriptor) ast.SourcePos {
file := d.ParentFile()
if file == nil {
return ast.UnknownPos(unknownFilePath)
}
path, ok := computePath(d)
if !ok {
return ast.UnknownPos(d.ParentFile().Path())
return ast.UnknownPos(file.Path())
}
namePath := path
switch d.(type) {
Expand All @@ -318,36 +324,40 @@ func sourcePositionFor(d protoreflect.Descriptor) ast.SourcePos {
// NB: shouldn't really happen, but just in case fall back to path to
// descriptor, sans name field
}
loc := d.ParentFile().SourceLocations().ByPath(namePath)
loc := file.SourceLocations().ByPath(namePath)
if isZeroLoc(loc) {
loc = d.ParentFile().SourceLocations().ByPath(path)
loc = file.SourceLocations().ByPath(path)
if isZeroLoc(loc) {
return ast.UnknownPos(d.ParentFile().Path())
return ast.UnknownPos(file.Path())
}
}
return ast.SourcePos{
Filename: d.ParentFile().Path(),
Filename: file.Path(),
Line: loc.StartLine,
Col: loc.StartColumn,
}
}

func sourcePositionForNumber(fd protoreflect.FieldDescriptor) ast.SourcePos {
file := fd.ParentFile()
if file == nil {
return ast.UnknownPos(unknownFilePath)
}
path, ok := computePath(fd)
if !ok {
return ast.UnknownPos(fd.ParentFile().Path())
return ast.UnknownPos(file.Path())
}
numberPath := path
numberPath = append(numberPath, internal.FieldNumberTag)
loc := fd.ParentFile().SourceLocations().ByPath(numberPath)
loc := file.SourceLocations().ByPath(numberPath)
if isZeroLoc(loc) {
loc = fd.ParentFile().SourceLocations().ByPath(path)
loc = file.SourceLocations().ByPath(path)
if isZeroLoc(loc) {
return ast.UnknownPos(fd.ParentFile().Path())
return ast.UnknownPos(file.Path())
}
}
return ast.SourcePos{
Filename: fd.ParentFile().Path(),
Filename: file.Path(),
Line: loc.StartLine,
Col: loc.StartColumn,
}
Expand Down Expand Up @@ -401,7 +411,7 @@ func (s *Symbols) importResultWithExtensions(pkg *packageSymbols, r *result, han
node := r.FieldNode(fd.FieldDescriptorProto())
pos := file.NodeInfo(node.FieldTag()).Start()
extendee := fd.ContainingMessage()
if err := s.AddExtension(extendee.ParentFile().Package(), extendee.FullName(), fd.Number(), pos, handler); err != nil {
if err := s.AddExtension(packageFor(extendee), extendee.FullName(), fd.Number(), pos, handler); err != nil {
return err
}

Expand Down
Loading