Skip to content

Commit 80b086f

Browse files
Ensure TopKEncoder has correct outputs when model is saved (#1225)
* Remove output_names from base BaseModel. * Add assertion for output signature of saved model to test_topk_encoder * Move compile method from BaseModel to Model * Correct name of structured outputs * Move compile method back and add special case for TopKOutput
1 parent 16d289a commit 80b086f

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

merlin/models/tf/models/base.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
from merlin.models.tf.outputs.base import ModelOutput, ModelOutputType
7070
from merlin.models.tf.outputs.classification import CategoricalOutput
7171
from merlin.models.tf.outputs.contrastive import ContrastiveOutput
72+
from merlin.models.tf.outputs.topk import TopKOutput
7273
from merlin.models.tf.prediction_tasks.base import ParallelPredictionBlock, PredictionTask
7374
from merlin.models.tf.transforms.features import PrepareFeatures, expected_input_cols_from_schema
7475
from merlin.models.tf.transforms.sequence import SequenceTransform
@@ -446,7 +447,10 @@ def compile(
446447
if num_v1_blocks > 0:
447448
self.output_names = [task.task_name for task in self.prediction_tasks]
448449
else:
449-
self.output_names = [block.full_name for block in self.model_outputs]
450+
if num_v2_blocks == 1 and isinstance(self.model_outputs[0], TopKOutput):
451+
pass
452+
else:
453+
self.output_names = [block.full_name for block in self.model_outputs]
450454

451455
# This flag will make Keras change the metric-names which is not needed in v2
452456
from_serialized = kwargs.pop("from_serialized", num_v2_blocks > 0)

tests/unit/tf/core/test_encoder.py

+9
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,15 @@ def test_topk_encoder(music_streaming_data: Dataset):
122122
loaded_topk_encoder = tf.keras.models.load_model(tmpdir)
123123
batch_output = loaded_topk_encoder(batch[0])
124124

125+
output_signature = loaded_topk_encoder.signatures["serving_default"].structured_outputs
126+
assert len(output_signature) == 2
127+
assert output_signature["scores"] == tf.TensorSpec(
128+
shape=(None, TOP_K), dtype=tf.float32, name="scores"
129+
)
130+
assert output_signature["identifiers"] == tf.TensorSpec(
131+
shape=(None, TOP_K), dtype=tf.int32, name="identifiers"
132+
)
133+
125134
assert list(batch_output.scores.shape) == [BATCH_SIZE, TOP_K]
126135
tf.debugging.assert_equal(
127136
topk_encoder.topk_layer._candidates,

0 commit comments

Comments
 (0)