Skip to content

Commit

Permalink
Ahead of time compilation with auto layouts for parameters and decode.
Browse files Browse the repository at this point in the history
  • Loading branch information
s2bk committed Feb 8, 2025
1 parent ff59fb3 commit 63a2ea7
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 62 deletions.
82 changes: 56 additions & 26 deletions MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,32 +39,32 @@
# pylint: disable=too-many-positional-arguments


def prefill_benchmark_loop(engine, params, tokens, true_length, iters):
def prefill_benchmark_loop(engine_prefill, params, tokens, true_length, iters):
"""Inner loop for benchmarking prefill step."""
start = datetime.datetime.now()
rng = jax.random.PRNGKey(1234)
prefill_result = None
for _ in range(iters):
rng, rng_prefill = jax.random.split(rng)
prefill_result, _ = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length, rng=rng_prefill)
prefill_result, _ = engine_prefill(params, tokens, true_length, rng_prefill)
jax.block_until_ready(prefill_result)
end = datetime.datetime.now()
del prefill_result
return (end - start).total_seconds()


def prefill_benchmark(config, engine, params, tokens, true_length, num_model_params, iters):
def prefill_benchmark(config, engine_prefill, params, tokens, true_length, num_model_params, iters):
"""Handles warmup, running prefill benchmark, and printing results."""
rng = jax.random.PRNGKey(1234)
prefill_result = None
for _ in range(_WARMUP_ITERS):
rng, rng_prefill = jax.random.split(rng)
prefill_result, _ = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length, rng=rng_prefill)
prefill_result, _ = engine_prefill(params, tokens, true_length, rng_prefill)
jax.block_until_ready(prefill_result)
del prefill_result

print(f"Prefill benchmark results for length {tokens.size}:\n")
time_in_s = prefill_benchmark_loop(engine, params, tokens, true_length, iters)
time_in_s = prefill_benchmark_loop(engine_prefill, params, tokens, true_length, iters)
prefill_average_ms = 1000 * time_in_s / iters
prefill_tflops_per_device, _, _ = maxtext_utils.calculate_prefill_tflops_per_device(num_model_params, tokens.size, config)
tflops_per_sec_per_device = prefill_tflops_per_device / prefill_average_ms * 1000.0
Expand All @@ -82,7 +82,7 @@ def prefill_benchmark(config, engine, params, tokens, true_length, num_model_par


def prefill_insert_benchmark_loop(
config, engine, decode_state, params, total_slots, tokens, true_length, iters, profile_name
config, engine_insert, decode_state, params, total_slots, tokens, true_length, iters, profile_name
):
"""Inner loop for benchmarking prefill and insert step."""
prof = profiler.Profiler(config)
Expand All @@ -91,59 +91,57 @@ def prefill_insert_benchmark_loop(
rng = jax.random.PRNGKey(1234)
for i in range(iters):
rng, rng_prefill = jax.random.split(rng)
prefill_result, _ = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length, rng=rng_prefill)
decode_state = engine.insert(prefill_result, decode_state, int(i % total_slots))
del prefill_result
decode_state = engine_insert(tokens, true_length, rng_prefill, decode_state, int(i % total_slots), params)
jax.block_until_ready(decode_state)
end = datetime.datetime.now()
prof.deactivate()
return (end - start).total_seconds(), decode_state


def prefill_insert_benchmark(config, engine, decode_state, params, total_slots, tokens, true_length, iters):
def prefill_insert_benchmark(config, engine_insert, decode_state, params, total_slots, tokens, true_length, iters):
"""Handles warmup, running insert benchmark, and printing results."""
rng = jax.random.PRNGKey(1234)
for i in range(_WARMUP_ITERS):
rng, rng_prefill = jax.random.split(rng)
prefill_result, _ = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length, rng=rng_prefill)
decode_state = engine.insert(prefill_result, decode_state, int(i % total_slots))
del prefill_result
decode_state = engine_insert(tokens, true_length, rng_prefill, decode_state, int(i % total_slots), params)
jax.block_until_ready(decode_state)

