Skip to content

Commit 7242b6d

Browse files
authored
Merge pull request #6 from aws-neuron/1.3
Add v0.4.0 release content
2 parents 0d073c7 + 08d4bd9 commit 7242b6d

21 files changed

+738
-75
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ build
22
*.egg-info/
33
dist/
44
pip/
5+
.attach_pid*

Config

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
package.KaenaTransformers = {
22
# please make sure the major.minor version matches the __version__ string in version.py
3-
interfaces = (0.3);
3+
interfaces = (0.4);
44

55
build-system = custom-build;
66
build-tools = {
7-
0.3 = {
7+
0.4 = {
88
Python3PBuildTool = 2.0;
99
Python-wheel = 0.x;
1010
};

README.md

+188-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Transformers Neuron (``transformers-neuronx``) Developer Guide
22

3-
Transformers Neuron for Trn1/Inf2 is a software package that enables
3+
Transformers Neuron for Trn1 and Inf2 is a software package that enables
44
PyTorch users to perform large language model (LLM) inference on
55
second-generation Neuron hardware (See: [NeuronCore-v2](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/neuron-core-v2.html)).
66

@@ -29,7 +29,7 @@ new features are developed.
2929
To install the most rigorously tested stable release, use the PyPI pip wheel:
3030

3131
```
32-
pip install transformers-neuronx
32+
pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com
3333
```
3434

3535
## Development Version
@@ -158,10 +158,12 @@ API via the ``HuggingFaceGenerationModelAdapter`` adapter class. In the followin
158158
demonstrate how to run sampling with temperature using the ``GPT2`` model:
159159

160160
```
161+
import os
161162
from transformers_neuronx.gpt2.model import GPT2ForSampling
162163
from transformers_neuronx.generation_utils import HuggingFaceGenerationModelAdapter
163164
from transformers_neuronx.module import save_pretrained_split
164165
from transformers import AutoModelForCausalLM, AutoTokenizer
166+
os.environ['NEURON_CC_FLAGS'] = '--model-type=transformer-inference'
165167
166168
# Load and save the CPU model
167169
model_cpu = AutoModelForCausalLM.from_pretrained('gpt2')
@@ -193,18 +195,167 @@ sample_output = model.generate(
193195
print([tokenizer.decode(tok) for tok in sample_output])
194196
```
195197

196-
## Serialization support
198+
## int8 weight storage support
199+
200+
Transformers Neuron supports int8 weight storage for the `GPT2` model class.
201+
int8 weight storage can be used to reduce memory bandwidth usage to improve
202+
model performace. int8 weight storage support for additional model classes
203+
will be added in an uncoming relesae. In the following example we demonstrate
204+
how to apply int8 weight storage to the `GPT2` model via the
205+
`QuantizationConfig` and `NeuronConfig` configs:
206+
207+
```
208+
import os
209+
import torch
210+
from transformers_neuronx.gpt2.model import GPT2ForSampling
211+
from transformers_neuronx.module import save_pretrained_split
212+
from transformers_neuronx.config import NeuronConfig, QuantizationConfig
213+
from transformers import AutoModelForCausalLM, AutoTokenizer
214+
os.environ['NEURON_CC_FLAGS'] = '--model-type=transformer-inference'
215+
216+
# Cast attention and mlp layers to low precisions only; layernorms stay as f32
217+
def amp_callback(model, dtype):
218+
for block in model.transformer.h:
219+
block.attn.to(dtype)
220+
block.mlp.to(dtype)
221+
model.lm_head.to(dtype)
222+
223+
# Load and save the CPU model with bfloat16 casting
224+
model_cpu = AutoModelForCausalLM.from_pretrained('gpt2')
225+
amp_callback(model_cpu, torch.bfloat16)
226+
save_pretrained_split(model_cpu, 'gpt2-split')
227+
228+
# Set the weight storage config use int8 quantization and bf16 dequantization
229+
neuron_config = NeuronConfig(
230+
quant=QuantizationConfig(quant_dtype='s8', dequant_dtype='bf16'),
231+
)
232+
233+
# Create and compile the Neuron model
234+
model_neuron = GPT2ForSampling.from_pretrained('gpt2-split', batch_size=1, tp_degree=2, n_positions=256, amp='bf16', neuron_config=neuron_config)
235+
model_neuron.to_neuron()
236+
237+
# Get a tokenizer and exaple input
238+
tokenizer = AutoTokenizer.from_pretrained('gpt2')
239+
text = "Hello, I'm a language model,"
240+
encoded_input = tokenizer(text, return_tensors='pt')
241+
242+
# Run inference
243+
with torch.inference_mode():
244+
generated_sequence = model_neuron.sample(encoded_input.input_ids, sequence_length=256, start_ids=None)
245+
print([tokenizer.decode(tok) for tok in generated_sequence])
246+
247+
```
248+
249+
## Parallel Input Prompt Context Encoding
250+
251+
Transformers Neuron supports parallel input prompt context encoding for the `GPT2`
252+
model class. Parallel context encoding can be used to significantly reduce
253+
the latency of the input prompt context encoding before the autoregressive
254+
decoder token generation loop. Parallel context encoding support for additional
255+
model classes will be added in an uncoming release.
256+
257+
The `GPT2ForSamplingWithContextBroadcasting` class has a `context_length_estimate`
258+
variable that determines the number of input prompt tokens that will be processed in
259+
parallel. For optimal results, this should be set to a power of 2 that is
260+
closest to the most frequently seen input prompt length.
261+
In the following example we demonstrate how to apply parallel context encoding
262+
to the `GPT2` model via the `GPT2ForSamplingWithContextBroadcasting` class.
263+
In this example, we set the `context_length_estimate` to be 128, which is
264+
the closest power of 2 the length of the input prompt (97 tokens).
265+
266+
```
267+
import os
268+
import math
269+
import torch
270+
from transformers_neuronx.gpt2.model import GPT2ForSamplingWithContextBroadcasting
271+
from transformers_neuronx.module import save_pretrained_split
272+
from transformers import AutoModelForCausalLM, AutoTokenizer
273+
os.environ['NEURON_CC_FLAGS'] = '--model-type=transformer-inference' # Apply optimal
274+
275+
# Load and save the CPU model with bfloat16 casting
276+
model_cpu = AutoModelForCausalLM.from_pretrained('gpt2')
277+
save_pretrained_split(model_cpu, 'gpt2-split')
278+
279+
# Get a tokenizer and exaple input
280+
tokenizer = AutoTokenizer.from_pretrained('gpt2')
281+
text = "Hello, I'm a generative AI language model. Generative AI is a type of AI that can create new content and ideas, including conversations, stories, images, videos, and music. It is powered by large models that are pre-trained on vast amounts of data and commonly referred to as foundation models (FMs). With generative AI on AWS, you can reinvent your applications, create entirely new customer experiences, drive unprecedented levels of productivity, and transform your business. "
282+
encoded_input = tokenizer(text, return_tensors='pt')
283+
284+
# Set the number of tokens that will be processed in parallel
285+
prompt_len = encoded_input.input_ids.shape[1]
286+
context_length_estimate = int(2 ** math.ceil(math.log(prompt_len, 2))) # Use the closest power of two bucket size
287+
288+
# Create and compile the Neuron model
289+
model_neuron = GPT2ForSamplingWithContextBroadcasting.from_pretrained('gpt2-split', batch_size=1, tp_degree=2, n_positions=256, amp='bf16', context_length_estimate=context_length_estimate)
290+
model_neuron.to_neuron()
291+
292+
# Run inference
293+
with torch.inference_mode():
294+
generated_sequence = model_neuron.sample(encoded_input.input_ids, sequence_length=256, start_ids=None)
295+
print([tokenizer.decode(tok) for tok in generated_sequence])
296+
```
297+
298+
The `GPT2ForSamplingWithContextBroadcasting` class can also process
299+
an input prompt that has a different batch size from the batch size of the
300+
autoregressive decoder output. For example, an input prompt with batch size = 1 can
301+
be used to produce an output of batch size = 5 to generate multiple suggestions
302+
for the same input prompt. The input prompt batch size can be specified using
303+
the `prompt_batch_size` argument and the autoregressive decoder output batch
304+
size can be specified using the `batch_size` argument. In the following example
305+
we demonstrate how to apply parallel context encoding to the `GPT2` model
306+
to generate 5 outputs for a single input.
307+
308+
```
309+
import os
310+
import math
311+
import torch
312+
from transformers_neuronx.gpt2.model import GPT2ForSamplingWithContextBroadcasting
313+
from transformers_neuronx.module import save_pretrained_split
314+
from transformers import AutoModelForCausalLM, AutoTokenizer
315+
os.environ['NEURON_CC_FLAGS'] = '--model-type=transformer-inference'
316+
317+
# Load and save the CPU model with bfloat16 casting
318+
model_cpu = AutoModelForCausalLM.from_pretrained('gpt2')
319+
save_pretrained_split(model_cpu, 'gpt2-split')
320+
321+
# Get a tokenizer and exaple input
322+
tokenizer = AutoTokenizer.from_pretrained('gpt2')
323+
text = "Hello, I'm a generative AI language model. Generative AI is a type of AI that can create new content and ideas, including conversations, stories, images, videos, and music. It is powered by large models that are pre-trained on vast amounts of data and commonly referred to as foundation models (FMs). With generative AI on AWS, you can reinvent your applications, create entirely new customer experiences, drive unprecedented levels of productivity, and transform your business. "
324+
encoded_input = tokenizer(text, return_tensors='pt')
325+
326+
# Set the number of tokens that will be processed in parallel
327+
prompt_len = encoded_input.input_ids.shape[1]
328+
context_length_estimate = int(2 ** math.ceil(math.log(prompt_len, 2))) # Use the closest power of two bucket size
329+
330+
# Create and compile the Neuron model
331+
model_neuron = GPT2ForSamplingWithContextBroadcasting.from_pretrained('gpt2-split', prompt_batch_size=1, batch_size=5, tp_degree=2, n_positions=256, amp='bf16', context_length_estimate=context_length_estimate)
332+
model_neuron.to_neuron()
333+
334+
# Run inference
335+
with torch.inference_mode():
336+
generated_sequence = model_neuron.sample(encoded_input.input_ids, sequence_length=256, start_ids=None)
337+
for i, output in enumerate(generated_sequence):
338+
print('-'*50)
339+
print(f'Batch {i} output:')
340+
print(tokenizer.decode(output))
341+
```
342+
343+
344+
## [Experimental] Serialization support
197345

198346
Transformers Neuron supports model serialization (model saving and loading) for
199-
the ``GPT2`` model class. Serialization support for additional model classes
347+
the `GPT2` model class. Serialization support for additional model classes
200348
will be added in an uncoming relesae. In the following example we demonstrate
201-
how to save and load the ``GPT2`` model:
349+
how to save and load the `GPT2` model:
202350

203351
```
352+
import os
353+
import torch
204354
from transformers_neuronx.gpt2.model import GPT2ForSampling
205355
from transformers_neuronx.generation_utils import HuggingFaceGenerationModelAdapter
206356
from transformers_neuronx.module import save_pretrained_split
207357
from transformers import AutoModelForCausalLM, AutoTokenizer
358+
os.environ['NEURON_CC_FLAGS'] = '--model-type=transformer-inference'
208359
209360
# Load and save the CPU model
210361
model_cpu = AutoModelForCausalLM.from_pretrained('gpt2')
@@ -221,7 +372,39 @@ model_neuron._save_compiled_artifacts('gpt2-neuron')
221372
model_neuron = GPT2ForSampling.from_pretrained('gpt2-split', batch_size=1, tp_degree=2, n_positions=256, amp='f32', unroll=None)
222373
model_neuron._load_compiled_artifacts('gpt2-neuron') # Load the compiled Neuron artifacts
223374
model_neuron.to_neuron() # Load the model weights but skip compilation
375+
# Get a tokenizer and exaple input
376+
tokenizer = AutoTokenizer.from_pretrained('gpt2')
377+
text = "Hello, I'm a language model,"
378+
encoded_input = tokenizer(text, return_tensors='pt')
379+
380+
# Run inference
381+
with torch.inference_mode():
382+
generated_sequence = model_neuron.sample(encoded_input.input_ids, sequence_length=256, start_ids=None)
383+
print([tokenizer.decode(tok) for tok in generated_sequence])
384+
```
385+
386+
## model-type=transformer-inference Compiler Flag
387+
388+
We recommend using the `--model-type=transformer-inference` compiler flag for optimized
389+
decoder-only LLM inference. In a future release, this compiler flag may be enabled
390+
by default. This compiler flag can be enabled via the `NEURON_CC_FLAGS` environment
391+
variable:
392+
224393
```
394+
export NEURON_CC_FLAGS="--model-type=transformer-inference"
395+
```
396+
397+
## Running inference with multiple models
398+
399+
Multiple transformers-neuronx models can be loaded at the same time as long
400+
as the total number of consumed NeuronCores is less than or equal to the total
401+
number of NeuronCores on the instance. For example, three tp-degree=8 models can be
402+
loaded and run in parallel on an inf2.48xlarge which has 24 NeuronCores. The
403+
`NEURON_RT_NUM_CORES` and `NEURON_RT_VISIBLE_CORES` environment variables
404+
can be used to allocate the necessary number of NeuronCores to each process
405+
to run multiple transformers-neuronx models in parallel. See the
406+
[NeuronCore Allocation and Model Placement for Inference](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/programming-guide/inference/core-placement.html#torch-neuronx-core-placement-guide)
407+
section for additional information about how to use these environment variables.
225408

226409
# Examples
227410

releasenotes.md

+19
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,22 @@
1+
# Transformers Neuron 0.4.0 Release Notes
2+
3+
Date: 2023-06-12
4+
5+
## What's New?
6+
7+
- Added ``int8`` weight storage for `GPT2` models.
8+
- Improved prompt context encoding performance for `GPT2` models.
9+
- Improved collective communications performance for tp-degrees 4, 8, and 24 on Inf2.
10+
- Improved collective communications performance for tp-degrees 8 and 32 on Trn1.
11+
- Support for the ``--model-type=transformer-inference`` compiler flag for optimized decoder-only LLM inference.
12+
13+
## Bug Fixes
14+
15+
- Added padding to the `GPT-J` ``linear`` layer to correctly handle odd vocabulary sizes.
16+
- Issues where the HuggingFace `generate` method produces incorrect results when
17+
`beam_search` is used have been resolved.
18+
19+
120
# Transformers Neuron 0.3.0 Release Notes
221

322
Date: 2023-04-28

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def get_version():
6161
'gpt2_generation_demo=transformers_neuronx.gpt2.generation_demo:main',
6262
'gpt2_demo=transformers_neuronx.gpt2.demo:main',
6363
'gptj_demo=transformers_neuronx.gptj.demo:main',
64+
'gptneox_demo=transformers_neuronx.gptneox.demo:main',
6465
'opt_demo=transformers_neuronx.opt.demo:main',
6566
'opt_gen_random_pretrained=transformers_neuronx.opt.gen_random_pretrained:main',
6667
'gen_randn_hlo_snapshot=transformers_neuronx.tools.gen_hlo_snapshot:main_randn',

src/transformers_neuronx/activations.py

+9
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,12 @@ def relu(hidden):
4646
zero = dtype.Constant(constant_value=0.0)
4747
zero_br = dtype[sizes].Broadcast(zero, dimensions=[])
4848
return dtype[sizes].Maximum(hidden, zero_br)
49+
50+
51+
def sigmoid(tensor):
52+
return tensor.dtype[tensor.sizes].Logistic(tensor)
53+
54+
55+
def silu(tensor):
56+
logistic = sigmoid(tensor)
57+
return tensor.dtype[tensor.sizes].Multiply(tensor, logistic)

src/transformers_neuronx/compiler.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import subprocess
1818
import tarfile
1919
import tempfile
20+
import numpy as np
2021
from textwrap import dedent
2122
import torch
2223
from torch_neuronx.pyhlo import xla_data_pb2
@@ -243,12 +244,39 @@ def get_debug_tensors(self):
243244

244245

245246
class ParallelKernel:
246-
247+
hlo_snapshot_iter = 0
247248
def __init__(self, hlo_module, tp_degree):
248249
self.hlo_module = hlo_module
249250
self.tp_degree = tp_degree
250251
self.neff_bytes = None
251252
self.model = None
253+
self.hlo_snapshot = None
254+
self.generate_hlo_snapshot()
255+
256+
257+
def generate_hlo_snapshot(self, tensors=None):
258+
if tensors is None:
259+
self.hlo_snapshot_folder = os.environ.get("HLO_SNAPSHOT_PATH", None)
260+
self.hlo_snapshot = self.hlo_snapshot_folder is not None
261+
if self.hlo_snapshot:
262+
os.makedirs(f"{self.hlo_snapshot_folder}", exist_ok=True)
263+
elif self.hlo_snapshot:
264+
folder = os.path.join(self.hlo_snapshot_folder, f"iter{ParallelKernel.hlo_snapshot_iter}")
265+
os.makedirs(folder, exist_ok=True)
266+
for i, tensor in enumerate(tensors):
267+
filename = os.path.join(folder, f"{i}.npy")
268+
tensor_cpu = ops.parallel_cpu(tensor)
269+
if isinstance(tensor_cpu, list):
270+
tensor_cpu = tensor_cpu[0]
271+
if tensor_cpu.dtype == torch.bfloat16:
272+
tensor_cpu = tensor_cpu.view(torch.int16)
273+
tensor_cpu = tensor_cpu.numpy()
274+
tensor_cpu = tensor_cpu.view('|V2')
275+
else:
276+
tensor_cpu = tensor_cpu.detach().numpy()
277+
np.save(filename, tensor_cpu)
278+
ParallelKernel.hlo_snapshot_iter += 1
279+
252280

253281
def build_memory(self):
254282
return ParallelMemory(self.hlo_module, self.tp_degree)
@@ -264,6 +292,8 @@ def load(self):
264292
self.model.load()
265293

266294
def __call__(self, memory):
295+
if self.hlo_snapshot:
296+
self.generate_hlo_snapshot(memory.input_tensors)
267297
return ops.parallel_run(self.model, memory.inputs, memory.outputs)
268298

269299

0 commit comments

Comments
 (0)