Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add_dynamic_weight_for_muti_label #469

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
52 changes: 52 additions & 0 deletions docs/source/models/loss.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,58 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2
- loss_weight_strategy: Random
- 表示损失函数的权重设定为归一化的随机数

### Loss动态权重

在多目标学习任务中,我们经常遇到给与不同目标设置不同的权重,甚至不同目标的loss权重会随着样本而动态变化。EasyRec支持用户为不同目标设置动态的权重。

- 1.首先在dataset中配置动态权重字段的名称,同时增加这些字段的input_config,如下示例

```protobuf
data_config {
batch_size: 4096
label_fields: "clk"
label_fields: "buy"
label_dynamic_weight: "clk_weight"
label_dynamic_weight: "buy_weight"
prefetch_size: 32
input_type: CSVInput
input_fields {
input_name: "clk"
input_type: INT32
}
input_fields {
input_name: "buy"
input_type: INT32
}
input_fields {
input_name: "clk_weight"
input_type: double
}
input_fields {
input_name: "buy_weight"
input_type: double
}
}
```

- 2.需要在对应的任务tower中设置对应的权重列,如下示例

```protobuf
task_towers {
tower_name: "ctr"
label_name: "clk"
dnn {
hidden_units: [256, 192, 128, 64]
}
num_class: 1
dynamic_weight: "clk_weight"
loss_type: CLASSIFICATION
metrics_set: {
auc {}
}
}
```

### 参考论文:

- 《 Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics 》
Expand Down
18 changes: 17 additions & 1 deletion easy_rec/python/input/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(self,
x.default_val for x in data_config.input_fields
]
self._label_fields = list(data_config.label_fields)
self._label_dynamic_weight = list(data_config.label_dynamic_weight)
self._feature_fields = list(data_config.feature_fields)
self._label_sep = list(data_config.label_sep)
self._label_dim = list(data_config.label_dim)
Expand Down Expand Up @@ -139,6 +140,8 @@ def __init__(self,
# add sample weight to effective fields
if self._data_config.HasField('sample_weight'):
self._effective_fields.append(self._data_config.sample_weight)
if len(self._label_dynamic_weight) > 0:
self._effective_fields.extend(self._label_dynamic_weight)

# add uid_field of GAUC and session_fields of SessionAUC
if self._pipeline_config is not None:
Expand Down Expand Up @@ -234,6 +237,7 @@ def get_feature_input_fields(self):
return [
x for x in self._input_fields
if x not in self._label_fields and x != self._data_config.sample_weight
and x not in self._label_dynamic_weight
]

def should_stop(self, curr_epoch):
Expand Down Expand Up @@ -269,13 +273,14 @@ def create_multi_placeholders(self, export_config):
effective_fids = [
fid for fid in range(len(self._input_fields))
if self._input_fields[fid] not in self._label_fields and
self._input_fields[fid] not in self._label_dynamic_weight and
self._input_fields[fid] != sample_weight_field
]

inputs = {}
for fid in effective_fids:
input_name = self._input_fields[fid]
if input_name == sample_weight_field:
if input_name == sample_weight_field or input_name in self._label_dynamic_weight:
continue
if placeholder_named_by_input:
placeholder_name = input_name
Expand Down Expand Up @@ -318,6 +323,7 @@ def create_placeholders(self, export_config):
effective_fids = [
fid for fid in range(len(self._input_fields))
if self._input_fields[fid] not in self._label_fields and
self._input_fields[fid] not in self._label_dynamic_weight and
self._input_fields[fid] != sample_weight_field
]
logging.info(
Expand All @@ -330,6 +336,8 @@ def create_placeholders(self, export_config):
ftype = self._input_field_types[fid]
tf_type = get_tf_type(ftype)
input_name = self._input_fields[fid]
if input_name in self._label_dynamic_weight:
continue
if tf_type in [tf.float32, tf.double, tf.int32, tf.int64]:
features[input_name] = tf.string_to_number(
input_vals[:, tmp_id],
Expand Down Expand Up @@ -925,6 +933,14 @@ def _preprocess(self, field_dict):
if self._mode != tf.estimator.ModeKeys.PREDICT:
parsed_dict[constant.SAMPLE_WEIGHT] = field_dict[
self._data_config.sample_weight]
if len(self._label_dynamic_weight
) > 0 and self._mode != tf.estimator.ModeKeys.PREDICT:
for label_weight in self._label_dynamic_weight:
if field_dict[label_weight].dtype == tf.float32:
parsed_dict[label_weight] = field_dict[label_weight]
else:
parsed_dict[label_weight] = tf.cast(
field_dict[label_weight], dtype=tf.float64)

if Input.DATA_OFFSET in field_dict:
parsed_dict[Input.DATA_OFFSET] = field_dict[Input.DATA_OFFSET]
Expand Down
2 changes: 2 additions & 0 deletions easy_rec/python/model/multi_task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ def build_loss_graph(self):
for task_tower_cfg in self._task_towers:
tower_name = task_tower_cfg.tower_name
loss_weight = task_tower_cfg.weight
if task_tower_cfg.HasField('dynamic_weight'):
loss_weight *= self._feature_dict[task_tower_cfg.dynamic_weight]
if task_tower_cfg.use_sample_weight:
loss_weight *= self._sample_weight

Expand Down
2 changes: 2 additions & 0 deletions easy_rec/python/protos/dataset.proto
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ message DatasetConfig {

// input field for sample weight
optional string sample_weight = 22;
// input field for label dynimic weight
repeated string label_dynamic_weight = 27;
// the compression type of tfrecord
optional string data_compression_type = 23 [default = ''];

Expand Down
5 changes: 5 additions & 0 deletions easy_rec/python/protos/tower.proto
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ message TaskTower {
optional DNN dnn = 6;
// training loss weights
optional float weight = 7 [default = 1.0];
// training loss label dynamic weights
optional string dynamic_weight = 8;
// label name for indicating the sample space for the task tower
optional string task_space_indicator_label = 10;
// the loss weight for sample in the task space
Expand Down Expand Up @@ -76,4 +78,7 @@ message BayesTaskTower {
optional bool use_ait_module = 17 [default = false];
// set this when the dimensions of last layer of towers are not equal
optional uint32 ait_project_dim = 18;
// training loss label dynamic weights
optional string dynamic_weight = 19;

};
14 changes: 13 additions & 1 deletion easy_rec/python/test/train_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import threading
import time
import unittest
from distutils.version import LooseVersion

import numpy as np
import six
import tensorflow as tf
from distutils.version import LooseVersion
from tensorflow.python.platform import gfile

from easy_rec.python.main import predict
Expand Down Expand Up @@ -942,12 +942,24 @@ def test_sequence_esmm(self):
self._test_dir)
self.assertTrue(self._success)

def test_label_dynamic_weight_esmm(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/esmm_on_label_dynamic_weight_feature_taobao.config',
self._test_dir)
self.assertTrue(self._success)

def test_sequence_mmoe(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/mmoe_on_sequence_feature_taobao.config',
self._test_dir)
self.assertTrue(self._success)

def test_label_dynamic_weight_sequence_mmoe(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/mmoe_on_label_dynamic_weight_sequence_feature_taobao.config',
self._test_dir)
self.assertTrue(self._success)

def test_sequence_ple(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/ple_on_sequence_feature_taobao.config',
Expand Down
Loading
Loading