Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug of undefined flags of easyrec tools run with DeepRec #511

Merged
merged 10 commits into from
Jan 6, 2025
3 changes: 3 additions & 0 deletions easy_rec/python/tools/add_boundaries_to_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import json
import logging
import os
import sys

import common_io
import tensorflow as tf

from easy_rec.python.utils import config_util
from easy_rec.python.utils import io_util

if tf.__version__ >= '2.0':
tf = tf.compat.v1
Expand Down Expand Up @@ -61,4 +63,5 @@ def main(argv):


if __name__ == '__main__':
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
tf.app.run()
3 changes: 3 additions & 0 deletions easy_rec/python/tools/add_feature_info_to_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import json
import logging
import os
import sys

import tensorflow as tf

from easy_rec.python.utils import config_util
from easy_rec.python.utils import io_util
from easy_rec.python.utils.hive_utils import HiveUtils

if tf.__version__ >= '2.0':
Expand Down Expand Up @@ -139,4 +141,5 @@ def main(argv):


if __name__ == '__main__':
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
tf.app.run()
3 changes: 3 additions & 0 deletions easy_rec/python/tools/faiss_index_pai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

import logging
import os
import sys

import faiss
import numpy as np
import tensorflow as tf
from easy_rec.python.utils import io_util

logging.basicConfig(
level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
Expand Down Expand Up @@ -109,4 +111,5 @@ def main(argv):


if __name__ == '__main__':
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
tf.app.run()
3 changes: 3 additions & 0 deletions easy_rec/python/tools/feature_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import json
import os
import sys
from collections import OrderedDict

import numpy as np
Expand All @@ -11,6 +12,7 @@
from tensorflow.python.framework.meta_graph import read_meta_graph_file

from easy_rec.python.utils import config_util
from easy_rec.python.utils import io_util

if tf.__version__ >= '2.0':
tf = tf.compat.v1
Expand Down Expand Up @@ -299,6 +301,7 @@ def _visualize_feature_importance(self, feature_importance, group_name):


if __name__ == '__main__':
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
if FLAGS.model_type == 'variational_dropout':
fs = VariationalDropoutFS(
FLAGS.config_path,
Expand Down
3 changes: 3 additions & 0 deletions easy_rec/python/tools/hit_rate_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
import json
import logging
import os
import sys

import graphlearn as gl
import tensorflow as tf

from easy_rec.python.protos.dataset_pb2 import DatasetConfig
from easy_rec.python.utils import config_util
from easy_rec.python.utils import io_util
from easy_rec.python.utils.config_util import process_multi_file_input_path
from easy_rec.python.utils.hit_rate_utils import compute_hitrate_batch
from easy_rec.python.utils.hit_rate_utils import load_graph
Expand Down Expand Up @@ -217,4 +219,5 @@ def main():


if __name__ == '__main__':
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
main()
3 changes: 3 additions & 0 deletions easy_rec/python/tools/hit_rate_pai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from __future__ import division
from __future__ import print_function

import sys
import tensorflow as tf

from easy_rec.python.utils import io_util
from easy_rec.python.utils.hit_rate_utils import compute_hitrate_batch
from easy_rec.python.utils.hit_rate_utils import load_graph
from easy_rec.python.utils.hit_rate_utils import reduce_hitrate
Expand Down Expand Up @@ -131,4 +133,5 @@ def main():


if __name__ == '__main__':
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
main()
3 changes: 3 additions & 0 deletions easy_rec/python/tools/pre_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import json
import logging
import os
import sys

import tensorflow as tf

from easy_rec.python.input.input import Input
from easy_rec.python.utils import config_util
from easy_rec.python.utils import fg_util
from easy_rec.python.utils import io_util
from easy_rec.python.utils.check_utils import check_env_and_input_path
from easy_rec.python.utils.check_utils import check_sequence

Expand Down Expand Up @@ -114,4 +116,5 @@ def main(argv):


if __name__ == '__main__':
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
tf.app.run()
3 changes: 3 additions & 0 deletions easy_rec/python/tools/split_model_pai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import logging
import os
import sys

import tensorflow as tf
from tensorflow.core.framework import graph_pb2
Expand All @@ -11,6 +12,7 @@
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.tools import saved_model_utils
from tensorflow.python.training import saver as tf_saver
from easy_rec.python.utils import io_util

if tf.__version__ >= '2.0':
tf = tf.compat.v1
Expand Down Expand Up @@ -282,4 +284,5 @@ def main(argv):


if __name__ == '__main__':
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
tf.app.run()
3 changes: 3 additions & 0 deletions easy_rec/python/tools/split_pdn_model_pai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import logging
import os
import sys

import tensorflow as tf
from tensorflow.core.framework import graph_pb2
Expand All @@ -12,6 +13,7 @@
from tensorflow.python.saved_model.utils_impl import get_variables_path
from tensorflow.python.tools import saved_model_utils
from tensorflow.python.training import saver as tf_saver
from easy_rec.python.utils import io_util

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('model_dir', '', '')
Expand Down Expand Up @@ -265,4 +267,5 @@ def main(argv):


if __name__ == '__main__':
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
tf.app.run()
95 changes: 95 additions & 0 deletions easy_rec/python/utils/io_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,98 @@ def read_data_from_json_path(json_path):
else:
logging.info('json_path not exists, return None')
return None


def convert_tf_flags_to_argparse(flags):
"""Convert tf.app.flags.FLAGS to argparse.ArgumentParser.

Args:
flags: tf.app.flags.FLAGS
Returns:
argparse.ArgumentParser: configurate ArgumentParser object
"""
import argparse
import ast
parser = argparse.ArgumentParser()

args = {}
for flag in flags._flags().values():
flag_name = flag.name
if flag_name in args:
args[flag_name][0] = True
continue
default = flag.value
flag_type = type(default)
help_str = flag.help or ''
args[flag_name] = [
False, flag_type, default, help_str,
flag.choices if hasattr(flag, 'choices') else None
]

def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')

for flag_name, (multi, flag_type, default, help_str, choices) in args.items():
if flag_type == bool:
parser.add_argument(
'--' + flag_name,
type=str2bool,
nargs='?',
const=True,
default=False,
help=help_str)
elif flag_type == str:
if choices:
parser.add_argument(
'--' + flag_name,
type=str,
choices=choices,
default=default,
help=help_str)
elif multi:
parser.add_argument(
'--' + flag_name,
type=str,
action='append',
default=default,
help=help_str)
else:
parser.add_argument(
'--' + flag_name, type=str, default=default, help=help_str)
elif flag_type in (list, dict):
parser.add_argument(
'--' + flag_name,
type=lambda s: ast.literal_eval(s),
default=default,
help=help_str)
elif flag_type in (int, float):
parser.add_argument(
'--' + flag_name, type=flag_type, default=default, help=help_str)
else:
parser.add_argument(
'--' + flag_name, type=str, default=default, help=help_str)
return parser


def filter_unknown_args(flags, args):
"""Filter unknown args."""
known_args = [args[0]]
parser = convert_tf_flags_to_argparse(flags)
args, unknown = parser.parse_known_args(args)
if len(unknown) > 1:
logging.info('undefined arguments: %s', ', '.join(unknown[1:]))
for key, value in vars(args).items():
if value is None:
continue
if type(value) in (list, dict) and not value:
continue
known_args.append('--' + key + '=' + str(value))
logging.info('defined arguments: %s', ', '.join(known_args[1:]))
return known_args