Skip to content

Commit

Permalink
Updated the main test
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-au-922 committed Dec 3, 2023
1 parent b4ec3d8 commit 7964eeb
Show file tree
Hide file tree
Showing 42 changed files with 1,811 additions and 338 deletions.
1 change: 1 addition & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ CONSUMER_LOG_DATE_FORMAT="%Y-%m-%d %H:%M:%S"
CONSUMER_LOG_DIR=./logs/producer
CONSUMER_LOG_RETENTION=7
CONSUMER_LOG_ROTATION=midnight
CONSUMER_REPLICAS=16

CSV_PARSER_RECOGNIZED_DATETIME_FORMATS="%Y-%m-%dT%H:%M:%S.%f%z"
CSV_PARSER_DELIMITER=","
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ test_consumer:
export QUEUE_NAME=$(QUEUE_NAME) && \
export CSV_PARSER_RECOGNIZED_DATETIME_FORMATS=$(CSV_PARSER_RECOGNIZED_DATETIME_FORMATS) && \
export CSV_PARSER_DELIMITER=$(CSV_PARSER_DELIMITER) && \
COVERAGE_FILE=.coverage_consumer coverage run -m pytest -vxs consumer/tests
COVERAGE_FILE=.coverage_consumer coverage run -m pytest -vx --last-failed consumer/tests
coverage_report:
coverage combine .coverage_producer .coverage_consumer && \
coverage report -m --omit="*/tests/*"
Expand Down
94 changes: 0 additions & 94 deletions consumer/src/adapters/fetch_filenames/rabbitmq.py

This file was deleted.

165 changes: 165 additions & 0 deletions consumer/src/adapters/fetch_filenames_stream/rabbitmq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from datetime import datetime
import time
from ...usecases import FetchFilenameStreamClient
import pika
from pika.adapters.blocking_connection import BlockingChannel
from pika.spec import Basic, BasicProperties
from pika.connection import Connection
from typing import Generator, Iterator, Optional, Sequence, cast, overload
from typing_extensions import override
from collections.abc import Callable
import logging


class RabbitMQFetchFilenameStreamClient(FetchFilenameStreamClient[int]):
def __init__(
self,
host: str,
port: int,
credentials_service: Callable[[], tuple[str, str]],
queue: str = "filenames",
polling_timeout: int = 10,
) -> None:
self._host = host
self._port = port
self._credentials_service = credentials_service
self._queue = queue
self._conn: Optional[Connection] = None
self._channel: Optional[BlockingChannel] = None
self._polling_timeout = polling_timeout
self._last_poll_time: Optional[datetime] = None

@overload
def ack(self, message_receipt: int) -> bool:
...

@overload
def ack(self, message_receipt: Sequence[int]) -> list[bool]:
...

@override
def ack(self, message_receipt: int | Sequence[int]) -> bool | list[bool]:
if isinstance(message_receipt, int):
return self._ack_single(message_receipt)
return self._ack_batch(message_receipt)

def _ack_single(self, message_receipt: int) -> bool:
try:
with self._get_channel() as channel:
channel.basic_ack(delivery_tag=message_receipt, multiple=False)
return True
except Exception as e:
logging.exception(e)
return False

def _ack_batch(self, message_receipts: Sequence[int]) -> list[bool]:
#! RabbitMQ is not thread-safe, so we have to use a single thread to ack
results: list[bool] = []
for receipt in message_receipts:
results.append(self._ack_single(receipt))
return results

@overload
def reject(self, message_receipt: int) -> bool:
...

@overload
def reject(self, message_receipt: Sequence[int]) -> list[bool]:
...

@override
def reject(self, message_receipt: int | Sequence[int]) -> bool | list[bool]:
if isinstance(message_receipt, int):
return self._reject_single(message_receipt)
return self._reject_batch(message_receipt)

def _reject_single(self, message_receipt: int) -> bool:
try:
with self._get_channel() as channel:
channel.basic_nack(delivery_tag=message_receipt, requeue=True)
return True
except Exception as e:
logging.exception(e)
return False

def _reject_batch(self, message_receipts: Sequence[int]) -> list[bool]:
#! RabbitMQ is not thread-safe, so we have to use a single thread to ack
results: list[bool] = []
for receipt in message_receipts:
results.append(self._reject_single(receipt))
return results

def _reset_conn(self) -> None:
self._conn = None
self._channel = None

