
where(condition, x=None, y=None)[source]

Returns elements chosen from x or y depending on a boolean tensor condition, or the indices of entries satisfying the condition.

The input tensors condition, x, and y must all be broadcastable to the same shape.

  • condition (tensor_like[bool]) – A boolean tensor. Where True , elements from x will be chosen, otherwise y. If x and y are None the indices where condition==True holds will be returned.

  • x (tensor_like) – values from which to choose if the condition evaluates to True

  • y (tensor_like) – values from which to choose if the condition evaluates to False


If x is None and y is None, a tensor or tuple of tensors with the indices where condition is True . Else, a tensor with elements from x where the condition is True, and y otherwise. In this case, the output tensor has the same shape as the input tensors.

Return type

tensor_like or tuple[tensor_like]

Example with three arguments

>>> a = torch.tensor([0.6, 0.23, 0.7, 1.5, 1.7], requires_grad=True)
>>> b = torch.tensor([-1., -2., -3., -4., -5.], requires_grad=True)
>>> math.where(a < 1, a, b)
tensor([ 0.6000,  0.2300,  0.7000, -4.0000, -5.0000], grad_fn=<SWhereBackward>)


The output format for x=None and y=None follows the respective interface and differs between TensorFlow and all other interfaces: For TensorFlow, the output is a tensor with shape (len(condition.shape), num_true) where num_true is the number of entries in condition that are True . For all other interfaces, the output is a tuple of tensor-like objects, with the jth object indicating the jth entries of all indices. Also see the examples below.

Example with single argument

For Torch, Autograd, JAX and NumPy, the output formatting is as follows:

>>> a = [[0.6, 0.23, 1.7],[1.5, 0.7, -0.2]]
>>> math.where(torch.tensor(a) < 1)
(tensor([0, 0, 1, 1]), tensor([0, 1, 1, 2]))

This is not a single tensor-like object but corresponds to the shape (2, 4) . For TensorFlow, on the other hand:

>>> math.where(tf.constant(a) < 1)
<tf.Tensor: shape=(2, 4), dtype=int64, numpy=
array([[0, 0, 1, 1],
       [0, 1, 1, 2]])>

Note that the number of dimensions of the output does not depend on the input shape, it is always two-dimensional.


Using PennyLane

Release news


