Skip to content

Commit

Permalink
Merge pull request #51 from joerunde/streams
Browse files Browse the repository at this point in the history
Streams
  • Loading branch information
joerunde authored May 11, 2023
2 parents d0b57a6 + 92aea25 commit 9d874e7
Show file tree
Hide file tree
Showing 5 changed files with 395 additions and 146 deletions.
14 changes: 11 additions & 3 deletions py_to_proto/descriptor_to_message_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,17 @@ def descriptor_to_message_class(
try:
message_class = descriptor._concrete_class
except (TypeError, SystemError, AttributeError):
message_class = reflection.message_factory.MessageFactory().GetPrototype(
descriptor
)
# protobuf version compatibility
if hasattr(reflection.message_factory, "GetMessageClass"):
# Newer protobuf versions use GetMessageClass
message_class = reflection.message_factory.GetMessageClass(
descriptor
) # pragma: no cover
else:
# Older protobuf versions require creating an instance of a MessageFactory
message_class = (
reflection.message_factory.MessageFactory().GetPrototype(descriptor)
) # pragma: no cover

# Recursively add nested messages
for nested_message_descriptor in descriptor.nested_types:
Expand Down
231 changes: 150 additions & 81 deletions py_to_proto/json_to_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
import types

# Third Party
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pb2
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message, service
from google.protobuf.descriptor import ServiceDescriptor
from google.protobuf import service
from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor
from google.protobuf.service import Service
from google.protobuf.service_reflection import GeneratedServiceType
import grpc
Expand Down Expand Up @@ -36,7 +35,11 @@
"input_type": {"type": "string"},
"name": {"type": "string"},
"output_type": {"type": "string"},
}
},
"optionalProperties": {
"server_streaming": {"type": "boolean"},
"client_streaming": {"type": "boolean"},
},
}
}
}
Expand All @@ -52,14 +55,22 @@
ServiceJsonType = Dict[str, Dict[str, List[Dict[str, str]]]]


@dataclasses.dataclass
class GRPCService:
descriptor: ServiceDescriptor
registration_function: Callable[[Service, grpc.Server], None]
client_stub_class: Type
service_class: Type[service.Service]