print(f"Prefill and insert benchmark results for length {tokens.size}:\n")
time_in_s, decode_state = prefill_insert_benchmark_loop(
config, engine, decode_state, params, total_slots, tokens, true_length, iters, f"prefill_insert_{tokens.size}"
config, engine_insert, decode_state, params, total_slots, tokens, true_length, iters, f"prefill_insert_{tokens.size}"
)
prefill_insert_average_ms = time_in_s / iters * 1000.0
print(f"\tPrefill + Insert step average time: {prefill_insert_average_ms:.3f} ms\n\n\n\n")
result_dict = {"time_in_ms": prefill_insert_average_ms}
return result_dict, decode_state


def ar_benchmark_loop(config, engine, params, decode_state, iters, profile_name):
def ar_benchmark_loop(config, engine_generate, params, decode_state, iters, profile_name):
"""Inner loop for benchmarking ar step."""
prof = profiler.Profiler(config)
prof.activate(optional_postfix=profile_name)
start = datetime.datetime.now()
rng = jax.random.PRNGKey(1234)
for _ in range(iters):
rng, rng_generate = jax.random.split(rng)
decode_state, _ = engine.generate(params, decode_state, rng=rng_generate)
decode_state, _ = engine_generate(params, decode_state, rng_generate)
jax.block_until_ready(decode_state)
end = datetime.datetime.now()
prof.deactivate()
return (end - start).total_seconds(), decode_state


def ar_benchmark(config, engine, params, decode_state, global_batch_size, cache_size, model_size, iters):
def ar_benchmark(config, engine_generate, params, decode_state, global_batch_size, cache_size, model_size, iters):
"""Handles warmup, running ar benchmark, and printing results."""
rng = jax.random.PRNGKey(1234)
for _ in range(_WARMUP_ITERS):
rng, rng_generate = jax.random.split(rng)
decode_state, _ = engine.generate(params, decode_state, rng=rng_generate)
decode_state, _ = engine_generate(params, decode_state, rng_generate)
jax.block_until_ready(decode_state)

time_in_s, decode_state = ar_benchmark_loop(config, engine, params, decode_state, iters, profile_name="autoregress")
time_in_s, decode_state = ar_benchmark_loop(
config, engine_generate, params, decode_state, iters, profile_name="autoregress"
)
seconds_per_step = time_in_s / iters
ar_average_ms = seconds_per_step * 1000
total_throughput = global_batch_size / seconds_per_step
Expand Down Expand Up @@ -224,11 +222,11 @@ def print_results_for_analyze(results):
print(f"SYSTEM_TIME_PER_DECODE_TOKEN_MS = {results['autoregressive']['step_in_ms_per_seq']}")


def summarize_prefill_result(engine, params, tokens, true_length):
def summarize_prefill_result(engine_prefill, params, tokens, true_length):
"""Summarize Prefill result."""
print(f"Prefill result of length {tokens.size}:\n")
rng = jax.random.PRNGKey(1234)
prefill_result, _ = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length, rng=rng)
prefill_result, _ = engine_prefill(params, tokens, true_length, rng)
jax.block_until_ready(prefill_result)
num_prefill_logits_params, total_prefill_logits_size, avg_prefill_logits_param_size = max_utils.summarize_pytree_data(
prefill_result["logits"], name="Prefill Logits", raw=True
Expand Down Expand Up @@ -261,7 +259,10 @@ def run_benchmarks(config):
metadata = engine.get_tokenizer()
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)
rng, rng_init_decode = jax.random.split(rng)
decode_state = engine.init_decode_state(rng_init_decode)

generate_executable, params, decode_state_executable = engine.aot_compile(params, pass_rng_shape=True)
decode_state = decode_state_executable(rng_init_decode)

_, cache_size, _ = max_utils.summarize_pytree_data(decode_state["cache"], name="Cache")
num_model_params, model_size, _ = max_utils.summarize_pytree_data(params, name="Model")

Expand All @@ -273,19 +274,41 @@ def run_benchmarks(config):
benchmark_results["insert"] = {}
prefill_tokens = {}
prefill_true_lengths = {}
prefill_executable = {}
prefill_insert_executable = {}
i32_scalar = jax.ShapeDtypeStruct((), int)
rng_shape = jax.ShapeDtypeStruct([4], jax.numpy.dtype("uint32"))

for prefill_length in prefill_lengths:
prefill_tokens[prefill_length], prefill_true_lengths[prefill_length] = token_utils.tokenize_and_pad(
text, vocab, is_bos=True, prefill_lengths=[prefill_length]
)

