Skip to content

Commit

Permalink
Fix data race in compiler when resolver provides unlinked descriptor …
Browse files Browse the repository at this point in the history
…proto (#103)

From looking at the data race errors in bug report #102, this can happen
only if a `*descriptorpb.FileDescriptorProto` or `parser.Result` is
provided as a search result and is shared across parallel compiler
operations.

Importantly, that means that most usages of the compiler are safe and
won't ever observe any data races (including the way that the Buf CLI
uses it).
  • Loading branch information
jhump authored Mar 2, 2023
1 parent 80a64ae commit 2c27603
Show file tree
Hide file tree
Showing 11 changed files with 604 additions and 22 deletions.
25 changes: 22 additions & 3 deletions compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import (
"sync"

"golang.org/x/sync/semaphore"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/descriptorpb"

"github.com/bufbuild/protocompile/ast"
"github.com/bufbuild/protocompile/linker"
Expand Down Expand Up @@ -400,6 +402,11 @@ func (t *task) asFile(ctx context.Context, name string, r SearchResult) (linker.
if err != nil {
return nil, err
}
if linkRes, ok := parseRes.(linker.Result); ok {
// if resolver returned a parse result that was actually a link result,
// use the link result directly (no other steps needed)
return linkRes, nil
}

var deps []linker.File
fileDescriptorProto := parseRes.FileDescriptorProto()
Expand Down Expand Up @@ -578,7 +585,7 @@ func (t *task) link(parseRes parser.Result, deps linker.Files) (linker.File, err
file.CheckForUnusedImports(t.h)
}

if t.e.c.SourceInfoMode != SourceInfoNone && parseRes.AST() != nil {
if needsSourceInfo(parseRes, t.e.c.SourceInfoMode) {
switch t.e.c.SourceInfoMode {
case SourceInfoStandard:
parseRes.FileDescriptorProto().SourceCodeInfo = sourceinfo.GenerateSourceInfo(parseRes.AST(), optsIndex)
Expand All @@ -594,19 +601,31 @@ func (t *task) link(parseRes parser.Result, deps linker.Files) (linker.File, err
return file, nil
}

func needsSourceInfo(parseRes parser.Result, mode SourceInfoMode) bool {
return mode != SourceInfoNone && parseRes.AST() != nil && parseRes.FileDescriptorProto().SourceCodeInfo == nil
}

func (t *task) asParseResult(name string, r SearchResult) (parser.Result, error) {
if r.ParseResult != nil {
if r.ParseResult.FileDescriptorProto().GetName() != name {
return nil, fmt.Errorf("search result for %q returned descriptor for %q", name, r.ParseResult.FileDescriptorProto().GetName())
}
return r.ParseResult, nil
// If the file descriptor needs linking, it will be mutated during the
// next stage. So to make anu mutations thread-safe, we must make a
// defensive copy.
res := parser.Clone(r.ParseResult)
return res, nil
}

if r.Proto != nil {
if r.Proto.GetName() != name {
return nil, fmt.Errorf("search result for %q returned descriptor for %q", name, r.Proto.GetName())
}
return parser.ResultWithoutAST(r.Proto), nil
// If the file descriptor needs linking, it will be mutated during the
// next stage. So to make any mutations thread-safe, we must make a
// defensive copy.
descProto := proto.Clone(r.Proto).(*descriptorpb.FileDescriptorProto) //nolint:errcheck
return parser.ResultWithoutAST(descProto), nil
}

file, err := t.asAST(name, r)
Expand Down
128 changes: 128 additions & 0 deletions compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,27 @@
package protocompile

import (
"bytes"
"context"
"errors"
"os"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/descriptorpb"

"github.com/bufbuild/protocompile/internal"
"github.com/bufbuild/protocompile/internal/prototest"
"github.com/bufbuild/protocompile/linker"
"github.com/bufbuild/protocompile/parser"
"github.com/bufbuild/protocompile/reporter"
)

func TestParseFilesMessageComments(t *testing.T) {
Expand Down Expand Up @@ -255,6 +262,127 @@ message Foo {
assert.Equal(t, int64(123), barVal.Int())
}

func TestDataRace(t *testing.T) {
t.Parallel()
if !internal.IsRace {
t.Skip("only useful when race detector enabled")
return
}

data, err := os.ReadFile("./internal/testdata/desc_test_complex.proto")
require.NoError(t, err)
ast, err := parser.Parse("desc_test_complex.proto", bytes.NewReader(data), reporter.NewHandler(nil))
require.NoError(t, err)
parseResult, err := parser.ResultFromAST(ast, true, reporter.NewHandler(nil))
require.NoError(t, err)
// let's also produce a resolved proto
files, err := (&Compiler{
Resolver: WithStandardImports(&SourceResolver{
ImportPaths: []string{"./internal/testdata"},
}),
SourceInfoMode: SourceInfoStandard,
}).Compile(context.Background(), "desc_test_complex.proto")
require.NoError(t, err)
resolvedProto := files[0].(linker.Result).FileDescriptorProto()

descriptor, err := protoregistry.GlobalFiles.FindFileByPath(descriptorProtoPath)
require.NoError(t, err)
descriptorProto := protodesc.ToFileDescriptorProto(descriptor)

// We will share this descriptor/parse result (which needs to be modified by the linker
// to resolve all references) from multiple concurrent operations to make sure the race
// detector is not triggered.
testCases := []struct {
name string
resolver Resolver
}{
// TODO: Sadly, interpreting options does not actually work when no AST is provided.
// Uncomment this test case when this is fixed.
// {
// name: "share unresolved descriptor",
// resolver: WithStandardImports(ResolverFunc(func(name string) (SearchResult, error) {
// if name == "desc_test_complex.proto" {
// return SearchResult{
// Proto: parseResult.FileDescriptorProto(),
// }, nil
// }
// return SearchResult{}, os.ErrNotExist
// })),
// },
{
name: "share resolved descriptor",
resolver: WithStandardImports(ResolverFunc(func(name string) (SearchResult, error) {
if name == "desc_test_complex.proto" {
return SearchResult{
Proto: resolvedProto,
}, nil
}
return SearchResult{}, os.ErrNotExist
})),
},
{
name: "share unresolved parse result",
resolver: WithStandardImports(ResolverFunc(func(name string) (SearchResult, error) {
if name == "desc_test_complex.proto" {
return SearchResult{
ParseResult: parseResult,
}, nil
}
return SearchResult{}, os.ErrNotExist
})),
},
{
name: "share google/protobuf/descriptor.proto",
resolver: WithStandardImports(ResolverFunc(func(name string) (SearchResult, error) {
// we'll parse our test proto from source, but its implicit dep on
// descriptor.proto will use a
switch name {
case "desc_test_complex.proto":
return SearchResult{
Source: bytes.NewReader(data),
}, nil
case "google/protobuf/descriptor.proto":
return SearchResult{
Proto: descriptorProto,
}, nil
default:
return SearchResult{}, os.ErrNotExist
}
})),
},
}

for i := range testCases {
testCase := testCases[i]
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()
compiler1 := &Compiler{
Resolver: testCase.resolver,
SourceInfoMode: SourceInfoStandard,
}
compiler2 := &Compiler{
Resolver: testCase.resolver,
SourceInfoMode: SourceInfoStandard,
}
grp, ctx := errgroup.WithContext(context.Background())
grp.Go(func() error {
_, err := compiler1.Compile(ctx, "desc_test_complex.proto")
return err
})
grp.Go(func() error {
// We need to start this *after* the one above, but we can't
// use any sychronizing event or that would not be a race.
// So we assume a one second delay is sufficient.
time.Sleep(time.Second)
_, err := compiler2.Compile(ctx, "desc_test_complex.proto")
return err
})
err := grp.Wait()
require.NoError(t, err)
})
}
}

func TestPanicHandling(t *testing.T) {
t.Parallel()
c := Compiler{
Expand Down
4 changes: 2 additions & 2 deletions parser/norace_test.go → internal/norace.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@

//go:build !race

package parser
package internal

const isRace = false
const IsRace = false
4 changes: 2 additions & 2 deletions parser/race_test.go → internal/race.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@

//go:build race

package parser
package internal

const isRace = true
const IsRace = true
37 changes: 29 additions & 8 deletions linker/resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,10 @@ func resolveFieldTypes(f *fldDescriptor, handler *reporter.Handler, s *Symbols,
return handler.HandleErrorf(file.NodeInfo(node.FieldExtendee()).Start(), "extendee is invalid: %s is %s, not a message", dsc.FullName(), descriptorTypeWithArticle(dsc))
}
f.extendee = extd
fld.Extendee = proto.String("." + string(dsc.FullName()))
extendeeName := "." + string(dsc.FullName())
if fld.GetExtendee() != extendeeName {
fld.Extendee = proto.String(extendeeName)
}
// make sure the tag number is in range
found := false
tag := protoreflect.FieldNumber(fld.GetNumber())
Expand Down Expand Up @@ -403,10 +406,15 @@ func resolveFieldTypes(f *fldDescriptor, handler *reporter.Handler, s *Symbols,
return handler.HandleErrorf(file.NodeInfo(node.FieldType()).Start(), "%s: %s is a synthetic map entry and may not be referenced explicitly", scope, dsc.FullName())
}
}
fld.TypeName = proto.String("." + string(dsc.FullName()))
// if type was tentatively unset, we now know it's actually a message
typeName := "." + string(dsc.FullName())
if fld.GetTypeName() != typeName {
fld.TypeName = proto.String(typeName)
}
if fld.Type == nil {
// if type was tentatively unset, we now know it's actually a message
fld.Type = descriptorpb.FieldDescriptorProto_TYPE_MESSAGE.Enum()
} else if fld.GetType() != descriptorpb.FieldDescriptorProto_TYPE_MESSAGE && fld.GetType() != descriptorpb.FieldDescriptorProto_TYPE_GROUP {
return handler.HandleErrorf(file.NodeInfo(node.FieldType()).Start(), "%s: descriptor proto indicates type %v but should be %v", scope, fld.GetType(), descriptorpb.FieldDescriptorProto_TYPE_MESSAGE)
}
f.msgType = dsc
case protoreflect.EnumDescriptor:
Expand All @@ -416,9 +424,16 @@ func resolveFieldTypes(f *fldDescriptor, handler *reporter.Handler, s *Symbols,
// fields in a proto3 message cannot refer to proto2 enums
return handler.HandleErrorf(file.NodeInfo(node.FieldType()).Start(), "%s: cannot use proto2 enum %s in a proto3 message", scope, fld.GetTypeName())
}
fld.TypeName = proto.String("." + string(dsc.FullName()))
// the type was tentatively unset, but now we know it's actually an enum
fld.Type = descriptorpb.FieldDescriptorProto_TYPE_ENUM.Enum()
typeName := "." + string(dsc.FullName())
if fld.GetTypeName() != typeName {
fld.TypeName = proto.String(typeName)
}
if fld.Type == nil {
// the type was tentatively unset, but now we know it's actually an enum
fld.Type = descriptorpb.FieldDescriptorProto_TYPE_ENUM.Enum()
} else if fld.GetType() != descriptorpb.FieldDescriptorProto_TYPE_ENUM {
return handler.HandleErrorf(file.NodeInfo(node.FieldType()).Start(), "%s: descriptor proto indicates type %v but should be %v", scope, fld.GetType(), descriptorpb.FieldDescriptorProto_TYPE_ENUM)
}
f.enumType = dsc
default:
return handler.HandleErrorf(file.NodeInfo(node.FieldType()).Start(), "%s: invalid type: %s is %s, not a message or enum", scope, dsc.FullName(), descriptorTypeWithArticle(dsc))
Expand Down Expand Up @@ -461,7 +476,10 @@ func resolveMethodTypes(m *mtdDescriptor, handler *reporter.Handler, scopes []sc
return err
}
} else {
mtd.InputType = proto.String("." + string(dsc.FullName()))
typeName := "." + string(dsc.FullName())
if mtd.GetInputType() != typeName {
mtd.InputType = proto.String(typeName)
}
m.inputType = msg
}

Expand All @@ -480,7 +498,10 @@ func resolveMethodTypes(m *mtdDescriptor, handler *reporter.Handler, scopes []sc
return err
}
} else {
mtd.OutputType = proto.String("." + string(dsc.FullName()))
typeName := "." + string(dsc.FullName())
if mtd.GetOutputType() != typeName {
mtd.OutputType = proto.String(typeName)
}
m.outputType = msg
}

Expand Down
Loading

0 comments on commit 2c27603

Please sign in to comment.