Skip to content

Commit

Permalink
API: Add template list endpoint
Browse files Browse the repository at this point in the history
Fetches all template names that a user has in the templates directory
for chat completions.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
kingbri1 committed Dec 30, 2023
1 parent dce8c74 commit 79a5758
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 3 deletions.
9 changes: 9 additions & 0 deletions OAI/types/template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from pydantic import BaseModel, Field
from typing import List


class TemplateList(BaseModel):
"""Represents a list of templates."""

object: str = "list"
data: List[str] = Field(default_factory=list)
11 changes: 10 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
ModelLoadResponse,
ModelCardParameters,
)
from OAI.types.template import TemplateList
from OAI.types.token import (
TokenEncodeRequest,
TokenEncodeResponse,
Expand All @@ -45,7 +46,7 @@
create_chat_completion_response,
create_chat_completion_stream_chunk,
)
from templating import get_prompt_from_template
from templating import get_all_templates, get_prompt_from_template
from utils import get_generator_error, get_sse_packet, load_progress, unwrap
from logger import init_logger

Expand Down Expand Up @@ -244,6 +245,14 @@ async def unload_model():
MODEL_CONTAINER = None


@app.get("/v1/templates", dependencies=[Depends(check_api_key)])
@app.get("/v1/template/list", dependencies=[Depends(check_api_key)])
async def get_templates():
templates = get_all_templates()
template_strings = list(map(lambda template: template.stem, templates))
return TemplateList(data=template_strings)


# Lora list endpoint
@app.get("/v1/loras", dependencies=[Depends(check_api_key)])
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
Expand Down
11 changes: 9 additions & 2 deletions templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,18 @@ def raise_exception(message):
return jinja_template


def get_all_templates():
"""Fetches all templates from the templates directory"""

template_directory = pathlib.Path("templates")
return template_directory.glob("*.jinja")


def find_template_from_model(model_path: pathlib.Path):
"""Find a matching template name from a model path."""
model_name = model_path.name
template_directory = pathlib.Path("templates")
for filepath in template_directory.glob("*.jinja"):
template_files = get_all_templates()
for filepath in template_files:
template_name = filepath.stem.lower()

# Check if the template name is present in the model name
Expand Down

0 comments on commit 79a5758

Please sign in to comment.