Skip to content

Commit

Permalink
Refactor/burn compute (#1580)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Apr 23, 2024
1 parent c579686 commit 886a1de
Show file tree
Hide file tree
Showing 26 changed files with 809 additions and 612 deletions.
499 changes: 261 additions & 238 deletions Cargo.lock

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions crates/burn-compute/src/channel/base.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
use crate::server::{ComputeServer, Handle};
use crate::server::{Binding, ComputeServer, Handle};
use alloc::vec::Vec;
use burn_common::reader::Reader;

/// The ComputeChannel trait links the ComputeClient to the ComputeServer
/// while ensuring thread-safety
pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug + Send + Sync {
/// Given a handle, returns owned resource as bytes
fn read(&self, handle: &Handle<Server>) -> Reader<Vec<u8>>;
/// Given a binding, returns owned resource as bytes
fn read(&self, binding: Binding<Server>) -> Reader<Vec<u8>>;

/// Given a resource as bytes, stores it and returns the resource handle
fn create(&self, data: &[u8]) -> Handle<Server>;

/// Reserves `size` bytes in the storage, and returns a handle over them
fn empty(&self, size: usize) -> Handle<Server>;

/// Executes the `kernel` over the given `handles`.
fn execute(&self, kernel: Server::Kernel, handles: &[&Handle<Server>]);
/// Executes the `kernel` over the given `bindings`.
fn execute(&self, kernel: Server::Kernel, bindings: Vec<Binding<Server>>);

/// Wait for the completion of every task in the server.
fn sync(&self);
Expand Down
10 changes: 5 additions & 5 deletions crates/burn-compute/src/channel/cell.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::ComputeChannel;
use crate::server::{ComputeServer, Handle};
use crate::server::{Binding, ComputeServer, Handle};
use alloc::sync::Arc;
use alloc::vec::Vec;
use burn_common::reader::Reader;
Expand Down Expand Up @@ -42,8 +42,8 @@ impl<Server> ComputeChannel<Server> for RefCellComputeChannel<Server>
where
Server: ComputeServer,
{
fn read(&self, handle: &Handle<Server>) -> Reader<Vec<u8>> {
self.server.borrow_mut().read(handle)
fn read(&self, binding: Binding<Server>) -> Reader<Vec<u8>> {
self.server.borrow_mut().read(binding)
}

fn create(&self, resource: &[u8]) -> Handle<Server> {
Expand All @@ -54,10 +54,10 @@ where
self.server.borrow_mut().empty(size)
}

fn execute(&self, kernel_description: Server::Kernel, handles: &[&Handle<Server>]) {
fn execute(&self, kernel_description: Server::Kernel, bindings: Vec<Binding<Server>>) {
self.server
.borrow_mut()
.execute(kernel_description, handles)
.execute(kernel_description, bindings)
}

fn sync(&self) {
Expand Down
29 changes: 11 additions & 18 deletions crates/burn-compute/src/channel/mpsc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
use burn_common::reader::Reader;

use super::ComputeChannel;
use crate::server::{ComputeServer, Handle};
use crate::server::{Binding, ComputeServer, Handle};

/// Create a channel using the [multi-producer, single-consumer channel](mpsc) to communicate with
/// the compute server spawn on its own thread.
Expand All @@ -33,10 +33,10 @@ enum Message<Server>
where
Server: ComputeServer,
{
Read(Handle<Server>, Callback<Reader<Vec<u8>>>),
Read(Binding<Server>, Callback<Reader<Vec<u8>>>),
Create(Vec<u8>, Callback<Handle<Server>>),
Empty(usize, Callback<Handle<Server>>),
ExecuteKernel(Server::Kernel, Vec<Handle<Server>>),
ExecuteKernel(Server::Kernel, Vec<Binding<Server>>),
Sync(Callback<()>),
}

Expand All @@ -51,9 +51,8 @@ where
let _handle = thread::spawn(move || {
while let Ok(message) = receiver.recv() {
match message {
Message::Read(handle, callback) => {
let data = server.read(&handle);
core::mem::drop(handle);
Message::Read(binding, callback) => {
let data = server.read(binding);
callback.send(data).unwrap();
}
Message::Create(data, callback) => {
Expand All @@ -64,8 +63,8 @@ where
let handle = server.empty(size);
callback.send(handle).unwrap();
}
Message::ExecuteKernel(kernel, handles) => {
server.execute(kernel, &handles.iter().collect::<Vec<_>>());
Message::ExecuteKernel(kernel, bindings) => {
server.execute(kernel, bindings);
}
Message::Sync(callback) => {
server.sync();
Expand Down Expand Up @@ -93,12 +92,12 @@ impl<Server> ComputeChannel<Server> for MpscComputeChannel<Server>
where
Server: ComputeServer + 'static,
{
fn read(&self, handle: &Handle<Server>) -> Reader<Vec<u8>> {
fn read(&self, binding: Binding<Server>) -> Reader<Vec<u8>> {
let (callback, response) = mpsc::channel();

self.state
.sender
.send(Message::Read(handle.clone(), callback))
.send(Message::Read(binding, callback))
.unwrap();

self.response(response)
Expand Down Expand Up @@ -126,16 +125,10 @@ where
self.response(response)
}

fn execute(&self, kernel: Server::Kernel, handles: &[&Handle<Server>]) {
fn execute(&self, kernel: Server::Kernel, bindings: Vec<Binding<Server>>) {
self.state
.sender
.send(Message::ExecuteKernel(
kernel,
handles
.iter()
.map(|h| (*h).clone())
.collect::<Vec<Handle<Server>>>(),
))
.send(Message::ExecuteKernel(kernel, bindings))
.unwrap()
}

Expand Down
6 changes: 3 additions & 3 deletions crates/burn-compute/src/channel/mutex.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::ComputeChannel;
use crate::server::{ComputeServer, Handle};
use crate::server::{Binding, ComputeServer, Handle};
use alloc::sync::Arc;
use alloc::vec::Vec;
use burn_common::reader::Reader;
Expand Down Expand Up @@ -35,7 +35,7 @@ impl<Server> ComputeChannel<Server> for MutexComputeChannel<Server>
where
Server: ComputeServer,
{
fn read(&self, handle: &Handle<Server>) -> Reader<Vec<u8>> {
fn read(&self, handle: Binding<Server>) -> Reader<Vec<u8>> {
self.server.lock().read(handle)
}

Expand All @@ -47,7 +47,7 @@ where
self.server.lock().empty(size)
}

fn execute(&self, kernel: Server::Kernel, handles: &[&Handle<Server>]) {
fn execute(&self, kernel: Server::Kernel, handles: Vec<Binding<Server>>) {
self.server.lock().execute(kernel, handles)
}

Expand Down
14 changes: 7 additions & 7 deletions crates/burn-compute/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
channel::ComputeChannel,
server::{ComputeServer, Handle},
server::{Binding, ComputeServer, Handle},
tune::{AutotuneOperationSet, Tuner},
};
use alloc::vec::Vec;
Expand Down Expand Up @@ -39,9 +39,9 @@ where
Self { channel, tuner }
}

/// Given a handle, returns owned resource as bytes.
pub fn read(&self, handle: &Handle<Server>) -> Reader<Vec<u8>> {
self.channel.read(handle)
/// Given a binding, returns owned resource as bytes.
pub fn read(&self, binding: Binding<Server>) -> Reader<Vec<u8>> {
self.channel.read(binding)
}

/// Given a resource, stores it and returns the resource handle.
Expand All @@ -54,9 +54,9 @@ where
self.channel.empty(size)
}

/// Executes the `kernel` over the given `handles`.
pub fn execute(&self, kernel: Server::Kernel, handles: &[&Handle<Server>]) {
self.channel.execute(kernel, handles)
/// Executes the `kernel` over the given `bindings`.
pub fn execute(&self, kernel: Server::Kernel, bindings: Vec<Binding<Server>>) {
self.channel.execute(kernel, bindings)
}

/// Wait for the completion of every task in the server.
Expand Down
140 changes: 128 additions & 12 deletions crates/burn-compute/src/id.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
use alloc::sync::Arc;

#[macro_export(local_inner_macros)]
/// Create a new storage ID type.
macro_rules! storage_id_type {
($name:ident) => {
#[derive(Clone, Hash, PartialEq, Eq)]
/// Storage ID.
#[derive(Clone, Hash, PartialEq, Eq)]
pub struct $name {
id: alloc::sync::Arc<alloc::string::String>,
value: usize,
}

impl $name {
/// Create a new ID.
pub fn new() -> Self {
Self {
id: alloc::sync::Arc::new(burn_common::id::IdGenerator::generate()),
use core::sync::atomic::{AtomicUsize, Ordering};

static COUNTER: AtomicUsize = AtomicUsize::new(0);

let value = COUNTER.fetch_add(1, Ordering::Relaxed);
if value == usize::MAX {
core::panic!("Memory ID overflowed");
}
Self { value }
}
}

Expand All @@ -25,26 +33,134 @@ macro_rules! storage_id_type {
};
}

/// Reference to a buffer handle.
#[derive(Clone, Debug)]
pub struct HandleRef<Id> {
id: Arc<Id>,
all: Arc<()>,
}

/// Reference to buffer binding.
#[derive(Clone, Debug)]
pub struct BindingRef<Id> {
id: Id,
_all: Arc<()>,
}

impl<Id> BindingRef<Id>
where
Id: Clone + core::fmt::Debug,
{
/// The id associated to the buffer.
pub(crate) fn id(&self) -> &Id {
&self.id
}
}

impl<Id> HandleRef<Id>
where
Id: Clone + core::fmt::Debug,
{
/// Create a new handle.
pub(crate) fn new(id: Id) -> Self {
Self {
id: Arc::new(id),
all: Arc::new(()),
}
}

/// The id associated to the handle.
pub(crate) fn id(&self) -> &Id {
&self.id
}

/// Get the binding.
pub(crate) fn binding(self) -> BindingRef<Id> {
BindingRef {
id: self.id.as_ref().clone(),
_all: self.all,
}
}

/// If the handle can be mut.
pub(crate) fn can_mut(&self) -> bool {
// 1 memory management reference with 1 tensor reference.
Arc::strong_count(&self.id) <= 2
}

/// If the resource is free.
pub(crate) fn is_free(&self) -> bool {
Arc::strong_count(&self.all) <= 1
}
}

#[macro_export(local_inner_macros)]
/// Create a new memory ID type.
/// Create new memory ID types.
macro_rules! memory_id_type {
($name:ident) => {
#[derive(Clone, Hash, PartialEq, Eq, Debug)]
($id:ident, $handle:ident, $binding:ident) => {
/// Memory Handle.
#[derive(Clone, Debug)]
pub struct $handle {
value: $crate::id::HandleRef<$id>,
}

/// Binding of a memory handle.
#[derive(Clone, Debug)]
pub struct $binding {
value: $crate::id::BindingRef<$id>,
}

/// Memory ID.
pub struct $name {
id: alloc::sync::Arc<alloc::string::String>,
#[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)]
pub struct $id {
value: usize,
}

impl $name {
impl $handle {
/// Create a new ID.
pub(crate) fn new() -> Self {
let value = Self::gen_id();
Self {
id: alloc::sync::Arc::new(burn_common::id::IdGenerator::generate()),
value: $crate::id::HandleRef::new($id { value }),
}
}

pub(crate) fn binding(self) -> $binding {
$binding {
value: self.value.binding(),
}
}

fn gen_id() -> usize {
static COUNTER: core::sync::atomic::AtomicUsize =
core::sync::atomic::AtomicUsize::new(0);

let value = COUNTER.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
if value == usize::MAX {
core::panic!("Memory ID overflowed");
}

value
}
}

impl Default for $name {
impl core::ops::Deref for $handle {
type Target = $crate::id::HandleRef<$id>;

fn deref(&self) -> &Self::Target {
&self.value
}
}

impl core::ops::Deref for $binding {
type Target = $crate::id::BindingRef<$id>;

fn deref(&self) -> &Self::Target {
&self.value
}
}

impl Default for $handle {
fn default() -> Self {
Self::new()
}
Expand Down
Loading

0 comments on commit 886a1de

Please sign in to comment.