def json_to_service(
name: str,
package: str,
json_service_def: ServiceJsonType,
*,
descriptor_pool: Optional[_descriptor_pool.DescriptorPool] = None,
) -> _descriptor.ServiceDescriptor:
"""Convert a JSON representation of an RPC service into a ServiceDescriptor.
) -> GRPCService:
"""Convert a JSON representation of an RPC service into a GRPCService.
Reference: https://jsontypedef.com/docs/jtd-in-5-minutes/
Expand All @@ -77,21 +88,71 @@ def json_to_service(
message descriptors
Returns:
descriptor: google.protobuf.descriptor.ServiceDescriptor
The ServiceDescriptor corresponding to this json definition
grpc_service: GRPCService
The GRPCService container with the service descriptor and other associated
grpc bits required to boot a server:
- Servicer registration function
- Client stub class
- Servicer base class
"""
# Ensure we have a valid service spec
log.debug2("Validating service json")
if not validate_jtd(json_service_def, SERVICE_JTD_SCHEMA, EXTENDED_TYPE_VALIDATORS):
raise ValueError("Invalid service json")

method_descriptor_protos: List[descriptor_pb2.MethodDescriptorProto] = []
imports: List[str] = []

# And descriptor pool
if descriptor_pool is None:
log.debug2("Using the default descriptor pool")
descriptor_pool = _descriptor_pool.Default()

# First get the descriptor proto:
service_fd_proto = _json_to_service_file_descriptor_proto(
name, package, json_service_def, descriptor_pool=descriptor_pool
)
assert (
len(service_fd_proto.service) == 1
), f"File Descriptor {service_fd_proto.name} should only have one service"
service_descriptor_proto = service_fd_proto.service[0]

# Then put that in the pool to get the real descriptor back
log.debug("Adding Descriptors to DescriptorPool")
safe_add_fd_to_pool(service_fd_proto, descriptor_pool)
service_fullname = name if not package else ".".join([package, name])
service_descriptor = descriptor_pool.FindServiceByName(service_fullname)

# Then the client stub:
client_stub = _service_descriptor_to_client_stub(
service_descriptor, service_descriptor_proto
)

# And the registration function:
registration_function = _service_descriptor_to_server_registration_function(
service_descriptor, service_descriptor_proto
)

# And service class!
service_class = _service_descriptor_to_service(service_descriptor)

return GRPCService(
descriptor=service_descriptor,
service_class=service_class,
client_stub_class=client_stub,
registration_function=registration_function,
)


def _json_to_service_file_descriptor_proto(
name: str,
package: str,
json_service_def: ServiceJsonType,
*,
descriptor_pool: Optional[_descriptor_pool.DescriptorPool] = None,
) -> descriptor_pb2.FileDescriptorProto:
"""Creates the FileDescriptorProto for the service definition"""

method_descriptor_protos: List[descriptor_pb2.MethodDescriptorProto] = []
imports: List[str] = []

json_service = json_service_def["service"]
rpcs_def = json_service["rpcs"]
for rpc_def in rpcs_def:
Expand All @@ -106,6 +167,8 @@ def json_to_service(
name=rpc_def["name"],
input_type=input_descriptor.full_name,
output_type=output_descriptor.full_name,
client_streaming=rpc_def.get("client_streaming", False),
server_streaming=rpc_def.get("server_streaming", False),
)
)
imports.append(input_descriptor.file.name)
Expand All @@ -126,18 +189,11 @@ def json_to_service(
service=[service_descriptor_proto],
)

# Add the FileDescriptorProto to the Descriptor Pool
log.debug("Adding Descriptors to DescriptorPool")
safe_add_fd_to_pool(fd_proto, descriptor_pool)

# Return the descriptor for the top-level message
fullname = name if not package else ".".join([package, name])

return descriptor_pool.FindServiceByName(fullname)
return fd_proto


def service_descriptor_to_service(
service_descriptor: _descriptor.ServiceDescriptor,
def _service_descriptor_to_service(
service_descriptor: ServiceDescriptor,
) -> Type[service.Service]:
"""Create a service class from a service descriptor
Expand All @@ -147,8 +203,8 @@ def service_descriptor_to_service(
Returns:
Type[google.protobuf.service.Service]
A new class with metaclass google.protobuf.service_reflection.GeneratedServiceType containing the methods
from the service_descriptor
A new class with metaclass google.protobuf.service_reflection.GeneratedServiceType
containing the methods from the service_descriptor
"""
service_class = types.new_class(
service_descriptor.name,
Expand All @@ -161,28 +217,49 @@ def service_descriptor_to_service(
return service_class


def service_descriptor_to_client_stub(
service_descriptor: _descriptor.ServiceDescriptor,
def _service_descriptor_to_client_stub(
service_descriptor: ServiceDescriptor,
service_descriptor_proto: descriptor_pb2.ServiceDescriptorProto,
) -> Type:
"""Generates a new client stub class from the service descriptor
Args:
service_descriptor (google.protobuf.descriptor.ServiceDescriptor):
service_descriptor: google.protobuf.descriptor.ServiceDescriptor
The ServiceDescriptor to generate a service interface for
service_descriptor_proto: google.protobuf.descriptor_pb2.ServiceDescriptorProto
The descriptor proto for that service. This holds the I/O streaming information
for each method
"""
methods = _get_rpc_methods(service_descriptor)
_assert_method_lists_same(service_descriptor, service_descriptor_proto)

def _get_channel_func(
channel: grpc.Channel, method: descriptor_pb2.MethodDescriptorProto
) -> Callable:
if method.client_streaming and method.server_streaming:
return channel.stream_stream
if not method.client_streaming and method.server_streaming:
return channel.unary_stream
if method.client_streaming and not method.server_streaming:
return channel.stream_unary
return channel.unary_unary

# Initializer
def initializer(self, channel):
def initializer(self, channel: grpc.Channel):
f"""Initializes a client stub with for the {service_descriptor.name} Service"""
for method in methods:
for method, method_proto in zip(
service_descriptor.methods, service_descriptor_proto.method
):
setattr(
self,
method.name,
channel.unary_unary(
method.fullname,
request_serializer=method.input_message_class.SerializeToString,
response_deserializer=method.output_message_class.FromString,
_get_channel_func(channel, method_proto)(
_get_method_fullname(method),
request_serializer=descriptor_to_message_class(
method.input_type
).SerializeToString,
response_deserializer=descriptor_to_message_class(
method.output_type
).FromString,
),
)

Expand All @@ -196,29 +273,48 @@ def initializer(self, channel):
)


def service_descriptor_to_server_registration_function(
service_descriptor: _descriptor.ServiceDescriptor,
def _service_descriptor_to_server_registration_function(
service_descriptor: ServiceDescriptor,
service_descriptor_proto: descriptor_pb2.ServiceDescriptorProto,
) -> Callable[[Service, grpc.Server], None]:
"""Generates a server registration function from the service descriptor
Args:
service_descriptor: google.protobuf.descriptor.ServiceDescriptor
The ServiceDescriptor to generate a service interface for
service_descriptor_proto: google.protobuf.descriptor_pb2.ServiceDescriptorProto
The descriptor proto for that service. This holds the I/O streaming information
for each method
Returns:
function: Server registration function to add service handlers to a server
"""
methods = _get_rpc_methods(service_descriptor)
_assert_method_lists_same(service_descriptor, service_descriptor_proto)

def _get_handler(method: descriptor_pb2.MethodDescriptorProto):
if method.client_streaming and method.server_streaming:
return grpc.stream_stream_rpc_method_handler
if not method.client_streaming and method.server_streaming:
return grpc.unary_stream_rpc_method_handler
if method.client_streaming and not method.server_streaming:
return grpc.stream_unary_rpc_method_handler
return grpc.unary_unary_rpc_method_handler

def registration_function(servicer: Service, server: grpc.Server):
"""Server registration function"""
rpc_method_handlers = {
method.name: grpc.unary_unary_rpc_method_handler(
method.name: _get_handler(method_proto)(
getattr(servicer, method.name),
request_deserializer=method.input_message_class.FromString,
response_serializer=method.output_message_class.SerializeToString,
request_deserializer=descriptor_to_message_class(
method.input_type
).FromString,
response_serializer=descriptor_to_message_class(
method.output_type
).SerializeToString,
)
for method, method_proto in zip(
service_descriptor.methods, service_descriptor_proto.method
)
for method in methods
}
generic_handler = grpc.method_handlers_generic_handler(
service_descriptor.full_name, rpc_method_handlers
Expand All @@ -228,47 +324,20 @@ def registration_function(servicer: Service, server: grpc.Server):
return registration_function


@dataclasses.dataclass
class _RPCMethod:
name: str
fullname: str
input_message_class: Type[message.Message]
output_message_class: Type[message.Message]


def _get_rpc_methods(service_descriptor: ServiceDescriptor) -> List[_RPCMethod]:
"""Get list of RPC methods from a service descriptor
def _get_method_fullname(method: MethodDescriptor):
method_name_parts = method.full_name.split(".")
return f"/{'.'.join(method_name_parts[:-1])}/{method_name_parts[-1]}"

Args:
service_descriptor: google.protobuf.descriptor.ServiceDescriptor
The ServiceDescriptor to get RPC methods for
Returns:
List of RPC methods
"""
# For each method, need to know input / output message
methods: List[_RPCMethod] = []

for method in service_descriptor.methods:
method: _descriptor.MethodDescriptor

input_descriptor: _descriptor.Descriptor = method.input_type
output_descriptor: _descriptor.Descriptor = method.output_type

input_message_class = descriptor_to_message_class(input_descriptor)
output_message_class = descriptor_to_message_class(output_descriptor)

method_name_parts = method.full_name.split(".")
method_full_name = (
f"/{'.'.join(method_name_parts[:-1])}/{method_name_parts[-1]}"
)
methods.append(
_RPCMethod(
name=method.name,
fullname=method_full_name,
input_message_class=input_message_class,
output_message_class=output_message_class,
)
)
def _assert_method_lists_same(
service_descriptor: ServiceDescriptor,
service_descriptor_proto: descriptor_pb2.ServiceDescriptorProto,
):
assert len(service_descriptor.methods) == len(service_descriptor_proto.method), (
f"Method count mismatch: {service_descriptor.full_name} has"
f" {len(service_descriptor.methods)} methods but proto descriptor"
f" {service_descriptor_proto.name} has {len(service_descriptor_proto.method)} methods"
)

return methods
for m1, m2 in zip(service_descriptor.methods, service_descriptor_proto.method):
assert m1.name == m2.name, f"Method mismatch: {m1.name}, {m2.name}"
1 change: 1 addition & 0 deletions requirements_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
pytest>=6.2.5
pytest-cov>=3.0.0
pytest-xdist>=2.5.0
tls_test_tools>=0.1.1

# Round-trip proto compilation
grpcio-tools>=1.46.3
4 changes: 2 additions & 2 deletions tests/test_descriptor_to_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def test_descriptor_to_file_service_descriptor(temp_dpool):
}
},
descriptor_pool=temp_dpool,
)
).descriptor
# TODO: type annotation fixup
res = descriptor_to_file(service_descriptor)
assert "service FooService {" in res
Expand Down Expand Up @@ -347,6 +347,6 @@ def test_descriptor_to_file_compilable_proto_with_service_descriptor(temp_dpool)
}
},
descriptor_pool=temp_dpool,
)
).descriptor
res = descriptor_to_file(service_descriptor)
assert compile_proto_module(res, imported_file_contents=imported_files)
Loading

0 comments on commit 9d874e7

Please sign in to comment.