diff --git a/docs/source/models/loss.md b/docs/source/models/loss.md index e5d81bae8..881794e6a 100644 --- a/docs/source/models/loss.md +++ b/docs/source/models/loss.md @@ -156,6 +156,58 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2 - loss_weight_strategy: Random - 表示损失函数的权重设定为归一化的随机数 +### Loss动态权重 + +在多目标学习任务中,我们经常遇到给与不同目标设置不同的权重,甚至不同目标的loss权重会随着样本而动态变化。EasyRec支持用户为不同目标设置动态的权重。 + +- 1.首先在dataset中配置动态权重字段的名称,同时增加这些字段的input_config,如下示例 + +```protobuf +data_config { + batch_size: 4096 + label_fields: "clk" + label_fields: "buy" + label_dynamic_weight: "clk_weight" + label_dynamic_weight: "buy_weight" + prefetch_size: 32 + input_type: CSVInput + input_fields { + input_name: "clk" + input_type: INT32 + } + input_fields { + input_name: "buy" + input_type: INT32 + } + input_fields { + input_name: "clk_weight" + input_type: double + } + input_fields { + input_name: "buy_weight" + input_type: double + } +} +``` + +- 2.需要在对应的任务tower中设置对应的权重列,如下示例 + +```protobuf +task_towers { + tower_name: "ctr" + label_name: "clk" + dnn { + hidden_units: [256, 192, 128, 64] + } + num_class: 1 + dynamic_weight: "clk_weight" + loss_type: CLASSIFICATION + metrics_set: { + auc {} + } +} +``` + ### 参考论文: - 《 Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics 》 diff --git a/easy_rec/python/input/input.py b/easy_rec/python/input/input.py index d94b1de13..35c4eab32 100644 --- a/easy_rec/python/input/input.py +++ b/easy_rec/python/input/input.py @@ -78,6 +78,7 @@ def __init__(self, x.default_val for x in data_config.input_fields ] self._label_fields = list(data_config.label_fields) + self._label_dynamic_weight = list(data_config.label_dynamic_weight) self._feature_fields = list(data_config.feature_fields) self._label_sep = list(data_config.label_sep) self._label_dim = list(data_config.label_dim) @@ -139,6 +140,8 @@ def __init__(self, # add sample weight to effective fields if self._data_config.HasField('sample_weight'): self._effective_fields.append(self._data_config.sample_weight) + if len(self._label_dynamic_weight) > 0: + self._effective_fields.extend(self._label_dynamic_weight) # add uid_field of GAUC and session_fields of SessionAUC if self._pipeline_config is not None: @@ -234,6 +237,7 @@ def get_feature_input_fields(self): return [ x for x in self._input_fields if x not in self._label_fields and x != self._data_config.sample_weight + and x not in self._label_dynamic_weight ] def should_stop(self, curr_epoch): @@ -269,13 +273,14 @@ def create_multi_placeholders(self, export_config): effective_fids = [ fid for fid in range(len(self._input_fields)) if self._input_fields[fid] not in self._label_fields and + self._input_fields[fid] not in self._label_dynamic_weight and self._input_fields[fid] != sample_weight_field ] inputs = {} for fid in effective_fids: input_name = self._input_fields[fid] - if input_name == sample_weight_field: + if input_name == sample_weight_field or input_name in self._label_dynamic_weight: continue if placeholder_named_by_input: placeholder_name = input_name @@ -318,6 +323,7 @@ def create_placeholders(self, export_config): effective_fids = [ fid for fid in range(len(self._input_fields)) if self._input_fields[fid] not in self._label_fields and + self._input_fields[fid] not in self._label_dynamic_weight and self._input_fields[fid] != sample_weight_field ] logging.info( @@ -330,6 +336,8 @@ def create_placeholders(self, export_config): ftype = self._input_field_types[fid] tf_type = get_tf_type(ftype) input_name = self._input_fields[fid] + if input_name in self._label_dynamic_weight: + continue if tf_type in [tf.float32, tf.double, tf.int32, tf.int64]: features[input_name] = tf.string_to_number( input_vals[:, tmp_id], @@ -925,6 +933,14 @@ def _preprocess(self, field_dict): if self._mode != tf.estimator.ModeKeys.PREDICT: parsed_dict[constant.SAMPLE_WEIGHT] = field_dict[ self._data_config.sample_weight] + if len(self._label_dynamic_weight + ) > 0 and self._mode != tf.estimator.ModeKeys.PREDICT: + for label_weight in self._label_dynamic_weight: + if field_dict[label_weight].dtype == tf.float32: + parsed_dict[label_weight] = field_dict[label_weight] + else: + parsed_dict[label_weight] = tf.cast( + field_dict[label_weight], dtype=tf.float64) if Input.DATA_OFFSET in field_dict: parsed_dict[Input.DATA_OFFSET] = field_dict[Input.DATA_OFFSET] diff --git a/easy_rec/python/model/multi_task_model.py b/easy_rec/python/model/multi_task_model.py index fa6ce8948..f38d825a1 100644 --- a/easy_rec/python/model/multi_task_model.py +++ b/easy_rec/python/model/multi_task_model.py @@ -224,6 +224,8 @@ def build_loss_graph(self): for task_tower_cfg in self._task_towers: tower_name = task_tower_cfg.tower_name loss_weight = task_tower_cfg.weight + if task_tower_cfg.HasField('dynamic_weight'): + loss_weight *= self._feature_dict[task_tower_cfg.dynamic_weight] if task_tower_cfg.use_sample_weight: loss_weight *= self._sample_weight diff --git a/easy_rec/python/protos/dataset.proto b/easy_rec/python/protos/dataset.proto index 5ffefd064..2b64da65d 100644 --- a/easy_rec/python/protos/dataset.proto +++ b/easy_rec/python/protos/dataset.proto @@ -292,6 +292,8 @@ message DatasetConfig { // input field for sample weight optional string sample_weight = 22; + // input field for label dynimic weight + repeated string label_dynamic_weight = 27; // the compression type of tfrecord optional string data_compression_type = 23 [default = '']; diff --git a/easy_rec/python/protos/tower.proto b/easy_rec/python/protos/tower.proto index 14cf64c63..d52c44f5c 100644 --- a/easy_rec/python/protos/tower.proto +++ b/easy_rec/python/protos/tower.proto @@ -26,6 +26,8 @@ message TaskTower { optional DNN dnn = 6; // training loss weights optional float weight = 7 [default = 1.0]; + // training loss label dynamic weights + optional string dynamic_weight = 8; // label name for indicating the sample space for the task tower optional string task_space_indicator_label = 10; // the loss weight for sample in the task space @@ -76,4 +78,7 @@ message BayesTaskTower { optional bool use_ait_module = 17 [default = false]; // set this when the dimensions of last layer of towers are not equal optional uint32 ait_project_dim = 18; + // training loss label dynamic weights + optional string dynamic_weight = 19; + }; diff --git a/easy_rec/python/test/train_eval_test.py b/easy_rec/python/test/train_eval_test.py index 68d0b8656..a682e91bc 100644 --- a/easy_rec/python/test/train_eval_test.py +++ b/easy_rec/python/test/train_eval_test.py @@ -7,11 +7,11 @@ import threading import time import unittest +from distutils.version import LooseVersion import numpy as np import six import tensorflow as tf -from distutils.version import LooseVersion from tensorflow.python.platform import gfile from easy_rec.python.main import predict @@ -942,12 +942,24 @@ def test_sequence_esmm(self): self._test_dir) self.assertTrue(self._success) + def test_label_dynamic_weight_esmm(self): + self._success = test_utils.test_single_train_eval( + 'samples/model_config/esmm_on_label_dynamic_weight_feature_taobao.config', + self._test_dir) + self.assertTrue(self._success) + def test_sequence_mmoe(self): self._success = test_utils.test_single_train_eval( 'samples/model_config/mmoe_on_sequence_feature_taobao.config', self._test_dir) self.assertTrue(self._success) + def test_label_dynamic_weight_sequence_mmoe(self): + self._success = test_utils.test_single_train_eval( + 'samples/model_config/mmoe_on_label_dynamic_weight_sequence_feature_taobao.config', + self._test_dir) + self.assertTrue(self._success) + def test_sequence_ple(self): self._success = test_utils.test_single_train_eval( 'samples/model_config/ple_on_sequence_feature_taobao.config', diff --git a/samples/model_config/esmm_on_label_dynamic_weight_feature_taobao.config b/samples/model_config/esmm_on_label_dynamic_weight_feature_taobao.config new file mode 100644 index 000000000..be13e0d95 --- /dev/null +++ b/samples/model_config/esmm_on_label_dynamic_weight_feature_taobao.config @@ -0,0 +1,319 @@ +train_input_path: "data/test/tb_data/taobao_train_data_label_dynamic_weight" +eval_input_path: "data/test/tb_data/taobao_test_data_label_dynamic_weight" +model_dir: "experiments/esmm_taobao_ckpt" + +train_config { + log_step_count_steps: 100 + optimizer_config: { + adam_optimizer: { + learning_rate: { + exponential_decay_learning_rate { + initial_learning_rate: 0.001 + decay_steps: 1000 + decay_factor: 0.5 + min_learning_rate: 0.0000001 + } + } + } + use_moving_average: false + } + save_checkpoints_steps: 100 + sync_replicas: True + num_steps: 100 +} + +eval_config { + metrics_set: { + auc {} + } +} + +data_config { + input_fields { + input_name:'clk' + input_type: INT32 + } + input_fields { + input_name:'buy' + input_type: INT32 + } + input_fields { + input_name: 'pid' + input_type: STRING + } + input_fields { + input_name: 'adgroup_id' + input_type: STRING + } + input_fields { + input_name: 'cate_id' + input_type: STRING + } + input_fields { + input_name: 'campaign_id' + input_type: STRING + } + input_fields { + input_name: 'customer' + input_type: STRING + } + input_fields { + input_name: 'brand' + input_type: STRING + } + input_fields { + input_name: 'user_id' + input_type: STRING + } + input_fields { + input_name: 'cms_segid' + input_type: STRING + } + input_fields { + input_name: 'cms_group_id' + input_type: STRING + } + input_fields { + input_name: 'final_gender_code' + input_type: STRING + } + input_fields { + input_name: 'age_level' + input_type: STRING + } + input_fields { + input_name: 'pvalue_level' + input_type: STRING + } + input_fields { + input_name: 'shopping_level' + input_type: STRING + } + input_fields { + input_name: 'occupation' + input_type: STRING + } + input_fields { + input_name: 'new_user_class_level' + input_type: STRING + } + input_fields { + input_name: 'tag_category_list' + input_type: STRING + } + input_fields { + input_name: 'tag_brand_list' + input_type: STRING + } + input_fields { + input_name: 'price' + input_type: INT32 + } + input_fields { + input_name: "clk_w" + input_type: INT32 + } + input_fields { + input_name: "buy_w" + input_type: FLOAT + } + label_fields: 'buy' + label_fields: 'clk' + label_dynamic_weight: "clk_w" + label_dynamic_weight: "buy_w" + batch_size: 4096 + num_epochs: 10000 + prefetch_size: 32 + input_type: CSVInput +} + +feature_configs : { + input_names: 'pid' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs : { + input_names: 'adgroup_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 +} +feature_configs : { + input_names: 'cate_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10000 +} +feature_configs : { + input_names: 'campaign_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 +} +feature_configs : { + input_names: 'customer' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 +} +feature_configs : { + input_names: 'brand' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 +} +feature_configs : { + input_names: 'user_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 +} +feature_configs : { + input_names: 'cms_segid' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100 +} +feature_configs : { + input_names: 'cms_group_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100 +} +feature_configs : { + input_names: 'final_gender_code' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs : { + input_names: 'age_level' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs : { + input_names: 'pvalue_level' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs : { + input_names: 'shopping_level' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs : { + input_names: 'occupation' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs : { + input_names: 'new_user_class_level' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs : { + input_names: 'tag_category_list' + feature_type: SequenceFeature + separator: '|' + hash_bucket_size: 100000 + embedding_dim: 16 +} +feature_configs : { + input_names: 'tag_brand_list' + feature_type: SequenceFeature + separator: '|' + hash_bucket_size: 100000 + embedding_dim: 16 +} +feature_configs : { + input_names: 'price' + feature_type: IdFeature + embedding_dim: 16 + num_buckets: 50 +} +model_config: { + model_class: 'ESMM' + feature_groups: { + group_name: 'user' + feature_names: 'user_id' + feature_names: 'cms_segid' + feature_names: 'cms_group_id' + feature_names: 'age_level' + feature_names: 'pvalue_level' + feature_names: 'shopping_level' + feature_names: 'occupation' + feature_names: 'new_user_class_level' + wide_deep: DEEP + sequence_features: { + group_name: "seq_fea" + tf_summary: false + allow_key_search:true + seq_att_map: { + key: "brand" + key: "cate_id" + hist_seq: "tag_brand_list" + hist_seq: "tag_category_list" + } + } + } + feature_groups: { + group_name: 'item' + feature_names: 'adgroup_id' + feature_names: 'cate_id' + feature_names: 'campaign_id' + feature_names: 'customer' + feature_names: 'brand' + feature_names: 'price' + wide_deep: DEEP + } + esmm { + groups { + input: "user" + dnn { + hidden_units: [256, 128, 96, 64] + } + } + groups { + input: "item" + dnn { + hidden_units: [256, 128, 96, 64] + } + } + cvr_tower { + tower_name: "cvr" + label_name: "buy" + dynamic_weight: "buy_w" + dnn { + hidden_units: [128, 96, 64, 32, 16] + } + num_class: 1 + weight: 1.0 + loss_type: CLASSIFICATION + metrics_set: { + auc {} + } + } + ctr_tower { + tower_name: "ctr" + label_name: "clk" + dynamic_weight: "clk_w" + dnn { + hidden_units: [128, 96, 64, 32, 16] + } + num_class: 1 + weight: 1.0 + loss_type: CLASSIFICATION + metrics_set: { + auc {} + } + } + l2_regularization: 1e-6 + } + embedding_regularization: 5e-5 +} diff --git a/samples/model_config/mmoe_on_label_dynamic_weight_sequence_feature_taobao.config b/samples/model_config/mmoe_on_label_dynamic_weight_sequence_feature_taobao.config new file mode 100644 index 000000000..9226cddb7 --- /dev/null +++ b/samples/model_config/mmoe_on_label_dynamic_weight_sequence_feature_taobao.config @@ -0,0 +1,304 @@ +train_input_path: "data/test/tb_data/taobao_train_data_label_dynamic_weight" +eval_input_path: "data/test/tb_data/taobao_test_data_label_dynamic_weight" +model_dir: "experiments/mmoe_taobao_ckpt" + +train_config { + optimizer_config { + adam_optimizer { + learning_rate { + exponential_decay_learning_rate { + initial_learning_rate: 0.001 + decay_steps: 1000 + decay_factor: 0.5 + min_learning_rate: 1e-07 + } + } + } + use_moving_average: false + } + num_steps: 5000 + sync_replicas: true + save_checkpoints_steps: 100 + log_step_count_steps: 100 +} +eval_config { + metrics_set { + auc { + } + } +} +data_config { + batch_size: 4096 + label_fields: "clk" + label_fields: "buy" + label_dynamic_weight: "clk_w" + label_dynamic_weight: "buy_w" + prefetch_size: 32 + input_type: CSVInput + input_fields { + input_name: "clk" + input_type: INT32 + } + input_fields { + input_name: "buy" + input_type: INT32 + } + input_fields { + input_name: "pid" + input_type: STRING + } + input_fields { + input_name: "adgroup_id" + input_type: STRING + } + input_fields { + input_name: "cate_id" + input_type: STRING + } + input_fields { + input_name: "campaign_id" + input_type: STRING + } + input_fields { + input_name: "customer" + input_type: STRING + } + input_fields { + input_name: "brand" + input_type: STRING + } + input_fields { + input_name: "user_id" + input_type: STRING + } + input_fields { + input_name: "cms_segid" + input_type: STRING + } + input_fields { + input_name: "cms_group_id" + input_type: STRING + } + input_fields { + input_name: "final_gender_code" + input_type: STRING + } + input_fields { + input_name: "age_level" + input_type: STRING + } + input_fields { + input_name: "pvalue_level" + input_type: STRING + } + input_fields { + input_name: "shopping_level" + input_type: STRING + } + input_fields { + input_name: "occupation" + input_type: STRING + } + input_fields { + input_name: "new_user_class_level" + input_type: STRING + } + input_fields { + input_name: "tag_category_list" + input_type: STRING + } + input_fields { + input_name: "tag_brand_list" + input_type: STRING + } + input_fields { + input_name: "price" + input_type: INT32 + } + input_fields { + input_name: "clk_w" + input_type: INT32 + } + input_fields { + input_name: "buy_w" + input_type: FLOAT + } +} +feature_configs { + input_names: "pid" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs { + input_names: "adgroup_id" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 +} +feature_configs { + input_names: "cate_id" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10000 +} +feature_configs { + input_names: "campaign_id" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 +} +feature_configs { + input_names: "customer" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 +} +feature_configs { + input_names: "brand" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 +} +feature_configs { + input_names: "user_id" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 +} +feature_configs { + input_names: "cms_segid" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100 +} +feature_configs { + input_names: "cms_group_id" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100 +} +feature_configs { + input_names: "final_gender_code" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs { + input_names: "age_level" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs { + input_names: "pvalue_level" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs { + input_names: "shopping_level" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs { + input_names: "occupation" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs { + input_names: "new_user_class_level" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs { + input_names: "tag_category_list" + feature_type: SequenceFeature + embedding_dim: 16 + hash_bucket_size: 100000 + separator: "|" +} +feature_configs { + input_names: "tag_brand_list" + feature_type: SequenceFeature + embedding_dim: 16 + hash_bucket_size: 100000 + separator: "|" +} +feature_configs { + input_names: "price" + feature_type: IdFeature + embedding_dim: 16 + num_buckets: 50 +} +model_config { + model_class: "MMoE" + feature_groups { + group_name: "all" + feature_names: "user_id" + feature_names: "cms_segid" + feature_names: "cms_group_id" + feature_names: "age_level" + feature_names: "pvalue_level" + feature_names: "shopping_level" + feature_names: "occupation" + feature_names: "new_user_class_level" + feature_names: "adgroup_id" + feature_names: "cate_id" + feature_names: "campaign_id" + feature_names: "customer" + feature_names: "brand" + feature_names: "price" + feature_names: "pid" + wide_deep: DEEP + sequence_features: { + group_name: "seq_fea" + tf_summary: false + seq_att_map: { + key: "brand" + key: "cate_id" + hist_seq: "tag_brand_list" + hist_seq: "tag_category_list" + } + } + } + mmoe { + expert_dnn { + hidden_units: [256, 192, 128, 64] + } + num_expert: 4 + task_towers { + tower_name: "ctr" + label_name: "clk" + dynamic_weight: "clk_w" + dnn { + hidden_units: [256, 192, 128, 64] + } + num_class: 1 + weight: 1.0 + loss_type: CLASSIFICATION + metrics_set: { + auc {} + } + } + task_towers { + tower_name: "cvr" + label_name: "buy" + dynamic_weight: "buy_w" + dnn { + hidden_units: [256, 192, 128, 64] + } + num_class: 1 + weight: 1.0 + loss_type: CLASSIFICATION + metrics_set: { + auc {} + } + } + l2_regularization: 1e-06 + } + embedding_regularization: 5e-05 +}