Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat]: Add sequence encoder to easy_rec_model object #352

Open
wants to merge 63 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
bd86e2f
[feat]: add sequence encoding module
yangxudong Mar 10, 2023
2b8f2e7
[feat]: add sequence encoding module
yangxudong Mar 10, 2023
e666f41
[feat]: add sequence encoding module
yangxudong Mar 10, 2023
778e70e
[feat]: add sequence encoding module
yangxudong Mar 12, 2023
0254902
[feat]: add pairwise logistic loss
yangxudong Mar 18, 2023
6b54fe7
[feat]: add pairwise logistic loss
yangxudong Mar 20, 2023
d65ece3
[feat]: add pairwise logistic loss
yangxudong Mar 21, 2023
d2793df
[feat]: add pairwise logistic loss
yangxudong Mar 21, 2023
31e2502
[feat]: add pairwise logistic loss
yangxudong Mar 21, 2023
7eb9c5c
[feat]: add pairwise logistic loss
yangxudong Mar 21, 2023
547c807
[feat]: add jrc loss
yangxudong Apr 4, 2023
c7476bb
[feat]: add jrc loss
yangxudong Apr 4, 2023
7f6ee53
[feat]: add jrc loss
yangxudong Apr 4, 2023
98f9ec4
[feat]: add jrc loss
yangxudong Apr 7, 2023
9586212
[feat]: add jrc loss
yangxudong Apr 7, 2023
381c62b
[feat]: add more logit
yangxudong Apr 23, 2023
9fdd3bb
Merge branch 'master' of https://github.com/alibaba/EasyRec into feat…
yangxudong May 1, 2023
e27b121
[feat]: add attention normalizer for din
yangxudong May 1, 2023
e834050
[feat]: add dice activation
yangxudong May 4, 2023
05d0e64
[feat]: add dice activation for dnn layer
yangxudong May 5, 2023
23962b2
[feat]: add FSCD layer
yangxudong May 5, 2023
5dfb29f
[feat]: add dice activation for dnn layer
yangxudong May 5, 2023
39b4c9a
Merge branch 'master' of https://github.com/alibaba/EasyRec into feat…
yangxudong May 5, 2023
51428ce
[feat]: add dice activation for dnn layer
yangxudong May 8, 2023
8509174
[feat]: add const feature column
yangxudong May 8, 2023
eba4219
[feat]: add feature selection tool
yangxudong May 9, 2023
1d26d13
Merge branch 'master' of https://github.com/alibaba/EasyRec into feat…
yangxudong May 9, 2023
ee52de3
Merge branch 'master' of https://github.com/alibaba/EasyRec into feat…
yangxudong May 9, 2023
c27c1d8
[feat]: add feature selection tool
yangxudong May 11, 2023
524ce67
[feat]: add feature selection tool
yangxudong May 15, 2023
3a8d732
[feat]: add fibinet & masknet
yangxudong May 25, 2023
90bdc50
Merge branch 'master' of https://github.com/alibaba/EasyRec into feat…
yangxudong May 25, 2023
4d91ff1
Merge branch 'master' of https://github.com/alibaba/EasyRec into feat…
yangxudong Jun 1, 2023
48601c7
[feat]: add backbone network
yangxudong Jun 9, 2023
5a47eb8
[feat]: add backbone network
yangxudong Jun 12, 2023
0260333
Merge branch 'master' of https://github.com/alibaba/EasyRec into feat…
yangxudong Jun 12, 2023
b1cb609
[feat]: add backbone network
yangxudong Jun 12, 2023
383cbed
[feat]: add test config for backbone network
yangxudong Jun 12, 2023
1114aab
[feat]: add more backbone blocks
yangxudong Jun 14, 2023
96d502e
[feat]: add more backbone blocks
yangxudong Jun 16, 2023
5cf7d8f
[feat]: add more backbone blocks
yangxudong Jun 18, 2023
9234140
[feat]: add more backbone blocks
yangxudong Jun 18, 2023
7d0e350
[feat]: add more backbone blocks
yangxudong Jun 19, 2023
136cf37
[feat]: format backbone code, add recurrent and sequential layer
yangxudong Jun 19, 2023
e795f00
[feat]: format backbone code, add recurrent and sequential layer
yangxudong Jun 19, 2023
c4f5ea9
[feat]: format backbone code, add recurrent and sequential layer
yangxudong Jun 20, 2023
1b504a8
[feat]: add repeat block
yangxudong Jun 20, 2023
32ff01c
fix bug of no is_predicting argument
yangxudong Jun 20, 2023
0c087d9
fix bug of no is_predicting argument
yangxudong Jun 20, 2023
af871b3
fix deepfm distribute eval test case
yangxudong Jun 20, 2023
5813c0e
modify
yangxudong Jun 22, 2023
d1d16ac
Merge branch 'master' of https://github.com/alibaba/EasyRec into feat…
yangxudong Jun 23, 2023
0c85dd2
add gate layer
yangxudong Jun 24, 2023
3ed293a
add gate layer
yangxudong Jun 24, 2023
a9aff75
add gate layer
yangxudong Jun 25, 2023
5cc6efe
Merge branch 'master' of https://github.com/alibaba/EasyRec into feat…
yangxudong Jun 27, 2023
2847347
add block package for reuse sub network
yangxudong Jun 28, 2023
ee49dbe
add block package for reuse sub network
yangxudong Jun 28, 2023
2c591fc
fix a bug
yangxudong Jun 30, 2023
a1f0b8a
fix bug of input layer block
yangxudong Jun 30, 2023
ef5f9cd
fix bug of input layer block
yangxudong Jun 30, 2023
a3944f9
fix bug of input layer block
yangxudong Jul 12, 2023
faf8ddf
upgrade to new version
yangxudong Jul 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions easy_rec/python/builders/loss_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,18 @@ def build(loss_type,
return tf.losses.mean_squared_error(
labels=label, predictions=pred, weights=loss_weight, **kwargs)
elif loss_type == LossType.JRC_LOSS:
alpha = 0.5 if loss_param is None else loss_param.alpha
auto_weight = False if loss_param is None else not loss_param.HasField(
'alpha')
session = kwargs.get('session_ids', None)
if loss_param is None:
return jrc_loss(label, pred, session, name=loss_name)
return jrc_loss(
label, pred, session, alpha, auto_weight=auto_weight, name=loss_name)
label,
pred,
session,
loss_param.alpha,
loss_weight_strategy=loss_param.loss_weight_strategy,
sample_weights=loss_weight,
same_label_loss=loss_param.same_label_loss,
name=loss_name)
elif loss_type == LossType.PAIR_WISE_LOSS:
session = kwargs.get('session_ids', None)
margin = 0 if loss_param is None else loss_param.margin
Expand Down
229 changes: 229 additions & 0 deletions easy_rec/python/compat/array_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import gen_math_ops


