Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 35 additions & 13 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ impl Mock {
self
}

pub fn add_matcher(&mut self, matcher: impl Match + Send + Sync + 'static) {
pub fn add_matcher(mut self, matcher: impl Match + Send + Sync + 'static) -> Self {
self.matcher.push(Arc::new(matcher));
self
}

/// You can use this to verify the mock separately to the one you put into the server (if
Expand All @@ -92,28 +93,36 @@ async fn ws_handler_pathless(
ws_handler(ws, Path(String::new()), headers, params, mocks).await
}

#[inline(always)]
fn can_consider(match_res: Option<bool>) -> bool {
matches!(match_res, Some(true) | None)
}

async fn ws_handler(
ws: WebSocketUpgrade,
Path(path): Path<String>,
headers: HeaderMap,
Query(params): Query<HashMap<String, String>>,
mocks: Extension<MockList>,
) -> Response {
let mut active_mocks = vec![];
{
debug!("checking request level matches");
let mocks = mocks.read().await.clone();
for mock in &mocks {
for (index, mock) in mocks.iter().enumerate() {
if mock
.matcher
.iter()
.all(|x| x.request_match(&path, &headers, &params))
.map(|x| x.request_match(&path, &headers, &params))
.all(can_consider)
{
mock.calls.fetch_add(1, Ordering::Acquire);
active_mocks.push(index);
}
}
}
debug!("about to upgrade websocket connection");
ws.on_upgrade(|socket| async move { handle_socket(socket, mocks.0).await })
ws.on_upgrade(|socket| async move { handle_socket(socket, mocks.0, active_mocks).await })
}

fn convert_message(msg: AxumMessage) -> Message {
Expand All @@ -129,16 +138,25 @@ fn convert_message(msg: AxumMessage) -> Message {
}
}

async fn handle_socket(mut socket: WebSocket, mocks: MockList) {
async fn handle_socket(mut socket: WebSocket, mocks: MockList, active_mocks: Vec<usize>) {
// Clone the mocks present when the connection comes in
let mocks: Vec<Mock> = mocks.read().await.clone();
let active_mocks = active_mocks
.iter()
.filter_map(|m| mocks.get(*m))
.collect::<Vec<&Mock>>();
println!("{} mocks loaded", mocks.len());
while let Some(msg) = socket.recv().await {
if let Ok(msg) = msg {
let msg = convert_message(msg);
debug!("Checking: {:?}", msg);
for mock in &mocks {
if mock.matcher.iter().all(|x| x.unary_match(&msg)) {
for mock in &active_mocks {
if mock
.matcher
.iter()
.map(|x| x.unary_match(&msg))
.all(can_consider)
{
mock.calls.fetch_add(1, Ordering::Acquire);
}
}
Expand Down Expand Up @@ -206,6 +224,10 @@ impl MockServer {
assert!(mock.verify())
}
}

pub async fn mocks_pass(&self) -> bool {
self.mocks.read().await.iter().all(|x| x.verify())
}
}

impl Drop for MockServer {
Expand All @@ -221,12 +243,12 @@ pub trait Match {
path: &str,
headers: &HeaderMap,
query: &HashMap<String, String>,
) -> bool {
false
) -> Option<bool> {
None
}

fn unary_match(&self, message: &Message) -> bool {
false
fn unary_match(&self, message: &Message) -> Option<bool> {
None
}
}

Expand All @@ -235,7 +257,7 @@ where
F: Fn(&Message) -> bool,
F: Send + Sync,
{
fn unary_match(&self, msg: &Message) -> bool {
self(msg)
fn unary_match(&self, msg: &Message) -> Option<bool> {
Some(self(msg))
}
}
36 changes: 17 additions & 19 deletions src/matchers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ impl Match for PathExactMatcher {
path: &str,
_headers: &HeaderMap,
_query: &HashMap<String, String>,
) -> bool {
self.0 == path
) -> Option<bool> {
Some(self.0 == path)
}
}

Expand All @@ -38,11 +38,9 @@ impl Match for HeaderExactMatcher {
_path: &str,
headers: &HeaderMap,
_query: &HashMap<String, String>,
) -> bool {
) -> Option<bool> {
let all_values = headers.get_all(&self.0);
println!("All: {:?}", all_values);
println!("Checking against: {:?}: {:?}", self.0, self.1);
self.1.iter().all(|x| all_values.iter().any(|v| v == x))
Some(self.1.iter().all(|x| all_values.iter().any(|v| v == x)))
}
}

Expand Down Expand Up @@ -71,8 +69,8 @@ impl Match for HeaderExistsMatcher {
_path: &str,
headers: &HeaderMap,
_query: &HashMap<String, String>,
) -> bool {
headers.contains_key(&self.0)
) -> Option<bool> {
Some(headers.contains_key(&self.0))
}
}

Expand All @@ -98,8 +96,8 @@ impl Match for QueryParamExactMatcher {
_path: &str,
_headers: &HeaderMap,
query: &HashMap<String, String>,
) -> bool {
query.get(&self.name) == Some(&self.value)
) -> Option<bool> {
Some(query.get(&self.name) == Some(&self.value))
}
}

Expand All @@ -123,11 +121,11 @@ impl Match for QueryParamContainsMatcher {
_path: &str,
_headers: &HeaderMap,
query: &HashMap<String, String>,
) -> bool {
) -> Option<bool> {
if let Some(s) = query.get(&self.name) {
s.contains(&self.value)
Some(s.contains(&self.value))
} else {
false
Some(false)
}
}
}
Expand All @@ -149,8 +147,8 @@ impl Match for QueryParamIsMissingMatcher {
_path: &str,
_headers: &HeaderMap,
query: &HashMap<String, String>,
) -> bool {
!query.contains_key(&self.0)
) -> Option<bool> {
Some(!query.contains_key(&self.0))
}
}

Expand All @@ -171,11 +169,11 @@ pub mod json {
pub struct ValidJsonMatcher;

impl Match for ValidJsonMatcher {
fn unary_match(&self, msg: &Message) -> bool {
fn unary_match(&self, msg: &Message) -> Option<bool> {
match msg {
Message::Text(t) => serde_json::from_str::<Value>(&t).is_ok(),
Message::Binary(b) => serde_json::from_slice::<Value>(b.as_ref()).is_ok(),
_ => false, // We can't be judging pings/pongs/closes
Message::Text(t) => Some(serde_json::from_str::<Value>(&t).is_ok()),
Message::Binary(b) => Some(serde_json::from_slice::<Value>(b.as_ref()).is_ok()),
_ => None,
}
}
}
Expand Down
42 changes: 39 additions & 3 deletions tests/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,9 @@ async fn header_doesnt_match() {
async fn query_param_matchers() {
let server = MockServer::start().await;

let mut mock = Mock::given(QueryParamExactMatcher::new("hello", "world"));
mock.add_matcher(QueryParamContainsMatcher::new("foo", "ar"));
mock.add_matcher(QueryParamIsMissingMatcher::new("not_here"));
let mut mock = Mock::given(QueryParamExactMatcher::new("hello", "world"))
.add_matcher(QueryParamContainsMatcher::new("foo", "ar"))
.add_matcher(QueryParamIsMissingMatcher::new("not_here"));

server.register(mock.expect(1..)).await;

Expand All @@ -197,3 +197,39 @@ async fn query_param_matchers() {

server.verify().await;
}

#[tokio::test]
#[traced_test]
async fn combine_request_and_content_matchers() {
let server = MockServer::start().await;

server
.register(
Mock::given(path("api/stream"))
.add_matcher(ValidJsonMatcher)
.expect(1..),
)
.await;

let (mut stream, response) = connect_async(format!("{}/api", server.uri()))
.await
.unwrap();

// Send a message just to show it doesn't change anything.
let val = json!({"hello": "world"});
stream.send(Message::text(val.to_string())).await.unwrap();
sleep(Duration::from_millis(200)).await;

assert!(!server.mocks_pass().await);

let (mut stream, response) = connect_async(format!("{}/api/stream", server.uri()))
.await
.unwrap();

// Send a message just to show it doesn't change anything.
let val = json!({"hello": "world"});
stream.send(Message::text(val.to_string())).await.unwrap();
sleep(Duration::from_millis(200)).await;

assert!(server.mocks_pass().await);
}