Skip to content

Commit 6c5a34e

Browse files
committed
Add optional system prompt in post_training_llava
1 parent e29f3c7 commit 6c5a34e

File tree

2 files changed

+5
-0
lines changed

2 files changed

+5
-0
lines changed

examples/post_training_llava/configs/sft.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ redis = "12800"
1818
[custom.dataset]
1919
annotation_path = "data/sft/annotations.json"
2020
media_path = "data/sft/train2017"
21+
system_prompt = ""
2122

2223
[train]
2324
output_dir = "outputs/sft"

examples/post_training_llava/scripts/custom_sft.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class CustomDatasetConfig(pydantic.BaseModel):
3737
"""Dataset annotation path."""
3838
media_path: str = pydantic.Field(default="")
3939
"""Dataset media path."""
40+
system_prompt: str = pydantic.Field(default="")
41+
"""System prompt for post-training."""
4042

4143

4244
class CustomConfig(pydantic.BaseModel):
@@ -60,6 +62,7 @@ def __init__(
6062
):
6163
self.annotation = json.load(open(custom_config.dataset.annotation_path))
6264
self.media_path = custom_config.dataset.media_path
65+
self.system_prompt = custom_config.dataset.system_prompt
6366
self.config = config
6467
self.custom_config = custom_config
6568
self.vision_kwargs = custom_config.vision.model_dump(exclude_none=True)
@@ -90,6 +93,7 @@ def __getitem__(self, idx: int) -> list[dict]:
9093
user_prompt = re.sub(r"(\n)?</?(image|video)>(\n)?", "", user_prompt)
9194

9295
conversations = create_conversation(
96+
system_prompt=self.system_prompt,
9397
user_prompt=user_prompt,
9498
response=response,
9599
images=images,

0 commit comments

Comments
 (0)