Skip to content

Commit 27779e3

Browse files
maxbrunsfeldmaxdeviantjvmncs
authored
Refactor: Restructure collab main function to prepare for new subcommand: serve llm (zed-industries#15824)
This is just a refactor that we're landing ahead of any functional changes to make sure we haven't broken anything. Release Notes: - N/A Co-authored-by: Marshall <[email protected]> Co-authored-by: Jason <[email protected]>
1 parent 705f7e7 commit 27779e3

File tree

5 files changed

+144
-87
lines changed

5 files changed

+144
-87
lines changed

Procfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
collab: RUST_LOG=${RUST_LOG:-info} cargo run --package=collab serve
1+
collab: RUST_LOG=${RUST_LOG:-info} cargo run --package=collab serve all
22
livekit: livekit-server --dev
33
blob_store: ./script/run-local-minio

crates/collab/src/api.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ impl std::fmt::Display for CloudflareIpCountryHeader {
6161
}
6262
}
6363

64-
pub fn routes(rpc_server: Option<Arc<rpc::Server>>, state: Arc<AppState>) -> Router<(), Body> {
64+
pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
6565
Router::new()
6666
.route("/user", get(get_authenticated_user))
6767
.route("/users/:id/access_tokens", post(create_access_token))
@@ -70,7 +70,6 @@ pub fn routes(rpc_server: Option<Arc<rpc::Server>>, state: Arc<AppState>) -> Rou
7070
.merge(contributors::router())
7171
.layer(
7272
ServiceBuilder::new()
73-
.layer(Extension(state))
7473
.layer(Extension(rpc_server))
7574
.layer(middleware::from_fn(validate_api_token)),
7675
)
@@ -152,12 +151,8 @@ struct CreateUserParams {
152151
}
153152

154153
async fn get_rpc_server_snapshot(
155-
Extension(rpc_server): Extension<Option<Arc<rpc::Server>>>,
154+
Extension(rpc_server): Extension<Arc<rpc::Server>>,
156155
) -> Result<ErasedJson> {
157-
let Some(rpc_server) = rpc_server else {
158-
return Err(Error::Internal(anyhow!("rpc server is not available")));
159-
};
160-
161156
Ok(ErasedJson::pretty(rpc_server.snapshot().await))
162157
}
163158

crates/collab/src/lib.rs

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ pub mod auth;
33
pub mod db;
44
pub mod env;
55
pub mod executor;
6+
pub mod llm;
67
mod rate_limiter;
78
pub mod rpc;
89
pub mod seed;
@@ -124,7 +125,7 @@ impl std::fmt::Display for Error {
124125

125126
impl std::error::Error for Error {}
126127

127-
#[derive(Deserialize)]
128+
#[derive(Clone, Deserialize)]
128129
pub struct Config {
129130
pub http_port: u16,
130131
pub database_url: String,
@@ -176,6 +177,29 @@ impl Config {
176177
}
177178
}
178179

180+
/// The service mode that collab should run in.
181+
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
182+
pub enum ServiceMode {
183+
Api,
184+
Collab,
185+
Llm,
186+
All,
187+
}
188+
189+
impl ServiceMode {
190+
pub fn is_collab(&self) -> bool {
191+
matches!(self, Self::Collab | Self::All)
192+
}
193+
194+
pub fn is_api(&self) -> bool {
195+
matches!(self, Self::Api | Self::All)
196+
}
197+
198+
pub fn is_llm(&self) -> bool {
199+
matches!(self, Self::Llm | Self::All)
200+
}
201+
}
202+
179203
pub struct AppState {
180204
pub db: Arc<Database>,
181205
pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,

crates/collab/src/llm.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
use std::sync::Arc;
2+
3+
use crate::{executor::Executor, Config, Result};
4+
5+
pub struct LlmState {
6+
pub config: Config,
7+
pub executor: Executor,
8+
}
9+
10+
impl LlmState {
11+
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
12+
let this = Self { config, executor };
13+
14+
Ok(Arc::new(this))
15+
}
16+
}

crates/collab/src/main.rs

Lines changed: 100 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use axum::{
55
routing::get,
66
Extension, Router,
77
};
8-
use collab::api::billing::poll_stripe_events_periodically;
8+
use collab::{api::billing::poll_stripe_events_periodically, llm::LlmState, ServiceMode};
99
use collab::{
1010
api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor,
1111
rpc::ResultExt, AppState, Config, RateLimiter, Result,
@@ -56,88 +56,99 @@ async fn main() -> Result<()> {
5656
collab::seed::seed(&config, &db, true).await?;
5757
}
5858
Some("serve") => {
59-
let (is_api, is_collab) = if let Some(next) = args.next() {
60-
(next == "api", next == "collab")
61-
} else {
62-
(true, true)
59+
let mode = match args.next().as_deref() {
60+
Some("collab") => ServiceMode::Collab,
61+
Some("api") => ServiceMode::Api,
62+
Some("llm") => ServiceMode::Llm,
63+
Some("all") => ServiceMode::All,
64+
_ => {
65+
return Err(anyhow!(
66+
"usage: collab <version | migrate | seed | serve <api|collab|llm|all>>"
67+
))?;
68+
}
6369
};
64-
if !is_api && !is_collab {
65-
Err(anyhow!(
66-
"usage: collab <version | migrate | seed | serve [api|collab]>"
67-
))?;
68-
}
6970

7071
let config = envy::from_env::<Config>().expect("error loading config");
7172
init_tracing(&config);
73+
let mut app = Router::new()
74+
.route("/", get(handle_root))
75+
.route("/healthz", get(handle_liveness_probe))
76+
.layer(Extension(mode));
7277

73-
run_migrations(&config).await?;
74-
75-
let state = AppState::new(config, Executor::Production).await?;
76-
77-
let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
78+
let listener = TcpListener::bind(&format!("0.0.0.0:{}", config.http_port))
7879
.expect("failed to bind TCP listener");
7980

80-
let rpc_server = if is_collab {
81-
let epoch = state
82-
.db
83-
.create_server(&state.config.zed_environment)
84-
.await?;
85-
let rpc_server = collab::rpc::Server::new(epoch, state.clone());
86-
rpc_server.start().await?;
87-
88-
Some(rpc_server)
89-
} else {
90-
None
91-
};
81+
let mut on_shutdown = None;
9282

93-
if is_collab {
94-
state.db.purge_old_embeddings().await.trace_err();
95-
RateLimiter::save_periodically(state.rate_limiter.clone(), state.executor.clone());
96-
}
83+
if mode.is_llm() {
84+
let state = LlmState::new(config.clone(), Executor::Production).await?;
9785

98-
if is_api {
99-
poll_stripe_events_periodically(state.clone());
100-
fetch_extensions_from_blob_store_periodically(state.clone());
86+
app = app.layer(Extension(state.clone()));
10187
}
10288

103-
let mut app = collab::api::routes(rpc_server.clone(), state.clone());
104-
if let Some(rpc_server) = rpc_server.clone() {
105-
app = app.merge(collab::rpc::routes(rpc_server))
106-
}
107-
app = app
108-
.merge(
109-
Router::new()
110-
.route("/", get(handle_root))
111-
.route("/healthz", get(handle_liveness_probe))
112-
.merge(collab::api::extensions::router())
89+
if mode.is_collab() || mode.is_api() {
90+
run_migrations(&config).await?;
91+
92+
let state = AppState::new(config, Executor::Production).await?;
93+
94+
if mode.is_collab() {
95+
state.db.purge_old_embeddings().await.trace_err();
96+
RateLimiter::save_periodically(
97+
state.rate_limiter.clone(),
98+
state.executor.clone(),
99+
);
100+
101+
let epoch = state
102+
.db
103+
.create_server(&state.config.zed_environment)
104+
.await?;
105+
let rpc_server = collab::rpc::Server::new(epoch, state.clone());
106+
rpc_server.start().await?;
107+
108+
app = app
109+
.merge(collab::api::routes(rpc_server.clone()))
110+
.merge(collab::rpc::routes(rpc_server.clone()));
111+
112+
on_shutdown = Some(Box::new(move || rpc_server.teardown()));
113+
}
114+
115+
if mode.is_api() {
116+
poll_stripe_events_periodically(state.clone());
117+
fetch_extensions_from_blob_store_periodically(state.clone());
118+
119+
app = app
113120
.merge(collab::api::events::router())
114-
.layer(Extension(state.clone())),
115-
)
116-
.layer(
117-
TraceLayer::new_for_http()
118-
.make_span_with(|request: &Request<_>| {
119-
let matched_path = request
120-
.extensions()
121-
.get::<MatchedPath>()
122-
.map(MatchedPath::as_str);
123-
124-
tracing::info_span!(
125-
"http_request",
126-
method = ?request.method(),
127-
matched_path,
128-
)
129-
})
130-
.on_response(
131-
|response: &Response<_>, latency: Duration, _: &tracing::Span| {
132-
let duration_ms = latency.as_micros() as f64 / 1000.;
133-
tracing::info!(
134-
duration_ms,
135-
status = response.status().as_u16(),
136-
"finished processing request"
137-
);
138-
},
139-
),
140-
);
121+
.merge(collab::api::extensions::router())
122+
}
123+
124+
app = app.layer(Extension(state.clone()));
125+
}
126+
127+
app = app.layer(
128+
TraceLayer::new_for_http()
129+
.make_span_with(|request: &Request<_>| {
130+
let matched_path = request
131+
.extensions()
132+
.get::<MatchedPath>()
133+
.map(MatchedPath::as_str);
134+
135+
tracing::info_span!(
136+
"http_request",
137+
method = ?request.method(),
138+
matched_path,
139+
)
140+
})
141+
.on_response(
142+
|response: &Response<_>, latency: Duration, _: &tracing::Span| {
143+
let duration_ms = latency.as_micros() as f64 / 1000.;
144+
tracing::info!(
145+
duration_ms,
146+
status = response.status().as_u16(),
147+
"finished processing request"
148+
);
149+
},
150+
),
151+
);
141152

142153
#[cfg(unix)]
143154
let signal = async move {
@@ -174,16 +185,16 @@ async fn main() -> Result<()> {
174185
signal.await;
175186
tracing::info!("Received interrupt signal");
176187

177-
if let Some(rpc_server) = rpc_server {
178-
rpc_server.teardown();
188+
if let Some(on_shutdown) = on_shutdown {
189+
on_shutdown();
179190
}
180191
})
181192
.await
182193
.map_err(|e| anyhow!(e))?;
183194
}
184195
_ => {
185196
Err(anyhow!(
186-
"usage: collab <version | migrate | seed | serve [api|collab]>"
197+
"usage: collab <version | migrate | seed | serve <api|collab|llm|all>>"
187198
))?;
188199
}
189200
}
@@ -222,12 +233,23 @@ async fn run_migrations(config: &Config) -> Result<()> {
222233
return Ok(());
223234
}
224235

225-
async fn handle_root() -> String {
226-
format!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"))
236+
async fn handle_root(Extension(mode): Extension<ServiceMode>) -> String {
237+
format!(
238+
"collab {mode:?} v{VERSION} ({})",
239+
REVISION.unwrap_or("unknown")
240+
)
227241
}
228242

229-
async fn handle_liveness_probe(Extension(state): Extension<Arc<AppState>>) -> Result<String> {
230-
state.db.get_all_users(0, 1).await?;
243+
async fn handle_liveness_probe(
244+
app_state: Option<Extension<Arc<AppState>>>,
245+
llm_state: Option<Extension<Arc<LlmState>>>,
246+
) -> Result<String> {
247+
if let Some(state) = app_state {
248+
state.db.get_all_users(0, 1).await?;
249+
}
250+
251+
if let Some(_llm_state) = llm_state {}
252+
231253
Ok("ok".to_string())
232254
}
233255

0 commit comments

Comments
 (0)