Skip to content

Commit

Permalink
casting bug (#1059)
Browse files Browse the repository at this point in the history
- **Adding unit test**
- **Adding commit to fix parsing**

<!-- ELLIPSIS_HIDDEN -->


----

> [!IMPORTANT]
> Replaces `coerce` with `cast` for type handling, introduces
`DifferentiateUnions` for union types, and updates Python, TypeScript,
and Ruby client implementations.
> 
>   - **Behavior**:
> - Replaces `coerce` with `cast` in `async_client.py.j2`,
`sync_client.py.j2`, and `baml_py.pyi`.
> - Introduces `DifferentiateUnions` for handling `OriginalA` and
`OriginalB` union types.
>   - **TypeScript**:
> - Updates `sync_client.ts`, `type_builder.ts`, and `types.ts` for
`DifferentiateUnions` and casting logic.
>   - **Ruby**:
> - Updates `client.rb`, `inlined.rb`, and `partial-types.rb` for
`DifferentiateUnions` and casting changes.
>   - **Tests**:
> - Adds `test_differing_unions` in `test_functions.py` to validate
`DifferentiateUnions`.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for 3eb3ec3. It will automatically
update as commits are pushed.</sup>


<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
hellovai authored Oct 18, 2024
1 parent 5eae0a7 commit 830b0cb
Show file tree
Hide file tree
Showing 25 changed files with 1,195 additions and 1,373 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, TypeVar, Union, TypedDict, Type, Literal
from typing import Any, Dict, List, Optional, TypeVar, Union, TypedDict, Type, Literal, cast
from typing_extensions import NotRequired
import pprint

Expand All @@ -12,16 +12,6 @@ from .globals import DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX, D

OutputType = TypeVar('OutputType')

def coerce(cls: Type[BaseModel], parsed: Any) -> Any:
try:
return cls.model_validate({"inner": parsed}).inner # type: ignore
except ValidationError as e:
raise TypeError(
"Internal BAML error while casting output to {}\n{}".format(
cls.__name__,
pprint.pformat(parsed)
)
) from e

# Define the TypedDict with optional parameters having default values
class BamlCallOptions(TypedDict, total=False):
Expand Down Expand Up @@ -69,8 +59,7 @@ class BamlAsyncClient:
tb,
__cr__,
)
mdl = create_model("{{ fn.name }}ReturnType", inner=({{ fn.return_type }}, ...))
return coerce(mdl, raw.parsed())
return cast({{fn.return_type}}, raw.cast_to(types, types))
{% endfor %}


Expand Down Expand Up @@ -110,13 +99,10 @@ class BamlStreamClient:
__cr__,
)

mdl = create_model("{{ fn.name }}ReturnType", inner=({{ fn.return_type }}, ...))
partial_mdl = create_model("{{ fn.name }}PartialReturnType", inner=({{ fn.partial_return_type }}, ...))

return baml_py.BamlStream[{{ fn.partial_return_type }}, {{ fn.return_type }}](
raw,
lambda x: coerce(partial_mdl, x),
lambda x: coerce(mdl, x),
lambda x: cast({{fn.partial_return_type}}, x.cast_to(types, partial_types)),
lambda x: cast({{fn.return_type}}, x.cast_to(types, types)),
self.__ctx_manager.get(),
)
{% endfor %}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, TypeVar, Union, TypedDict, Type, Literal
from typing import Any, Dict, List, Optional, TypeVar, Union, TypedDict, Type, Literal, cast
from typing_extensions import NotRequired
import pprint

Expand All @@ -11,17 +11,6 @@ from .globals import DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX, D

OutputType = TypeVar('OutputType')

def coerce(cls: Type[BaseModel], parsed: Any) -> Any:
try:
return cls.model_validate({"inner": parsed}).inner # type: ignore
except ValidationError as e:
raise TypeError(
"Internal BAML error while casting output to {}\n{}".format(
cls.__name__,
pprint.pformat(parsed)
)
) from e

# Define the TypedDict with optional parameters having default values
class BamlCallOptions(TypedDict, total=False):
tb: NotRequired[TypeBuilder]
Expand Down Expand Up @@ -67,8 +56,7 @@ class BamlSyncClient:
tb,
__cr__,
)
mdl = create_model("{{ fn.name }}ReturnType", inner=({{ fn.return_type }}, ...))
return coerce(mdl, raw.parsed())
return cast({{fn.return_type}}, raw.cast_to(types, types))
{% endfor %}