def convert_to_int_tensor(tensor, name, dtype=tf.int32):
"""Converts the given value to an integer Tensor."""
tensor = ops.convert_to_tensor(tensor, name=name, preferred_dtype=dtype)
if tensor.dtype.is_integer:
tensor = gen_math_ops.cast(tensor, dtype)
else:
raise TypeError('%s must be an integer tensor; dtype=%s' %
(name, tensor.dtype))
return tensor


def _with_nonzero_rank(data):
"""If `data` is scalar, then add a dimension; otherwise return as-is."""
if data.shape.ndims is not None:
if data.shape.ndims == 0:
return tf.stack([data])
else:
return data
else:
data_shape = tf.shape(data)
data_ndims = tf.rank(data)
return tf.reshape(data, tf.concat([[1], data_shape], axis=0)[-data_ndims:])


def get_positive_axis(axis, ndims):
"""Validate an `axis` parameter, and normalize it to be positive.

If `ndims` is known (i.e., not `None`), then check that `axis` is in the
range `-ndims <= axis < ndims`, and return `axis` (if `axis >= 0`) or
`axis + ndims` (otherwise).
If `ndims` is not known, and `axis` is positive, then return it as-is.
If `ndims` is not known, and `axis` is negative, then report an error.

Args:
axis: An integer constant
ndims: An integer constant, or `None`

Returns:
The normalized `axis` value.

Raises:
ValueError: If `axis` is out-of-bounds, or if `axis` is negative and
`ndims is None`.
"""
if not isinstance(axis, int):
raise TypeError('axis must be an int; got %s' % type(axis).__name__)
if ndims is not None:
if 0 <= axis < ndims:
return axis
elif -ndims <= axis < 0:
return axis + ndims
else:
raise ValueError('axis=%s out of bounds: expected %s<=axis<%s' %
(axis, -ndims, ndims))
elif axis < 0:
raise ValueError('axis may only be negative if ndims is statically known.')
return axis


