From 43e0cfc573e77bf24cdf0f1003a3f4a9db1424b3 Mon Sep 17 00:00:00 2001 From: Nate Sales Date: Tue, 19 Nov 2024 16:45:40 -0500 Subject: [PATCH] fix: local listener config --- examples/nginx/Dockerfile | 2 +- examples/ollama/Dockerfile | 2 +- main.go | 16 ++++++++++++---- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/examples/nginx/Dockerfile b/examples/nginx/Dockerfile index 1065ed3..8fce967 100644 --- a/examples/nginx/Dockerfile +++ b/examples/nginx/Dockerfile @@ -4,6 +4,6 @@ FROM nginxdemos/hello COPY --from=shim /nitro-attestation-shim /nitro-attestation-shim ENV NITRO_SHIM_PORT=6000 -ENV NITRO_UPSTREAM_PORT=80 +ENV NITRO_SHIM_UPSTREAM_PORT=80 ENTRYPOINT ["/nitro-attestation-shim", "nginx", "-g", "daemon off;"] diff --git a/examples/ollama/Dockerfile b/examples/ollama/Dockerfile index aedad15..eea6e6b 100644 --- a/examples/ollama/Dockerfile +++ b/examples/ollama/Dockerfile @@ -5,7 +5,7 @@ FROM ollama/ollama COPY --from=shim /nitro-attestation-shim /nitro-attestation-shim ENV NITRO_SHIM_PORT=6000 -ENV NITRO_UPSTREAM_PORT=11434 +ENV NITRO_SHIM_UPSTREAM_PORT=11434 RUN nohup bash -c "ollama serve &" && sleep 5 && ollama pull llama3.2:1b diff --git a/main.go b/main.go index 22046a2..b01c0ff 100644 --- a/main.go +++ b/main.go @@ -3,6 +3,7 @@ package main import ( "fmt" "log" + "net" "net/http" "net/http/httputil" "net/url" @@ -47,16 +48,23 @@ func getInt(env string, defaultValue int) int { func main() { listenPort := getInt("NITRO_SHIM_PORT", 6000) - upstreamHost := fmt.Sprintf("localhost:%d", getInt("NITRO_UPSTREAM_PORT", 6001)) - cid := getInt("NITRO_CID", 16) + upstreamHost := fmt.Sprintf("localhost:%d", getInt("NITRO_SHIM_UPSTREAM_PORT", 6001)) + cid := getInt("NITRO_SHIM_CID", 16) + useVsock := os.Getenv("NITRO_SHIM_LOCAL") == "" if len(os.Args) < 2 { log.Fatalf("Usage: %s [args...]", os.Args[0]) } - l, err := vsock.ListenContextID(uint32(cid), uint32(listenPort), nil) + var l net.Listener + var err error + if useVsock { + l, err = vsock.ListenContextID(uint32(cid), uint32(listenPort), nil) + } else { + l, err = net.Listen("tcp", fmt.Sprintf(":%d", listenPort)) + } if err != nil { - log.Fatalf("failed vsock.Listen: %s", err) + log.Fatalf("listen: %s", err) return } defer l.Close()