tensorflow.python.ops.map_fn 源代码

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

"""Functional operations."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export


[文档]@tf_export(v1=["map_fn"]) def map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True, swap_memory=False, infer_shape=True, name=None): """map on the list of tensors unpacked from `elems` on dimension 0. The simplest version of `map_fn` repeatedly applies the callable `fn` to a sequence of elements from first to last. The elements are made of the tensors unpacked from `elems`. `dtype` is the data type of the return value of `fn`. Users must provide `dtype` if it is different from the data type of `elems`. Suppose that `elems` is unpacked into `values`, a list of tensors. The shape of the result tensor is `[values.shape[0]] + fn(values[0]).shape`. This method also allows multi-arity `elems` and output of `fn`. If `elems` is a (possibly nested) list or tuple of tensors, then each of these tensors must have a matching first (unpack) dimension. The signature of `fn` may match the structure of `elems`. That is, if `elems` is `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: `fn = lambda (t1, [t2, t3, [t4, t5]]):`. Furthermore, `fn` may emit a different structure than its input. For example, `fn` may look like: `fn = lambda t1: return (t1 + 1, t1 - 1)`. In this case, the `dtype` parameter is not optional: `dtype` must be a type or (possibly nested) tuple of types matching the output of `fn`. To apply a functional operation to the nonzero elements of a SparseTensor one of the following methods is recommended. First, if the function is expressible as TensorFlow ops, use ```python result = SparseTensor(input.indices, fn(input.values), input.dense_shape) ``` If, however, the function is not expressible as a TensorFlow op, then use ```python result = SparseTensor( input.indices, map_fn(fn, input.values), input.dense_shape) ``` instead. When executing eagerly, map_fn does not execute in parallel even if `parallel_iterations` is set to a value > 1. You can still get the performance benefits of running a function in parallel by using the `tf.function` decorator, ```python # Assume the function being used in map_fn is fn. # To ensure map_fn calls fn in parallel, use the tf.function decorator. @tf.function def func(tensor): return tf.map_fn(fn, tensor) ``` Note that if you use the `tf.function` decorator, any non-TensorFlow Python code that you may have written in your function won't get executed. See [`tf.function`](https://www.tensorflow.org/api_docs/python/tf/function) for more details. The recommendation would be to debug without `tf.function` but switch to it to get performance benefits of running `map_fn` in parallel. Args: fn: The callable to be performed. It accepts one argument, which will have the same (possibly nested) structure as `elems`. Its output must have the same structure as `dtype` if one is provided, otherwise it must have the same structure as `elems`. elems: A tensor or (possibly nested) sequence of tensors, each of which will be unpacked along their first dimension. The nested sequence of the resulting slices will be applied to `fn`. dtype: (optional) The output type(s) of `fn`. If `fn` returns a structure of Tensors differing from the structure of `elems`, then `dtype` is not optional and must have the same structure as the output of `fn`. parallel_iterations: (optional) The number of iterations allowed to run in parallel. When graph building, the default value is 10. While executing eagerly, the default value is set to 1. back_prop: (optional) True enables support for back propagation. swap_memory: (optional) True enables GPU-CPU memory swapping. infer_shape: (optional) False disables tests for consistent output shapes. name: (optional) Name prefix for the returned tensors. Returns: A tensor or (possibly nested) sequence of tensors. Each tensor packs the results of applying `fn` to tensors unpacked from `elems` along the first dimension, from first to last. Raises: TypeError: if `fn` is not callable or the structure of the output of `fn` and `dtype` do not match, or if elems is a SparseTensor. ValueError: if the lengths of the output of `fn` and `dtype` do not match. Examples: ```python elems = np.array([1, 2, 3, 4, 5, 6]) squares = map_fn(lambda x: x * x, elems) # squares == [1, 4, 9, 16, 25, 36] ``` ```python elems = (np.array([1, 2, 3]), np.array([-1, 1, -1])) alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64) # alternate == [-1, 2, -3] ``` ```python elems = np.array([1, 2, 3]) alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64)) # alternates[0] == [1, 2, 3] # alternates[1] == [-1, -2, -3] ``` """ if not callable(fn): raise TypeError("fn must be callable.") if isinstance(elems, sparse_tensor.SparseTensor): raise TypeError( "To perform a map on the values of a sparse tensor use either " " SparseTensor(input.indices, fn(input.values), input.dense_shape) or " " SparseTensor(input.indices, map_fn(fn, input.values), " "input.dense_shape)") in_graph_mode = not context.executing_eagerly() # Set the default number of parallel_iterations depending on graph/eager mode. if in_graph_mode and not parallel_iterations: parallel_iterations = 10 elif not in_graph_mode and not parallel_iterations: parallel_iterations = 1 if not in_graph_mode and parallel_iterations > 1: logging.log_first_n( logging.WARN, "Setting parallel_iterations > 1 has no " "effect when executing eagerly. Consider calling map_fn" " with tf.function to execute fn in " "parallel.", 1) parallel_iterations = 1 input_is_sequence = nest.is_sequence(elems) input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x] def input_pack(x): return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0] if dtype is None: output_is_sequence = input_is_sequence output_flatten = input_flatten output_pack = input_pack else: output_is_sequence = nest.is_sequence(dtype) output_flatten = lambda x: nest.flatten(x) if output_is_sequence else [x] def output_pack(x): return (nest.pack_sequence_as(dtype, x) if output_is_sequence else x[0]) elems_flat = input_flatten(elems) with ops.name_scope(name, "map", elems_flat): # TODO(akshayka): Remove the in_graph_mode check once caching devices are # supported in Eager if in_graph_mode: # Any get_variable calls in fn will cache the first call locally # and not issue repeated network I/O requests for each iteration. varscope = vs.get_variable_scope() varscope_caching_device_was_none = False if varscope.caching_device is None: # TODO(ebrevdo): Change to using colocate_with here and in other # methods. varscope.set_caching_device(lambda op: op.device) varscope_caching_device_was_none = True elems_flat = [ ops.convert_to_tensor(elem, name="elem") for elem in elems_flat] dtype = dtype or input_pack([elem.dtype for elem in elems_flat]) dtype_flat = output_flatten(dtype) # Convert elems to tensor array. n may be known statically. static_shape = elems_flat[0].shape if static_shape.ndims is not None and static_shape.ndims < 1: if len(elems_flat) == 1: raise ValueError("elems must be a 1+ dimensional Tensor, not a scalar") else: raise ValueError( "elements in elems must be 1+ dimensional Tensors, not scalars" ) n = (tensor_shape.dimension_value(static_shape[0]) or array_ops.shape(elems_flat[0])[0]) # TensorArrays are always flat elems_ta = [ tensor_array_ops.TensorArray(dtype=elem.dtype, size=n, dynamic_size=False, infer_shape=True) for elem in elems_flat] # Unpack elements elems_ta = [ elem_ta.unstack(elem) for elem_ta, elem in zip(elems_ta, elems_flat)] i = constant_op.constant(0) accs_ta = [ tensor_array_ops.TensorArray(dtype=dt, size=n, dynamic_size=False, infer_shape=infer_shape) for dt in dtype_flat] def compute(i, tas): """The loop body of map_fn. Args: i: the loop counter tas: the flat TensorArray accumulator list Returns: (i + 1, tas): the updated counter + updated TensorArrays Raises: TypeError: if dtype and packed_fn_values structure do not match ValueType: if dtype and packed_fn_values lengths do not match """ packed_values = input_pack([elem_ta.read(i) for elem_ta in elems_ta]) packed_fn_values = fn(packed_values) nest.assert_same_structure(dtype or elems, packed_fn_values) flat_fn_values = output_flatten(packed_fn_values) tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_fn_values)] return (i + 1, tas) _, r_a = control_flow_ops.while_loop( lambda i, _: i < n, compute, (i, accs_ta), parallel_iterations=parallel_iterations, back_prop=back_prop, swap_memory=swap_memory, maximum_iterations=n) results_flat = [r.stack() for r in r_a] n_static = tensor_shape.Dimension(tensor_shape.dimension_value( elems_flat[0].get_shape().with_rank_at_least(1)[0])) for elem in elems_flat[1:]: n_static.merge_with(tensor_shape.Dimension(tensor_shape.dimension_value( elem.get_shape().with_rank_at_least(1)[0]))) for r in results_flat: r.set_shape(tensor_shape.TensorShape(n_static).concatenate( r.get_shape()[1:])) # TODO(akshayka): Remove the in_graph_mode check once caching devices are # supported in Eager if in_graph_mode and varscope_caching_device_was_none: varscope.set_caching_device(None) return output_pack(results_flat)
@tf_export("map_fn", v1=[]) @deprecation.deprecated_arg_values( None, """back_prop=False is deprecated. Consider using tf.stop_gradient instead. Instead of: results = tf.map_fn(fn, elems, back_prop=False) Use: results = tf.nest.map_structure(tf.stop_gradient, tf.map_fn(fn, elems))""", warn_once=True, back_prop=False) def map_fn_v2(fn, elems, dtype=None, parallel_iterations=None, back_prop=True, swap_memory=False, infer_shape=True, name=None): """map on the list of tensors unpacked from `elems` on dimension 0. The simplest version of `map_fn` repeatedly applies the callable `fn` to a sequence of elements from first to last. The elements are made of the tensors unpacked from `elems`. `dtype` is the data type of the return value of `fn`. Users must provide `dtype` if it is different from the data type of `elems`. Suppose that `elems` is unpacked into `values`, a list of tensors. The shape of the result tensor is `[values.shape[0]] + fn(values[0]).shape`. This method also allows multi-arity `elems` and output of `fn`. If `elems` is a (possibly nested) list or tuple of tensors, then each of these tensors must have a matching first (unpack) dimension. The signature of `fn` may match the structure of `elems`. That is, if `elems` is `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: `fn = lambda (t1, [t2, t3, [t4, t5]]):`. Furthermore, `fn` may emit a different structure than its input. For example, `fn` may look like: `fn = lambda t1: return (t1 + 1, t1 - 1)`. In this case, the `dtype` parameter is not optional: `dtype` must be a type or (possibly nested) tuple of types matching the output of `fn`. To apply a functional operation to the nonzero elements of a SparseTensor one of the following methods is recommended. First, if the function is expressible as TensorFlow ops, use ```python result = SparseTensor(input.indices, fn(input.values), input.dense_shape) ``` If, however, the function is not expressible as a TensorFlow op, then use ```python result = SparseTensor( input.indices, map_fn(fn, input.values), input.dense_shape) ``` instead. When executing eagerly, map_fn does not execute in parallel even if `parallel_iterations` is set to a value > 1. You can still get the performance benefits of running a function in parallel by using the `tf.function` decorator, ```python # Assume the function being used in map_fn is fn. # To ensure map_fn calls fn in parallel, use the tf.function decorator. @tf.function def func(tensor): return tf.map_fn(fn, tensor) ``` Note that if you use the `tf.function` decorator, any non-TensorFlow Python code that you may have written in your function won't get executed. See [`tf.function`](https://www.tensorflow.org/api_docs/python/tf/function) for more details. The recommendation would be to debug without `tf.function` but switch to it to get performance benefits of running `map_fn` in parallel. Args: fn: The callable to be performed. It accepts one argument, which will have the same (possibly nested) structure as `elems`. Its output must have the same structure as `dtype` if one is provided, otherwise it must have the same structure as `elems`. elems: A tensor or (possibly nested) sequence of tensors, each of which will be unpacked along their first dimension. The nested sequence of the resulting slices will be applied to `fn`. dtype: (optional) The output type(s) of `fn`. If `fn` returns a structure of Tensors differing from the structure of `elems`, then `dtype` is not optional and must have the same structure as the output of `fn`. parallel_iterations: (optional) The number of iterations allowed to run in parallel. When graph building, the default value is 10. While executing eagerly, the default value is set to 1. back_prop: (optional) Deprecated. False disables support for back propagation. Prefer using `tf.stop_gradient` instead. swap_memory: (optional) True enables GPU-CPU memory swapping. infer_shape: (optional) False disables tests for consistent output shapes. name: (optional) Name prefix for the returned tensors. Returns: A tensor or (possibly nested) sequence of tensors. Each tensor packs the results of applying `fn` to tensors unpacked from `elems` along the first dimension, from first to last. Raises: TypeError: if `fn` is not callable or the structure of the output of `fn` and `dtype` do not match, or if elems is a SparseTensor. ValueError: if the lengths of the output of `fn` and `dtype` do not match. Examples: ```python elems = np.array([1, 2, 3, 4, 5, 6]) squares = map_fn(lambda x: x * x, elems) # squares == [1, 4, 9, 16, 25, 36] ``` ```python elems = (np.array([1, 2, 3]), np.array([-1, 1, -1])) alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64) # alternate == [-1, 2, -3] ``` ```python elems = np.array([1, 2, 3]) alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64)) # alternates[0] == [1, 2, 3] # alternates[1] == [-1, -2, -3] ``` """ return map_fn( fn=fn, elems=elems, dtype=dtype, parallel_iterations=parallel_iterations, back_prop=back_prop, swap_memory=swap_memory, infer_shape=infer_shape, name=name)