Skip to content

Commit

Permalink
wip join (need to rewrite workflow builder to support join, also ther…
Browse files Browse the repository at this point in the history
…e is a bug in cyclic workflows atm anyway)

Signed-off-by: Teo Koon Peng <[email protected]>
  • Loading branch information
koonpeng committed Dec 19, 2024
1 parent 5215cf0 commit 2a863d3
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 31 deletions.
48 changes: 45 additions & 3 deletions src/diagram.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod fork_clone;
mod fork_result;
mod impls;
mod join;
mod node_registry;
mod serialization;
mod split_serialized;
Expand All @@ -10,6 +11,7 @@ mod workflow_builder;

use fork_clone::ForkCloneOp;
use fork_result::ForkResultOp;
use join::JoinOp;
pub use node_registry::*;
pub use serialization::*;
pub use split_serialized::*;
Expand Down Expand Up @@ -184,6 +186,46 @@ pub enum DiagramOperation {
/// ```
Split(SplitOp),

/// Waits for an item to be emitted from each of the inputs, then combined the
/// oldest of each into an array.
///
/// # Examples
/// ```
/// # bevy_impulse::Diagram::from_json_str(r#"
/// {
/// "ops": {
/// "start": {
/// "type": "start",
/// "next": "split"
/// },
/// "split": {
/// "type": "split",
/// "index": ["op1", "op2"]
/// },
/// "op1": {
/// "type": "node",
/// "nodeId": "foo",
/// "next": "join"
/// },
/// "op2": {
/// "type": "node",
/// "nodeId": "bar",
/// "next": "join"
/// },
/// "join": {
/// "type": "join",
/// "next": "terminate"
/// },
/// "terminate": {
/// "type": "terminate"
/// }
/// }
/// }
/// # "#)?;
/// # Ok::<_, serde_json::Error>(())
/// ```
Join(JoinOp),

/// If the request is serializable, transforms it by running it through a [CEL](https://cel.dev/) program.
/// The context includes a "request" variable which contains the request.
///
Expand Down Expand Up @@ -320,23 +362,23 @@ impl Diagram {
scope.terminate.id()
);

let mut dyn_builder =
let mut wf_builder =
unwrap_or_return!(WorkflowBuilder::new(&scope, builder, registry, self));

// connect node operations
for (op_id, op) in self.ops.iter().filter_map(|(op_id, v)| match v {
DiagramOperation::Node(op) => Some((op_id, op)),
_ => None,
}) {
unwrap_or_return!(dyn_builder.connect_node(&scope, builder, op_id, op));
unwrap_or_return!(wf_builder.connect_node(&scope, builder, op_id, op));
}

// connect start operation, note that this consumes scope, so we need to do this last
if let Some((_, start_op)) = self.ops.iter().find_map(|(op_id, v)| match v {
DiagramOperation::Start(op) => Some((op_id, op)),
_ => None,
}) {
unwrap_or_return!(dyn_builder.connect_start(scope, builder, start_op));
unwrap_or_return!(wf_builder.connect_start(scope, builder, start_op));
}
});

Expand Down
2 changes: 1 addition & 1 deletion src/diagram/fork_clone.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ where
amount: usize,
) -> Result<Vec<DynOutput>, DiagramError> {
debug!("fork clone: {:?}", output);
assert_eq!(output.type_info, TypeId::of::<T>());
assert_eq!(output.type_id, TypeId::of::<T>());

let fork_clone = output.into_output::<T>().fork_clone(builder);
let outputs = (0..amount)
Expand Down
2 changes: 1 addition & 1 deletion src/diagram/fork_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ where
output: DynOutput,
) -> Result<(DynOutput, DynOutput), DiagramError> {
debug!("fork result: {:?}", output);
assert_eq!(output.type_info, TypeId::of::<Result<T, E>>());
assert_eq!(output.type_id, TypeId::of::<Result<T, E>>());

let chain = output.into_output::<Result<T, E>>().chain(builder);
let outputs = chain.fork_result(|c| c.output().into(), |c| c.output().into());
Expand Down
108 changes: 108 additions & 0 deletions src/diagram/join.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
use std::any::TypeId;

use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use tracing::debug;

use crate::{Builder, IterBufferable};

use super::{
impls::DefaultImpl, register_serialize, DiagramError, DynOutput, NodeRegistry, SerializeMessage,
};

#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct JoinOp {
pub(super) next: String,
}

