diff --git a/easy_rec/python/tools/add_feature_info_to_config.py b/easy_rec/python/tools/add_feature_info_to_config.py index 62717e6f5..20274987e 100644 --- a/easy_rec/python/tools/add_feature_info_to_config.py +++ b/easy_rec/python/tools/add_feature_info_to_config.py @@ -31,17 +31,19 @@ def main(argv): FLAGS.template_config_path) reader = common_io.table.TableReader( - FLAGS.config_table, selected_cols='feature,feature_info') + FLAGS.config_table, selected_cols='feature,feature_info,message') feature_info_map = {} + drop_feature_names = [] while True: try: record = reader.read() feature_name = record[0][0] feature_info_map[feature_name] = json.loads(record[0][1]) + if 'DROP IT' in record[0][2]: + drop_feature_names.append(feature_name) except common_io.exception.OutOfRangeException: reader.close() break - for feature_config in config_util.get_compatible_feature_configs( pipeline_config): feature_name = feature_config.input_names[0] @@ -77,6 +79,17 @@ def main(argv): 'decay_steps'] logging.info('modify decay_steps to %s' % learning_rate.decay_steps) + for feature_group in pipeline_config.model_config.feature_groups: + feature_names = feature_group.feature_names + reserved_features = [] + for feature_name in feature_names: + if feature_name not in drop_feature_names: + reserved_features.append(feature_name) + else: + logging.info('drop feature: %s' % feature_name) + feature_group.ClearField('feature_names') + feature_group.feature_names.extend(reserved_features) + config_dir, config_name = os.path.split(FLAGS.output_config_path) config_util.save_pipeline_config(pipeline_config, config_dir, config_name)