Skip to content

Commit

Permalink
Merge pull request NVlabs#9 from xivh/master
Browse files Browse the repository at this point in the history
fixed encoder
  • Loading branch information
pbaylies authored Jun 12, 2020
2 parents 358a046 + 938ac88 commit ffc1deb
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
15 changes: 11 additions & 4 deletions dnnlib/tflib/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def run(self,
minibatch_size: int = None,
num_gpus: int = 1,
assume_frozen: bool = False,
custom_inputs: Any = None,
**dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]:
"""Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s).
Expand All @@ -404,6 +405,7 @@ def run(self,
minibatch_size: Maximum minibatch size to use, None = disable batching.
num_gpus: Number of GPUs to use.
assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls.
custom_inputs: Allow to use another tensor as input instead of default placeholders.
dynamic_kwargs: Additional keyword arguments to be passed into the network build function.
"""
assert len(in_arrays) == self.num_inputs
Expand All @@ -427,10 +429,15 @@ def unwind_key(obj):

# Build graph.
if key not in self._run_cache:
with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None):
with tf.device("/cpu:0"):
in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None):
if custom_inputs is not None:
with tf.device("/gpu:0"):
in_expr = [input_builder(name) for input_builder, name in zip(custom_inputs, self.input_names)]
in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
else:
with tf.device("/cpu:0"):
in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))

out_split = []
for gpu in range(num_gpus):
Expand Down
1 change: 1 addition & 0 deletions encode_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pretrained_networks
from encoder.generator_model import Generator
from encoder.perceptual_model import PerceptualModel
from encoder.perceptual_model import load_images
from keras.models import load_model
from keras.applications.resnet50 import preprocess_input

Expand Down

0 comments on commit ffc1deb

Please sign in to comment.