Expand Down Expand Up @@ -109,13 +97,10 @@ class BamlStreamClient:
__cr__,
)

mdl = create_model("{{ fn.name }}ReturnType", inner=({{ fn.return_type }}, ...))
partial_mdl = create_model("{{ fn.name }}PartialReturnType", inner=({{ fn.partial_return_type }}, ...))

return baml_py.BamlSyncStream[{{ fn.partial_return_type }}, {{ fn.return_type }}](
raw,
lambda x: coerce(partial_mdl, x),
lambda x: coerce(mdl, x),
lambda x: cast({{fn.partial_return_type}}, x.cast_to(types, partial_types)),
lambda x: cast({{fn.return_type}}, x.cast_to(types, types)),
self.__ctx_manager.get(),
)
{% endfor %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ class FunctionResult:
"""

def __str__(self) -> str: ...
def parsed(self) -> Any: ...
# Returns True if the function call was successful, False otherwise
def is_ok(self) -> bool: ...
def cast_to(self, enum_module: Any, class_module: Any) -> Any: ...

class FunctionResultStream:
"""The result of a BAML function stream.
Expand Down
8 changes: 4 additions & 4 deletions engine/language_client_python/python_src/baml_py/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ async def __aiter__(self):
if event is None:
break
if event.is_ok():
yield self.__partial_coerce(event.parsed())
yield self.__partial_coerce(event)

async def get_final_response(self):
final = self.__drive_to_completion_in_bg()
return self.__final_coerce((await asyncio.wrap_future(final)).parsed())
return self.__final_coerce((await asyncio.wrap_future(final)))


class BamlSyncStream(Generic[PartialOutputType, FinalOutputType]):
Expand Down Expand Up @@ -138,7 +138,7 @@ def __iter__(self):
if event is None:
break
if event.is_ok():
yield self.__partial_coerce(event.parsed())
yield self.__partial_coerce(event)

def get_final_response(self):
self.__drive_to_completion_in_bg()
Expand All @@ -153,4 +153,4 @@ def get_final_response(self):
"BAML Internal error: Stream did not complete successfully. Please report this issue."
)

return self.__final_coerce(self.__result.parsed())
return self.__final_coerce(self.__result)
93 changes: 90 additions & 3 deletions engine/language_client_python/src/types/function_results.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use baml_types::BamlValue;
use pyo3::prelude::{pymethods, PyResult};
use pyo3::{PyObject, Python};
use pyo3::types::{PyAnyMethods, PyListMethods, PyModule};
use pyo3::{Bound, IntoPy, PyObject, Python};
use pythonize::pythonize;

use crate::errors::BamlError;

use super::{BamlAudioPy, BamlImagePy};

crate::lang_wrapper!(FunctionResult, baml_runtime::FunctionResult);

#[pymethods]
Expand All @@ -17,12 +20,96 @@ impl FunctionResult {
self.inner.parsed_content().is_ok()
}

fn parsed(&self, py: Python<'_>) -> PyResult<PyObject> {
// Cast the parsed value to a specific type
// the module is the module that the type is defined in
fn cast_to(
&self,
py: Python<'_>,
enum_module: Bound<'_, PyModule>,
cls_module: Bound<'_, PyModule>,
) -> PyResult<PyObject> {
let parsed = self
.inner
.parsed_content()
.map_err(BamlError::from_anyhow)?;

Ok(pythonize(py, &BamlValue::from(parsed))?)
let parsed = BamlValue::from(parsed);
let parsed = pythonize_strict(py, &parsed, &enum_module, &cls_module)?;

Ok(parsed)
}
}

