diff --git a/augur/__init__.py b/augur/__init__.py index af4c23817..96cf036cc 100644 --- a/augur/__init__.py +++ b/augur/__init__.py @@ -14,7 +14,7 @@ from .debug import DEBUGGING from .errors import AugurError from .io.print import print_err -from .argparse_ import add_command_subparsers, add_default_command +from .argparse_ import register_commands, add_default_command DEFAULT_AUGUR_RECURSION_LIMIT = 10000 sys.setrecursionlimit(int(os.environ.get("AUGUR_RECURSION_LIMIT") or DEFAULT_AUGUR_RECURSION_LIMIT)) @@ -52,14 +52,15 @@ def make_parser(): parser = argparse.ArgumentParser( - prog = "augur", - description = "Augur: A bioinformatics toolkit for phylogenetic analysis.") + prog = "augur", + description = "Augur: A bioinformatics toolkit for phylogenetic analysis.", + formatter_class = argparse.ArgumentDefaultsHelpFormatter, + ) add_default_command(parser) add_version_alias(parser) - subparsers = parser.add_subparsers() - add_command_subparsers(subparsers, COMMANDS) + register_commands(parser, COMMANDS) return parser diff --git a/augur/argparse_.py b/augur/argparse_.py index 6084fdd5e..ac1861ccb 100644 --- a/augur/argparse_.py +++ b/augur/argparse_.py @@ -1,7 +1,8 @@ """ Custom helpers for the argparse standard library. """ -from argparse import Action, ArgumentDefaultsHelpFormatter, ArgumentParser, _ArgumentGroup +import argparse +from argparse import Action, ArgumentParser, _ArgumentGroup from typing import Union from .types import ValidationMode @@ -32,16 +33,14 @@ def run(args): parser.set_defaults(__command__ = default_command) -def add_command_subparsers(subparsers, commands, command_attribute='__command__'): +def register_commands(parser: argparse.ArgumentParser, commands, command_attribute='__command__'): """ Add subparsers for each command module. Parameters ---------- - subparsers: argparse._SubParsersAction - The special subparsers action object created by the parent parser - via `parser.add_subparsers()`. - + parser + ArgumentParser object. commands: list[types.ModuleType] A list of modules that are commands that require their own subparser. Each module is required to have a `register_parser` function to add its own @@ -51,6 +50,8 @@ def add_command_subparsers(subparsers, commands, command_attribute='__command__' Optional attribute name for the commands. The default is `__command__`, which allows top level augur to run commands directly via `args.__command__.run()`. """ + subparsers = parser.add_subparsers() + for command in commands: # Allow each command to register its own subparser subparser = command.register_parser(subparsers) @@ -59,9 +60,8 @@ def add_command_subparsers(subparsers, commands, command_attribute='__command__' if command_attribute: subparser.set_defaults(**{command_attribute: command}) - # Use the same formatting class for every command for consistency. - # Set here to avoid repeating it in every command's register_parser(). - subparser.formatter_class = ArgumentDefaultsHelpFormatter + # Ensure all subparsers format like the top-level parser + subparser.formatter_class = parser.formatter_class if not subparser.description and command.__doc__: subparser.description = command.__doc__ @@ -70,6 +70,10 @@ def add_command_subparsers(subparsers, commands, command_attribute='__command__' if not getattr(command, "run", None): add_default_command(subparser) + # Recursively register any subcommands + if getattr(subparser, "subcommands", None): + register_commands(subparser, subparser.subcommands) + class HideAsFalseAction(Action): """ diff --git a/augur/curate/__init__.py b/augur/curate/__init__.py index f2af6b7d9..6b92b58f2 100644 --- a/augur/curate/__init__.py +++ b/augur/curate/__init__.py @@ -1,18 +1,6 @@ """ A suite of commands to help with data curation. """ -import argparse -import sys -from collections import deque -from textwrap import dedent -from typing import Iterable, Set - -from augur.argparse_ import ExtendOverwriteDefault, add_command_subparsers -from augur.errors import AugurError -from augur.io.json import dump_ndjson, load_ndjson -from augur.io.metadata import DEFAULT_DELIMITERS, InvalidDelimiter, read_table_to_dict, read_metadata_with_sequences, write_records_to_tsv -from augur.io.sequences import write_records_to_fasta -from augur.types import DataErrorMethod from . import format_dates, normalize_strings, passthru, titlecase, apply_geolocation_rules, apply_record_annotations, abbreviate_authors, parse_genbank_location, transform_strain_name, rename @@ -31,220 +19,9 @@ ] -def create_shared_parser(): - """ - Creates an argparse.ArgumentParser that is intended to be used as a parent - parser¹ for all `augur curate` subcommands. This should include all options - that are intended to be shared across the subcommands. - - Note that any options strings used here cannot be used in individual subcommand - subparsers unless the subparser specifically sets `conflict_handler='resolve'` ², - then the subparser option will override the option defined here. - - Based on https://stackoverflow.com/questions/23296695/permit-argparse-global-flags-after-subcommand/23296874#23296874 - - ¹ https://docs.python.org/3/library/argparse.html#parents - ² https://docs.python.org/3/library/argparse.html#conflict-handler - """ - shared_parser = argparse.ArgumentParser(add_help=False) - - shared_inputs = shared_parser.add_argument_group( - title="INPUTS", - description=""" - Input options shared by all `augur curate` commands. - If no input options are provided, commands will try to read NDJSON records from stdin. - """) - shared_inputs.add_argument("--metadata", - help="Input metadata file. May be plain text (TSV, CSV) or an Excel or OpenOffice spreadsheet workbook file. When an Excel or OpenOffice workbook, only the first visible worksheet will be read and initial empty rows/columns will be ignored. Accepts '-' to read plain text from stdin.") - shared_inputs.add_argument("--id-column", - help="Name of the metadata column that contains the record identifier for reporting duplicate records. " - "Uses the first column of the metadata file if not provided. " - "Ignored if also providing a FASTA file input.") - shared_inputs.add_argument("--metadata-delimiters", default=DEFAULT_DELIMITERS, nargs="+", action=ExtendOverwriteDefault, - help="Delimiters to accept when reading a plain text metadata file. Only one delimiter will be inferred.") - - shared_inputs.add_argument("--fasta", - help="Plain or gzipped FASTA file. Headers can only contain the sequence id used to match a metadata record. " + - "Note that an index file will be generated for the FASTA file as .fasta.fxi") - shared_inputs.add_argument("--seq-id-column", - help="Name of metadata column that contains the sequence id to match sequences in the FASTA file.") - shared_inputs.add_argument("--seq-field", - help="The name to use for the sequence field when joining sequences from a FASTA file.") - - shared_inputs.add_argument("--unmatched-reporting", - type=DataErrorMethod.argtype, - choices=list(DataErrorMethod), - default=DataErrorMethod.ERROR_FIRST, - help="How unmatched records from combined metadata/FASTA input should be reported.") - shared_inputs.add_argument("--duplicate-reporting", - type=DataErrorMethod.argtype, - choices=list(DataErrorMethod), - default=DataErrorMethod.ERROR_FIRST, - help="How should duplicate records be reported.") - - shared_outputs = shared_parser.add_argument_group( - title="OUTPUTS", - description=""" - Output options shared by all `augur curate` commands. - If no output options are provided, commands will output NDJSON records to stdout. - """) - shared_outputs.add_argument("--output-metadata", - help="Output metadata TSV file. Accepts '-' to output TSV to stdout.") - - shared_outputs.add_argument("--output-fasta", - help="Output FASTA file.") - shared_outputs.add_argument("--output-id-field", - help="The record field to use as the sequence identifier in the FASTA output.") - shared_outputs.add_argument("--output-seq-field", - help="The record field that contains the sequence for the FASTA output. " - "This field will be deleted from the metadata output.") - - return shared_parser - - def register_parser(parent_subparsers): - shared_parser = create_shared_parser() parser = parent_subparsers.add_parser("curate", help=__doc__) - # Add print_help so we can run it when no subcommands are called - parser.set_defaults(print_help = parser.print_help) - - # Add subparsers for subcommands - subparsers = parser.add_subparsers(dest="subcommand", required=False) - # Add the shared_parser to make it available for subcommands - # to include in their own parser - subparsers.shared_parser = shared_parser - # Using a subcommand attribute so subcommands are not directly - # run by top level Augur. Process I/O in `curate`` so individual - # subcommands do not have to worry about it. - add_command_subparsers(subparsers, SUBCOMMANDS, SUBCOMMAND_ATTRIBUTE) + parser.subcommands = SUBCOMMANDS return parser - - -def validate_records(records: Iterable[dict], subcmd_name: str, is_input: bool) -> Iterable[dict]: - """ - Validate that the provided *records* all have the same fields. - Uses the keys of the first record to check against all other records. - - Parameters - ---------- - records: iterable of dict - - subcmd_name: str - The name of the subcommand whose output is being validated; used in - error messages displayed to the user. - - is_input: bool - Whether the provided records come directly from user provided input - """ - error_message = "Records do not have the same fields! " - if is_input: - error_message += "Please check your input data has the same fields." - else: - # Hopefully users should not run into this error as it means we are - # not uniformly adding/removing fields from records - error_message += dedent(f"""\ - Something unexpected happened during the augur curate {subcmd_name} command. - To report this, please open a new issue including the original command: - - """) - - first_record_keys: Set[str] = set() - for idx, record in enumerate(records): - if idx == 0: - first_record_keys.update(record.keys()) - else: - if set(record.keys()) != first_record_keys: - raise AugurError(error_message) - yield record - - -def run(args): - # Print help if no subcommands are used - if not getattr(args, SUBCOMMAND_ATTRIBUTE, None): - return args.print_help() - - # Check provided args are valid and required combination of args are provided - if not args.fasta and (args.seq_id_column or args.seq_field): - raise AugurError("The --seq-id-column and --seq-field options should only be used when providing a FASTA file.") - - if args.fasta and (not args.seq_id_column or not args.seq_field): - raise AugurError("The --seq-id-column and --seq-field options are required for a FASTA file input.") - - if not args.output_fasta and (args.output_id_field or args.output_seq_field): - raise AugurError("The --output-id-field and --output-seq-field options should only be used when requesting a FASTA output.") - - if args.output_fasta and (not args.output_id_field or not args.output_seq_field): - raise AugurError("The --output-id-field and --output-seq-field options are required for a FASTA output.") - - # Read inputs - # Special case single hyphen as stdin - if args.metadata == '-': - args.metadata = sys.stdin.buffer - - if args.metadata and args.fasta: - try: - records = read_metadata_with_sequences( - args.metadata, - args.metadata_delimiters, - args.fasta, - args.seq_id_column, - args.seq_field, - args.unmatched_reporting, - args.duplicate_reporting) - except InvalidDelimiter: - raise AugurError( - f"Could not determine the delimiter of {args.metadata!r}. " - f"Valid delimiters are: {args.metadata_delimiters!r}. " - "This can be changed with --metadata-delimiters." - ) - elif args.metadata: - try: - records = read_table_to_dict(args.metadata, args.metadata_delimiters, args.duplicate_reporting, args.id_column) - except InvalidDelimiter: - raise AugurError( - f"Could not determine the delimiter of {args.metadata!r}. " - f"Valid delimiters are: {args.metadata_delimiters!r}. " - "This can be changed with --metadata-delimiters." - ) - elif not sys.stdin.isatty(): - records = load_ndjson(sys.stdin) - else: - raise AugurError(dedent("""\ - No valid inputs were provided. - NDJSON records can be streamed from stdin or - input files can be provided via the command line options `--metadata` and `--fasta`. - See the command's help message for more details.""")) - - # Get the name of the subcmd being run - subcmd_name = args.subcommand - - # Validate records have the same input fields - validated_input_records = validate_records(records, subcmd_name, True) - - # Run subcommand to get modified records - modified_records = getattr(args, SUBCOMMAND_ATTRIBUTE).run(args, validated_input_records) - - # Validate modified records have the same output fields - validated_output_records = validate_records(modified_records, subcmd_name, False) - - # Output modified records - # First output FASTA, since the write fasta function yields the records again - # and removes the sequences from the records - if args.output_fasta: - validated_output_records = write_records_to_fasta( - validated_output_records, - args.output_fasta, - args.output_id_field, - args.output_seq_field) - - if args.output_metadata: - write_records_to_tsv(validated_output_records, args.output_metadata) - - if not (args.output_fasta or args.output_metadata): - dump_ndjson(validated_output_records) - else: - # Exhaust generator to ensure we run through all records - # when only a FASTA output is requested but not a metadata output - deque(validated_output_records, maxlen=0) diff --git a/augur/curate/_shared.py b/augur/curate/_shared.py new file mode 100644 index 000000000..2cde7afb7 --- /dev/null +++ b/augur/curate/_shared.py @@ -0,0 +1,215 @@ +import argparse +import sys +from collections import deque +from textwrap import dedent +from typing import Callable, Iterable, Set + +from augur.argparse_ import ExtendOverwriteDefault +from augur.errors import AugurError +from augur.io.json import dump_ndjson, load_ndjson +from augur.io.metadata import DEFAULT_DELIMITERS, InvalidDelimiter, read_table_to_dict, read_metadata_with_sequences, write_records_to_tsv +from augur.io.sequences import write_records_to_fasta +from augur.types import DataErrorMethod + + +def create_shared_parser(): + """ + Creates an argparse.ArgumentParser that is intended to be used as a parent + parser¹ for all `augur curate` subcommands. This should include all options + that are intended to be shared across the subcommands. + + Note that any options strings used here cannot be used in individual subcommand + subparsers unless the subparser specifically sets `conflict_handler='resolve'` ², + then the subparser option will override the option defined here. + + Based on https://stackoverflow.com/questions/23296695/permit-argparse-global-flags-after-subcommand/23296874#23296874 + + ¹ https://docs.python.org/3/library/argparse.html#parents + ² https://docs.python.org/3/library/argparse.html#conflict-handler + """ + shared_parser = argparse.ArgumentParser(add_help=False) + + shared_inputs = shared_parser.add_argument_group( + title="INPUTS", + description=""" + Input options shared by all `augur curate` commands. + If no input options are provided, commands will try to read NDJSON records from stdin. + """) + shared_inputs.add_argument("--metadata", + help="Input metadata file. May be plain text (TSV, CSV) or an Excel or OpenOffice spreadsheet workbook file. When an Excel or OpenOffice workbook, only the first visible worksheet will be read and initial empty rows/columns will be ignored. Accepts '-' to read plain text from stdin.") + shared_inputs.add_argument("--id-column", + help="Name of the metadata column that contains the record identifier for reporting duplicate records. " + "Uses the first column of the metadata file if not provided. " + "Ignored if also providing a FASTA file input.") + shared_inputs.add_argument("--metadata-delimiters", default=DEFAULT_DELIMITERS, nargs="+", action=ExtendOverwriteDefault, + help="Delimiters to accept when reading a plain text metadata file. Only one delimiter will be inferred.") + + shared_inputs.add_argument("--fasta", + help="Plain or gzipped FASTA file. Headers can only contain the sequence id used to match a metadata record. " + + "Note that an index file will be generated for the FASTA file as .fasta.fxi") + shared_inputs.add_argument("--seq-id-column", + help="Name of metadata column that contains the sequence id to match sequences in the FASTA file.") + shared_inputs.add_argument("--seq-field", + help="The name to use for the sequence field when joining sequences from a FASTA file.") + + shared_inputs.add_argument("--unmatched-reporting", + type=DataErrorMethod.argtype, + choices=list(DataErrorMethod), + default=DataErrorMethod.ERROR_FIRST, + help="How unmatched records from combined metadata/FASTA input should be reported.") + shared_inputs.add_argument("--duplicate-reporting", + type=DataErrorMethod.argtype, + choices=list(DataErrorMethod), + default=DataErrorMethod.ERROR_FIRST, + help="How should duplicate records be reported.") + + shared_outputs = shared_parser.add_argument_group( + title="OUTPUTS", + description=""" + Output options shared by all `augur curate` commands. + If no output options are provided, commands will output NDJSON records to stdout. + """) + shared_outputs.add_argument("--output-metadata", + help="Output metadata TSV file. Accepts '-' to output TSV to stdout.") + + shared_outputs.add_argument("--output-fasta", + help="Output FASTA file.") + shared_outputs.add_argument("--output-id-field", + help="The record field to use as the sequence identifier in the FASTA output.") + shared_outputs.add_argument("--output-seq-field", + help="The record field that contains the sequence for the FASTA output. " + "This field will be deleted from the metadata output.") + + return shared_parser + + +shared_parser = create_shared_parser() + + +def validate_records(records: Iterable[dict], subcmd_name: str, is_input: bool) -> Iterable[dict]: + """ + Validate that the provided *records* all have the same fields. + Uses the keys of the first record to check against all other records. + + Parameters + ---------- + records: iterable of dict + + subcmd_name: str + The name of the subcommand whose output is being validated; used in + error messages displayed to the user. + + is_input: bool + Whether the provided records come directly from user provided input + """ + error_message = "Records do not have the same fields! " + if is_input: + error_message += "Please check your input data has the same fields." + else: + # Hopefully users should not run into this error as it means we are + # not uniformly adding/removing fields from records + error_message += dedent(f"""\ + Something unexpected happened during the augur curate {subcmd_name} command. + To report this, please open a new issue including the original command: + + """) + + first_record_keys: Set[str] = set() + for idx, record in enumerate(records): + if idx == 0: + first_record_keys.update(record.keys()) + else: + if set(record.keys()) != first_record_keys: + raise AugurError(error_message) + yield record + + +def validate(command_name: str): + + def decorator_with_context(command_run: Callable[[argparse.Namespace, Iterable[dict]], Iterable[dict]]): + + def run(args: argparse.Namespace) -> None: + # Check provided args are valid and required combination of args are provided + if not args.fasta and (args.seq_id_column or args.seq_field): + raise AugurError("The --seq-id-column and --seq-field options should only be used when providing a FASTA file.") + + if args.fasta and (not args.seq_id_column or not args.seq_field): + raise AugurError("The --seq-id-column and --seq-field options are required for a FASTA file input.") + + if not args.output_fasta and (args.output_id_field or args.output_seq_field): + raise AugurError("The --output-id-field and --output-seq-field options should only be used when requesting a FASTA output.") + + if args.output_fasta and (not args.output_id_field or not args.output_seq_field): + raise AugurError("The --output-id-field and --output-seq-field options are required for a FASTA output.") + + # Read inputs + # Special case single hyphen as stdin + if args.metadata == '-': + args.metadata = sys.stdin.buffer + + if args.metadata and args.fasta: + try: + records = read_metadata_with_sequences( + args.metadata, + args.metadata_delimiters, + args.fasta, + args.seq_id_column, + args.seq_field, + args.unmatched_reporting, + args.duplicate_reporting) + except InvalidDelimiter: + raise AugurError( + f"Could not determine the delimiter of {args.metadata!r}. " + f"Valid delimiters are: {args.metadata_delimiters!r}. " + "This can be changed with --metadata-delimiters." + ) + elif args.metadata: + try: + records = read_table_to_dict(args.metadata, args.metadata_delimiters, args.duplicate_reporting, args.id_column) + except InvalidDelimiter: + raise AugurError( + f"Could not determine the delimiter of {args.metadata!r}. " + f"Valid delimiters are: {args.metadata_delimiters!r}. " + "This can be changed with --metadata-delimiters." + ) + elif not sys.stdin.isatty(): + records = load_ndjson(sys.stdin) + else: + raise AugurError(dedent("""\ + No valid inputs were provided. + NDJSON records can be streamed from stdin or + input files can be provided via the command line options `--metadata` and `--fasta`. + See the command's help message for more details.""")) + + # Validate records have the same input fields + validated_input_records = validate_records(records, command_name, True) + + # Run subcommand to get modified records + modified_records = command_run(args, validated_input_records) + + # Validate modified records have the same output fields + validated_output_records = validate_records(modified_records, command_name, False) + + # Output modified records + # First output FASTA, since the write fasta function yields the records again + # and removes the sequences from the records + if args.output_fasta: + validated_output_records = write_records_to_fasta( + validated_output_records, + args.output_fasta, + args.output_id_field, + args.output_seq_field) + + if args.output_metadata: + write_records_to_tsv(validated_output_records, args.output_metadata) + + if not (args.output_fasta or args.output_metadata): + dump_ndjson(validated_output_records) + else: + # Exhaust generator to ensure we run through all records + # when only a FASTA output is requested but not a metadata output + deque(validated_output_records, maxlen=0) + + return run + + return decorator_with_context diff --git a/augur/curate/abbreviate_authors.py b/augur/curate/abbreviate_authors.py index f95b8b723..399247b42 100644 --- a/augur/curate/abbreviate_authors.py +++ b/augur/curate/abbreviate_authors.py @@ -10,6 +10,10 @@ from typing import Generator, List from augur.io.print import print_err from augur.utils import first_line +from ._shared import shared_parser, validate + + +COMMAND_NAME = "abbreviate-authors" def parse_authors( @@ -51,8 +55,8 @@ def register_parser( parent_subparsers: argparse._SubParsersAction, ) -> argparse._SubParsersAction: parser = parent_subparsers.add_parser( - "abbreviate-authors", - parents=[parent_subparsers.shared_parser], # type: ignore + COMMAND_NAME, + parents=[shared_parser], # type: ignore help=first_line(__doc__), ) @@ -75,6 +79,7 @@ def register_parser( return parser +@validate(COMMAND_NAME) def run(args: argparse.Namespace, records: List[dict]) -> Generator[dict, None, None]: for index, record in enumerate(records): parse_authors( diff --git a/augur/curate/apply_geolocation_rules.py b/augur/curate/apply_geolocation_rules.py index e14974890..ecf7ab7ef 100644 --- a/augur/curate/apply_geolocation_rules.py +++ b/augur/curate/apply_geolocation_rules.py @@ -5,6 +5,10 @@ from augur.errors import AugurError from augur.io.print import print_err from augur.utils import first_line +from ._shared import shared_parser, validate + + +COMMAND_NAME = "apply-geolocation-rules" class CyclicGeolocationRulesError(AugurError): @@ -187,8 +191,8 @@ def transform_geolocations(geolocation_rules, geolocation): def register_parser(parent_subparsers): - parser = parent_subparsers.add_parser("apply-geolocation-rules", - parents=[parent_subparsers.shared_parser], + parser = parent_subparsers.add_parser(COMMAND_NAME, + parents=[shared_parser], help=first_line(__doc__)) parser.add_argument("--region-field", default="region", @@ -210,6 +214,7 @@ def register_parser(parent_subparsers): return parser +@validate(COMMAND_NAME) def run(args, records): location_fields = [args.region_field, args.country_field, args.division_field, args.location_field] diff --git a/augur/curate/apply_record_annotations.py b/augur/curate/apply_record_annotations.py index a650226d3..c725ee39f 100644 --- a/augur/curate/apply_record_annotations.py +++ b/augur/curate/apply_record_annotations.py @@ -7,11 +7,15 @@ from augur.errors import AugurError from augur.io.print import print_err from augur.utils import first_line +from ._shared import shared_parser, validate + + +COMMAND_NAME = "apply-record-annotations" def register_parser(parent_subparsers): - parser = parent_subparsers.add_parser("apply-record-annotations", - parents=[parent_subparsers.shared_parser], + parser = parent_subparsers.add_parser(COMMAND_NAME, + parents=[shared_parser], help=first_line(__doc__)) parser.add_argument("--annotations", metavar="TSV", required=True, @@ -27,6 +31,7 @@ def register_parser(parent_subparsers): return parser +@validate(COMMAND_NAME) def run(args, records): annotations = defaultdict(dict) with open(args.annotations, 'r', newline='') as annotations_fh: diff --git a/augur/curate/format_dates.py b/augur/curate/format_dates.py index 063096de4..c9e2fe8fe 100644 --- a/augur/curate/format_dates.py +++ b/augur/curate/format_dates.py @@ -10,6 +10,10 @@ from augur.io.print import print_err from augur.types import DataErrorMethod from .format_dates_directives import YEAR_DIRECTIVES, YEAR_MONTH_DIRECTIVES, YEAR_MONTH_DAY_DIRECTIVES +from ._shared import shared_parser, validate + + +COMMAND_NAME = "format-dates" # Default date formats that this command should parse @@ -23,8 +27,8 @@ def register_parser(parent_subparsers): - parser = parent_subparsers.add_parser("format-dates", - parents=[parent_subparsers.shared_parser], + parser = parent_subparsers.add_parser(COMMAND_NAME, + parents=[shared_parser], help=__doc__) required = parser.add_argument_group(title="REQUIRED") @@ -178,6 +182,7 @@ def format_date(date_string, expected_formats): return None +@validate(COMMAND_NAME) def run(args, records): failures = [] failure_reporting = args.failure_reporting diff --git a/augur/curate/normalize_strings.py b/augur/curate/normalize_strings.py index ac98f4a6f..b45d87332 100644 --- a/augur/curate/normalize_strings.py +++ b/augur/curate/normalize_strings.py @@ -7,11 +7,15 @@ import unicodedata from augur.utils import first_line +from ._shared import shared_parser, validate + + +COMMAND_NAME = "normalize-strings" def register_parser(parent_subparsers): - parser = parent_subparsers.add_parser("normalize-strings", - parents=[parent_subparsers.shared_parser], + parser = parent_subparsers.add_parser(COMMAND_NAME, + parents=[shared_parser], help=first_line(__doc__)) optional = parser.add_argument_group(title="OPTIONAL") @@ -44,6 +48,7 @@ def normalize_strings(record, form='NFC'): } +@validate(COMMAND_NAME) def run(args, records): for record in records: yield normalize_strings(record, args.form) diff --git a/augur/curate/parse_genbank_location.py b/augur/curate/parse_genbank_location.py index f1123d595..e28e7028a 100644 --- a/augur/curate/parse_genbank_location.py +++ b/augur/curate/parse_genbank_location.py @@ -10,6 +10,10 @@ from typing import Generator, List from augur.io.print import print_err from augur.utils import first_line +from ._shared import shared_parser, validate + + +COMMAND_NAME = "parse-genbank-location" def parse_location( @@ -49,8 +53,8 @@ def register_parser( parent_subparsers: argparse._SubParsersAction, ) -> argparse._SubParsersAction: parser = parent_subparsers.add_parser( - "parse-genbank-location", - parents=[parent_subparsers.shared_parser], # type: ignore + COMMAND_NAME, + parents=[shared_parser], # type: ignore help=first_line(__doc__), ) @@ -64,6 +68,7 @@ def register_parser( return parser +@validate(COMMAND_NAME) def run( args: argparse.Namespace, records: List[dict], diff --git a/augur/curate/passthru.py b/augur/curate/passthru.py index f31f2dc45..a40e58c78 100644 --- a/augur/curate/passthru.py +++ b/augur/curate/passthru.py @@ -2,14 +2,19 @@ Pass through records without doing any data transformations. Useful for testing, troubleshooting, or just converting file formats. """ +from ._shared import shared_parser, validate + + +COMMAND_NAME = "passthru" def register_parser(parent_subparsers): - return parent_subparsers.add_parser("passthru", - parents=[parent_subparsers.shared_parser], + return parent_subparsers.add_parser(COMMAND_NAME, + parents=[shared_parser], help=__doc__) +@validate(COMMAND_NAME) def run(args, records): yield from records diff --git a/augur/curate/rename.py b/augur/curate/rename.py index 68730fa28..28882827d 100644 --- a/augur/curate/rename.py +++ b/augur/curate/rename.py @@ -6,10 +6,15 @@ import argparse from augur.io.print import print_err from augur.errors import AugurError +from ._shared import shared_parser, validate + + +COMMAND_NAME = "rename" + def register_parser(parent_subparsers): - parser = parent_subparsers.add_parser("rename", - parents = [parent_subparsers.shared_parser], + parser = parent_subparsers.add_parser(COMMAND_NAME, + parents = [shared_parser], help = __doc__) required = parser.add_argument_group(title="REQUIRED") @@ -86,6 +91,7 @@ def transform_columns(existing_fields: List[str], field_map: List[Tuple[str,str] return m +@validate(COMMAND_NAME) def run(args: argparse.Namespace, records: Iterable[dict]) -> Iterable[dict]: col_map: Union[Literal[False], List[Tuple[str,str]]] = False for record in records: diff --git a/augur/curate/titlecase.py b/augur/curate/titlecase.py index e0d6017b8..f2869dc3e 100644 --- a/augur/curate/titlecase.py +++ b/augur/curate/titlecase.py @@ -7,10 +7,15 @@ from augur.errors import AugurError from augur.io.print import print_err from augur.types import DataErrorMethod +from ._shared import shared_parser, validate + + +COMMAND_NAME = "titlecase" + def register_parser(parent_subparsers): - parser = parent_subparsers.add_parser("titlecase", - parents = [parent_subparsers.shared_parser], + parser = parent_subparsers.add_parser(COMMAND_NAME, + parents = [shared_parser], help = __doc__) required = parser.add_argument_group(title="REQUIRED") @@ -73,6 +78,7 @@ def changecase(index, word): return ''.join(changecase(i, w) for i, w in words) +@validate(COMMAND_NAME) def run(args, records): failures = [] failure_reporting = args.failure_reporting diff --git a/augur/curate/transform_strain_name.py b/augur/curate/transform_strain_name.py index ae128893d..8e4e4ab3a 100644 --- a/augur/curate/transform_strain_name.py +++ b/augur/curate/transform_strain_name.py @@ -9,6 +9,10 @@ from typing import Generator, List from augur.io.print import print_err from augur.utils import first_line +from ._shared import shared_parser, validate + + +COMMAND_NAME = "transform-strain-name" def transform_name( @@ -41,8 +45,8 @@ def register_parser( parent_subparsers: argparse._SubParsersAction, ) -> argparse._SubParsersAction: parser = parent_subparsers.add_parser( - "transform-strain-name", - parents=[parent_subparsers.shared_parser], # type: ignore + COMMAND_NAME, + parents=[shared_parser], # type: ignore help=first_line(__doc__), ) @@ -64,6 +68,7 @@ def register_parser( return parser +@validate(COMMAND_NAME) def run(args: argparse.Namespace, records: List[dict]) -> Generator[dict, None, None]: strain_name_pattern = re.compile(args.strain_regex) diff --git a/augur/export.py b/augur/export.py index 38c025b80..050c94475 100644 --- a/augur/export.py +++ b/augur/export.py @@ -1,7 +1,8 @@ """ Export JSON files suitable for visualization with auspice. + +Augur export now needs you to define the JSON version you want, e.g. `augur export v2`. """ -from .argparse_ import add_command_subparsers from . import export_v1, export_v2 SUBCOMMANDS = [ @@ -12,10 +13,7 @@ def register_parser(parent_subparsers): parser = parent_subparsers.add_parser("export", help=__doc__) - # Add subparsers for subcommands - metavar_msg ="Augur export now needs you to define the JSON version " + \ - "you want, e.g. `augur export v2`." - subparsers = parser.add_subparsers(title="JSON SCHEMA", - metavar=metavar_msg) - add_command_subparsers(subparsers, SUBCOMMANDS) + + parser.subcommands = SUBCOMMANDS + return parser diff --git a/augur/import_/__init__.py b/augur/import_/__init__.py index d2f925a29..ef9a9ad0c 100644 --- a/augur/import_/__init__.py +++ b/augur/import_/__init__.py @@ -1,7 +1,6 @@ """ Import analyses into augur pipeline from other systems """ -from augur.argparse_ import add_command_subparsers from augur.utils import first_line from . import beast @@ -11,8 +10,7 @@ def register_parser(parent_subparsers): parser = parent_subparsers.add_parser("import", help=first_line(__doc__)) - metavar_msg = "Import analyses into augur pipeline from other systems" - subparsers = parser.add_subparsers(title="TYPE", - metavar=metavar_msg) - add_command_subparsers(subparsers, SUBCOMMANDS) + + parser.subcommands = SUBCOMMANDS + return parser diff --git a/augur/measurements/__init__.py b/augur/measurements/__init__.py index 3fbc93273..54a4638bc 100644 --- a/augur/measurements/__init__.py +++ b/augur/measurements/__init__.py @@ -1,7 +1,6 @@ """ Create JSON files suitable for visualization within the measurements panel of Auspice. """ -from augur.argparse_ import add_command_subparsers from augur.utils import first_line from . import export, concat @@ -13,7 +12,7 @@ def register_parser(parent_subparsers): parser = parent_subparsers.add_parser("measurements", help=first_line(__doc__)) - # Add subparsers for subcommands - subparsers = parser.add_subparsers(dest='subcommand') - add_command_subparsers(subparsers, SUBCOMMANDS) + + parser.subcommands = SUBCOMMANDS + return parser diff --git a/tests/io/test_curate_validate_records.py b/tests/io/test_curate_validate_records.py index 72f5083f4..42d4f1e15 100644 --- a/tests/io/test_curate_validate_records.py +++ b/tests/io/test_curate_validate_records.py @@ -1,5 +1,5 @@ import pytest -from augur.curate import validate_records +from augur.curate._shared import validate_records from augur.errors import AugurError