Skip to content

Commit

Permalink
add custom op demo
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxudong committed Oct 20, 2023
1 parent bc9f2f5 commit 1fb8ac0
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions easy_rec/python/layers/keras/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@

import easy_rec

LIB_PATH = tf.sysconfig.get_link_flags()[0][2:]
LD_LIBRARY_PATH = os.getenv('LD_LIBRARY_PATH')
if LIB_PATH not in LD_LIBRARY_PATH:
os.environ['LD_LIBRARY_PATH'] = ':'.join([LIB_PATH, LD_LIBRARY_PATH])
logging.info('set LD_LIBRARY_PATH=%s' % os.getenv('LD_LIBRARY_PATH'))

# LIB_PATH = tf.sysconfig.get_link_flags()[0][2:]
# LD_LIBRARY_PATH = os.getenv('LD_LIBRARY_PATH')
# if LIB_PATH not in LD_LIBRARY_PATH:
# os.environ['LD_LIBRARY_PATH'] = ':'.join([LIB_PATH, LD_LIBRARY_PATH])
# logging.info('set LD_LIBRARY_PATH=%s' % os.getenv('LD_LIBRARY_PATH'))

if tf.__version__ >= '2.0':
tf = tf.compat.v1
Expand Down Expand Up @@ -46,7 +45,12 @@ def __init__(self, params, name='edit_distance', reuse=None, **kwargs):
def call(self, inputs, training=None, **kwargs):
input1, input2 = inputs[:2]
with ops.device('/CPU:0'):
dist = self.edit_distance(input1, input2, normalize=False, dtype=tf.int32, encoding=self.txt_encoding)
dist = self.edit_distance(
input1,
input2,
normalize=False,
dtype=tf.int32,
encoding=self.txt_encoding)
ids = tf.clip_by_value(dist, 0, self.emb_size - 1)
embed = tf.nn.embedding_lookup(self.embedding_table, ids)
return embed

0 comments on commit 1fb8ac0

Please sign in to comment.