From e0c3f04dba90b434c534fe42bd1312c8fd90a455 Mon Sep 17 00:00:00 2001 From: Eric Curtin Date: Wed, 29 Jan 2025 09:59:37 +0000 Subject: [PATCH] Parse https://ollama.com/library/ syntax People search for ollama models using the web ui, this change allows one to copy the url from the browser and for it to be compatible with ramalama run. Signed-off-by: Eric Curtin --- ramalama/cli.py | 12 ++++++------ ramalama/common.py | 9 +++++++++ ramalama/huggingface.py | 7 +++---- ramalama/oci.py | 16 ++++------------ ramalama/ollama.py | 6 ++++-- ramalama/url.py | 12 ++++-------- 6 files changed, 30 insertions(+), 32 deletions(-) diff --git a/ramalama/cli.py b/ramalama/cli.py index 0b3577b4..e853a35a 100644 --- a/ramalama/cli.py +++ b/ramalama/cli.py @@ -857,14 +857,14 @@ def rm_cli(args): def New(model, args): - if model.startswith("huggingface://") or model.startswith("hf://") or model.startswith("hf.co/"): + if model.startswith("hf://") or model.startswith("huggingface://") or model.startswith("hf.co/"): return Huggingface(model) - if model.startswith("ollama"): - return Ollama(model) - if model.startswith("oci://") or model.startswith("docker://"): - return OCI(model, args.engine) - if model.startswith("http://") or model.startswith("https://") or model.startswith("file://"): + elif ( + model.startswith("https://") or model.startswith("http://") or model.startswith("file://") + ) and not model.startswith("https://ollama.com/library/"): return URL(model) + elif model.startswith("oci://") or model.startswith("docker://"): + return OCI(model, args.engine) transport = config.get("transport", "ollama") if transport == "huggingface": diff --git a/ramalama/common.py b/ramalama/common.py index 7f4f2e4e..88a002f1 100644 --- a/ramalama/common.py +++ b/ramalama/common.py @@ -243,3 +243,12 @@ def get_env_vars(): # env_vars[gpu_type] = str(gpu_num) return env_vars + + +def rm_until_substring(model, substring): + pos = model.find(substring) + if pos == -1: + return model + + # Create a new string starting after the found substring + return ''.join(model[i] for i in range(pos + len(substring), len(model))) diff --git a/ramalama/huggingface.py b/ramalama/huggingface.py index 184a107d..1212ccb4 100644 --- a/ramalama/huggingface.py +++ b/ramalama/huggingface.py @@ -1,7 +1,7 @@ import os import pathlib import urllib.request -from ramalama.common import available, run_cmd, exec_cmd, download_file, verify_checksum, perror +from ramalama.common import available, run_cmd, exec_cmd, download_file, verify_checksum, perror, rm_until_substring from ramalama.model import Model missing_huggingface = """ @@ -33,9 +33,8 @@ def fetch_checksum_from_api(url): class Huggingface(Model): def __init__(self, model): - model = model.removeprefix("huggingface://") - model = model.removeprefix("hf://") - model = model.removeprefix("hf.co/") + model = rm_until_substring(model, "hf.co/") + model = rm_until_substring(model, "://") super().__init__(model) self.type = "huggingface" split = self.model.rsplit("/", 1) diff --git a/ramalama/oci.py b/ramalama/oci.py index 6f5e50ff..f83cfeaf 100644 --- a/ramalama/oci.py +++ b/ramalama/oci.py @@ -4,14 +4,8 @@ import tempfile import ramalama.annotations as annotations -from ramalama.model import Model, MODEL_TYPES -from ramalama.common import ( - engine_version, - exec_cmd, - MNT_FILE, - perror, - run_cmd, -) +from ramalama.model import Model +from ramalama.common import engine_version, exec_cmd, MNT_FILE, perror, run_cmd, rm_until_substring prefix = "oci://" @@ -103,10 +97,8 @@ def list_models(args): class OCI(Model): def __init__(self, model, conman): - super().__init__(model.removeprefix(prefix).removeprefix("docker://")) - for t in MODEL_TYPES: - if self.model.startswith(t + "://"): - raise ValueError(f"{model} invalid: Only OCI Model types supported") + model = rm_until_substring(model, "://") + super().__init__(model) self.type = "OCI" self.conman = conman diff --git a/ramalama/ollama.py b/ramalama/ollama.py index 32aa4565..8713d1cd 100644 --- a/ramalama/ollama.py +++ b/ramalama/ollama.py @@ -1,7 +1,7 @@ import os import urllib.request import json -from ramalama.common import run_cmd, verify_checksum, download_file +from ramalama.common import run_cmd, verify_checksum, download_file, rm_until_substring from ramalama.model import Model @@ -60,7 +60,9 @@ def init_pull(repos, accept, registry_head, model_name, model_tag, models, model class Ollama(Model): def __init__(self, model): - super().__init__(model.removeprefix("ollama://")) + model = rm_until_substring(model, "ollama.com/library/") + model = rm_until_substring(model, "://") + super().__init__(model) self.type = "Ollama" def _local(self, args): diff --git a/ramalama/url.py b/ramalama/url.py index 8e11dfe4..593bfff9 100644 --- a/ramalama/url.py +++ b/ramalama/url.py @@ -1,17 +1,13 @@ import os -from ramalama.common import download_file +from ramalama.common import download_file, rm_until_substring from ramalama.model import Model +from urllib.parse import urlparse class URL(Model): def __init__(self, model): - self.type = "" - for prefix in ["file", "http", "https"]: - if model.startswith(f"{prefix}://"): - self.type = prefix - model = model.removeprefix(f"{prefix}://") - break - + self.type = urlparse(model).scheme + model = rm_until_substring(model, "://") super().__init__(model) split = self.model.rsplit("/", 1) self.directory = split[0].removeprefix("/") if len(split) > 1 else ""