From ecf0be308941cc62fa0b411893ad2076ff73b6c7 Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Wed, 25 Sep 2024 16:04:20 -0700 Subject: [PATCH 1/8] curate: Set shared parser separately This is necessary to register subparsers recursively. --- augur/curate/__init__.py | 81 +----------------------- augur/curate/_shared.py | 78 +++++++++++++++++++++++ augur/curate/abbreviate_authors.py | 3 +- augur/curate/apply_geolocation_rules.py | 3 +- augur/curate/apply_record_annotations.py | 3 +- augur/curate/format_dates.py | 3 +- augur/curate/normalize_strings.py | 3 +- augur/curate/parse_genbank_location.py | 3 +- augur/curate/passthru.py | 3 +- augur/curate/rename.py | 3 +- augur/curate/titlecase.py | 3 +- augur/curate/transform_strain_name.py | 3 +- 12 files changed, 100 insertions(+), 89 deletions(-) create mode 100644 augur/curate/_shared.py diff --git a/augur/curate/__init__.py b/augur/curate/__init__.py index f2af6b7d9..5e8888127 100644 --- a/augur/curate/__init__.py +++ b/augur/curate/__init__.py @@ -1,18 +1,16 @@ """ A suite of commands to help with data curation. """ -import argparse import sys from collections import deque from textwrap import dedent from typing import Iterable, Set -from augur.argparse_ import ExtendOverwriteDefault, add_command_subparsers +from augur.argparse_ import add_command_subparsers from augur.errors import AugurError from augur.io.json import dump_ndjson, load_ndjson -from augur.io.metadata import DEFAULT_DELIMITERS, InvalidDelimiter, read_table_to_dict, read_metadata_with_sequences, write_records_to_tsv +from augur.io.metadata import InvalidDelimiter, read_table_to_dict, read_metadata_with_sequences, write_records_to_tsv from augur.io.sequences import write_records_to_fasta -from augur.types import DataErrorMethod from . import format_dates, normalize_strings, passthru, titlecase, apply_geolocation_rules, apply_record_annotations, abbreviate_authors, parse_genbank_location, transform_strain_name, rename @@ -31,79 +29,7 @@ ] -def create_shared_parser(): - """ - Creates an argparse.ArgumentParser that is intended to be used as a parent - parser¹ for all `augur curate` subcommands. This should include all options - that are intended to be shared across the subcommands. - - Note that any options strings used here cannot be used in individual subcommand - subparsers unless the subparser specifically sets `conflict_handler='resolve'` ², - then the subparser option will override the option defined here. - - Based on https://stackoverflow.com/questions/23296695/permit-argparse-global-flags-after-subcommand/23296874#23296874 - - ¹ https://docs.python.org/3/library/argparse.html#parents - ² https://docs.python.org/3/library/argparse.html#conflict-handler - """ - shared_parser = argparse.ArgumentParser(add_help=False) - - shared_inputs = shared_parser.add_argument_group( - title="INPUTS", - description=""" - Input options shared by all `augur curate` commands. - If no input options are provided, commands will try to read NDJSON records from stdin. - """) - shared_inputs.add_argument("--metadata", - help="Input metadata file. May be plain text (TSV, CSV) or an Excel or OpenOffice spreadsheet workbook file. When an Excel or OpenOffice workbook, only the first visible worksheet will be read and initial empty rows/columns will be ignored. Accepts '-' to read plain text from stdin.") - shared_inputs.add_argument("--id-column", - help="Name of the metadata column that contains the record identifier for reporting duplicate records. " - "Uses the first column of the metadata file if not provided. " - "Ignored if also providing a FASTA file input.") - shared_inputs.add_argument("--metadata-delimiters", default=DEFAULT_DELIMITERS, nargs="+", action=ExtendOverwriteDefault, - help="Delimiters to accept when reading a plain text metadata file. Only one delimiter will be inferred.") - - shared_inputs.add_argument("--fasta", - help="Plain or gzipped FASTA file. Headers can only contain the sequence id used to match a metadata record. " + - "Note that an index file will be generated for the FASTA file as .fasta.fxi") - shared_inputs.add_argument("--seq-id-column", - help="Name of metadata column that contains the sequence id to match sequences in the FASTA file.") - shared_inputs.add_argument("--seq-field", - help="The name to use for the sequence field when joining sequences from a FASTA file.") - - shared_inputs.add_argument("--unmatched-reporting", - type=DataErrorMethod.argtype, - choices=list(DataErrorMethod), - default=DataErrorMethod.ERROR_FIRST, - help="How unmatched records from combined metadata/FASTA input should be reported.") - shared_inputs.add_argument("--duplicate-reporting", - type=DataErrorMethod.argtype, - choices=list(DataErrorMethod), - default=DataErrorMethod.ERROR_FIRST, - help="How should duplicate records be reported.") - - shared_outputs = shared_parser.add_argument_group( - title="OUTPUTS", - description=""" - Output options shared by all `augur curate` commands. - If no output options are provided, commands will output NDJSON records to stdout. - """) - shared_outputs.add_argument("--output-metadata", - help="Output metadata TSV file. Accepts '-' to output TSV to stdout.") - - shared_outputs.add_argument("--output-fasta", - help="Output FASTA file.") - shared_outputs.add_argument("--output-id-field", - help="The record field to use as the sequence identifier in the FASTA output.") - shared_outputs.add_argument("--output-seq-field", - help="The record field that contains the sequence for the FASTA output. " - "This field will be deleted from the metadata output.") - - return shared_parser - - def register_parser(parent_subparsers): - shared_parser = create_shared_parser() parser = parent_subparsers.add_parser("curate", help=__doc__) # Add print_help so we can run it when no subcommands are called @@ -111,9 +37,6 @@ def register_parser(parent_subparsers): # Add subparsers for subcommands subparsers = parser.add_subparsers(dest="subcommand", required=False) - # Add the shared_parser to make it available for subcommands - # to include in their own parser - subparsers.shared_parser = shared_parser # Using a subcommand attribute so subcommands are not directly # run by top level Augur. Process I/O in `curate`` so individual # subcommands do not have to worry about it. diff --git a/augur/curate/_shared.py b/augur/curate/_shared.py new file mode 100644 index 000000000..590033cbe --- /dev/null +++ b/augur/curate/_shared.py @@ -0,0 +1,78 @@ +import argparse +from augur.argparse_ import ExtendOverwriteDefault +from augur.io.metadata import DEFAULT_DELIMITERS +from augur.types import DataErrorMethod + + +def create_shared_parser(): + """ + Creates an argparse.ArgumentParser that is intended to be used as a parent + parser¹ for all `augur curate` subcommands. This should include all options + that are intended to be shared across the subcommands. + + Note that any options strings used here cannot be used in individual subcommand + subparsers unless the subparser specifically sets `conflict_handler='resolve'` ², + then the subparser option will override the option defined here. + + Based on https://stackoverflow.com/questions/23296695/permit-argparse-global-flags-after-subcommand/23296874#23296874 + + ¹ https://docs.python.org/3/library/argparse.html#parents + ² https://docs.python.org/3/library/argparse.html#conflict-handler + """ + shared_parser = argparse.ArgumentParser(add_help=False) + + shared_inputs = shared_parser.add_argument_group( + title="INPUTS", + description=""" + Input options shared by all `augur curate` commands. + If no input options are provided, commands will try to read NDJSON records from stdin. + """) + shared_inputs.add_argument("--metadata", + help="Input metadata file. May be plain text (TSV, CSV) or an Excel or OpenOffice spreadsheet workbook file. When an Excel or OpenOffice workbook, only the first visible worksheet will be read and initial empty rows/columns will be ignored. Accepts '-' to read plain text from stdin.") + shared_inputs.add_argument("--id-column", + help="Name of the metadata column that contains the record identifier for reporting duplicate records. " + "Uses the first column of the metadata file if not provided. " + "Ignored if also providing a FASTA file input.") + shared_inputs.add_argument("--metadata-delimiters", default=DEFAULT_DELIMITERS, nargs="+", action=ExtendOverwriteDefault, + help="Delimiters to accept when reading a plain text metadata file. Only one delimiter will be inferred.") + + shared_inputs.add_argument("--fasta", + help="Plain or gzipped FASTA file. Headers can only contain the sequence id used to match a metadata record. " + + "Note that an index file will be generated for the FASTA file as .fasta.fxi") + shared_inputs.add_argument("--seq-id-column", + help="Name of metadata column that contains the sequence id to match sequences in the FASTA file.") + shared_inputs.add_argument("--seq-field", + help="The name to use for the sequence field when joining sequences from a FASTA file.") + + shared_inputs.add_argument("--unmatched-reporting", + type=DataErrorMethod.argtype, + choices=list(DataErrorMethod), + default=DataErrorMethod.ERROR_FIRST, + help="How unmatched records from combined metadata/FASTA input should be reported.") + shared_inputs.add_argument("--duplicate-reporting", + type=DataErrorMethod.argtype, + choices=list(DataErrorMethod), + default=DataErrorMethod.ERROR_FIRST, + help="How should duplicate records be reported.") + + shared_outputs = shared_parser.add_argument_group( + title="OUTPUTS", + description=""" + Output options shared by all `augur curate` commands. + If no output options are provided, commands will output NDJSON records to stdout. + """) + shared_outputs.add_argument("--output-metadata", + help="Output metadata TSV file. Accepts '-' to output TSV to stdout.") + + shared_outputs.add_argument("--output-fasta", + help="Output FASTA file.") + shared_outputs.add_argument("--output-id-field", + help="The record field to use as the sequence identifier in the FASTA output.") + shared_outputs.add_argument("--output-seq-field", + help="The record field that contains the sequence for the FASTA output. " + "This field will be deleted from the metadata output.") + + return shared_parser + + +shared_parser = create_shared_parser() diff --git a/augur/curate/abbreviate_authors.py b/augur/curate/abbreviate_authors.py index f95b8b723..ae9a161f5 100644 --- a/augur/curate/abbreviate_authors.py +++ b/augur/curate/abbreviate_authors.py @@ -10,6 +10,7 @@ from typing import Generator, List from augur.io.print import print_err from augur.utils import first_line +from ._shared import shared_parser def parse_authors( @@ -52,7 +53,7 @@ def register_parser( ) -> argparse._SubParsersAction: parser = parent_subparsers.add_parser( "abbreviate-authors", - parents=[parent_subparsers.shared_parser], # type: ignore + parents=[shared_parser], # type: ignore help=first_line(__doc__), ) diff --git a/augur/curate/apply_geolocation_rules.py b/augur/curate/apply_geolocation_rules.py index e14974890..26e227f72 100644 --- a/augur/curate/apply_geolocation_rules.py +++ b/augur/curate/apply_geolocation_rules.py @@ -5,6 +5,7 @@ from augur.errors import AugurError from augur.io.print import print_err from augur.utils import first_line +from ._shared import shared_parser class CyclicGeolocationRulesError(AugurError): @@ -188,7 +189,7 @@ def transform_geolocations(geolocation_rules, geolocation): def register_parser(parent_subparsers): parser = parent_subparsers.add_parser("apply-geolocation-rules", - parents=[parent_subparsers.shared_parser], + parents=[shared_parser], help=first_line(__doc__)) parser.add_argument("--region-field", default="region", diff --git a/augur/curate/apply_record_annotations.py b/augur/curate/apply_record_annotations.py index a650226d3..1cffed35d 100644 --- a/augur/curate/apply_record_annotations.py +++ b/augur/curate/apply_record_annotations.py @@ -7,11 +7,12 @@ from augur.errors import AugurError from augur.io.print import print_err from augur.utils import first_line +from ._shared import shared_parser def register_parser(parent_subparsers): parser = parent_subparsers.add_parser("apply-record-annotations", - parents=[parent_subparsers.shared_parser], + parents=[shared_parser], help=first_line(__doc__)) parser.add_argument("--annotations", metavar="TSV", required=True, diff --git a/augur/curate/format_dates.py b/augur/curate/format_dates.py index 063096de4..7e181b44c 100644 --- a/augur/curate/format_dates.py +++ b/augur/curate/format_dates.py @@ -10,6 +10,7 @@ from augur.io.print import print_err from augur.types import DataErrorMethod from .format_dates_directives import YEAR_DIRECTIVES, YEAR_MONTH_DIRECTIVES, YEAR_MONTH_DAY_DIRECTIVES +from ._shared import shared_parser # Default date formats that this command should parse @@ -24,7 +25,7 @@ def register_parser(parent_subparsers): parser = parent_subparsers.add_parser("format-dates", - parents=[parent_subparsers.shared_parser], + parents=[shared_parser], help=__doc__) required = parser.add_argument_group(title="REQUIRED") diff --git a/augur/curate/normalize_strings.py b/augur/curate/normalize_strings.py index ac98f4a6f..55cc366ff 100644 --- a/augur/curate/normalize_strings.py +++ b/augur/curate/normalize_strings.py @@ -7,11 +7,12 @@ import unicodedata from augur.utils import first_line +from ._shared import shared_parser def register_parser(parent_subparsers): parser = parent_subparsers.add_parser("normalize-strings", - parents=[parent_subparsers.shared_parser], + parents=[shared_parser], help=first_line(__doc__)) optional = parser.add_argument_group(title="OPTIONAL") diff --git a/augur/curate/parse_genbank_location.py b/augur/curate/parse_genbank_location.py index f1123d595..19a2dafe5 100644 --- a/augur/curate/parse_genbank_location.py +++ b/augur/curate/parse_genbank_location.py @@ -10,6 +10,7 @@ from typing import Generator, List from augur.io.print import print_err from augur.utils import first_line +from ._shared import shared_parser def parse_location( @@ -50,7 +51,7 @@ def register_parser( ) -> argparse._SubParsersAction: parser = parent_subparsers.add_parser( "parse-genbank-location", - parents=[parent_subparsers.shared_parser], # type: ignore + parents=[shared_parser], # type: ignore help=first_line(__doc__), ) diff --git a/augur/curate/passthru.py b/augur/curate/passthru.py index f31f2dc45..cbd3dce11 100644 --- a/augur/curate/passthru.py +++ b/augur/curate/passthru.py @@ -2,11 +2,12 @@ Pass through records without doing any data transformations. Useful for testing, troubleshooting, or just converting file formats. """ +from ._shared import shared_parser def register_parser(parent_subparsers): return parent_subparsers.add_parser("passthru", - parents=[parent_subparsers.shared_parser], + parents=[shared_parser], help=__doc__) diff --git a/augur/curate/rename.py b/augur/curate/rename.py index 68730fa28..3d7580dbb 100644 --- a/augur/curate/rename.py +++ b/augur/curate/rename.py @@ -6,10 +6,11 @@ import argparse from augur.io.print import print_err from augur.errors import AugurError +from ._shared import shared_parser def register_parser(parent_subparsers): parser = parent_subparsers.add_parser("rename", - parents = [parent_subparsers.shared_parser], + parents = [shared_parser], help = __doc__) required = parser.add_argument_group(title="REQUIRED") diff --git a/augur/curate/titlecase.py b/augur/curate/titlecase.py index e0d6017b8..8cb1de4ab 100644 --- a/augur/curate/titlecase.py +++ b/augur/curate/titlecase.py @@ -7,10 +7,11 @@ from augur.errors import AugurError from augur.io.print import print_err from augur.types import DataErrorMethod +from ._shared import shared_parser def register_parser(parent_subparsers): parser = parent_subparsers.add_parser("titlecase", - parents = [parent_subparsers.shared_parser], + parents = [shared_parser], help = __doc__) required = parser.add_argument_group(title="REQUIRED") diff --git a/augur/curate/transform_strain_name.py b/augur/curate/transform_strain_name.py index ae128893d..dae6a735a 100644 --- a/augur/curate/transform_strain_name.py +++ b/augur/curate/transform_strain_name.py @@ -9,6 +9,7 @@ from typing import Generator, List from augur.io.print import print_err from augur.utils import first_line +from ._shared import shared_parser def transform_name( @@ -42,7 +43,7 @@ def register_parser( ) -> argparse._SubParsersAction: parser = parent_subparsers.add_parser( "transform-strain-name", - parents=[parent_subparsers.shared_parser], # type: ignore + parents=[shared_parser], # type: ignore help=first_line(__doc__), ) From 6babb82acaf515296c34dc7a35099949dc90584a Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Thu, 31 Oct 2024 16:42:48 -0700 Subject: [PATCH 2/8] curate: Validate records with a wrapper This is necessary to register subparsers recursively. --- augur/curate/__init__.py | 137 ----------------------- augur/curate/_shared.py | 137 ++++++++++++++++++++++- augur/curate/abbreviate_authors.py | 3 +- augur/curate/apply_geolocation_rules.py | 3 +- augur/curate/apply_record_annotations.py | 3 +- augur/curate/format_dates.py | 3 +- augur/curate/normalize_strings.py | 3 +- augur/curate/parse_genbank_location.py | 3 +- augur/curate/passthru.py | 3 +- augur/curate/rename.py | 3 +- augur/curate/titlecase.py | 3 +- augur/curate/transform_strain_name.py | 3 +- 12 files changed, 156 insertions(+), 148 deletions(-) diff --git a/augur/curate/__init__.py b/augur/curate/__init__.py index 5e8888127..d162ab40b 100644 --- a/augur/curate/__init__.py +++ b/augur/curate/__init__.py @@ -1,16 +1,7 @@ """ A suite of commands to help with data curation. """ -import sys -from collections import deque -from textwrap import dedent -from typing import Iterable, Set - from augur.argparse_ import add_command_subparsers -from augur.errors import AugurError -from augur.io.json import dump_ndjson, load_ndjson -from augur.io.metadata import InvalidDelimiter, read_table_to_dict, read_metadata_with_sequences, write_records_to_tsv -from augur.io.sequences import write_records_to_fasta from . import format_dates, normalize_strings, passthru, titlecase, apply_geolocation_rules, apply_record_annotations, abbreviate_authors, parse_genbank_location, transform_strain_name, rename @@ -43,131 +34,3 @@ def register_parser(parent_subparsers): add_command_subparsers(subparsers, SUBCOMMANDS, SUBCOMMAND_ATTRIBUTE) return parser - - -def validate_records(records: Iterable[dict], subcmd_name: str, is_input: bool) -> Iterable[dict]: - """ - Validate that the provided *records* all have the same fields. - Uses the keys of the first record to check against all other records. - - Parameters - ---------- - records: iterable of dict - - subcmd_name: str - The name of the subcommand whose output is being validated; used in - error messages displayed to the user. - - is_input: bool - Whether the provided records come directly from user provided input - """ - error_message = "Records do not have the same fields! " - if is_input: - error_message += "Please check your input data has the same fields." - else: - # Hopefully users should not run into this error as it means we are - # not uniformly adding/removing fields from records - error_message += dedent(f"""\ - Something unexpected happened during the augur curate {subcmd_name} command. - To report this, please open a new issue including the original command: - - """) - - first_record_keys: Set[str] = set() - for idx, record in enumerate(records): - if idx == 0: - first_record_keys.update(record.keys()) - else: - if set(record.keys()) != first_record_keys: - raise AugurError(error_message) - yield record - - -def run(args): - # Print help if no subcommands are used - if not getattr(args, SUBCOMMAND_ATTRIBUTE, None): - return args.print_help() - - # Check provided args are valid and required combination of args are provided - if not args.fasta and (args.seq_id_column or args.seq_field): - raise AugurError("The --seq-id-column and --seq-field options should only be used when providing a FASTA file.") - - if args.fasta and (not args.seq_id_column or not args.seq_field): - raise AugurError("The --seq-id-column and --seq-field options are required for a FASTA file input.") - - if not args.output_fasta and (args.output_id_field or args.output_seq_field): - raise AugurError("The --output-id-field and --output-seq-field options should only be used when requesting a FASTA output.") - - if args.output_fasta and (not args.output_id_field or not args.output_seq_field): - raise AugurError("The --output-id-field and --output-seq-field options are required for a FASTA output.") - - # Read inputs - # Special case single hyphen as stdin - if args.metadata == '-': - args.metadata = sys.stdin.buffer - - if args.metadata and args.fasta: - try: - records = read_metadata_with_sequences( - args.metadata, - args.metadata_delimiters, - args.fasta, - args.seq_id_column, - args.seq_field, - args.unmatched_reporting, - args.duplicate_reporting) - except InvalidDelimiter: - raise AugurError( - f"Could not determine the delimiter of {args.metadata!r}. " - f"Valid delimiters are: {args.metadata_delimiters!r}. " - "This can be changed with --metadata-delimiters." - ) - elif args.metadata: - try: - records = read_table_to_dict(args.metadata, args.metadata_delimiters, args.duplicate_reporting, args.id_column) - except InvalidDelimiter: - raise AugurError( - f"Could not determine the delimiter of {args.metadata!r}. " - f"Valid delimiters are: {args.metadata_delimiters!r}. " - "This can be changed with --metadata-delimiters." - ) - elif not sys.stdin.isatty(): - records = load_ndjson(sys.stdin) - else: - raise AugurError(dedent("""\ - No valid inputs were provided. - NDJSON records can be streamed from stdin or - input files can be provided via the command line options `--metadata` and `--fasta`. - See the command's help message for more details.""")) - - # Get the name of the subcmd being run - subcmd_name = args.subcommand - - # Validate records have the same input fields - validated_input_records = validate_records(records, subcmd_name, True) - - # Run subcommand to get modified records - modified_records = getattr(args, SUBCOMMAND_ATTRIBUTE).run(args, validated_input_records) - - # Validate modified records have the same output fields - validated_output_records = validate_records(modified_records, subcmd_name, False) - - # Output modified records - # First output FASTA, since the write fasta function yields the records again - # and removes the sequences from the records - if args.output_fasta: - validated_output_records = write_records_to_fasta( - validated_output_records, - args.output_fasta, - args.output_id_field, - args.output_seq_field) - - if args.output_metadata: - write_records_to_tsv(validated_output_records, args.output_metadata) - - if not (args.output_fasta or args.output_metadata): - dump_ndjson(validated_output_records) - else: - # Exhaust generator to ensure we run through all records - # when only a FASTA output is requested but not a metadata output - deque(validated_output_records, maxlen=0) diff --git a/augur/curate/_shared.py b/augur/curate/_shared.py index 590033cbe..2fd725d1d 100644 --- a/augur/curate/_shared.py +++ b/augur/curate/_shared.py @@ -1,6 +1,14 @@ import argparse +import sys +from collections import deque +from textwrap import dedent +from typing import Callable, Iterable, Set + from augur.argparse_ import ExtendOverwriteDefault -from augur.io.metadata import DEFAULT_DELIMITERS +from augur.errors import AugurError +from augur.io.json import dump_ndjson, load_ndjson +from augur.io.metadata import DEFAULT_DELIMITERS, InvalidDelimiter, read_table_to_dict, read_metadata_with_sequences, write_records_to_tsv +from augur.io.sequences import write_records_to_fasta from augur.types import DataErrorMethod @@ -76,3 +84,130 @@ def create_shared_parser(): shared_parser = create_shared_parser() + + +def validate_records(records: Iterable[dict], subcmd_name: str, is_input: bool) -> Iterable[dict]: + """ + Validate that the provided *records* all have the same fields. + Uses the keys of the first record to check against all other records. + + Parameters + ---------- + records: iterable of dict + + subcmd_name: str + The name of the subcommand whose output is being validated; used in + error messages displayed to the user. + + is_input: bool + Whether the provided records come directly from user provided input + """ + error_message = "Records do not have the same fields! " + if is_input: + error_message += "Please check your input data has the same fields." + else: + # Hopefully users should not run into this error as it means we are + # not uniformly adding/removing fields from records + error_message += dedent(f"""\ + Something unexpected happened during the augur curate {subcmd_name} command. + To report this, please open a new issue including the original command: + + """) + + first_record_keys: Set[str] = set() + for idx, record in enumerate(records): + if idx == 0: + first_record_keys.update(record.keys()) + else: + if set(record.keys()) != first_record_keys: + raise AugurError(error_message) + yield record + + +def validate(command_run: Callable[[argparse.Namespace, Iterable[dict]], Iterable[dict]]): + def run(args: argparse.Namespace) -> None: + # Check provided args are valid and required combination of args are provided + if not args.fasta and (args.seq_id_column or args.seq_field): + raise AugurError("The --seq-id-column and --seq-field options should only be used when providing a FASTA file.") + + if args.fasta and (not args.seq_id_column or not args.seq_field): + raise AugurError("The --seq-id-column and --seq-field options are required for a FASTA file input.") + + if not args.output_fasta and (args.output_id_field or args.output_seq_field): + raise AugurError("The --output-id-field and --output-seq-field options should only be used when requesting a FASTA output.") + + if args.output_fasta and (not args.output_id_field or not args.output_seq_field): + raise AugurError("The --output-id-field and --output-seq-field options are required for a FASTA output.") + + # Read inputs + # Special case single hyphen as stdin + if args.metadata == '-': + args.metadata = sys.stdin.buffer + + if args.metadata and args.fasta: + try: + records = read_metadata_with_sequences( + args.metadata, + args.metadata_delimiters, + args.fasta, + args.seq_id_column, + args.seq_field, + args.unmatched_reporting, + args.duplicate_reporting) + except InvalidDelimiter: + raise AugurError( + f"Could not determine the delimiter of {args.metadata!r}. " + f"Valid delimiters are: {args.metadata_delimiters!r}. " + "This can be changed with --metadata-delimiters." + ) + elif args.metadata: + try: + records = read_table_to_dict(args.metadata, args.metadata_delimiters, args.duplicate_reporting, args.id_column) + except InvalidDelimiter: + raise AugurError( + f"Could not determine the delimiter of {args.metadata!r}. " + f"Valid delimiters are: {args.metadata_delimiters!r}. " + "This can be changed with --metadata-delimiters." + ) + elif not sys.stdin.isatty(): + records = load_ndjson(sys.stdin) + else: + raise AugurError(dedent("""\ + No valid inputs were provided. + NDJSON records can be streamed from stdin or + input files can be provided via the command line options `--metadata` and `--fasta`. + See the command's help message for more details.""")) + + # Get the name of the subcmd being run + subcmd_name = args.subcommand + + # Validate records have the same input fields + validated_input_records = validate_records(records, subcmd_name, True) + + # Run subcommand to get modified records + modified_records = command_run(args, validated_input_records) + + # Validate modified records have the same output fields + validated_output_records = validate_records(modified_records, subcmd_name, False) + + # Output modified records + # First output FASTA, since the write fasta function yields the records again + # and removes the sequences from the records + if args.output_fasta: + validated_output_records = write_records_to_fasta( + validated_output_records, + args.output_fasta, + args.output_id_field, + args.output_seq_field) + + if args.output_metadata: + write_records_to_tsv(validated_output_records, args.output_metadata) + + if not (args.output_fasta or args.output_metadata): + dump_ndjson(validated_output_records) + else: + # Exhaust generator to ensure we run through all records + # when only a FASTA output is requested but not a metadata output + deque(validated_output_records, maxlen=0) + + return run diff --git a/augur/curate/abbreviate_authors.py b/augur/curate/abbreviate_authors.py index ae9a161f5..0a2bdc8a8 100644 --- a/augur/curate/abbreviate_authors.py +++ b/augur/curate/abbreviate_authors.py @@ -10,7 +10,7 @@ from typing import Generator, List from augur.io.print import print_err from augur.utils import first_line -from ._shared import shared_parser +from ._shared import shared_parser, validate def parse_authors( @@ -76,6 +76,7 @@ def register_parser( return parser +@validate def run(args: argparse.Namespace, records: List[dict]) -> Generator[dict, None, None]: for index, record in enumerate(records): parse_authors( diff --git a/augur/curate/apply_geolocation_rules.py b/augur/curate/apply_geolocation_rules.py index 26e227f72..15ab27032 100644 --- a/augur/curate/apply_geolocation_rules.py +++ b/augur/curate/apply_geolocation_rules.py @@ -5,7 +5,7 @@ from augur.errors import AugurError from augur.io.print import print_err from augur.utils import first_line -from ._shared import shared_parser +from ._shared import shared_parser, validate class CyclicGeolocationRulesError(AugurError): @@ -211,6 +211,7 @@ def register_parser(parent_subparsers): return parser +@validate def run(args, records): location_fields = [args.region_field, args.country_field, args.division_field, args.location_field] diff --git a/augur/curate/apply_record_annotations.py b/augur/curate/apply_record_annotations.py index 1cffed35d..c79269856 100644 --- a/augur/curate/apply_record_annotations.py +++ b/augur/curate/apply_record_annotations.py @@ -7,7 +7,7 @@ from augur.errors import AugurError from augur.io.print import print_err from augur.utils import first_line -from ._shared import shared_parser +from ._shared import shared_parser, validate def register_parser(parent_subparsers): @@ -28,6 +28,7 @@ def register_parser(parent_subparsers): return parser +@validate def run(args, records): annotations = defaultdict(dict) with open(args.annotations, 'r', newline='') as annotations_fh: diff --git a/augur/curate/format_dates.py b/augur/curate/format_dates.py index 7e181b44c..1cc2fb6bc 100644 --- a/augur/curate/format_dates.py +++ b/augur/curate/format_dates.py @@ -10,7 +10,7 @@ from augur.io.print import print_err from augur.types import DataErrorMethod from .format_dates_directives import YEAR_DIRECTIVES, YEAR_MONTH_DIRECTIVES, YEAR_MONTH_DAY_DIRECTIVES -from ._shared import shared_parser +from ._shared import shared_parser, validate # Default date formats that this command should parse @@ -179,6 +179,7 @@ def format_date(date_string, expected_formats): return None +@validate def run(args, records): failures = [] failure_reporting = args.failure_reporting diff --git a/augur/curate/normalize_strings.py b/augur/curate/normalize_strings.py index 55cc366ff..e4ef76ca4 100644 --- a/augur/curate/normalize_strings.py +++ b/augur/curate/normalize_strings.py @@ -7,7 +7,7 @@ import unicodedata from augur.utils import first_line -from ._shared import shared_parser +from ._shared import shared_parser, validate def register_parser(parent_subparsers): @@ -45,6 +45,7 @@ def normalize_strings(record, form='NFC'): } +@validate def run(args, records): for record in records: yield normalize_strings(record, args.form) diff --git a/augur/curate/parse_genbank_location.py b/augur/curate/parse_genbank_location.py index 19a2dafe5..b25cef190 100644 --- a/augur/curate/parse_genbank_location.py +++ b/augur/curate/parse_genbank_location.py @@ -10,7 +10,7 @@ from typing import Generator, List from augur.io.print import print_err from augur.utils import first_line -from ._shared import shared_parser +from ._shared import shared_parser, validate def parse_location( @@ -65,6 +65,7 @@ def register_parser( return parser +@validate def run( args: argparse.Namespace, records: List[dict], diff --git a/augur/curate/passthru.py b/augur/curate/passthru.py index cbd3dce11..29868ad81 100644 --- a/augur/curate/passthru.py +++ b/augur/curate/passthru.py @@ -2,7 +2,7 @@ Pass through records without doing any data transformations. Useful for testing, troubleshooting, or just converting file formats. """ -from ._shared import shared_parser +from ._shared import shared_parser, validate def register_parser(parent_subparsers): @@ -11,6 +11,7 @@ def register_parser(parent_subparsers): help=__doc__) +@validate def run(args, records): yield from records diff --git a/augur/curate/rename.py b/augur/curate/rename.py index 3d7580dbb..4ee080792 100644 --- a/augur/curate/rename.py +++ b/augur/curate/rename.py @@ -6,7 +6,7 @@ import argparse from augur.io.print import print_err from augur.errors import AugurError -from ._shared import shared_parser +from ._shared import shared_parser, validate def register_parser(parent_subparsers): parser = parent_subparsers.add_parser("rename", @@ -87,6 +87,7 @@ def transform_columns(existing_fields: List[str], field_map: List[Tuple[str,str] return m +@validate def run(args: argparse.Namespace, records: Iterable[dict]) -> Iterable[dict]: col_map: Union[Literal[False], List[Tuple[str,str]]] = False for record in records: diff --git a/augur/curate/titlecase.py b/augur/curate/titlecase.py index 8cb1de4ab..b26f1f175 100644 --- a/augur/curate/titlecase.py +++ b/augur/curate/titlecase.py @@ -7,7 +7,7 @@ from augur.errors import AugurError from augur.io.print import print_err from augur.types import DataErrorMethod -from ._shared import shared_parser +from ._shared import shared_parser, validate def register_parser(parent_subparsers): parser = parent_subparsers.add_parser("titlecase", @@ -74,6 +74,7 @@ def changecase(index, word): return ''.join(changecase(i, w) for i, w in words) +@validate def run(args, records): failures = [] failure_reporting = args.failure_reporting diff --git a/augur/curate/transform_strain_name.py b/augur/curate/transform_strain_name.py index dae6a735a..556f7282b 100644 --- a/augur/curate/transform_strain_name.py +++ b/augur/curate/transform_strain_name.py @@ -9,7 +9,7 @@ from typing import Generator, List from augur.io.print import print_err from augur.utils import first_line -from ._shared import shared_parser +from ._shared import shared_parser, validate def transform_name( @@ -65,6 +65,7 @@ def register_parser( return parser +@validate def run(args: argparse.Namespace, records: List[dict]) -> Generator[dict, None, None]: strain_name_pattern = re.compile(args.strain_regex) From 11a3e15b6a6acb1e735df946d483ac1ae96d5ce5 Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Thu, 31 Oct 2024 16:24:54 -0700 Subject: [PATCH 3/8] curate: Pass command name to validate wrapper This information will be lost when registering subparsers recursively. --- augur/curate/_shared.py | 176 ++++++++++++----------- augur/curate/abbreviate_authors.py | 7 +- augur/curate/apply_geolocation_rules.py | 7 +- augur/curate/apply_record_annotations.py | 7 +- augur/curate/format_dates.py | 7 +- augur/curate/normalize_strings.py | 7 +- augur/curate/parse_genbank_location.py | 7 +- augur/curate/passthru.py | 7 +- augur/curate/rename.py | 8 +- augur/curate/titlecase.py | 8 +- augur/curate/transform_strain_name.py | 7 +- 11 files changed, 141 insertions(+), 107 deletions(-) diff --git a/augur/curate/_shared.py b/augur/curate/_shared.py index 2fd725d1d..2cde7afb7 100644 --- a/augur/curate/_shared.py +++ b/augur/curate/_shared.py @@ -124,90 +124,92 @@ def validate_records(records: Iterable[dict], subcmd_name: str, is_input: bool) yield record -def validate(command_run: Callable[[argparse.Namespace, Iterable[dict]], Iterable[dict]]): - def run(args: argparse.Namespace) -> None: - # Check provided args are valid and required combination of args are provided - if not args.fasta and (args.seq_id_column or args.seq_field): - raise AugurError("The --seq-id-column and --seq-field options should only be used when providing a FASTA file.") - - if args.fasta and (not args.seq_id_column or not args.seq_field): - raise AugurError("The --seq-id-column and --seq-field options are required for a FASTA file input.") - - if not args.output_fasta and (args.output_id_field or args.output_seq_field): - raise AugurError("The --output-id-field and --output-seq-field options should only be used when requesting a FASTA output.") - - if args.output_fasta and (not args.output_id_field or not args.output_seq_field): - raise AugurError("The --output-id-field and --output-seq-field options are required for a FASTA output.") - - # Read inputs - # Special case single hyphen as stdin - if args.metadata == '-': - args.metadata = sys.stdin.buffer - - if args.metadata and args.fasta: - try: - records = read_metadata_with_sequences( - args.metadata, - args.metadata_delimiters, - args.fasta, - args.seq_id_column, - args.seq_field, - args.unmatched_reporting, - args.duplicate_reporting) - except InvalidDelimiter: - raise AugurError( - f"Could not determine the delimiter of {args.metadata!r}. " - f"Valid delimiters are: {args.metadata_delimiters!r}. " - "This can be changed with --metadata-delimiters." - ) - elif args.metadata: - try: - records = read_table_to_dict(args.metadata, args.metadata_delimiters, args.duplicate_reporting, args.id_column) - except InvalidDelimiter: - raise AugurError( - f"Could not determine the delimiter of {args.metadata!r}. " - f"Valid delimiters are: {args.metadata_delimiters!r}. " - "This can be changed with --metadata-delimiters." - ) - elif not sys.stdin.isatty(): - records = load_ndjson(sys.stdin) - else: - raise AugurError(dedent("""\ - No valid inputs were provided. - NDJSON records can be streamed from stdin or - input files can be provided via the command line options `--metadata` and `--fasta`. - See the command's help message for more details.""")) - - # Get the name of the subcmd being run - subcmd_name = args.subcommand - - # Validate records have the same input fields - validated_input_records = validate_records(records, subcmd_name, True) - - # Run subcommand to get modified records - modified_records = command_run(args, validated_input_records) - - # Validate modified records have the same output fields - validated_output_records = validate_records(modified_records, subcmd_name, False) - - # Output modified records - # First output FASTA, since the write fasta function yields the records again - # and removes the sequences from the records - if args.output_fasta: - validated_output_records = write_records_to_fasta( - validated_output_records, - args.output_fasta, - args.output_id_field, - args.output_seq_field) - - if args.output_metadata: - write_records_to_tsv(validated_output_records, args.output_metadata) - - if not (args.output_fasta or args.output_metadata): - dump_ndjson(validated_output_records) - else: - # Exhaust generator to ensure we run through all records - # when only a FASTA output is requested but not a metadata output - deque(validated_output_records, maxlen=0) - - return run +def validate(command_name: str): + + def decorator_with_context(command_run: Callable[[argparse.Namespace, Iterable[dict]], Iterable[dict]]): + + def run(args: argparse.Namespace) -> None: + # Check provided args are valid and required combination of args are provided + if not args.fasta and (args.seq_id_column or args.seq_field): + raise AugurError("The --seq-id-column and --seq-field options should only be used when providing a FASTA file.") + + if args.fasta and (not args.seq_id_column or not args.seq_field): + raise AugurError("The --seq-id-column and --seq-field options are required for a FASTA file input.") + + if not args.output_fasta and (args.output_id_field or args.output_seq_field): + raise AugurError("The --output-id-field and --output-seq-field options should only be used when requesting a FASTA output.") + + if args.output_fasta and (not args.output_id_field or not args.output_seq_field): + raise AugurError("The --output-id-field and --output-seq-field options are required for a FASTA output.") + + # Read inputs + # Special case single hyphen as stdin + if args.metadata == '-': + args.metadata = sys.stdin.buffer + + if args.metadata and args.fasta: + try: + records = read_metadata_with_sequences( + args.metadata, + args.metadata_delimiters, + args.fasta, + args.seq_id_column, + args.seq_field, + args.unmatched_reporting, + args.duplicate_reporting) + except InvalidDelimiter: + raise AugurError( + f"Could not determine the delimiter of {args.metadata!r}. " + f"Valid delimiters are: {args.metadata_delimiters!r}. " + "This can be changed with --metadata-delimiters." + ) + elif args.metadata: + try: + records = read_table_to_dict(args.metadata, args.metadata_delimiters, args.duplicate_reporting, args.id_column) + except InvalidDelimiter: + raise AugurError( + f"Could not determine the delimiter of {args.metadata!r}. " + f"Valid delimiters are: {args.metadata_delimiters!r}. " + "This can be changed with --metadata-delimiters." + ) + elif not sys.stdin.isatty(): + records = load_ndjson(sys.stdin) + else: + raise AugurError(dedent("""\ + No valid inputs were provided. + NDJSON records can be streamed from stdin or + input files can be provided via the command line options `--metadata` and `--fasta`. + See the command's help message for more details.""")) + + # Validate records have the same input fields + validated_input_records = validate_records(records, command_name, True) + + # Run subcommand to get modified records + modified_records = command_run(args, validated_input_records) + + # Validate modified records have the same output fields + validated_output_records = validate_records(modified_records, command_name, False) + + # Output modified records + # First output FASTA, since the write fasta function yields the records again + # and removes the sequences from the records + if args.output_fasta: + validated_output_records = write_records_to_fasta( + validated_output_records, + args.output_fasta, + args.output_id_field, + args.output_seq_field) + + if args.output_metadata: + write_records_to_tsv(validated_output_records, args.output_metadata) + + if not (args.output_fasta or args.output_metadata): + dump_ndjson(validated_output_records) + else: + # Exhaust generator to ensure we run through all records + # when only a FASTA output is requested but not a metadata output + deque(validated_output_records, maxlen=0) + + return run + + return decorator_with_context diff --git a/augur/curate/abbreviate_authors.py b/augur/curate/abbreviate_authors.py index 0a2bdc8a8..399247b42 100644 --- a/augur/curate/abbreviate_authors.py +++ b/augur/curate/abbreviate_authors.py @@ -13,6 +13,9 @@ from ._shared import shared_parser, validate +COMMAND_NAME = "abbreviate-authors" + + def parse_authors( record: dict, authors_field: str, @@ -52,7 +55,7 @@ def register_parser( parent_subparsers: argparse._SubParsersAction, ) -> argparse._SubParsersAction: parser = parent_subparsers.add_parser( - "abbreviate-authors", + COMMAND_NAME, parents=[shared_parser], # type: ignore help=first_line(__doc__), ) @@ -76,7 +79,7 @@ def register_parser( return parser -@validate +@validate(COMMAND_NAME) def run(args: argparse.Namespace, records: List[dict]) -> Generator[dict, None, None]: for index, record in enumerate(records): parse_authors( diff --git a/augur/curate/apply_geolocation_rules.py b/augur/curate/apply_geolocation_rules.py index 15ab27032..ecf7ab7ef 100644 --- a/augur/curate/apply_geolocation_rules.py +++ b/augur/curate/apply_geolocation_rules.py @@ -8,6 +8,9 @@ from ._shared import shared_parser, validate +COMMAND_NAME = "apply-geolocation-rules" + + class CyclicGeolocationRulesError(AugurError): pass @@ -188,7 +191,7 @@ def transform_geolocations(geolocation_rules, geolocation): def register_parser(parent_subparsers): - parser = parent_subparsers.add_parser("apply-geolocation-rules", + parser = parent_subparsers.add_parser(COMMAND_NAME, parents=[shared_parser], help=first_line(__doc__)) @@ -211,7 +214,7 @@ def register_parser(parent_subparsers): return parser -@validate +@validate(COMMAND_NAME) def run(args, records): location_fields = [args.region_field, args.country_field, args.division_field, args.location_field] diff --git a/augur/curate/apply_record_annotations.py b/augur/curate/apply_record_annotations.py index c79269856..c725ee39f 100644 --- a/augur/curate/apply_record_annotations.py +++ b/augur/curate/apply_record_annotations.py @@ -10,8 +10,11 @@ from ._shared import shared_parser, validate +COMMAND_NAME = "apply-record-annotations" + + def register_parser(parent_subparsers): - parser = parent_subparsers.add_parser("apply-record-annotations", + parser = parent_subparsers.add_parser(COMMAND_NAME, parents=[shared_parser], help=first_line(__doc__)) @@ -28,7 +31,7 @@ def register_parser(parent_subparsers): return parser -@validate +@validate(COMMAND_NAME) def run(args, records): annotations = defaultdict(dict) with open(args.annotations, 'r', newline='') as annotations_fh: diff --git a/augur/curate/format_dates.py b/augur/curate/format_dates.py index 1cc2fb6bc..c9e2fe8fe 100644 --- a/augur/curate/format_dates.py +++ b/augur/curate/format_dates.py @@ -13,6 +13,9 @@ from ._shared import shared_parser, validate +COMMAND_NAME = "format-dates" + + # Default date formats that this command should parse # without additional input from the user. DEFAULT_EXPECTED_DATE_FORMATS = [ @@ -24,7 +27,7 @@ def register_parser(parent_subparsers): - parser = parent_subparsers.add_parser("format-dates", + parser = parent_subparsers.add_parser(COMMAND_NAME, parents=[shared_parser], help=__doc__) @@ -179,7 +182,7 @@ def format_date(date_string, expected_formats): return None -@validate +@validate(COMMAND_NAME) def run(args, records): failures = [] failure_reporting = args.failure_reporting diff --git a/augur/curate/normalize_strings.py b/augur/curate/normalize_strings.py index e4ef76ca4..b45d87332 100644 --- a/augur/curate/normalize_strings.py +++ b/augur/curate/normalize_strings.py @@ -10,8 +10,11 @@ from ._shared import shared_parser, validate +COMMAND_NAME = "normalize-strings" + + def register_parser(parent_subparsers): - parser = parent_subparsers.add_parser("normalize-strings", + parser = parent_subparsers.add_parser(COMMAND_NAME, parents=[shared_parser], help=first_line(__doc__)) @@ -45,7 +48,7 @@ def normalize_strings(record, form='NFC'): } -@validate +@validate(COMMAND_NAME) def run(args, records): for record in records: yield normalize_strings(record, args.form) diff --git a/augur/curate/parse_genbank_location.py b/augur/curate/parse_genbank_location.py index b25cef190..e28e7028a 100644 --- a/augur/curate/parse_genbank_location.py +++ b/augur/curate/parse_genbank_location.py @@ -13,6 +13,9 @@ from ._shared import shared_parser, validate +COMMAND_NAME = "parse-genbank-location" + + def parse_location( record: dict, location_field_name: str, @@ -50,7 +53,7 @@ def register_parser( parent_subparsers: argparse._SubParsersAction, ) -> argparse._SubParsersAction: parser = parent_subparsers.add_parser( - "parse-genbank-location", + COMMAND_NAME, parents=[shared_parser], # type: ignore help=first_line(__doc__), ) @@ -65,7 +68,7 @@ def register_parser( return parser -@validate +@validate(COMMAND_NAME) def run( args: argparse.Namespace, records: List[dict], diff --git a/augur/curate/passthru.py b/augur/curate/passthru.py index 29868ad81..a40e58c78 100644 --- a/augur/curate/passthru.py +++ b/augur/curate/passthru.py @@ -5,13 +5,16 @@ from ._shared import shared_parser, validate +COMMAND_NAME = "passthru" + + def register_parser(parent_subparsers): - return parent_subparsers.add_parser("passthru", + return parent_subparsers.add_parser(COMMAND_NAME, parents=[shared_parser], help=__doc__) -@validate +@validate(COMMAND_NAME) def run(args, records): yield from records diff --git a/augur/curate/rename.py b/augur/curate/rename.py index 4ee080792..28882827d 100644 --- a/augur/curate/rename.py +++ b/augur/curate/rename.py @@ -8,8 +8,12 @@ from augur.errors import AugurError from ._shared import shared_parser, validate + +COMMAND_NAME = "rename" + + def register_parser(parent_subparsers): - parser = parent_subparsers.add_parser("rename", + parser = parent_subparsers.add_parser(COMMAND_NAME, parents = [shared_parser], help = __doc__) @@ -87,7 +91,7 @@ def transform_columns(existing_fields: List[str], field_map: List[Tuple[str,str] return m -@validate +@validate(COMMAND_NAME) def run(args: argparse.Namespace, records: Iterable[dict]) -> Iterable[dict]: col_map: Union[Literal[False], List[Tuple[str,str]]] = False for record in records: diff --git a/augur/curate/titlecase.py b/augur/curate/titlecase.py index b26f1f175..f2869dc3e 100644 --- a/augur/curate/titlecase.py +++ b/augur/curate/titlecase.py @@ -9,8 +9,12 @@ from augur.types import DataErrorMethod from ._shared import shared_parser, validate + +COMMAND_NAME = "titlecase" + + def register_parser(parent_subparsers): - parser = parent_subparsers.add_parser("titlecase", + parser = parent_subparsers.add_parser(COMMAND_NAME, parents = [shared_parser], help = __doc__) @@ -74,7 +78,7 @@ def changecase(index, word): return ''.join(changecase(i, w) for i, w in words) -@validate +@validate(COMMAND_NAME) def run(args, records): failures = [] failure_reporting = args.failure_reporting diff --git a/augur/curate/transform_strain_name.py b/augur/curate/transform_strain_name.py index 556f7282b..8e4e4ab3a 100644 --- a/augur/curate/transform_strain_name.py +++ b/augur/curate/transform_strain_name.py @@ -12,6 +12,9 @@ from ._shared import shared_parser, validate +COMMAND_NAME = "transform-strain-name" + + def transform_name( record: dict, index: int, @@ -42,7 +45,7 @@ def register_parser( parent_subparsers: argparse._SubParsersAction, ) -> argparse._SubParsersAction: parser = parent_subparsers.add_parser( - "transform-strain-name", + COMMAND_NAME, parents=[shared_parser], # type: ignore help=first_line(__doc__), ) @@ -65,7 +68,7 @@ def register_parser( return parser -@validate +@validate(COMMAND_NAME) def run(args: argparse.Namespace, records: List[dict]) -> Generator[dict, None, None]: strain_name_pattern = re.compile(args.strain_regex) From 2457a6dadbab7516c0258ea97c6e731674f9b79a Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Wed, 25 Sep 2024 16:10:58 -0700 Subject: [PATCH 4/8] Register subparsers recursively Keeps consistent with Nextstrain CLI. --- augur/__init__.py | 5 ++--- augur/argparse_.py | 15 ++++++++++----- augur/curate/__init__.py | 11 +---------- augur/export.py | 12 +++++------- augur/import_/__init__.py | 8 +++----- augur/measurements/__init__.py | 7 +++---- 6 files changed, 24 insertions(+), 34 deletions(-) diff --git a/augur/__init__.py b/augur/__init__.py index af4c23817..00d6e609e 100644 --- a/augur/__init__.py +++ b/augur/__init__.py @@ -14,7 +14,7 @@ from .debug import DEBUGGING from .errors import AugurError from .io.print import print_err -from .argparse_ import add_command_subparsers, add_default_command +from .argparse_ import register_commands, add_default_command DEFAULT_AUGUR_RECURSION_LIMIT = 10000 sys.setrecursionlimit(int(os.environ.get("AUGUR_RECURSION_LIMIT") or DEFAULT_AUGUR_RECURSION_LIMIT)) @@ -58,8 +58,7 @@ def make_parser(): add_default_command(parser) add_version_alias(parser) - subparsers = parser.add_subparsers() - add_command_subparsers(subparsers, COMMANDS) + register_commands(parser, COMMANDS) return parser diff --git a/augur/argparse_.py b/augur/argparse_.py index 6084fdd5e..081fb51c0 100644 --- a/augur/argparse_.py +++ b/augur/argparse_.py @@ -1,6 +1,7 @@ """ Custom helpers for the argparse standard library. """ +import argparse from argparse import Action, ArgumentDefaultsHelpFormatter, ArgumentParser, _ArgumentGroup from typing import Union from .types import ValidationMode @@ -32,16 +33,14 @@ def run(args): parser.set_defaults(__command__ = default_command) -def add_command_subparsers(subparsers, commands, command_attribute='__command__'): +def register_commands(parser: argparse.ArgumentParser, commands, command_attribute='__command__'): """ Add subparsers for each command module. Parameters ---------- - subparsers: argparse._SubParsersAction - The special subparsers action object created by the parent parser - via `parser.add_subparsers()`. - + parser + ArgumentParser object. commands: list[types.ModuleType] A list of modules that are commands that require their own subparser. Each module is required to have a `register_parser` function to add its own @@ -51,6 +50,8 @@ def add_command_subparsers(subparsers, commands, command_attribute='__command__' Optional attribute name for the commands. The default is `__command__`, which allows top level augur to run commands directly via `args.__command__.run()`. """ + subparsers = parser.add_subparsers() + for command in commands: # Allow each command to register its own subparser subparser = command.register_parser(subparsers) @@ -70,6 +71,10 @@ def add_command_subparsers(subparsers, commands, command_attribute='__command__' if not getattr(command, "run", None): add_default_command(subparser) + # Recursively register any subcommands + if getattr(subparser, "subcommands", None): + register_commands(subparser, subparser.subcommands) + class HideAsFalseAction(Action): """ diff --git a/augur/curate/__init__.py b/augur/curate/__init__.py index d162ab40b..6b92b58f2 100644 --- a/augur/curate/__init__.py +++ b/augur/curate/__init__.py @@ -1,7 +1,6 @@ """ A suite of commands to help with data curation. """ -from augur.argparse_ import add_command_subparsers from . import format_dates, normalize_strings, passthru, titlecase, apply_geolocation_rules, apply_record_annotations, abbreviate_authors, parse_genbank_location, transform_strain_name, rename @@ -23,14 +22,6 @@ def register_parser(parent_subparsers): parser = parent_subparsers.add_parser("curate", help=__doc__) - # Add print_help so we can run it when no subcommands are called - parser.set_defaults(print_help = parser.print_help) - - # Add subparsers for subcommands - subparsers = parser.add_subparsers(dest="subcommand", required=False) - # Using a subcommand attribute so subcommands are not directly - # run by top level Augur. Process I/O in `curate`` so individual - # subcommands do not have to worry about it. - add_command_subparsers(subparsers, SUBCOMMANDS, SUBCOMMAND_ATTRIBUTE) + parser.subcommands = SUBCOMMANDS return parser diff --git a/augur/export.py b/augur/export.py index 38c025b80..050c94475 100644 --- a/augur/export.py +++ b/augur/export.py @@ -1,7 +1,8 @@ """ Export JSON files suitable for visualization with auspice. + +Augur export now needs you to define the JSON version you want, e.g. `augur export v2`. """ -from .argparse_ import add_command_subparsers from . import export_v1, export_v2 SUBCOMMANDS = [ @@ -12,10 +13,7 @@ def register_parser(parent_subparsers): parser = parent_subparsers.add_parser("export", help=__doc__) - # Add subparsers for subcommands - metavar_msg ="Augur export now needs you to define the JSON version " + \ - "you want, e.g. `augur export v2`." - subparsers = parser.add_subparsers(title="JSON SCHEMA", - metavar=metavar_msg) - add_command_subparsers(subparsers, SUBCOMMANDS) + + parser.subcommands = SUBCOMMANDS + return parser diff --git a/augur/import_/__init__.py b/augur/import_/__init__.py index d2f925a29..ef9a9ad0c 100644 --- a/augur/import_/__init__.py +++ b/augur/import_/__init__.py @@ -1,7 +1,6 @@ """ Import analyses into augur pipeline from other systems """ -from augur.argparse_ import add_command_subparsers from augur.utils import first_line from . import beast @@ -11,8 +10,7 @@ def register_parser(parent_subparsers): parser = parent_subparsers.add_parser("import", help=first_line(__doc__)) - metavar_msg = "Import analyses into augur pipeline from other systems" - subparsers = parser.add_subparsers(title="TYPE", - metavar=metavar_msg) - add_command_subparsers(subparsers, SUBCOMMANDS) + + parser.subcommands = SUBCOMMANDS + return parser diff --git a/augur/measurements/__init__.py b/augur/measurements/__init__.py index 3fbc93273..54a4638bc 100644 --- a/augur/measurements/__init__.py +++ b/augur/measurements/__init__.py @@ -1,7 +1,6 @@ """ Create JSON files suitable for visualization within the measurements panel of Auspice. """ -from augur.argparse_ import add_command_subparsers from augur.utils import first_line from . import export, concat @@ -13,7 +12,7 @@ def register_parser(parent_subparsers): parser = parent_subparsers.add_parser("measurements", help=first_line(__doc__)) - # Add subparsers for subcommands - subparsers = parser.add_subparsers(dest='subcommand') - add_command_subparsers(subparsers, SUBCOMMANDS) + + parser.subcommands = SUBCOMMANDS + return parser From b365f9d24cd2602d034310e6c194e49029bd8ed4 Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Tue, 24 Sep 2024 14:35:18 -0700 Subject: [PATCH 5/8] Get subparser formatter class from top-level parser Set a single source of truth. --- augur/__init__.py | 6 ++++-- augur/argparse_.py | 7 +++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/augur/__init__.py b/augur/__init__.py index 00d6e609e..96cf036cc 100644 --- a/augur/__init__.py +++ b/augur/__init__.py @@ -52,8 +52,10 @@ def make_parser(): parser = argparse.ArgumentParser( - prog = "augur", - description = "Augur: A bioinformatics toolkit for phylogenetic analysis.") + prog = "augur", + description = "Augur: A bioinformatics toolkit for phylogenetic analysis.", + formatter_class = argparse.ArgumentDefaultsHelpFormatter, + ) add_default_command(parser) add_version_alias(parser) diff --git a/augur/argparse_.py b/augur/argparse_.py index 081fb51c0..ac1861ccb 100644 --- a/augur/argparse_.py +++ b/augur/argparse_.py @@ -2,7 +2,7 @@ Custom helpers for the argparse standard library. """ import argparse -from argparse import Action, ArgumentDefaultsHelpFormatter, ArgumentParser, _ArgumentGroup +from argparse import Action, ArgumentParser, _ArgumentGroup from typing import Union from .types import ValidationMode @@ -60,9 +60,8 @@ def register_commands(parser: argparse.ArgumentParser, commands, command_attribu if command_attribute: subparser.set_defaults(**{command_attribute: command}) - # Use the same formatting class for every command for consistency. - # Set here to avoid repeating it in every command's register_parser(). - subparser.formatter_class = ArgumentDefaultsHelpFormatter + # Ensure all subparsers format like the top-level parser + subparser.formatter_class = parser.formatter_class if not subparser.description and command.__doc__: subparser.description = command.__doc__ From 6581f104e980b6f9e2e038b09b22217ad4da4d18 Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Thu, 31 Oct 2024 16:58:04 -0700 Subject: [PATCH 6/8] =?UTF-8?q?=F0=9F=9A=A7=20figure=20out=20why=20this=20?= =?UTF-8?q?breaks=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- augur/argparse_.py | 6 +----- augur/curate/__init__.py | 1 - 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/augur/argparse_.py b/augur/argparse_.py index ac1861ccb..d979b1acf 100644 --- a/augur/argparse_.py +++ b/augur/argparse_.py @@ -33,7 +33,7 @@ def run(args): parser.set_defaults(__command__ = default_command) -def register_commands(parser: argparse.ArgumentParser, commands, command_attribute='__command__'): +def register_commands(parser: argparse.ArgumentParser, commands): """ Add subparsers for each command module. @@ -56,10 +56,6 @@ def register_commands(parser: argparse.ArgumentParser, commands, command_attribu # Allow each command to register its own subparser subparser = command.register_parser(subparsers) - # Add default attribute for command module - if command_attribute: - subparser.set_defaults(**{command_attribute: command}) - # Ensure all subparsers format like the top-level parser subparser.formatter_class = parser.formatter_class diff --git a/augur/curate/__init__.py b/augur/curate/__init__.py index 6b92b58f2..2b21ead8f 100644 --- a/augur/curate/__init__.py +++ b/augur/curate/__init__.py @@ -4,7 +4,6 @@ from . import format_dates, normalize_strings, passthru, titlecase, apply_geolocation_rules, apply_record_annotations, abbreviate_authors, parse_genbank_location, transform_strain_name, rename -SUBCOMMAND_ATTRIBUTE = '_curate_subcommand' SUBCOMMANDS = [ passthru, normalize_strings, From f5e9472f6002dc2dfdaf1b08b3fa572fd983de0c Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Thu, 31 Oct 2024 16:58:50 -0700 Subject: [PATCH 7/8] =?UTF-8?q?Revert=20"=F0=9F=9A=A7=20figure=20out=20why?= =?UTF-8?q?=20this=20breaks=20tests"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 6581f104e980b6f9e2e038b09b22217ad4da4d18. --- augur/argparse_.py | 6 +++++- augur/curate/__init__.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/augur/argparse_.py b/augur/argparse_.py index d979b1acf..ac1861ccb 100644 --- a/augur/argparse_.py +++ b/augur/argparse_.py @@ -33,7 +33,7 @@ def run(args): parser.set_defaults(__command__ = default_command) -def register_commands(parser: argparse.ArgumentParser, commands): +def register_commands(parser: argparse.ArgumentParser, commands, command_attribute='__command__'): """ Add subparsers for each command module. @@ -56,6 +56,10 @@ def register_commands(parser: argparse.ArgumentParser, commands): # Allow each command to register its own subparser subparser = command.register_parser(subparsers) + # Add default attribute for command module + if command_attribute: + subparser.set_defaults(**{command_attribute: command}) + # Ensure all subparsers format like the top-level parser subparser.formatter_class = parser.formatter_class diff --git a/augur/curate/__init__.py b/augur/curate/__init__.py index 2b21ead8f..6b92b58f2 100644 --- a/augur/curate/__init__.py +++ b/augur/curate/__init__.py @@ -4,6 +4,7 @@ from . import format_dates, normalize_strings, passthru, titlecase, apply_geolocation_rules, apply_record_annotations, abbreviate_authors, parse_genbank_location, transform_strain_name, rename +SUBCOMMAND_ATTRIBUTE = '_curate_subcommand' SUBCOMMANDS = [ passthru, normalize_strings, From a4ee061672bd4f2b0c04567b7d84f14d91eaac0e Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Thu, 31 Oct 2024 17:09:42 -0700 Subject: [PATCH 8/8] fixup! curate: Validate records with a wrapper --- tests/io/test_curate_validate_records.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/io/test_curate_validate_records.py b/tests/io/test_curate_validate_records.py index 72f5083f4..42d4f1e15 100644 --- a/tests/io/test_curate_validate_records.py +++ b/tests/io/test_curate_validate_records.py @@ -1,5 +1,5 @@ import pytest -from augur.curate import validate_records +from augur.curate._shared import validate_records from augur.errors import AugurError