Skip to content

Commit

Permalink
sticky node ids #16
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Jul 31, 2024
1 parent 980d5d2 commit 0bfb8e3
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 3 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ __pycache__/
.venv
test_weights.npz
.exo_used_ports
.exo_node_id
.idea

# Byte-compiled / optimized / DLL files
Expand Down Expand Up @@ -166,4 +167,4 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
#.idea/
34 changes: 34 additions & 0 deletions exo/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import random
import platform
import psutil
import uuid
from pathlib import Path

DEBUG = int(os.getenv("DEBUG", default="0"))
DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0"))
Expand Down Expand Up @@ -167,3 +169,35 @@ def find_longest_prefix(self, argument: str) -> Optional[Tuple[K, V]]:
return None

return max(matches, key=lambda x: len(x[0]))

def is_valid_uuid(val):
try:
uuid.UUID(str(val))
return True
except ValueError:
return False

def get_or_create_node_id():
NODE_ID_FILE = Path(os.path.dirname(os.path.abspath(__file__))) / ".exo_node_id"
try:
if NODE_ID_FILE.is_file():
with open(NODE_ID_FILE, "r") as f:
stored_id = f.read().strip()
if is_valid_uuid(stored_id):
if DEBUG >= 2: print(f"Retrieved existing node ID: {stored_id}")
return stored_id
else:
if DEBUG >= 2: print("Stored ID is not a valid UUID. Generating a new one.")

new_id = str(uuid.uuid4())
with open(NODE_ID_FILE, "w") as f:
f.write(new_id)

if DEBUG >= 2: print(f"Generated and stored new node ID: {new_id}")
return new_id
except IOError as e:
if DEBUG >= 2: print(f"IO error creating node_id: {e}")
return str(uuid.uuid4())
except Exception as e:
if DEBUG >= 2: print(f"Unexpected error creating node_id: {e}")
return str(uuid.uuid4())
5 changes: 3 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from exo.networking.grpc.grpc_discovery import GRPCDiscovery
from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
from exo.api import ChatGPTAPI
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info, get_or_create_node_id

# parse args
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
parser.add_argument("--node-id", type=str, default=str(uuid.uuid4()), help="Node ID")
parser.add_argument("--node-id", type=str, default=None, help="Node ID")
parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
parser.add_argument("--node-port", type=int, default=None, help="Node port")
parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
Expand Down Expand Up @@ -40,6 +40,7 @@
args.node_port = find_available_port(args.node_host)
if DEBUG >= 1: print(f"Using available port: {args.node_port}")

args.node_id = args.node_id or get_or_create_node_id()
discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, discovery_timeout=args.discovery_timeout)
node = StandardNode(
args.node_id,
Expand Down

0 comments on commit 0bfb8e3

Please sign in to comment.