Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multithreaded SocketServer #106

Closed
wants to merge 10 commits into from
16 changes: 15 additions & 1 deletion qos-core/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ pub const CID: &str = "cid";
pub const PORT: &str = "port";
/// "usock"
pub const USOCK: &str = "usock";
const MOCK: &str = "mock";
/// Name for the option to specify the quorum key file.
pub const QUORUM_FILE_OPT: &str = "quorum-file";
/// Name for the option to specify the pivot key file.
Expand All @@ -26,7 +25,11 @@ pub const PIVOT_FILE_OPT: &str = "pivot-file";
pub const EPHEMERAL_FILE_OPT: &str = "ephemeral-file";
/// Name for the option to specify the manifest file.
pub const MANIFEST_FILE_OPT: &str = "manifest-file";
/// Name for the option to specify the number of threads for the socket server's
/// thread pool.
pub const THREAD_COUNT: &str = "thread-count";
const APP_USOCK: &str = "app-usock";
const MOCK: &str = "mock";

/// CLI options for starting up the enclave server.
#[derive(Default, Clone, Debug, PartialEq)]
Expand Down Expand Up @@ -118,6 +121,12 @@ impl EnclaveOpts {
.expect("has a default value.")
.clone()
}

fn thread_count(&self) -> Option<usize> {
self.parsed
.single(THREAD_COUNT)
.map(|n| n.parse().expect("failed to parse `--thread-count`"))
}
}

/// Enclave server CLI.
Expand All @@ -143,6 +152,7 @@ impl CLI {
opts.nsm(),
opts.addr(),
opts.app_addr(),
opts.thread_count(),
);
}
}
Expand Down Expand Up @@ -198,6 +208,10 @@ impl GetParserForOptions for EnclaveParser {
.takes_value(true)
.default_value(SEC_APP_SOCK)
)
.token(
Token::new(THREAD_COUNT, "count of threads for the socket servers thread pool")
emostov marked this conversation as resolved.
Show resolved Hide resolved
.takes_value(true)
)
}
}

Expand Down
2 changes: 1 addition & 1 deletion qos-core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ impl From<borsh::maybestd::io::Error> for ClientError {
}

