NumPy Where: Understanding np.where()

Ben Cook • Posted 2021-03-03 • Last updated 2021-10-19

The NumPy where() function is like a vectorized switch that you can use to combine two arrays. For example, let’s say you have an array with some data called df.revenue and you want to create a new array with 1 whenever an element in df.revenue is more than one standard deviation from the mean and -1 for all other elements.

This is a perfect use case for np.where(). First, create a boolean array for your conditional, and then call np.where():

import numpy as np
import pandas as pd

df = pd.read_csv("https://jbencook.s3.amazonaws.com/data/dummy-sales.csv")
condition = np.abs(df.revenue - df.revenue.mean()) > df.revenue.std()

np.where(condition, 1, -1)

# Expected result
# array([ 1, -1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1,  1,  1, -1,
#        -1, -1,  1, -1, -1, -1, -1,  1,  1,  1,  1,  1,  1])

The arguments to np.where() are:

  • condition: a NumPy array of elements that evaluate to True or False
  • x: an optional array-like result for elements that evaluate to True
  • y: an optional array-like result for elements that evaluate to False

The elements of condition don’t actually need to have a boolean type as long as they can be coerced to a boolean (e.g. non-zero integers are interpreted as True). Also, both x and y are optional, but if you provide one, you need to provide both. Additionally, the input arrays can have any shape so you can use this as a multi-dimensional switch.

One thing to watch out for: the return value takes a different form if you don’t supply x and y. In that case, np.where() returns the indices of the true elements (for a 1-D vector) and the indices for all axes where the elements are true for higher dimensional cases. This is equivalent to np.argwhere() except that the index arrays are split by axis.

You can see how this works by calling np.stack() on the result of np.where():

x = np.eye(4)
np.stack(np.where(x), -1) == np.argwhere(x)

# Expected result
# array([[ True,  True],
#        [ True,  True],
#        [ True,  True],
#        [ True,  True]])

This makes np.where() without the x and y inputs equivalent to calling the .nonzero() method on the condition array:

np.stack(x.nonzero(), -1) == np.argwhere(x)

# Expected result
# array([[ True,  True],
#        [ True,  True],
#        [ True,  True],
#        [ True,  True]])

Multi-dimensional binary cross entropy

Now that we know how the API works, let’s look at another example: multi-dimensional binary cross entropy. Say we have a 3-D array of binary class probabilities yhat and a 3-D array of binary labels y. The one-liner formula for binary cross-entropy is the following:

-(y * np.log(yhat) + (1 - y) * np.log(1 - yhat)).mean()

This does work in the multi-dimensional case because NumPy defaults to element-wise operations. The multiplication of y and 1 - y times the log terms function like switches. When y == 1 the first term is included and when y == 0 the second term is included:

np.random.seed(1)

yhat = np.random.uniform(size=(3, 3, 3))
y = np.random.randint(0, 2, size=(3, 3, 3))

-(y * np.log(yhat) + (1 - y) * np.log(1 - yhat)).mean()

# Expected result
# 1.221865004504288

But we can accomplish the same thing with np.where():

-np.where(y, np.log(yhat), np.log(1 - yhat)).mean()

# Expected result
# 1.221865004504288

Pretty cool! This is not necessarily a better implementation in any important way, but it does make the purpose of the y and 1 - y terms very clear.