key_shape = jax.ShapeDtypeStruct([prefill_length], jax.numpy.dtype("int32"))
prefill_executable[prefill_length] = (
jax.jit(
engine.prefill_aot,
in_shardings=(engine.param_layouts, None, None, None),
).lower(params, key_shape, i32_scalar, rng_shape)
).compile(compiler_options=None)

prefill_insert_executable[prefill_length] = (
jax.jit(
engine.prefill_insert,
in_shardings=(None, None, None, engine.decode_state_layouts, None, engine.param_layouts),
out_shardings=(engine.decode_state_layouts),
donate_argnames=("decode_state",),
).lower(key_shape, i32_scalar, rng_shape, engine.decode_state_shapes, i32_scalar, params)
).compile(compiler_options=None)

benchmark_results["prefill-result-sizes"][prefill_length] = summarize_prefill_result(
engine, params, prefill_tokens[prefill_length], prefill_true_lengths[prefill_length]
prefill_executable[prefill_length], params, prefill_tokens[prefill_length], prefill_true_lengths[prefill_length]
)

for prefill_length in prefill_lengths:
benchmark_results["prefill"][prefill_length] = prefill_benchmark(
config,
engine,
prefill_executable[prefill_length],
params,
prefill_tokens[prefill_length],
prefill_true_lengths[prefill_length],
Expand All @@ -295,7 +318,7 @@ def run_benchmarks(config):

prefill_insert_time, decode_state = prefill_insert_benchmark(
config,
engine,
prefill_insert_executable[prefill_length],
decode_state,
params,
engine.max_concurrent_decodes,
Expand All @@ -310,7 +333,14 @@ def run_benchmarks(config):

if "generate" in stages_to_benchmark:
benchmark_results["autoregressive"], decode_state = ar_benchmark(
config, engine, params, decode_state, engine.max_concurrent_decodes, cache_size, model_size, benchmark_loop_iters
config,
generate_executable,
params,
decode_state,
engine.max_concurrent_decodes,
cache_size,
model_size,
benchmark_loop_iters,
)

results = collate_results(config, benchmark_results, model_size, cache_size, num_model_params)
Expand Down
100 changes: 71 additions & 29 deletions MaxText/inference_mlperf/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,12 @@ def __init__(self, engine: engine_api.Engine, params, base_engine: engine_api.En
self.live = False
self.engine = engine
self.decode_state = None
self.decode_state_executable = None
if params is None:
self.relayout_params = True
params = engine.load_params()
else:
self.relayout_params = False
rng = jax.random.PRNGKey(0)
set_engine_vars_from_base_engine(engine, base_engine, rng)
self.params = params
Expand All @@ -80,12 +83,21 @@ def __init__(self, engine: engine_api.Engine, params, base_engine: engine_api.En
self.detokenize_backlog = queue.Queue(10)
self.prefill_buckets = defaultdict(list)

self._decode_state_executable = None

def init_decode_state(self):
if self.decode_state is None:
self.decode_state = self.engine.init_decode_state()
assert self._decode_state_executable != None, "Decode state executable is none"
self.decode_state = self._decode_state_executable(None)

def warmup(self, max_length, warmup_samples):

self._cached_generate, self.params, self._decode_state_executable = self.engine.aot_compile(
self.params, pass_rng_shape=False
)

self.init_decode_state()

interesting_buckets = [
64,
128,
Expand All @@ -95,18 +107,31 @@ def warmup(self, max_length, warmup_samples):
2048,
4096,
]
i32_scalar = jax.ShapeDtypeStruct((), int)

for length in interesting_buckets:
if length > max_length:
break
log.info(f"Compiling prefill: {length}")
input_data = jax.ShapeDtypeStruct((length,), jnp.dtype("int32"))
self._cached_pref[length] = (
jax.jit(self._prefill_insert, donate_argnums=(4,))
.lower(self.params, tokens=input_data, slot=0, true_length=length - 1, decode_state=self.decode_state)
.compile()

insert_with_layout = jax.jit(
self._prefill_insert,
in_shardings=(self.engine.param_layouts, None, None, None, self.engine.decode_state_layouts),
out_shardings=(
None,
self.engine.decode_state_layouts,
),
donate_argnames=("decode_state"),
)
lowered_insert = insert_with_layout.lower(
self.params, input_data, i32_scalar, i32_scalar, self.engine.decode_state_shapes
)
self._cached_pref[length] = lowered_insert.compile(compiler_options=None)

if length == 64 or length == 1024:
continue

input_data_batch = jax.ShapeDtypeStruct((max_length,), jnp.dtype("int32"))
min_num_prompts = max_length // length
max_num_prompts = max_length // (length // 2)
Expand All @@ -116,6 +141,20 @@ def warmup(self, max_length, warmup_samples):
self._cached_pref_batch[(length, num_prompts)] = (
jax.jit(
self._prefill_insert_batch,
in_shardings=(
self.engine.param_layouts,
None,
None,
None,
None,
None,
None,
self.engine.decode_state_layouts,
),
out_shardings=(
None,
self.engine.decode_state_layouts,
),
static_argnames=(
"num_prompts",
"padded_length",
Expand All @@ -124,21 +163,19 @@ def warmup(self, max_length, warmup_samples):
)
.lower(
self.params,
tokens=input_data_batch,
slots=jnp.arange(0, 16, dtype=int),
num_prompts=num_prompts,
decoder_positions=jnp.arange(0, max_length, dtype=int),
decoder_segment_ids=jnp.ones(max_length, dtype=int),
start_pos=jnp.arange(0, max_length, 64, dtype=int),
padded_length=length,
true_lengths=jnp.full(16, length, dtype=int),
decode_state=self.decode_state,
input_data_batch,
jnp.arange(0, 16, dtype=int),
num_prompts,
jnp.arange(0, max_length, dtype=int),
jnp.ones(max_length, dtype=int),
jnp.arange(0, max_length, 64, dtype=int),
length,
jnp.full(16, length, dtype=int),
self.engine.decode_state_shapes,
)
.compile()
.compile(compiler_options=None)
)
self._cached_generate = (
jax.jit(self.engine.generate, donate_argnums=(1,)).lower(self.params, self.decode_state).compile()
)

self.batch_inference(warmup_samples, desc="warmup")

def _prefill_insert(self, params, tokens, slot, true_length, decode_state):
Expand Down Expand Up @@ -208,10 +245,11 @@ def prefill(prefill_bucket, prefill_len):
prefill_fn = self._prefill_insert
if (cached := self._cached_pref.get(prefill_len)) is not None:
prefill_fn = cached
else:
assert False, "prefill fn not found"

for slot, row in prefill_bucket:
first_token, self.decode_state = prefill_fn(
self.params, tokens=row.tokens, slot=slot, true_length=row.true_length, decode_state=self.decode_state
)
first_token, self.decode_state = prefill_fn(self.params, row.tokens, slot, row.true_length, self.decode_state)
prefill_result.append((first_token, slot, row))
return prefill_result
else:
Expand Down Expand Up @@ -250,16 +288,18 @@ def pad_num_prompts_len_array(array_to_pad, pad_len):
log.info(f"invoking compiled function with length {prefill_len} num_prompts {num_prompts}")
if (cached := self._cached_pref_batch.get((prefill_len, num_prompts))) is not None:
prefill_fn = cached
else:
assert False, "prefill batch not found"

first_tokens, self.decode_state = prefill_fn(
self.params,
tokens=tokens,
slots=slots,
decoder_positions=positions,
decoder_segment_ids=sequence_indicator,
start_pos=start_pos,
true_lengths=true_lengths,
decode_state=self.decode_state,
tokens,
slots,
positions,
sequence_indicator,
start_pos,
true_lengths,
self.decode_state,
) # pytype: disable=missing-parameter
prefill_result = [(first_tokens[idx], slot, row) for (idx, (slot, row)) in enumerate(prefill_bucket)]

Expand Down Expand Up @@ -297,9 +337,11 @@ def decode():
gen_fn = self.engine.generate
if self._cached_generate is not None:
gen_fn = self._cached_generate
else:
assert False, "no generate fn"
result_tokens_l = []
for i in range(5):
self.decode_state, result_tokens = gen_fn(self.params, self.decode_state)
self.decode_state, result_tokens = gen_fn(self.params, self.decode_state, None)
result_tokens_l.append(result_tokens)
for i in range(5):
# result_tokens.copy_to_host_async()
Expand Down
Loading

0 comments on commit 63a2ea7

Please sign in to comment.