fn pythonize_strict(
py: Python<'_>,
parsed: &BamlValue,
enum_module: &Bound<'_, PyModule>,
cls_module: &Bound<'_, PyModule>,
) -> PyResult<PyObject> {
match parsed {
BamlValue::String(val) => Ok(val.into_py(py)),
BamlValue::Int(val) => Ok(val.into_py(py)),
BamlValue::Float(val) => Ok(val.into_py(py)),
BamlValue::Bool(val) => Ok(val.into_py(py)),
BamlValue::Map(index_map) => {
let dict = pyo3::types::PyDict::new_bound(py);
for (key, value) in index_map {
let key = key.into_py(py);
let value = pythonize_strict(py, value, enum_module, cls_module)?;
dict.set_item(key, value)?;
}
Ok(dict.into())
}
BamlValue::List(vec) => Ok(pyo3::types::PyList::new_bound(
py,
vec.iter()
.map(|v| pythonize_strict(py, v, enum_module, cls_module))
.collect::<PyResult<Vec<_>>>()?,
)
.into()),
BamlValue::Media(baml_media) => match baml_media.media_type {
baml_types::BamlMediaType::Image => {
Ok(BamlImagePy::from(baml_media.clone()).into_py(py))
}
baml_types::BamlMediaType::Audio => {
Ok(BamlAudioPy::from(baml_media.clone()).into_py(py))
}
},
BamlValue::Enum(enum_name, value) => {
let enum_type = match enum_module.getattr(enum_name.as_str()) {
Ok(e) => e,
// This can be true in the case of dynamic types.
Err(_) => return Ok(enum_name.into_py(py)),
};

// Call the constructor with the value
let instance = enum_type.call1((value,))?;
Ok(instance.into())
}
BamlValue::Class(class_name, index_map) => {
let properties = index_map
.iter()
.map(|(key, value)| {
let key = key.as_str();
let value = pythonize_strict(py, value, enum_module, cls_module)?;
Ok((key, value))
})
.collect::<PyResult<Vec<_>>>()?;

let properties_dict = pyo3::types::PyDict::new_bound(py);
for (key, value) in properties {
properties_dict.set_item(key, value)?;
}

let class_type = match cls_module.getattr(class_name.as_str()) {
Ok(class) => class,
// This can be true in the case of dynamic types.
Err(_) => return Ok(properties_dict.into()),
};
let instance = class_type.call_method("model_validate", (properties_dict,), None)?;

Ok(instance.into())
}
BamlValue::Null => Ok(py.None()),
}
}
17 changes: 17 additions & 0 deletions integ-tests/baml_src/test-files/dynamic/differentiate_unions.baml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
class OriginalA {
value int
}

class OriginalB {
value int
@@dynamic
}

function DifferentiateUnions() -> OriginalA | OriginalB {
client "openai/gpt-4o-mini"
prompt #"
Create a data model that represents the latter of the two classes.

{{ ctx.output_format }}
"#
}
6 changes: 4 additions & 2 deletions integ-tests/baml_src/test-files/providers/providers.baml
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,12 @@ test TestProvider {
}


function TestCaching(input: string) -> string {
function TestCaching(input: string, not_cached: string) -> string {
client ClaudeWithCaching
prompt #"
{{ _.role('user', cache_control={"type": "ephemeral"}) }}
{{ _.role('system', cache_control={"type": "ephemeral"}) }}
Describe this in 5 words: {{ input }}
{{ _.role('user') }}
{{ not_cached }}
"#
}
47 changes: 47 additions & 0 deletions integ-tests/openapi/baml_client/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,21 @@ paths:
title: DescribeImage4Response
type: string
operationId: DescribeImage4
/call/DifferentiateUnions:
post:
requestBody:
$ref: '#/components/requestBodies/DifferentiateUnions'
responses:
'200':
description: Successful operation
content:
application/json:
schema:
title: DifferentiateUnionsResponse
oneOf:
- $ref: '#/components/schemas/OriginalA'
- $ref: '#/components/schemas/OriginalB'
operationId: DifferentiateUnions
/call/DummyOutputFunction:
post:
requestBody:
Expand Down Expand Up @@ -1448,6 +1463,19 @@ components:
- classWithImage
- img2
additionalProperties: false
DifferentiateUnions:
required: true
content:
application/json:
schema:
title: DifferentiateUnionsRequest
type: object
properties:
__baml_options__:
nullable: true
$ref: '#/components/schemas/BamlOptions'
required: []
additionalProperties: false
DummyOutputFunction:
required: true
content:
Expand Down Expand Up @@ -2186,11 +2214,14 @@ components:
properties:
input:
type: string
not_cached:
type: string
__baml_options__:
nullable: true
$ref: '#/components/schemas/BamlOptions'
required:
- input
- not_cached
additionalProperties: false
TestFallbackClient:
required: true
Expand Down Expand Up @@ -3153,6 +3184,22 @@ components:
required:
- order_status
additionalProperties: false
OriginalA:
type: object
properties:
value:
type: integer
required:
- value
additionalProperties: false
OriginalB:
type: object
properties:
value:
type: integer
required:
- value
additionalProperties: false
Person:
type: object
properties:
Expand Down
Loading

0 comments on commit 830b0cb

Please sign in to comment.