From 53d9ba2da7d0e14db198dab06041b502c4313e74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E9=AB=98=E9=A3=9E?= <1052924341@qq.com> Date: Tue, 28 May 2024 17:35:42 +0800 Subject: [PATCH 1/7] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E7=9B=AE=E6=A0=87?= =?UTF-8?q?=E7=9A=84=E5=8A=A8=E6=80=81=E6=9D=83=E9=87=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/source/quick_start/mc_tutorial.md | 2 +- easy_rec/python/input/input.py | 18 +- easy_rec/python/model/multi_task_model.py | 2 + easy_rec/python/protos/dataset.proto | 2 + easy_rec/python/protos/tower.proto | 4 + easy_rec/python/test/train_eval_test.py | 12 + ...label_dynamic_weight_feature_taobao.config | 319 ++++++++++++++++++ ...amic_weight_sequence_feature_taobao.config | 304 +++++++++++++++++ 8 files changed, 661 insertions(+), 2 deletions(-) create mode 100644 samples/model_config/esmm_on_label_dynamic_weight_feature_taobao.config create mode 100644 samples/model_config/mmoe_on_label_dynamic_weight_sequence_feature_taobao.config diff --git a/docs/source/quick_start/mc_tutorial.md b/docs/source/quick_start/mc_tutorial.md index 16761d2db..ea6a65723 100644 --- a/docs/source/quick_start/mc_tutorial.md +++ b/docs/source/quick_start/mc_tutorial.md @@ -39,7 +39,7 @@ pai -name easy_rec_ext -project algo_public - -Dtables: 定义其他依赖表(可选),如负采样的表 - -Dcluster: 定义PS的数目和worker的数目。具体见:[PAI-TF任务参数介绍](https://help.aliyun.com/document_detail/154186.html?spm=a2c4g.11186623.4.3.e56f1adb7AJ9T5) - -Deval_method: 评估方法 -- separate: 用worker(task_id=1)做评估 +- separate: 用worker(task_id=1)做评估。点击训练的logview中worker#1_0的stderr,出现类似字段"Saving dict for global step 3949: auc = 0.7643898, global_step = 3949, loss = 0.38898173, loss/loss/cross_entropy_loss = 0.38898173, loss/loss/total_loss = 0.38898173"即是评估指标 - none: 不需要评估 - master: 在master(task_id=0)上做评估 - -Dfine_tune_checkpoint: 可选,从checkpoint restore参数,进行finetune diff --git a/easy_rec/python/input/input.py b/easy_rec/python/input/input.py index d94b1de13..35c4eab32 100644 --- a/easy_rec/python/input/input.py +++ b/easy_rec/python/input/input.py @@ -78,6 +78,7 @@ def __init__(self, x.default_val for x in data_config.input_fields ] self._label_fields = list(data_config.label_fields) + self._label_dynamic_weight = list(data_config.label_dynamic_weight) self._feature_fields = list(data_config.feature_fields) self._label_sep = list(data_config.label_sep) self._label_dim = list(data_config.label_dim) @@ -139,6 +140,8 @@ def __init__(self, # add sample weight to effective fields if self._data_config.HasField('sample_weight'): self._effective_fields.append(self._data_config.sample_weight) + if len(self._label_dynamic_weight) > 0: + self._effective_fields.extend(self._label_dynamic_weight) # add uid_field of GAUC and session_fields of SessionAUC if self._pipeline_config is not None: @@ -234,6 +237,7 @@ def get_feature_input_fields(self): return [ x for x in self._input_fields if x not in self._label_fields and x != self._data_config.sample_weight + and x not in self._label_dynamic_weight ] def should_stop(self, curr_epoch): @@ -269,13 +273,14 @@ def create_multi_placeholders(self, export_config): effective_fids = [ fid for fid in range(len(self._input_fields)) if self._input_fields[fid] not in self._label_fields and + self._input_fields[fid] not in self._label_dynamic_weight and self._input_fields[fid] != sample_weight_field ] inputs = {} for fid in effective_fids: input_name = self._input_fields[fid] - if input_name == sample_weight_field: + if input_name == sample_weight_field or input_name in self._label_dynamic_weight: continue if placeholder_named_by_input: placeholder_name = input_name @@ -318,6 +323,7 @@ def create_placeholders(self, export_config): effective_fids = [ fid for fid in range(len(self._input_fields)) if self._input_fields[fid] not in self._label_fields and + self._input_fields[fid] not in self._label_dynamic_weight and self._input_fields[fid] != sample_weight_field ] logging.info( @@ -330,6 +336,8 @@ def create_placeholders(self, export_config): ftype = self._input_field_types[fid] tf_type = get_tf_type(ftype) input_name = self._input_fields[fid] + if input_name in self._label_dynamic_weight: + continue if tf_type in [tf.float32, tf.double, tf.int32, tf.int64]: features[input_name] = tf.string_to_number( input_vals[:, tmp_id], @@ -925,6 +933,14 @@ def _preprocess(self, field_dict): if self._mode != tf.estimator.ModeKeys.PREDICT: parsed_dict[constant.SAMPLE_WEIGHT] = field_dict[ self._data_config.sample_weight] + if len(self._label_dynamic_weight + ) > 0 and self._mode != tf.estimator.ModeKeys.PREDICT: + for label_weight in self._label_dynamic_weight: + if field_dict[label_weight].dtype == tf.float32: + parsed_dict[label_weight] = field_dict[label_weight] + else: + parsed_dict[label_weight] = tf.cast( + field_dict[label_weight], dtype=tf.float64) if Input.DATA_OFFSET in field_dict: parsed_dict[Input.DATA_OFFSET] = field_dict[Input.DATA_OFFSET] diff --git a/easy_rec/python/model/multi_task_model.py b/easy_rec/python/model/multi_task_model.py index cff58e079..f35148a65 100644 --- a/easy_rec/python/model/multi_task_model.py +++ b/easy_rec/python/model/multi_task_model.py @@ -198,6 +198,8 @@ def build_loss_graph(self): for task_tower_cfg in self._task_towers: tower_name = task_tower_cfg.tower_name loss_weight = task_tower_cfg.weight + if task_tower_cfg.HasField('dynamic_weight'): + loss_weight *= self._feature_dict[task_tower_cfg.dynamic_weight] if task_tower_cfg.use_sample_weight: loss_weight *= self._sample_weight diff --git a/easy_rec/python/protos/dataset.proto b/easy_rec/python/protos/dataset.proto index 5ffefd064..2b64da65d 100644 --- a/easy_rec/python/protos/dataset.proto +++ b/easy_rec/python/protos/dataset.proto @@ -292,6 +292,8 @@ message DatasetConfig { // input field for sample weight optional string sample_weight = 22; + // input field for label dynimic weight + repeated string label_dynamic_weight = 27; // the compression type of tfrecord optional string data_compression_type = 23 [default = '']; diff --git a/easy_rec/python/protos/tower.proto b/easy_rec/python/protos/tower.proto index 580708825..3cd6f6253 100644 --- a/easy_rec/python/protos/tower.proto +++ b/easy_rec/python/protos/tower.proto @@ -26,6 +26,8 @@ message TaskTower { optional DNN dnn = 6; // training loss weights optional float weight = 7 [default = 1.0]; + // training loss label dynamic weights + optional string dynamic_weight = 8; // label name for indicating the sample space for the task tower optional string task_space_indicator_label = 10; // the loss weight for sample in the task space @@ -72,4 +74,6 @@ message BayesTaskTower { repeated Loss losses = 15; // whether to use sample weight in this tower required bool use_sample_weight = 16 [default = true]; + // training loss label dynamic weights + optional string dynamic_weight = 17; }; diff --git a/easy_rec/python/test/train_eval_test.py b/easy_rec/python/test/train_eval_test.py index 73f05836d..f689dcd01 100644 --- a/easy_rec/python/test/train_eval_test.py +++ b/easy_rec/python/test/train_eval_test.py @@ -937,12 +937,24 @@ def test_sequence_esmm(self): self._test_dir) self.assertTrue(self._success) + def test_label_dynamic_weight_esmm(self): + self._success = test_utils.test_single_train_eval( + 'samples/model_config/esmm_on_label_dynamic_weight_feature_taobao.config', + self._test_dir) + self.assertTrue(self._success) + def test_sequence_mmoe(self): self._success = test_utils.test_single_train_eval( 'samples/model_config/mmoe_on_sequence_feature_taobao.config', self._test_dir) self.assertTrue(self._success) + def test_label_dynamic_weight_sequence_mmoe(self): + self._success = test_utils.test_single_train_eval( + 'samples/model_config/mmoe_on_label_dynamic_weight_sequence_feature_taobao.config', + self._test_dir) + self.assertTrue(self._success) + def test_sequence_ple(self): self._success = test_utils.test_single_train_eval( 'samples/model_config/ple_on_sequence_feature_taobao.config', diff --git a/samples/model_config/esmm_on_label_dynamic_weight_feature_taobao.config b/samples/model_config/esmm_on_label_dynamic_weight_feature_taobao.config new file mode 100644 index 000000000..be13e0d95 --- /dev/null +++ b/samples/model_config/esmm_on_label_dynamic_weight_feature_taobao.config @@ -0,0 +1,319 @@ +train_input_path: "data/test/tb_data/taobao_train_data_label_dynamic_weight" +eval_input_path: "data/test/tb_data/taobao_test_data_label_dynamic_weight" +model_dir: "experiments/esmm_taobao_ckpt" + +train_config { + log_step_count_steps: 100 + optimizer_config: { + adam_optimizer: { + learning_rate: { + exponential_decay_learning_rate { + initial_learning_rate: 0.001 + decay_steps: 1000 + decay_factor: 0.5 + min_learning_rate: 0.0000001 + } + } + } + use_moving_average: false + } + save_checkpoints_steps: 100 + sync_replicas: True + num_steps: 100 +} + +eval_config { + metrics_set: { + auc {} + } +} + +data_config { + input_fields { + input_name:'clk' + input_type: INT32 + } + input_fields { + input_name:'buy' + input_type: INT32 + } + input_fields { + input_name: 'pid' + input_type: STRING + } + input_fields { + input_name: 'adgroup_id' + input_type: STRING + } + input_fields { + input_name: 'cate_id' + input_type: STRING + } + input_fields { + input_name: 'campaign_id' + input_type: STRING + } + input_fields { + input_name: 'customer' + input_type: STRING + } + input_fields { + input_name: 'brand' + input_type: STRING + } + input_fields { + input_name: 'user_id' + input_type: STRING + } + input_fields { + input_name: 'cms_segid' + input_type: STRING + } + input_fields { + input_name: 'cms_group_id' + input_type: STRING + } + input_fields { + input_name: 'final_gender_code' + input_type: STRING + } + input_fields { + input_name: 'age_level' + input_type: STRING + } + input_fields { + input_name: 'pvalue_level' + input_type: STRING + } + input_fields { + input_name: 'shopping_level' + input_type: STRING + } + input_fields { + input_name: 'occupation' + input_type: STRING + } + input_fields { + input_name: 'new_user_class_level' + input_type: STRING + } + input_fields { + input_name: 'tag_category_list' + input_type: STRING + } + input_fields { + input_name: 'tag_brand_list' + input_type: STRING + } + input_fields { + input_name: 'price' + input_type: INT32 + } + input_fields { + input_name: "clk_w" + input_type: INT32 + } + input_fields { + input_name: "buy_w" + input_type: FLOAT + } + label_fields: 'buy' + label_fields: 'clk' + label_dynamic_weight: "clk_w" + label_dynamic_weight: "buy_w" + batch_size: 4096 + num_epochs: 10000 + prefetch_size: 32 + input_type: CSVInput +} + +feature_configs : { + input_names: 'pid' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs : { + input_names: 'adgroup_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 +} +feature_configs : { + input_names: 'cate_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10000 +} +feature_configs : { + input_names: 'campaign_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 +} +feature_configs : { + input_names: 'customer' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 +} +feature_configs : { + input_names: 'brand' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 +} +feature_configs : { + input_names: 'user_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 +} +feature_configs : { + input_names: 'cms_segid' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100 +} +feature_configs : { + input_names: 'cms_group_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100 +} +feature_configs : { + input_names: 'final_gender_code' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs : { + input_names: 'age_level' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs : { + input_names: 'pvalue_level' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs : { + input_names: 'shopping_level' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs : { + input_names: 'occupation' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs : { + input_names: 'new_user_class_level' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs : { + input_names: 'tag_category_list' + feature_type: SequenceFeature + separator: '|' + hash_bucket_size: 100000 + embedding_dim: 16 +} +feature_configs : { + input_names: 'tag_brand_list' + feature_type: SequenceFeature + separator: '|' + hash_bucket_size: 100000 + embedding_dim: 16 +} +feature_configs : { + input_names: 'price' + feature_type: IdFeature + embedding_dim: 16 + num_buckets: 50 +} +model_config: { + model_class: 'ESMM' + feature_groups: { + group_name: 'user' + feature_names: 'user_id' + feature_names: 'cms_segid' + feature_names: 'cms_group_id' + feature_names: 'age_level' + feature_names: 'pvalue_level' + feature_names: 'shopping_level' + feature_names: 'occupation' + feature_names: 'new_user_class_level' + wide_deep: DEEP + sequence_features: { + group_name: "seq_fea" + tf_summary: false + allow_key_search:true + seq_att_map: { + key: "brand" + key: "cate_id" + hist_seq: "tag_brand_list" + hist_seq: "tag_category_list" + } + } + } + feature_groups: { + group_name: 'item' + feature_names: 'adgroup_id' + feature_names: 'cate_id' + feature_names: 'campaign_id' + feature_names: 'customer' + feature_names: 'brand' + feature_names: 'price' + wide_deep: DEEP + } + esmm { + groups { + input: "user" + dnn { + hidden_units: [256, 128, 96, 64] + } + } + groups { + input: "item" + dnn { + hidden_units: [256, 128, 96, 64] + } + } + cvr_tower { + tower_name: "cvr" + label_name: "buy" + dynamic_weight: "buy_w" + dnn { + hidden_units: [128, 96, 64, 32, 16] + } + num_class: 1 + weight: 1.0 + loss_type: CLASSIFICATION + metrics_set: { + auc {} + } + } + ctr_tower { + tower_name: "ctr" + label_name: "clk" + dynamic_weight: "clk_w" + dnn { + hidden_units: [128, 96, 64, 32, 16] + } + num_class: 1 + weight: 1.0 + loss_type: CLASSIFICATION + metrics_set: { + auc {} + } + } + l2_regularization: 1e-6 + } + embedding_regularization: 5e-5 +} diff --git a/samples/model_config/mmoe_on_label_dynamic_weight_sequence_feature_taobao.config b/samples/model_config/mmoe_on_label_dynamic_weight_sequence_feature_taobao.config new file mode 100644 index 000000000..9226cddb7 --- /dev/null +++ b/samples/model_config/mmoe_on_label_dynamic_weight_sequence_feature_taobao.config @@ -0,0 +1,304 @@ +train_input_path: "data/test/tb_data/taobao_train_data_label_dynamic_weight" +eval_input_path: "data/test/tb_data/taobao_test_data_label_dynamic_weight" +model_dir: "experiments/mmoe_taobao_ckpt" + +train_config { + optimizer_config { + adam_optimizer { + learning_rate { + exponential_decay_learning_rate { + initial_learning_rate: 0.001 + decay_steps: 1000 + decay_factor: 0.5 + min_learning_rate: 1e-07 + } + } + } + use_moving_average: false + } + num_steps: 5000 + sync_replicas: true + save_checkpoints_steps: 100 + log_step_count_steps: 100 +} +eval_config { + metrics_set { + auc { + } + } +} +data_config { + batch_size: 4096 + label_fields: "clk" + label_fields: "buy" + label_dynamic_weight: "clk_w" + label_dynamic_weight: "buy_w" + prefetch_size: 32 + input_type: CSVInput + input_fields { + input_name: "clk" + input_type: INT32 + } + input_fields { + input_name: "buy" + input_type: INT32 + } + input_fields { + input_name: "pid" + input_type: STRING + } + input_fields { + input_name: "adgroup_id" + input_type: STRING + } + input_fields { + input_name: "cate_id" + input_type: STRING + } + input_fields { + input_name: "campaign_id" + input_type: STRING + } + input_fields { + input_name: "customer" + input_type: STRING + } + input_fields { + input_name: "brand" + input_type: STRING + } + input_fields { + input_name: "user_id" + input_type: STRING + } + input_fields { + input_name: "cms_segid" + input_type: STRING + } + input_fields { + input_name: "cms_group_id" + input_type: STRING + } + input_fields { + input_name: "final_gender_code" + input_type: STRING + } + input_fields { + input_name: "age_level" + input_type: STRING + } + input_fields { + input_name: "pvalue_level" + input_type: STRING + } + input_fields { + input_name: "shopping_level" + input_type: STRING + } + input_fields { + input_name: "occupation" + input_type: STRING + } + input_fields { + input_name: "new_user_class_level" + input_type: STRING + } + input_fields { + input_name: "tag_category_list" + input_type: STRING + } + input_fields { + input_name: "tag_brand_list" + input_type: STRING + } + input_fields { + input_name: "price" + input_type: INT32 + } + input_fields { + input_name: "clk_w" + input_type: INT32 + } + input_fields { + input_name: "buy_w" + input_type: FLOAT + } +} +feature_configs { + input_names: "pid" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs { + input_names: "adgroup_id" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 +} +feature_configs { + input_names: "cate_id" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10000 +} +feature_configs { + input_names: "campaign_id" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 +} +feature_configs { + input_names: "customer" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 +} +feature_configs { + input_names: "brand" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 +} +feature_configs { + input_names: "user_id" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 +} +feature_configs { + input_names: "cms_segid" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100 +} +feature_configs { + input_names: "cms_group_id" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100 +} +feature_configs { + input_names: "final_gender_code" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs { + input_names: "age_level" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs { + input_names: "pvalue_level" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs { + input_names: "shopping_level" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs { + input_names: "occupation" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs { + input_names: "new_user_class_level" + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 +} +feature_configs { + input_names: "tag_category_list" + feature_type: SequenceFeature + embedding_dim: 16 + hash_bucket_size: 100000 + separator: "|" +} +feature_configs { + input_names: "tag_brand_list" + feature_type: SequenceFeature + embedding_dim: 16 + hash_bucket_size: 100000 + separator: "|" +} +feature_configs { + input_names: "price" + feature_type: IdFeature + embedding_dim: 16 + num_buckets: 50 +} +model_config { + model_class: "MMoE" + feature_groups { + group_name: "all" + feature_names: "user_id" + feature_names: "cms_segid" + feature_names: "cms_group_id" + feature_names: "age_level" + feature_names: "pvalue_level" + feature_names: "shopping_level" + feature_names: "occupation" + feature_names: "new_user_class_level" + feature_names: "adgroup_id" + feature_names: "cate_id" + feature_names: "campaign_id" + feature_names: "customer" + feature_names: "brand" + feature_names: "price" + feature_names: "pid" + wide_deep: DEEP + sequence_features: { + group_name: "seq_fea" + tf_summary: false + seq_att_map: { + key: "brand" + key: "cate_id" + hist_seq: "tag_brand_list" + hist_seq: "tag_category_list" + } + } + } + mmoe { + expert_dnn { + hidden_units: [256, 192, 128, 64] + } + num_expert: 4 + task_towers { + tower_name: "ctr" + label_name: "clk" + dynamic_weight: "clk_w" + dnn { + hidden_units: [256, 192, 128, 64] + } + num_class: 1 + weight: 1.0 + loss_type: CLASSIFICATION + metrics_set: { + auc {} + } + } + task_towers { + tower_name: "cvr" + label_name: "buy" + dynamic_weight: "buy_w" + dnn { + hidden_units: [256, 192, 128, 64] + } + num_class: 1 + weight: 1.0 + loss_type: CLASSIFICATION + metrics_set: { + auc {} + } + } + l2_regularization: 1e-06 + } + embedding_regularization: 5e-05 +} From f2c04832d626c17ab25a8f4b9e22a15c3ef7447c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E9=AB=98=E9=A3=9E?= <1052924341@qq.com> Date: Thu, 5 Sep 2024 16:47:15 +0800 Subject: [PATCH 2/7] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E7=9B=AE=E6=A0=87loss?= =?UTF-8?q?=E5=8A=A8=E6=80=81=E6=9D=83=E9=87=8D=E5=AF=B9=E5=BA=94=E6=96=87?= =?UTF-8?q?=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/source/models/loss.md | 52 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/docs/source/models/loss.md b/docs/source/models/loss.md index 1fd13e6ab..f1246299f 100644 --- a/docs/source/models/loss.md +++ b/docs/source/models/loss.md @@ -155,6 +155,58 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2 - loss_weight_strategy: Random - 表示损失函数的权重设定为归一化的随机数 +### Loss动态权重 + +在多目标学习任务中,我们经常遇到给与不同目标设置不同的权重,甚至不同目标的loss权重会随着样本而动态变化。EasyRec支持用户为不同目标设置动态的权重。 + +- 1.首先在dataset中配置动态权重字段的名称,同时增加这些字段的input_config,如下示例 + +```protobuf +data_config { + batch_size: 4096 + label_fields: "clk" + label_fields: "buy" + label_dynamic_weight: "clk_weight" + label_dynamic_weight: "buy_weight" + prefetch_size: 32 + input_type: CSVInput + input_fields { + input_name: "clk" + input_type: INT32 + } + input_fields { + input_name: "buy" + input_type: INT32 + } + input_fields { + input_name: "clk_weight" + input_type: double + } + input_fields { + input_name: "buy_weight" + input_type: double + } +} +``` + +- 2.需要在对应的任务tower中设置对应的权重列,如下示例 + +```protobuf +task_towers { + tower_name: "ctr" + label_name: "clk" + dnn { + hidden_units: [256, 192, 128, 64] + } + num_class: 1 + dynamic_weight: "clk_weight" + loss_type: CLASSIFICATION + metrics_set: { + auc {} + } +} +``` + ### 参考论文: - 《 Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics 》 From 0605025effaecd21c3eafeff1f234c82141d0e11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E9=AB=98=E9=A3=9E?= Date: Thu, 5 Sep 2024 17:57:50 +0800 Subject: [PATCH 3/7] =?UTF-8?q?=E8=A7=A3=E5=86=B3=E5=86=B2=E7=AA=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/source/quick_start/mc_tutorial.md | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/docs/source/quick_start/mc_tutorial.md b/docs/source/quick_start/mc_tutorial.md index ea6a65723..0f6065f0c 100644 --- a/docs/source/quick_start/mc_tutorial.md +++ b/docs/source/quick_start/mc_tutorial.md @@ -4,14 +4,18 @@ 针对阿里集团内部用户,请参考[mc_tutorial_inner](mc_tutorial_inner.md)。 +有技术问题可加钉钉群:37930014162 + ### 输入数据: -输入一般是odps表: +输入一般是MaxCompute表: - train: pai_online_project.dwd_avazu_ctr_deepmodel_train -- test: pai_online_project.dwd_avazu_ctr_deepmodel_test +- test: pai_online_project.dwd_avazu_ctr_deepmodel_test + +说明:原则上这两张表是自己odps的表,为了方便,以上提供case的两张表可在国内用户的MaxCompute项目空间中访问。 -说明:原则上这两张表是自己odps的表,为了方便,以上提供case的两张表在任何地方都可以访问。两个表可以带分区,也可以不带分区。 +两个表可以带分区,也可以不带分区。带分区的方式:odps://xyz_project/table1/dt=20240101 ### 训练: @@ -24,7 +28,7 @@ pai -name easy_rec_ext -project algo_public -Dconfig=oss://easyrec/config/MultiTower/dwd_avazu_ctr_deepmodel_ext.config -Dtrain_tables='odps://pai_online_project/tables/dwd_avazu_ctr_deepmodel_train' -Deval_tables='odps://pai_online_project/tables/dwd_avazu_ctr_deepmodel_test' --Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":3, "cpu":1000, "gpu":100, "memory":40000}}' +-Dcluster='{"ps":{"count":1, "cpu":1000}, "worker" : {"count":3, "cpu":1000, "gpu":0, "memory":40000}}' -Deval_method=separate -Dmodel_dir=oss://easyrec/ckpt/MultiTower -Darn=acs:ram::xxx:role/xxx @@ -32,26 +36,26 @@ pai -name easy_rec_ext -project algo_public -DossHost=oss-cn-beijing-internal.aliyuncs.com; ``` -- -Dcmd: train 模型训练 +- -Dcmd: train 表示模型训练 - -Dconfig: 训练用的配置文件 - -Dtrain_tables: 定义训练表 -- -Deval_tables: 定义测试表 +- -Deval_tables: 定义评估表 - -Dtables: 定义其他依赖表(可选),如负采样的表 - -Dcluster: 定义PS的数目和worker的数目。具体见:[PAI-TF任务参数介绍](https://help.aliyun.com/document_detail/154186.html?spm=a2c4g.11186623.4.3.e56f1adb7AJ9T5) - -Deval_method: 评估方法 -- separate: 用worker(task_id=1)做评估。点击训练的logview中worker#1_0的stderr,出现类似字段"Saving dict for global step 3949: auc = 0.7643898, global_step = 3949, loss = 0.38898173, loss/loss/cross_entropy_loss = 0.38898173, loss/loss/total_loss = 0.38898173"即是评估指标 +- separate: 用worker(task_id=1)做评估。找到MaxCompute训练任务的logview,打开logview之后在worker1机器的stderr日志中查看评估指标数据。 - none: 不需要评估 - master: 在master(task_id=0)上做评估 - -Dfine_tune_checkpoint: 可选,从checkpoint restore参数,进行finetune - 可以指定directory,将使用directory里面的最新的checkpoint. - -Dmodel_dir: 如果指定了model_dir将会覆盖config里面的model_dir,一般在周期性调度的时候使用。 -- -Darn: rolearn 注意这个的arn要替换成客户自己的。可以从dataworks的设置中查看arn。 +- -Darn: rolearn 注意这个的arn要替换成客户自己的。可以从dataworks的设置中查看arn;或者阿里云控制台人工智能平台PAI,左侧菜单"开通和授权",找到全部云产品依赖->Designer->OSS->查看授权信息。 - -Dbuckets: config所在的bucket和保存模型的bucket; 如果有多个bucket,逗号分割 - -DossHost: ossHost地址 ### 注意: -- dataworks和pai的project 一样,案例都是pai_online_project,用户需要根据自己的环境修改。如果需要使用gpu,PAI的project需要设置开通GPU。链接:[https://pai.data.aliyun.com/console?projectId=®ionId=cn-beijing#/visual](https://pai.data.aliyun.com/console?projectId=%C2%AEionId=cn-beijing#/visual) ,其中regionId可能不一致。 +- dataworks和PAI的project一样,案例都是pai_online_project,用户需要根据自己的环境修改。如果需要使用gpu,PAI的project需要设置开通GPU。链接:[https://pai.data.aliyun.com/console?projectId=®ionId=cn-beijing#/visual](https://pai.data.aliyun.com/console?projectId=%C2%AEionId=cn-beijing#/visual) ,其中regionId可能不一致。 ![mc_gpu](../../images/quick_start/mc_gpu.png) @@ -68,7 +72,7 @@ pai -name easy_rec_ext -project algo_public -Dcmd=evaluate -Dconfig=oss://easyrec/config/MultiTower/dwd_avazu_ctr_deepmodel_ext.config -Deval_tables='odps://pai_online_project/tables/dwd_avazu_ctr_deepmodel_test' --Dcluster='{"worker" : {"count":1, "cpu":1000, "gpu":100, "memory":40000}}' +-Dcluster='{"worker" : {"count":1, "cpu":1000, "gpu":0, "memory":40000}}' -Dmodel_dir=oss://easyrec/ckpt/MultiTower -Darn=acs:ram::xxx:role/xxx -Dbuckets=oss://easyrec/ From 1c08c303b09f2de2e91f05d63ca125d0aaa1b27a Mon Sep 17 00:00:00 2001 From: chengaofei <52209156+chengaofei@users.noreply.github.com> Date: Thu, 5 Sep 2024 18:47:49 +0800 Subject: [PATCH 4/7] Delete setup.cfg --- setup.cfg | 34 ---------------------------------- 1 file changed, 34 deletions(-) delete mode 100644 setup.cfg diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 337833a0f..000000000 --- a/setup.cfg +++ /dev/null @@ -1,34 +0,0 @@ -[easy_install] -index_url = http://mirrors.aliyun.com/pypi/simple/ - -[bdist_wheel] -universal = 1 - -[isort] -line_length = 79 -multi_line_output = 7 -force_single_line = true -known_standard_library = setuptools -known_first_party = easy_rec -known_third_party = absl,common_io,docutils,eas_prediction,faiss,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sparse_operation_kit,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml -no_lines_before = LOCALFOLDER -default_section = THIRDPARTY -skip = easy_rec/python/protos - -[yapf] -BASED_ON_STYLE = yapf -ALIGN_CLOSING_BRACKET_WITH_VISUAL_INDENT = true - -[flake8] -select = B,C,D,E,F,P,T4,W,B9 -max-line-length = 120 -ignore = - E111,E114,E125,E129,W291,W503,W504, - # docstring missing error should be used when all docstrings are completed - D100,D101,D102,D103,D104,D105,D106,D107 -per-file-ignores = - __init__.py: F401 - easy_rec/python/utils/test_utils.py: E402 - easy_rec/python/utils/io_util.py: E402 -exclude = docs/src,scripts,easy_rec/python/protos,*.pyi,.git -docstring-convention = google From 1d42a81c2c25cb8936f6026e9e33fc91e1bf49cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E9=AB=98=E9=A3=9E?= Date: Thu, 5 Sep 2024 18:55:44 +0800 Subject: [PATCH 5/7] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E5=86=B2=E7=AA=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.cfg | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 setup.cfg diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 000000000..337833a0f --- /dev/null +++ b/setup.cfg @@ -0,0 +1,34 @@ +[easy_install] +index_url = http://mirrors.aliyun.com/pypi/simple/ + +[bdist_wheel] +universal = 1 + +[isort] +line_length = 79 +multi_line_output = 7 +force_single_line = true +known_standard_library = setuptools +known_first_party = easy_rec +known_third_party = absl,common_io,docutils,eas_prediction,faiss,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sparse_operation_kit,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml +no_lines_before = LOCALFOLDER +default_section = THIRDPARTY +skip = easy_rec/python/protos + +[yapf] +BASED_ON_STYLE = yapf +ALIGN_CLOSING_BRACKET_WITH_VISUAL_INDENT = true + +[flake8] +select = B,C,D,E,F,P,T4,W,B9 +max-line-length = 120 +ignore = + E111,E114,E125,E129,W291,W503,W504, + # docstring missing error should be used when all docstrings are completed + D100,D101,D102,D103,D104,D105,D106,D107 +per-file-ignores = + __init__.py: F401 + easy_rec/python/utils/test_utils.py: E402 + easy_rec/python/utils/io_util.py: E402 +exclude = docs/src,scripts,easy_rec/python/protos,*.pyi,.git +docstring-convention = google From 1773f8f8132bf947a52e73688c8ceeead15f5454 Mon Sep 17 00:00:00 2001 From: chengaofei <52209156+chengaofei@users.noreply.github.com> Date: Thu, 5 Sep 2024 18:57:42 +0800 Subject: [PATCH 6/7] Update setup.cfg MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 代码冲突异常 --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 337833a0f..b43211827 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,7 +10,7 @@ multi_line_output = 7 force_single_line = true known_standard_library = setuptools known_first_party = easy_rec -known_third_party = absl,common_io,docutils,eas_prediction,faiss,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sparse_operation_kit,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml +known_third_party = absl,common_io,distutils,docutils,eas_prediction,faiss,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sparse_operation_kit,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml no_lines_before = LOCALFOLDER default_section = THIRDPARTY skip = easy_rec/python/protos From 2bad4c807bb0f9b5ad3b37e40846d2420288ab17 Mon Sep 17 00:00:00 2001 From: chengaofei <52209156+chengaofei@users.noreply.github.com> Date: Thu, 5 Sep 2024 18:58:32 +0800 Subject: [PATCH 7/7] Update early_stopping.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 解决代码冲突引发的异常 --- easy_rec/python/compat/early_stopping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/easy_rec/python/compat/early_stopping.py b/easy_rec/python/compat/early_stopping.py index fe4c12132..fc850fb62 100644 --- a/easy_rec/python/compat/early_stopping.py +++ b/easy_rec/python/compat/early_stopping.py @@ -21,9 +21,9 @@ import os import threading import time -from distutils.version import LooseVersion import tensorflow as tf +from distutils.version import LooseVersion from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import init_ops