Skip to content

Commit e77237a

Browse files
karlhigleymarcromeynjperez999
authored
Make Model creation explicit (NVIDIA-Merlin#432)
* Remove the mypy plugin for model creation return type determination No longer needed if we make model creation explicit * Make `Model` creation explicit Currently, we're assuming that if you give us something that computes a loss, we should create a `Model`, which is a bit of black magic that has some weird edge cases. This change brings the API closer to standard Keras. We're also looking forward to future changes that make the `Model` responsible for computing losses and metrics, which will break this somewhat magical functionality. Co-authored-by: Marc Romeyn <[email protected]> * Remove unnecessary `ml.ModelContext` params * Switch dataset tests back to `connect` * Add tasks to masking tests * Set the block schema in all blocks inside `SequentialBlock` Co-authored-by: Julio Perez <[email protected]> * Fix the return type annotation on `connect` method Co-authored-by: Marc Romeyn <[email protected]> Co-authored-by: Julio Perez <[email protected]>
1 parent d169d33 commit e77237a

23 files changed

+73
-121
lines changed

docs/source/models_overview.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ High-level API:
4141
import merlin.models.tf as ml
4242

4343
block = ml.TwoTowerBlock(schema, ml.MLPBlock([512, 256]))
44-
model = block.connect(ml.ItemRetrievalTask())
44+
model = ml.Model(block, ml.ItemRetrievalTask())
4545
```
4646

4747
Low-level API:
@@ -53,7 +53,7 @@ from merlin.schema import Tags
5353
user_tower = ml.InputBlock(schema.select_by_tag(Tags.USER), ml.MLPBlock([512, 256]))
5454
item_tower = ml.InputBlock(schema.select_by_tag(Tags.ITEM), ml.MLPBlock([512, 256]))
5555
two_tower = ml.ParallelBlock({"user": user_tower, "item": item_tower})
56-
model = two_tower.connect(ml.ItemRetrievalTask())
56+
model = ml.Model(two_tower, ml.ItemRetrievalTask())
5757
```
5858

5959
## Ranking
@@ -78,7 +78,7 @@ dlrm = ml.DLRMBlock(
7878
bottom_block=ml.MLPBlock([512, 128]),
7979
top_block=ml.MLPBlock([512, 128])
8080
)
81-
model = dlrm.connect(ml.BinaryClassificationTask(schema))
81+
model = ml.Model(dlrm, ml.BinaryClassificationTask(schema))
8282
```
8383

8484
Low-level API:
@@ -140,7 +140,7 @@ inputs = ml.InputBlock(schema)
140140
prediction_tasks = ml.PredictionTasks(schema)
141141
block = ml.MLPBlock([64])
142142
mmoe = ml.MMOEBlock(prediction_tasks, expert_block=ml.MLPBlock([64]), num_experts=4)
143-
model = inputs.connect(block, mmoe, prediction_tasks)
143+
model = ml.Model(inputs, block, mmoe, prediction_tasks)
144144
```
145145

146146
### Progressive Layered Extraction
@@ -163,5 +163,5 @@ block = ml.MLPBlock([64])
163163
cgc = ml.CGCBlock(
164164
prediction_tasks, expert_block=ml.MLPBlock([64]), num_task_experts=2, num_shared_experts=2
165165
)
166-
model = inputs.connect(ml.MLPBlock([64]), cgc, prediction_tasks)
166+
model = ml.Model(inputs, ml.MLPBlock([64]), cgc, prediction_tasks)
167167
```

examples/06-Define-your-own-architecture-with-Merlin-Models.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -886,7 +886,7 @@
886886
}
887887
],
888888
"source": [
889-
"model = deep_dlrm_interaction.connect(binary_task)\n",
889+
"model = mm.Model(deep_dlrm_interaction, binary_task)\n",
890890
"type(model)"
891891
]
892892
},

merlin/models/mypy.py

Lines changed: 0 additions & 48 deletions
This file was deleted.

merlin/models/tf/blocks/core/base.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def connect(
385385
*block: Union[tf.keras.layers.Layer, str],
386386
block_name: Optional[str] = None,
387387
context: Optional[ModelContext] = None,
388-
) -> Union["SequentialBlock", "Model", "RetrievalModel"]:
388+
) -> "SequentialBlock":
389389
"""Connect the block to other blocks sequentially.
390390
391391
Parameters
@@ -399,8 +399,6 @@ def connect(
399399
400400
"""
401401
from merlin.models.tf.blocks.core.combinators import SequentialBlock
402-
from merlin.models.tf.models.base import Model, RetrievalBlock, RetrievalModel
403-
from merlin.models.tf.prediction_tasks.retrieval import ItemRetrievalTask
404402

405403
blocks = [self.parse(b) for b in block]
406404

@@ -413,16 +411,6 @@ def connect(
413411
[self, *blocks], copy_layers=False, block_name=block_name, context=context
414412
)
415413

416-
if isinstance(blocks[-1], ModelLikeBlock):
417-
if (
418-
any(isinstance(b, RetrievalBlock) for b in blocks)
419-
or isinstance(self, RetrievalBlock)
420-
and any(isinstance(b, ItemRetrievalTask) for b in blocks)
421-
):
422-
return RetrievalModel(output)
423-
424-
return Model(output)
425-
426414
return output
427415

428416
def connect_with_residual(

merlin/models/tf/blocks/core/combinators.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ def __init__(
8181
if getattr(layers[0], "has_schema", None):
8282
super().set_schema(layers[0].schema)
8383

84+
for layer in layers[1:]:
85+
if hasattr(layer, "set_schema"):
86+
layer.set_schema(layers[0].schema)
87+
8488
layers = copy.copy(layers) if copy_layers else layers
8589
if filter:
8690
if not isinstance(filter, Filter):

merlin/models/tf/models/benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,6 @@ def NCFModel(
9191
ncf = ParallelBlock({"mf": mf_branch, "mlp": mlp_branch}, aggregation="concat")
9292

9393
prediction_tasks = parse_prediction_tasks(schema, prediction_tasks)
94-
model = ncf.connect(prediction_tasks)
94+
model = Model(ncf, prediction_tasks)
9595

9696
return model

merlin/models/tf/models/ranking.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def DLRMModel(
6868
bottom_block=bottom_block,
6969
top_block=top_block,
7070
)
71-
model = dlrm_body.connect(prediction_tasks)
71+
model = Model(dlrm_body, prediction_tasks)
7272

7373
return model
7474

@@ -152,7 +152,7 @@ def DCNModel(
152152
else:
153153
dcn_body = input_block.connect_branch(CrossBlock(depth), deep_block, aggregation="concat")
154154

155-
model = dcn_body.connect(prediction_tasks)
155+
model = Model(dcn_body, prediction_tasks)
156156

157157
return model
158158

@@ -230,6 +230,6 @@ def DeepFMModel(
230230
)
231231

232232
prediction_tasks = parse_prediction_tasks(schema, prediction_tasks)
233-
model = deep_fm.connect(prediction_tasks)
233+
model = Model(deep_fm, prediction_tasks)
234234

235235
return model

merlin/models/tf/models/retrieval.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def MatrixFactorizationModel(
3636
metrics: MetricOrMetrics = ItemRetrievalTask.DEFAULT_METRICS,
3737
samplers: Sequence[ItemSampler] = (),
3838
**kwargs,
39-
) -> Union[Model, RetrievalModel]:
39+
) -> RetrievalModel:
4040
"""Builds a matrix factorization model.
4141
4242
Example Usage::
@@ -74,7 +74,7 @@ def MatrixFactorizationModel(
7474
7575
Returns
7676
-------
77-
Union[Model, RetrievalModel]
77+
RetrievalModel
7878
"""
7979

8080
if not prediction_tasks:
@@ -99,7 +99,7 @@ def MatrixFactorizationModel(
9999
**kwargs,
100100
)
101101

102-
model = two_tower.connect(prediction_tasks)
102+
model = RetrievalModel(two_tower, prediction_tasks)
103103

104104
return model
105105

@@ -125,7 +125,7 @@ def TwoTowerModel(
125125
metrics: MetricOrMetrics = ItemRetrievalTask.DEFAULT_METRICS,
126126
samplers: Sequence[ItemSampler] = (),
127127
**kwargs,
128-
) -> Union[Model, RetrievalModel]:
128+
) -> RetrievalModel:
129129
"""Builds the Two-tower architecture, as proposed in [1].
130130
131131
Example Usage::
@@ -178,7 +178,7 @@ def TwoTowerModel(
178178
179179
Returns
180180
-------
181-
Union[Model, RetrievalModel]
181+
RetrievalModel
182182
"""
183183

184184
if not prediction_tasks:
@@ -203,7 +203,7 @@ def TwoTowerModel(
203203
**kwargs,
204204
)
205205

206-
model = two_tower.connect(prediction_tasks)
206+
model = RetrievalModel(two_tower, prediction_tasks)
207207

208208
return model
209209

@@ -292,4 +292,6 @@ def YoutubeDNNRetrievalModel(
292292
num_sampled=num_sampled,
293293
)
294294

295-
return inputs.connect(top_block, task)
295+
# TODO: Figure out how to make this fit as
296+
# a RetrievalModel (which must have a RetrievalBlock)
297+
return Model(inputs, top_block, task)

merlin/models/tf/utils/testing_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def mark_run_eagerly_modes(*args, **kwargs):
3535

3636

3737
def assert_body_works_in_model(dataset, body, run_eagerly, num_epochs=5):
38-
model = body.connect(BinaryClassificationTask("click"))
38+
model = Model(body, BinaryClassificationTask("click"))
3939
model.compile(optimizer="adam", run_eagerly=run_eagerly)
4040

4141
losses = model.fit(dataset, batch_size=50, epochs=num_epochs)

setup.cfg

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ parentdir_prefix = merlin-models-
4343
[mypy]
4444
ignore_missing_imports = True
4545
no_implicit_optional = True
46-
plugins = merlin.models.mypy
4746

4847
[codespell]
4948
skip = .*pb2.py,./.git,./.github,./bench,./dist,./docs/build,.*egg-info.*,versioneer.py,*.csv,*.parquet,./.mypy_cache

0 commit comments

Comments
 (0)