Skip to content

Commit f4e3818

Browse files
aws-yishanmhannanjgaws
authored andcommitted
enforce max context_length_estimate <= max n_positions
GitOrigin-RevId: 3e9481845f4cb4ac6b0e6af94383e48bb56e4722
1 parent c9359b9 commit f4e3818

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

src/transformers_neuronx/module.py

+6
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,12 @@ def from_pretrained(cls, pretrained_model_path, *model_args, **kwargs):
398398
def _sanity_check(**kwargs):
399399
context_length_estimate = kwargs.get("context_length_estimate", None)
400400
n_positions = kwargs.get("n_positions", 2048)
401+
max_n_pos = max(n_positions) if isinstance(n_positions, list) else n_positions
402+
max_cle = max(context_length_estimate) if isinstance(context_length_estimate, list) else context_length_estimate
403+
# max_n_pos or max_cle could be None if customer intends to use defaults
404+
if isinstance(max_n_pos, int) and isinstance(max_cle, int):
405+
assert max_n_pos >= max_cle, \
406+
f"Max context_length_estimate {max_cle} cannot be more than max n_positions {max_n_pos}."
401407
neuron_config = kwargs.get("neuron_config", None)
402408
bsh_cache_layout = False
403409
if neuron_config is not None:

0 commit comments

Comments
 (0)