Skip to content

Commit fbe20e7

Browse files
committed
update to Ovis2
1 parent d248e34 commit fbe20e7

23 files changed

+899
-92
lines changed

README.md

Lines changed: 29 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,25 @@ Ovis (Open VISion) is a novel Multimodal Large Language Model (MLLM) architectur
77
</div>
88

99
## Release
10-
- [11/26] 🔥 Announcing [Ovis1.6-Gemma2-27B](https://huggingface.co/AIDC-AI/Ovis1.6-Gemma2-27B)!
11-
- [11/04] 🔥 Announcing quantized versions of Ovis1.6: [Ovis1.6-Gemma2-9B-GPTQ-Int4](https://huggingface.co/AIDC-AI/Ovis1.6-Gemma2-9B-GPTQ-Int4) and [Ovis1.6-Llama3.2-3B-GPTQ-Int4](https://huggingface.co/AIDC-AI/Ovis1.6-Llama3.2-3B-GPTQ-Int4)!
12-
- [10/22] 🔥 Announcing Ovis1.6-Llama3.2-3B ([Model](https://huggingface.co/AIDC-AI/Ovis1.6-Llama3.2-3B), [Demo](https://huggingface.co/spaces/AIDC-AI/Ovis1.6-Llama3.2-3B))!
13-
- [09/19] 🔥 Announcing Ovis1.6-Gemma2-9B ([Model](https://huggingface.co/AIDC-AI/Ovis1.6-Gemma2-9B), [Demo](https://huggingface.co/spaces/AIDC-AI/Ovis1.6-Gemma2-9B))! This latest release further enhances high-resolution image processing, is trained on a larger, more diverse, and higher-quality dataset, and refines the training process with DPO training following instruction-tuning.
14-
- [07/24] 🔥 Introducing Ovis1.5, featuring improved high-resolution image processing and optimized training data for enhanced performance.
15-
- [06/14] 🔥 Launch of Ovis1.0, the inaugural version of the Ovis model.
10+
- [25/01/26] 🔥 Launch of [Ovis2-1/2/4/8/16/34B](https://huggingface.co/AIDC-AI/Ovis2-34B), the latest version of Ovis models, featuring breakthrough small-model performance, enhanced reasoning capabilities, advanced video and multi-image processing, expanded multilingual OCR support, and improved high-resolution image handling.
11+
- [24/11/26] 🔥 Announcing [Ovis1.6-Gemma2-27B](https://huggingface.co/AIDC-AI/Ovis1.6-Gemma2-27B)!
12+
- [24/11/04] 🔥 Announcing quantized versions of Ovis1.6: [Ovis1.6-Gemma2-9B-GPTQ-Int4](https://huggingface.co/AIDC-AI/Ovis1.6-Gemma2-9B-GPTQ-Int4) and [Ovis1.6-Llama3.2-3B-GPTQ-Int4](https://huggingface.co/AIDC-AI/Ovis1.6-Llama3.2-3B-GPTQ-Int4)!
13+
- [24/10/22] 🔥 Announcing Ovis1.6-Llama3.2-3B ([Model](https://huggingface.co/AIDC-AI/Ovis1.6-Llama3.2-3B), [Demo](https://huggingface.co/spaces/AIDC-AI/Ovis1.6-Llama3.2-3B))!
14+
- [24/09/19] 🔥 Announcing Ovis1.6-Gemma2-9B ([Model](https://huggingface.co/AIDC-AI/Ovis1.6-Gemma2-9B), [Demo](https://huggingface.co/spaces/AIDC-AI/Ovis1.6-Gemma2-9B))! This release further enhances high-resolution image processing, is trained on a larger, more diverse, and higher-quality dataset, and refines the training process with DPO training following instruction-tuning.
15+
- [24/07/24] 🔥 Introducing Ovis1.5, featuring improved high-resolution image processing and optimized training data for enhanced performance.
16+
- [24/06/14] 🔥 Launch of Ovis1.0, the inaugural version of the Ovis model.
1617

1718
## Contents
1819
- [Install](#install)
1920
- [Model](#model)
2021
- [Performance](#performance)
21-
- [Finetune](#finetune)
2222
- [Inference](#inference)
23-
- [Quantization](#quantization)
2423
- [Citation](#citation)
2524
- [Team](#team)
2625
- [License](#license)
2726

2827
## Install
29-
Ovis has been tested with Python 3.10, Torch 2.4.0, Transformers 4.46.2, and DeepSpeed 0.15.4. For a comprehensive list of package dependencies, please consult the `requirements.txt` file. Before finetuning or inference, please install Ovis as follows.
28+
Ovis has been tested with Python 3.10, Torch 2.4.0, Transformers 4.46.2, and DeepSpeed 0.15.4. For a comprehensive list of package dependencies, please consult the `requirements.txt` file.
3029
```bash
3130
git clone [email protected]:AIDC-AI/Ovis.git
3231
conda create -n ovis python=3.10 -y
@@ -39,27 +38,30 @@ pip install -e .
3938
## Model
4039
Ovis can be instantiated with popular LLMs. We provide the following Ovis MLLMs:
4140

42-
| Ovis MLLMs | ViT | LLM | Model Weights | Demo |
43-
|:------------------|:-----------:|:------------------:|:---------------------------------------------------------------:|:----------------------------------------------------------------:|
44-
| Ovis1.6-Gemma2-27B | Siglip-400M | Gemma2-27B-It | [Huggingface](https://huggingface.co/AIDC-AI/Ovis1.6-Gemma2-27B) | - |
45-
| Ovis1.6-Gemma2-9B | Siglip-400M | Gemma2-9B-It | [Huggingface](https://huggingface.co/AIDC-AI/Ovis1.6-Gemma2-9B) | [Space](https://huggingface.co/spaces/AIDC-AI/Ovis1.6-Gemma2-9B) |
46-
| Ovis1.6-Llama3.2-3B | Siglip-400M | Llama-3.2-3B-Instruct | [Huggingface](https://huggingface.co/AIDC-AI/Ovis1.6-Llama3.2-3B) | [Space](https://huggingface.co/spaces/AIDC-AI/Ovis1.6-Llama3.2-3B) |
41+
| Ovis MLLMs | ViT | LLM | Model Weights | Demo |
42+
|:-----------|:-----------------------:|:---------------------:|:-------------------------------------------------------:|:--------------------------------------------------------:|
43+
| Ovis2-1B | aimv2-large-patch14-448 | Qwen2.5-0.5B-Instruct | [Huggingface](https://huggingface.co/AIDC-AI/Ovis2-1B) | [Space](https://huggingface.co/spaces/AIDC-AI/Ovis2-1B) |
44+
| Ovis2-2B | aimv2-large-patch14-448 | Qwen2.5-1.5B-Instruct | [Huggingface](https://huggingface.co/AIDC-AI/Ovis2-2B) | [Space](https://huggingface.co/spaces/AIDC-AI/Ovis2-2B) |
45+
| Ovis2-4B | aimv2-huge-patch14-448 | Qwen2.5-3B-Instruct | [Huggingface](https://huggingface.co/AIDC-AI/Ovis2-4B) | [Space](https://huggingface.co/spaces/AIDC-AI/Ovis2-4B) |
46+
| Ovis2-8B | aimv2-huge-patch14-448 | Qwen2.5-7B-Instruct | [Huggingface](https://huggingface.co/AIDC-AI/Ovis2-8B) | [Space](https://huggingface.co/spaces/AIDC-AI/Ovis2-8B) |
47+
| Ovis2-16B | aimv2-huge-patch14-448 | Qwen2.5-14B-Instruct | [Huggingface](https://huggingface.co/AIDC-AI/Ovis2-16B) | [Space](https://huggingface.co/spaces/AIDC-AI/Ovis2-16B) |
48+
| Ovis2-34B | aimv2-1B-patch14-448 | Qwen2.5-32B-Instruct | [Huggingface](https://huggingface.co/AIDC-AI/Ovis2-34B) | - |
4749

4850
## Performance
49-
With **29B** parameters, **Ovis1.6-Gemma2-27B** achieves exceptional performance in the [OpenCompass](https://github.com/open-compass/VLMEvalKit) benchmark, ranking among the top-tier open-source MLLMs.
5051

51-
![performance-Ovis1_6-Gemma2-27B](docs/performance/Ovis1_6-Gemma2-27B.png)
52+
![performance-Ovis2](docs/performance/Ovis2.png)
5253

53-
With just **10B** parameters, **Ovis1.6-Gemma2-9B** leads the [OpenCompass](https://github.com/open-compass/VLMEvalKit) benchmark among open-source MLLMs within **30B** parameters.
54-
55-
![performance-Ovis1_6-Gemma2-9B](docs/performance/Ovis1_6-Gemma2-9B.png)
56-
57-
**Ovis1.6-Llama3.2-3B** leads the [OpenCompass](https://github.com/open-compass/VLMEvalKit) benchmark among open-source MLLMs under **4B** parameters, even surpassing Llama-3.2-11B-Vision-Instruct.
58-
59-
![performance-Ovis1_6-Llama3_2-3B](docs/performance/Ovis1_6-Llama3_2-3B.png)
60-
61-
## Finetune
62-
Finetuning Ovis1.6-Gemma2-9B is supported in [ms-swift](https://github.com/modelscope/ms-swift).
54+
|Benchmark|Ovis2-1B|Ovis2-2B|Ovis2-4B|Ovis2-8B|Ovis2-16B|Ovis2-34B|
55+
|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
56+
|MMBench-V1.1<sub>test</sub>|68.5|77.2|81.4|83.3|85.2|86.2|
57+
|MMStar|52.0|59.0|61.7|64.4|66.9|69.4|
58+
|MMMU<sub>val</sub>|36.0|45.3|48.0|59.0|59.6|65.6|
59+
|MathVista<sub>testmini</sub>|59.5|64.4|69.1|71.4|74.9|77.0|
60+
|HallBench<sub>avg</sub>|44.5|50.2|54.0|56.0|55.9|58.8|
61+
|AI2D<sub>test</sub>|76.8|82.6|85.5|86.8|86.1|88.4|
62+
|OCRBench|88.7|87.5|91.0|89.3|88.2|89.8|
63+
|MMVet|50.3|58.6|65.5|68.5|68.4|75.5|
64+
|Average|59.5|65.6|69.5|72.3|73.1|76.3|
6365

6466
## Inference
6567
We provide an inference wrapper in `ovis/serve/runner.py`, which can be used as:
@@ -77,16 +79,6 @@ Based on [Gradio](https://github.com/gradio-app/gradio), Ovis can also be access
7779
python ovis/serve/server.py --model_path MODEL_PATH --port PORT
7880
```
7981

80-
## Quantization
81-
We quantized Ovis1.6 using AutoGPTQ. For detailed information on running and creating your own quantized version, please refer to the respective Huggingface model cards: [Ovis1.6-Gemma2-9B-GPTQ-Int4](https://huggingface.co/AIDC-AI/Ovis1.6-Gemma2-9B-GPTQ-Int4) and [Ovis1.6-Llama3.2-3B-GPTQ-Int4](https://huggingface.co/AIDC-AI/Ovis1.6-Llama3.2-3B-GPTQ-Int4). Quantized Ovis1.6 maintains performance comparable to its non-quantized counterpart while requiring less GPU memory:
82-
83-
- Benchmark performance:
84-
![performance-Ovis1_6-Gemma2-9B-GPTQ-Int4](docs/performance/Ovis1_6-Gemma2-9B-GPTQ-Int4.png)
85-
![performance-Ovis1_6-Llama3_2-3B-GPTQ-Int4](docs/performance/Ovis1_6-Llama3_2-3B-GPTQ-Int4.png)
86-
87-
- GPU memory usage (max_partition=9):
88-
![performance-Ovis1_6-VRAM-Comparison](docs/performance/Ovis1_6-VRAM-Comparison.png)
89-
9082
## Citation
9183
If you find Ovis useful, please cite the paper
9284
```
@@ -99,7 +91,7 @@ If you find Ovis useful, please cite the paper
9991
```
10092

10193
## Team
102-
This work is a collaborative effort by the MarcoVL team. We would also like to provide links to the following MLLM papers from our team:
94+
This work is a collaborative effort by the Alibaba Ovis team. We would also like to provide links to the following MLLM papers from our team:
10395
- [Parrot: Multilingual Visual Instruction Tuning](https://arxiv.org/abs/2406.02539)
10496
- [Wings: Learning Multimodal LLMs without Text-only Forgetting](https://arxiv.org/abs/2406.03496)
10597

docs/performance/Ovis2.png

584 KB
Loading

ovis/model/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,9 @@
1+
from transformers import AutoConfig, AutoModel
2+
from .visual_tokenizer.configuration_aimv2 import AIMv2Config
3+
from .visual_tokenizer.modeling_aimv2 import AIMv2Model
14
from .visual_tokenizer.clip_visual_tokenizer import ClipVisualTokenizerConfig, ClipVisualTokenizer
25
from .visual_tokenizer.siglip_visual_tokenizer import SiglipVisualTokenizerConfig, SiglipVisualTokenizer
6+
from .visual_tokenizer.aimv2_visual_tokenizer import Aimv2VisualTokenizerConfig, Aimv2VisualTokenizer
7+
8+
AutoConfig.register('aimv2', AIMv2Config)
9+
AutoModel.register(AIMv2Config, AIMv2Model)

ovis/model/configuration_ovis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def __init__(
1010
self,
1111
llm_config: Optional[Union[PretrainedConfig, dict]] = None,
1212
visual_tokenizer_config: Optional[Union[PretrainedConfig, dict]] = None,
13-
multimodal_max_length=8192,
13+
multimodal_max_length=32768,
1414
hidden_size=None,
1515
conversation_formatter_class=None,
1616
llm_attn_implementation=None,

ovis/model/conversation_formatter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def __init__(self, tokenizer):
1515
self.image_token = IMAGE_TOKEN
1616
self.image_token_id = IMAGE_TOKEN_ID
1717
self.ignore_id = IGNORE_ID
18+
self.im_end = None
1819

1920
def _tokenize_with_image_symbol(self, text):
2021
text_chunks = [self.tokenizer(chunk, add_special_tokens=False).input_ids for chunk in

ovis/model/modeling_ovis.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,7 @@ def _merge_modules(modules_list: tuple):
8383
self.is_parallelizable = all((self.llm.is_parallelizable, self.visual_tokenizer.is_parallelizable))
8484
self.supports_gradient_checkpointing = all(
8585
(self.llm.supports_gradient_checkpointing, self.visual_tokenizer.supports_gradient_checkpointing))
86-
self._supports_flash_attn_2 = all(
87-
(self.llm._supports_flash_attn_2, self.visual_tokenizer._supports_flash_attn_2))
86+
self._supports_flash_attn_2 = True
8887
self._supports_sdpa = all((self.llm._supports_sdpa, self.visual_tokenizer._supports_sdpa))
8988

9089
def get_text_tokenizer(self):
@@ -147,7 +146,7 @@ def forward(
147146
pixel_values: List[Optional[torch.Tensor]],
148147
**kwargs
149148
):
150-
assert self.training, "`forward` can only be used in training. For inference, use `generate`."
149+
# assert self.training, "`forward` can only be used in training. For inference, use `generate`."
151150
_, inputs_embeds, labels, attention_mask = self.merge_multimodal(
152151
text_input_ids=input_ids,
153152
text_attention_masks=attention_mask,
@@ -161,7 +160,8 @@ def merge_multimodal(
161160
text_input_ids: torch.Tensor,
162161
text_attention_masks: torch.Tensor,
163162
text_labels: Optional[torch.Tensor],
164-
pixel_values: List[Optional[torch.Tensor]]
163+
pixel_values: List[Optional[torch.Tensor]],
164+
left_padding: bool = False
165165
):
166166
input_device = text_input_ids.device
167167
visual_vocab_szie = self.get_visual_tokenizer().config.vocab_size
@@ -202,7 +202,8 @@ def merge_multimodal(
202202
visual_input_ids = [None] * len(num_images)
203203
visual_labels = [None] * len(num_images)
204204
# just placeholders
205-
text_labels = torch.full(text_input_ids.shape, IGNORE_ID, dtype=torch.long, device=input_device)
205+
if text_labels is None:
206+
text_labels = torch.full(text_input_ids.shape, IGNORE_ID, dtype=torch.long, device=input_device)
206207

207208
input_embeds = []
208209
attention_masks = []
@@ -254,29 +255,30 @@ def merge_multimodal(
254255
attention_masks.append(attention_mask)
255256
labels.append(label)
256257

257-
if self.training: # padding to self.config.multimodal_max_length for increased training speed
258-
padding_size = max(0, self.config.multimodal_max_length - len(input_embeds[0]))
259-
input_embeds[0] = torch.nn.ConstantPad2d((0, 0, 0, padding_size), 0.0)(input_embeds[0])
260-
attention_masks[0] = torch.nn.ConstantPad1d((0, padding_size), False)(attention_masks[0])
261-
labels[0] = torch.nn.ConstantPad1d((0, padding_size), IGNORE_ID)(labels[0])
262-
batch_input_embeds = torch.nn.utils.rnn.pad_sequence(input_embeds, batch_first=True, padding_value=0.0)[:,
263-
:self.config.multimodal_max_length, :]
264-
batch_attention_mask = torch.nn.utils.rnn.pad_sequence(attention_masks, batch_first=True, padding_value=False)[
265-
:,
266-
:self.config.multimodal_max_length]
267-
batch_labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_ID)[:,
268-
:self.config.multimodal_max_length]
258+
batch_input_embeds = self.pad_truncate_sequence(input_embeds, batch_first=True, padding_value=0.0, left_padding=left_padding)
259+
batch_attention_mask = self.pad_truncate_sequence(attention_masks, batch_first=True, padding_value=False, left_padding=left_padding)
260+
batch_labels = self.pad_truncate_sequence(labels, batch_first=True, padding_value=IGNORE_ID, left_padding=left_padding)
269261

270262
return visual_input_ids, batch_input_embeds, batch_labels, batch_attention_mask
271263

264+
def pad_truncate_sequence(self, sequences: List[torch.Tensor], batch_first: bool = True, padding_value: float = 0.0, left_padding: bool = False) -> torch.Tensor:
265+
if not left_padding:
266+
pad_sequence = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=batch_first, padding_value=padding_value)
267+
return pad_sequence[:,:self.config.multimodal_max_length]
268+
else:
269+
pad_sequence = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[0]) for i in sequences],batch_first=True, padding_value=padding_value).flip(dims=[1])
270+
return pad_sequence[:,-self.config.multimodal_max_length:]
271+
272272
def preprocess_inputs(
273273
self,
274274
text_or_conversations: Union[List[Dict], str],
275275
images: Optional[List[PIL.Image.Image]],
276276
max_partition=9,
277277
generation_preface='',
278278
return_labels=False,
279-
propagate_exception=True
279+
propagate_exception=True,
280+
frame_selector=None,
281+
frame_selector_kwargs=None
280282
):
281283
# convert text to conversations
282284
if isinstance(text_or_conversations, str):
@@ -290,6 +292,10 @@ def preprocess_inputs(
290292
raise ValueError(f'Invalid type of `text_or_conversations`, expected `List[Dict]` or `str`,'
291293
f' but got {type(text_or_conversations)}')
292294

295+
if frame_selector is not None:
296+
frame_selector_kwargs = frame_selector_kwargs or {}
297+
conversations, images = frame_selector(conversations=conversations, frames=images, **frame_selector_kwargs)
298+
293299
# format conversations
294300
prompt, raw_input_ids, raw_labels = self.get_conversation_formatter().format(
295301
conversations, generation_preface=generation_preface)
@@ -408,22 +414,23 @@ def _get_hybrid_cache_for_llm(self, batch_size: int, max_cache_len: int):
408414
llm._cache.reset()
409415
return llm._cache
410416

411-
# TODO: support batch generation
412417
def generate(
413418
self,
414419
inputs: Optional[torch.Tensor] = None,
415420
**kwargs
416421
) -> Union[GenerateOutput, torch.LongTensor]:
417-
assert inputs.shape[0] == 1, 'Currently, only support `batch_size=1`'
418422
_, inputs_embeds, labels, attention_mask = self.merge_multimodal(
419423
text_input_ids=inputs,
420424
text_attention_masks=kwargs.pop('attention_mask'),
421425
text_labels=None,
422-
pixel_values=kwargs.pop('pixel_values')
426+
pixel_values=kwargs.pop('pixel_values'),
427+
left_padding=True
423428
)
429+
inputs_embeds = inputs_embeds.detach()
430+
torch.cuda.empty_cache()
424431
if getattr(self.generation_config, 'cache_implementation') == 'hybrid': # mainly for Gemma2
425432
kwargs['past_key_values'] = self._get_hybrid_cache_for_llm(
426-
getattr(kwargs, "num_beams", 1), kwargs['max_new_tokens'] + inputs_embeds.shape[-2])
433+
getattr(kwargs, "num_beams", inputs_embeds.shape[0]), kwargs['max_new_tokens'] + inputs_embeds.shape[-2])
427434
self.get_llm()._supports_cache_class = True
428435
kwargs['cache_implementation'] = None
429436

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from transformers import AutoConfig, AutoModel
2+
from transformers import CLIPImageProcessor
3+
from .modeling_aimv2 import AIMv2Model
4+
from .base_visual_tokenizer import BaseVisualTokenizerConfig, BaseVisualTokenizer
5+
6+
MODEL_TYPE = "aimv2_visual_tokenizer"
7+
8+
9+
class Aimv2VisualTokenizerConfig(BaseVisualTokenizerConfig):
10+
model_type = MODEL_TYPE
11+
12+
def __init__(self, **kwargs):
13+
super().__init__(**kwargs)
14+
if self.drop_cls_token:
15+
self.drop_cls_token = False
16+
if self.depths:
17+
assert len(self.depths) == 1
18+
self.backbone_kwargs['num_hidden_layers'] = self.depths[0]
19+
20+
21+
class Aimv2VisualTokenizer(BaseVisualTokenizer):
22+
config_class = Aimv2VisualTokenizerConfig
23+
supports_gradient_checkpointing = True
24+
_no_split_modules = ["AIMv2ViTPreprocessor", "AIMv2Block"]
25+
_image_processor_class = CLIPImageProcessor
26+
_image_processor_kwargs = dict(do_center_crop=False)
27+
_backbone_class = AIMv2Model
28+
29+
def get_monitor_tensors(self):
30+
return dict(
31+
backbone_bottom=self.backbone.trunk.blocks[0].attn.qkv.weight,
32+
backbone_top=self.backbone.trunk.blocks[-1].attn.qkv.weight,
33+
head=self.head[0].weight
34+
)
35+
36+
def get_image_size(self):
37+
height = self.image_processor.crop_size["height"]
38+
width = self.image_processor.crop_size["width"]
39+
return height, width
40+
41+
42+
AutoConfig.register(MODEL_TYPE, Aimv2VisualTokenizerConfig)
43+
AutoModel.register(Aimv2VisualTokenizerConfig, Aimv2VisualTokenizer)

0 commit comments

Comments
 (0)