Skip to content

Commit 5202520

Browse files
committed
support mps backend.
1 parent 3bde61c commit 5202520

File tree

6 files changed

+23
-4
lines changed

6 files changed

+23
-4
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,6 @@ cython_debug/
162162
.vscode
163163
.infinity_cache
164164
libs/infinity_emb/data/*
165+
166+
# macOS
167+
.DS_Store

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Infinity is a high-throughput, low-latency REST API for serving vector embedding
2424
## Why Infinity:
2525
Infinity provides the following features:
2626
- **Deploy virtually any SentenceTransformer** - deploy the model you know from [SentenceTransformers](https://github.com/UKPLab/sentence-transformers/)
27-
- **Fast inference backends**: The inference server is built on top of [torch](https://github.com/pytorch/pytorch), [fastembed(onnx-cpu)](https://github.com/qdrant/fastembed) and [CTranslate2](https://github.com/OpenNMT/CTranslate2), getting most out of your **CUDA** or **CPU** hardware.
27+
- **Fast inference backends**: The inference server is built on top of [torch](https://github.com/pytorch/pytorch), [fastembed(onnx-cpu)](https://github.com/qdrant/fastembed) and [CTranslate2](https://github.com/OpenNMT/CTranslate2), getting most out of your **CUDA**, **CPU** or **MPS** hardware.
2828
- **Dynamic batching**: New embedding requests are queued while GPU is busy with the previous ones. New requests are squeezed intro your GPU/CPU as soon as ready.
2929
- **Correct and tested implementation**: Unit and end-to-end tested. Embeddings via infinity are identical to [SentenceTransformers](https://github.com/UKPLab/sentence-transformers/) (up to numerical precision). Lets API users create embeddings till infinity and beyond.
3030
- **Easy to use**: The API is built on top of [FastAPI](https://fastapi.tiangolo.com/), [Swagger](https://swagger.io/) makes it fully documented. API are aligned to [OpenAI's Embedding specs](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings). See below on how to get started.

libs/infinity_emb/infinity_emb/primitives.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
class Device(enum.Enum):
1414
cpu = "cpu"
1515
cuda = "cuda"
16+
mps = "mps"
1617
auto = None
1718

1819

libs/infinity_emb/infinity_emb/transformer/classifier/torch.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@ def __init__(self, model_name_or_path, device: Optional[str] = None) -> None:
3131
top_k=None,
3232
torch_dtype=torch.float32 if used_device == "cpu" else torch.float16,
3333
)
34-
self._pipe.model = to_bettertransformer(self._pipe.model, logger=logger)
34+
if used_device == "mps":
35+
logger.info(
36+
"Disable Optimizations via Huggingface optimum for MPS Backend. "
37+
)
38+
else:
39+
self._pipe.model = to_bettertransformer(self._pipe.model, logger)
3540

3641
self._infinity_tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
3742

libs/infinity_emb/infinity_emb/transformer/crossencoder/torch.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,12 @@ def __init__(self, model_name_or_path, **kwargs):
4747
self._infinity_tokenizer = copy.deepcopy(self.tokenizer)
4848
self.model.eval()
4949

50-
self.model = to_bettertransformer(self.model, logger)
50+
if self._target_device.type == "mps":
51+
logger.info(
52+
"Disable Optimizations via Huggingface optimum for MPS Backend. "
53+
)
54+
else:
55+
self.model = to_bettertransformer(self.model, logger)
5156

5257
if self._target_device.type == "cuda" and not os.environ.get(
5358
"INFINITY_DISABLE_HALF", ""

libs/infinity_emb/infinity_emb/transformer/embedder/sentence_transformer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,12 @@ def __init__(self, model_name_or_path, **kwargs):
5454
self._infinity_tokenizer = copy.deepcopy(fm.tokenizer)
5555
self.eval()
5656

57-
fm.auto_model = to_bettertransformer(fm.auto_model, logger)
57+
if self._target_device.type == "mps":
58+
logger.info(
59+
"Disable Optimizations via Huggingface optimum for MPS Backend. "
60+
)
61+
else:
62+
fm.auto_model = to_bettertransformer(fm.auto_model, logger)
5863

5964
if self._target_device.type == "cuda" and not os.environ.get(
6065
"INFINITY_DISABLE_HALF", ""

0 commit comments

Comments
 (0)