diff --git a/.mapping.json b/.mapping.json index 95f3b9643baa..a336d61e8041 100644 --- a/.mapping.json +++ b/.mapping.json @@ -1865,6 +1865,7 @@ "grpc/include/userver/ugrpc/impl/deadline_timepoint.hpp":"taxi/uservices/userver/grpc/include/userver/ugrpc/impl/deadline_timepoint.hpp", "grpc/include/userver/ugrpc/impl/internal_tag_fwd.hpp":"taxi/uservices/userver/grpc/include/userver/ugrpc/impl/internal_tag_fwd.hpp", "grpc/include/userver/ugrpc/impl/maybe_owned_string.hpp":"taxi/uservices/userver/grpc/include/userver/ugrpc/impl/maybe_owned_string.hpp", + "grpc/include/userver/ugrpc/impl/protobuf_collector.hpp":"taxi/uservices/userver/grpc/include/userver/ugrpc/impl/protobuf_collector.hpp", "grpc/include/userver/ugrpc/impl/queue_runner.hpp":"taxi/uservices/userver/grpc/include/userver/ugrpc/impl/queue_runner.hpp", "grpc/include/userver/ugrpc/impl/span.hpp":"taxi/uservices/userver/grpc/include/userver/ugrpc/impl/span.hpp", "grpc/include/userver/ugrpc/impl/static_metadata.hpp":"taxi/uservices/userver/grpc/include/userver/ugrpc/impl/static_metadata.hpp", @@ -1958,6 +1959,7 @@ "grpc/src/ugrpc/impl/internal_tag.hpp":"taxi/uservices/userver/grpc/src/ugrpc/impl/internal_tag.hpp", "grpc/src/ugrpc/impl/logging.cpp":"taxi/uservices/userver/grpc/src/ugrpc/impl/logging.cpp", "grpc/src/ugrpc/impl/logging.hpp":"taxi/uservices/userver/grpc/src/ugrpc/impl/logging.hpp", + "grpc/src/ugrpc/impl/protobuf_collector.cpp":"taxi/uservices/userver/grpc/src/ugrpc/impl/protobuf_collector.cpp", "grpc/src/ugrpc/impl/protobuf_utils.cpp":"taxi/uservices/userver/grpc/src/ugrpc/impl/protobuf_utils.cpp", "grpc/src/ugrpc/impl/protobuf_utils.hpp":"taxi/uservices/userver/grpc/src/ugrpc/impl/protobuf_utils.hpp", "grpc/src/ugrpc/impl/queue_runner.cpp":"taxi/uservices/userver/grpc/src/ugrpc/impl/queue_runner.cpp", @@ -2030,6 +2032,7 @@ "grpc/tests/generic_server_test.cpp":"taxi/uservices/userver/grpc/tests/generic_server_test.cpp", "grpc/tests/logging_test.cpp":"taxi/uservices/userver/grpc/tests/logging_test.cpp", "grpc/tests/middlewares_test.cpp":"taxi/uservices/userver/grpc/tests/middlewares_test.cpp", + "grpc/tests/protobuf_collector_test.cpp":"taxi/uservices/userver/grpc/tests/protobuf_collector_test.cpp", "grpc/tests/protobuf_visit_test.cpp":"taxi/uservices/userver/grpc/tests/protobuf_visit_test.cpp", "grpc/tests/secret_fields_test.cpp":"taxi/uservices/userver/grpc/tests/secret_fields_test.cpp", "grpc/tests/serialization_test.cpp":"taxi/uservices/userver/grpc/tests/serialization_test.cpp", diff --git a/grpc/include/userver/ugrpc/impl/protobuf_collector.hpp b/grpc/include/userver/ugrpc/impl/protobuf_collector.hpp new file mode 100644 index 000000000000..c6780052e3cf --- /dev/null +++ b/grpc/include/userver/ugrpc/impl/protobuf_collector.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include +#include + +#include + +USERVER_NAMESPACE_BEGIN + +namespace ugrpc::impl { + +/// @brief Registers multiple message types during static initialization time +void RegisterMessageTypes(std::initializer_list type_names); + +/// @brief Find all known messages +/// +/// @warning This is probably not an exhaustive list! +DescriptorList GetGeneratedMessages(); + +} // namespace ugrpc::impl + +USERVER_NAMESPACE_END diff --git a/grpc/include/userver/ugrpc/protobuf_visit.hpp b/grpc/include/userver/ugrpc/protobuf_visit.hpp index 0c5ae08d7224..122141b05b4c 100644 --- a/grpc/include/userver/ugrpc/protobuf_visit.hpp +++ b/grpc/include/userver/ugrpc/protobuf_visit.hpp @@ -3,11 +3,22 @@ /// @file userver/ugrpc/protobuf_visit.hpp /// @brief Utilities for visiting the fields of protobufs +#include +#include +#include +#include +#include + +#include + #include +#include +#include namespace google::protobuf { class Message; +class Descriptor; class FieldDescriptor; } // namespace google::protobuf @@ -16,22 +27,293 @@ USERVER_NAMESPACE_BEGIN namespace ugrpc { -using MessageVisitor = utils::function_ref; +using MessageVisitCallback = + utils::function_ref; -using FieldVisitor = utils::function_ref; /// @brief Execute a callback for all non-empty fields of the message. -void VisitFields(google::protobuf::Message& message, FieldVisitor callback); +void VisitFields(google::protobuf::Message& message, + FieldVisitCallback callback); /// @brief Execute a callback for the message and its non-empty submessages. void VisitMessagesRecursive(google::protobuf::Message& message, - MessageVisitor callback); + MessageVisitCallback callback); /// @brief Execute a callback for all fields /// of the message and its non-empty submessages. void VisitFieldsRecursive(google::protobuf::Message& message, - FieldVisitor callback); + FieldVisitCallback callback); + +using DescriptorList = std::vector; + +using FieldDescriptorList = + std::vector; + +/// @brief Get the descriptors of fields in the message. +FieldDescriptorList GetFieldDescriptors( + const google::protobuf::Descriptor& descriptor); + +/// @brief Get the descriptors of current and nested messages. +DescriptorList GetNestedMessageDescriptors( + const google::protobuf::Descriptor& descriptor); + +/// @brief Find a generated type by name. +const google::protobuf::Descriptor* FindGeneratedMessage(std::string_view name); + +/// @brief Find the field of a generated type by name. +const google::protobuf::FieldDescriptor* FindField( + const google::protobuf::Descriptor* descriptor, std::string_view field); + +/// @brief Base class for @ref FieldsVisitor and @ref MessagesVisitor. +/// Provides the interface and manages the descriptor graph +/// to enable the visitors to find all selected structures. +template +class BaseVisitor { + public: + enum class LockBehavior { + /// @brief Do not take shared_mutex locks for any operation on the visitor + kNone = 0, + + /// @brief Take shared_lock for all read operations on the visitor + /// and unique_lock for all Compile operations + kShared = 1 + }; + + BaseVisitor(BaseVisitor&&) = delete; + BaseVisitor(const BaseVisitor&) = delete; + + /// @brief Compiles the visitor for the given message type + /// and its dependent types + void Compile(const google::protobuf::Descriptor* descriptor); + + /// @brief Compiles the visitor for the given message types + /// and their dependent types + void Compile(const DescriptorList& descriptors); + + /// @brief Compiles the visitor for the given + /// generated message type and its dependent types + void CompileGenerated(std::string_view message_name) { + Compile(FindGeneratedMessage(message_name)); + } + + /// @brief Compiles the visitor for the given + /// generated message type and their dependent types + void CompileGenerated(utils::span message_names) { + DescriptorList descriptors; + for (const std::string_view& message_name : message_names) { + descriptors.push_back(FindGeneratedMessage(message_name)); + } + Compile(descriptors); + } + + /// @brief Execute a callback without recursion + /// + /// Equivalent to @ref VisitFields + /// but utilizes the precompilation data from @ref Compile + void Visit(google::protobuf::Message& message, Callback callback); + + /// @brief Execute a callback recursively + /// + /// Equivalent to @ref VisitFieldsRecursive and @ref VisitMessagesRecursive + /// but utilizes the precompilation data from @ref Compile + void VisitRecursive(google::protobuf::Message& message, Callback callback); + + /// @cond + /// Only for internal use. + using Dependencies = std::unordered_map< + const google::protobuf::Descriptor*, + std::unordered_set>; + + /// Only for internal use. + using DescriptorSet = std::unordered_set; + + /// Only for internal use. + using FieldDescriptorSet = + std::unordered_set; + + /// Only for internal use. + const Dependencies& GetFieldsWithSelectedChildren( + utils::impl::InternalTag) const { + return fields_with_selected_children_; + } + + /// Only for internal use. + const Dependencies& GetReverseEdges(utils::impl::InternalTag) const { + return reverse_edges_; + } + + /// Only for internal use. + const DescriptorSet& GetPropagated(utils::impl::InternalTag) const { + return propagated_; + } + + /// Only for internal use. + const DescriptorSet& GetCompiled(utils::impl::InternalTag) const { + return compiled_; + } + /// @endcond + + protected: + /// @cond + explicit BaseVisitor(LockBehavior lock_behavior) + : lock_behavior_(lock_behavior) {} + + // Disallow destruction via pointer to base + ~BaseVisitor() = default; + + /// @brief Compile one message without nested. + virtual void CompileOne(const google::protobuf::Descriptor& descriptor) = 0; + + /// @brief Checks if the message is selected or has anything selected. + virtual bool IsSelected(const google::protobuf::Descriptor&) const = 0; + + /// @brief Execute a callback without recursion + virtual void DoVisit(google::protobuf::Message& message, + Callback callback) = 0; + /// @endcond + + private: + /// @brief Gets all submessages of the given messages. + DescriptorSet GetFullSubtrees(const DescriptorList& descriptors) const; + + /// @brief Propagate the selection information upwards + void PropagateSelected(const google::protobuf::Descriptor* descriptor); + + /// @brief Safe version with recursion_limit + void VisitRecursiveImpl(google::protobuf::Message& message, Callback callback, + int recursion_limit); + + std::shared_mutex mutex_; + const LockBehavior lock_behavior_; + + Dependencies fields_with_selected_children_; + Dependencies reverse_edges_; + DescriptorSet propagated_; + DescriptorSet compiled_; +}; + +/// @brief Collects knowledge of the structure of the protobuf messages +/// allowing for efficient loops over fields to apply a callback to the ones +/// selected by the 'selector' function. +/// +/// If you do not have static knowledge of the required fields, you should +/// use @ref VisitFields or @ref VisitFieldsRecursive that are equivalent to +/// FieldsVisitor with a `return true;` selector. +/// +/// @warning You should not construct this at runtime as it performs significant +/// computations in the constructor to precompile the visitors. +/// You should create this ones at start-up. +/// +/// Example usage: @snippet grpc/src/ugrpc/impl/protobuf_utils.cpp +class FieldsVisitor final : public BaseVisitor { + public: + using Selector = + utils::function_ref; + + /// @brief Creates the visitor with the given selector + /// and compiles it for the message types we can find. + explicit FieldsVisitor(Selector selector); + + /// @brief Creates the visitor with the given selector + /// and compiles it for the given message types and their fields recursively. + FieldsVisitor(Selector selector, const DescriptorList& descriptors); + + /// @brief Creates the visitor with custom thread locking behavior + /// and the given selector for runtime compilation. + /// + /// @warning Do not use this unless you know what you are doing. + FieldsVisitor(Selector selector, LockBehavior lock_behavior); + + /// @brief Creates the visitor with custom thread locking behavior + /// and the given selector; compiles it for the given message types. + /// + /// @warning Do not use this unless you know what you are doing. + FieldsVisitor(Selector selector, const DescriptorList& descriptors, + LockBehavior lock_behavior); + + /// @cond + /// Only for internal use. + const Dependencies& GetSelectedFields(utils::impl::InternalTag) const { + return selected_fields_; + } + /// @endcond + + private: + void CompileOne(const google::protobuf::Descriptor& descriptor) override; + + bool IsSelected( + const google::protobuf::Descriptor& descriptor) const override { + return selected_fields_.find(&descriptor) != selected_fields_.end(); + } + + void DoVisit(google::protobuf::Message& message, + FieldVisitCallback callback) override; + + Dependencies selected_fields_; + const Selector selector_; +}; + +/// @brief Collects knowledge of the structure of the protobuf messages +/// allowing for efficient loops over nested messages to apply a callback +/// to the ones selected by the 'selector' function. +/// +/// If you do not have static knowledge of the required messages, you should +/// use @ref VisitMessagesRecursive that is equivalent to +/// MessagesVisitor with a 'return true' selector. +/// +/// @warning You should not construct this at runtime as it performs significant +/// computations in the constructor to precompile the visitors. +/// You should create this ones at start-up. +class MessagesVisitor final : public BaseVisitor { + public: + using Selector = + utils::function_ref; + + /// @brief Creates the visitor with the given selector for runtime compilation + /// and compiles it for the message types we can find. + explicit MessagesVisitor(Selector selector); + + /// @brief Creates the visitor with the given selector + /// and compiles it for the given message types and their fields recursively. + MessagesVisitor(Selector selector, const DescriptorList& descriptors); + + /// @brief Creates the visitor with custom thread locking behavior + /// and the given selector for runtime compilation. + /// + /// @warning Do not use this unless you know what you are doing. + MessagesVisitor(Selector selector, LockBehavior lock_behavior); + + /// @brief Creates the visitor with custom thread locking behavior + /// and the given selector; compiles it for the given message types. + /// + /// @warning Do not use this unless you know what you are doing. + MessagesVisitor(Selector selector, const DescriptorList& descriptors, + LockBehavior lock_behavior); + + /// @cond + /// Only for internal use. + const DescriptorSet& GetSelectedMessages(utils::impl::InternalTag) const { + return selected_messages_; + } + /// @endcond + + private: + void CompileOne(const google::protobuf::Descriptor& descriptor) override; + + bool IsSelected( + const google::protobuf::Descriptor& descriptor) const override { + return selected_messages_.find(&descriptor) != selected_messages_.end(); + } + + void DoVisit(google::protobuf::Message& message, + MessageVisitCallback callback) override; + + DescriptorSet selected_messages_; + const Selector selector_; +}; } // namespace ugrpc diff --git a/grpc/proto/tests/protobuf.proto b/grpc/proto/tests/protobuf.proto index 1defbd89a88c..353f28eb2e07 100644 --- a/grpc/proto/tests/protobuf.proto +++ b/grpc/proto/tests/protobuf.proto @@ -3,26 +3,45 @@ syntax = "proto3"; package sample.ugrpc; import "google/protobuf/struct.proto"; +import "google/protobuf/descriptor.proto"; + +message FieldOptions { + bool selected = 1; +} + +message MessageOptions { + bool selected = 1; +} + +extend google.protobuf.FieldOptions { + FieldOptions field = 35784; +} + +extend google.protobuf.MessageOptions { + MessageOptions message = 35785; +} // A message with fields of many different types message MessageWithDifferentTypes { // Nested message message NestedMessage { - string required_string = 1; + option (sample.ugrpc.message).selected = true; + + string required_string = 1 [(sample.ugrpc.field).selected = true]; optional string optional_string = 2; - uint32 required_int = 3; + uint32 required_int = 3 [(sample.ugrpc.field).selected = true]; optional uint32 optional_int = 4; } // Strings string required_string = 1; - optional string optional_string = 2; + optional string optional_string = 2 [(sample.ugrpc.field).selected = true]; // Integers uint32 required_int = 3; - optional uint32 optional_int = 4; + optional uint32 optional_int = 4 [(sample.ugrpc.field).selected = true]; // Nested messages NestedMessage required_nested = 5; @@ -36,7 +55,7 @@ message MessageWithDifferentTypes { repeated string repeated_primitive = 9; // Repeated message - repeated NestedMessage repeated_message = 10; + repeated NestedMessage repeated_message = 10 [(sample.ugrpc.field).selected = true]; // Map of primitives map primitives_map = 11; @@ -53,4 +72,102 @@ message MessageWithDifferentTypes { // Google type google.protobuf.Value google_value = 16; + + // Weird map key (not an integer or string) + map weird_map = 17; +} + +/* +Component 1: + -> E + | +A (+) -> B <-----> C (+) + | | + -> D (+) <- +*/ + +message Msg1A { + option (sample.ugrpc.message).selected = true; + + string value1 = 1 [(sample.ugrpc.field).selected = true]; + string value2 = 2 [(sample.ugrpc.field).selected = true]; + Msg1B nested = 3; +}; + +message Msg1B { + Msg1C recursive_1 = 1; + Msg1C recursive_2 = 2; + Msg1D nested_secret_1 = 3; + Msg1D nested_secret_2 = 4; + Msg1D nested_secret_3 = 5; + Msg1E nested_nosecret_1 = 6; + Msg1E nested_nosecret_2 = 7; + Msg1E nested_nosecret_3 = 8; +}; + +message Msg1C { + option (sample.ugrpc.message).selected = true; + + Msg1B recursive_1 = 1; + Msg1B recursive_2 = 2; + string value = 3 [(sample.ugrpc.field).selected = true]; + Msg1D nested = 4; +}; + +message Msg1D { + option (sample.ugrpc.message).selected = true; + + string value = 1 [(sample.ugrpc.field).selected = true]; +}; + +message Msg1E { + string value = 1; +}; + +/* +Component 2: +A (+) +*/ + +message Msg2A { + option (sample.ugrpc.message).selected = true; + + string value1 = 1 [(sample.ugrpc.field).selected = true]; + uint32 value2 = 2; + bool value3 = 3 [(sample.ugrpc.field).selected = true]; +} + +/* +Component 3: +A -> B +*/ + +message Msg3A { + string value = 1; + Msg3B nested = 2; +} + +message Msg3B { + string value = 1; +} + +/* +Component 4: +A -> B (+) -> C -> A -> ... +*/ + +message Msg4A { + Msg4B nested = 1; +} + +message Msg4B { + option (sample.ugrpc.message).selected = true; + + string value = 1 [(sample.ugrpc.field).selected = true]; + Msg4C nested = 2; +} + +message Msg4C { + Msg4A nested_1 = 1; + Msg4A nested_2 = 2; } diff --git a/grpc/proto/tests/unit_test.proto b/grpc/proto/tests/unit_test.proto index 5ea7fce838fa..c526addcc053 100644 --- a/grpc/proto/tests/unit_test.proto +++ b/grpc/proto/tests/unit_test.proto @@ -9,6 +9,9 @@ syntax = "proto3"; package sample.ugrpc; import "tests/messages.proto"; +import "tests/protobuf.proto"; + +import "google/protobuf/empty.proto"; service UnitTestService { // Simple RPC @@ -19,4 +22,7 @@ service UnitTestService { rpc WriteMany(stream StreamGreetingRequest) returns(StreamGreetingResponse) {} // Bidirectional streaming RPC rpc Chat(stream StreamGreetingRequest) returns(stream StreamGreetingResponse) {} + + // Simple RPC with less simple response + rpc GetData(google.protobuf.Empty) returns(MessageWithDifferentTypes) {} } diff --git a/grpc/src/ugrpc/impl/protobuf_collector.cpp b/grpc/src/ugrpc/impl/protobuf_collector.cpp new file mode 100644 index 000000000000..c952851c7c4b --- /dev/null +++ b/grpc/src/ugrpc/impl/protobuf_collector.cpp @@ -0,0 +1,45 @@ +#include + +#include + +#include + +#include +#include +#include + +USERVER_NAMESPACE_BEGIN + +namespace ugrpc::impl { + +namespace { + +std::unordered_set& GetGeneratedMessagesImpl() { + static std::unordered_set messages; + return messages; +} + +} // namespace + +void RegisterMessageTypes(std::initializer_list type_names) { + utils::impl::AssertStaticRegistrationAllowed( + "Calling ugrpc::impl::RegisterMessageTypes()"); + GetGeneratedMessagesImpl().merge(std::unordered_set(type_names)); +} + +DescriptorList GetGeneratedMessages() { + utils::impl::AssertStaticRegistrationFinished(); + + DescriptorList result; + for (const std::string& service_name : GetGeneratedMessagesImpl()) { + const google::protobuf::Descriptor* descriptor = + ugrpc::FindGeneratedMessage(service_name); + UINVARIANT(descriptor, "descriptor is nullptr"); + result.push_back(descriptor); + } + return result; +} + +} // namespace ugrpc::impl + +USERVER_NAMESPACE_END diff --git a/grpc/src/ugrpc/impl/protobuf_utils.cpp b/grpc/src/ugrpc/impl/protobuf_utils.cpp index fb9de515fa37..38adfabbb1ed 100644 --- a/grpc/src/ugrpc/impl/protobuf_utils.cpp +++ b/grpc/src/ugrpc/impl/protobuf_utils.cpp @@ -7,20 +7,20 @@ USERVER_NAMESPACE_BEGIN namespace ugrpc::impl { -bool IsFieldSecret(const google::protobuf::FieldDescriptor* field) { - return GetFieldOptions(field).secret(); -} - void TrimSecrets(google::protobuf::Message& message) { - ugrpc::VisitFieldsRecursive( + static ugrpc::FieldsVisitor kSecretVisitor( + [](const google::protobuf::Descriptor&, + const google::protobuf::FieldDescriptor& field) { + return GetFieldOptions(field).secret(); + }); + + kSecretVisitor.VisitRecursive( message, [](google::protobuf::Message& message, const google::protobuf::FieldDescriptor& field) { - if (IsFieldSecret(&field)) { - const google::protobuf::Reflection* reflection = - message.GetReflection(); - UINVARIANT(reflection, "reflection is nullptr"); - reflection->ClearField(&message, &field); - } + const google::protobuf::Reflection* reflection = + message.GetReflection(); + UINVARIANT(reflection, "reflection is nullptr"); + reflection->ClearField(&message, &field); }); } diff --git a/grpc/src/ugrpc/impl/protobuf_utils.hpp b/grpc/src/ugrpc/impl/protobuf_utils.hpp index 6a118b042c28..bc37f1e7e275 100644 --- a/grpc/src/ugrpc/impl/protobuf_utils.hpp +++ b/grpc/src/ugrpc/impl/protobuf_utils.hpp @@ -6,8 +6,8 @@ USERVER_NAMESPACE_BEGIN namespace ugrpc::impl { -inline auto GetFieldOptions(const google::protobuf::FieldDescriptor* field) { - return field->options().GetExtension(userver::field); +inline auto GetFieldOptions(const google::protobuf::FieldDescriptor& field) { + return field.options().GetExtension(userver::field); } void TrimSecrets(google::protobuf::Message& message); diff --git a/grpc/src/ugrpc/protobuf_visit.cpp b/grpc/src/ugrpc/protobuf_visit.cpp index a1f8c9194f73..f83dd24d114d 100644 --- a/grpc/src/ugrpc/protobuf_visit.cpp +++ b/grpc/src/ugrpc/protobuf_visit.cpp @@ -1,9 +1,15 @@ #include +#include +#include + #include +#include #include +#include #include +#include #include USERVER_NAMESPACE_BEGIN @@ -14,8 +20,40 @@ namespace { constexpr int kMaxRecursionLimit = 100; +bool IsMessage(const google::protobuf::FieldDescriptor& field) { + return field.type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE || + field.type() == google::protobuf::FieldDescriptor::TYPE_GROUP; +} + +void CallNestedMessage(google::protobuf::Message& message, + const google::protobuf::FieldDescriptor& field, + MessageVisitCallback callback) { + UINVARIANT(IsMessage(field), "Not a nested message"); + + // Get reflection + const google::protobuf::Reflection* reflection = message.GetReflection(); + UINVARIANT(reflection, "reflection is nullptr"); + + if (field.is_repeated()) { + // Repeated types (including maps) + const int repeated_size = reflection->FieldSize(message, &field); + for (int i = 0; i < repeated_size; ++i) { + google::protobuf::Message* msg = + reflection->MutableRepeatedMessage(&message, &field, i); + UINVARIANT(msg, "msg is nullptr"); + callback(*msg); + } + } else if (reflection->HasField(message, &field)) { + // Primitive types + google::protobuf::Message* msg = + reflection->MutableMessage(&message, &field); + UINVARIANT(msg, "msg is nullptr"); + callback(*msg); + } +} + void VisitMessagesRecursiveImpl(google::protobuf::Message& message, - MessageVisitor callback, + MessageVisitCallback callback, const int recursion_limit) { UINVARIANT(recursion_limit > 0, "Recursion limit reached while traversing protobuf Message."); @@ -30,72 +68,67 @@ void VisitMessagesRecursiveImpl(google::protobuf::Message& message, google::protobuf::Message& message, const google::protobuf::FieldDescriptor& field) -> void { // Not a nested message - if (field.type() != google::protobuf::FieldDescriptor::TYPE_MESSAGE && - field.type() != google::protobuf::FieldDescriptor::TYPE_GROUP) { - return; - } - - // Get reflection - const google::protobuf::Reflection* reflection = - message.GetReflection(); - UINVARIANT(reflection, "reflection is nullptr"); - - if (field.is_repeated()) { - // Repeated types (including maps) - const int repeated_size = reflection->FieldSize(message, &field); - for (int i = 0; i < repeated_size; ++i) { - google::protobuf::Message* msg = - reflection->MutableRepeatedMessage(&message, &field, i); - UINVARIANT(msg, "msg is nullptr"); - VisitMessagesRecursiveImpl(*msg, callback, recursion_limit - 1); - } - } else if (reflection->HasField(message, &field)) { - // Primitive types - google::protobuf::Message* msg = - reflection->MutableMessage(&message, &field); - UINVARIANT(msg, "msg is nullptr"); - VisitMessagesRecursiveImpl(*msg, callback, recursion_limit - 1); - } + if (!IsMessage(field)) return; + + CallNestedMessage(message, field, [&](google::protobuf::Message& msg) { + VisitMessagesRecursiveImpl(msg, callback, recursion_limit - 1); + }); }); } -} // namespace +void GetNestedMessageDescriptorsImpl( + const google::protobuf::Descriptor& descriptor, + std::unordered_set& result) { + if (!result.insert(&descriptor).second) { + // It is already there. Stop to avoid infinite recursion. + return; + } -void VisitFields(google::protobuf::Message& message, FieldVisitor callback) { - // Get descriptor - const google::protobuf::Descriptor* descriptor = message.GetDescriptor(); - UINVARIANT(descriptor, "descriptor is nullptr"); + // Check the fields + for (const google::protobuf::FieldDescriptor* field : + GetFieldDescriptors(descriptor)) { + UINVARIANT(field, "field is nullptr"); + // Not a nested message + if (!IsMessage(*field)) continue; + + const google::protobuf::Descriptor* msg = field->message_type(); + UINVARIANT(msg, "msg is nullptr"); + GetNestedMessageDescriptorsImpl(*msg, result); + } +} + +std::unordered_set +GetNestedMessageDescriptorsSet(const google::protobuf::Descriptor& descriptor) { + std::unordered_set result; + GetNestedMessageDescriptorsImpl(descriptor, result); + return result; +} + +} // namespace + +void VisitFields(google::protobuf::Message& message, + FieldVisitCallback callback) { // Get reflection const google::protobuf::Reflection* reflection = message.GetReflection(); UINVARIANT(reflection, "reflection is nullptr"); - for (int field_index = 0; field_index < descriptor->field_count(); - ++field_index) { - // Get field descriptor - const google::protobuf::FieldDescriptor* field = - descriptor->field(field_index); - UINVARIANT(field, "field is nullptr"); + std::vector fields; + reflection->ListFields(message, &fields); - if (field->is_repeated()) { - // Repeated types (including maps) - if (reflection->FieldSize(message, field) > 0) { - callback(message, *field); - } - } else if (reflection->HasField(message, field)) { - // Primitive types - callback(message, *field); - } + for (const google::protobuf::FieldDescriptor* field : fields) { + UINVARIANT(field, "field is nullptr"); + callback(message, *field); } } void VisitMessagesRecursive(google::protobuf::Message& message, - MessageVisitor callback) { + MessageVisitCallback callback) { VisitMessagesRecursiveImpl(message, callback, kMaxRecursionLimit); } void VisitFieldsRecursive(google::protobuf::Message& message, - FieldVisitor callback) { + FieldVisitCallback callback) { VisitMessagesRecursiveImpl( message, [&](google::protobuf::Message& message) -> void { @@ -104,6 +137,315 @@ void VisitFieldsRecursive(google::protobuf::Message& message, kMaxRecursionLimit); } +FieldDescriptorList GetFieldDescriptors( + const google::protobuf::Descriptor& descriptor) { + FieldDescriptorList result; + result.reserve(descriptor.field_count()); + for (int idx = 0; idx < descriptor.field_count(); ++idx) { + const google::protobuf::FieldDescriptor* field = descriptor.field(idx); + UINVARIANT(field, "field is nullptr"); + result.push_back(field); + } + return result; +} + +DescriptorList GetNestedMessageDescriptors( + const google::protobuf::Descriptor& descriptor) { + const auto set = GetNestedMessageDescriptorsSet(descriptor); + return DescriptorList(set.begin(), set.end()); +} + +const google::protobuf::Descriptor* FindGeneratedMessage( + std::string_view name) { + const google::protobuf::DescriptorPool* pool = + google::protobuf::DescriptorPool::generated_pool(); + UINVARIANT(pool, "pool is nullptr"); +#if GOOGLE_PROTOBUF_VERSION >= 3014000 + return pool->FindMessageTypeByName(name); +#else + return pool->FindMessageTypeByName(std::string(name)); +#endif +} + +const google::protobuf::FieldDescriptor* FindField( + const google::protobuf::Descriptor* descriptor, std::string_view field) { + UINVARIANT(descriptor, "descriptor is nullptr"); +#if GOOGLE_PROTOBUF_VERSION >= 3014000 + return descriptor->FindFieldByName(field); +#else + return descriptor->FindFieldByName(std::string(field)); +#endif +} + +template +void BaseVisitor::Compile( + const google::protobuf::Descriptor* descriptor) { + UINVARIANT(descriptor, "descriptor is nullptr"); + Compile(DescriptorList{descriptor}); +} + +template +void BaseVisitor::Compile(const DescriptorList& descriptors) { + { + std::shared_lock read_lock(mutex_, std::defer_lock); + if (lock_behavior_ == LockBehavior::kShared) { + read_lock.lock(); + } + + bool are_compiled = true; + for (const google::protobuf::Descriptor* descriptor : descriptors) { + if (compiled_.find(descriptor) == compiled_.end()) { + // Something is not compiled. Need to compile. + are_compiled = false; + break; + } + } + if (are_compiled) { + // Everything is already compiled. Stop. + return; + } + } + + std::unique_lock write_lock(mutex_, std::defer_lock); + if (lock_behavior_ == LockBehavior::kShared) { + write_lock.lock(); + } + + for (const google::protobuf::Descriptor* descriptor : + GetFullSubtrees(descriptors)) { + UINVARIANT(descriptor, "descriptor is nullptr"); + + // We have already compiled this. Skip. + if (!compiled_.insert(descriptor).second) continue; + + // Compile the selection data + CompileOne(*descriptor); + + // Update everything else + for (const google::protobuf::FieldDescriptor* field : + GetFieldDescriptors(*descriptor)) { + UINVARIANT(field, "field is nullptr"); + + // Not a nested message + if (!IsMessage(*field)) continue; + + // Sync the reverse edges. + // Even from unknown types - we might need to compile them in the future. + const google::protobuf::Descriptor* msg = field->message_type(); + UINVARIANT(msg, "msg is nullptr"); + reverse_edges_[msg].insert(field); + + // Compile the direct edge + propagated_.erase(msg); + PropagateSelected(msg); + } + + // Compile the connections to this message using the reverse edges + PropagateSelected(descriptor); + } +} + +template +void BaseVisitor::Visit(google::protobuf::Message& message, + Callback callback) { + // Compile if not yet compiled + Compile(message.GetDescriptor()); + + std::shared_lock read_lock(mutex_, std::defer_lock); + if (lock_behavior_ == LockBehavior::kShared) { + read_lock.lock(); + } + DoVisit(message, callback); +} + +template +void BaseVisitor::VisitRecursive(google::protobuf::Message& message, + Callback callback) { + // Compile if not yet compiled + Compile(message.GetDescriptor()); + + std::shared_lock read_lock(mutex_, std::defer_lock); + if (lock_behavior_ == LockBehavior::kShared) { + read_lock.lock(); + } + VisitRecursiveImpl(message, callback, kMaxRecursionLimit); +} + +template +typename BaseVisitor::DescriptorSet +BaseVisitor::GetFullSubtrees( + const DescriptorList& descriptors) const { + DescriptorSet result; + for (const google::protobuf::Descriptor* descriptor : descriptors) { + UINVARIANT(descriptor, "descriptor is nullptr"); + if (result.find(descriptor) != result.end()) { + // We have already parsed this + continue; + } + result.merge(GetNestedMessageDescriptorsSet(*descriptor)); + } + return result; +} + +template +void BaseVisitor::PropagateSelected( + const google::protobuf::Descriptor* descriptor) { + UINVARIANT(descriptor, "descriptor is nullptr"); + if (!IsSelected(*descriptor) && + fields_with_selected_children_.find(descriptor) == + fields_with_selected_children_.end()) { + // This does not need to be propagated + return; + } + + if (!propagated_.insert(descriptor).second) { + // We have already propagated this before + return; + } + + const auto it = reverse_edges_.find(descriptor); + if (it == reverse_edges_.end()) return; // No edges + + const FieldDescriptorSet& fields = it->second; + for (const google::protobuf::FieldDescriptor* field : fields) { + UINVARIANT(field, "field is nullptr"); + + const google::protobuf::Descriptor* msg = field->containing_type(); + UINVARIANT(msg, "msg is nullptr"); + + // Save the connection + fields_with_selected_children_[msg].insert(field); + + // Go further over reverse_edges + PropagateSelected(msg); + } +} + +template +void BaseVisitor::VisitRecursiveImpl( + google::protobuf::Message& message, Callback callback, + int recursion_limit) { + UINVARIANT(recursion_limit > 0, + "Recursion limit reached while traversing protobuf Message."); + + // Loop over this message + DoVisit(message, callback); + + // Recurse into nested messages + const auto it = fields_with_selected_children_.find(message.GetDescriptor()); + if (it == fields_with_selected_children_.end()) return; + + const FieldDescriptorSet& fields = it->second; + for (const google::protobuf::FieldDescriptor* field : fields) { + UINVARIANT(field, "field is nullptr"); + CallNestedMessage(message, *field, [&](google::protobuf::Message& msg) { + VisitRecursiveImpl(msg, callback, recursion_limit - 1); + }); + } +} + +FieldsVisitor::FieldsVisitor(Selector selector) + : BaseVisitor(LockBehavior::kShared), + selector_(selector) { + Compile(impl::GetGeneratedMessages()); +} + +FieldsVisitor::FieldsVisitor(Selector selector, + const DescriptorList& descriptors) + : BaseVisitor(LockBehavior::kShared), + selector_(selector) { + Compile(descriptors); +} + +FieldsVisitor::FieldsVisitor(Selector selector, LockBehavior lock_behavior) + : BaseVisitor(lock_behavior), selector_(selector) { + Compile(impl::GetGeneratedMessages()); +} + +FieldsVisitor::FieldsVisitor(Selector selector, + const DescriptorList& descriptors, + LockBehavior lock_behavior) + : BaseVisitor(lock_behavior), selector_(selector) { + Compile(descriptors); +} + +void FieldsVisitor::CompileOne(const google::protobuf::Descriptor& descriptor) { + for (const google::protobuf::FieldDescriptor* field : + GetFieldDescriptors(descriptor)) { + UINVARIANT(field, "field is nullptr"); + if (selector_(descriptor, *field)) { + selected_fields_[&descriptor].insert(field); + } + } +} + +void FieldsVisitor::DoVisit(google::protobuf::Message& message, + FieldVisitCallback callback) { + const auto it = selected_fields_.find(message.GetDescriptor()); + if (it == selected_fields_.end()) return; + + // Get reflection + const google::protobuf::Reflection* reflection = message.GetReflection(); + UINVARIANT(reflection, "reflection is nullptr"); + + const FieldDescriptorSet& fields = it->second; + for (const google::protobuf::FieldDescriptor* field : fields) { + // Repeated types (including maps) + if (field->is_repeated()) { + if (reflection->FieldSize(message, field) > 0) { + callback(message, *field); + } + } else { + // Primitive types + if (reflection->HasField(message, field)) { + callback(message, *field); + } + } + } +} + +MessagesVisitor::MessagesVisitor(Selector selector) + : BaseVisitor(LockBehavior::kShared), + selector_(selector) { + Compile(impl::GetGeneratedMessages()); +} + +MessagesVisitor::MessagesVisitor(Selector selector, + const DescriptorList& descriptors) + : BaseVisitor(LockBehavior::kShared), + selector_(selector) { + Compile(descriptors); +} + +MessagesVisitor::MessagesVisitor(Selector selector, LockBehavior lock_behavior) + : BaseVisitor(lock_behavior), selector_(selector) { + Compile(impl::GetGeneratedMessages()); +} + +MessagesVisitor::MessagesVisitor(Selector selector, + const DescriptorList& descriptors, + LockBehavior lock_behavior) + : BaseVisitor(lock_behavior), selector_(selector) { + Compile(descriptors); +} + +void MessagesVisitor::CompileOne( + const google::protobuf::Descriptor& descriptor) { + if (selector_(descriptor)) { + selected_messages_.insert(&descriptor); + } +} + +void MessagesVisitor::DoVisit(google::protobuf::Message& message, + MessageVisitCallback callback) { + const auto it = selected_messages_.find(message.GetDescriptor()); + if (it == selected_messages_.end()) return; + callback(message); +} + +template class BaseVisitor; +template class BaseVisitor; + } // namespace ugrpc USERVER_NAMESPACE_END diff --git a/grpc/tests/protobuf_collector_test.cpp b/grpc/tests/protobuf_collector_test.cpp new file mode 100644 index 000000000000..b67562721b0b --- /dev/null +++ b/grpc/tests/protobuf_collector_test.cpp @@ -0,0 +1,43 @@ +#include + +#include +#include + +#include +#include + +#include +#include + +USERVER_NAMESPACE_BEGIN + +TEST(GetGeneratedMessages, Ok) { + const ugrpc::DescriptorList generated_message = + ugrpc::impl::GetGeneratedMessages(); + + // Currently, there are 8 different request/response types in the binary. + // You should adjust this number if you start using new types + // as requests/responses in the tests' protobufs + EXPECT_EQ(generated_message.size(), 8); + + EXPECT_THAT(generated_message, testing::Contains(ugrpc::FindGeneratedMessage( + "sample.ugrpc.SendRequest"))); + EXPECT_THAT(generated_message, testing::Contains(ugrpc::FindGeneratedMessage( + "sample.ugrpc.SendResponse"))); + EXPECT_THAT(generated_message, testing::Contains(ugrpc::FindGeneratedMessage( + "sample.ugrpc.GreetingRequest"))); + EXPECT_THAT(generated_message, testing::Contains(ugrpc::FindGeneratedMessage( + "sample.ugrpc.GreetingResponse"))); + EXPECT_THAT(generated_message, testing::Contains(ugrpc::FindGeneratedMessage( + "sample.ugrpc.StreamGreetingRequest"))); + EXPECT_THAT(generated_message, testing::Contains(ugrpc::FindGeneratedMessage( + "sample.ugrpc.StreamGreetingResponse"))); + EXPECT_THAT( + generated_message, + testing::Contains(ugrpc::FindGeneratedMessage("google.protobuf.Empty"))); + EXPECT_THAT(generated_message, + testing::Contains(ugrpc::FindGeneratedMessage( + "sample.ugrpc.MessageWithDifferentTypes"))); +} + +USERVER_NAMESPACE_END diff --git a/grpc/tests/protobuf_visit_test.cpp b/grpc/tests/protobuf_visit_test.cpp index 9e8b56fa524e..c3ddfb4f8275 100644 --- a/grpc/tests/protobuf_visit_test.cpp +++ b/grpc/tests/protobuf_visit_test.cpp @@ -1,12 +1,17 @@ #include +#include +#include +#include +#include +#include #include #include #include #include -#include +#include USERVER_NAMESPACE_BEGIN @@ -16,8 +21,8 @@ sample::ugrpc::MessageWithDifferentTypes::NestedMessage ConstructNestedMessage() { sample::ugrpc::MessageWithDifferentTypes::NestedMessage message; message.set_required_string("string1"); - // Leave required_int as empty required field message.set_optional_string("string2"); + // Leave required_int as empty required field // Leave optional_int as empty optional field return message; } @@ -32,13 +37,16 @@ sample::ugrpc::MessageWithDifferentTypes ConstructMessage() { message.set_optional_int(456654); message.mutable_required_nested()->set_required_string("string1"); - // Leave required_int an empty required field message.mutable_required_nested()->set_optional_string("string2"); + // Leave required_int an empty required field // Leave optional_int an empty optional field // leave optional_nested empty - // leave recursive messages empty: the recursion should ignore them + message.mutable_required_recursive()->set_required_string("string1"); + message.mutable_required_recursive()->set_optional_string("string2"); + + // leave optional_recursive empty message.add_repeated_primitive("string1"); message.add_repeated_primitive("string2"); @@ -61,6 +69,251 @@ sample::ugrpc::MessageWithDifferentTypes ConstructMessage() { return message; } +std::pair +MakeDependency(std::string_view message, std::vector fields) { + std::unordered_set field_desc; + for (const std::string_view field : fields) { + field_desc.insert( + ugrpc::FindField(ugrpc::FindGeneratedMessage(message), field)); + } + return {ugrpc::FindGeneratedMessage(message), std::move(field_desc)}; +} + +std::pair +MakeDependency(std::string_view message, std::string_view fields_message, + std::vector fields) { + std::unordered_set field_desc; + for (const std::string_view field : fields) { + field_desc.insert( + ugrpc::FindField(ugrpc::FindGeneratedMessage(fields_message), field)); + } + return {ugrpc::FindGeneratedMessage(message), std::move(field_desc)}; +} + +std::unordered_map> ToStrings( + const ugrpc::FieldsVisitor::Dependencies& dependencies) { + std::unordered_map> result; + for (const auto& [msg, fields] : dependencies) { + for (const google::protobuf::FieldDescriptor* field : fields) { + result[msg->full_name()].insert(field->name()); + } + } + return result; +} + +std::unordered_set ToStrings( + const ugrpc::FieldsVisitor::DescriptorSet& messages) { + std::unordered_set result; + for (const auto& msg : messages) { + result.insert(msg->full_name()); + } + return result; +} + +std::unordered_set ToStrings( + const ugrpc::FieldsVisitor::FieldDescriptorSet& fields) { + std::unordered_set result; + for (const auto& field : fields) { + result.insert(field->name()); + } + return result; +} + +template +std::unordered_set ToSet(const std::vector& vector) { + return {vector.begin(), vector.end()}; +} + +void MyExpectEq(const ugrpc::FieldsVisitor::FieldDescriptorSet& val1, + const ugrpc::FieldsVisitor::FieldDescriptorSet& val2) { + EXPECT_EQ(ToStrings(val1), ToStrings(val2)); + EXPECT_EQ(val1, val2); +} + +void MyExpectEq(const ugrpc::FieldsVisitor::DescriptorSet& val1, + const ugrpc::FieldsVisitor::DescriptorSet& val2) { + EXPECT_EQ(ToStrings(val1), ToStrings(val2)); + EXPECT_EQ(val1, val2); +} + +void MyExpectEq(const ugrpc::FieldsVisitor::Dependencies& val1, + const ugrpc::FieldsVisitor::Dependencies& val2) { + EXPECT_EQ(ToStrings(val1), ToStrings(val2)); + EXPECT_EQ(val1, val2); +} + +template +bool ContainsMessage(const MapOrSet& collection, std::string_view value) { + return collection.find(ugrpc::FindGeneratedMessage(value)) != + collection.end(); +} + +bool FieldSelector(const google::protobuf::Descriptor&, + const google::protobuf::FieldDescriptor& field) { + return field.options().GetExtension(sample::ugrpc::field).selected(); +} + +bool MessageSelector(const google::protobuf::Descriptor& message) { + return message.options().GetExtension(sample::ugrpc::message).selected(); +} + +namespace component1 { + +ugrpc::DescriptorList Get() { + return {ugrpc::FindGeneratedMessage("sample.ugrpc.Msg1A"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg1B"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg1C"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg1D"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg1E")}; +} + +ugrpc::FieldsVisitor::Dependencies GetSelectedFields() { + return { + MakeDependency("sample.ugrpc.Msg1A", {"value1", "value2"}), + MakeDependency("sample.ugrpc.Msg1C", {"value"}), + MakeDependency("sample.ugrpc.Msg1D", {"value"}), + }; +} + +ugrpc::FieldsVisitor::DescriptorSet GetSelectedMessages() { + return { + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg1A"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg1C"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg1D"), + }; +} + +ugrpc::FieldsVisitor::Dependencies GetFieldsWithSelectedChildren() { + return { + MakeDependency("sample.ugrpc.Msg1A", {"nested"}), + MakeDependency("sample.ugrpc.Msg1B", + {"recursive_1", "recursive_2", "nested_secret_1", + "nested_secret_2", "nested_secret_3"}), + MakeDependency("sample.ugrpc.Msg1C", + {"recursive_1", "recursive_2", "nested"}), + }; +} + +} // namespace component1 + +namespace component2 { + +ugrpc::DescriptorList Get() { + return {ugrpc::FindGeneratedMessage("sample.ugrpc.Msg2A")}; +} + +ugrpc::FieldsVisitor::Dependencies GetSelectedFields() { + return { + MakeDependency("sample.ugrpc.Msg2A", {"value1", "value3"}), + }; +} + +ugrpc::FieldsVisitor::DescriptorSet GetSelectedMessages() { + return { + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg2A"), + }; +} + +ugrpc::FieldsVisitor::Dependencies GetFieldsWithSelectedChildren() { + return {}; +} + +} // namespace component2 + +namespace component3 { + +ugrpc::DescriptorList Get() { + return {ugrpc::FindGeneratedMessage("sample.ugrpc.Msg3A"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg3B")}; +} + +ugrpc::FieldsVisitor::Dependencies GetSelectedFields() { return {}; } + +ugrpc::FieldsVisitor::DescriptorSet GetSelectedMessages() { return {}; } + +ugrpc::FieldsVisitor::Dependencies GetFieldsWithSelectedChildren() { + return {}; +} + +} // namespace component3 + +namespace component4 { + +ugrpc::DescriptorList Get() { + return {ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4A"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4B"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4C")}; +} + +ugrpc::FieldsVisitor::Dependencies GetSelectedFields() { + return { + MakeDependency("sample.ugrpc.Msg4B", {"value"}), + }; +} + +ugrpc::FieldsVisitor::DescriptorSet GetSelectedMessages() { + return { + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4B"), + }; +} + +ugrpc::FieldsVisitor::Dependencies GetFieldsWithSelectedChildren() { + return { + MakeDependency("sample.ugrpc.Msg4A", {"nested"}), + MakeDependency("sample.ugrpc.Msg4B", {"nested"}), + MakeDependency("sample.ugrpc.Msg4C", {"nested_1", "nested_2"}), + }; +} + +} // namespace component4 + +namespace diff_types { + +ugrpc::DescriptorList Get() { + return {ugrpc::FindGeneratedMessage("sample.ugrpc.MessageWithDifferentTypes"), + ugrpc::FindGeneratedMessage( + "sample.ugrpc.MessageWithDifferentTypes.NestedMessage"), + ugrpc::FindGeneratedMessage( + "sample.ugrpc.MessageWithDifferentTypes.NestedMapEntry"), + ugrpc::FindGeneratedMessage( + "sample.ugrpc.MessageWithDifferentTypes.PrimitivesMapEntry"), + ugrpc::FindGeneratedMessage("google.protobuf.Value"), + ugrpc::FindGeneratedMessage("google.protobuf.ListValue")}; +} + +ugrpc::FieldsVisitor::Dependencies GetSelectedFields() { + return { + MakeDependency("sample.ugrpc.MessageWithDifferentTypes", + {"optional_string", "optional_int", "repeated_message"}), + MakeDependency("sample.ugrpc.MessageWithDifferentTypes.NestedMessage", + {"required_string", "required_int"}), + }; +} + +ugrpc::FieldsVisitor::DescriptorSet GetSelectedMessages() { + return { + ugrpc::FindGeneratedMessage( + "sample.ugrpc.MessageWithDifferentTypes.NestedMessage"), + }; +} + +ugrpc::FieldsVisitor::Dependencies GetFieldsWithSelectedChildren() { + return { + MakeDependency( + "sample.ugrpc.MessageWithDifferentTypes", + {"required_nested", "optional_nested", "required_recursive", + "optional_recursive", "repeated_message", "nested_map", + "oneof_nested", "weird_map"}), + MakeDependency("sample.ugrpc.MessageWithDifferentTypes.NestedMapEntry", + {"value"}), + MakeDependency("sample.ugrpc.MessageWithDifferentTypes.WeirdMapEntry", + {"value"})}; +} + +} // namespace diff_types + } // namespace TEST(VisitFields, TestEmptyMessage) { @@ -85,6 +338,7 @@ TEST(VisitFields, TestMessage) { 1 + // required_int 1 + // optional_int 1 + // required_nested + 1 + // required_recursive 1 + // repeated_primitive 1 + // repeated_message 1 + // primitives_map @@ -113,6 +367,7 @@ TEST(VisitMessagesRecursive, TestMessage) { message, [&calls](google::protobuf::Message&) { ++calls; }); const std::size_t expected_calls = 1 + // root object 1 + // required_nested + 1 + // required_recursive 2 + // repeated_message 2 + // primitives_map ({ key, value }) 2 + // nested_map ({ key, value }) @@ -147,6 +402,9 @@ TEST(VisitFieldsRecursive, TestMessage) { 1 + // required_nested 1 + // required_nested required_string 1 + // required_nested optional_string + 1 + // required_recursive + 1 + // required_recursive required_string + 1 + // required_recursive optional_string 1 + // repeated_primitive 1 + // repeated_message 2 + // repeated_message required_string @@ -167,4 +425,547 @@ TEST(VisitFieldsRecursive, TestMessage) { message, ConstructMessage())); } +TEST(GetFieldDescriptors, MessageWithDifferentTypes) { + constexpr auto msg = "sample.ugrpc.MessageWithDifferentTypes"; + MyExpectEq( + ToSet(ugrpc::GetFieldDescriptors(*ugrpc::FindGeneratedMessage(msg))), + { + ugrpc::FindField(ugrpc::FindGeneratedMessage(msg), "required_string"), + ugrpc::FindField(ugrpc::FindGeneratedMessage(msg), "optional_string"), + ugrpc::FindField(ugrpc::FindGeneratedMessage(msg), "required_int"), + ugrpc::FindField(ugrpc::FindGeneratedMessage(msg), "optional_int"), + ugrpc::FindField(ugrpc::FindGeneratedMessage(msg), "required_nested"), + ugrpc::FindField(ugrpc::FindGeneratedMessage(msg), "optional_nested"), + ugrpc::FindField(ugrpc::FindGeneratedMessage(msg), + "required_recursive"), + ugrpc::FindField(ugrpc::FindGeneratedMessage(msg), + "optional_recursive"), + ugrpc::FindField(ugrpc::FindGeneratedMessage(msg), + "repeated_primitive"), + ugrpc::FindField(ugrpc::FindGeneratedMessage(msg), + "repeated_message"), + ugrpc::FindField(ugrpc::FindGeneratedMessage(msg), "primitives_map"), + ugrpc::FindField(ugrpc::FindGeneratedMessage(msg), "nested_map"), + ugrpc::FindField(ugrpc::FindGeneratedMessage(msg), "oneof_string"), + ugrpc::FindField(ugrpc::FindGeneratedMessage(msg), "oneof_int"), + ugrpc::FindField(ugrpc::FindGeneratedMessage(msg), "oneof_nested"), + ugrpc::FindField(ugrpc::FindGeneratedMessage(msg), "google_value"), + ugrpc::FindField(ugrpc::FindGeneratedMessage(msg), "weird_map"), + }); +} + +TEST(GetNestedMessageDescriptors, MessageWithDifferentTypes) { + const std::string msg = "sample.ugrpc.MessageWithDifferentTypes"; + MyExpectEq( + ToSet(ugrpc::GetNestedMessageDescriptors( + *ugrpc::FindGeneratedMessage(msg))), + { + ugrpc::FindGeneratedMessage(msg), + ugrpc::FindGeneratedMessage(msg + ".NestedMessage"), + ugrpc::FindGeneratedMessage(msg + ".NestedMapEntry"), + ugrpc::FindGeneratedMessage(msg + ".PrimitivesMapEntry"), + ugrpc::FindGeneratedMessage(msg + ".WeirdMapEntry"), + ugrpc::FindGeneratedMessage("google.protobuf.Value"), + ugrpc::FindGeneratedMessage("google.protobuf.Struct"), + ugrpc::FindGeneratedMessage("google.protobuf.Struct.FieldsEntry"), + ugrpc::FindGeneratedMessage("google.protobuf.ListValue"), + }); +} + +TEST(FieldsVisitorCompile, OneLeafNoSelected) { + ugrpc::FieldsVisitor visitor(FieldSelector, ugrpc::DescriptorList{}); + visitor.CompileGenerated("sample.ugrpc.Msg3B"); + MyExpectEq(visitor.GetSelectedFields(utils::impl::InternalTag()), {}); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + {}); + MyExpectEq(visitor.GetReverseEdges(utils::impl::InternalTag()), {}); + MyExpectEq(visitor.GetPropagated(utils::impl::InternalTag()), {}); + MyExpectEq(visitor.GetCompiled(utils::impl::InternalTag()), + {ugrpc::FindGeneratedMessage("sample.ugrpc.Msg3B")}); +} + +TEST(FieldsVisitorCompile, OneNonLeafNoSelected) { + ugrpc::FieldsVisitor visitor(FieldSelector, ugrpc::DescriptorList{}); + visitor.CompileGenerated("sample.ugrpc.Msg3A"); + MyExpectEq(visitor.GetSelectedFields(utils::impl::InternalTag()), {}); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + {}); + MyExpectEq( + visitor.GetReverseEdges(utils::impl::InternalTag()), + {MakeDependency("sample.ugrpc.Msg3B", "sample.ugrpc.Msg3A", {"nested"})}); + MyExpectEq(visitor.GetPropagated(utils::impl::InternalTag()), {}); + MyExpectEq(visitor.GetCompiled(utils::impl::InternalTag()), + {ugrpc::FindGeneratedMessage("sample.ugrpc.Msg3A"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg3B")}); +} + +TEST(FieldsVisitorCompile, OneLeafSelected) { + constexpr auto msg = "sample.ugrpc.MessageWithDifferentTypes.NestedMessage"; + ugrpc::FieldsVisitor visitor(FieldSelector, ugrpc::DescriptorList{}); + visitor.CompileGenerated(msg); + MyExpectEq(visitor.GetSelectedFields(utils::impl::InternalTag()), + {MakeDependency(msg, {"required_string", "required_int"})}); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + {}); + MyExpectEq(visitor.GetReverseEdges(utils::impl::InternalTag()), {}); + MyExpectEq(visitor.GetPropagated(utils::impl::InternalTag()), + {ugrpc::FindGeneratedMessage(msg)}); + MyExpectEq(visitor.GetCompiled(utils::impl::InternalTag()), + {ugrpc::FindGeneratedMessage(msg)}); +} + +TEST(FieldsVisitorCompile, OneNonLeafSelected) { + const std::string msg = "sample.ugrpc.MessageWithDifferentTypes"; + ugrpc::FieldsVisitor visitor(FieldSelector, ugrpc::DescriptorList{}); + visitor.CompileGenerated(msg); + MyExpectEq(visitor.GetSelectedFields(utils::impl::InternalTag()), + diff_types::GetSelectedFields()); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + diff_types::GetFieldsWithSelectedChildren()); + EXPECT_GT(visitor.GetReverseEdges(utils::impl::InternalTag()).size(), 3); + MyExpectEq(visitor.GetPropagated(utils::impl::InternalTag()), + {ugrpc::FindGeneratedMessage(msg), + ugrpc::FindGeneratedMessage(msg + ".NestedMessage"), + ugrpc::FindGeneratedMessage(msg + ".NestedMapEntry"), + ugrpc::FindGeneratedMessage(msg + ".WeirdMapEntry")}); + EXPECT_GT(visitor.GetCompiled(utils::impl::InternalTag()).size(), 7); +} + +TEST(FieldsVisitorCompile, OneLoop) { + ugrpc::FieldsVisitor visitor(FieldSelector, ugrpc::DescriptorList{}); + visitor.CompileGenerated("sample.ugrpc.Msg4A"); + MyExpectEq(visitor.GetSelectedFields(utils::impl::InternalTag()), + component4::GetSelectedFields()); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + component4::GetFieldsWithSelectedChildren()); + EXPECT_EQ(visitor.GetReverseEdges(utils::impl::InternalTag()).size(), 3); + MyExpectEq(visitor.GetPropagated(utils::impl::InternalTag()), + {ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4A"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4B"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4C")}); + MyExpectEq(visitor.GetCompiled(utils::impl::InternalTag()), + {ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4A"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4B"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4C")}); +} + +TEST(FieldsVisitorCompile, TwoAB) { + const std::string msg = "sample.ugrpc.MessageWithDifferentTypes"; + ugrpc::FieldsVisitor visitor(FieldSelector, ugrpc::DescriptorList{}); + visitor.CompileGenerated(msg); + visitor.CompileGenerated(msg + ".NestedMessage"); + MyExpectEq(visitor.GetSelectedFields(utils::impl::InternalTag()), + diff_types::GetSelectedFields()); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + diff_types::GetFieldsWithSelectedChildren()); + EXPECT_GT(visitor.GetReverseEdges(utils::impl::InternalTag()).size(), 3); + MyExpectEq(visitor.GetPropagated(utils::impl::InternalTag()), + {ugrpc::FindGeneratedMessage(msg), + ugrpc::FindGeneratedMessage(msg + ".NestedMessage"), + ugrpc::FindGeneratedMessage(msg + ".NestedMapEntry"), + ugrpc::FindGeneratedMessage(msg + ".WeirdMapEntry")}); + EXPECT_GT(visitor.GetCompiled(utils::impl::InternalTag()).size(), 7); +} + +TEST(FieldsVisitorCompile, TwoBA) { + const std::string msg = "sample.ugrpc.MessageWithDifferentTypes"; + ugrpc::FieldsVisitor visitor(FieldSelector, ugrpc::DescriptorList{}); + visitor.CompileGenerated(msg + ".NestedMessage"); + visitor.CompileGenerated(msg); + MyExpectEq(visitor.GetSelectedFields(utils::impl::InternalTag()), + diff_types::GetSelectedFields()); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + diff_types::GetFieldsWithSelectedChildren()); + EXPECT_GT(visitor.GetReverseEdges(utils::impl::InternalTag()).size(), 3); + MyExpectEq(visitor.GetPropagated(utils::impl::InternalTag()), + {ugrpc::FindGeneratedMessage(msg), + ugrpc::FindGeneratedMessage(msg + ".NestedMessage"), + ugrpc::FindGeneratedMessage(msg + ".NestedMapEntry"), + ugrpc::FindGeneratedMessage(msg + ".WeirdMapEntry")}); + EXPECT_GT(visitor.GetCompiled(utils::impl::InternalTag()).size(), 7); +} + +TEST(FieldsVisitorCompile, ThreeABC) { + const std::string msg = "sample.ugrpc.MessageWithDifferentTypes"; + ugrpc::FieldsVisitor visitor(FieldSelector, ugrpc::DescriptorList{}); + visitor.CompileGenerated("sample.ugrpc.Msg4A"); + visitor.CompileGenerated("sample.ugrpc.Msg4B"); + visitor.CompileGenerated("sample.ugrpc.Msg4C"); + MyExpectEq(visitor.GetSelectedFields(utils::impl::InternalTag()), + component4::GetSelectedFields()); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + component4::GetFieldsWithSelectedChildren()); + EXPECT_EQ(visitor.GetReverseEdges(utils::impl::InternalTag()).size(), 3); + MyExpectEq(visitor.GetPropagated(utils::impl::InternalTag()), + {ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4A"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4B"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4C")}); + MyExpectEq(visitor.GetCompiled(utils::impl::InternalTag()), + {ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4A"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4B"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4C")}); +} + +TEST(FieldsVisitorCompile, ThreeACB) { + const std::string msg = "sample.ugrpc.MessageWithDifferentTypes"; + ugrpc::FieldsVisitor visitor(FieldSelector, ugrpc::DescriptorList{}); + visitor.CompileGenerated("sample.ugrpc.Msg4A"); + visitor.CompileGenerated("sample.ugrpc.Msg4C"); + visitor.CompileGenerated("sample.ugrpc.Msg4B"); + MyExpectEq(visitor.GetSelectedFields(utils::impl::InternalTag()), + component4::GetSelectedFields()); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + component4::GetFieldsWithSelectedChildren()); + EXPECT_EQ(visitor.GetReverseEdges(utils::impl::InternalTag()).size(), 3); + MyExpectEq(visitor.GetPropagated(utils::impl::InternalTag()), + {ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4A"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4B"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4C")}); + MyExpectEq(visitor.GetCompiled(utils::impl::InternalTag()), + {ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4A"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4B"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4C")}); +} + +TEST(FieldsVisitorCompile, ThreeCAB) { + const std::string msg = "sample.ugrpc.MessageWithDifferentTypes"; + ugrpc::FieldsVisitor visitor(FieldSelector, ugrpc::DescriptorList{}); + visitor.CompileGenerated("sample.ugrpc.Msg4C"); + visitor.CompileGenerated("sample.ugrpc.Msg4A"); + visitor.CompileGenerated("sample.ugrpc.Msg4B"); + MyExpectEq(visitor.GetSelectedFields(utils::impl::InternalTag()), + component4::GetSelectedFields()); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + component4::GetFieldsWithSelectedChildren()); + EXPECT_EQ(visitor.GetReverseEdges(utils::impl::InternalTag()).size(), 3); + MyExpectEq(visitor.GetPropagated(utils::impl::InternalTag()), + {ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4A"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4B"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4C")}); + MyExpectEq(visitor.GetCompiled(utils::impl::InternalTag()), + {ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4A"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4B"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg4C")}); +} + +TEST(FieldsVisitorConstructor, TestComponent1) { + ugrpc::FieldsVisitor visitor(FieldSelector, component1::Get()); + MyExpectEq(visitor.GetSelectedFields(utils::impl::InternalTag()), + component1::GetSelectedFields()); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + component1::GetFieldsWithSelectedChildren()); +} + +TEST(FieldsVisitorConstructor, TestComponent2) { + ugrpc::FieldsVisitor visitor(FieldSelector, component2::Get()); + MyExpectEq(visitor.GetSelectedFields(utils::impl::InternalTag()), + component2::GetSelectedFields()); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + component2::GetFieldsWithSelectedChildren()); +} + +TEST(FieldsVisitorConstructor, TestComponent3) { + ugrpc::FieldsVisitor visitor(FieldSelector, component3::Get()); + MyExpectEq(visitor.GetSelectedFields(utils::impl::InternalTag()), + component3::GetSelectedFields()); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + component3::GetFieldsWithSelectedChildren()); +} + +TEST(FieldsVisitorConstructor, TestComponent4) { + ugrpc::FieldsVisitor visitor(FieldSelector, component4::Get()); + MyExpectEq(visitor.GetSelectedFields(utils::impl::InternalTag()), + component4::GetSelectedFields()); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + component4::GetFieldsWithSelectedChildren()); +} + +TEST(FieldsVisitorConstructor, TestDiffTypes) { + ugrpc::FieldsVisitor visitor(FieldSelector, diff_types::Get()); + MyExpectEq(visitor.GetSelectedFields(utils::impl::InternalTag()), + diff_types::GetSelectedFields()); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + diff_types::GetFieldsWithSelectedChildren()); +} + +TEST(FieldsVisitorConstructor, TestPartialComponent) { + ugrpc::DescriptorList messages = { + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg1A"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg1C"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg1D"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg1E")}; + + ugrpc::FieldsVisitor visitor(FieldSelector, messages); + MyExpectEq(visitor.GetSelectedFields(utils::impl::InternalTag()), + component1::GetSelectedFields()); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + component1::GetFieldsWithSelectedChildren()); +} + +TEST(FieldsVisitorConstructor, TestMultipleComponents) { + ugrpc::DescriptorList messages; + for (const auto& msg : component1::Get()) messages.push_back(msg); + for (const auto& msg : component2::Get()) messages.push_back(msg); + for (const auto& msg : component3::Get()) messages.push_back(msg); + for (const auto& msg : component4::Get()) messages.push_back(msg); + + ugrpc::FieldsVisitor::Dependencies selected_fields; + selected_fields.merge(component1::GetSelectedFields()); + selected_fields.merge(component2::GetSelectedFields()); + selected_fields.merge(component3::GetSelectedFields()); + selected_fields.merge(component4::GetSelectedFields()); + + ugrpc::FieldsVisitor::Dependencies fields_with_selected_children; + fields_with_selected_children.merge( + component1::GetFieldsWithSelectedChildren()); + fields_with_selected_children.merge( + component2::GetFieldsWithSelectedChildren()); + fields_with_selected_children.merge( + component3::GetFieldsWithSelectedChildren()); + fields_with_selected_children.merge( + component4::GetFieldsWithSelectedChildren()); + + ugrpc::FieldsVisitor visitor(FieldSelector, messages); + MyExpectEq(visitor.GetSelectedFields(utils::impl::InternalTag()), + selected_fields); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + fields_with_selected_children); +} + +TEST(FieldsVisitorConstructor, TestAllMessageTypes) { + ugrpc::FieldsVisitor visitor(FieldSelector); + + const ugrpc::FieldsVisitor::Dependencies& sf = + visitor.GetSelectedFields(utils::impl::InternalTag()); + EXPECT_TRUE(ContainsMessage(sf, "sample.ugrpc.MessageWithDifferentTypes")); + EXPECT_TRUE(ContainsMessage( + sf, "sample.ugrpc.MessageWithDifferentTypes.NestedMessage")); + + const ugrpc::FieldsVisitor::Dependencies& fwsc = + visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()); + EXPECT_TRUE(ContainsMessage(fwsc, "sample.ugrpc.MessageWithDifferentTypes")); + EXPECT_TRUE(ContainsMessage( + fwsc, "sample.ugrpc.MessageWithDifferentTypes.NestedMapEntry")); +} + +TEST(FieldsVisitorVisit, TestEmptyMessage) { + std::size_t calls = 0; + ugrpc::FieldsVisitor visitor(FieldSelector, ugrpc::DescriptorList{}); + sample::ugrpc::MessageWithDifferentTypes message; + visitor.Visit( + message, [&calls](google::protobuf::Message&, + const google::protobuf::FieldDescriptor&) { ++calls; }); + ASSERT_EQ(calls, 0); + ASSERT_TRUE(google::protobuf::util::MessageDifferencer::Equals( + message, sample::ugrpc::MessageWithDifferentTypes())); +} + +TEST(FieldsVisitorVisit, TestMessage) { + std::size_t calls = 0; + auto message = ConstructMessage(); + ugrpc::FieldsVisitor visitor(FieldSelector, ugrpc::DescriptorList{}); + visitor.Visit( + message, [&calls](google::protobuf::Message&, + const google::protobuf::FieldDescriptor&) { ++calls; }); + + MyExpectEq(visitor.GetSelectedFields(utils::impl::InternalTag()), + diff_types::GetSelectedFields()); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + diff_types::GetFieldsWithSelectedChildren()); + + const std::size_t expected_calls = 1 + // optional_string + 1 + // optional_int + 1; // repeated_message + ASSERT_EQ(calls, expected_calls); + ASSERT_TRUE(google::protobuf::util::MessageDifferencer::Equals( + message, ConstructMessage())); +} + +TEST(FieldsVisitorVisitRecursive, TestEmptyMessage) { + std::size_t calls = 0; + ugrpc::FieldsVisitor visitor(FieldSelector, ugrpc::DescriptorList{}); + sample::ugrpc::MessageWithDifferentTypes message; + visitor.VisitRecursive( + message, [&calls](google::protobuf::Message&, + const google::protobuf::FieldDescriptor&) { ++calls; }); + ASSERT_EQ(calls, 0); + ASSERT_TRUE(google::protobuf::util::MessageDifferencer::Equals( + message, sample::ugrpc::MessageWithDifferentTypes())); +} + +TEST(FieldsVisitorVisitRecursive, TestMessage) { + std::size_t calls = 0; + auto message = ConstructMessage(); + ugrpc::FieldsVisitor visitor(FieldSelector, ugrpc::DescriptorList{}); + visitor.VisitRecursive( + message, [&calls](google::protobuf::Message&, + const google::protobuf::FieldDescriptor&) { ++calls; }); + + MyExpectEq(visitor.GetSelectedFields(utils::impl::InternalTag()), + diff_types::GetSelectedFields()); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + diff_types::GetFieldsWithSelectedChildren()); + + const std::size_t expected_calls = 1 + // optional_string + 1 + // optional_int + 1 + // required_nested required_string + 1 + // required_recursive optional_string + 1 + // repeated_message + 2 + // repeated_message required_string + 2; // nested_map (values) required_string + ASSERT_EQ(calls, expected_calls); + ASSERT_TRUE(google::protobuf::util::MessageDifferencer::Equals( + message, ConstructMessage())); +} + +TEST(MessagesVisitorConstructor, TestComponent1) { + ugrpc::MessagesVisitor visitor(MessageSelector, component1::Get()); + MyExpectEq(visitor.GetSelectedMessages(utils::impl::InternalTag()), + component1::GetSelectedMessages()); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + component1::GetFieldsWithSelectedChildren()); +} + +TEST(MessagesVisitorConstructor, TestComponent2) { + ugrpc::MessagesVisitor visitor(MessageSelector, component2::Get()); + MyExpectEq(visitor.GetSelectedMessages(utils::impl::InternalTag()), + component2::GetSelectedMessages()); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + component2::GetFieldsWithSelectedChildren()); +} + +TEST(MessagesVisitorConstructor, TestComponent3) { + ugrpc::MessagesVisitor visitor(MessageSelector, component3::Get()); + MyExpectEq(visitor.GetSelectedMessages(utils::impl::InternalTag()), + component3::GetSelectedMessages()); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + component3::GetFieldsWithSelectedChildren()); +} + +TEST(MessagesVisitorConstructor, TestComponent4) { + ugrpc::MessagesVisitor visitor(MessageSelector, component4::Get()); + MyExpectEq(visitor.GetSelectedMessages(utils::impl::InternalTag()), + component4::GetSelectedMessages()); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + component4::GetFieldsWithSelectedChildren()); +} + +TEST(MessagesVisitorConstructor, TestDiffTypes) { + ugrpc::MessagesVisitor visitor(MessageSelector, diff_types::Get()); + MyExpectEq(visitor.GetSelectedMessages(utils::impl::InternalTag()), + diff_types::GetSelectedMessages()); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + diff_types::GetFieldsWithSelectedChildren()); +} + +TEST(MessagesVisitorConstructor, TestPartialComponent) { + ugrpc::DescriptorList messages = { + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg1A"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg1C"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg1D"), + ugrpc::FindGeneratedMessage("sample.ugrpc.Msg1E")}; + ugrpc::MessagesVisitor visitor(MessageSelector, messages); + MyExpectEq(visitor.GetSelectedMessages(utils::impl::InternalTag()), + component1::GetSelectedMessages()); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + component1::GetFieldsWithSelectedChildren()); +} + +TEST(MessagesVisitorConstructor, TestMultipleComponents) { + ugrpc::DescriptorList messages; + for (const auto& msg : component1::Get()) messages.push_back(msg); + for (const auto& msg : component2::Get()) messages.push_back(msg); + for (const auto& msg : component3::Get()) messages.push_back(msg); + for (const auto& msg : component4::Get()) messages.push_back(msg); + + ugrpc::MessagesVisitor::DescriptorSet selected_messages; + selected_messages.merge(component1::GetSelectedMessages()); + selected_messages.merge(component2::GetSelectedMessages()); + selected_messages.merge(component3::GetSelectedMessages()); + selected_messages.merge(component4::GetSelectedMessages()); + + ugrpc::MessagesVisitor::Dependencies fields_with_selected_children; + fields_with_selected_children.merge( + component1::GetFieldsWithSelectedChildren()); + fields_with_selected_children.merge( + component2::GetFieldsWithSelectedChildren()); + fields_with_selected_children.merge( + component3::GetFieldsWithSelectedChildren()); + fields_with_selected_children.merge( + component4::GetFieldsWithSelectedChildren()); + + ugrpc::MessagesVisitor visitor(MessageSelector, messages); + MyExpectEq(visitor.GetSelectedMessages(utils::impl::InternalTag()), + selected_messages); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + fields_with_selected_children); +} + +TEST(MessagesVisitorConstructor, TestAllMessageTypes) { + ugrpc::MessagesVisitor visitor(MessageSelector); + + const ugrpc::MessagesVisitor::DescriptorSet& sm = + visitor.GetSelectedMessages(utils::impl::InternalTag()); + EXPECT_TRUE(ContainsMessage( + sm, "sample.ugrpc.MessageWithDifferentTypes.NestedMessage")); + + const ugrpc::MessagesVisitor::Dependencies& fwsc = + visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()); + EXPECT_TRUE(ContainsMessage(fwsc, "sample.ugrpc.MessageWithDifferentTypes")); + EXPECT_TRUE(ContainsMessage( + fwsc, "sample.ugrpc.MessageWithDifferentTypes.NestedMapEntry")); +} + +TEST(MessagesVisitorVisit, TestEmptyMessage) { + std::size_t calls = 0; + ugrpc::MessagesVisitor visitor(MessageSelector, ugrpc::DescriptorList{}); + sample::ugrpc::MessageWithDifferentTypes::NestedMessage message; + visitor.Visit(message, [&calls](google::protobuf::Message&) { ++calls; }); + ASSERT_EQ(calls, 1); + ASSERT_TRUE(google::protobuf::util::MessageDifferencer::Equals( + message, sample::ugrpc::MessageWithDifferentTypes::NestedMessage())); +} + +TEST(MessagesVisitorVisit, TestMessage) { + std::size_t calls = 0; + auto message = ConstructMessage(); + ugrpc::MessagesVisitor visitor(MessageSelector, ugrpc::DescriptorList{}); + visitor.Visit(*message.mutable_required_nested(), + [&calls](google::protobuf::Message&) { ++calls; }); + ASSERT_EQ(calls, 1); + ASSERT_TRUE(google::protobuf::util::MessageDifferencer::Equals( + message, ConstructMessage())); +} + +TEST(MessagesVisitorVisitRecursive, TestEmptyMessage) { + std::size_t calls = 0; + ugrpc::MessagesVisitor visitor(MessageSelector, ugrpc::DescriptorList{}); + sample::ugrpc::MessageWithDifferentTypes::NestedMessage message; + visitor.VisitRecursive(message, + [&calls](google::protobuf::Message&) { ++calls; }); + ASSERT_EQ(calls, 1); + ASSERT_TRUE(google::protobuf::util::MessageDifferencer::Equals( + message, sample::ugrpc::MessageWithDifferentTypes::NestedMessage())); +} + +TEST(MessagesVisitorVisitRecursive, TestMessage) { + std::size_t calls = 0; + auto message = ConstructMessage(); + ugrpc::MessagesVisitor visitor(MessageSelector, ugrpc::DescriptorList{}); + visitor.VisitRecursive(message, + [&calls](google::protobuf::Message&) { ++calls; }); + + MyExpectEq(visitor.GetSelectedMessages(utils::impl::InternalTag()), + diff_types::GetSelectedMessages()); + MyExpectEq(visitor.GetFieldsWithSelectedChildren(utils::impl::InternalTag()), + diff_types::GetFieldsWithSelectedChildren()); + + const std::size_t expected_calls = 1 + // required_nested + 2 + // repeated_message + 2; // nested_map values + ASSERT_EQ(calls, expected_calls); + ASSERT_TRUE(google::protobuf::util::MessageDifferencer::Equals( + message, ConstructMessage())); +} + USERVER_NAMESPACE_END diff --git a/scripts/grpc/templates/client.usrv.cpp.jinja b/scripts/grpc/templates/client.usrv.cpp.jinja index 2c0612e0ccc1..9cfc55201bf6 100644 --- a/scripts/grpc/templates/client.usrv.cpp.jinja +++ b/scripts/grpc/templates/client.usrv.cpp.jinja @@ -4,24 +4,31 @@ #include "{{ proto.source_file_without_ext }}_client.usrv.pb.hpp" {# All constant includes must go inside this header #} +#include #include {{ utils.include_grpcpp(proto.source_file_without_ext) }} {% call utils.optional_namespace(proto.namespace) %} -namespace { {% for service in proto.services %} -constexpr std::string_view k{{service.name}}MethodNames[] = { +// Inline to deduplicate between client.usrv.cpp and service.usrv.cpp +inline constexpr std::string_view k{{service.name}}MethodNames[] = { {% for method in service.method %} "{{proto.package_prefix}}{{service.name}}/{{method.name}}", {% endfor %} }; -{% endfor %} -} // namespace -{% for service in proto.services %} +// Inline to deduplicate between client.usrv.cpp and service.usrv.cpp +inline const bool k{{service.name}}TypesRegistration = + (USERVER_NAMESPACE::ugrpc::impl::RegisterMessageTypes({ + {% for method in service.method %} + std::string("{{method.input_type}}").substr(1), + std::string("{{method.output_type}}").substr(1), + {% endfor %} + }), + false); {{service.name}}Client::{{service.name}}Client( USERVER_NAMESPACE::ugrpc::client::impl::ClientDependencies&& dependencies) @@ -56,6 +63,8 @@ constexpr std::string_view k{{service.name}}MethodNames[] = { USERVER_NAMESPACE::ugrpc::impl::StaticServiceMetadata {{ service.name }}Client::GetMetadata() { + (void)k{{service.name}}TypesRegistration; // odr-use + return USERVER_NAMESPACE::ugrpc::impl::MakeStaticServiceMetadata< {{utils.namespace_with_colons(proto.namespace)}}::{{service.name}}>( k{{service.name}}MethodNames); diff --git a/scripts/grpc/templates/service.usrv.cpp.jinja b/scripts/grpc/templates/service.usrv.cpp.jinja index c2307cede319..b4315bdd3ee0 100644 --- a/scripts/grpc/templates/service.usrv.cpp.jinja +++ b/scripts/grpc/templates/service.usrv.cpp.jinja @@ -4,25 +4,31 @@ #include "{{ proto.source_file_without_ext }}_service.usrv.pb.hpp" {# All constant includes must go inside this header #} +#include #include {{ utils.include_grpcpp(proto.source_file_without_ext) }} {% call utils.optional_namespace(proto.namespace) %} -namespace { {% for service in proto.services %} -constexpr std::string_view k{{service.name}}MethodNames[] = { +// Inline to deduplicate between client.usrv.cpp and service.usrv.cpp +inline constexpr std::string_view k{{service.name}}MethodNames[] = { {% for method in service.method %} "{{proto.package_prefix}}{{service.name}}/{{method.name}}", {% endfor %} }; -{% endfor %} -} // namespace - -{% for service in proto.services %} +// Inline to deduplicate between client.usrv.cpp and service.usrv.cpp +inline const bool k{{service.name}}TypesRegistration = + (USERVER_NAMESPACE::ugrpc::impl::RegisterMessageTypes({ + {% for method in service.method %} + std::string("{{method.input_type}}").substr(1), + std::string("{{method.output_type}}").substr(1), + {% endfor %} + }), + false); {% for method in service.method %} {% if method.client_streaming and method.server_streaming %} @@ -98,6 +104,8 @@ void {{service.name}}Base::{{method.name}}({{method.name}}Call& call, std::unique_ptr {{service.name}}Base::MakeWorker( USERVER_NAMESPACE::ugrpc::server::impl::ServiceSettings&& settings) { + (void)k{{service.name}}TypesRegistration; // odr-use + return USERVER_NAMESPACE::ugrpc::server::impl::MakeServiceWorker< {{utils.namespace_with_colons(proto.namespace)}}::{{service.name}}>( std::move(settings), k{{service.name}}MethodNames, *this,