Skip to content

Commit 29882a6

Browse files
committed
init project
1 parent 01e267e commit 29882a6

File tree

2 files changed

+58
-19
lines changed

2 files changed

+58
-19
lines changed

ddtrace/llmobs/_llmobs.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ class LLMObs(Service):
183183
enabled = False
184184
_app_key: str = os.getenv("DD_APP_KEY", "")
185185
_project_name: str = os.getenv("DD_LLMOBS_PROJECT_NAME", DEFAULT_PROJECT_NAME)
186+
_project_id: str = ""
186187

187188
def __init__(
188189
self,
@@ -212,6 +213,7 @@ def __init__(
212213
interval=float(os.getenv("_DD_LLMOBS_WRITER_INTERVAL", 1.0)),
213214
timeout=float(os.getenv("_DD_LLMOBS_WRITER_TIMEOUT", 5.0)),
214215
_app_key=self._app_key,
216+
_default_project_id=self._project_id,
215217
is_agentless=True, # agent proxy doesn't seem to work for experiments
216218
)
217219

@@ -613,6 +615,15 @@ def enable(
613615
cls.enabled = True
614616
cls._instance.start()
615617

618+
try:
619+
cls._project_id = cls._instance._dne_client.project_create_or_get(cls._project_name)
620+
except Exception as e:
621+
log.error(
622+
"failed to get project ID with %s, dataset & experiments features may not be functional: %s",
623+
cls._project_name,
624+
e,
625+
)
626+
616627
# Register hooks for span events
617628
core.on("trace.span_start", cls._instance._on_span_start)
618629
core.on("trace.span_finish", cls._instance._on_span_finish)
@@ -654,15 +665,23 @@ def enable(
654665
)
655666

656667
@classmethod
657-
def pull_dataset(cls, name: str) -> Dataset:
658-
ds = cls._instance._dne_client.dataset_get_with_records(name)
668+
def pull_dataset(cls, dataset_name: str, project_name: Optional[str] = None) -> Dataset:
669+
ds = cls._instance._dne_client.dataset_get_with_records(
670+
dataset_name, cls._project_name if project_name is None else project_name
671+
)
659672
return ds
660673

661674
@classmethod
662-
def create_dataset(cls, name: str, description: str = "", records: Optional[List[DatasetRecord]] = None) -> Dataset:
675+
def create_dataset(
676+
cls,
677+
dataset_name: str,
678+
project_name: Optional[str] = None,
679+
description: str = "",
680+
records: Optional[List[DatasetRecord]] = None,
681+
) -> Dataset:
663682
if records is None:
664683
records = []
665-
ds = cls._instance._dne_client.dataset_create(name, description)
684+
ds = cls._instance._dne_client.dataset_create(dataset_name, project_name, description)
666685
for r in records:
667686
ds.append(r)
668687
if len(records) > 0:
@@ -678,13 +697,14 @@ def create_dataset_from_csv(
678697
expected_output_columns: Optional[List[str]] = None,
679698
metadata_columns: Optional[List[str]] = None,
680699
csv_delimiter: str = ",",
681-
description="",
700+
description: str = "",
701+
project_name: Optional[str] = None,
682702
) -> Dataset:
683703
if expected_output_columns is None:
684704
expected_output_columns = []
685705
if metadata_columns is None:
686706
metadata_columns = []
687-
ds = cls._instance._dne_client.dataset_create(dataset_name, description)
707+
ds = cls._instance._dne_client.dataset_create(dataset_name, project_name, description)
688708

689709
# Store the original field size limit to restore it later
690710
original_field_size_limit = csv.field_size_limit()

ddtrace/llmobs/_writer.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def __init__(
140140
_api_key: str = "",
141141
_app_key: str = "",
142142
_override_url: str = "",
143+
_default_project_id: str = "",
143144
) -> None:
144145
super(BaseLLMObsWriter, self).__init__(interval=interval)
145146
self._lock = forksafe.RLock()
@@ -150,6 +151,7 @@ def __init__(
150151
self._site: str = _site or config._dd_site
151152
self._app_key: str = _app_key
152153
self._override_url: str = _override_url or os.environ.get("DD_LLMOBS_OVERRIDE_ORIGIN", "")
154+
self._default_project_id: str = _default_project_id
153155

154156
self._agentless: bool = is_agentless
155157
self._intake: str = self._override_url or (
@@ -362,23 +364,28 @@ def dataset_delete(self, dataset_id: str) -> None:
362364
raise ValueError(f"Failed to delete dataset {id}: {resp.get_json()}")
363365
return None
364366

365-
def dataset_create(self, name: str, description: str) -> Dataset:
366-
path = "/api/unstable/llm-obs/v1/datasets"
367+
def dataset_create(
368+
self, dataset_name: str, project_name: Optional[str], description: str,
369+
) -> Dataset:
370+
project_id = self.project_create_or_get(project_name)
371+
logger.debug("getting records with project ID %s for %s", project_id, project_name)
372+
373+
path = f"/api/unstable/llm-obs/v1/{project_id}/datasets"
367374
body: JSONType = {
368375
"data": {
369376
"type": "datasets",
370-
"attributes": {"name": name, "description": description},
377+
"attributes": {"name": dataset_name, "description": description},
371378
}
372379
}
373380
resp = self.request("POST", path, body)
374381
if resp.status != 200:
375-
raise ValueError(f"Failed to create dataset {name}: {resp.status} {resp.get_json()}")
382+
raise ValueError(f"Failed to create dataset {dataset_name}: {resp.status} {resp.get_json()}")
376383
response_data = resp.get_json()
377384
dataset_id = response_data["data"]["id"]
378385
if dataset_id is None or dataset_id == "":
379386
raise ValueError(f"unexpected dataset state, invalid ID (is None: {dataset_id is None})")
380387
curr_version = response_data["data"]["attributes"]["current_version"]
381-
return Dataset(name, dataset_id, [], description, curr_version, _dne_client=self)
388+
return Dataset(dataset_name, dataset_id, [], description, curr_version, _dne_client=self)
382389

383390
@staticmethod
384391
def _get_record_json(record: Union[UpdatableDatasetRecord, DatasetRecordRaw], is_update: bool) -> JSONType:
@@ -436,16 +443,19 @@ def dataset_batch_update(
436443
new_record_ids: List[str] = [r["id"] for r in data] if data else []
437444
return new_version, new_record_ids
438445

439-
def dataset_get_with_records(self, name: str) -> Dataset:
440-
path = f"/api/unstable/llm-obs/v1/datasets?filter[name]={quote(name)}"
446+
def dataset_get_with_records(self, dataset_name: str, project_name: Optional[str] = None) -> Dataset:
447+
project_id = self.project_create_or_get(project_name)
448+
logger.debug("getting records with project ID %s for %s", project_id, project_name)
449+
450+
path = f"/api/unstable/llm-obs/v1/{project_id}/datasets?filter[name]={quote(dataset_name)}"
441451
resp = self.request("GET", path)
442452
if resp.status != 200:
443-
raise ValueError(f"Failed to pull dataset {name}: {resp.status}")
453+
raise ValueError(f"Failed to pull dataset {dataset_name} from project {project_name}: {resp.status}")
444454

445455
response_data = resp.get_json()
446456
data = response_data["data"]
447457
if not data:
448-
raise ValueError(f"Dataset '{name}' not found")
458+
raise ValueError(f"Dataset '{dataset_name}' not found in project {project_name}")
449459

450460
curr_version = data[0]["attributes"]["current_version"]
451461
dataset_description = data[0]["attributes"].get("description", "")
@@ -460,7 +470,8 @@ def dataset_get_with_records(self, name: str) -> Dataset:
460470
resp = self.request("GET", list_path, timeout=self.LIST_RECORDS_TIMEOUT)
461471
if resp.status != 200:
462472
raise ValueError(
463-
f"Failed to pull {page_num}th page of dataset records {name}: {resp.status} {resp.get_json()}"
473+
f"Failed to pull {page_num}th page of dataset records {dataset_name}: "
474+
f"{resp.status} {resp.get_json()}"
464475
)
465476
records_data = resp.get_json()
466477

@@ -481,7 +492,7 @@ def dataset_get_with_records(self, name: str) -> Dataset:
481492
list_path = f"{list_base_path}?page[cursor]={next_cursor}"
482493
logger.debug("next list records request path %s", list_path)
483494
page_num += 1
484-
return Dataset(name, dataset_id, class_records, dataset_description, curr_version, _dne_client=self)
495+
return Dataset(dataset_name, dataset_id, class_records, dataset_description, curr_version, _dne_client=self)
485496

486497
def dataset_bulk_upload(self, dataset_id: str, records: List[DatasetRecord]):
487498
with tempfile.NamedTemporaryFile(suffix=".csv") as tmp:
@@ -534,7 +545,10 @@ def dataset_bulk_upload(self, dataset_id: str, records: List[DatasetRecord]):
534545
raise ValueError(f"Failed to upload dataset from file: {resp.status} {resp.get_json()}")
535546
logger.debug("successfully uploaded with code %d", resp.status)
536547

537-
def project_create_or_get(self, name: str) -> str:
548+
def project_create_or_get(self, name: Optional[str] = None) -> str:
549+
if name is None or name == "":
550+
return self._default_project_id
551+
538552
path = "/api/unstable/llm-obs/v1/projects"
539553
resp = self.request(
540554
"POST",
@@ -544,7 +558,12 @@ def project_create_or_get(self, name: str) -> str:
544558
if resp.status != 200:
545559
raise ValueError(f"Failed to create project {name}: {resp.status} {resp.get_json()}")
546560
response_data = resp.get_json()
547-
return response_data["data"]["id"]
561+
project_id = response_data["data"]["id"]
562+
563+
if project_id is None or project_id == "":
564+
raise ValueError(f"project ID is required for dataset & experiments features (project name: {name})")
565+
566+
return project_id
548567

549568
def experiment_create(
550569
self,

0 commit comments

Comments
 (0)