def tile_one_dimension(data, axis, multiple):
"""Tiles a single dimension of a tensor."""
# Assumes axis is a nonnegative int.
if data.shape.ndims is not None:
multiples = [1] * data.shape.ndims
multiples[axis] = multiple
else:
ones_value = tf.ones(tf.rank(data), tf.int32)
multiples = tf.concat(
[ones_value[:axis], [multiple], ones_value[axis + 1:]], axis=0)
return tf.tile(data, multiples)


def _all_dimensions(x):
"""Returns a 1D-tensor listing all dimensions in x."""
# Fast path: avoid creating Rank and Range ops if ndims is known.
if isinstance(x, ops.Tensor) and x.get_shape().ndims is not None:
return constant_op.constant(np.arange(x.get_shape().ndims), dtype=tf.int32)
if (isinstance(x, sparse_tensor.SparseTensor) and
x.dense_shape.get_shape().is_fully_defined()):
r = x.dense_shape.get_shape().dims[0].value # sparse.dense_shape is 1-D.
return constant_op.constant(np.arange(r), dtype=tf.int32)

# Otherwise, we rely on `range` and `rank` to do the right thing at runtime.
return gen_math_ops._range(0, tf.rank(x), 1)


# This op is intended to exactly match the semantics of numpy.repeat, with
# one exception: numpy.repeat has special (and somewhat non-intuitive) behavior
# when axis is not specified. Rather than implement that special behavior, we
# simply make `axis` be a required argument.
#
# External (OSS) `tf.repeat` feature request:
# https://github.com/tensorflow/tensorflow/issues/8246
def repeat_with_axis(data, repeats, axis, name=None):
"""Repeats elements of `data`.

Args:
data: An `N`-dimensional tensor.
repeats: A 1-D integer tensor specifying how many times each element in
`axis` should be repeated. `len(repeats)` must equal `data.shape[axis]`.
Supports broadcasting from a scalar value.
axis: `int`. The axis along which to repeat values. Must be less than
`max(N, 1)`.
name: A name for the operation.

Returns:
A tensor with `max(N, 1)` dimensions. Has the same shape as `data`,
except that dimension `axis` has size `sum(repeats)`.
#### Examples:
```python
>>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0)
['a', 'a', 'a', 'c', 'c']
>>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0)
[[1, 2], [1, 2], [3, 4], [3, 4], [3, 4]]
>>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1)
[[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]]
```
"""
if not isinstance(axis, int):
raise TypeError('axis must be an int; got %s' % type(axis).__name__)

with ops.name_scope(name, 'Repeat', [data, repeats]):
data = ops.convert_to_tensor(data, name='data')
repeats = convert_to_int_tensor(repeats, name='repeats')
repeats.shape.with_rank_at_most(1)

# If `data` is a scalar, then upgrade it to a vector.
data = _with_nonzero_rank(data)
data_shape = tf.shape(data)

# If `axis` is negative, then convert it to a positive value.
axis = get_positive_axis(axis, data.shape.ndims)

# Check data Tensor shapes.
if repeats.shape.ndims == 1:
data.shape.dims[axis].assert_is_compatible_with(repeats.shape[0])

