From 12fa78bcc92c56c8e730be736f7fb8f382ac2b32 Mon Sep 17 00:00:00 2001 From: Harvey Tuch Date: Thu, 23 Jan 2020 15:36:50 -0500 Subject: [PATCH] protobuf: refactor proto visitor pattern. Move the proto traversal handling in version_converter.cc to a standalone library. This lets us replace existing proto visitor patterns in common/protobuf/utility.cc for unexpected field checks. The redaction code is actually a bit more involved, so I'm not refactoring this; it needs to recurse through Any/TypedStruct. Ultimately we might want something like this, but it doesn't seem super helpful given we only have a the single instance of this right now. Risk level: Low Testing: Existing tests continue to pass. Signed-off-by: Harvey Tuch --- source/common/config/BUILD | 1 + source/common/config/version_converter.cc | 53 ++----------- source/common/protobuf/BUILD | 8 ++ source/common/protobuf/utility.cc | 94 +++++++++++------------ source/common/protobuf/visitor.cc | 50 ++++++++++++ source/common/protobuf/visitor.h | 43 +++++++++++ 6 files changed, 155 insertions(+), 94 deletions(-) create mode 100644 source/common/protobuf/visitor.cc create mode 100644 source/common/protobuf/visitor.h diff --git a/source/common/config/BUILD b/source/common/config/BUILD index 24769e4df38d..f03188f8d40f 100644 --- a/source/common/config/BUILD +++ b/source/common/config/BUILD @@ -352,6 +352,7 @@ envoy_cc_library( deps = [ ":api_type_oracle_lib", "//source/common/protobuf", + "//source/common/protobuf:visitor_lib", "//source/common/protobuf:well_known_lib", "@envoy_api//envoy/config/core/v3:pkg_cc_proto", ], diff --git a/source/common/config/version_converter.cc b/source/common/config/version_converter.cc index 54b14e996214..c23533f3141d 100644 --- a/source/common/config/version_converter.cc +++ b/source/common/config/version_converter.cc @@ -2,6 +2,7 @@ #include "common/common/assert.h" #include "common/config/api_type_oracle.h" +#include "common/protobuf/visitor.h" #include "common/protobuf/well_known.h" #include "absl/strings/match.h" @@ -13,46 +14,6 @@ namespace { const char DeprecatedFieldShadowPrefix[] = "hidden_envoy_deprecated_"; -class ProtoVisitor { -public: - virtual ~ProtoVisitor() = default; - - // Invoked when a field is visited, with the message, field descriptor and - // context. Returns a new context for use when traversing the sub-message in a - // field. - virtual const void* onField(Protobuf::Message&, const Protobuf::FieldDescriptor&, - const void* ctxt) { - return ctxt; - } - - // Invoked when a message is visited, with the message and a context. - virtual void onMessage(Protobuf::Message&, const void*){}; -}; - -// TODO(htuch): refactor these message visitor patterns into utility.cc and share with -// MessageUtil::checkForUnexpectedFields. -void traverseMutableMessage(ProtoVisitor& visitor, Protobuf::Message& message, const void* ctxt) { - visitor.onMessage(message, ctxt); - const Protobuf::Descriptor* descriptor = message.GetDescriptor(); - const Protobuf::Reflection* reflection = message.GetReflection(); - for (int i = 0; i < descriptor->field_count(); ++i) { - const Protobuf::FieldDescriptor* field = descriptor->field(i); - const void* field_ctxt = visitor.onField(message, *field, ctxt); - // If this is a message, recurse to scrub deprecated fields in the sub-message. - if (field->cpp_type() == Protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { - if (field->is_repeated()) { - const int size = reflection->FieldSize(message, field); - for (int j = 0; j < size; ++j) { - traverseMutableMessage(visitor, *reflection->MutableRepeatedMessage(&message, field, j), - field_ctxt); - } - } else if (reflection->HasField(message, field)) { - traverseMutableMessage(visitor, *reflection->MutableMessage(&message, field), field_ctxt); - } - } - } -} - // Reinterpret a Protobuf message as another Protobuf message by converting to // wire format and back. This only works for messages that can be effectively // duck typed this way, e.g. with a subtype relationship modulo field name. @@ -86,7 +47,7 @@ DynamicMessagePtr createForDescriptorWithCast(const Protobuf::Message& message, // internally, we later want to recover their original types. void annotateWithOriginalType(const Protobuf::Descriptor& prev_descriptor, Protobuf::Message& next_message) { - class TypeAnnotatingProtoVisitor : public ProtoVisitor { + class TypeAnnotatingProtoVisitor : public ProtobufMessage::ProtoVisitor { public: void onMessage(Protobuf::Message& message, const void* ctxt) override { const Protobuf::Descriptor* descriptor = message.GetDescriptor(); @@ -125,7 +86,7 @@ void annotateWithOriginalType(const Protobuf::Descriptor& prev_descriptor, } }; TypeAnnotatingProtoVisitor proto_visitor; - traverseMutableMessage(proto_visitor, next_message, &prev_descriptor); + ProtobufMessage::traverseMutableMessage(proto_visitor, next_message, &prev_descriptor); } } // namespace @@ -138,7 +99,7 @@ void VersionConverter::upgrade(const Protobuf::Message& prev_message, } void VersionConverter::eraseOriginalTypeInformation(Protobuf::Message& message) { - class TypeErasingProtoVisitor : public ProtoVisitor { + class TypeErasingProtoVisitor : public ProtobufMessage::ProtoVisitor { public: void onMessage(Protobuf::Message& message, const void*) override { const Protobuf::Reflection* reflection = message.GetReflection(); @@ -147,7 +108,7 @@ void VersionConverter::eraseOriginalTypeInformation(Protobuf::Message& message) } }; TypeErasingProtoVisitor proto_visitor; - traverseMutableMessage(proto_visitor, message, nullptr); + ProtobufMessage::traverseMutableMessage(proto_visitor, message, nullptr); } DynamicMessagePtr VersionConverter::recoverOriginal(const Protobuf::Message& upgraded_message) { @@ -224,7 +185,7 @@ void VersionConverter::prepareMessageForGrpcWire(Protobuf::Message& message, } void VersionUtil::scrubHiddenEnvoyDeprecated(Protobuf::Message& message) { - class HiddenFieldScrubbingProtoVisitor : public ProtoVisitor { + class HiddenFieldScrubbingProtoVisitor : public ProtobufMessage::ProtoVisitor { public: const void* onField(Protobuf::Message& message, const Protobuf::FieldDescriptor& field, const void*) override { @@ -236,7 +197,7 @@ void VersionUtil::scrubHiddenEnvoyDeprecated(Protobuf::Message& message) { } }; HiddenFieldScrubbingProtoVisitor proto_visitor; - traverseMutableMessage(proto_visitor, message, nullptr); + ProtobufMessage::traverseMutableMessage(proto_visitor, message, nullptr); } } // namespace Config diff --git a/source/common/protobuf/BUILD b/source/common/protobuf/BUILD index 34bd1f7e18d4..9a9aa1f30624 100644 --- a/source/common/protobuf/BUILD +++ b/source/common/protobuf/BUILD @@ -68,12 +68,20 @@ envoy_cc_library( "//source/common/common:utility_lib", "//source/common/config:api_type_oracle_lib", "//source/common/config:version_converter_lib", + "//source/common/protobuf:visitor_lib", "@com_github_cncf_udpa//udpa/annotations:pkg_cc_proto", "@envoy_api//envoy/annotations:pkg_cc_proto", "@envoy_api//envoy/type/v3:pkg_cc_proto", ], ) +envoy_cc_library( + name = "visitor_lib", + srcs = ["visitor.cc"], + hdrs = ["visitor.h"], + deps = [":protobuf"], +) + envoy_cc_library( name = "well_known_lib", hdrs = ["well_known.h"], diff --git a/source/common/protobuf/utility.cc b/source/common/protobuf/utility.cc index 6ed24e465434..ba73c7e6600a 100644 --- a/source/common/protobuf/utility.cc +++ b/source/common/protobuf/utility.cc @@ -13,6 +13,7 @@ #include "common/config/version_converter.h" #include "common/protobuf/message_validator_impl.h" #include "common/protobuf/protobuf.h" +#include "common/protobuf/visitor.h" #include "common/protobuf/well_known.h" #include "absl/strings/match.h" @@ -375,42 +376,25 @@ void checkForDeprecatedNonRepeatedEnumValue(const Protobuf::Message& message, } } -void checkForUnexpectedFields(const Protobuf::Message& message, - ProtobufMessage::ValidationVisitor& validation_visitor, - Runtime::Loader* runtime) { - // Reject unknown fields. - const auto& unknown_fields = message.GetReflection()->GetUnknownFields(message); - if (!unknown_fields.empty()) { - std::string error_msg; - for (int n = 0; n < unknown_fields.field_count(); ++n) { - if (unknown_fields.field(n).number() == ProtobufWellKnown::OriginalTypeFieldNumber) { - continue; - } - error_msg += absl::StrCat(n > 0 ? ", " : "", unknown_fields.field(n).number()); - } - // We use the validation visitor but have hard coded behavior below for deprecated fields. - // TODO(htuch): Unify the deprecated and unknown visitor handling behind the validation - // visitor pattern. https://github.com/envoyproxy/envoy/issues/8092. - if (!error_msg.empty()) { - validation_visitor.onUnknownField("type " + message.GetTypeName() + - " with unknown field set {" + error_msg + "}"); - } - } +class UnexpectedFieldProtoVisitor : public ProtobufMessage::ConstProtoVisitor { +public: + UnexpectedFieldProtoVisitor(ProtobufMessage::ValidationVisitor& validation_visitor, + Runtime::Loader* runtime) + : validation_visitor_(validation_visitor), runtime_(runtime) {} - const Protobuf::Descriptor* descriptor = message.GetDescriptor(); - const Protobuf::Reflection* reflection = message.GetReflection(); - for (int i = 0; i < descriptor->field_count(); ++i) { - const Protobuf::FieldDescriptor* field = descriptor->field(i); - absl::string_view filename = filenameFromPath(field->file()->name()); + const void* onField(const Protobuf::Message& message, const Protobuf::FieldDescriptor& field, + const void*) override { + const Protobuf::Reflection* reflection = message.GetReflection(); + absl::string_view filename = filenameFromPath(field.file()->name()); // Before we check to see if the field is in use, see if there's a // deprecated default enum value. - checkForDeprecatedNonRepeatedEnumValue(message, filename, field, reflection, runtime); + checkForDeprecatedNonRepeatedEnumValue(message, filename, &field, reflection, runtime_); // If this field is not in use, continue. - if ((field->is_repeated() && reflection->FieldSize(message, field) == 0) || - (!field->is_repeated() && !reflection->HasField(message, field))) { - continue; + if ((field.is_repeated() && reflection->FieldSize(message, &field) == 0) || + (!field.is_repeated() && !reflection->HasField(message, &field))) { + return nullptr; } #ifdef ENVOY_DISABLE_DEPRECATED_FEATURES @@ -421,22 +405,22 @@ void checkForUnexpectedFields(const Protobuf::Message& message, // Allow runtime to be null both to not crash if this is called before server initialization, // and so proto validation works in context where runtime singleton is not set up (e.g. // standalone config validation utilities) - if (runtime && field->options().deprecated()) { + if (runtime_ && field.options().deprecated()) { // This is set here, rather than above, so that in the absence of a // registry (i.e. test) the default for if a feature is allowed or not is // based on ENVOY_DISABLE_DEPRECATED_FEATURES. - warn_only &= !field->options().GetExtension(envoy::annotations::disallowed_by_default); - warn_only = runtime->snapshot().deprecatedFeatureEnabled( - absl::StrCat("envoy.deprecated_features:", field->full_name()), warn_only); + warn_only &= !field.options().GetExtension(envoy::annotations::disallowed_by_default); + warn_only = runtime_->snapshot().deprecatedFeatureEnabled( + absl::StrCat("envoy.deprecated_features:", field.full_name()), warn_only); } // If this field is deprecated, warn or throw an error. - if (field->options().deprecated()) { + if (field.options().deprecated()) { std::string err = fmt::format( "Using deprecated option '{}' from file {}. This configuration will be removed from " "Envoy soon. Please see https://www.envoyproxy.io/docs/envoy/latest/intro/deprecated " "for details.", - field->full_name(), filename); + field.full_name(), filename); if (warn_only) { ENVOY_LOG_MISC(warn, "{}", err); } else { @@ -448,29 +432,43 @@ void checkForUnexpectedFields(const Protobuf::Message& message, throw ProtoValidationException(err + fatal_error, message); } } + return nullptr; + } - // If this is a message, recurse to check for deprecated fields in the sub-message. - if (field->cpp_type() == Protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { - if (field->is_repeated()) { - const int size = reflection->FieldSize(message, field); - for (int j = 0; j < size; ++j) { - checkForUnexpectedFields(reflection->GetRepeatedMessage(message, field, j), - validation_visitor, runtime); + void onMessage(const Protobuf::Message& message, const void*) override { + // Reject unknown fields. + const auto& unknown_fields = message.GetReflection()->GetUnknownFields(message); + if (!unknown_fields.empty()) { + std::string error_msg; + for (int n = 0; n < unknown_fields.field_count(); ++n) { + if (unknown_fields.field(n).number() == ProtobufWellKnown::OriginalTypeFieldNumber) { + continue; } - } else { - checkForUnexpectedFields(reflection->GetMessage(message, field), validation_visitor, - runtime); + error_msg += absl::StrCat(n > 0 ? ", " : "", unknown_fields.field(n).number()); + } + // We use the validation visitor but have hard coded behavior below for deprecated fields. + // TODO(htuch): Unify the deprecated and unknown visitor handling behind the validation + // visitor pattern. https://github.com/envoyproxy/envoy/issues/8092. + if (!error_msg.empty()) { + validation_visitor_.onUnknownField("type " + message.GetTypeName() + + " with unknown field set {" + error_msg + "}"); } } } -} + +private: + ProtobufMessage::ValidationVisitor& validation_visitor_; + Runtime::Loader* runtime_; +}; } // namespace void MessageUtil::checkForUnexpectedFields(const Protobuf::Message& message, ProtobufMessage::ValidationVisitor& validation_visitor, Runtime::Loader* runtime) { - ::Envoy::checkForUnexpectedFields(API_RECOVER_ORIGINAL(message), validation_visitor, runtime); + UnexpectedFieldProtoVisitor unexpected_field_visitor(validation_visitor, runtime); + ProtobufMessage::traverseMessage(unexpected_field_visitor, API_RECOVER_ORIGINAL(message), + nullptr); } std::string MessageUtil::getYamlStringFromMessage(const Protobuf::Message& message, diff --git a/source/common/protobuf/visitor.cc b/source/common/protobuf/visitor.cc new file mode 100644 index 000000000000..b7a453c406bb --- /dev/null +++ b/source/common/protobuf/visitor.cc @@ -0,0 +1,50 @@ +#include "common/protobuf/visitor.h" + +namespace Envoy { +namespace ProtobufMessage { + +void traverseMutableMessage(ProtoVisitor& visitor, Protobuf::Message& message, const void* ctxt) { + visitor.onMessage(message, ctxt); + const Protobuf::Descriptor* descriptor = message.GetDescriptor(); + const Protobuf::Reflection* reflection = message.GetReflection(); + for (int i = 0; i < descriptor->field_count(); ++i) { + const Protobuf::FieldDescriptor* field = descriptor->field(i); + const void* field_ctxt = visitor.onField(message, *field, ctxt); + // If this is a message, recurse to scrub deprecated fields in the sub-message. + if (field->cpp_type() == Protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + if (field->is_repeated()) { + const int size = reflection->FieldSize(message, field); + for (int j = 0; j < size; ++j) { + traverseMutableMessage(visitor, *reflection->MutableRepeatedMessage(&message, field, j), + field_ctxt); + } + } else if (reflection->HasField(message, field)) { + traverseMutableMessage(visitor, *reflection->MutableMessage(&message, field), field_ctxt); + } + } + } +} +void traverseMessage(ConstProtoVisitor& visitor, const Protobuf::Message& message, + const void* ctxt) { + visitor.onMessage(message, ctxt); + const Protobuf::Descriptor* descriptor = message.GetDescriptor(); + const Protobuf::Reflection* reflection = message.GetReflection(); + for (int i = 0; i < descriptor->field_count(); ++i) { + const Protobuf::FieldDescriptor* field = descriptor->field(i); + const void* field_ctxt = visitor.onField(message, *field, ctxt); + // If this is a message, recurse to scrub deprecated fields in the sub-message. + if (field->cpp_type() == Protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + if (field->is_repeated()) { + const int size = reflection->FieldSize(message, field); + for (int j = 0; j < size; ++j) { + traverseMessage(visitor, reflection->GetRepeatedMessage(message, field, j), field_ctxt); + } + } else if (reflection->HasField(message, field)) { + traverseMessage(visitor, reflection->GetMessage(message, field), field_ctxt); + } + } + } +} + +} // namespace ProtobufMessage +} // namespace Envoy diff --git a/source/common/protobuf/visitor.h b/source/common/protobuf/visitor.h new file mode 100644 index 000000000000..93ec07af93b7 --- /dev/null +++ b/source/common/protobuf/visitor.h @@ -0,0 +1,43 @@ +#pragma once + +#include "common/protobuf/protobuf.h" + +namespace Envoy { +namespace ProtobufMessage { + +class ProtoVisitor { +public: + virtual ~ProtoVisitor() = default; + + // Invoked when a field is visited, with the message, field descriptor and context. Returns a new + // context for use when traversing the sub-message in a field. + virtual const void* onField(Protobuf::Message&, const Protobuf::FieldDescriptor&, + const void* ctxt) { + return ctxt; + } + + // Invoked when a message is visited, with the message and a context. + virtual void onMessage(Protobuf::Message&, const void*){}; +}; + +class ConstProtoVisitor { +public: + virtual ~ConstProtoVisitor() = default; + + // Invoked when a field is visited, with the message, field descriptor and context. Returns a new + // context for use when traversing the sub-message in a field. + virtual const void* onField(const Protobuf::Message&, const Protobuf::FieldDescriptor&, + const void* ctxt) { + return ctxt; + } + + // Invoked when a message is visited, with the message and a context. + virtual void onMessage(const Protobuf::Message&, const void*){}; +}; + +void traverseMutableMessage(ProtoVisitor& visitor, Protobuf::Message& message, const void* ctxt); +void traverseMessage(ConstProtoVisitor& visitor, const Protobuf::Message& message, + const void* ctxt); + +} // namespace ProtobufMessage +} // namespace Envoy