tensorflow.python.ops.tensor_array_ops 源代码

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorArray: a dynamically sized array of Tensors."""
# Mixture of pep8 and non-pep8 names, so disable pylint bad-name
# pylint: disable=g-bad-name
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import contextlib

import numpy as np
import traceback
import weakref

from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import gen_control_flow_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_should_use
from tensorflow.python.util.tf_export import tf_export


# _GraphTensorArray accesses many of the hidden generated ops, but is in
# fact built to wrap these methods.
# pylint: disable=protected-access
class _GraphTensorArray(object):
  """Graph-mode implementation of TensorArray.
  """

  def __init__(self,
               dtype,
               size=None,
               dynamic_size=None,
               clear_after_read=None,
               tensor_array_name=None,
               handle=None,
               flow=None,
               infer_shape=True,
               element_shape=None,
               colocate_with_first_write_call=True,
               name=None):
    """Constructs a graph mode TensorArray.

    Args:
      dtype: (required) data type of the TensorArray.
      size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
        Required if handle is not provided.
      dynamic_size: (optional) Python bool: If true, writes to the TensorArray
        can grow the TensorArray past its initial size.  Default: False.
      clear_after_read: Boolean (optional, default: True).  If True, clear
        TensorArray values after reading them.  This disables read-many
        semantics, but allows early release of memory.
      tensor_array_name: (optional) Python string: the name of the TensorArray.
        This is used when creating the TensorArray handle.  If this value is
        set, handle should be None.
      handle: (optional) A `Tensor` handle to an existing TensorArray.  If this
        is set, tensor_array_name should be None. Only supported in graph mode.
      flow: (optional) A float `Tensor` scalar coming from an existing
        `TensorArray.flow`. Only supported in graph mode.
      infer_shape: (optional, default: True) If True, shape inference
        is enabled.  In this case, all elements must have the same shape.
      element_shape: (optional, default: None) A `TensorShape` object specifying
        the shape constraints of each of the elements of the TensorArray.
        Need not be fully defined.
      colocate_with_first_write_call: If `True`, the TensorArray will be
        colocated on the same device as the Tensor used on its first write
        (write operations include `write`, `unstack`, and `split`).  If `False`,
        the TensorArray will be placed on the device determined by the
        device context available during its initialization.
      name: A name for the operation (optional).

    Raises:
      ValueError: if both handle and tensor_array_name are provided.
      TypeError: if handle is provided but is not a Tensor.
    """
    if handle is not None and tensor_array_name:
      raise ValueError(
          "Cannot construct with both handle and tensor_array_name")
    if handle is not None and not isinstance(handle, ops.Tensor):
      raise TypeError("Handle must be a Tensor")
    if handle is None and size is None:
      raise ValueError("Size must be provided if handle is not provided")
    if handle is not None and size is not None:
      raise ValueError("Cannot provide both a handle and size "
                       "at the same time")
    if handle is not None and element_shape is not None:
      raise ValueError("Cannot provide both a handle and element_shape "
                       "at the same time")
    if handle is not None and dynamic_size is not None:
      raise ValueError("Cannot provide both a handle and dynamic_size "
                       "at the same time")
    if handle is not None and clear_after_read is not None:
      raise ValueError("Cannot provide both a handle and clear_after_read "
                       "at the same time")

    if clear_after_read is None:
      clear_after_read = True
    self._dynamic_size = dynamic_size or False
    self._dtype = dtypes.as_dtype(dtype).base_dtype

    # Used to keep track of what tensors the TensorArray should be
    # colocated with.  We choose to colocate the TensorArray with the
    # first tensor written to it.
    self._colocate_with_first_write_call = colocate_with_first_write_call
    if colocate_with_first_write_call:
      self._colocate_with = []
    else:
      self._colocate_with = None

    # Record the current static shape for the array elements. The element
    # shape is defined either by `element_shape` or the shape of the tensor
    # of the first write. If `infer_shape` is true, all writes checks for
    # shape equality.
    self._element_shape = [tensor_shape.as_shape(element_shape)]
    self._infer_shape = infer_shape
    self._size = size
    with ops.name_scope(name, "TensorArray", [handle, size, flow]) as scope:
      if handle is not None:
        self._handle = handle
        if flow is None:
          raise ValueError("flow must not be None if handle is not None.")
        self._flow = flow
      else:
        # Construct the TensorArray with an empty device.  The first
        # write into the TensorArray from a Tensor with a set device
        # will retroactively set the device value of this op.
        def create():
          """Create the TensorArray op."""
          return gen_data_flow_ops.tensor_array_v3(
              dtype=dtype,
              size=size,
              element_shape=element_shape,
              identical_element_shapes=infer_shape,
              dynamic_size=self._dynamic_size,
              clear_after_read=clear_after_read,
              tensor_array_name=tensor_array_name,
              name=scope)
        if colocate_with_first_write_call:
          with ops.device(None), ops.colocate_with(None, ignore_existing=True):
            self._handle, self._flow = create()
        else:
          self._handle, self._flow = create()

  @property
  def flow(self):
    return self._flow

  @property
  def dtype(self):
    return self._dtype

  @property
  def handle(self):
    return self._handle

  @property
  def element_shape(self):
    return self._element_shape[0]

  def _check_element_shape(self, shape):
    """Changes the element shape of the array given a shape to merge with.

    Args:
      shape: A `TensorShape` object to merge with.

    Raises:
      ValueError: if the provided shape is incompatible with the current
          element shape of the `TensorArray`.
    """
    if not shape.is_compatible_with(self.element_shape):
      raise ValueError("Inconsistent shapes: saw %s but expected %s " %
                       (shape, self.element_shape))
    if self._infer_shape:
      self._element_shape[0] = self.element_shape.merge_with(shape)

  @contextlib.contextmanager
  def _maybe_colocate_with(self, value):
    """Colocate operations with an internal colocation group or `value`.

    Args:
      value: `Tensor`, the tensor to try to colocate with.

    Yields:
      Does not yield anything, but the new context is a colocation context.

    If no internal colocation group is set, colocate with `value` and set
    the internal colocation group to be value.
    """
    if not self._colocate_with_first_write_call:
      yield
    else:
      if not self._colocate_with:
        self._colocate_with.append(value)
      with ops.colocate_with(self._colocate_with[0]):
        yield

  def identity(self):
    """See TensorArray."""
    flow = array_ops.identity(self._flow)
    return build_ta_with_new_flow(self, flow)

  def grad(self, source, flow=None, name=None):
    """See TensorArray."""
    # tensor_array_grad requires a flow input when forward
    # TensorArrays are dynamically sized.  This forces the creation
    # of the grad TensorArray only once the final forward array's size
    # is fixed.
    if flow is None:
      flow = self.flow
    with ops.name_scope(name, "TensorArrayGrad", [self._handle]):
      with ops.colocate_with(self._handle):
        g_handle, unused_flow = gen_data_flow_ops.tensor_array_grad_v3(
            handle=self._handle, source=source, flow_in=flow, name=name)
        with ops.control_dependencies([g_handle]):
          flow = array_ops.identity(flow, name="gradient_flow")
        g = TensorArray(
            dtype=self._dtype,
            handle=g_handle,
            flow=flow,
            infer_shape=self._infer_shape,
            colocate_with_first_write_call=False)
        # pylint: disable=protected-access
        g._implementation._element_shape = self._element_shape
        # pylint: enable=protected-access
        return g

  def read(self, index, name=None):
    """See TensorArray."""
    value = gen_data_flow_ops.tensor_array_read_v3(
        handle=self._handle,
        index=index,
        flow_in=self._flow,
        dtype=self._dtype,
        name=name)
    if self._element_shape:
      value.set_shape(self._element_shape[0].dims)
    return value

  def write(self, index, value, name=None):
    """See TensorArray."""
    with ops.name_scope(name, "TensorArrayWrite", [self._handle, index, value]):
      # TODO(b/129870929): Fix after all callers provide proper init dtype.
      value = ops.convert_to_tensor(
          value, preferred_dtype=self._dtype, name="value")
      _check_dtypes(value, self._dtype)
      self._check_element_shape(value.shape)
      with self._maybe_colocate_with(value):
        flow_out = gen_data_flow_ops.tensor_array_write_v3(
            handle=self._handle,
            index=index,
            value=value,
            flow_in=self._flow,
            name=name)
      return build_ta_with_new_flow(self, flow_out)

  def stack(self, name=None):
    """See TensorArray."""
    with ops.colocate_with(self._handle):
      with ops.name_scope(name, "TensorArrayStack", [self._handle]):
        value = self.gather(math_ops.range(0, self.size()), name=name)
        if (self.element_shape and not self._dynamic_size and
            self._size is not None):
          value.set_shape([tensor_util.constant_value(self._size)] +
                          self.element_shape.dims)
        return value

  def gather(self, indices, name=None):
    """See TensorArray."""
    if self._element_shape:
      element_shape = self._element_shape[0]
    else:
      element_shape = tensor_shape.unknown_shape(None)
    value = gen_data_flow_ops.tensor_array_gather_v3(
        handle=self._handle,
        indices=indices,
        flow_in=self._flow,
        dtype=self._dtype,
        name=name,
        element_shape=element_shape)
    if self.element_shape:
      value.set_shape([None] + self.element_shape.dims)
    return value

  def concat(self, name=None):
    """See TensorArray."""
    value, _ = gen_data_flow_ops.tensor_array_concat_v3(
        handle=self._handle,
        flow_in=self._flow,
        dtype=self._dtype,
        name=name,
        element_shape_except0=self.element_shape[1:])
    if self.element_shape:
      value.set_shape([None] + self.element_shape.dims[1:])
    return value

  @tf_should_use.should_use_result
  def unstack(self, value, name=None):
    """See TensorArray."""
    with ops.name_scope(name, "TensorArrayUnstack", [self._handle, value]):
      num_elements = array_ops.shape(value)[0]
      return self.scatter(
          indices=math_ops.range(0, num_elements), value=value, name=name)

  @tf_should_use.should_use_result
  def scatter(self, indices, value, name=None):
    """See TensorArray."""
    with ops.name_scope(name, "TensorArrayScatter",
                        [self._handle, value, indices]):
      # TODO(b/129870929): Fix after all callers provide proper init dtype.
      value = ops.convert_to_tensor(
          value, preferred_dtype=self._dtype, name="value")
      _check_dtypes(value, self._dtype)
      if not context.executing_eagerly():
        self._check_element_shape(value.shape[1:])
      with self._maybe_colocate_with(value):
        flow_out = gen_data_flow_ops.tensor_array_scatter_v3(
            handle=self._handle,
            indices=indices,
            value=value,
            flow_in=self._flow,
            name=name)
      return build_ta_with_new_flow(self, flow_out)

  @tf_should_use.should_use_result
  def split(self, value, lengths, name=None):
    """See TensorArray."""
    with ops.name_scope(name, "TensorArraySplit",
                        [self._handle, value, lengths]):
      value = ops.convert_to_tensor(value, dtype=self._dtype, name="value")
      with self._maybe_colocate_with(value):
        lengths_64 = math_ops.cast(lengths, dtypes.int64)
        if not context.executing_eagerly():
          clengths = tensor_util.constant_value(lengths_64)
          if value.shape.dims is not None and clengths is not None:
            if clengths.shape and clengths.max() == clengths.min():
              self._check_element_shape(
                  tensor_shape.TensorShape([clengths[0]]).concatenate(
                      value.shape[1:]))
        flow_out = gen_data_flow_ops.tensor_array_split_v3(
            handle=self._handle,
            value=value,
            lengths=lengths_64,
            flow_in=self._flow,
            name=name)
      return build_ta_with_new_flow(self, flow_out)

  def size(self, name=None):
    """See TensorArray."""
    if not self._dynamic_size and self._size is not None:
      return ops.convert_to_tensor(self._size, dtype=dtypes.int32)
    else:
      return gen_data_flow_ops.tensor_array_size_v3(
          handle=self._handle, flow_in=self.flow, name=name)

  @tf_should_use.should_use_result
  def close(self, name=None):
    """See TensorArray."""
    return gen_data_flow_ops.tensor_array_close_v3(
        handle=self._handle, name=name)


class _GraphTensorArrayV2(object):
  """Graph-mode implementation of TensorArray backed by TensorLists.

  The backing tensor of this TensorArray is a TensorList variant tensor which is
  stored in the `flow`. The `handle` is always none here. The reason we use the
  `flow` field and not the `handle` field is to ensure backwards compatibility
  with legacy control flow.
  """

  def __init__(self,
               dtype,
               size=None,
               dynamic_size=None,
               clear_after_read=None,
               tensor_array_name=None,
               handle=None,
               flow=None,
               infer_shape=True,
               element_shape=None,
               colocate_with_first_write_call=True,
               name=None):
    """Constructs a graph mode TensorArray.

    Args:
      dtype: (required) data type of the TensorArray.
      size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
        Required if flow is not provided.
      dynamic_size: (optional) Python bool: If true, writes to the TensorArray
        can grow the TensorArray past its initial size.  Default: False.
      clear_after_read: (optional) unused. Not supported in TensorLists.
      tensor_array_name: (optional) unused.
      handle: (optional) Must always be None.
      flow: (optional) A variant `Tensor` scalar for a TensorList.
      infer_shape: (optional, default: True) If True, shape inference is
        enabled.  In this case, all elements must have the same shape.
      element_shape: (optional, default: None) A `TensorShape` object specifying
        the shape constraints of each of the elements of the TensorArray. Need
        not be fully defined.
      colocate_with_first_write_call: (optional). unused.
      name: (optional) A name for the operation.

    Raises:
      ValueError: if both handle and tensor_array_name are provided.
      TypeError: if handle is provided but is not a Tensor.
    """
    assert handle is None
    del handle
    del clear_after_read
    del tensor_array_name
    del colocate_with_first_write_call

    self._dynamic_size = dynamic_size
    self._size = size

    if (flow is not None and
        (not isinstance(flow, ops.Tensor) or flow.dtype != dtypes.variant)):
      raise TypeError("flow must be a variant tensor")
    if flow is None and size is None:
      raise ValueError("Size must be provided if flow is not provided")
    if flow is not None and size is not None:
      raise ValueError("Cannot provide both a flow and size "
                       "at the same time")
    if flow is not None and element_shape is not None:
      raise ValueError("Cannot provide both a flow and element_shape "
                       "at the same time")

    self._dtype = dtypes.as_dtype(dtype).base_dtype

    # Record the current static shape for the array elements. The element
    # shape is defined either by `element_shape` or the shape of the tensor
    # of the first write. If `infer_shape` is true, all writes checks for
    # shape equality.
    self._element_shape = [tensor_shape.as_shape(element_shape)]
    self._infer_shape = infer_shape
    with ops.name_scope(name, "TensorArrayV2", [size, flow]) as scope:
      if flow is None:
        self._flow = list_ops.tensor_list_reserve(
            element_shape=element_shape,
            num_elements=size,
            element_dtype=dtype,
            name=scope)
      else:
        self._flow = flow

    # For backwards compatibility.
    self._colocate_with_first_write_call = None
    self._colocate_with = None

  @property
  def flow(self):
    return self._flow

  @property
  def dtype(self):
    return self._dtype

  @property
  def element_shape(self):
    return self._element_shape[0]

  @property
  def handle(self):
    # We intentionally do not raise an error so that legacy while_loop does not
    # complain.
    return None

  def _check_element_shape(self, shape):
    """Changes the element shape of the array given a shape to merge with.

    Args:
      shape: A `TensorShape` object to merge with.

    Raises:
      ValueError: if the provided shape is incompatible with the current
          element shape of the `TensorArray`.
    """
    if not shape.is_compatible_with(self.element_shape):
      raise ValueError("Inconsistent shapes: saw %s but expected %s " %
                       (shape, self.element_shape))
    if self._infer_shape:
      self._element_shape[0] = self.element_shape.merge_with(shape)

  def identity(self):
    """See TensorArray."""
    flow = array_ops.identity(self._flow)
    return build_ta_with_new_flow(self, flow)

  def grad(self, source, flow=None, name=None):
    """Not supported."""
    raise NotImplementedError()

  def read(self, index, name=None):
    """See TensorArray."""
    with ops.name_scope(name, "TensorArrayV2Read", [self._flow, index]):
      value = list_ops.tensor_list_get_item(
          input_handle=self._flow,
          index=index,
          element_dtype=self._dtype,
          element_shape=self.element_shape,
          name=name)
      return value

  def write(self, index, value, name=None):
    """See TensorArray."""
    with ops.name_scope(name, "TensorArrayV2Write", [self._flow, index, value]):
      # TODO(b/129870929): Fix after all callers provide proper init dtype.
      value = ops.convert_to_tensor(
          value, preferred_dtype=self._dtype, name="value")
      _check_dtypes(value, self._dtype)
      self._check_element_shape(value.shape)
      flow_out = list_ops.tensor_list_set_item(
          input_handle=self._flow,
          index=index,
          item=value,
          resize_if_index_out_of_bounds=self._dynamic_size,
          name=name)
      return build_ta_with_new_flow(self, flow_out)

  def stack(self, name=None):
    """See TensorArray."""
    with ops.name_scope(name, "TensorArrayV2Stack", [self._flow]):
      # TODO(b/139941163): remove constant_value after changing num_elements to regular input
      if not self._dynamic_size and self._size is not None:
        ta_size = tensor_util.constant_value(self._size)
      else:
        ta_size = -1
      value = list_ops.tensor_list_stack(
          input_handle=self._flow,
          element_dtype=self._dtype,
          num_elements=ta_size,
          element_shape=self.element_shape)
      return value

  def gather(self, indices, name=None):
    """See TensorArray."""
    value = list_ops.tensor_list_gather(
        input_handle=self._flow,
        indices=indices,
        element_dtype=self._dtype,
        element_shape=self.element_shape,
        name=name)
    return value

  def concat(self, name=None):
    """See TensorArray."""
    if self.element_shape:
      element_shape = [None] + self.element_shape.dims[1:]
    else:
      element_shape = None

    value = list_ops.tensor_list_concat(
        input_handle=self._flow,
        element_dtype=self._dtype,
        element_shape=element_shape,
        name=name)
    return value

  @tf_should_use.should_use_result
  def unstack(self, value, name=None):
    """See TensorArray."""
    with ops.name_scope(name, "TensorArrayUnstack", [self._flow, value]):
      # TODO(b/129870929): Fix after all callers provide proper init dtype.
      value = ops.convert_to_tensor(
          value, preferred_dtype=self._dtype, name="value")
      _check_dtypes(value, self._dtype)
      self._check_element_shape(value.shape[1:])
      flow_out = list_ops.tensor_list_from_tensor(
          tensor=value, element_shape=value.shape[1:])
      return build_ta_with_new_flow(self, flow_out)

  @tf_should_use.should_use_result
  def scatter(self, indices, value, name=None):
    """See TensorArray."""
    with ops.name_scope(name, "TensorArrayScatter",
                        [self._flow, value, indices]):
      # TODO(b/129870929): Fix after all callers provide proper init dtype.
      value = ops.convert_to_tensor(
          value, preferred_dtype=self._dtype, name="value")
      _check_dtypes(value, self._dtype)
      self._check_element_shape(value.shape[1:])
      flow_out = list_ops.tensor_list_scatter(
          tensor=value, indices=indices, element_shape=self.element_shape,
          input_handle=self._flow)
      return build_ta_with_new_flow(self, flow_out)

  @tf_should_use.should_use_result
  def split(self, value, lengths, name=None):
    """See TensorArray."""
    with ops.name_scope(name, "TensorArraySplit", [self._flow, value, lengths]):
      # TODO(b/129870929): Fix after all callers provide proper init dtype.
      value = ops.convert_to_tensor(
          value, preferred_dtype=self._dtype, name="value")
      _check_dtypes(value, self._dtype)
      lengths_64 = math_ops.cast(lengths, dtypes.int64)
      if not context.executing_eagerly():
        clengths = tensor_util.constant_value(lengths_64)
        if value.shape.dims is not None and clengths is not None:
          if clengths.shape and clengths.max() == clengths.min():
            self._check_element_shape(
                tensor_shape.TensorShape([clengths[0]]).concatenate(
                    value.shape[1:]))
      flow_out = list_ops.tensor_list_split(
          tensor=value,
          lengths=lengths_64,
          element_shape=self.element_shape,
          name=name)
      return build_ta_with_new_flow(self, flow_out)

  def size(self, name=None):
    """See TensorArray."""
    if not self._dynamic_size and self._size is not None:
      return ops.convert_to_tensor(self._size, dtype=dtypes.int32)
    else:
      return list_ops.tensor_list_length(input_handle=self._flow, name=name)

  def close(self, name=None):
    """See TensorArray."""
    return gen_control_flow_ops.no_op(name=name)

# pylint: enable=protected-access


class _EagerTensorArray(object):
  """Eager-compatible implementation of TensorArray.
  """

  def __init__(self,
               dtype,
               size=None,
               dynamic_size=None,
               clear_after_read=None,
               tensor_array_name=None,
               handle=None,
               flow=None,
               infer_shape=True,
               element_shape=None,
               colocate_with_first_write_call=True,
               name=None):
    """Constructs a TensorArray compatible with eager execution.

    Args:
      dtype: (required) data type of the TensorArray.
      size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
        Required if handle is not provided.
      dynamic_size: (optional) Python bool: If true, writes to the TensorArray
        can grow the TensorArray past its initial size.  Default: False.
      clear_after_read: Boolean (optional, default: True).  If True, clear
        TensorArray values after reading them.  This disables read-many
        semantics, but allows early release of memory.
      tensor_array_name: unused.
      handle: unsupported.
      flow: unsupported.
      infer_shape: used for error checking, same semantics as TensorArray.
      element_shape: used for error checking, same semantics as TensorArray.
      colocate_with_first_write_call: unsupported.
      name: unsupported.

    Raises:
      ValueError: handle or flow are supplied, or if size is not supplied.
    """

    del (flow, tensor_array_name, name)  # Unused.

    if handle is not None:
      raise ValueError("TensorArray handles are not supported when eager "
                       "execution is enabled.")
    if size is None:
      raise ValueError("Size must be declared for TensorArrays when eager "
                       "execution is enabled.")

    # These attributes are not meaningful when eager is enabled, but some
    # library functions (e.g., those in control_flow_ops.py) access them to
    # create new tensor arrays; as such, we define them for the sake of
    # compatibility.
    self._handle = None
    # we assign a dummy value to _flow in case other code assumes it to be
    # a Tensor
    self._flow = constant_op.constant(0, dtype=dtypes.int32)
    self._infer_shape = infer_shape
    self._element_shape = tensor_shape.as_shape(element_shape)
    self._colocate_with_first_write_call = colocate_with_first_write_call

    self._dtype = dtypes.as_dtype(dtype).base_dtype
    self._dynamic_size = dynamic_size or False
    self._clear_after_read = (
        True if clear_after_read is None else clear_after_read)
    self._previously_read_indices = []

    if isinstance(size, ops.EagerTensor):
      size = size.numpy()
    self._tensor_array = [None for _ in range(size)]

  @property
  def flow(self):
    """For compatibility; flows are not meaningful when eager is enabled."""
    return self._flow

  @property
  def dtype(self):
    return self._dtype

  @property
  def handle(self):
    """For compatibility; handles are not meaningful when eager is enabled."""
    return self._handle

  @property
  def element_shape(self):
    return self._element_shape

  def identity(self):
    """See TensorArray."""
    return self.parent()

  def grad(self, source, flow=None, name=None):
    raise NotImplementedError(
        "TensorArray.grad is not supported when executing eagerly; eager's "
        "gradient implementation does not use/need this function to compute "
        "gradients of operations that use TensorArrays.")

  def read(self, index, name=None):
    """See TensorArray."""
    del name  # not meaningful when executing eagerly.

    if isinstance(index, ops.EagerTensor):
      index = index.numpy()

    if index < 0:
      raise errors_impl.OutOfRangeError(
          None, None,
          "Reading from negative indices (index %d) is not allowed." % index)

    if index >= len(self._tensor_array):
      raise errors_impl.OutOfRangeError(
          None, None, "Tried to read from index %d but array size is: %d" %
          (index, len(self._tensor_array)))

    tensor = self._tensor_array[index]
    if tensor is None:
      if index in self._previously_read_indices:
        raise errors_impl.InvalidArgumentError(
            None, None,
            "Could not read index %d twice because it was cleared after "
            "a previous read (perhaps try setting clear_after_read = false?)" %
            index)
      else:
        tensor = self._maybe_zero(index)

    if self._clear_after_read:
      self._tensor_array[index] = None
      self._previously_read_indices.append(index)
    return tensor

  def _write(self, index, value):
    """Writes `value` into index named by `index`.

    Args:
      index: 0-D.  int32 scalar with the index to write to.
      value: N-D.  Tensor of type `dtype`.  The `Tensor` to write to `index`.

    Raises:
      errors_impl.InvalidArgumentError: `value` dtype does not match dtype.
      errors_impl.OutOfRangeError: `index` is out of bounds.
      ValueError: shape of `value` is not consistent with inferred shape.
    """

    if isinstance(index, ops.EagerTensor):
      index = index.numpy()

    if index < 0:
      raise errors_impl.OutOfRangeError(
          None, None,
          "Writing to negative indices (index %d) is not allowed." % index)

    size = len(self._tensor_array)
    if index >= size:
      if not self._dynamic_size:
        raise errors_impl.OutOfRangeError(
            None, None,
            "Tried to write to index %d but array is not resizeable and size "
            "is: %d" % (index, size))
      self._tensor_array.extend(None for _ in range(index - size + 1))

    if not isinstance(value, ops.EagerTensor):
      # TODO(b/129870929): Fix after all callers provide proper init dtype.
      value = ops.convert_to_tensor(
          value, preferred_dtype=self._dtype, name="value")

    if self._dtype != value.dtype:
      raise errors_impl.InvalidArgumentError(
          None, None,
          "TensorArray dtype is %s but Op is trying to write dtype %s" %
          (self._dtype.name, value.dtype.name))

    if not self._element_shape.is_compatible_with(value.shape):
      raise ValueError("Incompatible shape for value (%s), expected (%s)" %
                       (value.shape, self._element_shape))

    if self._infer_shape:
      self._element_shape = self._element_shape.merge_with(value.shape)

    self._tensor_array[index] = value

  def write(self, index, value, name=None):
    """See TensorArray."""
    del name  # not meaningful when executing eagerly.
    self._write(index, value)
    return self.parent()

  def _maybe_zero(self, ix):
    val = self._tensor_array[ix]
    if val is None:
      val = self._tensor_array[ix] = array_ops.zeros(
          shape=self._element_shape, dtype=self._dtype)
    return val

  def stack(self, name=None):
    """See TensorArray."""
    if self._tensor_array:
      for ix in range(len(self._tensor_array)):
        self._maybe_zero(ix)
    if not self._tensor_array and self._element_shape.is_fully_defined():
      return ops.convert_to_tensor(
          np.ndarray([0] + self._element_shape), name=name, dtype=self._dtype)
    else:
      return ops.convert_to_tensor(
          self._tensor_array, name=name, dtype=self._dtype)

  def gather(self, indices, name=None):
    """See TensorArray."""
    del name  # not meaningful when executing eagerly.
    if isinstance(indices, ops.EagerTensor):
      indices = indices.numpy()
    return array_ops.stack([self._maybe_zero(i) for i in indices])

  def concat(self, name=None):
    """See TensorArray."""
    try:
      return array_ops.concat(
          [self._maybe_zero(ix) for ix in range(len(self._tensor_array))],
          0, name=name)
    except errors_impl.OpError:
      # Reproduce a subset of the error-handling for graph-mode TensorArrays.
      shapes = [t.shape for t in self._tensor_array]
      ndims = [s.ndims for s in shapes]
      if 0 in ndims:
        idx = ndims.index(0)
        raise errors_impl.InvalidArgumentError(
            None, None, "Concat saw a scalar shape at index %d but requires "
            "at least vectors." % idx)
      else:
        raise

  def unstack(self, value, name=None):
    """See TensorArray."""
    tensors = array_ops.unstack(value, name=name)
    if len(tensors) > len(self._tensor_array) and not self._dynamic_size:
      raise ValueError(
          "Cannot unstack %d tensors into a TensorArray of static size %d" %
          (len(tensors), len(self._tensor_array)))
    self._tensor_array = tensors
    return self.parent()

  def scatter(self, indices, value, name=None):
    """See TensorArray."""
    del name  # not meaningful when executing eagerly.
    if isinstance(indices, ops.EagerTensor):
      indices = indices.numpy()
    for index, val in zip(indices, array_ops.unstack(value)):
      self._write(index, val)  # pylint: disable=protected-access
    return self.parent()

  def split(self, value, lengths, name=None):
    """See TensorArray."""
    # TODO(b/129870929): Fix after all callers provide proper init dtype.
    value = ops.convert_to_tensor(
        value, preferred_dtype=self._dtype, name="value")
    _check_dtypes(value, self._dtype)
    lengths = ops.convert_to_tensor(lengths)
    sum_lengths = math_ops.reduce_sum(lengths)
    if lengths.shape.ndims != 1:
      raise errors_impl.InvalidArgumentError(
          None, None, "Expected lengths to be a vector, received shape: %s" %
          lengths.shape.as_list())
    elif value.shape.ndims == 0:
      raise errors_impl.InvalidArgumentError(
          None, None, "Expected value to be at least a vector, "
          "but received shape: %s" % value.shape.as_list())
    elif sum_lengths.numpy() != value.shape.as_list()[0]:
      raise errors_impl.InvalidArgumentError(
          None, None, "Expected sum of lengths to be equal to "
          "values.shape[0], but sum of lengths is %d and "
          "value's shape is: %s " % (sum_lengths.numpy(),
                                     value.shape.as_list()))
    elif not self._dynamic_size and lengths.shape[0] != len(self._tensor_array):
      raise errors_impl.InvalidArgumentError(
          None, None, "TensorArray's size is not equal to the size of "
          "lengths (%d vs. %d), and the TensorArray is not marked as "
          "dynamically resizeable" % (len(self._tensor_array),
                                      lengths.shape[0]))
    else:
      self._tensor_array = array_ops.split(value, lengths, name=name)
      return self.parent()

  def size(self, name=None):
    """See TensorArray."""
    del name  # not meaningful when executing eagerly.
    return constant_op.constant(len(self._tensor_array))

  def close(self, name=None):
    del name  # not meaningful when executing eagerly.
    del self._tensor_array[:]


# TensorArray is designed to hide an underlying implementation object
# and as such accesses many of that object's hidden fields.
# pylint: disable=protected-access
# pylint:disable=line-too-long
[文档]@tf_export("TensorArray") class TensorArray(object): """Class wrapping dynamic-sized, per-time-step, write-once Tensor arrays. This class is meant to be used with dynamic iteration primitives such as `while_loop` and `map_fn`. It supports gradient back-propagation via special "flow" control flow dependencies. Example 1: Plain reading and writing. >>> ta = tf.TensorArray(tf.float32, size=0, dynamic_size=True, clear_after_read=False) >>> ta = ta.write(0, 10) >>> ta = ta.write(1, 20) >>> ta = ta.write(2, 30) >>> >>> ta.read(0) <tf.Tensor: shape=(), dtype=float32, numpy=10.0> >>> ta.read(1) <tf.Tensor: shape=(), dtype=float32, numpy=20.0> >>> ta.read(2) <tf.Tensor: shape=(), dtype=float32, numpy=30.0> >>> ta.stack() <tf.Tensor: shape=(3,), dtype=float32, numpy=array([10., 20., 30.], dtype=float32)> Example 2: Fibonacci sequence algorithm that writes in a loop then returns. >>> @tf.function ... def fibonacci(n): ... ta = tf.TensorArray(tf.float32, size=0, dynamic_size=True) ... ta = ta.unstack([0., 1.]) ... ... for i in range(2, n): ... ta = ta.write(i, ta.read(i - 1) + ta.read(i - 2)) ... ... return ta.stack() >>> >>> fibonacci(7) <tf.Tensor: shape=(7,), dtype=float32, numpy=array([0., 1., 1., 2., 3., 5., 8.], dtype=float32)> Example 3: A simple loop interacting with a `tf.Variable`. >>> v = tf.Variable(1) >>> >>> @tf.function ... def f(x): ... ta = tf.TensorArray(tf.int32, size=0, dynamic_size=True) ... ... for i in tf.range(x): ... v.assign_add(i) ... ta = ta.write(i, v) ... ... return ta.stack() >>> >>> f(5) <tf.Tensor: shape=(5,), dtype=int32, numpy=array([ 1, 2, 4, 7, 11], dtype=int32)> """ def __init__(self, dtype, size=None, dynamic_size=None, clear_after_read=None, tensor_array_name=None, handle=None, flow=None, infer_shape=True, element_shape=None, colocate_with_first_write_call=True, name=None): """Construct a new TensorArray or wrap an existing TensorArray handle. A note about the parameter `name`: The name of the `TensorArray` (even if passed in) is uniquified: each time a new `TensorArray` is created at runtime it is assigned its own name for the duration of the run. This avoids name collisions if a `TensorArray` is created within a `while_loop`. Args: dtype: (required) data type of the TensorArray. size: (optional) int32 scalar `Tensor`: the size of the TensorArray. Required if handle is not provided. dynamic_size: (optional) Python bool: If true, writes to the TensorArray can grow the TensorArray past its initial size. Default: False. clear_after_read: Boolean (optional, default: True). If True, clear TensorArray values after reading them. This disables read-many semantics, but allows early release of memory. tensor_array_name: (optional) Python string: the name of the TensorArray. This is used when creating the TensorArray handle. If this value is set, handle should be None. handle: (optional) A `Tensor` handle to an existing TensorArray. If this is set, tensor_array_name should be None. Only supported in graph mode. flow: (optional) A float `Tensor` scalar coming from an existing `TensorArray.flow`. Only supported in graph mode. infer_shape: (optional, default: True) If True, shape inference is enabled. In this case, all elements must have the same shape. element_shape: (optional, default: None) A `TensorShape` object specifying the shape constraints of each of the elements of the TensorArray. Need not be fully defined. colocate_with_first_write_call: If `True`, the TensorArray will be colocated on the same device as the Tensor used on its first write (write operations include `write`, `unstack`, and `split`). If `False`, the TensorArray will be placed on the device determined by the device context available during its initialization. name: A name for the operation (optional). Raises: ValueError: if both handle and tensor_array_name are provided. TypeError: if handle is provided but is not a Tensor. """ if (context.executing_eagerly() and (flow is None or flow.dtype != dtypes.variant)): # It is possible to create a Variant-style TensorArray even in eager mode, # and this is fine but can have performance implications in eager. # An example of when this happens is if a tf.function returns a # TensorArray in its output; its flow variant object is returned to Eager. # This can be wrapped back up in a Variant-style TensorArray. implementation = _EagerTensorArray elif (flow is not None and flow.dtype == dtypes.variant or control_flow_util.EnableControlFlowV2(ops.get_default_graph())): implementation = _GraphTensorArrayV2 else: implementation = _GraphTensorArray self._implementation = implementation( dtype, size=size, dynamic_size=dynamic_size, clear_after_read=clear_after_read, tensor_array_name=tensor_array_name, handle=handle, flow=flow, infer_shape=infer_shape, element_shape=element_shape, colocate_with_first_write_call=colocate_with_first_write_call, name=name) self._implementation.parent = weakref.ref(self) @property def flow(self): """The flow `Tensor` forcing ops leading to this TensorArray state.""" return self._implementation._flow @property def dtype(self): """The data type of this TensorArray.""" return self._implementation._dtype @property def handle(self): """The reference to the TensorArray.""" return self._implementation.handle @property def element_shape(self): """The `tf.TensorShape` of elements in this TensorArray.""" return self._implementation.element_shape @property def dynamic_size(self): """Python bool; if `True` the TensorArray can grow dynamically.""" return self._implementation._dynamic_size @property def _infer_shape(self): # TODO(slebedev): consider making public or changing TensorArrayStructure # to access _implementation directly. Note that dynamic_size is also # only used by TensorArrayStructure. return self._implementation._infer_shape
[文档] def identity(self): """Returns a TensorArray with the same content and properties. Returns: A new TensorArray object with flow that ensures the control dependencies from the contexts will become control dependencies for writes, reads, etc. Use this object all for subsequent operations. """ return self._implementation.identity()
[文档] def grad(self, source, flow=None, name=None): return self._implementation.grad(source, flow=flow, name=name)
[文档] def read(self, index, name=None): """Read the value at location `index` in the TensorArray. Args: index: 0-D. int32 tensor with the index to read from. name: A name for the operation (optional). Returns: The tensor at index `index`. """ return self._implementation.read(index, name=name)
[文档] @tf_should_use.should_use_result(warn_in_eager=True) def write(self, index, value, name=None): """Write `value` into index `index` of the TensorArray. Args: index: 0-D. int32 scalar with the index to write to. value: N-D. Tensor of type `dtype`. The Tensor to write to this index. name: A name for the operation (optional). Returns: A new TensorArray object with flow that ensures the write occurs. Use this object all for subsequent operations. Raises: ValueError: if there are more writers than specified. """ return self._implementation.write(index, value, name=name)
[文档] def stack(self, name=None): """Return the values in the TensorArray as a stacked `Tensor`. All of the values must have been written and their shapes must all match. If input shapes have rank-`R`, then output shape will have rank-`(R+1)`. Args: name: A name for the operation (optional). Returns: All the tensors in the TensorArray stacked into one tensor. """ return self._implementation.stack(name=name)
[文档] def gather(self, indices, name=None): """Return selected values in the TensorArray as a packed `Tensor`. All of selected values must have been written and their shapes must all match. Args: indices: A `1-D` `Tensor` taking values in `[0, max_value)`. If the `TensorArray` is not dynamic, `max_value=size()`. name: A name for the operation (optional). Returns: The tensors in the `TensorArray` selected by `indices`, packed into one tensor. """ return self._implementation.gather(indices, name=name)
[文档] def concat(self, name=None): """Return the values in the TensorArray as a concatenated `Tensor`. All of the values must have been written, their ranks must match, and and their shapes must all match for all dimensions except the first. Args: name: A name for the operation (optional). Returns: All the tensors in the TensorArray concatenated into one tensor. """ return self._implementation.concat(name=name)
[文档] @tf_should_use.should_use_result def unstack(self, value, name=None): """Unstack the values of a `Tensor` in the TensorArray. If input value shapes have rank-`R`, then the output TensorArray will contain elements whose shapes are rank-`(R-1)`. Args: value: (N+1)-D. Tensor of type `dtype`. The Tensor to unstack. name: A name for the operation (optional). Returns: A new TensorArray object with flow that ensures the unstack occurs. Use this object all for subsequent operations. Raises: ValueError: if the shape inference fails. """ return self._implementation.unstack(value, name=name)
[文档] @tf_should_use.should_use_result def scatter(self, indices, value, name=None): """Scatter the values of a `Tensor` in specific indices of a `TensorArray`. Args: indices: A `1-D` `Tensor` taking values in `[0, max_value)`. If the `TensorArray` is not dynamic, `max_value=size()`. value: (N+1)-D. Tensor of type `dtype`. The Tensor to unpack. name: A name for the operation (optional). Returns: A new TensorArray object with flow that ensures the scatter occurs. Use this object all for subsequent operations. Raises: ValueError: if the shape inference fails. """ return self._implementation.scatter(indices, value, name=name)
[文档] @tf_should_use.should_use_result def split(self, value, lengths, name=None): """Split the values of a `Tensor` into the TensorArray. Args: value: (N+1)-D. Tensor of type `dtype`. The Tensor to split. lengths: 1-D. int32 vector with the lengths to use when splitting `value` along its first dimension. name: A name for the operation (optional). Returns: A new TensorArray object with flow that ensures the split occurs. Use this object all for subsequent operations. Raises: ValueError: if the shape inference fails. """ return self._implementation.split(value, lengths, name=name)
[文档] def size(self, name=None): """Return the size of the TensorArray.""" return self._implementation.size(name=name)
[文档] @tf_should_use.should_use_result def close(self, name=None): """Close the current TensorArray.""" return self._implementation.close(name=name)
def build_ta_with_new_flow(old_ta, flow): """Builds a TensorArray with a new `flow` tensor.""" # Sometimes we get old_ta as the implementation, sometimes it's the # TensorArray wrapper object. impl = (old_ta._implementation if isinstance(old_ta, TensorArray) else old_ta) if not context.executing_eagerly(): if (not isinstance(impl, _GraphTensorArrayV2) and control_flow_util.EnableControlFlowV2(ops.get_default_graph())): raise NotImplementedError("Attempting to build a graph-mode TF2-style " "TensorArray from either an eager-mode " "TensorArray or a TF1-style TensorArray. " "This is not currently supported. You may be " "attempting to capture a TensorArray " "inside a tf.function or tf.data map function. " "Instead, construct a new TensorArray inside " "the function.") new_ta = TensorArray( dtype=impl.dtype, handle=impl.handle, flow=flow, infer_shape=impl._infer_shape, colocate_with_first_write_call=impl._colocate_with_first_write_call) new_impl = new_ta._implementation new_impl._dynamic_size = impl._dynamic_size new_impl._size = impl._size new_impl._colocate_with = impl._colocate_with new_impl._element_shape = impl._element_shape # Share _element_shape. return new_ta # pylint: enable=protected-access def _check_dtypes(value, dtype): if value.dtype != dtype: logging.error( "Error: Input value {} has dtype {}, but expected dtype {}. " "This leads to undefined behavior and will be an error " "in future versions of TensorFlow. Traceback:\n{}".format( value, str(value.dtype), str(dtype), "".join(traceback.format_stack())))
[文档]@tf_export("TensorArraySpec") class TensorArraySpec(type_spec.TypeSpec): """Type specification for a `tf.TensorArray`.""" __slots__ = ["_element_shape", "_dtype", "_dynamic_size", "_infer_shape"] value_type = property(lambda self: TensorArray) def __init__(self, element_shape=None, dtype=dtypes.float32, dynamic_size=False, infer_shape=True): """Constructs a type specification for a `tf.TensorArray`. Args: element_shape: The shape of each element in the `TensorArray`. dtype: Data type of the `TensorArray`. dynamic_size: Whether the `TensorArray` can grow past its initial size. infer_shape: Whether shape inference is enabled. """ self._element_shape = tensor_shape.as_shape(element_shape) self._dtype = dtypes.as_dtype(dtype) self._dynamic_size = dynamic_size self._infer_shape = infer_shape
[文档] def is_compatible_with(self, other): # pylint: disable=protected-access if not isinstance(other, type_spec.TypeSpec): other = type_spec.type_spec_from_value(other) # Note: we intentionally exclude infer_shape in this check. return (isinstance(other, TensorArraySpec) and self._dtype.is_compatible_with(other._dtype) and self._element_shape.is_compatible_with(other._element_shape) and self._dynamic_size == other._dynamic_size)
[文档] def most_specific_compatible_type(self, other): # pylint: disable=protected-access if not self.is_compatible_with(other): raise ValueError("Types are not compatible") infer_shape = self._infer_shape and other._infer_shape return TensorArraySpec( self._element_shape.most_specific_compatible_shape( other._element_shape), self._dtype, self._dynamic_size, infer_shape)
def _serialize(self): return (self._element_shape, self._dtype, self._dynamic_size, self._infer_shape) @property def _component_specs(self): return [tensor_spec.TensorSpec([], dtypes.variant)] def _to_components(self, value): if not isinstance(value, TensorArray): raise TypeError("value must be a TensorArray, but saw: {}" .format(type(value))) if value.flow is not None and value.flow.dtype == dtypes.variant: return [value.flow] else: # Convert to a TF2-style TensorArray. # TODO(ebrevdo): Add an "_as_variant" method to TensorArray class, or # "implementation / as_variant" arg to TensorArray constructor. with ops.name_scope("convert_tensor_array"): flow = list_ops.tensor_list_from_tensor( tensor=value.stack(), element_shape=value.element_shape) return [flow] def _from_components(self, tensor_list): # This will return a TF2 Graph-style TensorArray because tensor_list[0] is # a variant object. size == -1 implies unknown size. ret = TensorArray( dtype=self._dtype, flow=tensor_list[0], dynamic_size=self._dynamic_size, infer_shape=self._infer_shape) ret._implementation._element_shape = [self._element_shape] # pylint: disable=protected-access return ret
[文档] @staticmethod def from_value(value): if not isinstance(value, TensorArray): raise TypeError("Expected value to be a TensorArray, but saw: {}". format(type(value))) return TensorArraySpec( dtype=value.dtype, element_shape=value.element_shape, dynamic_size=value.dynamic_size, infer_shape=value._infer_shape) # pylint: disable=protected-access
def _to_legacy_output_types(self): return self._dtype def _to_legacy_output_shapes(self): # Sneak the dynamic_size and infer_shape values into the legacy shape. return (tensor_shape.TensorShape([self._dynamic_size, self._infer_shape ]).concatenate(self._element_shape)) def _to_legacy_output_classes(self): return TensorArray
# Register the TypeSpec for TensorArray. If TensorArray is updated to be a # CompositeTensor, then this registration can be deleted. type_spec.register_type_spec_from_value_converter( TensorArray, TensorArraySpec.from_value, allow_subclass=True)