Skip to content

Commit

Permalink
Merge pull request #159 from gjp4tw/main
Browse files Browse the repository at this point in the history
fix typo
  • Loading branch information
leondgarse authored Apr 9, 2024
2 parents a143ffb + 47c9c09 commit 488ecac
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions keras_cv_attention_models/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,13 +725,13 @@ def __call__(self, image, resize_method="bilinear", resize_antialias=False, inpu
return images


def add_pre_post_process(model, rescale_mode="tf", input_shape=None, post_process=None, featrues=None):
def add_pre_post_process(model, rescale_mode="tf", input_shape=None, post_process=None, features=None):
from keras_cv_attention_models.imagenet.eval_func import decode_predictions

input_shape = model.input_shape[1:] if input_shape is None else input_shape
model.preprocess_input = PreprocessInput(input_shape, rescale_mode=rescale_mode)
model.decode_predictions = decode_predictions if post_process is None else post_process
model.rescale_mode = rescale_mode

if featrues is not None:
if features is not None:
model.extract_features = lambda: features

0 comments on commit 488ecac

Please sign in to comment.