Skip to content

Commit

Permalink
Implement Ingress service
Browse files Browse the repository at this point in the history
  • Loading branch information
slinkydeveloper committed Jun 19, 2023
1 parent 1c44a47 commit cd69ef0
Show file tree
Hide file tree
Showing 13 changed files with 454 additions and 96 deletions.
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/ingress_grpc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ drain = { workspace = true }
arc-swap = { workspace = true }
thiserror = { workspace = true }
schemars = { workspace = true, optional = true }
once_cell = "1.18"

[dev-dependencies]
hyper = { workspace = true, features = ["client"] }
Expand Down
7 changes: 6 additions & 1 deletion src/ingress_grpc/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@ fn main() -> std::io::Result<()> {
.compile_protos(
&[
"proto/grpc/reflection/v1alpha/reflection.proto",
"proto/dev/restate/services.proto",
"tests/proto/greeter.proto",
],
&["proto/grpc/reflection/v1alpha", "tests/proto"],
&[
"proto/grpc/reflection/v1alpha",
"proto/dev/restate",
"tests/proto",
],
)?;
Ok(())
}
71 changes: 49 additions & 22 deletions src/ingress_grpc/src/handler.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use super::options::JsonOptions;
use super::pb::grpc::reflection::{
server_reflection_server::ServerReflection, server_reflection_server::ServerReflectionServer,
};
use super::protocol::{BoxBody, Protocol};
use super::reflection::{ServerReflection, ServerReflectionServer};
use super::*;

