diff --git a/grpcreflect/adapt.go b/grpcreflect/adapt.go index 0d5615fe..661b9250 100644 --- a/grpcreflect/adapt.go +++ b/grpcreflect/adapt.go @@ -83,48 +83,48 @@ func toV1AlphaRequest(v1 *refv1.ServerReflectionRequest) *refv1alpha.ServerRefle return &v1alpha } -func toV1AlphaResponse(v1 *refv1.ServerReflectionResponse) *refv1alpha.ServerReflectionResponse { - var v1alpha refv1alpha.ServerReflectionResponse - v1alpha.ValidHost = v1.ValidHost - if v1.OriginalRequest != nil { - v1alpha.OriginalRequest = toV1AlphaRequest(v1.OriginalRequest) +func toV1Response(v1alpha *refv1alpha.ServerReflectionResponse) *refv1.ServerReflectionResponse { + var v1 refv1.ServerReflectionResponse + v1.ValidHost = v1alpha.ValidHost + if v1alpha.OriginalRequest != nil { + v1.OriginalRequest = toV1Request(v1alpha.OriginalRequest) } - switch mr := v1.MessageResponse.(type) { - case *refv1.ServerReflectionResponse_FileDescriptorResponse: + switch mr := v1alpha.MessageResponse.(type) { + case *refv1alpha.ServerReflectionResponse_FileDescriptorResponse: if mr != nil { - v1alpha.MessageResponse = &refv1alpha.ServerReflectionResponse_FileDescriptorResponse{ - FileDescriptorResponse: &refv1alpha.FileDescriptorResponse{ + v1.MessageResponse = &refv1.ServerReflectionResponse_FileDescriptorResponse{ + FileDescriptorResponse: &refv1.FileDescriptorResponse{ FileDescriptorProto: mr.FileDescriptorResponse.GetFileDescriptorProto(), }, } } - case *refv1.ServerReflectionResponse_AllExtensionNumbersResponse: + case *refv1alpha.ServerReflectionResponse_AllExtensionNumbersResponse: if mr != nil { - v1alpha.MessageResponse = &refv1alpha.ServerReflectionResponse_AllExtensionNumbersResponse{ - AllExtensionNumbersResponse: &refv1alpha.ExtensionNumberResponse{ + v1.MessageResponse = &refv1.ServerReflectionResponse_AllExtensionNumbersResponse{ + AllExtensionNumbersResponse: &refv1.ExtensionNumberResponse{ BaseTypeName: mr.AllExtensionNumbersResponse.GetBaseTypeName(), ExtensionNumber: mr.AllExtensionNumbersResponse.GetExtensionNumber(), }, } } - case *refv1.ServerReflectionResponse_ListServicesResponse: + case *refv1alpha.ServerReflectionResponse_ListServicesResponse: if mr != nil { - svcs := make([]*refv1alpha.ServiceResponse, len(mr.ListServicesResponse.GetService())) + svcs := make([]*refv1.ServiceResponse, len(mr.ListServicesResponse.GetService())) for i, svc := range mr.ListServicesResponse.GetService() { - svcs[i] = &refv1alpha.ServiceResponse{ + svcs[i] = &refv1.ServiceResponse{ Name: svc.GetName(), } } - v1alpha.MessageResponse = &refv1alpha.ServerReflectionResponse_ListServicesResponse{ - ListServicesResponse: &refv1alpha.ListServiceResponse{ + v1.MessageResponse = &refv1.ServerReflectionResponse_ListServicesResponse{ + ListServicesResponse: &refv1.ListServiceResponse{ Service: svcs, }, } } - case *refv1.ServerReflectionResponse_ErrorResponse: + case *refv1alpha.ServerReflectionResponse_ErrorResponse: if mr != nil { - v1alpha.MessageResponse = &refv1alpha.ServerReflectionResponse_ErrorResponse{ - ErrorResponse: &refv1alpha.ErrorResponse{ + v1.MessageResponse = &refv1.ServerReflectionResponse_ErrorResponse{ + ErrorResponse: &refv1.ErrorResponse{ ErrorCode: mr.ErrorResponse.GetErrorCode(), ErrorMessage: mr.ErrorResponse.GetErrorMessage(), }, @@ -133,5 +133,5 @@ func toV1AlphaResponse(v1 *refv1.ServerReflectionResponse) *refv1alpha.ServerRef default: // no value set } - return &v1alpha + return &v1 } diff --git a/grpcreflect/client.go b/grpcreflect/client.go index cb6bf568..1a35540c 100644 --- a/grpcreflect/client.go +++ b/grpcreflect/client.go @@ -7,6 +7,7 @@ import ( "io" "reflect" "runtime" + "sort" "sync" "sync/atomic" "time" @@ -17,6 +18,9 @@ import ( refv1 "google.golang.org/grpc/reflection/grpc_reflection_v1" refv1alpha "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" "google.golang.org/grpc/status" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/descriptorpb" "github.com/jhump/protoreflect/desc" @@ -134,26 +138,37 @@ type extDesc struct { extensionNumber int32 } +type resolvers struct { + descriptorResolver protodesc.Resolver + extensionResolver protoregistry.ExtensionTypeResolver +} + +type fileEntry struct { + fd *desc.FileDescriptor + fallback bool +} + // Client is a client connection to a server for performing reflection calls // and resolving remote symbols. type Client struct { - ctx context.Context - now func() time.Time - stubV1 refv1.ServerReflectionClient - stubV1Alpha refv1alpha.ServerReflectionClient - allowMissing atomic.Bool + ctx context.Context + now func() time.Time + stubV1 refv1.ServerReflectionClient + stubV1Alpha refv1alpha.ServerReflectionClient + allowMissing atomic.Bool + fallbackResolver atomic.Pointer[resolvers] connMu sync.Mutex cancel context.CancelFunc - stream refv1alpha.ServerReflection_ServerReflectionInfoClient + stream refv1.ServerReflection_ServerReflectionInfoClient useV1Alpha bool lastTriedV1 time.Time cacheMu sync.RWMutex protosByName map[string]*descriptorpb.FileDescriptorProto - filesByName map[string]*desc.FileDescriptor - filesBySymbol map[string]*desc.FileDescriptor - filesByExtension map[extDesc]*desc.FileDescriptor + filesByName map[string]fileEntry + filesBySymbol map[string]fileEntry + filesByExtension map[extDesc]fileEntry } // NewClient creates a new Client with the given root context and using the @@ -173,6 +188,12 @@ func NewClientV1Alpha(ctx context.Context, stub refv1alpha.ServerReflectionClien return newClient(ctx, nil, stub) } +// NewClientV1 creates a new Client using the v1 version of reflection with the +// given root context and using the given RPC stub for talking to the server. +func NewClientV1(ctx context.Context, stub refv1.ServerReflectionClient) *Client { + return newClient(ctx, stub, nil) +} + func newClient(ctx context.Context, stubv1 refv1.ServerReflectionClient, stubv1alpha refv1alpha.ServerReflectionClient) *Client { cr := &Client{ ctx: ctx, @@ -180,9 +201,9 @@ func newClient(ctx context.Context, stubv1 refv1.ServerReflectionClient, stubv1a stubV1: stubv1, stubV1Alpha: stubv1alpha, protosByName: map[string]*descriptorpb.FileDescriptorProto{}, - filesByName: map[string]*desc.FileDescriptor{}, - filesBySymbol: map[string]*desc.FileDescriptor{}, - filesByExtension: map[extDesc]*desc.FileDescriptor{}, + filesByName: map[string]fileEntry{}, + filesBySymbol: map[string]fileEntry{}, + filesByExtension: map[extDesc]fileEntry{}, } // don't leak a grpc stream runtime.SetFinalizer(cr, (*Client).Reset) @@ -214,28 +235,49 @@ func (cr *Client) AllowMissingFileDescriptors() { cr.allowMissing.Store(true) } -// TODO: We should also have a NewClientV1. However that should not refer to internal -// generated code. So it will have to wait until the grpc-go team fixes this issue: -// https://github.com/grpc/grpc-go/issues/5684 +// AllowFallbackResolver configures the client to allow falling back to the +// given resolvers if the server is unable to supply descriptors for a particular +// query. This allows working around issues where servers' reflection service +// provides an incomplete set of descriptors, but the client has knowledge of +// the missing descriptors from another source. It is usually most appropriate +// to pass [protoregistry.GlobalFiles] and [protoregistry.GlobalTypes] as the +// resolver values. +// +// The first value is used as a fallback for FileByFilename and FileContainingSymbol +// queries. The second value is used as a fallback for FileContainingExtension. It +// can also be used as a fallback for AllExtensionNumbersForType if it provides +// a method with the following signature (which *[protoregistry.Types] provides): +// +// RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool) +func (cr *Client) AllowFallbackResolver(descriptors protodesc.Resolver, exts protoregistry.ExtensionTypeResolver) { + if descriptors == nil && exts == nil { + cr.fallbackResolver.Store(nil) + } else { + cr.fallbackResolver.Store(&resolvers{ + descriptorResolver: descriptors, + extensionResolver: exts, + }) + } +} // FileByFilename asks the server for a file descriptor for the proto file with // the given name. func (cr *Client) FileByFilename(filename string) (*desc.FileDescriptor, error) { // hit the cache first cr.cacheMu.RLock() - if fd, ok := cr.filesByName[filename]; ok { + if entry, ok := cr.filesByName[filename]; ok { cr.cacheMu.RUnlock() - return fd, nil + return entry.fd, nil } + // not there? see if we've downloaded the proto fdp, ok := cr.protosByName[filename] cr.cacheMu.RUnlock() - // not there? see if we've downloaded the proto if ok { return cr.descriptorFromProto(fdp) } - req := &refv1alpha.ServerReflectionRequest{ - MessageRequest: &refv1alpha.ServerReflectionRequest_FileByFilename{ + req := &refv1.ServerReflectionRequest{ + MessageRequest: &refv1.ServerReflectionRequest_FileByFilename{ FileByFilename: filename, }, } @@ -245,23 +287,37 @@ func (cr *Client) FileByFilename(filename string) (*desc.FileDescriptor, error) fd, err := cr.getAndCacheFileDescriptors(req, filename, "", accept) if isNotFound(err) { - // file not found? see if we can look up via alternate name + // File not found? see if we can look up via alternate name if alternate, ok := internal.StdFileAliases[filename]; ok { - req := &refv1alpha.ServerReflectionRequest{ - MessageRequest: &refv1alpha.ServerReflectionRequest_FileByFilename{ + req := &refv1.ServerReflectionRequest{ + MessageRequest: &refv1.ServerReflectionRequest_FileByFilename{ FileByFilename: alternate, }, } fd, err = cr.getAndCacheFileDescriptors(req, alternate, filename, accept) - if isNotFound(err) { - err = fileNotFound(filename, nil) + } + } + if isNotFound(err) { + // Still no? See if we can use a fallback resolver + resolver := cr.fallbackResolver.Load() + if resolver != nil && resolver.descriptorResolver != nil { + fileDesc, fallbackErr := resolver.descriptorResolver.FindFileByPath(filename) + if fallbackErr == nil { + var wrapErr error + fd, wrapErr = desc.WrapFile(fileDesc) + if wrapErr == nil { + fd = cr.cacheFile(fd, true) + err = nil // clear error since we've succeeded via the fallback + } } - } else { - err = fileNotFound(filename, nil) } + } + if isNotFound(err) { + err = fileNotFound(filename, nil) } else if e, ok := err.(*elementNotFoundError); ok { err = fileNotFound(filename, e) } + return fd, err } @@ -270,14 +326,14 @@ func (cr *Client) FileByFilename(filename string) (*desc.FileDescriptor, error) func (cr *Client) FileContainingSymbol(symbol string) (*desc.FileDescriptor, error) { // hit the cache first cr.cacheMu.RLock() - fd, ok := cr.filesBySymbol[symbol] + entry, ok := cr.filesBySymbol[symbol] cr.cacheMu.RUnlock() if ok { - return fd, nil + return entry.fd, nil } - req := &refv1alpha.ServerReflectionRequest{ - MessageRequest: &refv1alpha.ServerReflectionRequest_FileContainingSymbol{ + req := &refv1.ServerReflectionRequest{ + MessageRequest: &refv1.ServerReflectionRequest_FileContainingSymbol{ FileContainingSymbol: symbol, }, } @@ -285,6 +341,21 @@ func (cr *Client) FileContainingSymbol(symbol string) (*desc.FileDescriptor, err return fd.FindSymbol(symbol) != nil } fd, err := cr.getAndCacheFileDescriptors(req, "", "", accept) + if isNotFound(err) { + // Symbol not found? See if we can use a fallback resolver + resolver := cr.fallbackResolver.Load() + if resolver != nil && resolver.descriptorResolver != nil { + d, fallbackErr := resolver.descriptorResolver.FindDescriptorByName(protoreflect.FullName(symbol)) + if fallbackErr == nil { + var wrapErr error + fd, wrapErr = desc.WrapFile(d.ParentFile()) + if wrapErr == nil { + fd = cr.cacheFile(fd, true) + err = nil // clear error since we've succeeded via the fallback + } + } + } + } if isNotFound(err) { err = symbolNotFound(symbol, symbolTypeUnknown, nil) } else if e, ok := err.(*elementNotFoundError); ok { @@ -299,15 +370,15 @@ func (cr *Client) FileContainingSymbol(symbol string) (*desc.FileDescriptor, err func (cr *Client) FileContainingExtension(extendedMessageName string, extensionNumber int32) (*desc.FileDescriptor, error) { // hit the cache first cr.cacheMu.RLock() - fd, ok := cr.filesByExtension[extDesc{extendedMessageName, extensionNumber}] + entry, ok := cr.filesByExtension[extDesc{extendedMessageName, extensionNumber}] cr.cacheMu.RUnlock() if ok { - return fd, nil + return entry.fd, nil } - req := &refv1alpha.ServerReflectionRequest{ - MessageRequest: &refv1alpha.ServerReflectionRequest_FileContainingExtension{ - FileContainingExtension: &refv1alpha.ExtensionRequest{ + req := &refv1.ServerReflectionRequest{ + MessageRequest: &refv1.ServerReflectionRequest_FileContainingExtension{ + FileContainingExtension: &refv1.ExtensionRequest{ ContainingType: extendedMessageName, ExtensionNumber: extensionNumber, }, @@ -317,6 +388,21 @@ func (cr *Client) FileContainingExtension(extendedMessageName string, extensionN return fd.FindExtension(extendedMessageName, extensionNumber) != nil } fd, err := cr.getAndCacheFileDescriptors(req, "", "", accept) + if isNotFound(err) { + // Extension not found? See if we can use a fallback resolver + resolver := cr.fallbackResolver.Load() + if resolver != nil && resolver.extensionResolver != nil { + extType, fallbackErr := resolver.extensionResolver.FindExtensionByNumber(protoreflect.FullName(extendedMessageName), protoreflect.FieldNumber(extensionNumber)) + if fallbackErr == nil { + var wrapErr error + fd, wrapErr = desc.WrapFile(extType.TypeDescriptor().ParentFile()) + if wrapErr == nil { + fd = cr.cacheFile(fd, true) + err = nil // clear error since we've succeeded via the fallback + } + } + } + } if isNotFound(err) { err = extensionNotFound(extendedMessageName, extensionNumber, nil) } else if e, ok := err.(*elementNotFoundError); ok { @@ -325,7 +411,7 @@ func (cr *Client) FileContainingExtension(extendedMessageName string, extensionN return fd, err } -func (cr *Client) getAndCacheFileDescriptors(req *refv1alpha.ServerReflectionRequest, expectedName, alias string, accept func(*desc.FileDescriptor) bool) (*desc.FileDescriptor, error) { +func (cr *Client) getAndCacheFileDescriptors(req *refv1.ServerReflectionRequest, expectedName, alias string, accept func(*desc.FileDescriptor) bool) (*desc.FileDescriptor, error) { resp, err := cr.send(req) if err != nil { return nil, err @@ -412,96 +498,159 @@ func (cr *Client) descriptorFromProto(fd *descriptorpb.FileDescriptorProto) (*de } return nil, err } - d = cr.cacheFile(d) + d = cr.cacheFile(d, false) return d, nil } -func (cr *Client) cacheFile(fd *desc.FileDescriptor) *desc.FileDescriptor { +func (cr *Client) cacheFile(fd *desc.FileDescriptor, fallback bool) *desc.FileDescriptor { cr.cacheMu.Lock() defer cr.cacheMu.Unlock() - // cache file descriptor by name, but don't overwrite existing entry - // (existing entry could come from concurrent caller) - if existingFd, ok := cr.filesByName[fd.GetName()]; ok { - return existingFd + // Cache file descriptor by name. If we can't overwrite an existing + // entry, return it. (Existing entry could come from concurrent caller.) + if existing, ok := cr.filesByName[fd.GetName()]; ok && !canOverwrite(existing, fallback) { + return existing.fd } - cr.filesByName[fd.GetName()] = fd + entry := fileEntry{fd: fd, fallback: fallback} + cr.filesByName[fd.GetName()] = entry // also cache by symbols and extensions for _, m := range fd.GetMessageTypes() { - cr.cacheMessageLocked(fd, m) + cr.cacheMessageLocked(m, entry) } for _, e := range fd.GetEnumTypes() { - cr.filesBySymbol[e.GetFullyQualifiedName()] = fd + if !cr.maybeCacheFileBySymbol(e.GetFullyQualifiedName(), entry) { + continue + } for _, v := range e.GetValues() { - cr.filesBySymbol[v.GetFullyQualifiedName()] = fd + cr.maybeCacheFileBySymbol(v.GetFullyQualifiedName(), entry) } } for _, e := range fd.GetExtensions() { - cr.filesBySymbol[e.GetFullyQualifiedName()] = fd - cr.filesByExtension[extDesc{e.GetOwner().GetFullyQualifiedName(), e.GetNumber()}] = fd + if !cr.maybeCacheFileBySymbol(e.GetFullyQualifiedName(), entry) { + continue + } + cr.maybeCacheFileByExtension(extDesc{e.GetOwner().GetFullyQualifiedName(), e.GetNumber()}, entry) } for _, s := range fd.GetServices() { - cr.filesBySymbol[s.GetFullyQualifiedName()] = fd + if !cr.maybeCacheFileBySymbol(s.GetFullyQualifiedName(), entry) { + continue + } for _, m := range s.GetMethods() { - cr.filesBySymbol[m.GetFullyQualifiedName()] = fd + cr.maybeCacheFileBySymbol(m.GetFullyQualifiedName(), entry) } } return fd } -func (cr *Client) cacheMessageLocked(fd *desc.FileDescriptor, md *desc.MessageDescriptor) { - cr.filesBySymbol[md.GetFullyQualifiedName()] = fd +func (cr *Client) cacheMessageLocked(md *desc.MessageDescriptor, entry fileEntry) { + if !cr.maybeCacheFileBySymbol(md.GetFullyQualifiedName(), entry) { + return + } for _, f := range md.GetFields() { - cr.filesBySymbol[f.GetFullyQualifiedName()] = fd + cr.maybeCacheFileBySymbol(f.GetFullyQualifiedName(), entry) } for _, o := range md.GetOneOfs() { - cr.filesBySymbol[o.GetFullyQualifiedName()] = fd + cr.maybeCacheFileBySymbol(o.GetFullyQualifiedName(), entry) } for _, e := range md.GetNestedEnumTypes() { - cr.filesBySymbol[e.GetFullyQualifiedName()] = fd + if !cr.maybeCacheFileBySymbol(e.GetFullyQualifiedName(), entry) { + continue + } for _, v := range e.GetValues() { - cr.filesBySymbol[v.GetFullyQualifiedName()] = fd + cr.maybeCacheFileBySymbol(v.GetFullyQualifiedName(), entry) } } for _, e := range md.GetNestedExtensions() { - cr.filesBySymbol[e.GetFullyQualifiedName()] = fd - cr.filesByExtension[extDesc{e.GetOwner().GetFullyQualifiedName(), e.GetNumber()}] = fd + if !cr.maybeCacheFileBySymbol(e.GetFullyQualifiedName(), entry) { + continue + } + cr.maybeCacheFileByExtension(extDesc{e.GetOwner().GetFullyQualifiedName(), e.GetNumber()}, entry) } for _, m := range md.GetNestedMessageTypes() { - cr.cacheMessageLocked(fd, m) // recurse + cr.cacheMessageLocked(m, entry) // recurse } } +func canOverwrite(existing fileEntry, fallback bool) bool { + return !fallback && existing.fallback +} + +func (cr *Client) maybeCacheFileBySymbol(symbol string, entry fileEntry) bool { + existing, ok := cr.filesBySymbol[symbol] + if ok && !canOverwrite(existing, entry.fallback) { + return false + } + cr.filesBySymbol[symbol] = entry + return true +} + +func (cr *Client) maybeCacheFileByExtension(ext extDesc, entry fileEntry) { + existing, ok := cr.filesByExtension[ext] + if ok && !canOverwrite(existing, entry.fallback) { + return + } + cr.filesByExtension[ext] = entry +} + // AllExtensionNumbersForType asks the server for all known extension numbers // for the given fully-qualified message name. func (cr *Client) AllExtensionNumbersForType(extendedMessageName string) ([]int32, error) { - req := &refv1alpha.ServerReflectionRequest{ - MessageRequest: &refv1alpha.ServerReflectionRequest_AllExtensionNumbersOfType{ + req := &refv1.ServerReflectionRequest{ + MessageRequest: &refv1.ServerReflectionRequest_AllExtensionNumbersOfType{ AllExtensionNumbersOfType: extendedMessageName, }, } resp, err := cr.send(req) - if err != nil { - if isNotFound(err) { - return nil, symbolNotFound(extendedMessageName, symbolTypeMessage, nil) - } + var exts []int32 + if err != nil && !isNotFound(err) { + // If the server doesn't know about the message type and returns "not found", + // we'll treat that as "no known extensions" instead of returning an error. return nil, err } + if err == nil { + extResp := resp.GetAllExtensionNumbersResponse() + if extResp == nil { + return nil, &ProtocolError{reflect.TypeOf(extResp).Elem()} + } + exts = extResp.ExtensionNumber + } - extResp := resp.GetAllExtensionNumbersResponse() - if extResp == nil { - return nil, &ProtocolError{reflect.TypeOf(extResp).Elem()} + resolver := cr.fallbackResolver.Load() + if resolver != nil && resolver.extensionResolver != nil { + type extRanger interface { + RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool) + } + if ranger, ok := resolver.extensionResolver.(extRanger); ok { + // Merge results with fallback resolver + extSet := map[int32]struct{}{} + ranger.RangeExtensionsByMessage(protoreflect.FullName(extendedMessageName), func(extType protoreflect.ExtensionType) bool { + extSet[int32(extType.TypeDescriptor().Number())] = struct{}{} + return true + }) + if len(extSet) > 0 { + // De-dupe with the set of extension numbers we got + // from the server and merge the results back into exts. + for _, ext := range exts { + extSet[ext] = struct{}{} + } + exts = make([]int32, 0, len(extSet)) + for ext := range extSet { + exts = append(exts, ext) + } + sort.Slice(exts, func(i, j int) bool { return exts[i] < exts[j] }) + } + } } - return extResp.ExtensionNumber, nil + return exts, nil } // ListServices asks the server for the fully-qualified names of all exposed // services. func (cr *Client) ListServices() ([]string, error) { - req := &refv1alpha.ServerReflectionRequest{ - MessageRequest: &refv1alpha.ServerReflectionRequest_ListServices{ + req := &refv1.ServerReflectionRequest{ + MessageRequest: &refv1.ServerReflectionRequest_ListServices{ // proto doesn't indicate any purpose for this value and server impl // doesn't actually use it... ListServices: "*", @@ -523,7 +672,7 @@ func (cr *Client) ListServices() ([]string, error) { return serviceNames, nil } -func (cr *Client) send(req *refv1alpha.ServerReflectionRequest) (*refv1alpha.ServerReflectionResponse, error) { +func (cr *Client) send(req *refv1.ServerReflectionRequest) (*refv1.ServerReflectionResponse, error) { // we allow one immediate retry, in case we have a stale stream // (e.g. closed by server) resp, err := cr.doSend(req) @@ -548,7 +697,7 @@ func isNotFound(err error) bool { return ok && s.Code() == codes.NotFound } -func (cr *Client) doSend(req *refv1alpha.ServerReflectionRequest) (*refv1alpha.ServerReflectionResponse, error) { +func (cr *Client) doSend(req *refv1.ServerReflectionRequest) (*refv1.ServerReflectionResponse, error) { // TODO: Streams are thread-safe, so we shouldn't need to lock. But without locking, we'll need more machinery // (goroutines and channels) to ensure that responses are correctly correlated with their requests and thus // delivered in correct oder. @@ -557,7 +706,7 @@ func (cr *Client) doSend(req *refv1alpha.ServerReflectionRequest) (*refv1alpha.S return cr.doSendLocked(0, nil, req) } -func (cr *Client) doSendLocked(attemptCount int, prevErr error, req *refv1alpha.ServerReflectionRequest) (*refv1alpha.ServerReflectionResponse, error) { +func (cr *Client) doSendLocked(attemptCount int, prevErr error, req *refv1.ServerReflectionRequest) (*refv1.ServerReflectionResponse, error) { if attemptCount >= 3 && prevErr != nil { return nil, prevErr } @@ -610,7 +759,7 @@ func (cr *Client) initStreamLocked() error { // try the v1 API streamv1, err := cr.stubV1.ServerReflectionInfo(newCtx) if err == nil { - cr.stream = adaptStreamFromV1{streamv1} + cr.stream = streamv1 return nil } if status.Code(err) != codes.Unimplemented { @@ -622,7 +771,11 @@ func (cr *Client) initStreamLocked() error { cr.lastTriedV1 = cr.now() } var err error - cr.stream, err = cr.stubV1Alpha.ServerReflectionInfo(newCtx) + streamv1alpha, err := cr.stubV1Alpha.ServerReflectionInfo(newCtx) + if err == nil { + cr.stream = adaptStreamFromV1Alpha{streamv1alpha} + return nil + } return err } @@ -640,7 +793,7 @@ func (cr *Client) Reset() { func (cr *Client) resetLocked() { if cr.stream != nil { - cr.stream.CloseSend() + _ = cr.stream.CloseSend() for { // drain the stream, this covers io.EOF too if _, err := cr.stream.Recv(); err != nil { @@ -847,19 +1000,19 @@ func (mde msgDescriptorExtensions) nestedScopes() []extensionScope { return scopes } -type adaptStreamFromV1 struct { - refv1.ServerReflection_ServerReflectionInfoClient +type adaptStreamFromV1Alpha struct { + refv1alpha.ServerReflection_ServerReflectionInfoClient } -func (a adaptStreamFromV1) Send(request *refv1alpha.ServerReflectionRequest) error { - v1req := toV1Request(request) +func (a adaptStreamFromV1Alpha) Send(request *refv1.ServerReflectionRequest) error { + v1req := toV1AlphaRequest(request) return a.ServerReflection_ServerReflectionInfoClient.Send(v1req) } -func (a adaptStreamFromV1) Recv() (*refv1alpha.ServerReflectionResponse, error) { +func (a adaptStreamFromV1Alpha) Recv() (*refv1.ServerReflectionResponse, error) { v1resp, err := a.ServerReflection_ServerReflectionInfoClient.Recv() if err != nil { return nil, err } - return toV1AlphaResponse(v1resp), nil + return toV1Response(v1resp), nil } diff --git a/grpcreflect/client_test.go b/grpcreflect/client_test.go index 67ba0fc5..d744905a 100644 --- a/grpcreflect/client_test.go +++ b/grpcreflect/client_test.go @@ -19,6 +19,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/reflection" + reflectv1 "google.golang.org/grpc/reflection/grpc_reflection_v1" reflectv1alpha "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" "google.golang.org/grpc/status" "google.golang.org/protobuf/reflect/protodesc" @@ -208,8 +209,9 @@ func TestAllExtensionNumbersForType(t *testing.T) { sort.Ints(inums) testutil.Eq(t, []int{100, 101, 102, 103, 200}, inums) - _, err = client.AllExtensionNumbersForType("does not exist") - testutil.Require(t, IsElementNotFoundError(err)) + nums, err = client.AllExtensionNumbersForType("does not exist") + testutil.Ok(t, err) + testutil.Eq(t, 0, len(nums)) } func TestListServices(t *testing.T) { @@ -287,7 +289,7 @@ func TestMultipleFiles(t *testing.T) { dialCtx, dialCancel := context.WithTimeout(ctx, 3*time.Second) defer dialCancel() cc, err := grpc.DialContext(dialCtx, l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) - testutil.Ok(t, err, "failed ot dial %v", l.Addr().String()) + testutil.Ok(t, err, "failed to dial %v", l.Addr().String()) cl := reflectv1alpha.NewServerReflectionClient(cc) client := NewClientV1Alpha(ctx, cl) @@ -331,7 +333,7 @@ func TestAllowMissingFileDescriptors(t *testing.T) { dialCtx, dialCancel := context.WithTimeout(ctx, 3*time.Second) defer dialCancel() cc, err := grpc.DialContext(dialCtx, l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) - testutil.Ok(t, err, "failed ot dial %v", l.Addr().String()) + testutil.Ok(t, err, "failed to dial %v", l.Addr().String()) cl := reflectv1alpha.NewServerReflectionClient(cc) client := NewClientV1Alpha(ctx, cl) @@ -351,14 +353,106 @@ func TestAllowMissingFileDescriptors(t *testing.T) { testutil.Ok(t, err) testutil.Require(t, file != nil) testutil.Eq(t, "foo/bar/this.proto", file.GetName()) + file, err = client.FileContainingSymbol("foo.bar.Bar") + testutil.Ok(t, err) + testutil.Require(t, file != nil) + testutil.Eq(t, "foo/bar/this.proto", file.GetName()) + file, err = client.FileContainingExtension("google.protobuf.MessageOptions", 10101) + testutil.Ok(t, err) + testutil.Require(t, file != nil) + testutil.Eq(t, "test/imported.proto", file.GetName()) +} + +func TestAllowFallbackResolver(t *testing.T) { + svr := grpc.NewServer() + reflection.RegisterV1(svr) + + l, err := net.Listen("tcp", "127.0.0.1:0") + testutil.Ok(t, err, "failed to listen") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + defer cancel() + if err := svr.Serve(l); err != nil { + t.Logf("serve returned error: %v", err) + } + }() + time.Sleep(100 * time.Millisecond) // give server a chance to start + testutil.Ok(t, ctx.Err(), "failed to start server") + defer func() { + svr.Stop() + }() + + dialCtx, dialCancel := context.WithTimeout(ctx, 3*time.Second) + defer dialCancel() + cc, err := grpc.DialContext(dialCtx, l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) + testutil.Ok(t, err, "failed to dial %v", l.Addr().String()) + cl := reflectv1.NewServerReflectionClient(cc) + + client := NewClientV1(ctx, cl) + defer client.Reset() + + // First sanity-check that the well-known types are there. + file, err := client.FileByFilename("google/protobuf/descriptor.proto") + testutil.Ok(t, err) + testutil.Eq(t, "google/protobuf/descriptor.proto", file.GetName()) + // Now we try some things that should fail due to missing descriptors. + _, err = client.FileByFilename("foo/bar/this.proto") + testutil.Nok(t, err) _, err = client.FileContainingSymbol("foo.bar.Bar") + testutil.Nok(t, err) + file, err = client.FileContainingExtension("google.protobuf.MessageOptions", 23456) + testutil.Nok(t, err) + nums, err := client.AllExtensionNumbersForType("google.protobuf.MessageOptions") + testutil.Ok(t, err) + withoutFallbackExts := len(nums) + + // Now we configure a fallback. + fdp := &descriptorpb.FileDescriptorProto{ + Name: proto.String("foo/bar/this.proto"), + Package: proto.String("foo.bar"), + Dependency: []string{"google/protobuf/descriptor.proto"}, + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("Bar"), + }, + }, + Extension: []*descriptorpb.FieldDescriptorProto{ + { + Name: proto.String("opt"), + Extendee: proto.String(".google.protobuf.MessageOptions"), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_MESSAGE.Enum(), + TypeName: proto.String(".foo.bar.Bar"), + Number: proto.Int32(23456), + }, + }, + } + fd, err := protodesc.NewFile(fdp, protoregistry.GlobalFiles) + testutil.Ok(t, err) + var files files + err = files.RegisterFile(fd) + testutil.Ok(t, err) + + client.AllowFallbackResolver(&files, &files) + + // The above queries should now succeed. + file, err = client.FileByFilename("foo/bar/this.proto") testutil.Ok(t, err) testutil.Require(t, file != nil) testutil.Eq(t, "foo/bar/this.proto", file.GetName()) - _, err = client.FileContainingExtension("google.protobuf.MessageOptions", 10101) + file, err = client.FileContainingSymbol("foo.bar.Bar") testutil.Ok(t, err) testutil.Require(t, file != nil) testutil.Eq(t, "foo/bar/this.proto", file.GetName()) + file, err = client.FileContainingExtension("google.protobuf.MessageOptions", 23456) + testutil.Ok(t, err) + testutil.Require(t, file != nil) + testutil.Eq(t, "foo/bar/this.proto", file.GetName()) + nums, err = client.AllExtensionNumbersForType("google.protobuf.MessageOptions") + testutil.Ok(t, err) + // The same extensions as before, plus an extra one provided by the fallback. + testutil.Eq(t, withoutFallbackExts+1, len(nums)) } func TestFileWithoutDeps(t *testing.T) {