Skip to content

Commit

Permalink
Other OpenAPI improvements (#2302)
Browse files Browse the repository at this point in the history
* Improve casing of operationId

* Add normalization of $ref, as those can be sent back from SDKs in many cases (e.g. for self referencing data structures)

* Revert "Improve casing of operationId"

This reverts commit e1da196
  • Loading branch information
slinkydeveloper authored Nov 15, 2024
1 parent bb522b7 commit 19d1aee
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 27 deletions.
145 changes: 119 additions & 26 deletions crates/types/src/schema/openapi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ use restate_utoipa::openapi::path::{Operation, Parameter, ParameterIn};
use restate_utoipa::openapi::request_body::RequestBody;
use restate_utoipa::openapi::*;
use serde::{Deserialize, Serialize};
use serde_json::json;
use serde_json::{json, Value};
use std::collections::HashMap;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServiceOpenAPI {
paths: Paths,
components: Components,
}

impl ServiceOpenAPI {
Expand All @@ -20,6 +21,8 @@ impl ServiceOpenAPI {
service_type: ServiceType,
handlers: &HashMap<String, HandlerSchemas>,
) -> Self {
let mut schemas_collector = Vec::new();

let root_path = if service_type.is_keyed() {
format!("/{service_name}/{{key}}/")
} else {
Expand All @@ -36,15 +39,17 @@ impl ServiceOpenAPI {

let mut paths = Paths::builder();
for (handler_name, handler_schemas) in handlers {
let operation_id = format!("{service_name}-{handler_name}");
let operation_id = handler_name;

if !handler_schemas.target_meta.public {
// We don't generate the OpenAPI route for that.
continue;
}

let request_body = infer_handler_request_body(handler_schemas);
let response = infer_handler_response(handler_schemas);
let request_body =
infer_handler_request_body(operation_id, handler_schemas, &mut schemas_collector);
let response =
infer_handler_response(operation_id, handler_schemas, &mut schemas_collector);

let call_item = PathItem::builder()
.summary(Some(format!("Call {service_name}/{handler_name}")))
Expand Down Expand Up @@ -78,7 +83,7 @@ impl ServiceOpenAPI {
.operation(
HttpMethod::Post,
Operation::builder()
.operation_id(Some(format!("{operation_id}-send")))
.operation_id(Some(format!("{operation_id}Send")))
.description(Some(
handler_schemas
.documentation
Expand All @@ -103,6 +108,12 @@ impl ServiceOpenAPI {

ServiceOpenAPI {
paths: paths.build(),
components: Components::builder()
.schemas_from_iter(schemas_collector.into_iter().map(|(schema_name, schema)| {
let ref_refix = format!("#/components/schemas/{schema_name}");
(schema_name, normalize_schema_refs(&ref_refix, schema))
}))
.build(),
}
}

Expand All @@ -112,7 +123,12 @@ impl ServiceOpenAPI {
service_name: &str,
documentation: Option<&str>,
revision: ServiceRevision,
) -> serde_json::Value {
) -> Value {
let mut components = restate_components();
components
.schemas
.append(&mut self.components.schemas.clone());

// TODO how to add servers?! :(
serde_json::to_value(
OpenApi::builder()
Expand All @@ -124,7 +140,7 @@ impl ServiceOpenAPI {
.build(),
)
.paths(self.paths.clone())
.components(Some(restate_components()))
.components(Some(components))
.build(),
)
.expect("Mapping OpenAPI to JSON should never fail")
Expand All @@ -135,6 +151,7 @@ impl ServiceOpenAPI {
pub fn empty() -> Self {
Self {
paths: Default::default(),
components: Default::default(),
}
}

Expand All @@ -143,7 +160,33 @@ impl ServiceOpenAPI {
}
}

fn infer_handler_request_body(handler_schemas: &HandlerSchemas) -> Option<RequestBody> {
fn request_schema_name(operation_id: &str) -> String {
format!("{operation_id}Request")
}

fn request_schema_ref(operation_id: &str) -> Ref {
Ref::new(format!(
"#/components/schemas/{}",
request_schema_name(operation_id)
))
}

fn response_schema_name(operation_id: &str) -> String {
format!("{operation_id}Response")
}

fn response_schema_ref(operation_id: &str) -> Ref {
Ref::new(format!(
"#/components/schemas/{}",
response_schema_name(operation_id)
))
}

fn infer_handler_request_body(
operation_id: &str,
handler_schemas: &HandlerSchemas,
schemas_collector: &mut Vec<(String, Schema)>,
) -> Option<RequestBody> {
let mut is_required = true;
if handler_schemas
.target_meta
Expand Down Expand Up @@ -179,10 +222,16 @@ fn infer_handler_request_body(handler_schemas: &HandlerSchemas) -> Option<Reques
InputValidationRule::JsonValue {
content_type,
schema,
} => Some((
content_type.to_string(),
Content::new(Some(Schema::new(schema.clone()))),
)),
} => {
let schema = Schema::new(schema.clone());

schemas_collector.push((request_schema_name(operation_id), schema));

Some((
content_type.to_string(),
Content::new(Some(request_schema_ref(operation_id))),
))
}
_ => unreachable!(),
}
} else {
Expand All @@ -205,7 +254,11 @@ fn infer_handler_request_body(handler_schemas: &HandlerSchemas) -> Option<Reques
}
}

fn infer_handler_response(handler_schemas: &HandlerSchemas) -> Response {
fn infer_handler_response(
operation_id: &str,
handler_schemas: &HandlerSchemas,
schemas_collector: &mut Vec<(String, Schema)>,
) -> Response {
match (
&handler_schemas.target_meta.output_rules.json_schema,
&handler_schemas.target_meta.output_rules.content_type_rule,
Expand All @@ -219,14 +272,19 @@ fn infer_handler_response(handler_schemas: &HandlerSchemas) -> Response {
Content::builder().build(),
)
.build(),
(Some(schema), OutputContentTypeRule::Set { content_type, .. }) => Response::builder()
.content(
content_type
.to_str()
.expect("content_type should have been checked before during registration"),
Content::new(Some(Schema::new(schema.clone()))),
)
.build(),
(Some(schema), OutputContentTypeRule::Set { content_type, .. }) => {
let schema = Schema::new(schema.clone());
schemas_collector.push((response_schema_name(operation_id), schema));

Response::builder()
.content(
content_type
.to_str()
.expect("content_type should have been checked before during registration"),
Content::new(Some(response_schema_ref(operation_id))),
)
.build()
}
}
}

Expand Down Expand Up @@ -256,7 +314,7 @@ fn delay_parameter() -> Parameter {
.schema(Some(
string_json_schema()
))
.example(Some(serde_json::Value::String("10s".to_string())))
.example(Some(Value::String("10s".to_string())))
.required(Required::False)
.description(Some("Specify the delay to execute the operation, for more info check the [delay documentation](https://docs.restate.dev/invoke/http#sending-a-delayed-message-over-http)"))
.build()
Expand Down Expand Up @@ -307,7 +365,7 @@ fn error_response() -> Response {
}

// Ideally we code generate this
fn error_response_json_schema() -> serde_json::Value {
fn error_response_json_schema() -> Value {
json!({
"type": "object",
"title": "Error",
Expand All @@ -332,7 +390,7 @@ fn error_response_json_schema() -> serde_json::Value {
}

// Ideally we code generate this
fn error_response_example() -> serde_json::Value {
fn error_response_example() -> Value {
json!({
"code": 500,
"message": "Internal server error",
Expand All @@ -356,7 +414,7 @@ fn send_response() -> Response {
}

// Ideally we code generate this
fn send_response_json_schema() -> serde_json::Value {
fn send_response_json_schema() -> Value {
json!({
"type": "object",
"properties": {
Expand All @@ -379,7 +437,7 @@ fn send_response_json_schema() -> serde_json::Value {
}

// Ideally we code generate this
fn send_response_example() -> serde_json::Value {
fn send_response_example() -> Value {
json!({
"invocationId": "inv_1gdJBtdVEcM942bjcDmb1c1khoaJe11Hbz",
"status": "Accepted",
Expand All @@ -389,3 +447,38 @@ fn send_response_example() -> serde_json::Value {
fn string_json_schema() -> Schema {
Schema::new(json!({"type": "string"}))
}

// We need to normalize all the $refs in the schema to append the ref_prefix
fn normalize_schema_refs(ref_prefix: &str, mut schema: Schema) -> Schema {
normalize_schema_refs_inner(ref_prefix, &mut schema.0);
schema
}

fn normalize_schema_refs_inner(ref_prefix: &str, schema: &mut Value) {
match schema {
Value::Array(array_value) => {
for val in array_value.iter_mut() {
normalize_schema_refs_inner(ref_prefix, val)
}
}
Value::Object(obj_value) => {
// Replace $ref attribute, if existing and starts with #
if let Some(ref_value) = obj_value.get_mut("$ref") {
if let Some(str_ref_value) = ref_value.as_str() {
// Local refs always start with #
if str_ref_value.starts_with('#') {
*ref_value = Value::String(format!(
"{ref_prefix}{}",
str_ref_value.trim_start_matches('#')
));
}
}
}

for val in obj_value.values_mut() {
normalize_schema_refs_inner(ref_prefix, val)
}
}
_ => {}
}
}
2 changes: 1 addition & 1 deletion crates/utoipa/src/openapi/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ impl ComponentsBuilder {
#[derive(Serialize, Deserialize, Clone, PartialEq)]
#[cfg_attr(feature = "debug", derive(Debug))]
#[serde(transparent)]
pub struct Schema(serde_json::Value);
pub struct Schema(pub serde_json::Value);

impl Default for Schema {
fn default() -> Self {
Expand Down

0 comments on commit 19d1aee

Please sign in to comment.