diff --git a/docs/installation.md b/docs/installation.md index 12f113d36..1017b627e 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -16,8 +16,8 @@ Outlines supports OpenAI, transformers, Mamba, llama.cpp and exllama2 but **you pip install openai pip install transformers datasets accelerate torch pip install llama-cpp-python -pip install exllamav2 torch -pip install mamba_ssm torch +pip install exllamav2 transformers torch +pip install mamba_ssm transformers torch pip install vllm ``` diff --git a/docs/reference/models/exllamav2.md b/docs/reference/models/exllamav2.md index e4f9ae9d0..afe542112 100644 --- a/docs/reference/models/exllamav2.md +++ b/docs/reference/models/exllamav2.md @@ -1,3 +1,7 @@ # ExllamaV2 +```bash +pip install exllamav2 transformers torch +``` + *Coming soon* diff --git a/docs/reference/models/mamba.md b/docs/reference/models/mamba.md index 7a720516a..ac6db3682 100644 --- a/docs/reference/models/mamba.md +++ b/docs/reference/models/mamba.md @@ -1,3 +1,7 @@ # Mamba +```bash +pip install mamba_ssm transformers torch +``` + *Coming soon* diff --git a/outlines/models/exllamav2.py b/outlines/models/exllamav2.py index a3e97b6ff..0ec6ef033 100644 --- a/outlines/models/exllamav2.py +++ b/outlines/models/exllamav2.py @@ -1,11 +1,10 @@ import os from typing import TYPE_CHECKING, Optional -import torch - if TYPE_CHECKING: from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Lora from transformers import PreTrainedTokenizer + import torch from .transformers import TransformerTokenizer @@ -28,8 +27,9 @@ def __init__( self.past_seq = None self.lora = lora - def forward(self, input_ids: torch.LongTensor, *_): + def forward(self, input_ids: "torch.LongTensor", *_): """Compute a forward pass through the exl2 model.""" + import torch # Caching with past_seq reset = True @@ -74,7 +74,7 @@ def forward(self, input_ids: torch.LongTensor, *_): seq_tensor[-1:].view(1, -1), self.cache, loras=[self.lora] ) - def __call__(self, input_ids: torch.LongTensor, *_) -> torch.FloatTensor: + def __call__(self, input_ids: "torch.LongTensor", *_) -> "torch.FloatTensor": logits = self.forward(input_ids) next_token_logits = logits[..., -1, :] @@ -169,7 +169,7 @@ def exl2( from transformers import AutoTokenizer except ImportError: raise ImportError( - "The `exllamav2` library needs to be installed in order to use `exllamav2` models." + "The `exllamav2`, `transformers` and `torch` libraries needs to be installed in order to use `exllamav2` models." ) # Load tokenizer diff --git a/outlines/models/mamba.py b/outlines/models/mamba.py index 1375a3811..d3dabf669 100644 --- a/outlines/models/mamba.py +++ b/outlines/models/mamba.py @@ -1,10 +1,9 @@ from typing import TYPE_CHECKING, Optional -import torch - from .transformers import TransformerTokenizer if TYPE_CHECKING: + import torch from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel from transformers import PreTrainedTokenizer @@ -22,14 +21,14 @@ def __init__( self.model = model self.tokenizer = TransformerTokenizer(tokenizer) - def forward(self, input_ids: torch.LongTensor, *_): + def forward(self, input_ids: "torch.LongTensor", *_): """Compute a forward pass through the mamba model.""" output = self.model(input_ids) next_token_logits = output.logits[..., -1, :] return next_token_logits, None - def __call__(self, input_ids: torch.LongTensor, *_) -> torch.FloatTensor: + def __call__(self, input_ids: "torch.LongTensor", *_) -> "torch.FloatTensor": return self.forward(input_ids) @@ -40,11 +39,12 @@ def mamba( tokenizer_kwargs: dict = {}, ): try: + import torch from mamba_ssm import MambaLMHeadModel from transformers import AutoTokenizer except ImportError: raise ImportError( - "The `mamba_ssm` library needs to be installed in order to use Mamba people." + "The `mamba_ssm`, `torch` and `transformer` libraries needs to be installed in order to use Mamba people." ) if not torch.cuda.is_available():