如何在Python中使用NumPy argmax()函数 要在Python中使用NumPy的argmax()函数,可以按照以下步骤进行操作: 1. 首先,确保已经安装了NumPy库。可以使用pip install numpy命令来安装它。 2. 导入NumPy库,可以使用import numpy语句。 3. 创建一个NumPy数组,可以使用numpy.array()函数。例如,可以使用以下代码创建一个包含一些数字的数组: import numpy as np arr = np.array([5, 2, 8, 3, 9]) 4. 使用argmax()函数来查找数组中的最大值的索引。可以使用以下代码: max_index = np.argmax(arr) 5. 最后,可以打印出最大值的索引。例如,可以使用以下代码: print(“最大值的索引为:”, max_index) 使用argmax()函数可以轻松地找到NumPy数组中的最大值的索引。希望这个简单的教程能够帮助你开始使用NumPy的argmax()函数。

在本教程中,您将学习如何使用numpy argmax()函数来找到数组中最大元素的索引。

numpy是python中用于科学计算的强大库;它提供了比python列表更高效的n维数组。当使用numpy数组时,您经常会执行的一个常见操作是查找数组中的最大值。然而,有时您可能想要找到最大值出现的索引

argmax()函数帮助您在一维和多维数组中找到最大值的索引。现在让我们继续学习它是如何工作的。

如何在numpy数组中找到最大元素的索引

要按照本教程进行操作,您需要安装python和numpy。您可以通过启动python repl或启动jupyter笔记本来编写代码。

首先,让我们导入常用别名np的numpy。

import numpy as np

您可以使用numpy的max()函数来获取数组中的最大值(可选择沿特定轴)。

array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.max(array_1))

# 输出
10

在这种情况下,np.max(array_1)返回10,这是正确的。

假设您想要找到最大值出现的索引。您可以采取以下两步骤:

  1. 找到最大元素。
  2. 找到最大元素的索引。

array_1中,最大值10出现在索引4处,遵循零索引。第一个元素位于索引0;第二个元素位于索引1,依此类推。

要找到最大值出现的索引,您可以使用numpy where()函数。np.where(condition)返回一个数组,其中包含conditiontrue的所有索引。

您需要进入数组并访问第一个索引处的项目。为了找到最大值的出现位置,我们将condition设置为array_1==10;请记住10是array_1中的最大值。

print(int(np.where(array_1==10)[0]))

# 输出
4

我们使用np.where()仅使用了条件,但这不是推荐使用此函数的方法。

📑 注意:numpy where()函数
np.where(condition,x,y)返回:
– 当条件为true时,返回来自x的元素,和
– 当条件为false时,返回来自y的元素。

因此,通过链接np.max()np.where()函数,我们可以找到最大元素,然后找到它出现的索引。

与上述的两步骤过程不同,您可以使用numpy的argmax()函数来获取数组中最大元素的索引。

numpy argmax()函数的语法

使用numpy argmax()函数的一般语法如下:

np.argmax(array,axis,out)
# 我们已经将numpy导入为别名np

在上述语法中:

  • array 是任何有效的numpy数组。
  • axis 是一个可选参数。在使用多维数组时,可以使用axis参数来找到沿特定轴的最大索引。
  • out 是另一个可选参数。您可以将out参数设置为numpy数组,以存储argmax()函数的输出。

注意:从numpy版本1.22.0开始,还有一个额外的keepdims参数。当我们在argmax()函数调用中指定axis参数时,数组会沿该轴进行缩减。但是将keepdims参数设置为true可以确保返回的输出与输入数组的形状相同。

使用numpy的argmax()函数找到最大元素的索引

#1. 让我们使用numpy的argmax()函数找到array_1中最大元素的索引。

array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.argmax(array_1))

# 输出
4

argmax()函数返回的结果是4,这是正确的!✅

#2. 如果我们重新定义array_1,使得10出现两次,那么argmax()函数只返回第一次出现的索引。

array_1 = np.array([1,5,7,2,10,10,8,4])
print(np.argmax(array_1))

# 输出
4

对于接下来的示例,我们将使用在示例#1中定义的array_1的元素。

使用numpy的argmax()函数在二维数组中找到最大元素的索引

让我们将numpy数组array_1重新整形为一个包含两行四列的二维数组。

array_2 = array_1.reshape(2,4)
print(array_2)

# 输出
[[ 1  5  7  2]
 [10  9  8  4]]

对于二维数组,轴0表示行,轴1表示列。numpy数组遵循零索引。所以numpy数组array_2的行和列的索引如下所示:

