diff --git a/keras_cv_attention_models/common_layers.py b/keras_cv_attention_models/common_layers.py index bd34a72..e04232e 100644 --- a/keras_cv_attention_models/common_layers.py +++ b/keras_cv_attention_models/common_layers.py @@ -725,7 +725,7 @@ def __call__(self, image, resize_method="bilinear", resize_antialias=False, inpu return images -def add_pre_post_process(model, rescale_mode="tf", input_shape=None, post_process=None, featrues=None): +def add_pre_post_process(model, rescale_mode="tf", input_shape=None, post_process=None, features=None): from keras_cv_attention_models.imagenet.eval_func import decode_predictions input_shape = model.input_shape[1:] if input_shape is None else input_shape @@ -733,5 +733,5 @@ def add_pre_post_process(model, rescale_mode="tf", input_shape=None, post_proces model.decode_predictions = decode_predictions if post_process is None else post_process model.rescale_mode = rescale_mode - if featrues is not None: + if features is not None: model.extract_features = lambda: features