Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: request body parsing #956

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ serde_json = "1.0.109"
once_cell = "1.8.0"
actix-multipart = "0.6.1"
parking_lot = "0.12.3"
percent-encoding = "2.3"
sansyrox marked this conversation as resolved.
Show resolved Hide resolved

[features]
io-uring = ["actix-web/experimental-io-uring"]
Expand Down
18 changes: 18 additions & 0 deletions integration_tests/base_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,13 @@ def sync_multipart_file(request: Request):
return {"file_names": list(file_names)}


@app.post("/sync/multipart-file-form")
def sync_multipart_file_form(request: Request):
files = request.files
file_names = files.keys()
return {"form_data": request.form_data, "file_names": list(file_names)}


# Queries


Expand Down Expand Up @@ -584,6 +591,11 @@ def sync_form_data(request: Request):
return request.headers["Content-Type"]


@app.post("/sync/simple_form_data")
def sync_simple_form_data(request: Request):
return request.form_data["list_field"]


# JSON Request


Expand All @@ -609,6 +621,12 @@ async def request_json(request: Request):
return json["key"]


@app.post("/sync/request_json/list")
async def request_json_list(request: Request):
json = request.json()
return json["field"]


# --- PUT ---

# dict
Expand Down
3 changes: 2 additions & 1 deletion integration_tests/helpers/http_methods_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def post(

def multipart_post(
endpoint: str,
form_data: Optional[dict] = None,
files: Optional[dict] = None,
expected_status_code: int = 200,
should_check_response: bool = True,
Expand All @@ -77,7 +78,7 @@ def multipart_post(
"""

endpoint = endpoint.strip("/")
response = requests.post(f"{BASE_URL}/{endpoint}", files=files)
response = requests.post(f"{BASE_URL}/{endpoint}", files=files, data=form_data)
if should_check_response:
check_response(response, expected_status_code)
return response
Expand Down
17 changes: 16 additions & 1 deletion integration_tests/test_multipart_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from integration_tests.helpers.http_methods_helpers import multipart_post
from integration_tests.helpers.http_methods_helpers import multipart_post, post


@pytest.mark.benchmark
Expand All @@ -14,3 +14,18 @@ def test_form_data(function_type: str, session):
def test_multipart_file(function_type: str, session):
res = multipart_post(f"/{function_type}/multipart-file", files={"hello": "world"})
assert "hello" in res.text


@pytest.mark.benchmark
@pytest.mark.parametrize("function_type", ["sync"])
def test_multipart_file_and_form_data(function_type: str, session):
res = multipart_post(f"/{function_type}/multipart-file-form", files={"hello": "world"}, form_data={"list_field": ["a=", "b"]})
assert "hello" in res.text and '["a=","b"]' in res.text


@pytest.mark.benchmark
@pytest.mark.parametrize("function_type", ["sync"])
def test_simple_form_data(function_type: str, session):
data = {"hello": "world", "list_field": ["a=", "b"]}
res = post(f"/{function_type}/simple_form_data", data=data)
assert str(data["list_field"]) == res.text
1 change: 1 addition & 0 deletions integration_tests/test_request_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
("/sync/request_json", '{"hello": "world"', "None"),
("/async/request_json", '{"hello": "world"}', "<class 'dict'>"),
("/async/request_json", '{"hello": "world"', "None"),
("/sync/request_json/list", '{"hello": "world", "field": ["a=", "b"]}', "['a=', 'b']"),
],
)
def test_request(route, body, expected_result):
Expand Down
8 changes: 4 additions & 4 deletions robyn/robyn.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,8 @@ class Request:
body (Union[str, bytes]): The body of the request. If the request is a JSON, it will be a dict.
method (str): The method of the request. e.g. GET, POST, PUT etc.
url (Url): The url of the request. e.g. https://localhost/user
form_data (dict[str, str]): The form data of the request. e.g. {"name": "John"}
files (dict[str, bytes]): The files of the request. e.g. {"file": b"file"}
form_data (dict[str, Union[list[str], str]]): The form data of the request. e.g. {"name": "John"}
files (dict[str, dict[str, Union[list[bytes], bytes]]]): The files of the request. e.g. {"file": {"file_name": b"file"}}
ip_addr (Optional[str]): The IP Address of the client
identity (Optional[Identity]): The identity of the client
"""
Expand All @@ -266,8 +266,8 @@ class Request:
body: Union[str, bytes]
method: str
url: Url
form_data: dict[str, str]
files: dict[str, bytes]
form_data: dict[str, Union[list[str], str]]
files: dict[str, dict[str, Union[list[bytes], bytes]]]
ip_addr: Optional[str]
identity: Optional[Identity]

Expand Down
59 changes: 58 additions & 1 deletion src/types/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use std::collections::HashMap;
use log::debug;
use pyo3::{
exceptions::PyValueError,
prelude::*,
types::{PyBytes, PyString},
types::{PyBool, PyBytes, PyDict, PyFloat, PyInt, PyList, PyString, PyTuple},
};
use serde_json::Value;

pub mod function_info;
pub mod headers;
Expand Down Expand Up @@ -84,6 +86,61 @@ pub fn get_body_from_pyobject(body: &PyAny) -> PyResult<Vec<u8>> {
}
}

pub fn get_form_data_from_pyobject(form_data: &PyAny) -> PyResult<Option<HashMap<String, Value>>> {
if let Ok(py_dict) = form_data.downcast::<PyDict>() {
let mut map = HashMap::new();
for (key, value) in py_dict.iter() {
let key_str: String = key.extract()?;
let json_value: Value = pyany_to_value(value)?;
map.insert(key_str, json_value);
}
Ok(Some(map))
} else {
debug!("Could not convert specified form data");
Ok(None)
}
}

fn pyany_to_value(obj: &PyAny) -> PyResult<Value> {
if obj.is_none() {
Ok(Value::Null)
} else if let Ok(val) = obj.downcast::<PyBool>() {
Ok(Value::Bool(val.is_true()))
} else if let Ok(val) = obj.downcast::<PyInt>() {
let int_val: i64 = val.extract()?;
Ok(Value::Number(int_val.into()))
} else if let Ok(val) = obj.downcast::<PyFloat>() {
let float_val: f64 = val.extract()?;
Ok(Value::Number(serde_json::Number::from_f64(float_val).ok_or_else(|| {
PyValueError::new_err("Failed to convert float")
})?))
} else if let Ok(val) = obj.downcast::<PyString>() {
let str_val: String = val.extract()?;
Ok(Value::String(str_val))
} else if let Ok(val) = obj.downcast::<PyBytes>() {
let bytes_val = val.extract::<Vec<u8>>()?.into_iter().map(|c| Value::Number(c.into())).collect();
Ok(Value::Array(bytes_val))
} else if let Ok(dict) = obj.downcast::<PyDict>() {
let mut map = serde_json::Map::new();
for (key, value) in dict.iter() {
let key_str: String = key.extract()?;
let json_value = pyany_to_value(value)?;
map.insert(key_str, json_value);
}
Ok(Value::Object(map))
} else if let Ok(list) = obj.downcast::<PyList>() {
let vec = list.iter().map(pyany_to_value).collect::<Result<Vec<_>, _>>()?;
Ok(Value::Array(vec))
} else if let Ok(tuple) = obj.downcast::<PyTuple>() {
let vec = tuple.iter().map(pyany_to_value).collect::<Result<Vec<_>, _>>()?;
Ok(Value::Array(vec))
} else {
Err(PyValueError::new_err(
"Unsupported Python type for conversion to JSON",
))
}
}

pub fn get_description_from_pyobject(description: &PyAny) -> PyResult<Vec<u8>> {
if let Ok(s) = description.downcast::<PyString>() {
Ok(s.to_string().into_bytes())
Expand Down
Loading