Skip to content

Commit 10c95ab

Browse files
authored
Merge pull request #109 from alex-dixon/timestamp-timezone
use datetime with timezone on backend (#108)
2 parents 5c2b682 + c08d355 commit 10c95ab

File tree

6 files changed

+131
-19
lines changed

6 files changed

+131
-19
lines changed

ell-studio/src/components/LMPDetailsSidePanel.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function VersionItem({ version, index, totalVersions, currentLmpId }) {
2828
{isLatest && <span className="text-xs bg-green-500 text-white px-2 py-0.5 rounded">Latest</span>}
2929
</div>
3030
<div className="text-xs text-gray-500 mt-1">
31-
{getTimeAgo(new Date(version.created_at + "Z"))}
31+
{getTimeAgo(new Date(version.created_at))}
3232
</div>
3333
</Link>
3434
</div>

src/ell/decorators/track.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from ell.types import SerializedLStr
2+
from ell.types import SerializedLStr, utc_now
33
import ell.util.closure
44
from ell.configurator import config
55
from ell.lstr import lstr
@@ -86,7 +86,7 @@ def wrapper(*fn_args, **fn_kwargs) -> str:
8686
logger.info(f"Attempted to use cache on {func_to_track.__qualname__} but it was not cached, or did not exist in the store. Refreshing cache...")
8787

8888

89-
_start_time = datetime.now()
89+
_start_time = utc_now()
9090

9191
# XXX: thread saftey note, if I prevent yielding right here and get the global context I should be fine re: cache key problem
9292

@@ -96,7 +96,7 @@ def wrapper(*fn_args, **fn_kwargs) -> str:
9696
if not lmp
9797
else fn(*fn_args, _invocation_origin=invocation_id, **fn_kwargs, )
9898
)
99-
latency_ms = (datetime.now() - _start_time).total_seconds() * 1000
99+
latency_ms = (utc_now() - _start_time).total_seconds() * 1000
100100
usage = metadata.get("usage", {})
101101
prompt_tokens=usage.get("prompt_tokens", 0)
102102
completion_tokens=usage.get("completion_tokens", 0)
@@ -145,7 +145,7 @@ def _serialize_lmp(func, name, fn_closure, is_lmp, lm_kwargs):
145145
config._store.write_lmp(
146146
lmp_id=func.__ell_hash__,
147147
name=name,
148-
created_at=datetime.now(),
148+
created_at=utc_now(),
149149
source=fn_closure[0],
150150
dependencies=fn_closure[1],
151151
commit_message=commit,
@@ -162,7 +162,7 @@ def _write_invocation(func, invocation_id, latency_ms, prompt_tokens, completion
162162
config._store.write_invocation(
163163
id=invocation_id,
164164
lmp_id=func.__ell_hash__,
165-
created_at=datetime.now(),
165+
created_at=utc_now(),
166166
global_vars=get_immutable_vars(func.__ell_closure__[2]),
167167
free_vars=get_immutable_vars(func.__ell_closure__[3]),
168168
latency_ms=latency_ms,

src/ell/store.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from abc import ABC, abstractmethod
22
from contextlib import contextmanager
3+
from datetime import datetime
34
from typing import Any, Optional, Dict, List, Set, Union
45
from ell.lstr import lstr
56
from ell.types import InvocableLM
@@ -15,7 +16,7 @@ def write_lmp(self, lmp_id: str, name: str, source: str, dependencies: List[str]
1516
version_number: int,
1617
uses: Dict[str, Any],
1718
commit_message: Optional[str] = None,
18-
created_at: Optional[float]=None) -> Optional[Any]:
19+
created_at: Optional[datetime]=None) -> Optional[Any]:
1920
"""
2021
Write an LMP (Language Model Package) to the storage.
2122
@@ -33,7 +34,7 @@ def write_lmp(self, lmp_id: str, name: str, source: str, dependencies: List[str]
3334

3435
@abstractmethod
3536
def write_invocation(self, id: str, lmp_id: str, args: str, kwargs: str, result: Union[lstr, List[lstr]], invocation_kwargs: Dict[str, Any],
36-
created_at: Optional[float], consumes: Set[str], prompt_tokens: Optional[int] = None,
37+
created_at: Optional[datetime], consumes: Set[str], prompt_tokens: Optional[int] = None,
3738
completion_tokens: Optional[int] = None, latency_ms: Optional[float] = None,
3839
state_cache_key: Optional[str] = None,
3940
cost_estimate: Optional[float] = None) -> Optional[Any]:

src/ell/stores/sql.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import datetime
1+
from datetime import datetime
22
import json
33
import os
44
from typing import Any, Optional, Dict, List, Set, Union
@@ -7,7 +7,7 @@
77
import cattrs
88
import numpy as np
99
from sqlalchemy.sql import text
10-
from ell.types import InvocationTrace, SerializedLMP, Invocation, SerializedLMPUses, SerializedLStr
10+
from ell.types import InvocationTrace, SerializedLMP, Invocation, SerializedLMPUses, SerializedLStr, utc_now
1111
from ell.lstr import lstr
1212
from sqlalchemy import or_, func, and_
1313

@@ -26,7 +26,7 @@ def write_lmp(self, lmp_id: str, name: str, source: str, dependencies: List[str]
2626
global_vars: Dict[str, Any],
2727
free_vars: Dict[str, Any],
2828
commit_message: Optional[str] = None,
29-
created_at: Optional[float]=None) -> Optional[Any]:
29+
created_at: Optional[datetime]=None) -> Optional[Any]:
3030
with Session(self.engine) as session:
3131
lmp = session.query(SerializedLMP).filter(SerializedLMP.lmp_id == lmp_id).first()
3232

@@ -42,7 +42,7 @@ def write_lmp(self, lmp_id: str, name: str, source: str, dependencies: List[str]
4242
dependencies=dependencies,
4343
initial_global_vars=global_vars,
4444
initial_free_vars=free_vars,
45-
created_at= created_at or datetime.datetime.utcnow(),
45+
created_at= created_at or utc_now(),
4646
is_lm=is_lmp,
4747
lm_kwargs=lm_kwargs,
4848
commit_message=commit_message

src/ell/types.py

+30-7
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
# Let's define the core types.
22
from dataclasses import dataclass
3-
from typing import Callable, Dict, List, Union
3+
from typing import Callable, Dict, List, Union, Any, Optional
44

5-
from typing import Any
65
from ell.lstr import lstr
76
from ell.util.dict_sync_meta import DictSyncMeta
87

9-
from datetime import datetime
8+
from datetime import datetime, timezone
109
from typing import Any, List, Optional
11-
from sqlmodel import Field, SQLModel, Relationship, JSON, ARRAY, Column, Float
10+
from sqlmodel import Field, SQLModel, Relationship, JSON, Column
11+
from sqlalchemy import func
12+
import sqlalchemy.types as types
1213

1314
_lstr_generic = Union[lstr, str]
1415

@@ -43,6 +44,14 @@ class Message(dict, metaclass=DictSyncMeta):
4344
InvocableLM = Callable[..., _lstr_generic]
4445

4546

47+
def utc_now() -> datetime:
48+
"""
49+
Returns the current UTC timestamp.
50+
Serializes to ISO-8601.
51+
"""
52+
return datetime.now(tz=timezone.utc)
53+
54+
4655
class SerializedLMPUses(SQLModel, table=True):
4756
"""
4857
Represents the many-to-many relationship between SerializedLMPs.
@@ -54,6 +63,16 @@ class SerializedLMPUses(SQLModel, table=True):
5463
lmp_using_id: Optional[str] = Field(default=None, foreign_key="serializedlmp.lmp_id", primary_key=True, index=True) # ID of the LMP that is using the other LMP
5564

5665

66+
class UTCTimestamp(types.TypeDecorator[datetime]):
67+
impl = types.TIMESTAMP
68+
def process_result_value(self, value: datetime, dialect:Any):
69+
return value.replace(tzinfo=timezone.utc)
70+
71+
def UTCTimestampField(index:bool=False, **kwargs:Any):
72+
return Field(
73+
sa_column= Column(UTCTimestamp(timezone=True),index=index, **kwargs))
74+
75+
5776

5877
class SerializedLMP(SQLModel, table=True):
5978
"""
@@ -65,7 +84,11 @@ class SerializedLMP(SQLModel, table=True):
6584
name: str = Field(index=True) # Name of the LMP
6685
source: str # Source code or reference for the LMP
6786
dependencies: str # List of dependencies for the LMP, stored as a string
68-
created_at: datetime = Field(default_factory=datetime.utcnow, index=True) # Timestamp of when the LMP was created
87+
# Timestamp of when the LMP was created
88+
created_at: datetime = UTCTimestampField(
89+
index=True,
90+
nullable=False
91+
)
6992
is_lm: bool # Boolean indicating if it is an LM (Language Model) or an LMP
7093
lm_kwargs: dict = Field(sa_column=Column(JSON)) # Additional keyword arguments for the LMP
7194

@@ -131,8 +154,8 @@ class Invocation(SQLModel, table=True):
131154
completion_tokens: Optional[int] = Field(default=None)
132155
state_cache_key: Optional[str] = Field(default=None)
133156

134-
135-
created_at: datetime = Field(default_factory=datetime.utcnow) # Timestamp of when the invocation was created
157+
# Timestamp of when the invocation was created
158+
created_at: datetime = UTCTimestampField(default=func.now(), nullable=False)
136159
invocation_kwargs: dict = Field(default_factory=dict, sa_column=Column(JSON)) # Additional keyword arguments for the invocation
137160

138161
# Relationships

tests/test_sql_store.py

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import pytest
2+
from datetime import datetime, timezone
3+
from sqlmodel import Session, select
4+
from ell.stores.sql import SQLStore, SerializedLMP
5+
from sqlalchemy import Engine, create_engine
6+
7+
from ell.types import utc_now
8+
9+
@pytest.fixture
10+
def in_memory_db():
11+
return create_engine("sqlite:///:memory:")
12+
13+
@pytest.fixture
14+
def sql_store(in_memory_db: Engine) -> SQLStore:
15+
store = SQLStore("sqlite:///:memory:")
16+
store.engine = in_memory_db
17+
SerializedLMP.metadata.create_all(in_memory_db)
18+
return store
19+
20+
def test_write_lmp(sql_store: SQLStore):
21+
# Arrange
22+
lmp_id = "test_lmp_1"
23+
name = "Test LMP"
24+
source = "def test_function(): pass"
25+
dependencies = str(["dep1", "dep2"])
26+
is_lmp = True
27+
lm_kwargs = '{"param1": "value1"}'
28+
version_number = 1
29+
uses = {"used_lmp_1": {}, "used_lmp_2": {}}
30+
global_vars = {"global_var1": "value1"}
31+
free_vars = {"free_var1": "value2"}
32+
commit_message = "Initial commit"
33+
created_at = utc_now()
34+
assert created_at.tzinfo is not None
35+
36+
# Act
37+
sql_store.write_lmp(
38+
lmp_id=lmp_id,
39+
name=name,
40+
source=source,
41+
dependencies=dependencies,
42+
is_lmp=is_lmp,
43+
lm_kwargs=lm_kwargs,
44+
version_number=version_number,
45+
uses=uses,
46+
global_vars=global_vars,
47+
free_vars=free_vars,
48+
commit_message=commit_message,
49+
created_at=created_at
50+
)
51+
52+
# Assert
53+
with Session(sql_store.engine) as session:
54+
result = session.exec(select(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id)).first()
55+
56+
assert result is not None
57+
assert result.lmp_id == lmp_id
58+
assert result.name == name
59+
assert result.source == source
60+
assert result.dependencies == str(dependencies)
61+
assert result.is_lm == is_lmp
62+
assert result.lm_kwargs == lm_kwargs
63+
assert result.version_number == version_number
64+
assert result.initial_global_vars == global_vars
65+
assert result.initial_free_vars == free_vars
66+
assert result.commit_message == commit_message
67+
# we want to assert created_at has timezone information
68+
assert result.created_at.tzinfo is not None
69+
70+
# Test that writing the same LMP again doesn't create a duplicate
71+
sql_store.write_lmp(
72+
lmp_id=lmp_id,
73+
name=name,
74+
source=source,
75+
dependencies=dependencies,
76+
is_lmp=is_lmp,
77+
lm_kwargs=lm_kwargs,
78+
version_number=version_number,
79+
uses=uses,
80+
global_vars=global_vars,
81+
free_vars=free_vars,
82+
commit_message=commit_message,
83+
created_at=created_at
84+
)
85+
86+
with Session(sql_store.engine) as session:
87+
count = session.query(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id).count()
88+
assert count == 1

0 commit comments

Comments
 (0)