use std::sync::Arc;
Expand All @@ -12,6 +14,7 @@ use http::{Request, Response, StatusCode};
use http_body::Body;
use hyper::Body as HyperBody;
use opentelemetry::trace::{SpanContext, TraceContextExt};
use prost::Message;
use restate_common::types::{IngressId, ServiceInvocationResponseSink, SpanRelation};
use restate_service_metadata::MethodDescriptorRegistry;
use tokio::sync::Semaphore;
Expand Down Expand Up @@ -87,7 +90,7 @@ impl<InvocationFactory, MethodRegistry, ReflectionService> Service<Request<Hyper
for Handler<InvocationFactory, MethodRegistry, ReflectionService>
where
InvocationFactory: ServiceInvocationFactory + Clone + Send + 'static,
MethodRegistry: MethodDescriptorRegistry,
MethodRegistry: MethodDescriptorRegistry + Clone + Send + 'static,
ReflectionService: ServerReflection,
{
type Response = Response<BoxBody>;
Expand Down Expand Up @@ -144,6 +147,8 @@ where
let method_name = path_parts.remove(2).to_string();
let service_name = path_parts.remove(1).to_string();

// --- Special Restate services
// Reflections
if ServerReflectionServer::<ReflectionService>::NAME == service_name {
return self
.reflection_server
Expand All @@ -155,20 +160,6 @@ where
.boxed();
}

// Find the service method descriptor
let descriptor = if let Some(desc) = self
.method_registry
.resolve_method_descriptor(&service_name, &method_name)
{
desc
} else {
debug!("{}/{} not found", service_name, method_name);
return ok(protocol.encode_status(Status::not_found(format!(
"{service_name}/{method_name} not found"
))))
.boxed();
};

// Encapsulate in this closure the remaining part of the processing
let ingress_id = self.ingress_id;
let invocation_factory = self.invocation_factory.clone();
Expand Down Expand Up @@ -197,12 +188,30 @@ where
let ingress_span_context = ingress_span.context().span().span_context().clone();

async move {
let mut service_name = req_headers.service_name;
let mut method_name = req_headers.method_name;
let mut req_payload = req_payload;
let mut response_sink = Some(ServiceInvocationResponseSink::Ingress(ingress_id));
let mut wait_response = true;

// Ingress built-in service
if is_ingress_invoke(&service_name, &method_name) {
let invoke_request = pb::restate::services::InvokeRequest::decode(req_payload)
.map_err(|e| Status::invalid_argument(e.to_string()))?;

service_name = invoke_request.service;
method_name = invoke_request.method;
req_payload = invoke_request.argument;
response_sink = None;
wait_response = false;
}

// Create the service_invocation
let (service_invocation, service_invocation_span) = match invocation_factory.create(
&req_headers.service_name,
&req_headers.method_name,
&service_name,
&method_name,
req_payload,
Some(ServiceInvocationResponseSink::Ingress(ingress_id)),
response_sink,
SpanRelation::Parent(ingress_span_context)
) {
Ok(i) => i,
Expand All @@ -222,8 +231,22 @@ where
// https://docs.rs/tracing/latest/tracing/struct.Span.html#in-asynchronous-code
let enter_service_invocation_span = service_invocation_span.enter();

// More trace info
trace!(restate.invocation.request_headers = ?req_headers);
// Ingress built-in service just sends a fire and forget and closes
if !wait_response {
let sid = service_invocation.id.to_string();

if dispatcher_command_sender.send(Command::fire_and_forget(
service_invocation
)).is_err() {
debug!("Ingress dispatcher is closed while there is still an invocation in flight.");
return Err(Status::unavailable("Unavailable"));
}
return Ok(
pb::restate::services::InvokeResponse {
sid,
}.encode_to_vec().into()
)
}

// Send the service invocation
let (service_invocation_command, response_rx) =
Expand Down Expand Up @@ -259,7 +282,7 @@ where
let result_fut = protocol.handle_request(
service_name,
method_name,
descriptor,
self.method_registry.clone(),
self.json.clone(),
req,
ingress_request_handler,
Expand All @@ -283,3 +306,7 @@ fn span_relation(request_span: &SpanContext) -> SpanRelation {
SpanRelation::None
}
}

fn is_ingress_invoke(service_name: &str, method_name: &str) -> bool {
"dev.restate.Ingress" == service_name && "Invoke" == method_name
}
16 changes: 16 additions & 0 deletions src/ingress_grpc/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
mod dispatcher;
mod handler;
mod options;
mod pb;
mod protocol;
mod reflection;
mod server;

pub use dispatcher::{IngressDispatcherLoop, IngressDispatcherLoopError};
pub use options::Options;
pub use pb::MethodDescriptorRegistryWithIngressService;
pub use reflection::{ReflectionRegistry, RegistrationError};
pub use server::{HyperServerIngress, IngressServerError, StartSignal};

Expand Down Expand Up @@ -223,9 +225,23 @@ mod mocks {
pub(super) fn test_descriptor_registry() -> InMemoryMethodDescriptorRegistry {
let registry = InMemoryMethodDescriptorRegistry::default();
registry.register(greeter_service_descriptor());
registry.register(ingress_service_descriptor());
registry
}

pub(super) fn ingress_service_descriptor() -> ServiceDescriptor {
crate::pb::DEV_RESTATE_DESCRIPTOR_POOL
.get_service_by_name("dev.restate.Ingress")
.unwrap()
}

pub(super) fn ingress_invoke_method_descriptor() -> MethodDescriptor {
ingress_service_descriptor()
.methods()
.find(|m| m.name() == "Invoke")
.unwrap()
}

pub(super) fn greeter_service_descriptor() -> ServiceDescriptor {
test_descriptor_pool()
.services()
Expand Down
11 changes: 8 additions & 3 deletions src/ingress_grpc/src/options.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use super::reflection::ServerReflection;
use super::pb::grpc::reflection::server_reflection_server::ServerReflection;
use super::HyperServerIngress;
use super::*;

use crate::pb::MethodDescriptorRegistryWithIngressService;
use prost_reflect::{DeserializeOptions, SerializeOptions};
use restate_common::types::IngressId;
use restate_service_metadata::MethodDescriptorRegistry;
Expand Down Expand Up @@ -169,7 +170,11 @@ impl Options {
channel_size: usize,
) -> (
IngressDispatcherLoop,
HyperServerIngress<DescriptorRegistry, InvocationFactory, ReflectionService>,
HyperServerIngress<
MethodDescriptorRegistryWithIngressService<DescriptorRegistry>,
InvocationFactory,
ReflectionService,
>,
)
where
DescriptorRegistry: MethodDescriptorRegistry + Clone + Send + 'static,
Expand All @@ -189,7 +194,7 @@ impl Options {
concurrency_limit,
json,
ingress_id,
descriptor_registry,
MethodDescriptorRegistryWithIngressService::new(descriptor_registry),
invocation_factory,
reflection_service,
ingress_dispatcher_loop.create_command_sender(),
Expand Down
71 changes: 71 additions & 0 deletions src/ingress_grpc/src/pb.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use once_cell::sync::Lazy;
use prost_reflect::{DescriptorPool, MethodDescriptor};
use restate_service_metadata::MethodDescriptorRegistry;
use std::collections::HashMap;
use std::convert::AsRef;

pub(crate) mod grpc {
pub(crate) mod reflection {
#![allow(warnings)]
#![allow(clippy::all)]
#![allow(unknown_lints)]
include!(concat!(env!("OUT_DIR"), "/grpc.reflection.v1alpha.rs"));
}
}
pub(crate) mod restate {
pub(crate) mod services {
#![allow(warnings)]
#![allow(clippy::all)]
#![allow(unknown_lints)]
include!(concat!(env!("OUT_DIR"), "/dev.restate.rs"));
}
}

pub(crate) static DEV_RESTATE_DESCRIPTOR_POOL: Lazy<DescriptorPool> = Lazy::new(|| {
DescriptorPool::decode(
include_bytes!(concat!(env!("OUT_DIR"), "/file_descriptor_set.bin")).as_ref(),
)
.expect("The built-in descriptor pool should be valid")
});

// TODO this is a temporary solution until we have a schema registry where we can distinguish between ingress only services
// see https://github.com/restatedev/restate/issues/43#issuecomment-1597174972
#[derive(Clone)]
pub struct MethodDescriptorRegistryWithIngressService<MDR> {
method_descriptor_registry: MDR,
}

impl<MDR> MethodDescriptorRegistryWithIngressService<MDR> {
pub(crate) fn new(method_descriptor_registry: MDR) -> Self {
Self {
method_descriptor_registry,
}
}
}

impl<MDR: MethodDescriptorRegistry> MethodDescriptorRegistry
for MethodDescriptorRegistryWithIngressService<MDR>
{
fn resolve_method_descriptor(
&self,
svc_name: &str,
method_name: &str,
) -> Option<MethodDescriptor> {
if svc_name.starts_with("dev.restate") {
return DEV_RESTATE_DESCRIPTOR_POOL
.get_service_by_name(svc_name)
.and_then(|s| s.methods().find(|m| m.name() == method_name));
}
self.method_descriptor_registry
.resolve_method_descriptor(svc_name, method_name)
}

fn list_methods(&self, svc_name: &str) -> Option<HashMap<String, MethodDescriptor>> {
if svc_name.starts_with("dev.restate") {
return DEV_RESTATE_DESCRIPTOR_POOL
.get_service_by_name(svc_name)
.map(|s| s.methods().map(|m| (m.name().to_string(), m)).collect());
}
self.method_descriptor_registry.list_methods(svc_name)
}
}
Loading

0 comments on commit cd69ef0

Please sign in to comment.