pub(super) fn register_join_impl<T, Serializer>(registry: &mut NodeRegistry)
where
T: Send + Sync + 'static,
Serializer: SerializeMessage<Vec<T>>,
{
if registry.join_impls.contains_key(&TypeId::of::<T>()) {
return;
}

registry
.join_impls
.insert(TypeId::of::<T>(), Box::new(join_impl::<T>));

// also need to register serialize for Vec<T>
register_serialize::<Vec<T>, Serializer>(registry);
}

fn join_impl<T>(builder: &mut Builder, outputs: Vec<DynOutput>) -> Result<DynOutput, DiagramError>
where
T: Send + Sync + 'static,
{
debug!("join outputs");

let first_type = outputs[0].type_id;

let outputs = outputs
.into_iter()
.map(|o| {
if o.type_id != first_type {
Err(DiagramError::TypeMismatch)
} else {
Ok(o.into_output::<T>())
}
})
.collect::<Result<Vec<_>, _>>()?;

// we don't know the number of items at compile time, so we just use a sensible number.
Ok(outputs.join_vec::<4>(builder).output().into())
}

#[cfg(test)]
mod tests {
use serde_json::json;
use test_log::test;

use crate::{diagram::testing::DiagramTestFixture, Diagram};

#[test]
fn test_join() {
let mut fixture = DiagramTestFixture::new();

let diagram = Diagram::from_json(json!({
"ops": {
"start": {
"type": "start",
"next": "split"
},
"unzip": {
"type": "unzip",
"next": ["op1", "op2"]
},
"op1": {
"type": "node",
"nodeId": "multiply3",
"next": "join"
},
"op2": {
"type": "node",
"nodeId": "multiply3",
"next": "join"
},
"join": {
"type": "join",
"next": "terminate"
},
"terminate": {
"type": "terminate"
}
}
}))
.unwrap();

let result = fixture
.spawn_and_run(&diagram, serde_json::Value::from([1, 2]))
.unwrap();
assert_eq!(result.as_array().unwrap().len(), 2);
assert_eq!(result[0], 3);
assert_eq!(result[1], 6);
}
}
18 changes: 14 additions & 4 deletions src/diagram/node_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ where
pub struct DynOutput {
scope: Entity,
target: Entity,
pub(super) type_info: TypeId,
pub(super) type_id: TypeId,
}

impl DynOutput {
Expand Down Expand Up @@ -97,7 +97,7 @@ where
Self {
scope: output.scope(),
target: output.id(),
type_info: TypeId::of::<T>(),
type_id: TypeId::of::<T>(),
}
}
}
Expand Down Expand Up @@ -306,8 +306,8 @@ impl<'a, DeserializeImpl, SerializeImpl, ForkCloneImpl, UnzipImpl, ForkResultImp
register_deserialize::<Request, DeserializeImpl>(self.registry);
register_serialize::<Response, SerializeImpl>(self.registry);

UnzipImpl::register_serialize(&mut self.registry);
SplitImpl::register_serialize(&mut self.registry);
UnzipImpl::on_register(&mut self.registry);
SplitImpl::on_register(&mut self.registry);
}

/// Mark the node as having a non deserializable request. This allows nodes with
Expand Down Expand Up @@ -431,6 +431,11 @@ pub struct NodeRegistry {

pub(super) serialize_impls:
HashMap<TypeId, Box<dyn Fn(&mut Builder, DynOutput) -> Output<serde_json::Value>>>,

pub(super) join_impls: HashMap<
TypeId,
Box<dyn Fn(&mut Builder, Vec<DynOutput>) -> Result<DynOutput, DiagramError>>,
>,
}

impl Default for NodeRegistry {
Expand All @@ -442,11 +447,16 @@ impl Default for NodeRegistry {
gen: SchemaGenerator::new(settings),
deserialize_impls: HashMap::new(),
serialize_impls: HashMap::new(),
join_impls: HashMap::new(),
}
}
}

