diff --git a/Cargo.lock b/Cargo.lock index ed186cac6f..0c2694102a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -613,6 +613,7 @@ dependencies = [ "burn-hip", "burn-ndarray", "burn-remote", + "burn-router", "burn-tch", "burn-tensor", "burn-wgpu", @@ -805,6 +806,7 @@ dependencies = [ name = "burn-remote" version = "0.16.0" dependencies = [ + "async-channel", "axum", "burn-common", "burn-remote", diff --git a/Cargo.toml b/Cargo.toml index c092a8177b..7fbca03c1c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,6 @@ members = [ exclude = [ "examples/notebook", "examples/raspberry-pi-pico", # will cause dependency building issues otherwise - # "crates/burn-cuda", # comment this line to work on burn-cuda ] [workspace.package] @@ -168,4 +167,3 @@ tracel-xtask = { version = "~1.1" } [profile.dev] debug = 0 # Speed up compilation time and not necessary. -opt-level = 2 diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml index c82b27265a..1ef80fd4f1 100644 --- a/crates/burn-core/Cargo.toml +++ b/crates/burn-core/Cargo.toml @@ -22,9 +22,10 @@ default = [ "burn-tch?/default", "burn-tensor/default", "burn-wgpu?/default", + "burn-router?/default", "burn-cuda?/default", - "burn-hip?/default", "burn-autodiff?/default", + "burn-hip?/default", ] doc = [ "std", @@ -40,6 +41,7 @@ doc = [ "vision", "autodiff", "remote", + "router", "server", # Doc features "burn-candle/doc", @@ -49,6 +51,7 @@ doc = [ "burn-tch/doc", "burn-tensor/doc", "burn-wgpu/doc", + "burn-router/doc", "burn-cuda/doc", "burn-hip/doc", ] @@ -63,6 +66,7 @@ std = [ "burn-ndarray?/std", "burn-tensor/std", "burn-wgpu?/std", + "burn-router?/std", "burn-cuda?/std", "burn-hip?/std", "flate2", @@ -89,6 +93,7 @@ openblas = ["burn-ndarray?/blas-openblas"] openblas-system = ["burn-ndarray?/blas-openblas-system"] template = ["burn-wgpu?/template"] remote = ["burn-remote/client"] +router = ["burn-router"] server = ["burn-remote/server"] candle = ["burn-candle"] @@ -136,6 +141,7 @@ burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", optional = true, burn-tch = { path = "../burn-tch", version = "0.16.0", optional = true } burn-wgpu = { path = "../burn-wgpu", version = "0.16.0", optional = true, default-features = false } burn-remote = { path = "../burn-remote", version = "0.16.0", default-features = false, optional = true } +burn-router = { path = "../burn-router", version = "0.16.0", default-features = false, optional = true } data-encoding = { workspace = true } uuid = { workspace = true } diff --git a/crates/burn-core/src/backend.rs b/crates/burn-core/src/backend.rs index 5608ca6d1a..bd4c959302 100644 --- a/crates/burn-core/src/backend.rs +++ b/crates/burn-core/src/backend.rs @@ -44,3 +44,6 @@ pub use burn_tch as libtorch; #[cfg(feature = "tch")] pub use burn_tch::LibTorch; + +#[cfg(feature = "router")] +pub use burn_router::Router; diff --git a/crates/burn-fusion/src/stream/base.rs b/crates/burn-fusion/src/stream/base.rs index b6aca1ecb9..29b31df880 100644 --- a/crates/burn-fusion/src/stream/base.rs +++ b/crates/burn-fusion/src/stream/base.rs @@ -1,6 +1,10 @@ +use std::collections::BTreeSet; + use super::{execution::Operation, OperationConverter, RelativeOps}; use crate::FusionRuntime; -use burn_tensor::repr::OperationDescription; +use burn_tensor::repr::{OperationDescription, TensorId}; + +pub use burn_common::id::StreamId; /// A growing list of [tensor operation descriptions](OperationDescription). pub struct OperationQueue { @@ -15,6 +19,7 @@ pub struct OperationQueue { pub(crate) relative: Vec, pub(crate) converter: OperationConverter, pub(crate) operations: Vec>>, + pub(crate) ids: BTreeSet, } impl Default for OperationQueue { @@ -23,44 +28,6 @@ impl Default for OperationQueue { } } -/// The stream id. -#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] -pub struct StreamId { - #[cfg(feature = "std")] - value: std::thread::ThreadId, - #[cfg(not(feature = "std"))] - value: (), -} - -impl StreamId { - /// Get the current stream id. - pub fn current() -> Self { - Self { - #[cfg(feature = "std")] - value: Self::id(), - #[cfg(not(feature = "std"))] - value: (), - } - } - - #[cfg(feature = "std")] - fn id() -> std::thread::ThreadId { - std::thread_local! { - static ID: std::cell::OnceCell:: = const { std::cell::OnceCell::new() }; - }; - - // Getting the current thread is expensive, so we cache the value into a thread local - // variable, which is very fast. - ID.with(|cell| *cell.get_or_init(|| std::thread::current().id())) - } -} - -impl core::fmt::Display for StreamId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_fmt(format_args!("StreamID({:?})", self.value)) - } -} - impl OperationQueue { /// Create a new empty queue. pub fn new() -> Self { @@ -69,6 +36,7 @@ impl OperationQueue { relative: Vec::new(), converter: OperationConverter::default(), operations: Vec::new(), + ids: BTreeSet::new(), } } @@ -78,6 +46,9 @@ impl OperationQueue { /// representation that can be reused when the same pattern emerge in different but similar /// scenario, so that the same optimization can be used. pub fn add(&mut self, global: OperationDescription, operation: Box>) { + for node in global.nodes() { + self.ids.insert(node.id); + } let relative = global.to_relative(&mut self.converter); self.relative.push(relative); self.global.push(global); diff --git a/crates/burn-fusion/src/stream/multi.rs b/crates/burn-fusion/src/stream/multi.rs index 9c39b329cf..ca49630006 100644 --- a/crates/burn-fusion/src/stream/multi.rs +++ b/crates/burn-fusion/src/stream/multi.rs @@ -1,4 +1,4 @@ -use burn_tensor::repr::{HandleContainer, OperationDescription}; +use burn_tensor::repr::{HandleContainer, OperationDescription, TensorDescription}; use super::{ execution::{ExecutionMode, Operation, Processor, StreamSegment}, @@ -32,7 +32,7 @@ impl MultiStream { operation: Box>, handles: &mut HandleContainer, ) { - let id = self.maybe_drain(streams, handles); + let id = self.resolve_streams(streams, handles, &desc); let stream = match self.streams.get_mut(&id) { Some(stream) => stream, @@ -64,7 +64,7 @@ impl MultiStream { } } - /// Drain the streams + /// Drain a stream pub fn drain(&mut self, handles: &mut HandleContainer, id: StreamId) { if let Some(mut stream) = self.streams.remove(&id) { stream.processor.process( @@ -79,29 +79,42 @@ impl MultiStream { /// When one of the provided streams is different from the current stream, we drain them. /// /// Returns the current stream id. - fn maybe_drain( + fn resolve_streams( &mut self, streams: Vec, handles: &mut HandleContainer, + op: &OperationDescription, ) -> StreamId { let streams = Self::remove_duplicate(streams); let current = StreamId::current(); - if streams.len() == 1 { - // The only case where we don't need to drain, because we will process - // the operation queue of the current stream right after this. - if streams[0] == current { - return current; - } - } - for id in streams { - self.drain(handles, id); + if id != current { + self.resolve_stream(handles, id, op.nodes()); + } } current } + /// Drain the stream only if one of the tensor in the given nodes is also included in the + /// stream queue. + fn resolve_stream( + &mut self, + handles: &mut HandleContainer, + id: StreamId, + nodes: Vec<&TensorDescription>, + ) { + if let Some(stream) = self.streams.get(&id) { + for node in nodes { + if stream.queue.ids.contains(&node.id) { + self.drain(handles, id); + return; + } + } + } + } + fn remove_duplicate(items: Vec) -> Vec { if items.len() == 1 { return items; diff --git a/crates/burn-hip/src/lib.rs b/crates/burn-hip/src/lib.rs index 35442dfa4f..89b91243c6 100644 --- a/crates/burn-hip/src/lib.rs +++ b/crates/burn-hip/src/lib.rs @@ -18,13 +18,15 @@ pub type Hip = JitBackend; #[cfg(feature = "fusion")] pub type Hip = burn_fusion::Fusion>; -#[cfg(target_os = "linux")] -#[cfg(test)] -mod tests { - use burn_jit::JitBackend; - - pub type TestRuntime = cubecl::hip::HipRuntime; - pub use half::{bf16, f16}; - - burn_jit::testgen_all!(); -} +// TODO: Hang the computer when AMD isn't available. +// +// #[cfg(target_os = "linux")] +// #[cfg(test)] +// mod tests { +// use burn_jit::JitBackend; +// +// pub type TestRuntime = cubecl::hip::HipRuntime; +// pub use half::{bf16, f16}; +// +// burn_jit::testgen_all!(); +// } diff --git a/crates/burn-remote/Cargo.toml b/crates/burn-remote/Cargo.toml index 65ca12e1ee..920843fff8 100644 --- a/crates/burn-remote/Cargo.toml +++ b/crates/burn-remote/Cargo.toml @@ -14,7 +14,7 @@ version.workspace = true [features] default = [] doc = [] -client = ["tokio-tungstenite"] +client = ["tokio-tungstenite", "async-channel", "tokio/sync"] server = ["axum", "tracing-core", "tracing-subscriber"] @@ -28,13 +28,14 @@ derive-new = {workspace = true } log = { workspace = true } # Shared dependencies -tokio = { version = "1.37", features = ["sync", "rt-multi-thread"] } +tokio = { version = "1.37", features = ["rt-multi-thread"] } serde = { workspace = true, features = ["derive"] } serde_bytes = { workspace = true } rmp-serde = { workspace = true } futures-util = { version = "0.3" } # Client dependencies +async-channel = { workspace = true, optional = true } tokio-tungstenite = { version = "0.24", optional = true } # Server dependencies diff --git a/crates/burn-remote/src/client/base.rs b/crates/burn-remote/src/client/base.rs index 29057f7886..2d070b3fcd 100644 --- a/crates/burn-remote/src/client/base.rs +++ b/crates/burn-remote/src/client/base.rs @@ -1,12 +1,12 @@ use super::worker::{ClientRequest, ClientWorker}; use crate::shared::{ComputeTask, ConnectionId, Task, TaskResponseContent}; +use async_channel::Sender; use burn_common::id::StreamId; use burn_tensor::repr::TensorId; use std::{ future::Future, sync::{atomic::AtomicU64, Arc}, }; -use tokio::sync::mpsc::Sender; pub use super::WsDevice; @@ -46,22 +46,19 @@ pub(crate) struct WsSender { } impl WsSender { - pub(crate) fn send(&self, task: ComputeTask) -> impl Future + Send { + pub(crate) fn send(&self, task: ComputeTask) { let position = self .position_counter .fetch_add(1, std::sync::atomic::Ordering::Relaxed); let stream_id = StreamId::current(); let sender = self.sender.clone(); - async move { - sender - .send(ClientRequest::WithoutCallback(Task::Compute( - task, - ConnectionId::new(position, stream_id), - ))) - .await - .unwrap(); - } + sender + .send_blocking(ClientRequest::WithoutCallback(Task::Compute( + task, + ConnectionId::new(position, stream_id), + ))) + .unwrap(); } pub(crate) fn new_tensor_id(&self) -> TensorId { @@ -79,20 +76,18 @@ impl WsSender { .fetch_add(1, std::sync::atomic::Ordering::Relaxed); let stream_id = StreamId::current(); let sender = self.sender.clone(); - let (callback_sender, mut callback_recv) = tokio::sync::mpsc::channel(1); + let (callback_sender, callback_recv) = async_channel::bounded(1); + sender + .send_blocking(ClientRequest::WithSyncCallback( + Task::Compute(task, ConnectionId::new(position, stream_id)), + callback_sender, + )) + .unwrap(); async move { - sender - .send(ClientRequest::WithSyncCallback( - Task::Compute(task, ConnectionId::new(position, stream_id)), - callback_sender, - )) - .await - .unwrap(); - match callback_recv.recv().await { - Some(val) => val, - None => panic!(""), + Ok(val) => val, + Err(err) => panic!("{err:?}"), } } } diff --git a/crates/burn-remote/src/client/runner.rs b/crates/burn-remote/src/client/runner.rs index f75bc52173..97c031cb2b 100644 --- a/crates/burn-remote/src/client/runner.rs +++ b/crates/burn-remote/src/client/runner.rs @@ -3,7 +3,7 @@ use burn_tensor::{ backend::{DeviceId, DeviceOps}, DType, TensorData, }; -use std::sync::Arc; +use std::{future::Future, sync::Arc}; use crate::shared::{ComputeTask, TaskResponseContent}; @@ -18,10 +18,8 @@ impl RunnerClient for WsClient { type Device = WsDevice; fn register(&self, op: burn_tensor::repr::OperationDescription) { - let fut = self - .sender + self.sender .send(ComputeTask::RegisterOperation(Box::new(op))); - self.runtime.block_on(fut); } fn read_tensor( @@ -44,9 +42,7 @@ impl RunnerClient for WsClient { let shape = data.shape.clone(); let dtype = data.dtype; - let fut = self.sender.send(ComputeTask::RegisterTensor(id, data)); - - self.runtime.block_on(fut); + self.sender.send(ComputeTask::RegisterTensor(id, data)); RouterTensor::new(Arc::new(id), shape, dtype, self.clone()) } @@ -74,22 +70,20 @@ impl RunnerClient for WsClient { } fn register_orphan(&self, id: &burn_tensor::repr::TensorId) { - let fut = self.sender.send(ComputeTask::RegisterOrphan(*id)); - self.runtime.block_on(fut); + self.sender.send(ComputeTask::RegisterOrphan(*id)); } - fn sync(&self) { + fn sync(&self) -> impl Future + Send + 'static { // Important for ordering to call the creation of the future sync. let fut = self.sender.send_callback(ComputeTask::SyncBackend); + let runtime = self.runtime.clone(); - let fut = async move { - match fut.await { + async move { + match runtime.block_on(fut) { TaskResponseContent::SyncBackend => {} _ => panic!("Invalid message type"), }; - }; - - self.runtime.block_on(fut) + } } fn seed(&self, _seed: u64) { diff --git a/crates/burn-remote/src/client/worker.rs b/crates/burn-remote/src/client/worker.rs index 75209dfa2a..dd0236d377 100644 --- a/crates/burn-remote/src/client/worker.rs +++ b/crates/burn-remote/src/client/worker.rs @@ -7,7 +7,7 @@ use tokio_tungstenite::{ tungstenite::protocol::{Message, WebSocketConfig}, }; -pub type CallbackSender = tokio::sync::mpsc::Sender; +pub type CallbackSender = async_channel::Sender; pub enum ClientRequest { WithSyncCallback(Task, CallbackSender), @@ -45,7 +45,7 @@ impl ClientWorker { .unwrap(), ); - let (sender, mut rec) = tokio::sync::mpsc::channel(10); + let (sender, rec) = async_channel::bounded(10); let address_request = format!("{}/{}", device.address.as_str(), "request"); let address_response = format!("{}/{}", device.address.as_str(), "response"); @@ -117,11 +117,11 @@ impl ClientWorker { // Channel async worker sending operations to the server. tokio::spawn(async move { - while let Some(req) = rec.recv().await { + while let Ok(req) = rec.recv().await { let task = match req { ClientRequest::WithSyncCallback(task, callback) => { - let mut state = state.lock().await; if let Task::Compute(_content, id) = &task { + let mut state = state.lock().await; state.register_callback(*id, callback); } task diff --git a/crates/burn-remote/src/server/base.rs b/crates/burn-remote/src/server/base.rs index 169c262014..d3960b1603 100644 --- a/crates/burn-remote/src/server/base.rs +++ b/crates/burn-remote/src/server/base.rs @@ -1,5 +1,3 @@ -use std::{net::SocketAddr, sync::Arc}; - use axum::{ extract::{ ws::{self, WebSocket, WebSocketUpgrade}, @@ -9,6 +7,7 @@ use axum::{ routing::any, Router, }; +use std::{net::SocketAddr, sync::Arc}; use burn_tensor::{ backend::{Backend, BackendBridge}, @@ -90,37 +89,33 @@ where let packet = socket.recv().await; let msg = match packet { Some(msg) => msg, - None => { - log::info!("Still no message"); - panic!(""); - } + None => panic!("Still no message"), }; - if let Ok(ws::Message::Binary(bytes)) = msg { - let task = match rmp_serde::from_slice::(&bytes) { - Ok(val) => val, - Err(err) => { - log::info!("Only bytes messages are supported {err:?}"); - panic!(""); - } - }; - let id = match task { - Task::Init(id) => id, - _ => panic!(""), - }; + match msg { + Ok(ws::Message::Binary(bytes)) => { + let task = match rmp_serde::from_slice::(&bytes) { + Ok(val) => val, + Err(err) => panic!("Only bytes messages are supported {err:?}"), + }; + let id = match task { + Task::Init(id) => id, + _ => panic!("Response handler not initialized."), + }; - let receiver = self.state.register_responder(id).await; + let receiver = self.state.register_responder(id); - log::info!("Response handler connection active"); + log::info!("Response handler connection active"); - while let Ok(callback) = receiver.recv() { - let response = callback.recv().unwrap(); - let bytes = rmp_serde::to_vec(&response).unwrap(); + while let Ok(callback) = receiver.recv() { + let response = callback.recv().unwrap(); + let bytes = rmp_serde::to_vec(&response).unwrap(); - socket.send(ws::Message::Binary(bytes)).await.unwrap(); + socket.send(ws::Message::Binary(bytes)).await.unwrap(); + } } - } else { - panic!(""); + Err(err) => panic!("Can't start the response handler {err:?}"), + _ => panic!("Unsupported message type"), } } @@ -147,14 +142,13 @@ where } }; - let (stream, connection_id, task) = - match self.state.stream(&mut session_id, task).await { - Some(val) => val, - None => { - log::info!("Ops session activated {session_id:?}"); - continue; - } - }; + let (stream, connection_id, task) = match self.state.stream(&mut session_id, task) { + Some(val) => val, + None => { + log::info!("Ops session activated {session_id:?}"); + continue; + } + }; match task { ComputeTask::RegisterOperation(op) => { @@ -180,7 +174,7 @@ where } log::info!("Closing connection"); - self.state.close(session_id).await; + self.state.close(session_id); } } diff --git a/crates/burn-remote/src/server/processor.rs b/crates/burn-remote/src/server/processor.rs index 1328827748..273753b5ca 100644 --- a/crates/burn-remote/src/server/processor.rs +++ b/crates/burn-remote/src/server/processor.rs @@ -5,7 +5,7 @@ use burn_tensor::{ TensorData, }; use core::marker::PhantomData; -use std::sync::mpsc::Sender; +use std::sync::mpsc::{Sender, SyncSender}; use crate::shared::{ConnectionId, TaskResponse, TaskResponseContent}; @@ -21,7 +21,6 @@ pub enum ProcessorTask { RegisterTensor(TensorId, TensorData), ReadTensor(ConnectionId, TensorDescription, Callback), Sync(ConnectionId, Callback), - Fence(Callback<()>), RegisterOrphan(TensorId), Close, } @@ -32,8 +31,8 @@ where <::FullPrecisionBridge as BackendBridge>::Target: ReprBackend, { - pub fn start(runner: Runner) -> Sender { - let (sender, rec) = std::sync::mpsc::channel(); + pub fn start(runner: Runner) -> SyncSender { + let (sender, rec) = std::sync::mpsc::sync_channel(1); std::thread::spawn(move || { for item in rec.iter() { @@ -45,7 +44,8 @@ where runner.register_orphan(&id); } ProcessorTask::Sync(id, callback) => { - runner.sync(); + let fut = runner.sync(); + burn_common::future::block_on(fut); callback .send(TaskResponse { content: TaskResponseContent::SyncBackend, @@ -57,7 +57,8 @@ where runner.register_tensor_data_id(id, data); } ProcessorTask::ReadTensor(id, tensor, callback) => { - let tensor = burn_common::future::block_on(runner.read_tensor(tensor)); + let fut = runner.read_tensor(tensor); + let tensor = burn_common::future::block_on(fut); callback .send(TaskResponse { content: TaskResponseContent::ReadTensor(tensor), @@ -67,14 +68,12 @@ where } ProcessorTask::Close => { let device = runner.device(); - runner.sync(); + let fut = runner.sync(); + burn_common::future::block_on(fut); core::mem::drop(runner); B::sync(&device); return; } - ProcessorTask::Fence(sender) => { - sender.send(()).unwrap(); - } } } }); diff --git a/crates/burn-remote/src/server/session.rs b/crates/burn-remote/src/server/session.rs index e0ac508129..6932a1e892 100644 --- a/crates/burn-remote/src/server/session.rs +++ b/crates/burn-remote/src/server/session.rs @@ -1,15 +1,15 @@ use burn_common::id::StreamId; +use burn_common::stub::Mutex; use burn_router::Runner; use burn_tensor::{ backend::{Backend, BackendBridge}, - repr::{ReprBackend, TensorDescription, TensorId, TensorStatus}, + repr::ReprBackend, Device, }; use std::{ collections::HashMap, - sync::mpsc::{Receiver, Sender}, + sync::mpsc::{Receiver, SyncSender}, }; -use tokio::sync::Mutex; use crate::shared::{ComputeTask, ConnectionId, SessionId, Task, TaskResponse}; @@ -21,14 +21,13 @@ use super::stream::Stream; /// a native backend would have. pub struct SessionManager { runner: Runner, - sessions: tokio::sync::Mutex>>, + sessions: Mutex>>, } struct Session { runner: Runner, - tensors: HashMap>, streams: HashMap>, - sender: Sender>, + sender: SyncSender>, receiver: Option>>, } @@ -47,12 +46,9 @@ where /// Register a new responder for the session. Only one responder can exist for a session for /// now. - pub async fn register_responder( - &self, - session_id: SessionId, - ) -> Receiver> { + pub fn register_responder(&self, session_id: SessionId) -> Receiver> { log::info!("Register responder for session {session_id}"); - let mut sessions = self.sessions.lock().await; + let mut sessions = self.sessions.lock().unwrap(); self.register_session(&mut sessions, session_id); let session = sessions.get_mut(&session_id).unwrap(); @@ -60,12 +56,12 @@ where } /// Get the stream for the current session and task. - pub async fn stream( + pub fn stream( &self, session_id: &mut Option, task: Task, ) -> Option<(Stream, ConnectionId, ComputeTask)> { - let mut sessions = self.sessions.lock().await; + let mut sessions = self.sessions.lock().unwrap(); let session_id = match session_id { Some(id) => *id, @@ -86,19 +82,17 @@ where Task::Compute(task, connection_id) => (task, connection_id), _ => panic!("Only support compute tasks."), }; - let stream = session.select(connection_id.stream_id, &task); + let stream = session.select(connection_id.stream_id); Some((stream, connection_id, task)) } - None => { - panic!("To be initialized"); - } + None => panic!("To be initialized"), } } /// Close the session with the given id. - pub async fn close(&self, session_id: Option) { + pub fn close(&self, session_id: Option) { if let Some(id) = session_id { - let mut sessions = self.sessions.lock().await; + let mut sessions = self.sessions.lock().unwrap(); if let Some(session) = sessions.get_mut(&id) { session.close(); } @@ -121,10 +115,9 @@ where ReprBackend, { fn new(runner: Runner) -> Self { - let (sender, reveiver) = std::sync::mpsc::channel(); + let (sender, reveiver) = std::sync::mpsc::sync_channel(1); Self { runner, - tensors: Default::default(), streams: Default::default(), sender, receiver: Some(reveiver), @@ -138,58 +131,7 @@ where } /// Select the current [stream](Stream) based on the given task. - fn select(&mut self, stream_id: StreamId, task: &ComputeTask) -> Stream { - // We have to check every streams involved in the last operation, making - // sure the backend is up-to-date with those operations. - // - // 1. We update the tensor status of all tensors in the task. - // 2. We don't keep track of tensors that are used for the last time. - let mut fences = Vec::new(); - for (tensor_id, status) in task.tensors_info() { - let tensor_stream_ids = match self.tensors.get(&tensor_id) { - Some(val) => val, - None => { - if status != TensorStatus::ReadWrite { - // Add the first stream that created the tensor that may be used by other - // streams later. - self.register_tensor(tensor_id, stream_id); - } - continue; - } - }; - - let current_stream_already_synced = tensor_stream_ids.contains(&stream_id); - - if !current_stream_already_synced { - // We only need to sync to the first stream that created the tensor. - if let Some(id) = tensor_stream_ids.iter().next() { - fences.push(*id); - } - } - - // We add the stream to the list of updated stream to avoid needed to flush other - // operations that might use this tensor. - self.register_tensor(tensor_id, stream_id); - - // If the tensor has the status `read_write`, it means no other stream can reuse it - // afterward, so we remove it from the state. - if status == TensorStatus::ReadWrite { - self.tensors.remove(&tensor_id); - } - } - - // Cleanup orphans. - if let ComputeTask::RegisterOrphan(tensor_id) = task { - self.tensors.remove(tensor_id); - } - - // We have to wait for the streams to be updated. - for stream_id in fences { - if let Some(stream) = self.streams.get(&stream_id) { - stream.fence_sync(); - } - } - + fn select(&mut self, stream_id: StreamId) -> Stream { // We return the stream. match self.streams.get(&stream_id) { Some(stream) => stream.clone(), @@ -201,17 +143,6 @@ where } } - fn register_tensor(&mut self, tensor_id: TensorId, stream_id: StreamId) { - match self.tensors.get_mut(&tensor_id) { - Some(ids) => { - ids.push(stream_id); - } - None => { - self.tensors.insert(tensor_id, vec![stream_id]); - } - } - } - // Close all streams created in the session. fn close(&mut self) { for (id, stream) in self.streams.drain() { @@ -220,23 +151,3 @@ where } } } - -impl ComputeTask { - fn tensors_info(&self) -> Vec<(TensorId, TensorStatus)> { - fn from_descriptions(desc: &[&TensorDescription]) -> Vec<(TensorId, TensorStatus)> { - desc.iter().map(|t| (t.id, t.status.clone())).collect() - } - - match self { - ComputeTask::RegisterOperation(op) => from_descriptions(&op.nodes()), - ComputeTask::RegisterTensor(tensor_id, _tensor_data) => { - vec![(*tensor_id, TensorStatus::NotInit)] - } - ComputeTask::RegisterOrphan(tensor_id) => { - vec![(*tensor_id, TensorStatus::ReadWrite)] - } - ComputeTask::ReadTensor(tensor_description) => from_descriptions(&[tensor_description]), - ComputeTask::SyncBackend => vec![], - } - } -} diff --git a/crates/burn-remote/src/server/stream.rs b/crates/burn-remote/src/server/stream.rs index 5ade2994c6..50a6f7d5a4 100644 --- a/crates/burn-remote/src/server/stream.rs +++ b/crates/burn-remote/src/server/stream.rs @@ -1,5 +1,5 @@ use core::marker::PhantomData; -use std::sync::mpsc::{Receiver, Sender}; +use std::sync::mpsc::{Receiver, SyncSender}; use crate::shared::{ConnectionId, TaskResponse}; @@ -15,8 +15,8 @@ use burn_tensor::{ /// server, protentially waiting to reconstruct consistency. #[derive(Clone)] pub struct Stream { - compute_sender: Sender, - writer_sender: Sender>, + compute_sender: SyncSender, + writer_sender: SyncSender>, _p: PhantomData, } @@ -26,7 +26,7 @@ where <::FullPrecisionBridge as BackendBridge>::Target: ReprBackend, { - pub fn new(runner: Runner, writer_sender: Sender>) -> Self { + pub fn new(runner: Runner, writer_sender: SyncSender>) -> Self { let sender = Processor::start(runner); Self { @@ -74,20 +74,6 @@ where self.writer_sender.send(callback_rec).unwrap(); } - // Ensure that all tasks are sent to the backend. - // - // It doesn't mean that the computation is done, but it means the backend has received the - // tasks, which may be queued. - pub fn fence_sync(&self) { - let (callback_sender, callback_rec) = std::sync::mpsc::channel(); - - self.compute_sender - .send(ProcessorTask::Fence(callback_sender.clone())) - .unwrap(); - - callback_rec.recv().unwrap(); - } - pub fn close(&self) { self.compute_sender.send(ProcessorTask::Close).unwrap(); } diff --git a/crates/burn-router/src/backend.rs b/crates/burn-router/src/backend.rs index 00877dd2cf..86338f9b57 100644 --- a/crates/burn-router/src/backend.rs +++ b/crates/burn-router/src/backend.rs @@ -74,7 +74,7 @@ impl Backend for BackendRouter { fn sync(device: &Self::Device) { let client = get_client::(device); - client.sync(); + burn_common::future::block_on(client.sync()); } } diff --git a/crates/burn-router/src/client/base.rs b/crates/burn-router/src/client/base.rs index 47781996fb..35723000b3 100644 --- a/crates/burn-router/src/client/base.rs +++ b/crates/burn-router/src/client/base.rs @@ -43,7 +43,7 @@ pub trait RunnerClient: Clone + Send + Sync + Sized { /// Drop the tensor with the given [tensor id](TensorId). fn register_orphan(&self, id: &TensorId); /// Sync the runner, ensure that all computations are finished. - fn sync(&self); + fn sync(&self) -> impl Future + Send + 'static; /// Seed the runner. fn seed(&self, seed: u64); } diff --git a/crates/burn-router/src/lib.rs b/crates/burn-router/src/lib.rs index 49cf9f14cd..644f65ee67 100644 --- a/crates/burn-router/src/lib.rs +++ b/crates/burn-router/src/lib.rs @@ -25,6 +25,15 @@ pub use types::*; /// It transfers tensors between backends via the underlying [tensor data](burn_tensor::TensorData). pub type DirectByteChannel = DirectChannel>; +/// Router backend. +/// +/// # Example +/// +/// ```ignore +/// type MyBackend = Router<(NdArray, Wgpu)>; +/// ``` +pub type Router = BackendRouter>; + extern crate alloc; #[cfg(test)] diff --git a/crates/burn-router/src/runner.rs b/crates/burn-router/src/runner.rs index 2b2c82a667..8945f47e8c 100644 --- a/crates/burn-router/src/runner.rs +++ b/crates/burn-router/src/runner.rs @@ -1,6 +1,5 @@ use alloc::{sync::Arc, vec::Vec}; -use spin::Mutex; - +use burn_common::stub::Mutex; use burn_tensor::{ backend::{Backend, BackendBridge}, ops::FullPrecisionBackend, @@ -12,6 +11,7 @@ use burn_tensor::{ }, DType, Element, ElementConversion, Shape, TensorData, }; +use core::future::Future; use super::{RouterTensor, RunnerClient}; use crate::{ @@ -70,7 +70,7 @@ where /// Get the tensor handle for the given [tensor description](TensorDescription). pub(crate) fn get_tensor_handle(&self, tensor: &TensorDescription) -> B::Handle { - let handles = &mut self.context.lock().handles; + let handles = &mut self.context.lock().unwrap().handles; handles.get_tensor_handle(tensor).handle } @@ -82,7 +82,7 @@ where dtype: DType, client: C, ) -> RouterTensor { - let mut ctx = self.context.lock(); + let mut ctx = self.context.lock().unwrap(); let id = ctx.create_empty_handle(); ctx.handles.register_handle(*id.as_ref(), handle); @@ -93,7 +93,7 @@ where /// Register a tensor from its data and id. pub fn register_tensor_data_id(&self, id: TensorId, data: TensorData) { - let mut ctx = self.context.lock(); + let mut ctx = self.context.lock().unwrap(); let dtype = data.dtype; if dtype.is_float() { @@ -114,7 +114,7 @@ where /// Register a tensor and returns its description. pub fn register_tensor_data_desc(&self, data: TensorData) -> TensorDescription { - let mut ctx = self.context.lock(); + let mut ctx = self.context.lock().unwrap(); let id = ctx.create_empty_handle(); let shape = data.shape.clone(); let dtype = data.dtype; @@ -144,7 +144,7 @@ where /// Register an empty tensor and returns its description. pub fn register_empty_tensor_desc(&self, shape: Vec, dtype: DType) -> TensorDescription { - let mut ctx = self.context.lock(); + let mut ctx = self.context.lock().unwrap(); let id = ctx.create_empty_handle(); core::mem::drop(ctx); @@ -181,7 +181,7 @@ where /// Execute a tensor operation. fn register(&self, op: OperationDescription) { // Remove unused tensor handles - let mut ctx = self.context.lock(); + let mut ctx = self.context.lock().unwrap(); ctx.free_orphans(); let handles = &mut ctx.handles; @@ -1208,22 +1208,36 @@ where } } - async fn read_tensor(&self, tensor: TensorDescription) -> TensorData { - let mut ctx = self.context.lock(); + fn read_tensor(&self, tensor: TensorDescription) -> impl Future + Send { + let mut ctx = self.context.lock().unwrap(); + + enum Output { + Float(B::FloatTensorPrimitive), + Int(B::IntTensorPrimitive), + Bool(B::BoolTensorPrimitive), + } - if tensor.dtype.is_float() { + let tensor = if tensor.dtype.is_float() { let tensor = ctx.handles.get_float_tensor::(&tensor); - B::float_into_data(tensor).await + Output::::Float(tensor) } else if tensor.dtype.is_int() { let tensor = ctx.handles.get_int_tensor::(&tensor); - B::int_into_data(tensor).await + Output::Int(tensor) } else if tensor.dtype.is_bool() { let tensor = ctx.handles.get_bool_tensor::(&tensor); - B::bool_into_data(tensor).await + Output::Bool(tensor) } else if let DType::QFloat(_) = tensor.dtype { todo!() } else { unimplemented!() + }; + + async move { + match tensor { + Output::Float(val) => B::float_into_data(val).await, + Output::Int(val) => B::int_into_data(val).await, + Output::Bool(val) => B::bool_into_data(val).await, + } } } @@ -1247,11 +1261,15 @@ where } fn register_orphan(&self, id: &TensorId) { - self.context.lock().drop_tensor_handle(*id) + self.context.lock().unwrap().drop_tensor_handle(*id) } - fn sync(&self) { - B::sync(&self.device); + fn sync(&self) -> impl Future + Send + 'static { + let device = self.device.clone(); + + async move { + B::sync(&device); + } } fn seed(&self, seed: u64) { diff --git a/crates/burn-router/src/types.rs b/crates/burn-router/src/types.rs index fe160ed1fa..f5c552abf6 100644 --- a/crates/burn-router/src/types.rs +++ b/crates/burn-router/src/types.rs @@ -179,12 +179,16 @@ macro_rules! impl_multi_backend_types { } } - fn sync(&self) { - match self { - Self::$DefaultBackend(runner) => runner.sync(), + fn sync(&self) -> impl core::future::Future + Send + 'static { + let fut: core::pin::Pin + Send + 'static>> = match self { + Self::$DefaultBackend(runner) => Box::pin(runner.sync()), $( - Self::$OtherBackend(runner) => runner.sync(), + Self::$OtherBackend(runner) => Box::pin(runner.sync()), )+ + }; + + async move { + fut.await; } } diff --git a/crates/burn-train/src/metric/acc.rs b/crates/burn-train/src/metric/acc.rs index ad24fe077f..cf6faaf03d 100644 --- a/crates/burn-train/src/metric/acc.rs +++ b/crates/burn-train/src/metric/acc.rs @@ -40,23 +40,22 @@ impl Metric for AccuracyMetric { type Input = AccuracyInput; fn update(&mut self, input: &AccuracyInput, _metadata: &MetricMetadata) -> MetricEntry { - let [batch_size, _n_classes] = input.outputs.dims(); + let targets = input.targets.clone(); + let outputs = input.outputs.clone(); - let targets = input.targets.clone().to_device(&B::Device::default()); - let outputs = input - .outputs - .clone() - .argmax(1) - .to_device(&B::Device::default()) - .reshape([batch_size]); + let [batch_size, _n_classes] = outputs.dims(); + + let outputs = outputs.argmax(1).reshape([batch_size]); let accuracy = match self.pad_token { Some(pad_token) => { let mask = targets.clone().equal_elem(pad_token as i64); - let matches = outputs.equal(targets).int().mask_fill(mask.clone(), 0); - let num_pad = mask.int().sum().into_scalar().elem::(); + let matches = outputs.equal(targets).float().mask_fill(mask.clone(), 0); + let num_pad = mask.float().sum(); + + let acc = matches.sum() / (num_pad.neg() + batch_size as f32); - matches.sum().into_scalar().elem::() / (batch_size as f64 - num_pad) + acc.into_scalar().elem::() } None => { outputs diff --git a/crates/burn/Cargo.toml b/crates/burn/Cargo.toml index de28869e36..adfb90786a 100644 --- a/crates/burn/Cargo.toml +++ b/crates/burn/Cargo.toml @@ -57,6 +57,7 @@ wgpu = ["burn-core/wgpu"] wgpu-spirv = ["burn-core/wgpu-spirv"] remote = ["burn-core/remote"] server = ["burn-core/server"] +router = ["burn-core/router"] # Network utils network = ["burn-core/network"] diff --git a/examples/server/src/lib.rs b/examples/server/src/lib.rs index d206771ea0..92cba57a2a 100644 --- a/examples/server/src/lib.rs +++ b/examples/server/src/lib.rs @@ -9,10 +9,10 @@ pub fn start() { cfg_if::cfg_if! { if #[cfg(feature = "ndarray")]{ burn::server::start::(Default::default(), port); - } else if #[cfg(feature = "wgpu")] { - burn::server::start::(Default::default(), port); } else if #[cfg(feature = "cuda-jit")]{ burn::server::start::(Default::default(), port); + } else if #[cfg(feature = "wgpu")] { + burn::server::start::(Default::default(), port); } else { panic!("No backend selected, can't start server on port {port}"); }