Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[temporary]: add sok #400

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 126 additions & 1 deletion easy_rec/python/compat/feature_column/feature_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@
import abc
import collections
import math
import os

import numpy as np
import six
Expand Down Expand Up @@ -168,6 +169,16 @@
from easy_rec.python.compat import embedding_ops as ev_embedding_ops
from easy_rec.python.compat.feature_column import utils as fc_utils

try:
from sparse_operation_kit import experiment as sok
except Exception:
sok = None

try:
import horovod.tensorflow as hvd
except Exception:
hvd = None


def _internal_input_layer(features,
feature_columns,
Expand Down Expand Up @@ -222,6 +233,117 @@ def _get_logits(): # pylint: disable=missing-docstring
_verify_static_batch_size_equality(output_tensors, ordered_columns)
return array_ops.concat(output_tensors, 1)

def _get_logits_with_sok(): # pylint: disable=missing-docstring
assert sok is not None, 'sok is not installed'
assert hvd is not None, 'horovod is not installed'
builder = _LazyBuilder(features)
output_tensors = []
ordered_columns = []

lookup_embeddings = []
lookup_indices = []
lookup_combiners = []
lookup_cols = []
lookup_output_ids = []

lookup_embeddings_with_wgt = []
lookup_indices_with_wgt = []
lookup_wgts = []
lookup_cols_with_wgt = []
lookup_combiners_with_wgt = []
lookup_output_ids_with_wgt = []
shared_weights = {}
for column in sorted(feature_columns, key=lambda x: x.name):
ordered_columns.append(column)
with variable_scope.variable_scope(
None, default_name=column._var_scope_name): # pylint: disable=protected-access
if 'Embedding' not in str(type(column)):
output_tensors.append(
column._get_dense_tensor(
builder, weight_collections, trainable=trainable))
continue
num_buckets = column.categorical_column.num_buckets + hvd.size() - 1
per_worker_buckets = num_buckets // hvd.size()
embedding_shape = (per_worker_buckets, column.dimension)
if 'SharedEmbedding' in str(type(column)):
shared_name = column.shared_embedding_collection_name
if shared_name in shared_weights:
embedding_weights = shared_weights[shared_name]
else:
embedding_weights = variable_scope.get_variable(
name='embedding_weights',
shape=embedding_shape,
dtype=dtypes.float32,
initializer=column.initializer,
trainable=column.trainable and trainable,
partitioner=column.partitioner,
collections=weight_collections)
shared_weights[shared_name] = embedding_weights
else:
embedding_weights = variable_scope.get_variable(
name='embedding_weights',
shape=embedding_shape,
dtype=dtypes.float32,
initializer=column.initializer,
trainable=column.trainable and trainable,
partitioner=column.partitioner,
collections=weight_collections)
# required by sok
embedding_weights.target_gpu = -1
sparse_tensors = column.categorical_column._get_sparse_tensors(
builder, weight_collections=weight_collections, trainable=trainable)
output_id = len(output_tensors)
output_tensors.append(None)
if sparse_tensors.weight_tensor is not None:
lookup_embeddings_with_wgt.append(embedding_weights)
lookup_indices_with_wgt.append(sparse_tensors.id_tensor)
lookup_wgts.append(sparse_tensors.weight_tensor)
lookup_output_ids_with_wgt.append(output_id)
lookup_combiners_with_wgt.append(column.combiner)
lookup_cols_with_wgt.append(column)
else:
lookup_embeddings.append(embedding_weights)
lookup_indices.append(sparse_tensors.id_tensor)
lookup_output_ids.append(output_id)
lookup_combiners.append(column.combiner)
lookup_cols.append(column)
if cols_to_vars is not None:
cols_to_vars[column] = ops.get_collection(
ops.GraphKeys.GLOBAL_VARIABLES,
scope=variable_scope.get_variable_scope().name)

# do sok lookup
if len(lookup_output_ids) > 0:
outputs = sok.lookup_sparse(
lookup_embeddings, lookup_indices, combiners=lookup_combiners)
for output, output_id, col in zip(outputs, lookup_output_ids,
lookup_cols):
output_tensors[output_id] = output
if cols_to_output_tensors is not None:
cols_to_output_tensors[col] = output
if feature_name_to_output_tensors is not None:
feature_name_to_output_tensors[column.raw_name] = output
else:
outputs = sok.lookup_sparse(
lookup_embeddings_with_wgt,
lookup_indices_with_wgt,
lookup_wgts,
combiners=lookup_combiners_with_wgt)
for output, output_id, col in zip(outputs, lookup_output_ids_with_wgt,
lookup_cols_with_wgt):
output_tensors[output_id] = output
if cols_to_output_tensors is not None:
cols_to_output_tensors[col] = output
if feature_name_to_output_tensors is not None:
feature_name_to_output_tensors[column.raw_name] = output

if feature_name_to_output_tensors is not None:
for column, output_tensor in zip(
sorted(feature_columns, key=lambda x: x.name), output_tensors):
feature_name_to_output_tensors[column.raw_name] = output_tensor
_verify_static_batch_size_equality(output_tensors, ordered_columns)
return array_ops.concat(output_tensors, 1)

# If we're constructing from the `make_template`, that by default adds a
# variable scope with the name of the layer. In that case, we dont want to
# add another `variable_scope` as that would break checkpoints.
Expand All @@ -230,7 +352,10 @@ def _get_logits(): # pylint: disable=missing-docstring
else:
with variable_scope.variable_scope(
scope, default_name='input_layer', values=features.values()):
return _get_logits()
if 'ENABLE_SOK' in os.environ:
return _get_logits_with_sok()
else:
return _get_logits()


def input_layer(features,
Expand Down
25 changes: 24 additions & 1 deletion easy_rec/python/compat/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@
from easy_rec.python.ops.incr_record import set_sparse_indices
from easy_rec.python.utils import estimator_utils

try:
import horovod.tensorflow as hvd
except Exception:
hvd = None

OPTIMIZER_CLS_NAMES = {
'Adagrad':
train.AdagradOptimizer,
Expand Down Expand Up @@ -254,6 +259,24 @@ def optimize_loss(loss,
variables,
colocate_gradients_with_ops=colocate_gradients_with_ops)

if estimator_utils.has_hvd():
if not estimator_utils.has_sok():
reduced_grads = []
for g, v in gradients:
reduced_grads.append((hvd.allreduce(
g, op=hvd.Average,
compression=hvd.compression.NoneCompressor), v))
gradients = reduced_grads
else:
reduced_grads = []
for g, v in gradients:
if '/embedding' not in v.name:
reduced_grads.append((hvd.allreduce(
g, op=hvd.Average,
compression=hvd.compression.NoneCompressor), v))
else:
reduced_grads.append((g, v))

# Optionally add gradient noise.
if gradient_noise_scale is not None:
gradients = _add_scaled_noise_to_gradients(gradients,
Expand Down Expand Up @@ -304,7 +327,7 @@ def optimize_loss(loss,
summary.scalar('global_norm/clipped_gradient_norm',
clip_ops.global_norm(list(zip(*gradients))[0]))

task_index, _ = estimator_utils.get_task_index_and_num()
# task_index, _ = estimator_utils.get_task_index_and_num()

# Create gradient updates.
def _apply_grad():
Expand Down
15 changes: 8 additions & 7 deletions easy_rec/python/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,14 @@ def _create_estimator(pipeline_config, distribution=None, params={}):
gpu_options = GPUOptions(allow_growth=False)

if hvd is not None:
gpus = estimator_utils.get_available_gpus()
if len(gpus) > 0:
local_rnk = hvd.local_rank()
num_gpus_per_worker = pipeline_config.train_config.num_gpus_per_worker
sid = local_rnk * num_gpus_per_worker
eid = sid + num_gpus_per_worker
gpu_options.visible_device_list = ','.join(gpus[sid:eid])
# gpus = estimator_utils.get_available_gpus()
# if len(gpus) > 0:
local_rnk = hvd.local_rank()
# num_gpus_per_worker = pipeline_config.train_config.num_gpus_per_worker
# sid = local_rnk * num_gpus_per_worker
# eid = sid + num_gpus_per_worker
logging.info('local_rnk=%d' % local_rnk)
gpu_options.visible_device_list = str(local_rnk) # ','.join(gpus[sid:eid])

session_config = ConfigProto(
gpu_options=gpu_options,
Expand Down
17 changes: 8 additions & 9 deletions easy_rec/python/model/easy_rec_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from easy_rec.python.protos.train_pb2 import DistributionStrategy
from easy_rec.python.utils import constant
from easy_rec.python.utils import estimator_utils
from easy_rec.python.utils import hvd_utils
from easy_rec.python.utils import pai_util
from easy_rec.python.utils.multi_optimizer import MultiOptimizer

Expand Down Expand Up @@ -209,9 +210,7 @@ def _train_model_fn(self, features, labels, run_config):
if estimator_utils.has_hvd():
assert not self.train_config.sync_replicas, \
'sync_replicas should not be set when using horovod'
optimizer = hvd.DistributedOptimizer(
optimizer, backward_passes_per_step=1)
bcast_hook = hvd.BroadcastGlobalVariablesHook(0)
bcast_hook = hvd_utils.BroadcastGlobalVariablesHook(0)
hooks.append(bcast_hook)

# for distributed and synced training
Expand Down Expand Up @@ -384,9 +383,9 @@ def _train_model_fn(self, features, labels, run_config):
early_stop_var = find_early_stop_var(var_list)
var_list = [x for x in var_list if x != early_stop_var]

initialize_var_list = [
x for x in var_list if 'WorkQueue' not in str(type(x))
]
# initialize_var_list = [
# x for x in var_list if 'WorkQueue' not in str(type(x))
# ]

# incompatiable shape restore will not be saved in checkpoint
# but must be able to restore from checkpoint
Expand All @@ -407,9 +406,9 @@ def _train_model_fn(self, features, labels, run_config):
sharded=True,
max_to_keep=self.train_config.keep_checkpoint_max,
save_relative_paths=True),
local_init_op=tf.group(local_init_ops),
ready_for_local_init_op=tf.report_uninitialized_variables(
var_list=initialize_var_list))
local_init_op=tf.group(local_init_ops))
# ready_for_local_init_op=tf.report_uninitialized_variables(
# var_list=initialize_var_list))
# saver hook
saver_hook = estimator_utils.CheckpointSaverHook(
checkpoint_dir=self.model_dir,
Expand Down
1 change: 1 addition & 0 deletions easy_rec/python/protos/train.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ enum DistributionStrategy {
MultiWorkerMirroredStrategy = 5;
// use horovod strategy
HorovodStrategy = 6;
SokStrategy = 7;
}

message IncrementSaveConfig {
Expand Down
3 changes: 3 additions & 0 deletions easy_rec/python/train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@

if pipeline_config.train_config.train_distribute == DistributionStrategy.HorovodStrategy:
estimator_utils.init_hvd()
elif pipeline_config.train_config.train_distribute == DistributionStrategy.SokStrategy:
estimator_utils.init_hvd()
estimator_utils.init_sok()

if args.hpo_param_path:
with gfile.GFile(args.hpo_param_path, 'r') as fin:
Expand Down
20 changes: 20 additions & 0 deletions easy_rec/python/utils/estimator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@
except Exception:
hvd = None

try:
from sparse_operation_kit import experiment as sok
except Exception:
sok = None

try:
from kafka import KafkaProducer, KafkaAdminClient
from kafka.admin import NewTopic
Expand Down Expand Up @@ -614,6 +619,7 @@ def after_run(self, run_context, run_values):
def _save(self, session, step):
"""Saves the latest checkpoint, returns should_stop."""
logging.info('Saving checkpoints for %d into %s.', step, self._save_path)
return False

for l in self._listeners: # noqa: E741
l.before_save(session, step)
Expand Down Expand Up @@ -990,6 +996,10 @@ def has_hvd():
return hvd is not None and 'HOROVOD_RANK' in os.environ


def has_sok():
return sok is not None and 'ENABLE_SOK' in os.environ


def init_hvd():
if hvd is None:
logging.error(
Expand All @@ -1001,6 +1011,16 @@ def init_hvd():
os.environ['HOROVOD_RANK'] = str(hvd.rank())


def init_sok():
try:
sok.init()
os.environ['ENABLE_SOK'] = '1'
return True
except Exception:
logging.error('sok is not installed')
return False


def get_available_gpus():
local_device_protos = device_lib.list_local_devices()
return [x.name for x in local_device_protos if x.device_type == 'GPU']
50 changes: 50 additions & 0 deletions easy_rec/python/utils/hvd_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# -*- encoding: utf-8 -*-
import logging

import tensorflow as tf
from tensorflow.python.training import session_run_hook

# from horovod.tensorflow.compression import Compression
try:
from horovod.tensorflow.functions import broadcast_variables
except Exception:
pass

if tf.__version__ >= '2.0':
tf = tf.compat.v1


class BroadcastGlobalVariablesHook(session_run_hook.SessionRunHook):
"""SessionRunHook that will broadcast all global variables from root rank to all other processes during initialization.

This is necessary to ensure consistent initialization of all workers when
training is started with random weights or restored from a checkpoint.
""" # noqa: E501

def __init__(self, root_rank, device=''):
"""Construct a new BroadcastGlobalVariablesHook that will broadcast all global variables from root rank to all other processes during initialization.

Args:
root_rank:
Rank that will send data, other ranks will receive data.
device:
Device to be used for broadcasting. Uses GPU by default
if Horovod was built with HOROVOD_GPU_OPERATIONS.
""" # noqa: E501
super(BroadcastGlobalVariablesHook, self).__init__()
self.root_rank = root_rank
self.bcast_op = None
self.device = device

def begin(self):
bcast_vars = []
for x in tf.global_variables():
if '/embedding' not in x.name:
bcast_vars.append(x)
logging.info('will broadcast variable: %s' % x.name)
if not self.bcast_op or self.bcast_op.graph != tf.get_default_graph():
with tf.device(self.device):
self.bcast_op = broadcast_variables(bcast_vars, self.root_rank)

def after_create_session(self, session, coord):
session.run(self.bcast_op)
Loading
Loading