impl NodeRegistry {
pub fn new() -> Self {
Self::default()
}

/// Create a new [`RegistrationBuilder`]. By default, it is configured for nodes with
/// deserializable request and serializable responses and without support for any interconnect
/// operations like "fork_clone" and "unzip". See [`RegistrationBuilder`] for more information
Expand Down
10 changes: 6 additions & 4 deletions src/diagram/split_serialized.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use crate::{

use super::{
impls::{DefaultImpl, NotSupported},
join::register_join_impl,
register_serialize, DiagramError, DynOutput, NodeRegistry, OperationId, SerializeMessage,
};

Expand Down Expand Up @@ -213,7 +214,7 @@ pub trait DynSplit<T, Serializer> {
split_op: &'a SplitOp,
) -> Result<DynSplitOutputs<'a>, DiagramError>;

fn register_serialize(registry: &mut NodeRegistry);
fn on_register(registry: &mut NodeRegistry);
}

impl<T, Serializer> DynSplit<T, Serializer> for NotSupported {
Expand All @@ -227,14 +228,14 @@ impl<T, Serializer> DynSplit<T, Serializer> for NotSupported {
Err(DiagramError::NotSplittable)
}

fn register_serialize(_registry: &mut NodeRegistry) {}
fn on_register(_registry: &mut NodeRegistry) {}
}

impl<T, Serializer> DynSplit<T, Serializer> for DefaultImpl
where
T: Send + Sync + 'static + Splittable,
T::Key: FromSequential + FromSpecific<SpecificKey = String> + ForRemaining,
Serializer: SerializeMessage<T::Item>,
Serializer: SerializeMessage<T::Item> + SerializeMessage<Vec<T::Item>>,
{
const SUPPORTED: bool = true;

Expand All @@ -247,8 +248,9 @@ where
split_chain(chain, split_op)
}

fn register_serialize(registry: &mut NodeRegistry) {
fn on_register(registry: &mut NodeRegistry) {
register_serialize::<T::Item, Serializer>(registry);
register_join_impl::<T::Item, Serializer>(registry);
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/diagram/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ pub(super) fn transform_output(
) -> Result<Output<serde_json::Value>, DiagramError> {
debug!("transform output: {:?}, op: {:?}", output, transform_op);

let json_output = if output.type_info == TypeId::of::<serde_json::Value>() {
let json_output = if output.type_id == TypeId::of::<serde_json::Value>() {
output.into_output()
} else {
let serialize = registry
.serialize_impls
.get(&output.type_info)
.get(&output.type_id)
.ok_or(DiagramError::NotSerializable)?;
serialize(builder, output)
};
Expand Down
19 changes: 13 additions & 6 deletions src/diagram/unzip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::Builder;

use super::{
impls::{DefaultImpl, NotSupported},
join::register_join_impl,
register_serialize as register_serialize_impl, DiagramError, DynOutput, NodeRegistry,
OperationId, SerializeMessage,
};
Expand All @@ -22,9 +23,8 @@ pub trait DynUnzip<T, Serializer> {

fn dyn_unzip(builder: &mut Builder, output: DynOutput) -> Result<Vec<DynOutput>, DiagramError>;

/// Register serialize functions for all items in the tuple.
/// For a tuple of (T1, T2, T3), registers serialize for T1, T2 and T3.
fn register_serialize(registry: &mut NodeRegistry);
/// Called when a node is registered.
fn on_register(registry: &mut NodeRegistry);
}

impl<T, Serializer> DynUnzip<T, Serializer> for NotSupported {
Expand All @@ -37,15 +37,15 @@ impl<T, Serializer> DynUnzip<T, Serializer> for NotSupported {
Err(DiagramError::NotUnzippable)
}

fn register_serialize(_registry: &mut NodeRegistry) {}
fn on_register(_registry: &mut NodeRegistry) {}
}

macro_rules! dyn_unzip_impl {
($len:literal, $(($P:ident, $o:ident)),*) => {
impl<$($P),*, Serializer> DynUnzip<($($P,)*), Serializer> for DefaultImpl
where
$($P: Send + Sync + 'static),*,
Serializer: $(SerializeMessage<$P> +)*
Serializer: $(SerializeMessage<$P> +)* $(SerializeMessage<Vec<$P>> +)*,
{
const UNZIP_SLOTS: usize = $len;

Expand All @@ -66,11 +66,18 @@ macro_rules! dyn_unzip_impl {
Ok(outputs)
}

fn register_serialize(registry: &mut NodeRegistry)
fn on_register(registry: &mut NodeRegistry)
{
// Register serialize functions for all items in the tuple.
// For a tuple of (T1, T2, T3), registers serialize for T1, T2 and T3.
$(
register_serialize_impl::<$P, Serializer>(registry);
)*

// Register join impls for T1, T2, T3...
$(
register_join_impl::<$P, Serializer>(registry);
)*
}
}
};
Expand Down
Loading

0 comments on commit 2a863d3

Please sign in to comment.