From de73baf614d4a7128d71e7dafd4d6341dccaabae Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Mon, 8 Jul 2024 16:21:23 +0200 Subject: [PATCH] Flatten useless module hierarchy --- tonic/src/transport/mod.rs | 2 +- tonic/src/transport/service/grpc_timeout.rs | 287 ------------------- tonic/src/transport/service/mod.rs | 288 +++++++++++++++++++- 3 files changed, 287 insertions(+), 290 deletions(-) delete mode 100644 tonic/src/transport/service/grpc_timeout.rs diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index 254e5ef8d..52a710e43 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -107,7 +107,7 @@ pub use self::error::Error; #[cfg(feature = "server")] pub use self::server::Server; #[doc(inline)] -pub use self::service::grpc_timeout::TimeoutExpired; +pub use self::service::TimeoutExpired; #[cfg(feature = "tls")] pub use self::tls::Certificate; diff --git a/tonic/src/transport/service/grpc_timeout.rs b/tonic/src/transport/service/grpc_timeout.rs deleted file mode 100644 index 1858cdf5f..000000000 --- a/tonic/src/transport/service/grpc_timeout.rs +++ /dev/null @@ -1,287 +0,0 @@ -use crate::metadata::GRPC_TIMEOUT_HEADER; -use http::{HeaderMap, HeaderValue, Request}; -use pin_project::pin_project; -use std::{ - fmt, - future::Future, - pin::Pin, - task::{ready, Context, Poll}, - time::Duration, -}; -use tokio::time::Sleep; -use tower_service::Service; - -#[derive(Debug, Clone)] -pub(crate) struct GrpcTimeout { - inner: S, - server_timeout: Option, -} - -impl GrpcTimeout { - pub(crate) fn new(inner: S, server_timeout: Option) -> Self { - Self { - inner, - server_timeout, - } - } -} - -impl Service> for GrpcTimeout -where - S: Service>, - S::Error: Into, -{ - type Response = S::Response; - type Error = crate::Error; - type Future = ResponseFuture; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready(cx).map_err(Into::into) - } - - fn call(&mut self, req: Request) -> Self::Future { - let client_timeout = try_parse_grpc_timeout(req.headers()).unwrap_or_else(|e| { - tracing::trace!("Error parsing `grpc-timeout` header {:?}", e); - None - }); - - // Use the shorter of the two durations, if either are set - let timeout_duration = match (client_timeout, self.server_timeout) { - (None, None) => None, - (Some(dur), None) => Some(dur), - (None, Some(dur)) => Some(dur), - (Some(header), Some(server)) => { - let shorter_duration = std::cmp::min(header, server); - Some(shorter_duration) - } - }; - - ResponseFuture { - inner: self.inner.call(req), - sleep: timeout_duration - .map(tokio::time::sleep) - .map(Some) - .unwrap_or(None), - } - } -} - -#[pin_project] -pub(crate) struct ResponseFuture { - #[pin] - inner: F, - #[pin] - sleep: Option, -} - -impl Future for ResponseFuture -where - F: Future>, - E: Into, -{ - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - - if let Poll::Ready(result) = this.inner.poll(cx) { - return Poll::Ready(result.map_err(Into::into)); - } - - if let Some(sleep) = this.sleep.as_pin_mut() { - ready!(sleep.poll(cx)); - return Poll::Ready(Err(TimeoutExpired(()).into())); - } - - Poll::Pending - } -} - -const SECONDS_IN_HOUR: u64 = 60 * 60; -const SECONDS_IN_MINUTE: u64 = 60; - -/// Tries to parse the `grpc-timeout` header if it is present. If we fail to parse, returns -/// the value we attempted to parse. -/// -/// Follows the [gRPC over HTTP2 spec](https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md). -fn try_parse_grpc_timeout( - headers: &HeaderMap, -) -> Result, &HeaderValue> { - match headers.get(GRPC_TIMEOUT_HEADER) { - Some(val) => { - let (timeout_value, timeout_unit) = val - .to_str() - .map_err(|_| val) - .and_then(|s| if s.is_empty() { Err(val) } else { Ok(s) })? - // `HeaderValue::to_str` only returns `Ok` if the header contains ASCII so this - // `split_at` will never panic from trying to split in the middle of a character. - // See https://docs.rs/http/0.2.4/http/header/struct.HeaderValue.html#method.to_str - // - // `len - 1` also wont panic since we just checked `s.is_empty`. - .split_at(val.len() - 1); - - // gRPC spec specifies `TimeoutValue` will be at most 8 digits - // Caping this at 8 digits also prevents integer overflow from ever occurring - if timeout_value.len() > 8 { - return Err(val); - } - - let timeout_value: u64 = timeout_value.parse().map_err(|_| val)?; - - let duration = match timeout_unit { - // Hours - "H" => Duration::from_secs(timeout_value * SECONDS_IN_HOUR), - // Minutes - "M" => Duration::from_secs(timeout_value * SECONDS_IN_MINUTE), - // Seconds - "S" => Duration::from_secs(timeout_value), - // Milliseconds - "m" => Duration::from_millis(timeout_value), - // Microseconds - "u" => Duration::from_micros(timeout_value), - // Nanoseconds - "n" => Duration::from_nanos(timeout_value), - _ => return Err(val), - }; - - Ok(Some(duration)) - } - None => Ok(None), - } -} - -/// Error returned if a request didn't complete within the configured timeout. -/// -/// Timeouts can be configured either with [`Endpoint::timeout`], [`Server::timeout`], or by -/// setting the [`grpc-timeout` metadata value][spec]. -/// -/// [`Endpoint::timeout`]: crate::transport::server::Server::timeout -/// [`Server::timeout`]: crate::transport::channel::Endpoint::timeout -/// [spec]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md -#[derive(Debug)] -pub struct TimeoutExpired(()); - -impl fmt::Display for TimeoutExpired { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Timeout expired") - } -} - -// std::error::Error only requires a type to impl Debug and Display -impl std::error::Error for TimeoutExpired {} - -#[cfg(test)] -mod tests { - use super::*; - use quickcheck::{Arbitrary, Gen}; - use quickcheck_macros::quickcheck; - - // Helper function to reduce the boiler plate of our test cases - fn setup_map_try_parse(val: Option<&str>) -> Result, HeaderValue> { - let mut hm = HeaderMap::new(); - if let Some(v) = val { - let hv = HeaderValue::from_str(v).unwrap(); - hm.insert(GRPC_TIMEOUT_HEADER, hv); - }; - - try_parse_grpc_timeout(&hm).map_err(|e| e.clone()) - } - - #[test] - fn test_hours() { - let parsed_duration = setup_map_try_parse(Some("3H")).unwrap().unwrap(); - assert_eq!(Duration::from_secs(3 * 60 * 60), parsed_duration); - } - - #[test] - fn test_minutes() { - let parsed_duration = setup_map_try_parse(Some("1M")).unwrap().unwrap(); - assert_eq!(Duration::from_secs(60), parsed_duration); - } - - #[test] - fn test_seconds() { - let parsed_duration = setup_map_try_parse(Some("42S")).unwrap().unwrap(); - assert_eq!(Duration::from_secs(42), parsed_duration); - } - - #[test] - fn test_milliseconds() { - let parsed_duration = setup_map_try_parse(Some("13m")).unwrap().unwrap(); - assert_eq!(Duration::from_millis(13), parsed_duration); - } - - #[test] - fn test_microseconds() { - let parsed_duration = setup_map_try_parse(Some("2u")).unwrap().unwrap(); - assert_eq!(Duration::from_micros(2), parsed_duration); - } - - #[test] - fn test_nanoseconds() { - let parsed_duration = setup_map_try_parse(Some("82n")).unwrap().unwrap(); - assert_eq!(Duration::from_nanos(82), parsed_duration); - } - - #[test] - fn test_header_not_present() { - let parsed_duration = setup_map_try_parse(None).unwrap(); - assert!(parsed_duration.is_none()); - } - - #[test] - #[should_panic(expected = "82f")] - fn test_invalid_unit() { - // "f" is not a valid TimeoutUnit - setup_map_try_parse(Some("82f")).unwrap().unwrap(); - } - - #[test] - #[should_panic(expected = "123456789H")] - fn test_too_many_digits() { - // gRPC spec states TimeoutValue will be at most 8 digits - setup_map_try_parse(Some("123456789H")).unwrap().unwrap(); - } - - #[test] - #[should_panic(expected = "oneH")] - fn test_invalid_digits() { - // gRPC spec states TimeoutValue will be at most 8 digits - setup_map_try_parse(Some("oneH")).unwrap().unwrap(); - } - - #[quickcheck] - fn fuzz(header_value: HeaderValueGen) -> bool { - let header_value = header_value.0; - - // this just shouldn't panic - let _ = setup_map_try_parse(Some(&header_value)); - - true - } - - /// Newtype to implement `Arbitrary` for generating `String`s that are valid `HeaderValue`s. - #[derive(Clone, Debug)] - struct HeaderValueGen(String); - - impl Arbitrary for HeaderValueGen { - fn arbitrary(g: &mut Gen) -> Self { - let max = g.choose(&(1..70).collect::>()).copied().unwrap(); - Self(gen_string(g, 0, max)) - } - } - - // copied from https://github.com/hyperium/http/blob/master/tests/header_map_fuzz.rs - fn gen_string(g: &mut Gen, min: usize, max: usize) -> String { - let bytes: Vec<_> = (min..max) - .map(|_| { - // Chars to pick from - g.choose(b"ABCDEFGHIJKLMNOPQRSTUVabcdefghilpqrstuvwxyz----") - .copied() - .unwrap() - }) - .collect(); - - String::from_utf8(bytes).unwrap() - } -} diff --git a/tonic/src/transport/service/mod.rs b/tonic/src/transport/service/mod.rs index f70d445cb..1858cdf5f 100644 --- a/tonic/src/transport/service/mod.rs +++ b/tonic/src/transport/service/mod.rs @@ -1,3 +1,287 @@ -pub(crate) mod grpc_timeout; +use crate::metadata::GRPC_TIMEOUT_HEADER; +use http::{HeaderMap, HeaderValue, Request}; +use pin_project::pin_project; +use std::{ + fmt, + future::Future, + pin::Pin, + task::{ready, Context, Poll}, + time::Duration, +}; +use tokio::time::Sleep; +use tower_service::Service; -pub(crate) use self::grpc_timeout::GrpcTimeout; +#[derive(Debug, Clone)] +pub(crate) struct GrpcTimeout { + inner: S, + server_timeout: Option, +} + +impl GrpcTimeout { + pub(crate) fn new(inner: S, server_timeout: Option) -> Self { + Self { + inner, + server_timeout, + } + } +} + +impl Service> for GrpcTimeout +where + S: Service>, + S::Error: Into, +{ + type Response = S::Response; + type Error = crate::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, req: Request) -> Self::Future { + let client_timeout = try_parse_grpc_timeout(req.headers()).unwrap_or_else(|e| { + tracing::trace!("Error parsing `grpc-timeout` header {:?}", e); + None + }); + + // Use the shorter of the two durations, if either are set + let timeout_duration = match (client_timeout, self.server_timeout) { + (None, None) => None, + (Some(dur), None) => Some(dur), + (None, Some(dur)) => Some(dur), + (Some(header), Some(server)) => { + let shorter_duration = std::cmp::min(header, server); + Some(shorter_duration) + } + }; + + ResponseFuture { + inner: self.inner.call(req), + sleep: timeout_duration + .map(tokio::time::sleep) + .map(Some) + .unwrap_or(None), + } + } +} + +#[pin_project] +pub(crate) struct ResponseFuture { + #[pin] + inner: F, + #[pin] + sleep: Option, +} + +impl Future for ResponseFuture +where + F: Future>, + E: Into, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + if let Poll::Ready(result) = this.inner.poll(cx) { + return Poll::Ready(result.map_err(Into::into)); + } + + if let Some(sleep) = this.sleep.as_pin_mut() { + ready!(sleep.poll(cx)); + return Poll::Ready(Err(TimeoutExpired(()).into())); + } + + Poll::Pending + } +} + +const SECONDS_IN_HOUR: u64 = 60 * 60; +const SECONDS_IN_MINUTE: u64 = 60; + +/// Tries to parse the `grpc-timeout` header if it is present. If we fail to parse, returns +/// the value we attempted to parse. +/// +/// Follows the [gRPC over HTTP2 spec](https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md). +fn try_parse_grpc_timeout( + headers: &HeaderMap, +) -> Result, &HeaderValue> { + match headers.get(GRPC_TIMEOUT_HEADER) { + Some(val) => { + let (timeout_value, timeout_unit) = val + .to_str() + .map_err(|_| val) + .and_then(|s| if s.is_empty() { Err(val) } else { Ok(s) })? + // `HeaderValue::to_str` only returns `Ok` if the header contains ASCII so this + // `split_at` will never panic from trying to split in the middle of a character. + // See https://docs.rs/http/0.2.4/http/header/struct.HeaderValue.html#method.to_str + // + // `len - 1` also wont panic since we just checked `s.is_empty`. + .split_at(val.len() - 1); + + // gRPC spec specifies `TimeoutValue` will be at most 8 digits + // Caping this at 8 digits also prevents integer overflow from ever occurring + if timeout_value.len() > 8 { + return Err(val); + } + + let timeout_value: u64 = timeout_value.parse().map_err(|_| val)?; + + let duration = match timeout_unit { + // Hours + "H" => Duration::from_secs(timeout_value * SECONDS_IN_HOUR), + // Minutes + "M" => Duration::from_secs(timeout_value * SECONDS_IN_MINUTE), + // Seconds + "S" => Duration::from_secs(timeout_value), + // Milliseconds + "m" => Duration::from_millis(timeout_value), + // Microseconds + "u" => Duration::from_micros(timeout_value), + // Nanoseconds + "n" => Duration::from_nanos(timeout_value), + _ => return Err(val), + }; + + Ok(Some(duration)) + } + None => Ok(None), + } +} + +/// Error returned if a request didn't complete within the configured timeout. +/// +/// Timeouts can be configured either with [`Endpoint::timeout`], [`Server::timeout`], or by +/// setting the [`grpc-timeout` metadata value][spec]. +/// +/// [`Endpoint::timeout`]: crate::transport::server::Server::timeout +/// [`Server::timeout`]: crate::transport::channel::Endpoint::timeout +/// [spec]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md +#[derive(Debug)] +pub struct TimeoutExpired(()); + +impl fmt::Display for TimeoutExpired { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Timeout expired") + } +} + +// std::error::Error only requires a type to impl Debug and Display +impl std::error::Error for TimeoutExpired {} + +#[cfg(test)] +mod tests { + use super::*; + use quickcheck::{Arbitrary, Gen}; + use quickcheck_macros::quickcheck; + + // Helper function to reduce the boiler plate of our test cases + fn setup_map_try_parse(val: Option<&str>) -> Result, HeaderValue> { + let mut hm = HeaderMap::new(); + if let Some(v) = val { + let hv = HeaderValue::from_str(v).unwrap(); + hm.insert(GRPC_TIMEOUT_HEADER, hv); + }; + + try_parse_grpc_timeout(&hm).map_err(|e| e.clone()) + } + + #[test] + fn test_hours() { + let parsed_duration = setup_map_try_parse(Some("3H")).unwrap().unwrap(); + assert_eq!(Duration::from_secs(3 * 60 * 60), parsed_duration); + } + + #[test] + fn test_minutes() { + let parsed_duration = setup_map_try_parse(Some("1M")).unwrap().unwrap(); + assert_eq!(Duration::from_secs(60), parsed_duration); + } + + #[test] + fn test_seconds() { + let parsed_duration = setup_map_try_parse(Some("42S")).unwrap().unwrap(); + assert_eq!(Duration::from_secs(42), parsed_duration); + } + + #[test] + fn test_milliseconds() { + let parsed_duration = setup_map_try_parse(Some("13m")).unwrap().unwrap(); + assert_eq!(Duration::from_millis(13), parsed_duration); + } + + #[test] + fn test_microseconds() { + let parsed_duration = setup_map_try_parse(Some("2u")).unwrap().unwrap(); + assert_eq!(Duration::from_micros(2), parsed_duration); + } + + #[test] + fn test_nanoseconds() { + let parsed_duration = setup_map_try_parse(Some("82n")).unwrap().unwrap(); + assert_eq!(Duration::from_nanos(82), parsed_duration); + } + + #[test] + fn test_header_not_present() { + let parsed_duration = setup_map_try_parse(None).unwrap(); + assert!(parsed_duration.is_none()); + } + + #[test] + #[should_panic(expected = "82f")] + fn test_invalid_unit() { + // "f" is not a valid TimeoutUnit + setup_map_try_parse(Some("82f")).unwrap().unwrap(); + } + + #[test] + #[should_panic(expected = "123456789H")] + fn test_too_many_digits() { + // gRPC spec states TimeoutValue will be at most 8 digits + setup_map_try_parse(Some("123456789H")).unwrap().unwrap(); + } + + #[test] + #[should_panic(expected = "oneH")] + fn test_invalid_digits() { + // gRPC spec states TimeoutValue will be at most 8 digits + setup_map_try_parse(Some("oneH")).unwrap().unwrap(); + } + + #[quickcheck] + fn fuzz(header_value: HeaderValueGen) -> bool { + let header_value = header_value.0; + + // this just shouldn't panic + let _ = setup_map_try_parse(Some(&header_value)); + + true + } + + /// Newtype to implement `Arbitrary` for generating `String`s that are valid `HeaderValue`s. + #[derive(Clone, Debug)] + struct HeaderValueGen(String); + + impl Arbitrary for HeaderValueGen { + fn arbitrary(g: &mut Gen) -> Self { + let max = g.choose(&(1..70).collect::>()).copied().unwrap(); + Self(gen_string(g, 0, max)) + } + } + + // copied from https://github.com/hyperium/http/blob/master/tests/header_map_fuzz.rs + fn gen_string(g: &mut Gen, min: usize, max: usize) -> String { + let bytes: Vec<_> = (min..max) + .map(|_| { + // Chars to pick from + g.choose(b"ABCDEFGHIJKLMNOPQRSTUVabcdefghilpqrstuvwxyz----") + .copied() + .unwrap() + }) + .collect(); + + String::from_utf8(bytes).unwrap() + } +}