diff --git a/crates/types/src/schema/openapi.rs b/crates/types/src/schema/openapi.rs index e8b461114..c1b00f59b 100644 --- a/crates/types/src/schema/openapi.rs +++ b/crates/types/src/schema/openapi.rs @@ -6,12 +6,13 @@ use restate_utoipa::openapi::path::{Operation, Parameter, ParameterIn}; use restate_utoipa::openapi::request_body::RequestBody; use restate_utoipa::openapi::*; use serde::{Deserialize, Serialize}; -use serde_json::json; +use serde_json::{json, Value}; use std::collections::HashMap; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ServiceOpenAPI { paths: Paths, + components: Components, } impl ServiceOpenAPI { @@ -20,6 +21,8 @@ impl ServiceOpenAPI { service_type: ServiceType, handlers: &HashMap, ) -> Self { + let mut schemas_collector = Vec::new(); + let root_path = if service_type.is_keyed() { format!("/{service_name}/{{key}}/") } else { @@ -36,15 +39,17 @@ impl ServiceOpenAPI { let mut paths = Paths::builder(); for (handler_name, handler_schemas) in handlers { - let operation_id = format!("{service_name}-{handler_name}"); + let operation_id = handler_name; if !handler_schemas.target_meta.public { // We don't generate the OpenAPI route for that. continue; } - let request_body = infer_handler_request_body(handler_schemas); - let response = infer_handler_response(handler_schemas); + let request_body = + infer_handler_request_body(operation_id, handler_schemas, &mut schemas_collector); + let response = + infer_handler_response(operation_id, handler_schemas, &mut schemas_collector); let call_item = PathItem::builder() .summary(Some(format!("Call {service_name}/{handler_name}"))) @@ -78,7 +83,7 @@ impl ServiceOpenAPI { .operation( HttpMethod::Post, Operation::builder() - .operation_id(Some(format!("{operation_id}-send"))) + .operation_id(Some(format!("{operation_id}Send"))) .description(Some( handler_schemas .documentation @@ -103,6 +108,12 @@ impl ServiceOpenAPI { ServiceOpenAPI { paths: paths.build(), + components: Components::builder() + .schemas_from_iter(schemas_collector.into_iter().map(|(schema_name, schema)| { + let ref_refix = format!("#/components/schemas/{schema_name}"); + (schema_name, normalize_schema_refs(&ref_refix, schema)) + })) + .build(), } } @@ -112,7 +123,12 @@ impl ServiceOpenAPI { service_name: &str, documentation: Option<&str>, revision: ServiceRevision, - ) -> serde_json::Value { + ) -> Value { + let mut components = restate_components(); + components + .schemas + .append(&mut self.components.schemas.clone()); + // TODO how to add servers?! :( serde_json::to_value( OpenApi::builder() @@ -124,7 +140,7 @@ impl ServiceOpenAPI { .build(), ) .paths(self.paths.clone()) - .components(Some(restate_components())) + .components(Some(components)) .build(), ) .expect("Mapping OpenAPI to JSON should never fail") @@ -135,6 +151,7 @@ impl ServiceOpenAPI { pub fn empty() -> Self { Self { paths: Default::default(), + components: Default::default(), } } @@ -143,7 +160,33 @@ impl ServiceOpenAPI { } } -fn infer_handler_request_body(handler_schemas: &HandlerSchemas) -> Option { +fn request_schema_name(operation_id: &str) -> String { + format!("{operation_id}Request") +} + +fn request_schema_ref(operation_id: &str) -> Ref { + Ref::new(format!( + "#/components/schemas/{}", + request_schema_name(operation_id) + )) +} + +fn response_schema_name(operation_id: &str) -> String { + format!("{operation_id}Response") +} + +fn response_schema_ref(operation_id: &str) -> Ref { + Ref::new(format!( + "#/components/schemas/{}", + response_schema_name(operation_id) + )) +} + +fn infer_handler_request_body( + operation_id: &str, + handler_schemas: &HandlerSchemas, + schemas_collector: &mut Vec<(String, Schema)>, +) -> Option { let mut is_required = true; if handler_schemas .target_meta @@ -179,10 +222,16 @@ fn infer_handler_request_body(handler_schemas: &HandlerSchemas) -> Option Some(( - content_type.to_string(), - Content::new(Some(Schema::new(schema.clone()))), - )), + } => { + let schema = Schema::new(schema.clone()); + + schemas_collector.push((request_schema_name(operation_id), schema)); + + Some(( + content_type.to_string(), + Content::new(Some(request_schema_ref(operation_id))), + )) + } _ => unreachable!(), } } else { @@ -205,7 +254,11 @@ fn infer_handler_request_body(handler_schemas: &HandlerSchemas) -> Option Response { +fn infer_handler_response( + operation_id: &str, + handler_schemas: &HandlerSchemas, + schemas_collector: &mut Vec<(String, Schema)>, +) -> Response { match ( &handler_schemas.target_meta.output_rules.json_schema, &handler_schemas.target_meta.output_rules.content_type_rule, @@ -219,14 +272,19 @@ fn infer_handler_response(handler_schemas: &HandlerSchemas) -> Response { Content::builder().build(), ) .build(), - (Some(schema), OutputContentTypeRule::Set { content_type, .. }) => Response::builder() - .content( - content_type - .to_str() - .expect("content_type should have been checked before during registration"), - Content::new(Some(Schema::new(schema.clone()))), - ) - .build(), + (Some(schema), OutputContentTypeRule::Set { content_type, .. }) => { + let schema = Schema::new(schema.clone()); + schemas_collector.push((response_schema_name(operation_id), schema)); + + Response::builder() + .content( + content_type + .to_str() + .expect("content_type should have been checked before during registration"), + Content::new(Some(response_schema_ref(operation_id))), + ) + .build() + } } } @@ -256,7 +314,7 @@ fn delay_parameter() -> Parameter { .schema(Some( string_json_schema() )) - .example(Some(serde_json::Value::String("10s".to_string()))) + .example(Some(Value::String("10s".to_string()))) .required(Required::False) .description(Some("Specify the delay to execute the operation, for more info check the [delay documentation](https://docs.restate.dev/invoke/http#sending-a-delayed-message-over-http)")) .build() @@ -307,7 +365,7 @@ fn error_response() -> Response { } // Ideally we code generate this -fn error_response_json_schema() -> serde_json::Value { +fn error_response_json_schema() -> Value { json!({ "type": "object", "title": "Error", @@ -332,7 +390,7 @@ fn error_response_json_schema() -> serde_json::Value { } // Ideally we code generate this -fn error_response_example() -> serde_json::Value { +fn error_response_example() -> Value { json!({ "code": 500, "message": "Internal server error", @@ -356,7 +414,7 @@ fn send_response() -> Response { } // Ideally we code generate this -fn send_response_json_schema() -> serde_json::Value { +fn send_response_json_schema() -> Value { json!({ "type": "object", "properties": { @@ -379,7 +437,7 @@ fn send_response_json_schema() -> serde_json::Value { } // Ideally we code generate this -fn send_response_example() -> serde_json::Value { +fn send_response_example() -> Value { json!({ "invocationId": "inv_1gdJBtdVEcM942bjcDmb1c1khoaJe11Hbz", "status": "Accepted", @@ -389,3 +447,38 @@ fn send_response_example() -> serde_json::Value { fn string_json_schema() -> Schema { Schema::new(json!({"type": "string"})) } + +// We need to normalize all the $refs in the schema to append the ref_prefix +fn normalize_schema_refs(ref_prefix: &str, mut schema: Schema) -> Schema { + normalize_schema_refs_inner(ref_prefix, &mut schema.0); + schema +} + +fn normalize_schema_refs_inner(ref_prefix: &str, schema: &mut Value) { + match schema { + Value::Array(array_value) => { + for val in array_value.iter_mut() { + normalize_schema_refs_inner(ref_prefix, val) + } + } + Value::Object(obj_value) => { + // Replace $ref attribute, if existing and starts with # + if let Some(ref_value) = obj_value.get_mut("$ref") { + if let Some(str_ref_value) = ref_value.as_str() { + // Local refs always start with # + if str_ref_value.starts_with('#') { + *ref_value = Value::String(format!( + "{ref_prefix}{}", + str_ref_value.trim_start_matches('#') + )); + } + } + } + + for val in obj_value.values_mut() { + normalize_schema_refs_inner(ref_prefix, val) + } + } + _ => {} + } +} diff --git a/crates/utoipa/src/openapi/schema.rs b/crates/utoipa/src/openapi/schema.rs index 459c10557..a6a7118d4 100644 --- a/crates/utoipa/src/openapi/schema.rs +++ b/crates/utoipa/src/openapi/schema.rs @@ -215,7 +215,7 @@ impl ComponentsBuilder { #[derive(Serialize, Deserialize, Clone, PartialEq)] #[cfg_attr(feature = "debug", derive(Debug))] #[serde(transparent)] -pub struct Schema(serde_json::Value); +pub struct Schema(pub serde_json::Value); impl Default for Schema { fn default() -> Self {