Skip to content

Commit f262257

Browse files
committed
update code
1 parent dff0ff8 commit f262257

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

examples/post_training_llava/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ source .venv/bin/activate
2020

2121
## Example
2222

23-
Please update the fields `annotation_path` and `media_path` in `configs/sft.toml` to your custom dataset. `media_path` can be left as empty (`""`) if the paths in your annotation are absolute paths.
23+
Please update the fields `annotation_path` and `media_path` in `configs/sft.toml` to your custom dataset. `media_path` can be left as empty (`""`) if the paths in your annotation are absolute paths.
2424

2525
Here is one example of downloading the [Llava-Instruct-150K] (https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K) dataset and [COCO] (https://cocodataset.org/#home) images:
2626

examples/post_training_llava/scripts/custom_sft.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import argparse
1919
import json
2020
import os
21+
import re
2122
from pathlib import Path
2223

2324
import cosmos_rl.launcher.worker_entry
@@ -30,6 +31,7 @@
3031
from cosmos_reason1_utils.text import create_conversation
3132
from cosmos_reason1_utils.vision import VisionConfig
3233

34+
3335
class CustomDatasetConfig(pydantic.BaseModel):
3436
annotation_path: str = pydantic.Field()
3537
"""Dataset annotation path."""
@@ -85,7 +87,6 @@ def __getitem__(self, idx: int) -> list[dict]:
8587
videos = [os.path.join(self.media_path, vid) for vid in videos]
8688

8789
# Remove image and video tags from user prompt
88-
import re
8990
user_prompt = re.sub(r"(\n)?</?(image|video)>(\n)?", "", user_prompt)
9091

9192
conversations = create_conversation(
@@ -97,6 +98,7 @@ def __getitem__(self, idx: int) -> list[dict]:
9798
)
9899
return conversations
99100

101+
100102
if __name__ == "__main__":
101103
parser = argparse.ArgumentParser(description=__doc__)
102104
parser.add_argument(

0 commit comments

Comments
 (0)