/// Client for communicating with the enclave [`crate::server::SocketServer`].
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Client {
addr: SocketAddress,
}
Expand Down
5 changes: 3 additions & 2 deletions qos-core/src/coordinator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@ impl Coordinator {
/// - If waiting for the pivot errors.
pub fn execute(
handles: &Handles,
nsm: Box<dyn NsmProvider + Send>,
nsm: Box<dyn NsmProvider>,
addr: SocketAddress,
app_addr: SocketAddress,
thread_count: Option<usize>,
) {
let handles2 = handles.clone();
std::thread::spawn(move || {
let executor = Executor::new(nsm, handles2, app_addr);
SocketServer::listen(addr, executor).unwrap();
SocketServer::listen(addr, executor, thread_count).unwrap();
});

loop {
Expand Down
1 change: 1 addition & 0 deletions qos-core/src/io/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
//! within this module.

mod stream;
pub mod threadpool;

pub use stream::SocketAddress;
pub(crate) use stream::{Listener, Stream};
Expand Down
1 change: 0 additions & 1 deletion qos-core/src/io/stream.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
//! Abstractions to handle connection based socket streams.

use std::{mem::size_of, os::unix::io::RawFd};

#[cfg(feature = "vm")]
Expand Down
166 changes: 166 additions & 0 deletions qos-core/src/io/threadpool.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
//! Simple thread pool for running concurrent jobs on separate threads.

use std::{
sync::{mpsc, Arc, Mutex},
thread,
};

type Job = Box<dyn FnOnce() + Send + 'static>;

/// Errors for a [`ThreadPool`]
#[derive(Debug)]
pub enum ThreadPoolError {
/// Wrapper for `std::sync::mpsc::SendError<Message>`.
MpscSendError(std::sync::mpsc::SendError<Message>),
}

/// An abstraction for executing jobs concurrently across a fixed number of
/// threads.
pub struct ThreadPool {
workers: Vec<Worker>,
sender: mpsc::Sender<Message>,
}

/// Message sent to a worker thread in the thread pool.
pub enum Message {
/// Start a new Job.
NewJob(Job),
/// Terminate the thread.
Terminate,
}

impl ThreadPool {
/// Create a new instance of [`Self`].
///
/// # Arguments
///
/// * `size` - Number of threads in pool.
///
/// # Panics
///
/// Panics if the `size` is zero.
#[must_use]
pub fn new(size: usize) -> ThreadPool {
assert!(size > 0);

let (sender, receiver) = mpsc::channel();

let receiver = Arc::new(Mutex::new(receiver));

let mut workers = Vec::with_capacity(size);

for _ in 0..size {
workers.push(Worker::new(Arc::clone(&receiver)));
}

ThreadPool { workers, sender }
}

/// Execute `f` in the next free thread. This is non blocking.
///
/// # Errors
///
/// Returns an error if the `f` could not be sent to a worker thread.
pub fn execute<F>(&self, f: F) -> Result<(), ThreadPoolError>
where
F: FnOnce() + Send + 'static,
{
let job = Box::new(f);

self.sender
.send(Message::NewJob(job))
.map_err(ThreadPoolError::MpscSendError)?;
Ok(())
}
}

impl Drop for ThreadPool {
fn drop(&mut self) {
// Send 1 termination signal per worker thread. We don't know exactly
// which worker will recieve each message, but since we know that a
// worker will stop receiving after getting the terminate message, we
// can be confident that non-terminated threads will recieve the
// terminate message exactly once and terminated threads will never
// receive the message. Thus, if we have N workers and send N terminate
// messages we will terminate all worker threads.
for _ in &self.workers {
let _ = self
.sender
.send(Message::Terminate)
.map_err(|e| eprintln!("`ThreadPool::drop`: {:?}", e));
}

for worker in &mut self.workers {
if let Some(thread) = worker.thread.take() {
let _ = thread.join().map_err(|e| {
eprintln!("`ThreadPool::drop: failed to join: {:?}`", e);
});
}
}
}
}

struct Worker {
thread: Option<thread::JoinHandle<()>>,
}

impl Worker {
fn new(receiver: Arc<Mutex<mpsc::Receiver<Message>>>) -> Worker {
let thread = thread::spawn(move || loop {
let message = receiver
.lock()
.expect("channel receiver mutex poisoned")
.recv()
.expect("tried to receive on a closed chanel");

match message {
Message::NewJob(job) => {
job();
}
Message::Terminate => {
break;
}
}
});

Worker { thread: Some(thread) }
}
}

#[cfg(test)]
mod test {
use std::{
collections::HashMap,
sync::{Arc, Mutex},
};

use super::ThreadPool;

#[test]
fn graceful_shutdown_works() {
const KEY: &str = "key";
const EXECUTIONS: usize = 500;

let mut db = HashMap::new();
db.insert(KEY, 0);

let db = Arc::new(Mutex::new(db));

// create job that
emostov marked this conversation as resolved.
Show resolved Hide resolved
let thread_pool = ThreadPool::new(128);

for _ in 0..EXECUTIONS {
let db2 = db.clone();
thread_pool
.execute(move || {
*db2.lock().unwrap().get_mut(KEY).unwrap() += 1;
})
.unwrap();
}

// Graceful shutdown
drop(thread_pool);

assert_eq!(*db.lock().unwrap().get(KEY).unwrap(), EXECUTIONS);
}
}
1 change: 1 addition & 0 deletions qos-core/src/protocol/attestor/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub const MOCK_NSM_ATTESTATION_DOCUMENT: &[u8] =
include_bytes!("./static/mock_attestation_doc");

/// Mock Nitro Secure Module endpoint that should only ever be used for testing.
#[derive(Clone)]
pub struct MockNsm;
impl NsmProvider for MockNsm {
fn nsm_process_request(
Expand Down
2 changes: 1 addition & 1 deletion qos-core/src/protocol/attestor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub mod types;
/// generic so mock providers can be subbed in for testing. In production use
/// [`Nsm`].
// https://github.com/aws/aws-nitro-enclaves-nsm-api/blob/main/docs/attestation_process.md
pub trait NsmProvider {
pub trait NsmProvider: Send + Sync {
/// Create a message with input data and output capacity from a given
/// request, then send it to the NSM driver via `ioctl()` and wait
/// for the driver's response.
Expand Down
Loading