# If we know that `repeats` is a scalar, then we can just tile & reshape.
if repeats.shape.ndims == 0:
expanded = tf.expand_dims(data, axis + 1)
tiled = tile_one_dimension(expanded, axis + 1, repeats)
result_shape = tf.concat([data_shape[:axis], [-1], data_shape[axis + 1:]],
axis=0)
return tf.reshape(tiled, result_shape)

# Broadcast the `repeats` tensor so rank(repeats) == axis + 1.
if repeats.shape.ndims != axis + 1:
repeats_shape = tf.shape(repeats)
repeats_ndims = tf.rank(repeats)
broadcast_shape = tf.concat(
[data_shape[:axis + 1 - repeats_ndims], repeats_shape], axis=0)
repeats = tf.broadcast_to(repeats, broadcast_shape)
repeats.set_shape([None] * (axis + 1))

# Create a "sequence mask" based on `repeats`, where slices across `axis`
# contain one `True` value for each repetition. E.g., if
# `repeats = [3, 1, 2]`, then `mask = [[1, 1, 1], [1, 0, 0], [1, 1, 0]]`.
max_repeat = gen_math_ops.maximum(
0, gen_math_ops._max(repeats, _all_dimensions(repeats)))
mask = tf.sequence_mask(repeats, max_repeat)

# Add a new dimension around each value that needs to be repeated, and
# then tile that new dimension to match the maximum number of repetitions.
expanded = tf.expand_dims(data, axis + 1)
tiled = tile_one_dimension(expanded, axis + 1, max_repeat)

# Use `boolean_mask` to discard the extra repeated values. This also
# flattens all dimensions up through `axis`.
masked = tf.boolean_mask(tiled, mask)

# Reshape the output tensor to add the outer dimensions back.
if axis == 0:
result = masked
else:
result_shape = tf.concat([data_shape[:axis], [-1], data_shape[axis + 1:]],
axis=0)
result = tf.reshape(masked, result_shape)

# Preserve shape information.
if data.shape.ndims is not None:
new_axis_size = 0 if repeats.shape[0] == 0 else None
result.set_shape(data.shape[:axis].concatenate(
[new_axis_size]).concatenate(data.shape[axis + 1:]))

return result


