Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] update mind to support multiple behavior sequences #323

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions easy_rec/python/input/datahub_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from datahub.exceptions import DatahubException
from datahub.models import RecordType
from datahub.models import CursorType
from datahub.models.shard import ShardState
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
logging.getLogger('datahub.account').setLevel(logging.INFO)
Expand Down Expand Up @@ -70,12 +71,14 @@ def __init__(self,
if datahub_config:
shard_result = self._datahub.list_shard(self._datahub_config.project,
self._datahub_config.topic)
shards = shard_result.shards
shards = [x for x in shard_result.shards if x.state == ShardState.ACTIVE]
self._all_shards = shards
self._shards = [
shards[i] for i in range(len(shards)) if (i % task_num) == task_index
]
logging.info('all shards: %s' % str(self._shards))
logging.info('all_shards[len=%d]: %s task_shards[len=%d]: %s' %
(len(self._all_shards), str(
self._all_shards), len(self._shards), str(self._shards)))

offset_type = datahub_config.WhichOneof('offset')
if offset_type == 'offset_time':
Expand Down
17 changes: 11 additions & 6 deletions easy_rec/python/layers/input_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,12 @@ def __init__(self,
embedding_regularizer=None,
kernel_regularizer=None,
is_training=False):
self._feature_groups = {
x.group_name: FeatureGroup(x) for x in feature_groups_config
}
self._feature_groups = {}
for x in feature_groups_config:
assert x.group_name not in self._feature_groups, 'feature_group name(%s) is repeated'\
% x.group_name
self._feature_groups[x.group_name] = FeatureGroup(x)

self.sequence_feature_layer = sequence_feature_layer.SequenceFeatureLayer(
feature_configs, feature_groups_config, ev_params,
embedding_regularizer, kernel_regularizer, is_training)
Expand Down Expand Up @@ -79,7 +82,8 @@ def __call__(self, features, group_name, is_combine=True, is_dict=False):
is_combine: True
features: all features concatenate together
group_features: list of features
feature_name_to_output_tensors: dict, feature_name to feature_value, only present when is_dict is True
feature_name_to_output_tensors: dict, feature_name to feature_value,
only present when is_dict is True
is_combine: False
seq_features: list of sequence features, each element is a tuple:
3 dimension embedding tensor (batch_size, max_seq_len, embedding_dimension),
Expand Down Expand Up @@ -127,8 +131,9 @@ def __call__(self, features, group_name, is_combine=True, is_dict=False):
seq_features = []
embedding_reg_lst = []
for fc in group_seq_columns:
with variable_scope.variable_scope('input_layer/' +
fc.categorical_column.name):
with variable_scope.variable_scope(
group_name + '/' + fc.categorical_column.name,
reuse=variable_scope.AUTO_REUSE):
tmp_embedding, tmp_seq_len = fc._get_sequence_dense_tensor(builder)
if fc.max_seq_length > 0:
tmp_embedding, tmp_seq_len = shape_utils.truncate_sequence(
Expand Down
4 changes: 3 additions & 1 deletion easy_rec/python/model/dssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def build_predict_graph(self):
self._prediction_dict['logits'] = y_pred
self._prediction_dict['probs'] = tf.nn.sigmoid(y_pred)
elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
y_pred = self._mask_in_batch(y_pred)
if self._labels is not None:
y_pred = self._mask_in_batch(y_pred)
y_pred = self._mask_hist_seq(y_pred)
self._prediction_dict['logits'] = y_pred
self._prediction_dict['probs'] = tf.nn.softmax(y_pred)
else:
Expand Down
56 changes: 50 additions & 6 deletions easy_rec/python/model/match_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging

import numpy as np
import tensorflow as tf
from tensorflow.python.ops import array_ops

from easy_rec.python.builders import loss_builder
from easy_rec.python.model.easy_rec_model import EasyRecModel
Expand Down Expand Up @@ -66,6 +68,43 @@ def _mask_in_batch(self, logits):
else:
return logits

def _mask_hist_seq(self, logits):
if hasattr(self._model_config, 'hist_seq_masks') and \
len(self._model_config.hist_seq_masks) > 0 and \
self._item_ids is not None:
hist_seq_masks = self._model_config.hist_seq_masks
batch_size = tf.shape(logits)[0]
all_hist_seqs = [self._feature_dict[x] for x in hist_seq_masks]
all_hist_indices = [x.indices for x in all_hist_seqs]
all_hist_values = [x.values for x in all_hist_seqs]
all_hist_indices = array_ops.concat([x[:, 0] for x in all_hist_indices],
axis=0)
all_hist_values = array_ops.concat(all_hist_values, axis=0)

def _gen_hist_mask(all_hist_indices, all_hist_values, batch_size,
item_ids):
mask = np.zeros([batch_size, len(item_ids)], dtype=np.float32)
batch_hists = [{} for x in range(batch_size)]
for batch_id, value in zip(all_hist_indices, all_hist_values):
batch_hists[batch_id][value] = 1
for batch_idx in range(batch_size):
for item_idx, item_id in enumerate(item_ids):
if batch_idx == item_idx:
continue
if item_id in batch_hists[batch_idx]:
mask[batch_idx][item_idx] = 1
return mask

mask = tf.py_func(
_gen_hist_mask,
[all_hist_indices, all_hist_values, batch_size, self._item_ids],
Tout=tf.float32)
mask.set_shape(logits.get_shape())
logits = logits - mask * 1e32
return logits
else:
return logits

def _list_wise_sim(self, user_emb, item_emb):
batch_size = tf.shape(user_emb)[0]
hard_neg_indices = self._feature_dict.get('hard_neg_indices', None)
Expand Down Expand Up @@ -153,12 +192,17 @@ def _build_list_wise_loss_graph(self):
tf.log(hit_prob + 1e-12) * tf.squeeze(self._sample_weight))
logging.info('softmax cross entropy loss is used')

user_features = self._prediction_dict['user_tower_emb']
pos_item_emb = self._prediction_dict['item_tower_emb'][:batch_size]
pos_simi = tf.reduce_sum(user_features * pos_item_emb, axis=1)
# if pos_simi < 0, produce loss
reg_pos_loss = tf.nn.relu(-pos_simi)
self._loss_dict['reg_pos_loss'] = tf.reduce_mean(reg_pos_loss)
if hasattr(self._model_config,
'simi_pos_reg') and self._model_config.simi_pos_reg > 0:
logging.info('will add regularizations(%.3f) to constraint similarities'
' between user and target items to be positive.' %
self._model_config.simi_pos_reg)
user_features = self._prediction_dict['user_tower_emb']
pos_item_emb = self._prediction_dict['item_tower_emb'][:batch_size]
pos_simi = tf.reduce_sum(user_features * pos_item_emb, axis=1)
# if pos_simi < 0, produce loss
reg_pos_loss = tf.nn.relu(-pos_simi * self._model_config.simi_pos_reg)
self._loss_dict['reg_pos_loss'] = tf.reduce_mean(reg_pos_loss)
else:
raise ValueError('invalid loss type: %s' % str(self._loss_type))
return self._loss_dict
Expand Down
136 changes: 90 additions & 46 deletions easy_rec/python/model/mind.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import math

import tensorflow as tf
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope

from easy_rec.python.compat import regularizers
from easy_rec.python.layers import dnn
Expand Down Expand Up @@ -32,8 +37,8 @@ def __init__(self,
'invalid model config: %s' % self._model_config.WhichOneof('model')
self._model_config = self._model_config.mind

self._hist_seq_features = self._input_layer(
self._feature_dict, 'hist', is_combine=False)
self._init_seq_fea()

self._user_features, _ = self._input_layer(self._feature_dict, 'user')
self._item_features, _ = self._input_layer(self._feature_dict, 'item')

Expand All @@ -47,72 +52,110 @@ def __init__(self,
self._l2_reg = regularizers.l2_regularizer(
self._model_config.l2_regularization)

def build_predict_graph(self):
capsule_layer = CapsuleLayer(self._model_config.capsule_config,
self._is_training)

if self._model_config.time_id_fea:
time_id_fea = [
x[0]
for x in self._hist_seq_features
if self._model_config.time_id_fea in x[0].name
]
logging.info('time_id_fea is set(%s), find num: %d' %
(self._model_config.time_id_fea, len(time_id_fea)))
else:
time_id_fea = []
time_id_fea = time_id_fea[0] if len(time_id_fea) > 0 else None

if time_id_fea is not None:
hist_seq_feas = [
x[0]
for x in self._hist_seq_features
if self._model_config.time_id_fea not in x[0].name
]
def _init_seq_fea(self):
mind_seq_groups = list(self._model_config.seq_group_names)
if len(mind_seq_groups) <= 1:
group_name = 'hist' if len(self._model_config.seq_group_names) == 0 else \
self._model_config.seq_group_names[0]
hist_seq_feas = self._input_layer(
self._feature_dict, group_name, is_combine=False)
self._hist_seq_len = hist_seq_feas[0][1]
if len(self._model_config.time_id_fea) > 0:
time_fea_name = self._model_config.time_id_fea[0]
self._time_fea = [
x[0] for x in hist_seq_feas if time_fea_name in x[0].name
][0]
hist_seq_feas = [
x[0] for x in hist_seq_feas if time_fea_name not in x[0].name
]
else:
self._time_fea = None
self._hist_seq_fea = self._combine_multi_seq(hist_seq_feas)
else:
hist_seq_feas = [x[0] for x in self._hist_seq_features]

# it is assumed that all hist have the same length
hist_seq_len = self._hist_seq_features[0][1]

logging.info('mind_seq_groups[num=%d]:%s' %
(len(mind_seq_groups), ','.join(mind_seq_groups)))
mind_seq_type_embed_dim = 4
with variable_scope.variable_scope('mind_user_seq_type'):
seq_type_var = variable_scope.get_variable(
name='embedding_weights',
shape=[len(mind_seq_groups), self._model_config.seq_type_embed_dim],
dtype=dtypes.float32,
initializer=init_ops.truncated_normal_initializer(
mean=0, stddev=1e-2 / math.sqrt(mind_seq_type_embed_dim)),
trainable=self._is_training)

# multiple sequences
all_hist_seqs = []
all_time_feas = []
all_hist_seq_lens = []
for group_id, group_name in enumerate(self._model_config.seq_group_names):
hist_seq_feas = self._input_layer(
self._feature_dict, group_name, is_combine=False)
hist_seq_len = hist_seq_feas[0][1]
# batch_size, seq_len, embedding_dim
batch_size = array_ops.shape(hist_seq_feas[0][0])[0]
batch_seq_len = array_ops.shape(hist_seq_feas[0][0])[1]
if len(self._model_config.time_id_fea) > group_id:
time_fea_name = self._model_config.time_id_fea[group_id]
all_time_feas.append(
[x[0] for x in hist_seq_feas if time_fea_name in x[0].name][0])
hist_seq_feas = [
x for x in hist_seq_feas if time_fea_name not in x[0].name
]
seq_type = array_ops.tile(seq_type_var[group_id, :][None, None, :],
[batch_size, batch_seq_len, 1])
hist_seq_feas.append([seq_type, hist_seq_len])
all_hist_seqs.append(self._combine_multi_seq(hist_seq_feas))
all_hist_seq_lens.append(hist_seq_len)
self._hist_seq_fea = array_ops.concat(all_hist_seqs, axis=1)
self._hist_seq_len = tf.add_n(all_hist_seq_lens)
if len(all_time_feas) > 0:
self._time_fea = array_ops.concat(all_time_feas, axis=1)
else:
self._time_fea = None
return True

def _combine_multi_seq(self, hist_seq_feas):
hist_seq_feas = [x[0] for x in hist_seq_feas]
if self._model_config.user_seq_combine == MINDConfig.SUM:
# sum pooling over the features
hist_embed_dims = [x.get_shape()[-1] for x in hist_seq_feas]
for i in range(1, len(hist_embed_dims)):
assert hist_embed_dims[i] == hist_embed_dims[0], \
'all hist seq must have the same embedding shape, but: %s' \
% str(hist_embed_dims)
hist_seq_feas = tf.add_n(hist_seq_feas) / len(hist_seq_feas)
return tf.add_n(hist_seq_feas) / len(hist_seq_feas)
else:
hist_seq_feas = tf.concat(hist_seq_feas, axis=2)
return tf.concat(hist_seq_feas, axis=2)

def build_predict_graph(self):
capsule_layer = CapsuleLayer(self._model_config.capsule_config,
self._is_training)

if self._model_config.HasField('pre_capsule_dnn') and \
len(self._model_config.pre_capsule_dnn.hidden_units) > 0:
pre_dnn_layer = dnn.DNN(self._model_config.pre_capsule_dnn, self._l2_reg,
'pre_capsule_dnn', self._is_training)
hist_seq_feas = pre_dnn_layer(hist_seq_feas)
hist_seq_feas = pre_dnn_layer(self._hist_seq_fea)
else:
hist_seq_feas = self._hist_seq_fea

if time_id_fea is not None:
assert time_id_fea.get_shape(
if self._time_fea is not None:
assert self._time_fea.get_shape(
)[-1] == 1, 'time_id must have only embedding_size of 1'
time_id_mask = tf.sequence_mask(hist_seq_len, tf.shape(time_id_fea)[1])
time_id_mask = (tf.cast(time_id_mask, tf.float32) * 2 - 1) * 1e32
time_id_fea = tf.minimum(time_id_fea, time_id_mask[:, :, None])
hist_seq_feas = hist_seq_feas * tf.nn.softmax(time_id_fea, axis=1)
time_mask = tf.sequence_mask(self._hist_seq_len,
tf.shape(self._time_fea)[1])
time_mask = (tf.cast(time_mask, tf.float32) * 2 - 1) * 1e32
time_fea = tf.minimum(self._time_fea, time_mask[:, :, None])
hist_seq_feas = hist_seq_feas * tf.nn.softmax(time_fea, axis=1)

tf.summary.histogram('hist_seq_len', hist_seq_len)
tf.summary.histogram('hist_seq_len', self._hist_seq_len)

# batch_size x max_k x high_capsule_dim
high_capsules, num_high_capsules = capsule_layer(hist_seq_feas,
hist_seq_len)
self._hist_seq_len)

tf.summary.histogram('num_high_capsules', num_high_capsules)

# high_capsules = tf.layers.batch_normalization(
# high_capsules, training=self._is_training,
# trainable=True, name='capsule_bn')
# high_capsules = high_capsules * 0.1

tf.summary.scalar('high_capsules_norm',
tf.reduce_mean(tf.norm(high_capsules, axis=-1)))
tf.summary.scalar('num_high_capsules',
Expand Down Expand Up @@ -224,6 +267,7 @@ def build_predict_graph(self):
self._prediction_dict['probs'] = tf.nn.sigmoid(y_pred)
elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
y_pred = self._mask_in_batch(y_pred)
y_pred = self._mask_hist_seq(y_pred)
self._prediction_dict['logits'] = y_pred
self._prediction_dict['probs'] = tf.nn.softmax(y_pred)
else:
Expand Down
4 changes: 4 additions & 0 deletions easy_rec/python/protos/dssm.proto
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,8 @@ message DSSM {
optional bool scale_simi = 5 [default = true];
optional string item_id = 9;
required bool ignore_in_batch_neg_sam = 10 [default = false];

// mask negative sampled items by hist sequence
// which are actually positive ones
repeated string hist_seq_masks = 11;
}
25 changes: 22 additions & 3 deletions easy_rec/python/protos/mind.proto
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ message MIND {
CONCAT = 0;
SUM = 1;
}

// to support multiple sequences:
// the sequences will be concatenated together
// if not specified, default sequence group is 'hist'
repeated string seq_group_names = 100;

// preprocessing dnn before entering capsule layer
optional DNN pre_capsule_dnn = 101;

Expand Down Expand Up @@ -60,14 +66,27 @@ message MIND {

required float l2_regularization = 7 [default = 1e-4];

optional string time_id_fea = 8;
repeated string time_id_fea = 8;

optional string item_id = 9;

optional bool ignore_in_batch_neg_sam = 10 [default = false];
// used if multiple sequence group are used:
// sequence type embedding is used to differentiate
// multiple bhv sequences such as clk, buy, ...
optional int32 seq_type_embed_dim = 10 [default=4];

optional bool ignore_in_batch_neg_sam = 11 [default = false];

// if small than 1.0, then a loss will be added to
// limit the maximal interest similarities, but
// in experiments, setup such a loss leads to low hitrate.
optional float max_interests_simi = 11 [default = 1.0];
optional float max_interests_simi = 12 [default = 1.0];

// require similarities between user embeddings and
// target items to be positive
optional float simi_pos_reg = 13 [default=0.0];

// mask negative sampled items by hist sequence
// which are actually positive ones
repeated string hist_seq_masks = 14;
}
Loading