@contextmanager
def _get_amqp_conn(self) -> Iterator[Connection]:
if self._conn is None or self._conn.is_closed:
username, password = self._credentials_service()
credentials = pika.PlainCredentials(username, password)
conn_parameters = pika.ConnectionParameters(
host=self._host,
port=self._port,
credentials=credentials,
)
self._conn = pika.BlockingConnection(conn_parameters)
yield self._conn

@contextmanager
def _get_channel(self) -> Iterator[BlockingChannel]:
if self._channel is None or self._channel.is_closed:
with self._get_amqp_conn() as connection:
self._channel = connection.channel()
yield self._channel

def _wait(self) -> None:
time.sleep(0.5)

@override
def fetch_stream(self) -> Generator[tuple[str, int], None, None]:
while True:
try:
method: Optional[Basic.Deliver] = None
with self._get_channel() as channel:
channel.queue_declare(queue=self._queue, durable=True)
properties: Optional[BasicProperties]
body: Optional[bytes]

method, properties, body = channel.basic_get(
queue=self._queue, auto_ack=False
)

if method is None and properties is None and body is None:
if self._last_poll_time is None:
self._last_poll_time = datetime.now()
if (
datetime.now() - self._last_poll_time
).total_seconds() > self._polling_timeout:
break
self._wait()
continue

self._last_poll_time = None

yield body.decode(), cast(int, method.delivery_tag)

except Exception as e:
logging.exception(e)
if method is not None:
self.reject(method.delivery_tag)
self._reset_conn()

@override
def close(self) -> bool:
try:
if self._channel is not None:
self._channel.close()
if self._conn is not None:
self._conn.close()
return True
except Exception as e:
logging.exception(e)
return False
31 changes: 21 additions & 10 deletions consumer/src/adapters/file_parse_iot_records/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ...usecases import FileParseIOTRecordsClient
import csv
import logging
from pathlib import Path


class CSVParseIOTRecordsClient(FileParseIOTRecordsClient):
Expand All @@ -22,29 +23,40 @@ def __init__(
self._file_extension = file_extension

@overload
def parse(self, filename: str) -> list[IOTRecord]:
def parse(self, filename: str) -> Optional[list[IOTRecord]]:
...

@overload
def parse(self, filename: Sequence[str]) -> list[list[IOTRecord]]:
def parse(self, filename: Sequence[str]) -> list[Optional[list[IOTRecord]]]:
...

@override
def parse(
self, filename: str | Sequence[str]
) -> list[IOTRecord] | list[list[IOTRecord]]:
) -> Optional[list[IOTRecord]] | list[Optional[list[IOTRecord]]]:
if isinstance(filename, str):
return self._parse_single(filename)
return self._parse_batch(filename)

def _basic_file_check(self, filename: str) -> bool:
if not Path(filename).exists():
raise ValueError("File path must exist!")
if not Path(filename).is_file():
raise ValueError("File path must be a file!")
if not filename.endswith(self._file_extension):
raise ValueError(f"File extension must be {self._file_extension}")

@override
def parse_stream(self, filename: str) -> Iterator[IOTRecord]:
try:
if not filename.endswith(self._file_extension):
raise ValueError(f"File extension must be {self._file_extension}")
self._basic_file_check(filename)
with open(filename) as csvfile:
reader = csv.reader(csvfile, delimiter=self._delimiter, strict=True)
yield from self._parse_iter(reader)
except OSError as e:
logging.exception(e)
logging.error(f"Failed to read stream from {filename}!")
raise e
except Exception as e:
logging.error(f"Failed to parse {filename}")
logging.exception(e)
Expand Down Expand Up @@ -84,17 +96,16 @@ def _parse_iter(self, reader: Iterator[list[str]]) -> Iterator[IOTRecord]:
)
return iot_records

def _parse_single(self, filename: str) -> list[IOTRecord]:
def _parse_single(self, filename: str) -> Optional[list[IOTRecord]]:
try:
if not filename.endswith(self._file_extension):
raise ValueError(f"File extension must be {self._file_extension}")
self._basic_file_check(filename)
with open(filename) as csvfile:
reader = csv.reader(csvfile, delimiter=self._delimiter)
return list(self._parse_iter(reader))
except Exception as e:
logging.error(f"Failed to parse {filename}")
logging.exception(e)
return []
logging.error(f"Failed to parse {filename}")
return None

def _parse_batch(self, filenames: Sequence[str]) -> list[list[IOTRecord]]:
with ThreadPoolExecutor() as executor:
Expand Down
Loading

0 comments on commit 7964eeb

Please sign in to comment.