现在,让我们在二维数组array_2上调用argmax()函数。

print(np.argmax(array_2))

# 输出
4

即使我们在二维数组上调用了argmax()函数,它仍然返回4。这与前一节中一维数组array_1的输出相同。

为什么会这样? 🤔

这是因为我们没有为axis参数指定任何值。当未设置此axis参数时,默认情况下,argmax()函数会返回沿扁平化数组的最大元素的索引。

什么是扁平化数组? 如果存在一个形状为d1 x d2 x … x dn的n维数组,其中d1,d2,直到dn是n个维度上的数组大小,则扁平化数组是一个大小为  d1 * d2 * … * dn 的长一维数组。

要检查array_2的扁平化数组的样子,可以调用flatten()方法,如下所示:

array_2.flatten()

# 输出
array([ 1,  5,  7,  2, 10,  9,  8,  4])

最大元素的行索引(轴 = 0)

让我们继续找到最大元素的行索引(轴 = 0)。

np.argmax(array_2,axis=0)

# 输出
array([1, 1, 1, 1])

这个输出可能有点难理解,但我们将会理解它的工作原理。

我们将axis参数设置为零(axis = 0),因为我们想找到每列中最大元素出现的行索引。因此,argmax()函数返回每个列中最大元素出现的行号。

让我们通过可视化来更好地理解。

根据上图和argmax()的输出,我们得到以下结果:

  • 对于索引为0的第一列,最大值10出现在第二行,索引为1。
  • 对于索引为1的第二列,最大值9出现在第二行,索引为1。
  • 对于索引为2和3的第三列和第四列,最大值84都出现在第二行,索引为1。

这就是为什么我们有输出array([1, 1, 1, 1]),因为最大元素沿着行出现在第二行(对于所有列)。

最大元素的列索引(轴 = 1)

接下来,让我们使用argmax()函数来找到最大元素的列索引。

运行以下代码片段并观察输出。

np.argmax(array_2,axis=1)
array([2, 0])

你能理解输出吗?

我们设置axis = 1来计算最大元素沿列的索引。

argmax()函数返回每行中最大值所在的列号。

这是一个可视化解释:

根据上图和argmax()的输出,我们得到以下结果:

  • 对于索引为0的第一行,最大值7出现在第三列,索引为2。
  • 对于索引为1的第二行,最大值10出现在第一列,索引为0。

希望你现在明白了输出array([2, 0])的含义。

在numpy argmax()中使用可选的out参数

你可以在numpy argmax()函数中使用可选的out参数将输出存储在一个numpy数组中。

让我们初始化一个全零数组来存储上面argmax()函数调用的输出结果 – 找到沿列的最大元素的索引(axis = 1)。

out_arr = np.zeros((2,))
print(out_arr)
[0. 0.]

现在,让我们重新访问找到最大元素的索引(axis = 1)的示例,并将out设置为我们上面定义的out_arr

np.argmax(array_2,axis=1,out=out_arr)

我们可以看到python解释器抛出了一个typeerror,因为out_arr默认初始化为一个浮点数数组。

---------------------------------------------------------------------------
typeerror                                 traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
     56     try:
---> 57         return bound(*args, **kwds)
     58     except typeerror:

typeerror: cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'

因此,在将out参数设置为输出数组时,确保输出数组的形状和数据类型是正确的非常重要。由于数组索引始终是整数,因此在定义输出数组时,我们应将dtype参数设置为int

out_arr = np.zeros((2,),dtype=int)
print(out_arr)

# 输出
[0 0]

现在我们可以使用argmax()函数调用axisout参数,这次可以成功运行。

np.argmax(array_2,axis=1,out=out_arr)

argmax()函数的输出结果现在可以在数组out_arr中访问。

print(out_arr)
# 输出
[2 0]

结论

希望本教程能帮助您了解如何使用numpy的argmax()函数。您可以在jupyter notebook中运行示例代码。

让我们回顾一下我们学到的内容。

  • numpy的argmax()函数返回数组中最大元素的索引。如果最大元素在数组a中出现多次,则np.argmax(a)返回元素第一次出现的索引。
  • 在使用多维数组时,可以使用可选的axis参数获取沿特定轴的最大元素的索引。例如,在二维数组中:通过设置axis = 0axis = 1,分别可以获取行和列的最大元素的索引。
  • 如果您希望将返回的值存储在另一个数组中,可以将可选的out参数设置为输出数组。但是,输出数组的形状和数据类型应与原数组兼容。

接下来,请查看关于python集合的深入指南。还可以学习如何使用python sleep函数在代码中添加延迟。

类似文章