Skip to content

Commit 18acd45

Browse files
authored
support phi3 (#1064)
1 parent 022c8b0 commit 18acd45

File tree

9 files changed

+1313
-0
lines changed

9 files changed

+1313
-0
lines changed

SUPPORT_MODEL.md

+1
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ The table below represents the current support in the library for each of those
8686
| OpenELM |||
8787
| OPT |||
8888
| Phi2 |||
89+
| Phi3 |||
8990
| Pagasus |||
9091
| Pop2piano |||
9192
| Qwen2 |||

llm/inference/phi3/run_phi3.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import mindspore
2+
from mindnlp.transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3+
4+
mindspore.set_seed(0)
5+
6+
model = AutoModelForCausalLM.from_pretrained(
7+
"microsoft/Phi-3-mini-128k-instruct",
8+
ms_dtype="auto",
9+
)
10+
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
11+
12+
messages = [
13+
{"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"},
14+
{"role": "assistant", "content": "Sure! Here are some ways to eat bananas and dragonfruits together: 1. Banana and dragonfruit smoothie: Blend bananas and dragonfruits together with some milk and honey. 2. Banana and dragonfruit salad: Mix sliced bananas and dragonfruits together with some lemon juice and honey."},
15+
{"role": "user", "content": "What about solving an 2x + 3 = 7 equation?"},
16+
]
17+
18+
pipe = pipeline(
19+
"text-generation",
20+
model=model,
21+
tokenizer=tokenizer,
22+
)
23+
24+
generation_args = {
25+
"max_new_tokens": 500,
26+
"return_full_text": False,
27+
"temperature": 0.0,
28+
"do_sample": False,
29+
}
30+
31+
output = pipe(messages, **generation_args)
32+
print(output[0]['generated_text'])

mindnlp/transformers/models/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
opt,
9999
pegasus,
100100
phi,
101+
phi3,
101102
pop2piano,
102103
qwen2,
103104
qwen2_moe,
@@ -204,6 +205,7 @@
204205
from .opt import *
205206
from .pegasus import *
206207
from .phi import *
208+
from .phi3 import *
207209
from .pop2piano import *
208210
from .qwen2 import *
209211
from .qwen2_moe import *
@@ -310,6 +312,7 @@
310312
__all__.extend(opt.__all__)
311313
__all__.extend(pegasus.__all__)
312314
__all__.extend(phi.__all__)
315+
__all__.extend(phi3.__all__)
313316
__all__.extend(pop2piano.__all__)
314317
__all__.extend(qwen2.__all__)
315318
__all__.extend(qwen2_moe.__all__)

mindnlp/transformers/models/auto/configuration_auto.py

+2
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
("opt", "OPTConfig"),
8989
("pegasus", "PegasusConfig"),
9090
("phi", "PhiConfig"),
91+
("phi3", "Phi3Config"),
9192
("qwen2", "Qwen2Config"),
9293
("qwen2_moe", "Qwen2MoeConfig"),
9394
("reformer", "ReformerConfig"),
@@ -497,6 +498,7 @@
497498
("perceiver", "Perceiver"),
498499
("persimmon", "Persimmon"),
499500
("phi", "Phi"),
501+
("phi3", "Phi3"),
500502
("phobert", "PhoBERT"),
501503
("pix2struct", "Pix2Struct"),
502504
("plbart", "PLBart"),

mindnlp/transformers/models/auto/modeling_auto.py

+4
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
("opt", "OPTModel"),
8787
("pegasus", "PegasusModel"),
8888
("phi", "PhiModel"),
89+
("phi3", "Phi3Model"),
8990
("qwen2", "Qwen2Model"),
9091
("qwen2_moe", "Qwen2MoeModel"),
9192
("reformer", "ReformerModel"),
@@ -179,6 +180,7 @@
179180
("opt", "OPTForCausalLM"),
180181
("pegasus", "PegasusForCausalLM"),
181182
("phi", "PhiForCausalLM"),
183+
("phi3", "Phi3ForCausalLM"),
182184
("qwen2", "Qwen2ForCausalLM"),
183185
("qwen2_moe", "Qwen2MoeForCausalLM"),
184186
("reformer", "ReformerModelWithLMHead"),
@@ -392,6 +394,7 @@
392394
("mixtral", "MixtralForSequenceClassification"),
393395
("opt", "OPTForSequenceClassification"),
394396
("phi", "PhiForSequenceClassification"),
397+
("phi3", "Phi3ForSequenceClassification"),
395398
("qwen2", "Qwen2ForSequenceClassification"),
396399
("qwen2_moe", "Qwen2MoeForSequenceClassification"),
397400
("reformer", "ReformerForSequenceClassification"),
@@ -538,6 +541,7 @@
538541
("nezha", "NezhaForTokenClassification"),
539542
("nystromformer", "NystromformerForTokenClassification"),
540543
("phi", "PhiForTokenClassification"),
544+
("phi3", "Phi3ForTokenClassification"),
541545
("qdqbert", "QDQBertForTokenClassification"),
542546
("rembert", "RemBertForTokenClassification"),
543547
("roberta", "RobertaForTokenClassification"),

mindnlp/transformers/models/auto/tokenization_auto.py

+1
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@
347347
"LlamaTokenizerFast" if is_tokenizers_available() else None,
348348
),
349349
),
350+
("phi3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
350351
("phi", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)),
351352
("phobert", ("PhobertTokenizer", None)),
352353
# ("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright 2024 Huawei Technologies Co., Ltd
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
"""
16+
Phi3 Model init
17+
"""
18+
from . import configuration_phi3, modeling_phi3
19+
20+
from .configuration_phi3 import *
21+
from .modeling_phi3 import *
22+
23+
__all__ = []
24+
__all__.extend(configuration_phi3.__all__)
25+
__all__.extend(modeling_phi3.__all__)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
# coding=utf-8
2+
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
""" Phi-3 model configuration"""
17+
18+
from ...configuration_utils import PretrainedConfig
19+
from ....utils import logging
20+
21+
22+
logger = logging.get_logger(__name__)
23+
24+
PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP = {
25+
"microsoft/Phi-3-mini-4k-instruct": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/config.json",
26+
"microsoft/Phi-3-mini-128k-instruct": "https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/resolve/main/config.json",
27+
}
28+
29+
30+
class Phi3Config(PretrainedConfig):
31+
r"""
32+
This is the configuration class to store the configuration of a [`Phi3Model`]. It is used to instantiate a Phi-3
33+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
34+
defaults will yield a similar configuration to that of the
35+
[microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct).
36+
37+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38+
documentation from [`PretrainedConfig`] for more information.
39+
40+
Args:
41+
vocab_size (`int`, *optional*, defaults to 32064):
42+
Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the
43+
`inputs_ids` passed when calling [`Phi3Model`].
44+
hidden_size (`int`, *optional*, defaults to 3072):
45+
Dimension of the hidden representations.
46+
intermediate_size (`int`, *optional*, defaults to 8192):
47+
Dimension of the MLP representations.
48+
num_hidden_layers (`int`, *optional*, defaults to 32):
49+
Number of hidden layers in the Transformer decoder.
50+
num_attention_heads (`int`, *optional*, defaults to 32):
51+
Number of attention heads for each attention layer in the Transformer decoder.
52+
num_key_value_heads (`int`, *optional*):
53+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
54+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
55+
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
56+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
57+
by meanpooling all the original heads within that group. For more details checkout [this
58+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
59+
`num_attention_heads`.
60+
resid_pdrop (`float`, *optional*, defaults to 0.0):
61+
Dropout probability for mlp outputs.
62+
embd_pdrop (`int`, *optional*, defaults to 0.0):
63+
The dropout ratio for the embeddings.
64+
attention_dropout (`float`, *optional*, defaults to 0.0):
65+
The dropout ratio after computing the attention scores.
66+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
67+
The non-linear activation function (function or string) in the decoder.
68+
max_position_embeddings (`int`, *optional*, defaults to 4096):
69+
The maximum sequence length that this model might ever be used with.
70+
original_max_position_embeddings (`int`, *optional*, defaults to 4096):
71+
The maximum sequence length that this model was trained with. This is used to determine the size of the
72+
original RoPE embeddings when using long scaling.
73+
initializer_range (`float`, *optional*, defaults to 0.02):
74+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
75+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
76+
The epsilon value used for the RMSNorm.
77+
use_cache (`bool`, *optional*, defaults to `True`):
78+
Whether or not the model should return the last key/values attentions (not used by all models). Only
79+
relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
80+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
81+
Whether to tie weight embeddings
82+
rope_theta (`float`, *optional*, defaults to 10000.0):
83+
The base period of the RoPE embeddings.
84+
rope_scaling (`dict`, *optional*):
85+
The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
86+
contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be either `su` or `yarn` and
87+
the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
88+
divided by the number of attention heads divided by 2.
89+
bos_token_id (`int`, *optional*, defaults to 1):
90+
The id of the "beginning-of-sequence" token.
91+
eos_token_id (`int`, *optional*, defaults to 32000):
92+
The id of the "end-of-sequence" token.
93+
pad_token_id (`int`, *optional*, defaults to 32000):
94+
The id of the padding token.
95+
sliding_window (`int`, *optional*):
96+
Sliding window attention window size. If `None`, no sliding window is applied.
97+
98+
Example:
99+
100+
```python
101+
>>> from transformers import Phi3Model, Phi3Config
102+
103+
>>> # Initializing a Phi-3 style configuration
104+
>>> configuration = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
105+
106+
>>> # Initializing a model from the configuration
107+
>>> model = Phi3Model(configuration)
108+
109+
>>> # Accessing the model configuration
110+
>>> configuration = model.config
111+
```"""
112+
113+
model_type = "phi3"
114+
keys_to_ignore_at_inference = ["past_key_values"]
115+
116+
def __init__(
117+
self,
118+
vocab_size=32064,
119+
hidden_size=3072,
120+
intermediate_size=8192,
121+
num_hidden_layers=32,
122+
num_attention_heads=32,
123+
num_key_value_heads=None,
124+
resid_pdrop=0.0,
125+
embd_pdrop=0.0,
126+
attention_dropout=0.0,
127+
hidden_act="silu",
128+
max_position_embeddings=4096,
129+
original_max_position_embeddings=4096,
130+
initializer_range=0.02,
131+
rms_norm_eps=1e-5,
132+
use_cache=True,
133+
tie_word_embeddings=False,
134+
rope_theta=10000.0,
135+
rope_scaling=None,
136+
bos_token_id=1,
137+
eos_token_id=32000,
138+
pad_token_id=32000,
139+
sliding_window=None,
140+
**kwargs,
141+
):
142+
self.vocab_size = vocab_size
143+
self.hidden_size = hidden_size
144+
self.intermediate_size = intermediate_size
145+
self.num_hidden_layers = num_hidden_layers
146+
self.num_attention_heads = num_attention_heads
147+
148+
if num_key_value_heads is None:
149+
num_key_value_heads = num_attention_heads
150+
151+
self.num_key_value_heads = num_key_value_heads
152+
self.resid_pdrop = resid_pdrop
153+
self.embd_pdrop = embd_pdrop
154+
self.attention_dropout = attention_dropout
155+
self.hidden_act = hidden_act
156+
self.max_position_embeddings = max_position_embeddings
157+
self.original_max_position_embeddings = original_max_position_embeddings
158+
self.initializer_range = initializer_range
159+
self.rms_norm_eps = rms_norm_eps
160+
self.use_cache = use_cache
161+
self.rope_theta = rope_theta
162+
self.rope_scaling = rope_scaling
163+
self._rope_scaling_validation()
164+
self.sliding_window = sliding_window
165+
166+
super().__init__(
167+
bos_token_id=bos_token_id,
168+
eos_token_id=eos_token_id,
169+
pad_token_id=pad_token_id,
170+
tie_word_embeddings=tie_word_embeddings,
171+
**kwargs,
172+
)
173+
174+
def _rope_scaling_validation(self):
175+
"""
176+
Validate the `rope_scaling` configuration.
177+
"""
178+
if self.rope_scaling is None:
179+
return
180+
181+
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3:
182+
raise ValueError(
183+
"`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, "
184+
f"got {self.rope_scaling}"
185+
)
186+
rope_scaling_type = self.rope_scaling.get("type", None)
187+
rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
188+
rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
189+
if rope_scaling_type is None or rope_scaling_type not in ["su", "yarn"]:
190+
raise ValueError(f"`rope_scaling`'s type field must be one of ['su', 'yarn'], got {rope_scaling_type}")
191+
if not (
192+
isinstance(rope_scaling_short_factor, list)
193+
and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
194+
):
195+
raise ValueError(
196+
f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
197+
)
198+
if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2:
199+
raise ValueError(
200+
f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}"
201+
)
202+
if not (
203+
isinstance(rope_scaling_long_factor, list)
204+
and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
205+
):
206+
raise ValueError(
207+
f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
208+
)
209+
if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2:
210+
raise ValueError(
211+
f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}"
212+
)
213+
214+
__all__ = ['Phi3Config']

0 commit comments

Comments
 (0)