diff --git a/docs/scripts.md b/docs/scripts.md
new file mode 100644
index 00000000..e809ad1e
--- /dev/null
+++ b/docs/scripts.md
@@ -0,0 +1,112 @@
+# Script Documentation
+
+## Summary
+
+Some tasks are small enough that the project architecture should not change, but the large enough that they should not be performed by hand.
+Files in the `scripts` directory exist to fill this space.
+
+Currently, the following scripts are provided.
+
+* `get_counts.py`
+ * get docket, document, and comment counts from regulations.gov, a mirrulations dashboard, or a mirrulations Redis instance as json
+ * when using regulations.gov a timestamp can be given to make all dockets, documents, and comments before the timestamp count as if they were downloaded
+* `correct_counts.py`
+ * correct possible errors within a counts json file generated by `get_counts.py`
+* `set_counts.py`
+ * set values in a mirrulations Redis instance using json generated by `get_counts.py`
+* `get_correct_set.sh`
+ * run `get_counts.py`, `correct_counts.py`, and `set_counts.py`, logging relevant information
+
+All of the scripts above share a common format
+
+get_counts.py
common format
+
+```json
+{
+ "creation_timestamp": "2024-10-16 15:00:00",
+ "dockets": {
+ "downloaded": 253807,
+ "jobs": 0,
+ "total": 253807,
+ "last_timestamp": "2024-10-13 04:04:18"
+ },
+ "documents": {
+ "downloaded": 1843774,
+ "jobs": 0,
+ "total": 1843774,
+ "last_timestamp": "2024-10-13 04:04:18"
+ },
+ "comments": {
+ "downloaded": 22240501,
+ "jobs": 10,
+ "total": 22240511,
+ "last_timestamp": "2024-10-13 04:04:18"
+ }
+}
+```
+
+
+
+## Description
+
+### `get_correct_set.sh`
+
+`get_correct_set.sh` gets counts using `get_counts.py` from Redis, corrects them using `correct_counts.py`, and on success sets them using `set_counts.py`.
+It attempts to log to `/var/log/mirrulations_counts.log`.
+By default, it expects a virtual environment with all required dependencies in `/home/cs334/mirrulations/scripts/.venv`.
+
+### `get_counts.py`
+
+`get_counts.py` gets counts from one of three sources: regulations.gov, a Mirrulations Redis instance, a Mirrulations dashboard via HTTP.
+
+When reading from regulations.gov a UTC timestamp can be specified to mock having downloaded all dockets, documents, and comments from before that timestamp.
+
+When reading from a dashboard a UTC timestamp must be specified since the dashboard API does not provide one.
+
+### `correct_counts.py`
+
+`correct_counts.py` corrects counts from `get_counts.py` using one of two strategies: set downloaded counts for a type to the minimum of `downloaded` and `total` for that type, or set downloaded counts to the minimum of `total -jobs` and `downloaded`.
+By default any queued jobs will cause the script to exit and output nothing, this behavior can be changed with the `--ignore-queue` flag.
+
+### `set_counts.py`
+
+`set_counts.py` sets values from `get_counts.py` in a Redis instance.
+By default the script will prompt for user input before changing any values.
+This behavior can be changed using the `--yes` flag, which should be used **WITH GREAT CARE, ESPECIALLY IN PRODUCTION!!!**.
+
+## Setup
+
+First a virtual environment should be created to download dependencies to.
+
+```bash
+cd scripts
+python3 -m venv .venv
+source .venv/bin/activate
+pip install -r requirements.txt
+```
+
+Make sure when you are in the correct environment when running scripts.
+
+## Examples
+
+### Cap Docket, Document, and Comment downloaded counts by the counts from Regulations.gov
+
+```bash
+./get_counts.py redis | ./correct_counts.py | ./set_counts.py -y
+```
+
+### Set Docket, Document, Comment downloaded counts while jobs are in the queue
+
+```bash
+./get_counts.py dashboard | ./correct_counts.py --ignore-queue --strategy diff_total_with_jobs | ./set_counts.py -y
+```
+
+### Download Counts for a Certain Time from Regulations.gov
+
+```bash
+./get_counts.py --api-key $API_KEY -o aug_6_2022.json -t 2024-08-06T06:20:50Z
+
+EXPORT API_KEY=
+./get_counts.py regulations -o oct_01_2024.json --last-timestamp 2024-10-01T15:30:10Z
+./set_counts.py -i oct_01_2024.json
+```
diff --git a/scripts/correct_counts.py b/scripts/correct_counts.py
new file mode 100755
index 00000000..e276e994
--- /dev/null
+++ b/scripts/correct_counts.py
@@ -0,0 +1,114 @@
+#!/usr/bin/env python3
+
+from copy import deepcopy
+import json
+import pathlib
+import sys
+from json import JSONDecodeError
+from counts import Counts, CountsEncoder, CountsDecoder
+
+import argparse
+
+
+class JobsInQueueException(Exception):
+ pass
+
+
+def strategy_cap(recieved: Counts, ignore_queue: bool) -> Counts:
+ filtered = deepcopy(recieved)
+ if filtered["queue_size"] != 0 and not ignore_queue:
+ raise JobsInQueueException(f'Found jobs in job queue: {filtered["queue_size"]}')
+ for entity_type in ("dockets", "documents", "comments"):
+ total_ = filtered[entity_type]["total"]
+ downloaded = filtered[entity_type]["downloaded"]
+ filtered[entity_type]["downloaded"] = min(total_, downloaded)
+
+ return filtered
+
+
+def strategy_diff(recieved: Counts, ignore_queue: bool) -> Counts:
+ filtered = deepcopy(recieved)
+ for entity_type in ("dockets", "documents", "comments"):
+ total_ = filtered[entity_type]["total"]
+ downloaded = filtered[entity_type]["downloaded"]
+ jobs = filtered[entity_type]["jobs"]
+ if jobs > 0 and not ignore_queue:
+ raise JobsInQueueException(
+ f'{entity_type} has {filtered[entity_type]["jobs"]} in queue'
+ )
+ filtered[entity_type]["downloaded"] = min(total_ - jobs, downloaded)
+
+ return filtered
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ "Correct Counts",
+ description="Correct counts in json format by either capping downloaded with `total` or capping with `total - jobs`",
+ )
+ parser.add_argument(
+ "-o",
+ "--output",
+ metavar="OUTPUT_PATH",
+ type=str,
+ default="-",
+ help="file to output to, use '-' for stdout (default '%(default)s')",
+ )
+ parser.add_argument(
+ "-i",
+ "--input",
+ metavar="INPUT_PATH",
+ type=str,
+ default="-",
+ help="file to read from, use '-' for stdin (default '%(default)s')",
+ )
+ parser.add_argument(
+ "-s",
+ "--strategy",
+ type=str,
+ default="cap_with_total",
+ choices=("cap_with_total", "diff_total_with_jobs"),
+ help="the correction strategy to use (default '%(default)s')",
+ )
+ parser.add_argument(
+ "--ignore-queue",
+ action="store_true",
+ help="continue even if there are queued jobs",
+ )
+
+ args = parser.parse_args()
+
+ try:
+ if args.input == "-":
+ input_counts: Counts = json.load(sys.stdin, cls=CountsDecoder)
+ else:
+ try:
+ with open(pathlib.Path(args.input), "r") as fp:
+ input_counts = json.load(fp, cls=CountsDecoder)
+ except FileNotFoundError:
+ print(f"Missing file {args.input}, exitting", file=sys.stderr)
+ sys.exit(2)
+ except JSONDecodeError:
+ print(f"Malformed input file {args.input}, exitting", file=sys.stderr)
+ sys.exit(2)
+
+ try:
+ if args.strategy == "cap_with_total":
+ modified_counts = strategy_cap(input_counts, args.ignore_queue)
+ elif args.strategy == "diff_total_with_jobs":
+ modified_counts = strategy_diff(input_counts, args.ignore_queue)
+ else:
+ print(f"Unrecognized strategy {args.strategy}, exitting", file=sys.stderr)
+ sys.exit(1)
+ except JobsInQueueException as e:
+ print(
+ f"Found jobs in queue: {e}\nUse `--ignore-queue` to continue",
+ file=sys.stderr,
+ )
+ sys.exit(2)
+
+ if args.output == "-":
+ json.dump(modified_counts, sys.stdout, cls=CountsEncoder)
+ else:
+ with open(pathlib.Path(args.output), "w") as fp:
+ json.dump(modified_counts, fp, cls=CountsEncoder)
diff --git a/scripts/counts.py b/scripts/counts.py
new file mode 100644
index 00000000..28b066ca
--- /dev/null
+++ b/scripts/counts.py
@@ -0,0 +1,38 @@
+import json
+import datetime as dt
+from typing import Any, TypedDict
+
+
+class EntityCount(TypedDict):
+ downloaded: int
+ jobs: int
+ total: int
+ last_timestamp: dt.datetime
+
+
+class Counts(TypedDict):
+ creation_timestamp: dt.datetime
+ queue_size: int
+ dockets: EntityCount
+ documents: EntityCount
+ comments: EntityCount
+
+
+class CountsEncoder(json.JSONEncoder):
+ def default(self, o: Any) -> Any:
+ if isinstance(o, dt.datetime):
+ return o.strftime("%Y-%m-%d %H:%M:%S")
+ return super().default(o)
+
+
+class CountsDecoder(json.JSONDecoder):
+ def __init__(self, *args, **kwargs):
+ super().__init__(object_hook=self.object_hook, *args, **kwargs)
+
+ def object_hook(self, obj: Any) -> Any:
+ for key, value in obj.items():
+ try:
+ obj[key] = dt.datetime.strptime(value, "%Y-%m-%d %H:%M:%S")
+ except (ValueError, TypeError):
+ pass
+ return obj
diff --git a/scripts/get_correct_set.sh b/scripts/get_correct_set.sh
new file mode 100755
index 00000000..6fbca092
--- /dev/null
+++ b/scripts/get_correct_set.sh
@@ -0,0 +1,15 @@
+#!/bin/bash
+
+WORK_DIR="/home/cs334/mirrulations/scripts/"
+LOG_FILE=/var/log/mirrulations_counts.log
+START_TIME=$(date -u -Iseconds)
+echo "$START_TIME: RUnning" > $LOG_FILE
+cd $WORK_DIR
+
+PYTHON=".venv/bin/python3"
+
+$PYTHON get_counts redis -o "/tmp/mirrulations_$START_TIME.json" 2>> $LOG_FILE &&
+ $PYTHON correct_counts -i "/tmp/mirrulations_$START_TIME.json" -o "/tmp/mirrulations_${START_TIME}_corrected.json" 2>> $LOG_FILE &&
+ $PYTHON set_counts -y -i "/tmp/mirrulations_${START_TIME}_corrected.json" 2>> $LOG_FILE
+
+rm "/tmp/mirrulations_${START_TIME}_corrected.json" "/tmp/mirrulations_$START_TIME.json"
diff --git a/scripts/get_counts.py b/scripts/get_counts.py
new file mode 100755
index 00000000..05f2a93a
--- /dev/null
+++ b/scripts/get_counts.py
@@ -0,0 +1,254 @@
+#!/usr/bin/env python3
+
+import argparse
+import datetime as dt
+import json
+import os
+import pathlib
+import sys
+from counts import Counts, CountsEncoder
+from job_queue import RabbitMQ
+
+import redis
+import requests
+
+REGULATIONS_BASE_URL = "https://api.regulations.gov/v4/"
+
+
+class MissingRedisKeyException(Exception):
+ pass
+
+
+def _download_regulation_count(
+ url: str, headers: dict[str, str], params: dict[str, str]
+) -> int:
+ response = requests.get(
+ url,
+ headers=headers,
+ params=params,
+ )
+ response.raise_for_status()
+ return response.json()["meta"]["totalElements"]
+
+
+def get_regulation(api_key: str, last_timestamp: dt.datetime) -> Counts:
+ """Get counts from regulations.gov given a last_timestamp
+
+ Exactly 6 Regulations.gov API calls are made during this function
+ """
+ output: Counts = {
+ "creation_timestamp": dt.datetime.now(dt.timezone.utc),
+ "queue_size": 0,
+ "dockets": {
+ "downloaded": -1,
+ "jobs": 0,
+ "total": -1,
+ "last_timestamp": last_timestamp,
+ },
+ "documents": {
+ "downloaded": -1,
+ "jobs": 0,
+ "total": -1,
+ "last_timestamp": last_timestamp,
+ },
+ "comments": {
+ "downloaded": -1,
+ "jobs": 0,
+ "total": -1,
+ "last_timestamp": last_timestamp,
+ },
+ }
+
+ headers = {"X-Api-Key": api_key}
+ # NOTE: we set pagesize to be 5 since we only care about the metadata
+ downloaded_filter = {
+ "filter[lastModifiedDate][le]": last_timestamp.strftime("%Y-%m-%d %H:%M:%S"),
+ "page[size]": 5,
+ }
+
+ for entity_type in ("dockets", "documents", "comments"):
+ downloaded = _download_regulation_count(
+ REGULATIONS_BASE_URL + entity_type, headers, downloaded_filter
+ )
+ total = _download_regulation_count(
+ REGULATIONS_BASE_URL + entity_type, headers, {"page[size]": "5"}
+ )
+ output[entity_type]["downloaded"] = downloaded
+ output[entity_type]["total"] = total
+
+ return output
+
+
+def get_dashboard(dashboard_url: str, last_timestamp: dt.datetime) -> Counts:
+ """Get the counts of a running mirrulations instance via it's dashboard"""
+ response = requests.get(dashboard_url + "/data")
+ response.raise_for_status()
+
+ content = response.json()
+
+ counts: Counts = {
+ "creation_timestamp": dt.datetime.now(dt.timezone.utc),
+ "queue_size": content["num_jobs_waiting"],
+ "dockets": {
+ "downloaded": content["num_dockets_done"],
+ "jobs": content["num_jobs_dockets_queued"],
+ "total": content["regulations_total_dockets"],
+ "last_timestamp": last_timestamp,
+ },
+ "documents": {
+ "downloaded": content["num_documents_done"],
+ "jobs": content["num_jobs_documents_queued"],
+ "total": content["regulations_total_documents"],
+ "last_timestamp": last_timestamp,
+ },
+ "comments": {
+ "downloaded": content["num_comments_done"],
+ "jobs": content["num_jobs_comments_queued"],
+ "total": content["regulations_total_comments"],
+ "last_timestamp": last_timestamp,
+ },
+ }
+
+ return counts
+
+
+def _get_key_or_raise(db: redis.Redis, key: str) -> str:
+ value: str | None = db.get(key)
+ if value is None:
+ raise MissingRedisKeyException(f"missing redis key: {key}")
+
+ return value
+
+
+def get_redis(db: redis.Redis) -> Counts:
+ """Get the counts of a running mirrulations instance via a Redis connection"""
+
+ counts: Counts = {
+ "creation_timestamp": dt.datetime.now(dt.timezone.utc),
+ }
+ queue = RabbitMQ("jobs_waiting_queue")
+ counts["queue_size"] = queue.size()
+
+ for entity_type in ("dockets", "documents", "comments"):
+ # Getting any of these values can raise an exception
+ downloaded = _get_key_or_raise(db, f"num_{entity_type}_done")
+ jobs = _get_key_or_raise(db, f"num_jobs_{entity_type}_waiting")
+ total = _get_key_or_raise(db, f"regulations_total_{entity_type}")
+ last_timestamp = _get_key_or_raise(db, f"{entity_type}_last_timestamp")
+
+ counts[entity_type] = {
+ "downloaded": int(downloaded),
+ "jobs": int(jobs),
+ "total": int(total),
+ "last_timestamp": dt.datetime.strptime(last_timestamp, "%Y-%m-%d %H:%M:%S"),
+ }
+
+ return counts
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ "Get Counts",
+ description="Get Docket, Document, and Comment counts from multiple sources",
+ )
+ parser.add_argument(
+ "-o",
+ "--output",
+ metavar="PATH",
+ type=str,
+ default="-",
+ help="file to output to, use '-' for stdout (default '%(default)s')",
+ )
+ subparsers = parser.add_subparsers(
+ dest="source", required=True, help="The source to get counts from"
+ )
+
+ regulations = subparsers.add_parser(
+ "regulations", help="download counts from regulations.gov"
+ )
+ regulations.add_argument(
+ "-a",
+ "--api-key",
+ help="Regulations.gov api key, defaults to value of `API_KEY` environment variable",
+ default=os.getenv("API_KEY"),
+ type=str,
+ )
+ regulations.add_argument(
+ "-t",
+ "--last-timestamp",
+ metavar="TIMESTAMP",
+ type=dt.datetime.fromisoformat,
+ default=dt.datetime.now(dt.timezone.utc).isoformat(timespec="seconds"),
+ help="last timestamp that is assumed to have been downloaded in ISO 8601 format 'YYYY-MM-DDTHH:mm:ssZ' (default '%(default)s')",
+ )
+
+ dashboard = subparsers.add_parser(
+ "dashboard", help="get counts from a mirrulations dashboard"
+ )
+ dashboard.add_argument(
+ "-u",
+ "--url",
+ metavar="DASHBOARD_URL",
+ default="http://localhost",
+ help="dashboard url (default '%(default)s')",
+ )
+ dashboard.add_argument(
+ "last_timestamp",
+ type=dt.datetime.fromisoformat,
+ default=dt.datetime.now(dt.timezone.utc).isoformat(timespec="seconds"),
+ help="last timestamp that is assumed to have been downloaded in ISO 8601 format 'YYYY-MM-DDTHH:mm:ss' (default '%(default)s')",
+ )
+
+ redis_args = subparsers.add_parser("redis", help="get counts from redis")
+ redis_args.add_argument(
+ "--hostname",
+ metavar="HOSTNAME",
+ default="localhost",
+ help="redis server hostname (default '%(default)s')",
+ )
+ redis_args.add_argument(
+ "-p",
+ "--port",
+ metavar="PORT",
+ type=int,
+ default=6379,
+ help="port for redis server (default '%(default)s')",
+ )
+ redis_args.add_argument(
+ "-n",
+ "--db",
+ metavar="DB_NUMBER",
+ type=int,
+ default=0,
+ help="redis database number (default '%(default)s')",
+ )
+
+ args = parser.parse_args()
+
+ source = args.source
+ if source == "regulations":
+ api_key = args.api_key
+ if api_key is None or api_key == "":
+ print("No api key found, exitting", file=sys.stderr)
+ sys.exit(1)
+ output = get_regulation(api_key, args.last_timestamp)
+ elif source == "dashboard":
+ output = get_dashboard(args.url, args.last_timestamp)
+ elif source == "redis":
+ db = redis.Redis(
+ host=args.hostname, port=args.port, db=args.db, decode_responses=True
+ )
+ try:
+ output = get_redis(db)
+ except MissingRedisKeyException as e:
+ print(f"Missing a redis key, exitting\n{e}", file=sys.stderr)
+ sys.exit(1)
+ else:
+ print("Unrecognized source, exitting", file=sys.stderr)
+ sys.exit(1)
+
+ if args.output == "-":
+ json.dump(output, sys.stdout, cls=CountsEncoder)
+ else:
+ with open(pathlib.Path(args.output), "w") as fp:
+ json.dump(output, fp, cls=CountsEncoder)
diff --git a/scripts/job_queue.py b/scripts/job_queue.py
new file mode 100644
index 00000000..5658210f
--- /dev/null
+++ b/scripts/job_queue.py
@@ -0,0 +1,41 @@
+# pylint: disable=too-many-arguments
+import sys
+import pika
+
+
+class RabbitMQ:
+ """
+ Encapsulate calls to RabbitMQ in one place
+ """
+
+ def __init__(self, queue_name):
+ """
+ Create a new RabbitMQ object
+ @param queue_name: the name of the queue to use
+ """
+ self.queue_name = queue_name
+ self.connection = None
+ self.channel = None
+
+ def _ensure_channel(self):
+ if self.connection is None or not self.connection.is_open:
+ connection_parameter = pika.ConnectionParameters("localhost")
+ self.connection = pika.BlockingConnection(connection_parameter)
+ self.channel = self.connection.channel()
+ self.channel.queue_declare(self.queue_name, durable=True)
+
+ def size(self) -> int:
+ """
+ Get the number of jobs in the queue.
+ Can't be sure Channel is active between ensure_channel()
+ and queue_declare() which is the reasoning for implementation of try
+ except
+ @return: a non-negative integer
+ """
+ self._ensure_channel()
+ try:
+ queue = self.channel.queue_declare(self.queue_name, durable=True)
+ return queue.method.message_count
+ except pika.exceptions.StreamLostError:
+ print("FAILURE: RabbitMQ Channel Connection Lost", file=sys.stderr)
+ return 0
diff --git a/scripts/requirements.txt b/scripts/requirements.txt
new file mode 100644
index 00000000..5e8dc051
--- /dev/null
+++ b/scripts/requirements.txt
@@ -0,0 +1,3 @@
+requests
+redis
+pika
diff --git a/scripts/set_counts.py b/scripts/set_counts.py
new file mode 100755
index 00000000..8e2803cd
--- /dev/null
+++ b/scripts/set_counts.py
@@ -0,0 +1,164 @@
+#!/usr/bin/env python3
+
+import argparse
+import json
+import pathlib
+import redis
+import sys
+
+from counts import Counts, CountsDecoder
+
+ANSI_RESET = "\033[0m"
+ANSI_BOLD = "\033[1m"
+ANSI_BLINK = "\033[5m"
+ANSI_BLINK_OFF = "\033[25m"
+ANSI_FG_RED = "\033[31m"
+
+
+def _get_vals(db: redis.Redis, entity_type: str) -> dict[str, int | str]:
+ done_raw: str | None = db.get(f"num_{entity_type}_done")
+ if done_raw is not None:
+ done = int(done_raw)
+ else:
+ done = "None"
+
+ total_raw: str | None = db.get(f"regulations_total_{entity_type}")
+ if total_raw is not None:
+ total = int(total_raw)
+ else:
+ total = "None"
+
+ timestamp: str = db.get(f"{entity_type}_last_timestamp") or "None"
+
+ return {"done": done, "timestamp": timestamp, "total": total}
+
+
+def _print_changes(info: str, original: str, new: str) -> None:
+ if original != new:
+ print(
+ info,
+ ANSI_FG_RED + ANSI_BOLD + original,
+ f"{ANSI_BLINK}--->{ANSI_BLINK_OFF}",
+ new + ANSI_RESET,
+ )
+ else:
+ print(info, original, "--->", new)
+
+
+def show_changes(db: redis.Redis, counts: Counts) -> None:
+ for entity_type in ("dockets", "documents", "comments"):
+ vals = _get_vals(db, entity_type)
+ _print_changes(
+ f"num_{entity_type}_done:\n ",
+ str(vals["done"]),
+ str(counts[entity_type]["downloaded"]),
+ )
+ _print_changes(
+ f"regulations_total_{entity_type}:\n ",
+ str(vals["total"]),
+ str(counts[entity_type]["total"]),
+ )
+ _print_changes(
+ f"{entity_type}_last_timestamp:\n ",
+ str(vals["timestamp"]),
+ counts[entity_type]["last_timestamp"].strftime("%Y-%m-%d %H:%M:%S"),
+ )
+ print()
+
+
+def set_values(db: redis.Redis, counts: Counts):
+ for entity_type in ("dockets", "documents", "comments"):
+ try:
+ db.set(f"num_{entity_type}_done", counts[entity_type]["downloaded"])
+ db.set(f"regulations_total_{entity_type}", counts[entity_type]["total"])
+ db.set(
+ f"{entity_type}_last_timestamp",
+ counts[entity_type]["last_timestamp"].strftime("%Y-%m-%d %H:%M:%S"),
+ )
+ except Exception as e:
+ print(
+ f"Error occurred while setting values for {entity_type}, exitting",
+ file=sys.stderr,
+ )
+ print(e)
+ return
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ "Set Counts", description="Set counts in Redis database from json"
+ )
+ parser.add_argument(
+ "-i",
+ "--input",
+ metavar="INPUT_PATH",
+ type=str,
+ default="-",
+ help="file to read from, use '-' for stdin (default '%(default)s')",
+ )
+ parser.add_argument(
+ "-y",
+ "--yes",
+ dest="changes_confirmed",
+ action="store_true",
+ help="Do not check for confirmation when setting values",
+ )
+ parser.add_argument(
+ "--host",
+ metavar="HOSTNAME",
+ default="localhost",
+ help="redis server hostname (default '%(default)s')",
+ )
+ parser.add_argument(
+ "-p",
+ "--port",
+ metavar="PORT",
+ type=int,
+ default=6379,
+ help="port for redis server (default '%(default)s')",
+ )
+ parser.add_argument(
+ "-n",
+ "--db",
+ metavar="DB_NUMBER",
+ type=int,
+ default=0,
+ help="redis database number (default '%(default)s')",
+ )
+
+ args = parser.parse_args()
+
+ try:
+ if args.input == "-":
+ input_counts: Counts = json.load(sys.stdin, cls=CountsDecoder)
+ else:
+ try:
+ with open(pathlib.Path(args.input), "r") as fp:
+ input_counts = json.load(fp, cls=CountsDecoder)
+ except FileNotFoundError:
+ print(
+ f"Input file {args.input} does not exist, exitting", file=sys.stderr
+ )
+ sys.exit(2)
+ except json.JSONDecodeError:
+ print(f"Malformed input file {args.input}, exitting", file=sys.stderr)
+ sys.exit(2)
+
+ db = redis.Redis(args.host, args.port, args.db, decode_responses=True)
+ changes_confirmed: bool = args.changes_confirmed
+
+ if changes_confirmed:
+ set_values(db, input_counts)
+ else:
+ show_changes(db, input_counts)
+ response = (
+ input("Are you sure you want to make the above changes [y/n]: ")
+ .strip()
+ .lower()
+ )
+ changes_confirmed = response == "y" or response == "yes"
+ if changes_confirmed:
+ set_values(db, input_counts)
+ else:
+ print("No values set, exitting")
+ sys.exit()