tensorflow.python.ops.parallel_for.control_flow_ops 源代码

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""for_loop and pfor ops."""
# pylint: disable=g-direct-tensorflow-import

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools

from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.parallel_for.pfor import PFor
from tensorflow.python.ops.parallel_for.pfor import PForConfig
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export


def for_loop(loop_fn, loop_fn_dtypes, iters, parallel_iterations=None):
  """Runs `loop_fn` `iters` times and stacks the outputs.


  Runs `loop_fn` `iters` times, with input values from 0 to `iters - 1`, and
  stacks corresponding outputs of the different runs.

  Args:
    loop_fn: A function that takes an int32 scalar tf.Tensor object representing
      the iteration number, and returns a possibly nested structure of tensor
      objects. The shape of these outputs should not depend on the input.
    loop_fn_dtypes: dtypes for the outputs of loop_fn.
    iters: Number of iterations for which to run loop_fn.
    parallel_iterations: The number of iterations that can be dispatched in
      parallel. This knob can be used to control the total memory usage.

  Returns:
    Returns a nested structure of stacked output tensor objects with the same
    nested structure as the output of `loop_fn`.
  """

  flat_loop_fn_dtypes = nest.flatten(loop_fn_dtypes)
  is_none_list = []

  def while_body(i, *ta_list):
    """Body of while loop."""
    fn_output = nest.flatten(loop_fn(i))
    if len(fn_output) != len(flat_loop_fn_dtypes):
      raise ValueError(
          "Number of expected outputs, %d, does not match the number of "
          "actual outputs, %d, from loop_fn" % (len(flat_loop_fn_dtypes),
                                                len(fn_output)))
    outputs = []
    del is_none_list[:]
    is_none_list.extend(x is None for x in fn_output)
    for out, ta in zip(fn_output, ta_list):
      # TODO(agarwal): support returning Operation objects from loop_fn.
      if out is not None:
        # out may be a ref tensor, wrap it in identity to get a non-ref tensor.
        ta = ta.write(i, array_ops.expand_dims(out, 0))
      outputs.append(ta)
    return tuple([i + 1] + outputs)

  if parallel_iterations is not None:
    extra_args = {"parallel_iterations": parallel_iterations}
  else:
    extra_args = {}
  ta_list = control_flow_ops.while_loop(
      lambda i, *ta: i < iters,
      while_body,
      [0] + [tensor_array_ops.TensorArray(dtype.base_dtype, iters)
             for dtype in flat_loop_fn_dtypes],
      **extra_args)[1:]

  # TODO(rachelim): enable this for sparse tensors

  output = [None if is_none else ta.concat()
            for ta, is_none in zip(ta_list, is_none_list)]
  assert len(output) in (0, len(flat_loop_fn_dtypes))
  if not output:
    # This may happen for the case where iters == 0.
    return None
  else:
    return nest.pack_sequence_as(loop_fn_dtypes, output)


def _flatten_first_two_dims(x):
  """Flattens the first two dimensions of x into a single dimension."""
  old_shape = array_ops.shape(x)
  new_shape = array_ops.concat([[old_shape[0] * old_shape[1]], old_shape[2:]],
                               axis=0)
  return array_ops.reshape(x, new_shape)


PFOR_CONFIG_ARG = "pfor_config"


def _is_under_xla_context():
  """Check if we are currently inside an XLA compile context."""
  g = ops.get_default_graph()
  while g is not None:
    control_flow_context = g._get_control_flow_context()  # pylint: disable=protected-access
    while control_flow_context is not None:
      if control_flow_context.IsXLAContext():
        return True
      else:
        control_flow_context = control_flow_context.outer_context
    # If g is a FuncGraph, get its outer_graph.
    g = getattr(g, "outer_graph", None)
  return False


def pfor(loop_fn, iters, parallel_iterations=None):
  """Equivalent to running `loop_fn` `iters` times and stacking the outputs.

  `pfor` has functionality similar to `for_loop`, i.e. running `loop_fn` `iters`
  times, with input from 0 to `iters - 1`, and stacking corresponding output of
  each iteration. However the implementation does not use a tf.while_loop.
  Instead it adds new operations to the graph that collectively compute the same
  value as what running `loop_fn` in a loop would compute.


  This is an experimental feature and currently has a lot of limitations:
    - There should be no data dependency between the different iterations. For
      example, a future iteration should not depend on a value or side-effect of
      a previous iteration.
    - Stateful kernels may mostly not be supported since these often imply a
      data dependency or ordering of the iterations. We do support a limited set
      of such stateful kernels though (like RandomFoo, Variable operations like
      reads, etc).
    - Conversion works only on a limited set of kernels for which a converter
      has been registered.
    - loop_fn has limited support for control flow operations. tf.cond in
      particular is not supported.
    - `loop_fn` should return nested structure of Tensors or Operations. However
      if an Operation is returned, it should have zero outputs.
    - The shape and dtype of `loop_fn` outputs should not depend on the input
      to loop_fn.

  Args:
    loop_fn: A function that takes an int32 scalar tf.Tensor object representing
      the iteration number, and optionally a keyword argument `pfor_config` set
      to a PForConfig object. It returns a possibly nested structure of Tensor
      or Operation objects. Note that if setting `parallel_iterations` argument
      to something other than None, `loop_fn` may be called more than once
      during graph construction. So it may need to avoid mutating global state.
    iters: Number of iterations for which to run loop_fn.
    parallel_iterations: A knob to control how many iterations are vectorized
      and dispatched in parallel. The default value of None corresponds to
      vectorizing all the iterations.  If `parallel_iterations` is smaller than
      `iters`, then chunks of at most that many iterations are dispatched in
      sequence. This knob can be used to control the total memory usage.

  Returns:
    Returns a nested structure of stacked tensor objects with the same nested
    structure as the output of `loop_fn`.
  Raises:
    ValueError: If parallel_iterations is not None and not an integer > 1.
  """
  def f():
    return _pfor_impl(loop_fn, iters, parallel_iterations=parallel_iterations)
  # Note that we wrap into a tf.function if in eager execution mode or under
  # XLA compilation. The latter is so that we don't compile operations like
  # tf.placeholder that are created by the loop body.
  functions_run_eagerly = None
  if context.executing_eagerly() or _is_under_xla_context():
    functions_run_eagerly = def_function.functions_run_eagerly()
    if functions_run_eagerly:
      logging.warning(
          "It looks like tf.function behavior was disabled, perhaps using "
          "tf.config.experimental_run_functions_eagerly. Vectorization "
          "primitives (e.g. tf.vectorized_map) require tf.function to work. "
          "These primitives will override the disable.")
      def_function.run_functions_eagerly(False)
    f = def_function.function(f)
  outputs = f()
  if functions_run_eagerly is not None:
    def_function.run_functions_eagerly(functions_run_eagerly)
  return outputs


def _loop_fn_has_config(loop_fn):
  """Test if `loop_fn` has a `pfor_config` argument."""
  if tf_inspect.isfunction(loop_fn):
    argspec = tf_inspect.getargspec(loop_fn)
    return PFOR_CONFIG_ARG in argspec.args
  elif isinstance(loop_fn, functools.partial):
    fn = loop_fn.func
    argspec = tf_inspect.getargspec(fn)
    return (PFOR_CONFIG_ARG in argspec.args and
            PFOR_CONFIG_ARG not in loop_fn.keywords)
  else:
    loop_class = tf_decorator.unwrap(loop_fn)[1]
    if not hasattr(loop_class, "__call__"):
      raise ValueError("loop_fn object did not have a __call__ method")
    argspec = tf_inspect.getargspec(loop_class.__call__)
    return PFOR_CONFIG_ARG in argspec.args


def _pfor_impl(loop_fn, iters, parallel_iterations=None, pfor_config=None):
  """Implementation of pfor."""
  assert not context.executing_eagerly()
  loop_fn_has_config = _loop_fn_has_config(loop_fn)
  existing_ops = set(ops.get_default_graph().get_operations())
  # Run the loop body
  with ops.name_scope("loop_body"):
    loop_var = array_ops.placeholder_with_default(0, shape=[])
    if loop_fn_has_config:
      if pfor_config is None:
        pfor_config = PForConfig()
        pfor_config._set_iters(iters)  # pylint: disable=protected-access
      loop_fn_outputs = loop_fn(loop_var, **{PFOR_CONFIG_ARG: pfor_config})
    else:
      assert pfor_config is None
      loop_fn_outputs = loop_fn(loop_var)

  # Convert outputs to Tensor if needed.
  tmp_loop_fn_outputs = []
  for loop_fn_output in nest.flatten(loop_fn_outputs):
    if (loop_fn_output is not None and not isinstance(
        loop_fn_output,
        (ops.Operation, ops.Tensor, sparse_tensor.SparseTensor))):
      if isinstance(loop_fn_output, indexed_slices.IndexedSlices):
        logging.warn("Converting %s to a dense representation may make it slow."
                     " Alternatively, output the indices and values of the"
                     " IndexedSlices separately, and handle the vectorized"
                     " outputs directly." % loop_fn_output)
      loop_fn_output = ops.convert_to_tensor(loop_fn_output)
    tmp_loop_fn_outputs.append(loop_fn_output)
  loop_fn_outputs = nest.pack_sequence_as(loop_fn_outputs, tmp_loop_fn_outputs)

  new_ops = set(ops.get_default_graph().get_operations()) - existing_ops
  iters = ops.convert_to_tensor(iters)
  if parallel_iterations is not None:
    if parallel_iterations < 1:
      raise ValueError("parallel_iterations must be None or a positive integer")
    if parallel_iterations == 1:
      raise ValueError("Found parallel_iterations == 1. Use for_loop instead.")
    iters_value = tensor_util.constant_value(iters)
    if iters_value is not None and iters_value < parallel_iterations:
      parallel_iterations = None
  if parallel_iterations is None:
    with ops.name_scope("pfor"):
      converter = PFor(loop_var, iters, new_ops, pfor_config=pfor_config)
      outputs = []
      for loop_fn_output in nest.flatten(loop_fn_outputs):
        outputs.append(converter.convert(loop_fn_output))
      return nest.pack_sequence_as(loop_fn_outputs, outputs)
  else:
    if pfor_config is not None and pfor_config._has_reductions():  # pylint: disable=protected-access
      raise ValueError("Setting parallel_iterations currently unsupported if"
                       " reductions across iterations are performed.")
    num_tiled_iterations = iters // parallel_iterations
    num_remaining_iterations = iters % parallel_iterations
    # TODO(agarwal): Avoid calling loop_fn twice. Generate the loop body inside
    # a tf.function and extract the graph from there to vectorize it.
    with ops.name_scope("pfor_untiled"):
      converter = PFor(loop_var, num_remaining_iterations, new_ops,
                       pfor_config=pfor_config)
      remaining_outputs = []
      flattened_loop_fn_outputs = nest.flatten(loop_fn_outputs)
      for loop_fn_output in flattened_loop_fn_outputs:
        remaining_outputs.append(converter.convert(loop_fn_output))

    with ops.name_scope("pfor_tiled"):
      loop_fn_dtypes = [ops.convert_to_tensor(x).dtype
                        for x in flattened_loop_fn_outputs]

      def tiled_loop_body(j):
        offset = j * parallel_iterations + num_remaining_iterations

        def tiled_loop_fn(i, pfor_config=None):
          if loop_fn_has_config:
            return nest.flatten(loop_fn(i + offset, pfor_config=pfor_config))
          else:
            return nest.flatten(loop_fn(i + offset))

        return _pfor_impl(
            tiled_loop_fn, parallel_iterations, pfor_config=pfor_config)

      tiled_outputs = for_loop(tiled_loop_body, loop_fn_dtypes,
                               num_tiled_iterations, parallel_iterations=1)
      tiled_outputs = [_flatten_first_two_dims(y) for y in tiled_outputs]

    with ops.name_scope("pfor"):
      iters_value = tensor_util.constant_value(iters)
      if iters_value is None or iters_value % parallel_iterations:
        outputs = control_flow_ops.cond(
            math_ops.equal(num_remaining_iterations, 0),
            lambda: tiled_outputs,
            lambda: [array_ops.concat([x, y], axis=0)
                     for x, y in zip(remaining_outputs, tiled_outputs)])
      else:
        outputs = tiled_outputs
      return nest.pack_sequence_as(loop_fn_outputs, nest.flatten(outputs))


[文档]@tf_export("vectorized_map") def vectorized_map(fn, elems): """Parallel map on the list of tensors unpacked from `elems` on dimension 0. This method works similar to tf.map_fn but is optimized to run much faster, possibly with a much larger memory footprint. The speedups are obtained by vectorization (see https://arxiv.org/pdf/1903.04243.pdf). The idea behind vectorization is to semantically launch all the invocations of `fn` in parallel and fuse corresponding operations across all these invocations. This fusion is done statically at graph generation time and the generated code is often similar in performance to a manually fused version. Because `tf.vectorized_map` fully parallelizes the batch, this method will generally be significantly faster than using `tf.map_fn`, especially in eager mode. However this is an experimental feature and currently has a lot of limitations: - There should be no data dependency between the different semantic invocations of `fn`, i.e. it should be safe to map the elements of the inputs in any order. - Stateful kernels may mostly not be supported since these often imply a data dependency. We do support a limited set of such stateful kernels though (like RandomFoo, Variable operations like reads, etc). - `fn` has limited support for control flow operations. `tf.cond` in particular is not supported. - `fn` should return nested structure of Tensors or Operations. However if an Operation is returned, it should have zero outputs. - The shape and dtype of any intermediate or output tensors in the computation of `fn` should not depend on the input to `fn`. Examples: ```python def outer_product(a): return tf.tensordot(a, a, 0) batch_size = 100 a = tf.ones((batch_size, 32, 32)) c = tf.vectorized_map(outer_product, a) assert c.shape == (batch_size, 32, 32, 32, 32) ``` ```python # Computing per-example gradients batch_size = 10 num_features = 32 layer = tf.keras.layers.Dense(1) def model_fn(arg): with tf.GradientTape() as g: inp, label = arg inp = tf.expand_dims(inp, 0) label = tf.expand_dims(label, 0) prediction = layer(inp) loss = tf.nn.l2_loss(label - prediction) return g.gradient(loss, (layer.kernel, layer.bias)) inputs = tf.random.uniform([batch_size, num_features]) labels = tf.random.uniform([batch_size, 1]) per_example_gradients = tf.vectorized_map(model_fn, (inputs, labels)) assert per_example_gradients[0].shape == (batch_size, num_features, 1) assert per_example_gradients[1].shape == (batch_size, 1) ``` Args: fn: The callable to be performed. It accepts one argument, which will have the same (possibly nested) structure as `elems`, and returns a possibly nested structure of Tensors and Operations, which may be different than the structure of `elems`. elems: A tensor or (possibly nested) sequence of tensors, each of which will be unpacked along their first dimension. The nested sequence of the resulting slices will be mapped over by `fn`. Returns: A tensor or (possibly nested) sequence of tensors. Each tensor packs the results of applying fn to tensors unpacked from elems along the first dimension, from first to last. """ def loop_fn(i): gathered_elems = nest.map_structure(lambda x: array_ops.gather(x, i), elems) return fn(gathered_elems) batch_size = None first_elem = ops.convert_to_tensor(nest.flatten(elems)[0]) if first_elem.shape.rank is not None: batch_size = first_elem.shape.as_list()[0] if batch_size is None: batch_size = array_ops.shape(first_elem)[0] return pfor(loop_fn, batch_size)