Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support serializing/deserializing baml_py.Image, baml_py.Audio with pydantic #1062

Merged
merged 7 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions docs/docs/snippets/supported-types.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,25 @@ See [Union(|)](#union-) for more details.
See [calling a function with multimodal types](/docs/snippets/calling-baml/multi-modal)
and [testing image inputs](/docs/snippets/test-cases#images)

<Accordion title="Implementation details: runtime and security considerations">
BAML's multimodal types are designed for ease of use: we have deliberately made it
easy for you to construct a `image` or `audio` instance from a URL. Under the
hood, depending on the model you're using, BAML may need to download the image
and transcode it (usually as base64) for the model to consume.

This ease-of-use does come with some tradeoffs; namely, if you construct
an `image` or `audio` instance using untrusted user input, you may be exposing
yourself to [server-side request forgery (SSRF) attacks][ssrf]. Attackers may be
able to fetch files on your internal network, on external networks using your
application's identity, or simply excessively drive up your cloud network
bandwidth bill.

To prevent this, we recommend only using URLs from trusted sources/users or
validating them using allowlists or denylists.

[ssrf]: https://portswigger.net/web-security/ssrf
</Accordion>

### `image`

You can use an image like this for models that support them:
Expand Down Expand Up @@ -114,6 +133,36 @@ end
```
</CodeBlocks>

<Accordion title="Pydantic compatibility">
If using Pydantic, the following are valid ways to construct the `Image` type.

```json
{
"url": "https://upload.wikimedia.org/wikipedia/en/4/4d/Shrek_%28character%29.png"
}
```

```json
{
"url": "https://upload.wikimedia.org/wikipedia/en/4/4d/Shrek_%28character%29.png",
"media_type": "image/png"
}
```

```json
{
"base64": "iVBORw0K....",
}
```

```json
{
"base64": "iVBORw0K....",
"media_type": "image/png"
}
```
</Accordion>

### `audio`

Example
Expand Down
8 changes: 6 additions & 2 deletions engine/baml-lib/baml-types/src/media.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@ impl fmt::Display for BamlMediaType {
}
}

// We rely on the serialization and deserialization of this struct for:
// - prompt rendering (going into minijinja rendering and coming out)
/// We rely on the serialization and deserialization of this struct for:
///
/// - prompt rendering (going into minijinja rendering and coming out)
/// i.e. when we render a prompt, minijinja operates on a string; that
/// string needs to encode BamlMedia instances, and this is how we do
/// that
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct BamlMedia {
pub media_type: BamlMediaType,
Expand Down
6 changes: 6 additions & 0 deletions engine/language_client_python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ license = "Apache-2.0"
name = "baml_py"
crate-type = ["cdylib"]

[lints.rust]
dead_code = "deny"
unused_imports = "deny"
unused_must_use = "deny"
unused_variables = "deny"

[dependencies]
anyhow.workspace = true
baml-types.workspace = true
Expand Down
5 changes: 5 additions & 0 deletions engine/language_client_python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,8 @@ features = ["pyo3/extension-module"]

[project.scripts]
baml-cli = "baml_py:invoke_runtime_cli"

# NOTE: dependencies is _deliberately_ empty so that different flavors of BAML
# don't have to share dependencies, e.g. we currently generate python/pydantic
# but if we generate python/vanilla in the future, we don't want to mess
# with that.
14 changes: 2 additions & 12 deletions engine/language_client_python/src/errors.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
use baml_runtime::{
errors::ExposedError, internal::llm_client::LLMResponse, scope_diagnostics::ScopeStack,
};
use pyo3::prelude::pyclass;
use pyo3::types::PyModule;
use pyo3::{
create_exception, py_run, pyfunction, pymodule, wrap_pyfunction, wrap_pymodule, Bound, PyClass,
PyErr, PyResult, Python,
};
use pyo3::{create_exception, pymodule, Bound, PyErr, PyResult, Python};

create_exception!(baml_py, BamlError, pyo3::exceptions::PyException);
// Existing exception definitions
Expand Down Expand Up @@ -66,13 +62,7 @@ impl BamlError {
} => {
// Assuming ValidationError has fields that correspond to prompt, message, and raw_output
// If not, you may need to adjust this part based on the actual structure of ValidationError
Python::with_gil(|py| {
raise_baml_validation_error(
prompt.clone(),
message.clone(),
raw_output.clone(),
)
})
raise_baml_validation_error(prompt.clone(), message.clone(), raw_output.clone())
}
}
} else if let Some(er) = err.downcast_ref::<ScopeStack>() {
Expand Down
35 changes: 18 additions & 17 deletions engine/language_client_python/src/types/audio.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use baml_types::BamlMediaContent;
use pyo3::prelude::{pymethods, PyAnyMethods, PyModule, PyResult};
use pyo3::prelude::{pymethods, PyResult};
use pyo3::types::PyType;
use pyo3::{Bound, Py, PyAny, PyObject, Python, ToPyObject};
use pyo3::{Bound, PyAny, PyObject, Python};
use pythonize::{depythonize_bound, pythonize};

use crate::errors::BamlError;

use super::media_repr::{self, UserFacingBamlMedia};
crate::lang_wrapper!(BamlAudioPy, baml_types::BamlMedia);

#[pymethods]
Expand Down Expand Up @@ -63,31 +66,29 @@ impl BamlAudioPy {
}
}

// Makes it work with Pydantic
#[classmethod]
pub fn __get_pydantic_core_schema__(
_cls: Bound<'_, PyType>,
_source_type: Bound<'_, PyAny>,
_handler: Bound<'_, PyAny>,
) -> PyResult<PyObject> {
Python::with_gil(|py| {
let code = r#"
from pydantic_core import core_schema

def get_schema():
# No validation
return core_schema.any_schema()
media_repr::__get_pydantic_core_schema__(_cls, _source_type, _handler)
}

ret = get_schema()
"#;
// py.run(code, None, Some(ret_dict));
let fun: Py<PyAny> = PyModule::from_code_bound(py, code, "", "")?
.getattr("ret")?
.into();
Ok(fun.to_object(py)) // Return the PyObject
#[staticmethod]
fn baml_deserialize(data: PyObject, py: Python<'_>) -> PyResult<Self> {
let data: UserFacingBamlMedia = depythonize_bound(data.into_bound(py))?;
Ok(BamlAudioPy {
inner: data.to_baml_media(baml_types::BamlMediaType::Audio),
})
}

pub fn baml_serialize(&self, py: Python<'_>) -> PyResult<PyObject> {
let s: UserFacingBamlMedia = (&self.inner).try_into().map_err(BamlError::from_anyhow)?;
let s = serde_json::to_value(&s).map_err(|e| BamlError::from_anyhow(e.into()))?;
Ok(pythonize(py, &s)?)
}

pub fn __eq__(&self, other: &Self) -> bool {
self.inner == other.inner
}
Expand Down
3 changes: 1 addition & 2 deletions engine/language_client_python/src/types/function_results.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use baml_types::BamlValue;
use pyo3::prelude::{pymethods, PyResult};
use pyo3::types::{PyAnyMethods, PyListMethods, PyModule};
use pyo3::types::{PyAnyMethods, PyModule};
use pyo3::{Bound, IntoPy, PyObject, Python};
use pythonize::pythonize;

use crate::errors::BamlError;

Expand Down
37 changes: 19 additions & 18 deletions engine/language_client_python/src/types/image.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use pyo3::prelude::{pymethods, PyAnyMethods, PyModule, PyResult};
use pyo3::prelude::{pymethods, PyResult};
use pyo3::types::PyType;
use pyo3::{Bound, Py, PyAny, PyObject, Python, ToPyObject};
use pyo3::{Bound, PyAny, PyObject, Python};
use pythonize::{depythonize_bound, pythonize};

use crate::errors::BamlInvalidArgumentError;
use crate::errors::{BamlError, BamlInvalidArgumentError};

use super::media_repr::{self, UserFacingBamlMedia};
crate::lang_wrapper!(BamlImagePy, baml_types::BamlMedia);

#[pymethods]
Expand Down Expand Up @@ -62,31 +65,29 @@ impl BamlImagePy {
}
}

// Makes it work with Pydantic
#[classmethod]
pub fn __get_pydantic_core_schema__(
_cls: Bound<'_, PyType>,
_source_type: Bound<'_, PyAny>,
_handler: Bound<'_, PyAny>,
) -> PyResult<PyObject> {
Python::with_gil(|py| {
let code = r#"
from pydantic_core import core_schema

def get_schema():
# No validation
return core_schema.any_schema()
media_repr::__get_pydantic_core_schema__(_cls, _source_type, _handler)
}

ret = get_schema()
"#;
// py.run(code, None, Some(ret_dict));
let fun: Py<PyAny> = PyModule::from_code_bound(py, code, "", "")?
.getattr("ret")?
.into();
Ok(fun.to_object(py)) // Return the PyObject
#[staticmethod]
fn baml_deserialize(data: PyObject, py: Python<'_>) -> PyResult<Self> {
let data: UserFacingBamlMedia = depythonize_bound(data.into_bound(py))?;
Ok(BamlImagePy {
inner: data.to_baml_media(baml_types::BamlMediaType::Image),
})
}

pub fn baml_serialize(&self, py: Python<'_>) -> PyResult<PyObject> {
let s: UserFacingBamlMedia = (&self.inner).try_into().map_err(BamlError::from_anyhow)?;
let s = serde_json::to_value(&s).map_err(|e| BamlError::from_anyhow(e.into()))?;
Ok(pythonize(py, &s)?)
}

pub fn __eq__(&self, other: &Self) -> bool {
self.inner == other.inner
}
Expand Down
137 changes: 137 additions & 0 deletions engine/language_client_python/src/types/media_repr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
use anyhow::Result;
use baml_types::{BamlMedia, BamlMediaContent, BamlMediaType, MediaBase64, MediaUrl};
use pyo3::{
types::{PyAnyMethods, PyModule, PyType},
Bound, Py, PyAny, PyObject, PyResult, Python, ToPyObject,
};
use serde::{Deserialize, Serialize};

/// We rely on the serialization and deserialization of this struct for:
///
/// - pydantic serialization (JSON->FastAPI->Pydantic->baml_py), so that
/// users can include BAML types directly in their user-facing requests
#[derive(Debug, Serialize, Deserialize)]
pub struct UserFacingBamlMedia {
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "media_type")]
pub mime_type: Option<String>,
#[serde(flatten)]
pub content: UserFacingBamlMediaContent,
}

#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum UserFacingBamlMediaContent {
Url { url: String },
Base64 { base64: String },
}

impl UserFacingBamlMedia {
pub fn to_baml_media(self, media_type: BamlMediaType) -> BamlMedia {
BamlMedia {
media_type,
mime_type: self.mime_type,
aaronvg marked this conversation as resolved.
Show resolved Hide resolved
content: match self.content {
UserFacingBamlMediaContent::Url { url } => BamlMediaContent::Url(MediaUrl { url }),
UserFacingBamlMediaContent::Base64 { base64 } => {
BamlMediaContent::Base64(MediaBase64 { base64 })
}
},
}
}
}

impl TryInto<UserFacingBamlMedia> for &BamlMedia {
type Error = anyhow::Error;

fn try_into(self) -> Result<UserFacingBamlMedia> {
Ok(UserFacingBamlMedia {
mime_type: self.mime_type.clone(),
aaronvg marked this conversation as resolved.
Show resolved Hide resolved
content: match &self.content {
BamlMediaContent::Url(url) => UserFacingBamlMediaContent::Url {
url: url.url.clone(),
},
BamlMediaContent::Base64(base64) => UserFacingBamlMediaContent::Base64 {
base64: base64.base64.clone(),
},
BamlMediaContent::File(_) => {
anyhow::bail!("Cannot convert file media to user facing media")
}
},
})
}
}

/// This function is used for Pydantic compatibility in three ways:
///
/// - allows constructing Pydantic models containing a BamlImagePy instance
/// - allows FastAPI requests to deserialize BamlImagePy instances in JSON format
/// - allows serializing BamlImagePy instances in JSON format
///
/// Ideally this belongs in baml_py.internal_monkeypatch, so that we can get
/// ruff-based type checking, but this depends on the pydantic libraries, so we
/// can't implement this in internal_monkeypatch without adding a hard dependency
/// on pydantic. And we don't want to do _that_, because that will make it harder
/// to implement output_type python/vanilla in the future.
pub fn __get_pydantic_core_schema__(
_cls: Bound<'_, PyType>,
_source_type: Bound<'_, PyAny>,
_handler: Bound<'_, PyAny>,
) -> PyResult<PyObject> {
Python::with_gil(|py| {
let code = r#"
from pydantic_core import core_schema, SchemaValidator

def deserialize(data):
from baml_py.baml_py import BamlImagePy
if isinstance(data, BamlImagePy):
return data
else:
SchemaValidator(
core_schema.union_schema([
core_schema.model_fields_schema({
'url': core_schema.model_field(core_schema.str_schema()),
'media_type': core_schema.model_field(
core_schema.with_default_schema(
core_schema.union_schema([
core_schema.str_schema(),
core_schema.none_schema(),
]),
default=None,
),
),
}),
core_schema.model_fields_schema({
'base64': core_schema.model_field(core_schema.str_schema()),
'media_type': core_schema.model_field(
core_schema.with_default_schema(
core_schema.union_schema([
core_schema.str_schema(),
core_schema.none_schema(),
]),
default=None,
),
),
}),
])
).validate_python(data)
return BamlImagePy.baml_deserialize(data)

def get_schema():
return core_schema.no_info_after_validator_function(
deserialize,
core_schema.any_schema(),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda v: v.baml_serialize(),
)
)

ret = get_schema()
"#;
// py.run(code, None, Some(ret_dict));
let fun: Py<PyAny> = PyModule::from_code_bound(py, code, "", "")?
.getattr("ret")?
.into();
Ok(fun.to_object(py))
})
}
Loading
Loading