def repeat(input, repeats, axis=None, name=None): # pylint: disable=redefined-builtin
"""Repeat elements of `input`.

Args:
input: An `N`-dimensional Tensor.
repeats: An 1-D `int` Tensor. The number of repetitions for each element.
repeats is broadcasted to fit the shape of the given axis. `len(repeats)`
must equal `input.shape[axis]` if axis is not None.
axis: An int. The axis along which to repeat values. By default (axis=None),
use the flattened input array, and return a flat output array.
name: A name for the operation.

Returns:
A Tensor which has the same shape as `input`, except along the given axis.
If axis is None then the output array is flattened to match the flattened
input array.
#### Examples:
```python
>>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0)
['a', 'a', 'a', 'c', 'c']
>>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0)
[[1, 2], [1, 2], [3, 4], [3, 4], [3, 4]]
>>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1)
[[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]]
>>> repeat(3, repeats=4)
[3, 3, 3, 3]
>>> repeat([[1,2], [3,4]], repeats=2)
[1, 1, 2, 2, 3, 3, 4, 4]
```
"""
if axis is None:
input = tf.reshape(input, [-1])
axis = 0
return repeat_with_axis(input, repeats, axis, name)
61 changes: 53 additions & 8 deletions easy_rec/python/compat/feature_column/feature_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ def _internal_input_layer(features,
scope=None,
cols_to_output_tensors=None,
from_template=False,
feature_name_to_output_tensors=None):
feature_name_to_output_tensors=None,
sort_feature_columns_by_name=True):
"""See input_layer, `scope` is a name or variable scope to use."""
feature_columns = _normalize_feature_columns(feature_columns)
for column in feature_columns:
Expand All @@ -195,9 +196,11 @@ def _internal_input_layer(features,
def _get_logits(): # pylint: disable=missing-docstring
builder = _LazyBuilder(features)
output_tensors = []
ordered_columns = []
for column in sorted(feature_columns, key=lambda x: x.name):
ordered_columns.append(column)
if sort_feature_columns_by_name:
ordered_columns = sorted(feature_columns, key=lambda x: x.name)
else:
ordered_columns = feature_columns
for column in ordered_columns:
with variable_scope.variable_scope(
None, default_name=column._var_scope_name): # pylint: disable=protected-access
tensor = column._get_dense_tensor( # pylint: disable=protected-access
Expand Down Expand Up @@ -239,7 +242,8 @@ def input_layer(features,
trainable=True,
cols_to_vars=None,
cols_to_output_tensors=None,
feature_name_to_output_tensors=None):
feature_name_to_output_tensors=None,
sort_feature_columns_by_name=True):
"""Returns a dense `Tensor` as input layer based on given `feature_columns`.

Generally a single example in training data is described with FeatureColumns.
Expand Down Expand Up @@ -287,6 +291,7 @@ def input_layer(features,
cols_to_output_tensors: If not `None`, must be a dictionary that will be
filled with a mapping from '_FeatureColumn' to the associated
output `Tensor`s.
sort_feature_columns_by_name: whether to sort feature columns

Returns:
A `Tensor` which represents input layer of a model. Its shape
Expand All @@ -303,7 +308,8 @@ def input_layer(features,
trainable=trainable,
cols_to_vars=cols_to_vars,
cols_to_output_tensors=cols_to_output_tensors,
feature_name_to_output_tensors=feature_name_to_output_tensors)
feature_name_to_output_tensors=feature_name_to_output_tensors,
sort_feature_columns_by_name=sort_feature_columns_by_name)


# TODO(akshayka): InputLayer should be a subclass of Layer, and it
Expand Down Expand Up @@ -2530,7 +2536,46 @@ def name(self):

@property
def raw_name(self):
return self.categorical_column.name
return self.categorical_column.raw_name

@property
def cardinality(self):
from easy_rec.python.compat.feature_column.feature_column_v2 import HashedCategoricalColumn, \
BucketizedColumn, WeightedCategoricalColumn, SequenceWeightedCategoricalColumn, \
CrossedColumn, IdentityCategoricalColumn, VocabularyListCategoricalColumn, \
VocabularyFileCategoricalColumn

fc = self.categorical_column
if isinstance(fc, HashedCategoricalColumn) or isinstance(fc, CrossedColumn):
return fc.hash_bucket_size

if isinstance(fc, IdentityCategoricalColumn):
return fc.num_buckets

if isinstance(fc, BucketizedColumn):
return len(fc.boundaries) + 1

if isinstance(fc, VocabularyListCategoricalColumn):
return len(fc.vocabulary_list) + fc.num_oov_buckets

if isinstance(fc, VocabularyFileCategoricalColumn):
return len(fc.vocabulary_size) + fc.num_oov_buckets

if isinstance(fc, WeightedCategoricalColumn) or isinstance(
fc, SequenceWeightedCategoricalColumn):
sub_fc = fc.categorical_column
if isinstance(sub_fc, HashedCategoricalColumn) or isinstance(
sub_fc, CrossedColumn):
return sub_fc.hash_bucket_size
if isinstance(sub_fc, IdentityCategoricalColumn):
return sub_fc.num_buckets
if isinstance(sub_fc, VocabularyListCategoricalColumn):
return len(sub_fc.vocabulary_list) + fc.num_oov_buckets
if isinstance(sub_fc, VocabularyFileCategoricalColumn):
return len(sub_fc.vocabulary_size) + fc.num_oov_buckets
if isinstance(sub_fc, BucketizedColumn):
return len(sub_fc.boundaries) + 1
return 1

@property
def _var_scope_name(self):
Expand Down Expand Up @@ -2605,7 +2650,7 @@ def _get_dense_tensor_internal(self,
# get zero embedding
import os
if os.environ.get('tf.estimator.mode', '') != \
os.environ.get('tf.estimator.ModeKeys.TRAIN', 'train'):
os.environ.get('tf.estimator.ModeKeys.TRAIN', 'train'):
initializer = init_ops.zeros_initializer()
else:
initializer = self.initializer
Expand Down
Loading