-
-
Notifications
You must be signed in to change notification settings - Fork 95
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add an argparser that casts over to dictionaries of subgroups to integrate with the config. This argparser doesn't contain everything in the config due to complexity issues with CLI args, but will eventually progress to parity. In addition, it's used to override the config.yml rather than replace it. A config arg is also provided if the user wants to fully override the config yaml with another file path. Signed-off-by: kingbri <[email protected]>
- Loading branch information
Showing
4 changed files
with
173 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
"""Argparser for overriding config values""" | ||
import argparse | ||
|
||
|
||
def str_to_bool(value): | ||
"""Converts a string into a boolean value""" | ||
|
||
if value.lower() in {"false", "f", "0", "no", "n"}: | ||
return False | ||
elif value.lower() in {"true", "t", "1", "yes", "y"}: | ||
return True | ||
raise ValueError(f"{value} is not a valid boolean value") | ||
|
||
|
||
def init_argparser(): | ||
"""Creates an argument parser that any function can use""" | ||
|
||
parser = argparse.ArgumentParser( | ||
epilog="These args are only for a subset of the config. " | ||
+ "Please edit config.yml for all options!" | ||
) | ||
add_network_args(parser) | ||
add_model_args(parser) | ||
add_logging_args(parser) | ||
add_config_args(parser) | ||
|
||
return parser | ||
|
||
|
||
def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentParser): | ||
"""Broad conversion of surface level arg groups to dictionaries""" | ||
|
||
arg_groups = {} | ||
for group in parser._action_groups: | ||
group_dict = {} | ||
for arg in group._group_actions: | ||
value = getattr(args, arg.dest, None) | ||
if value is not None: | ||
group_dict[arg.dest] = value | ||
|
||
arg_groups[group.title] = group_dict | ||
|
||
return arg_groups | ||
|
||
|
||
def add_config_args(parser: argparse.ArgumentParser): | ||
"""Adds config arguments""" | ||
|
||
parser.add_argument( | ||
"--config", type=str, help="Path to an overriding config.yml file" | ||
) | ||
|
||
|
||
def add_network_args(parser: argparse.ArgumentParser): | ||
"""Adds networking arguments""" | ||
|
||
network_group = parser.add_argument_group("network") | ||
network_group.add_argument("--host", type=str, help="The IP to host on") | ||
network_group.add_argument("--port", type=int, help="The port to host on") | ||
network_group.add_argument( | ||
"--disable-auth", | ||
type=str_to_bool, | ||
help="Disable HTTP token authenticaion with requests", | ||
) | ||
|
||
|
||
def add_model_args(parser: argparse.ArgumentParser): | ||
"""Adds model arguments""" | ||
|
||
model_group = parser.add_argument_group("model") | ||
model_group.add_argument( | ||
"--model-dir", type=str, help="Overrides the directory to look for models" | ||
) | ||
model_group.add_argument("--model-name", type=str, help="An initial model to load") | ||
model_group.add_argument( | ||
"--max-seq-len", type=int, help="Override the maximum model sequence length" | ||
) | ||
model_group.add_argument( | ||
"--override-base-seq-len", | ||
type=str_to_bool, | ||
help="Overrides base model context length", | ||
) | ||
model_group.add_argument( | ||
"--rope-scale", type=float, help="Sets rope_scale or compress_pos_emb" | ||
) | ||
model_group.add_argument("--rope-alpha", type=float, help="Sets rope_alpha for NTK") | ||
model_group.add_argument( | ||
"--prompt-template", | ||
type=str, | ||
help="Set the prompt template for chat completions", | ||
) | ||
model_group.add_argument( | ||
"--gpu-split-auto", | ||
type=str_to_bool, | ||
help="Automatically allocate resources to GPUs", | ||
) | ||
model_group.add_argument( | ||
"--gpu-split", | ||
type=float, | ||
nargs="+", | ||
help="An integer array of GBs of vram to split between GPUs. " | ||
+ "Ignored if gpu_split_auto is true", | ||
) | ||
model_group.add_argument( | ||
"--num-experts-per-token", | ||
type=int, | ||
help="Number of experts to use per token in MoE models", | ||
) | ||
|
||
|
||
def add_logging_args(parser: argparse.ArgumentParser): | ||
"""Adds logging arguments""" | ||
|
||
logging_group = parser.add_argument_group("logging") | ||
logging_group.add_argument( | ||
"--log-prompt", type=str_to_bool, help="Enable prompt logging" | ||
) | ||
logging_group.add_argument( | ||
"--log-generation-params", | ||
type=str_to_bool, | ||
help="Enable generation parameter logging", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters