@@ -5,7 +5,7 @@ use axum::{
55 routing:: get,
66 Extension , Router ,
77} ;
8- use collab:: api:: billing:: poll_stripe_events_periodically;
8+ use collab:: { api:: billing:: poll_stripe_events_periodically, llm :: LlmState , ServiceMode } ;
99use collab:: {
1010 api:: fetch_extensions_from_blob_store_periodically, db, env, executor:: Executor ,
1111 rpc:: ResultExt , AppState , Config , RateLimiter , Result ,
@@ -56,88 +56,99 @@ async fn main() -> Result<()> {
5656 collab:: seed:: seed ( & config, & db, true ) . await ?;
5757 }
5858 Some ( "serve" ) => {
59- let ( is_api, is_collab) = if let Some ( next) = args. next ( ) {
60- ( next == "api" , next == "collab" )
61- } else {
62- ( true , true )
59+ let mode = match args. next ( ) . as_deref ( ) {
60+ Some ( "collab" ) => ServiceMode :: Collab ,
61+ Some ( "api" ) => ServiceMode :: Api ,
62+ Some ( "llm" ) => ServiceMode :: Llm ,
63+ Some ( "all" ) => ServiceMode :: All ,
64+ _ => {
65+ return Err ( anyhow ! (
66+ "usage: collab <version | migrate | seed | serve <api|collab|llm|all>>"
67+ ) ) ?;
68+ }
6369 } ;
64- if !is_api && !is_collab {
65- Err ( anyhow ! (
66- "usage: collab <version | migrate | seed | serve [api|collab]>"
67- ) ) ?;
68- }
6970
7071 let config = envy:: from_env :: < Config > ( ) . expect ( "error loading config" ) ;
7172 init_tracing ( & config) ;
73+ let mut app = Router :: new ( )
74+ . route ( "/" , get ( handle_root) )
75+ . route ( "/healthz" , get ( handle_liveness_probe) )
76+ . layer ( Extension ( mode) ) ;
7277
73- run_migrations ( & config) . await ?;
74-
75- let state = AppState :: new ( config, Executor :: Production ) . await ?;
76-
77- let listener = TcpListener :: bind ( & format ! ( "0.0.0.0:{}" , state. config. http_port) )
78+ let listener = TcpListener :: bind ( & format ! ( "0.0.0.0:{}" , config. http_port) )
7879 . expect ( "failed to bind TCP listener" ) ;
7980
80- let rpc_server = if is_collab {
81- let epoch = state
82- . db
83- . create_server ( & state. config . zed_environment )
84- . await ?;
85- let rpc_server = collab:: rpc:: Server :: new ( epoch, state. clone ( ) ) ;
86- rpc_server. start ( ) . await ?;
87-
88- Some ( rpc_server)
89- } else {
90- None
91- } ;
81+ let mut on_shutdown = None ;
9282
93- if is_collab {
94- state. db . purge_old_embeddings ( ) . await . trace_err ( ) ;
95- RateLimiter :: save_periodically ( state. rate_limiter . clone ( ) , state. executor . clone ( ) ) ;
96- }
83+ if mode. is_llm ( ) {
84+ let state = LlmState :: new ( config. clone ( ) , Executor :: Production ) . await ?;
9785
98- if is_api {
99- poll_stripe_events_periodically ( state. clone ( ) ) ;
100- fetch_extensions_from_blob_store_periodically ( state. clone ( ) ) ;
86+ app = app. layer ( Extension ( state. clone ( ) ) ) ;
10187 }
10288
103- let mut app = collab:: api:: routes ( rpc_server. clone ( ) , state. clone ( ) ) ;
104- if let Some ( rpc_server) = rpc_server. clone ( ) {
105- app = app. merge ( collab:: rpc:: routes ( rpc_server) )
106- }
107- app = app
108- . merge (
109- Router :: new ( )
110- . route ( "/" , get ( handle_root) )
111- . route ( "/healthz" , get ( handle_liveness_probe) )
112- . merge ( collab:: api:: extensions:: router ( ) )
89+ if mode. is_collab ( ) || mode. is_api ( ) {
90+ run_migrations ( & config) . await ?;
91+
92+ let state = AppState :: new ( config, Executor :: Production ) . await ?;
93+
94+ if mode. is_collab ( ) {
95+ state. db . purge_old_embeddings ( ) . await . trace_err ( ) ;
96+ RateLimiter :: save_periodically (
97+ state. rate_limiter . clone ( ) ,
98+ state. executor . clone ( ) ,
99+ ) ;
100+
101+ let epoch = state
102+ . db
103+ . create_server ( & state. config . zed_environment )
104+ . await ?;
105+ let rpc_server = collab:: rpc:: Server :: new ( epoch, state. clone ( ) ) ;
106+ rpc_server. start ( ) . await ?;
107+
108+ app = app
109+ . merge ( collab:: api:: routes ( rpc_server. clone ( ) ) )
110+ . merge ( collab:: rpc:: routes ( rpc_server. clone ( ) ) ) ;
111+
112+ on_shutdown = Some ( Box :: new ( move || rpc_server. teardown ( ) ) ) ;
113+ }
114+
115+ if mode. is_api ( ) {
116+ poll_stripe_events_periodically ( state. clone ( ) ) ;
117+ fetch_extensions_from_blob_store_periodically ( state. clone ( ) ) ;
118+
119+ app = app
113120 . merge ( collab:: api:: events:: router ( ) )
114- . layer ( Extension ( state. clone ( ) ) ) ,
115- )
116- . layer (
117- TraceLayer :: new_for_http ( )
118- . make_span_with ( |request : & Request < _ > | {
119- let matched_path = request
120- . extensions ( )
121- . get :: < MatchedPath > ( )
122- . map ( MatchedPath :: as_str) ;
123-
124- tracing:: info_span!(
125- "http_request" ,
126- method = ?request. method( ) ,
127- matched_path,
128- )
129- } )
130- . on_response (
131- |response : & Response < _ > , latency : Duration , _: & tracing:: Span | {
132- let duration_ms = latency. as_micros ( ) as f64 / 1000. ;
133- tracing:: info!(
134- duration_ms,
135- status = response. status( ) . as_u16( ) ,
136- "finished processing request"
137- ) ;
138- } ,
139- ) ,
140- ) ;
121+ . merge ( collab:: api:: extensions:: router ( ) )
122+ }
123+
124+ app = app. layer ( Extension ( state. clone ( ) ) ) ;
125+ }
126+
127+ app = app. layer (
128+ TraceLayer :: new_for_http ( )
129+ . make_span_with ( |request : & Request < _ > | {
130+ let matched_path = request
131+ . extensions ( )
132+ . get :: < MatchedPath > ( )
133+ . map ( MatchedPath :: as_str) ;
134+
135+ tracing:: info_span!(
136+ "http_request" ,
137+ method = ?request. method( ) ,
138+ matched_path,
139+ )
140+ } )
141+ . on_response (
142+ |response : & Response < _ > , latency : Duration , _: & tracing:: Span | {
143+ let duration_ms = latency. as_micros ( ) as f64 / 1000. ;
144+ tracing:: info!(
145+ duration_ms,
146+ status = response. status( ) . as_u16( ) ,
147+ "finished processing request"
148+ ) ;
149+ } ,
150+ ) ,
151+ ) ;
141152
142153 #[ cfg( unix) ]
143154 let signal = async move {
@@ -174,16 +185,16 @@ async fn main() -> Result<()> {
174185 signal. await ;
175186 tracing:: info!( "Received interrupt signal" ) ;
176187
177- if let Some ( rpc_server ) = rpc_server {
178- rpc_server . teardown ( ) ;
188+ if let Some ( on_shutdown ) = on_shutdown {
189+ on_shutdown ( ) ;
179190 }
180191 } )
181192 . await
182193 . map_err ( |e| anyhow ! ( e) ) ?;
183194 }
184195 _ => {
185196 Err ( anyhow ! (
186- "usage: collab <version | migrate | seed | serve [ api|collab] >"
197+ "usage: collab <version | migrate | seed | serve < api|collab|llm|all> >"
187198 ) ) ?;
188199 }
189200 }
@@ -222,12 +233,23 @@ async fn run_migrations(config: &Config) -> Result<()> {
222233 return Ok ( ( ) ) ;
223234}
224235
225- async fn handle_root ( ) -> String {
226- format ! ( "collab v{} ({})" , VERSION , REVISION . unwrap_or( "unknown" ) )
236+ async fn handle_root ( Extension ( mode) : Extension < ServiceMode > ) -> String {
237+ format ! (
238+ "collab {mode:?} v{VERSION} ({})" ,
239+ REVISION . unwrap_or( "unknown" )
240+ )
227241}
228242
229- async fn handle_liveness_probe ( Extension ( state) : Extension < Arc < AppState > > ) -> Result < String > {
230- state. db . get_all_users ( 0 , 1 ) . await ?;
243+ async fn handle_liveness_probe (
244+ app_state : Option < Extension < Arc < AppState > > > ,
245+ llm_state : Option < Extension < Arc < LlmState > > > ,
246+ ) -> Result < String > {
247+ if let Some ( state) = app_state {
248+ state. db . get_all_users ( 0 , 1 ) . await ?;
249+ }
250+
251+ if let Some ( _llm_state) = llm_state { }
252+
231253 Ok ( "ok" . to_string ( ) )
232254}
233255
0 commit comments