Skip to content

Commit

Permalink
Config: Add override argparser
Browse files Browse the repository at this point in the history
Add an argparser that casts over to dictionaries of subgroups to
integrate with the config.

This argparser doesn't contain everything in the config due to complexity
issues with CLI args, but will eventually progress to parity. In addition,
it's used to override the config.yml rather than replace it.

A config arg is also provided if the user wants to fully override the
config yaml with another file path.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
kingbri1 committed Jan 1, 2024
1 parent 7176fa6 commit bb7a8e4
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 9 deletions.
122 changes: 122 additions & 0 deletions args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""Argparser for overriding config values"""
import argparse


def str_to_bool(value):
"""Converts a string into a boolean value"""

if value.lower() in {"false", "f", "0", "no", "n"}:
return False
elif value.lower() in {"true", "t", "1", "yes", "y"}:
return True
raise ValueError(f"{value} is not a valid boolean value")


def init_argparser():
"""Creates an argument parser that any function can use"""

parser = argparse.ArgumentParser(
epilog="These args are only for a subset of the config. "
+ "Please edit config.yml for all options!"
)
add_network_args(parser)
add_model_args(parser)
add_logging_args(parser)
add_config_args(parser)

return parser


def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentParser):
"""Broad conversion of surface level arg groups to dictionaries"""

arg_groups = {}
for group in parser._action_groups:
group_dict = {}
for arg in group._group_actions:
value = getattr(args, arg.dest, None)
if value is not None:
group_dict[arg.dest] = value

arg_groups[group.title] = group_dict

return arg_groups


def add_config_args(parser: argparse.ArgumentParser):
"""Adds config arguments"""

parser.add_argument(
"--config", type=str, help="Path to an overriding config.yml file"
)


def add_network_args(parser: argparse.ArgumentParser):
"""Adds networking arguments"""

network_group = parser.add_argument_group("network")
network_group.add_argument("--host", type=str, help="The IP to host on")
network_group.add_argument("--port", type=int, help="The port to host on")
network_group.add_argument(
"--disable-auth",
type=str_to_bool,
help="Disable HTTP token authenticaion with requests",
)


def add_model_args(parser: argparse.ArgumentParser):
"""Adds model arguments"""

model_group = parser.add_argument_group("model")
model_group.add_argument(
"--model-dir", type=str, help="Overrides the directory to look for models"
)
model_group.add_argument("--model-name", type=str, help="An initial model to load")
model_group.add_argument(
"--max-seq-len", type=int, help="Override the maximum model sequence length"
)
model_group.add_argument(
"--override-base-seq-len",
type=str_to_bool,
help="Overrides base model context length",
)
model_group.add_argument(
"--rope-scale", type=float, help="Sets rope_scale or compress_pos_emb"
)
model_group.add_argument("--rope-alpha", type=float, help="Sets rope_alpha for NTK")
model_group.add_argument(
"--prompt-template",
type=str,
help="Set the prompt template for chat completions",
)
model_group.add_argument(
"--gpu-split-auto",
type=str_to_bool,
help="Automatically allocate resources to GPUs",
)
model_group.add_argument(
"--gpu-split",
type=float,
nargs="+",
help="An integer array of GBs of vram to split between GPUs. "
+ "Ignored if gpu_split_auto is true",
)
model_group.add_argument(
"--num-experts-per-token",
type=int,
help="Number of experts to use per token in MoE models",
)


def add_logging_args(parser: argparse.ArgumentParser):
"""Adds logging arguments"""

logging_group = parser.add_argument_group("logging")
logging_group.add_argument(
"--log-prompt", type=str_to_bool, help="Enable prompt logging"
)
logging_group.add_argument(
"--log-generation-params",
type=str_to_bool,
help="Enable generation parameter logging",
)
31 changes: 31 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,37 @@ def read_config_from_file(config_path: pathlib.Path):
GLOBAL_CONFIG = {}


def override_config_from_args(args: dict):
"""Overrides the config based on a dict representation of args"""

config_override = unwrap(args.get("options", {}).get("config"))
if config_override:
logger.info("Attempting to override config.yml from args.")
read_config_from_file(pathlib.Path(config_override))
return

# Network config
network_override = args.get("network")
if network_override:
network_config = get_network_config()
GLOBAL_CONFIG["network"] = {**network_config, **network_override}

# Model config
model_override = args.get("model")
if model_override:
model_config = get_model_config()
GLOBAL_CONFIG["model"] = {**model_config, **model_override}

# Logging config
logging_override = args.get("logging")
if logging_override:
logging_config = get_gen_logging_config()
GLOBAL_CONFIG["logging"] = {
**logging_config,
**{k.replace("log_", ""): logging_override[k] for k in logging_override},
}


def get_model_config():
"""Returns the model config from the global config"""
return unwrap(GLOBAL_CONFIG.get("model"), {})
Expand Down
11 changes: 10 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
from progress.bar import IncrementalBar

import gen_logging
from args import convert_args_to_dict, init_argparser
from auth import check_admin_key, check_api_key, load_auth_keys
from config import (
override_config_from_args,
read_config_from_file,
get_gen_logging_config,
get_model_config,
Expand Down Expand Up @@ -493,13 +495,20 @@ async def generator():
return response


def entrypoint():
def entrypoint(args: Optional[dict] = None):
"""Entry function for program startup"""
global MODEL_CONTAINER

# Load from YAML config
read_config_from_file(pathlib.Path("config.yml"))

# Parse and override config from args
if args is None:
parser = init_argparser()
args = convert_args_to_dict(parser.parse_args(), parser)

override_config_from_args(args)

network_config = get_network_config()

# Initialize auth keys
Expand Down
18 changes: 10 additions & 8 deletions start.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import pathlib
import subprocess
from args import convert_args_to_dict, init_argparser


def get_requirements_file():
Expand All @@ -24,28 +25,29 @@ def get_requirements_file():
return requirements_name


def get_argparser():
"""Fetches the argparser for this script"""
parser = argparse.ArgumentParser()
parser.add_argument(
def add_start_args(parser: argparse.ArgumentParser):
"""Add start script args to the provided parser"""
start_group = parser.add_argument_group("start")
start_group.add_argument(
"-iu",
"--ignore-upgrade",
action="store_true",
help="Ignore requirements upgrade",
)
parser.add_argument(
start_group.add_argument(
"-nw",
"--nowheel",
action="store_true",
help="Don't upgrade wheel dependencies (exllamav2, torch)",
)
return parser


if __name__ == "__main__":
subprocess.run(["pip", "-V"])

parser = get_argparser()
# Create an argparser and add extra startup script args
parser = init_argparser()
add_start_args(parser)
args = parser.parse_args()

if args.ignore_upgrade:
Expand All @@ -59,4 +61,4 @@ def get_argparser():
# Import entrypoint after installing all requirements
from main import entrypoint

entrypoint()
entrypoint(convert_args_to_dict(args, parser))

0 comments on commit bb7a8e4

Please sign in to comment.