Skip to content

Commit

Permalink
feat: use msgspec JSON encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
clintval committed Nov 13, 2024
1 parent 4bf0d29 commit 7f687b3
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 40 deletions.
36 changes: 18 additions & 18 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,28 +121,28 @@ def test_reader_will_write_a_complicated_record(tmp_path: Path) -> None:
},
field10=True,
field11=None,
field12=1,
field12=0.2,
)
with TsvRecordWriter.from_path(tmp_path / "test.txt", ComplexMetric) as writer:
assert (tmp_path / "test.txt").read_text() == ""
writer.write(metric)
assert (tmp_path / "test.txt").read_text() == "\t".join([
"1",
"'my\tname'",
"0.2",
"[1, 2, 3]",
"[3, 4, 5]",
"[5, 6, 7]",
'{"field1": 1, "field2": 2}',
'{"field1": 10, "field2": "hi-mom", "field3": null}',
", ".join([
r'{"first": {"field1": 2, "field2": "hi-dad", "field3": 0.2}',
r'"second": {"field1": 3, "field2": "hi-all", "field3": 0.3}}',
]),
"true",
"null",
"1\n",
])

expected: str = (
"1"
+ "\t'my\tname'"
+ "\t0.2"
+ "\t[1,2,3]"
+ "\t[3,4,5]"
+ "\t[5,6,7]"
+ '\t{"field1":1,"field2":2}'
+ '\t{"field1":10,"field2":"hi-mom","field3":null}'
+ '\t{"first":{"field1":2,"field2":"hi-dad","field3":0.2}'
+ ',"second":{"field1":3,"field2":"hi-all","field3":0.3}}'
+ "\ttrue"
+ "\tnull"
+ "\t0.2\n"
)
assert (tmp_path / "test.txt").read_text() == expected

with TsvRecordReader.from_path(tmp_path / "test.txt", ComplexMetric, header=False) as reader:
assert list(reader) == [metric]
Expand Down
33 changes: 16 additions & 17 deletions tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,23 +96,22 @@ def test_writer_will_write_a_complicated_record(tmp_path: Path) -> None:
with TsvRecordWriter.from_path(tmp_path / "test.txt", ComplexMetric) as writer:
assert (tmp_path / "test.txt").read_text() == ""
writer.write(metric)
assert (tmp_path / "test.txt").read_text() == "\t".join([
"1",
"'my\tname'",
"0.2",
"[1, 2, 3]",
"[3, 4, 5]",
"[5, 6, 7]",
'{"field1": 1, "field2": 2}',
'{"field1": 10, "field2": "hi-mom", "field3": null}',
", ".join([
r'{"first": {"field1": 2, "field2": "hi-dad", "field3": 0.2}',
r'"second": {"field1": 3, "field2": "hi-all", "field3": 0.3}}',
]),
"true",
"null",
"0.2\n",
])
expected: str = (
"1"
+ "\t'my\tname'"
+ "\t0.2"
+ "\t[1,2,3]"
+ "\t[3,4,5]"
+ "\t[5,6,7]"
+ '\t{"field1":1,"field2":2}'
+ '\t{"field1":10,"field2":"hi-mom","field3":null}'
+ '\t{"first":{"field1":2,"field2":"hi-dad","field3":0.2}'
+ ',"second":{"field1":3,"field2":"hi-all","field3":0.3}}'
+ "\ttrue"
+ "\tnull"
+ "\t0.2\n"
)
assert (tmp_path / "test.txt").read_text() == expected


def test_writer_can_write_with_a_custom_callback(tmp_path: Path) -> None:
Expand Down
2 changes: 1 addition & 1 deletion typeline/_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def from_path(
"""Construct a delimited data reader from a file path.
Args:
path: the pat to the file to read delimited data from.
path: the path to the file to read delimited data from.
record_type: the type of the object we will be writing.
header: whether we expect the first line to be a header or not.
comment_prefixes: skip lines that have any of these string prefixes.
Expand Down
21 changes: 17 additions & 4 deletions typeline/_writer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import csv
import json
from abc import ABC
from abc import abstractmethod
from contextlib import AbstractContextManager
Expand All @@ -16,6 +15,7 @@
from typing import final

from msgspec import to_builtins
from msgspec.json import Encoder as JSONEncoder
from typing_extensions import Self
from typing_extensions import override

Expand All @@ -39,12 +39,18 @@ def __init__(self, handle: TextIOWrapper, record_type: type[RecordType]) -> None
if not is_dataclass(record_type):
raise ValueError("record_type is not a dataclass but must be!")

# Initialize and save internal attributes of this class.
self._handle: TextIOWrapper = handle
self._record_type: type[RecordType] = record_type

# Inspect the record type and save the fields, field names, and field types.
self._fields: tuple[Field[Any], ...] = fields_of(record_type)
self._header: list[str] = [field.name for field in fields_of(record_type)]

# Build a JSON encoder for intermediate data conversion (after dataclass, before delimited).
self._encoder: JSONEncoder = JSONEncoder()

# Build the delimited dictionary reader, filtering out any comment lines along the way.
self._writer: DictWriter[str] = DictWriter(
handle,
fieldnames=self._header,
Expand Down Expand Up @@ -90,8 +96,10 @@ def write(self, record: RecordType) -> None:
)

encoded = {name: self._encode(getattr(record, name)) for name in self._header}
builtin = {
name: (json.dumps(value) if not isinstance(value, str) else value)
builtin: dict[str, str] = {
name: (
self._encoder.encode(value).decode("utf-8") if not isinstance(value, str) else value
)
for name, value in cast(dict[str, Any], to_builtins(encoded, str_keys=True)).items()
}
self._writer.writerow(builtin)
Expand All @@ -112,7 +120,12 @@ def close(self) -> None:
def from_path(
cls, path: Path | str, record_type: type[RecordType]
) -> "DelimitedRecordWriter[RecordType]":
"""Construct a delimited struct writer from a file path."""
"""Construct a delimited data writer from a file path.
Args:
path: the path to the file to write delimited data to.
record_type: the type of the object we will be writing.
"""
writer = cls(Path(path).open("w"), record_type)
return writer

Expand Down

0 comments on commit 7f687b3

Please sign in to comment.