Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change _fix_chat_template in case a template has both endif and endfor #1388

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions unsloth/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,26 +585,43 @@ def load_correct_tokenizer(
pass


def _find_end_position(template, endfor, endif):
where_endfor = template.find(endfor)
where_endif = template.find(endif)
if where_endfor == where_endif == -1:
return None
elif where_endfor > where_endif:
return endfor
else:
return endif
pass
pass


def _fix_chat_template(chat_template):
endfor = "{% endif %}"
where = chat_template.find(endfor)
if where == -1:
endfor = "{%- endif %}"
where = chat_template.find(endfor)
if where == -1:
endfor = "{% endfor %}"
endif = "{% endif %}"
chosen_end = _find_end_position(chat_template, endfor, endif)
if chosen_end is None:
endfor = "{%- endfor %}"
endif = "{%- endif %}"
chosen_end = _find_end_position(chat_template, endfor, endif)
if chosen_end is None:
return chat_template

where = chat_template.find(chosen_end)

after_endfor = chat_template[where + len(endfor):]
after_endfor = chat_template[where + len(chosen_end):]

dash = "-" if endfor.startswith("{%-") else ""
dash = "-" if chosen_end.startswith("{%-") else ""

if "{%" + dash + " if" not in after_endfor and "{%" + dash + " set " not in after_endfor and \
after_endfor.startswith("{{") and after_endfor.endswith("}}") and \
after_endfor.count("{{") == 1 and after_endfor.count("}}") == 1:

after_endfor = "{%" + dash + " if add_generation_prompt %}" + after_endfor + endfor
after_endfor = "{%" + dash + " if add_generation_prompt %}" + after_endfor + endif

chat_template = chat_template[:where + len(endfor)] + after_endfor
chat_template = chat_template[:where + len(chosen_end)] + after_endfor
pass
return chat_template
pass
Expand Down