Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removing tf-ranking as a dependency untill it supports tf 2.16 #7636

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion nightly_test_constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ tensorflow-hub==0.15.0
tensorflow-io==0.24.0
tensorflow-io-gcs-filesystem==0.24.0
tensorflow-metadata>=1.17.0.dev20241016
tensorflow-ranking==0.5.5
# tensorflow-ranking==0.5.5
tensorflow-serving-api==2.15.1
tensorflow-text==2.15.0
tensorflow-transform>=1.16.0.dev20240430
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ def run(self):
'flax': dependencies.make_extra_packages_flax(),
'kfp': dependencies.make_extra_packages_kfp(),
'tfjs': dependencies.make_extra_packages_tfjs(),
'tf-ranking': dependencies.make_extra_packages_tf_ranking(),
# This is due to TF Ranking not supporting TensorFlow 2.16, We should re-enable it when support is added.
# 'tf-ranking': dependencies.make_extra_packages_tf_ranking(),
'tfdf': dependencies.make_extra_packages_tfdf(),
'tflite-support': dependencies.make_extra_packages_tflite_support(),
'examples': dependencies.make_extra_packages_examples(),
Expand Down
35 changes: 18 additions & 17 deletions tfx/examples/ranking/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,37 @@
These names will be shared between the transform and the model.
"""

import tensorflow as tf
from tfx.examples.ranking import struct2tensor_parsing_utils
# import tensorflow as tf
# This is due to TF Ranking not supporting TensorFlow 2.16, We should re-enable it when support is added.
# from tfx.examples.ranking import struct2tensor_parsing_utils

# Labels are expected to be dense. In case of a batch of ELWCs have different
# number of documents, the shape of the label is [N, D], where N is the batch
# size, D is the maximum number of documents in the batch. If an ELWC in the
# batch has D_0 < D documents, then the value of label at D0 <= d < D must be
# negative to indicate that the document is invalid.
LABEL_PADDING_VALUE = -1
#LABEL_PADDING_VALUE = -1

# Names of features in the ELWC.
QUERY_TOKENS = 'query_tokens'
DOCUMENT_TOKENS = 'document_tokens'
LABEL = 'relevance'
#QUERY_TOKENS = 'query_tokens'
#DOCUMENT_TOKENS = 'document_tokens'
#LABEL = 'relevance'

# This "feature" does not exist in the data but will be created on the fly.
LIST_SIZE_FEATURE_NAME = 'example_list_size'
# LIST_SIZE_FEATURE_NAME = 'example_list_size'


def get_features():
"""Defines the context features and example features spec for parsing."""
#def get_features():
# """Defines the context features and example features spec for parsing."""

context_features = [
struct2tensor_parsing_utils.Feature(QUERY_TOKENS, tf.string)
]
# context_features = [
# struct2tensor_parsing_utils.Feature(QUERY_TOKENS, tf.string)
# ]

example_features = [
struct2tensor_parsing_utils.Feature(DOCUMENT_TOKENS, tf.string)
]
# example_features = [
# struct2tensor_parsing_utils.Feature(DOCUMENT_TOKENS, tf.string)
# ]

label = struct2tensor_parsing_utils.Feature(LABEL, tf.int64)
# label = struct2tensor_parsing_utils.Feature(LABEL, tf.int64)

return context_features, example_features, label
# return context_features, example_features, label
47 changes: 25 additions & 22 deletions tfx/examples/ranking/ranking_pipeline_e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
import unittest

import tensorflow as tf
from tfx.examples.ranking import ranking_pipeline
from tfx.orchestration import metadata
from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner
# from tfx.orchestration import metadata
# from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner

# This is due to TF Ranking not supporting TensorFlow 2.16, We should re-enable it when support is added.
# from tfx.examples.ranking import ranking_pipeline


try:
import struct2tensor # pylint: disable=g-import-not-at-top
Expand Down Expand Up @@ -62,23 +65,23 @@ def assertExecutedOnce(self, component) -> None:
execution = tf.io.gfile.listdir(os.path.join(component_path, output))
self.assertEqual(1, len(execution))

def testPipeline(self):
BeamDagRunner().run(
ranking_pipeline._create_pipeline(
pipeline_name=self._pipeline_name,
pipeline_root=self._tfx_root,
data_root=self._data_root,
module_file=self._module_file,
serving_model_dir=self._serving_model_dir,
metadata_path=self._metadata_path,
beam_pipeline_args=['--direct_num_workers=1']))
self.assertTrue(tf.io.gfile.exists(self._serving_model_dir))
self.assertTrue(tf.io.gfile.exists(self._metadata_path))
#def testPipeline(self):
# BeamDagRunner().run(
# ranking_pipeline._create_pipeline(
# pipeline_name=self._pipeline_name,
# pipeline_root=self._tfx_root,
# data_root=self._data_root,
# module_file=self._module_file,
# serving_model_dir=self._serving_model_dir,
# metadata_path=self._metadata_path,
# beam_pipeline_args=['--direct_num_workers=1']))
# self.assertTrue(tf.io.gfile.exists(self._serving_model_dir))
# self.assertTrue(tf.io.gfile.exists(self._metadata_path))

metadata_config = metadata.sqlite_metadata_connection_config(
self._metadata_path)
with metadata.Metadata(metadata_config) as m:
artifact_count = len(m.store.get_artifacts())
execution_count = len(m.store.get_executions())
self.assertGreaterEqual(artifact_count, execution_count)
self.assertEqual(9, execution_count)
# metadata_config = metadata.sqlite_metadata_connection_config(
# self._metadata_path)
# with metadata.Metadata(metadata_config) as m:
# artifact_count = len(m.store.get_artifacts())
# execution_count = len(m.store.get_executions())
# self.assertGreaterEqual(artifact_count, execution_count)
# self.assertEqual(9, execution_count)
Loading
Loading