如何在Python中使用NumPy argmax()函数 要在Python中使用NumPy的argmax()函数,可以按照以下步骤进行操作: 1. 首先,确保已经安装了NumPy库。可以使用pip install numpy命令来安装它。 2. 导入NumPy库,可以使用import numpy语句。 3. 创建一个NumPy数组,可以使用numpy.array()函数。例如,可以使用以下代码创建一个包含一些数字的数组: import numpy as np arr = np.array([5, 2, 8, 3, 9]) 4. 使用argmax()函数来查找数组中的最大值的索引。可以使用以下代码: max_index = np.argmax(arr) 5. 最后,可以打印出最大值的索引。例如,可以使用以下代码: print(“最大值的索引为:”, max_index) 使用argmax()函数可以轻松地找到NumPy数组中的最大值的索引。希望这个简单的教程能够帮助你开始使用NumPy的argmax()函数。
在本教程中,您将学习如何使用numpy argmax()函数来找到数组中最大元素的索引。
numpy是python中用于科学计算的强大库;它提供了比python列表更高效的n维数组。当使用numpy数组时,您经常会执行的一个常见操作是查找数组中的最大值。然而,有时您可能想要找到最大值出现的索引。
argmax()
函数帮助您在一维和多维数组中找到最大值的索引。现在让我们继续学习它是如何工作的。
如何在numpy数组中找到最大元素的索引
要按照本教程进行操作,您需要安装python和numpy。您可以通过启动python repl或启动jupyter笔记本来编写代码。
首先,让我们导入常用别名np
的numpy。
import numpy as np
您可以使用numpy的max()
函数来获取数组中的最大值(可选择沿特定轴)。
array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.max(array_1))
# 输出
10
在这种情况下,np.max(array_1)
返回10,这是正确的。
假设您想要找到最大值出现的索引。您可以采取以下两步骤:
- 找到最大元素。
- 找到最大元素的索引。
在array_1
中,最大值10出现在索引4处,遵循零索引。第一个元素位于索引0;第二个元素位于索引1,依此类推。
要找到最大值出现的索引,您可以使用numpy where()函数。np.where(condition)
返回一个数组,其中包含condition
为true
的所有索引。
您需要进入数组并访问第一个索引处的项目。为了找到最大值的出现位置,我们将condition
设置为array_1==10
;请记住10是array_1
中的最大值。
print(int(np.where(array_1==10)[0]))
# 输出
4
我们使用np.where()
仅使用了条件,但这不是推荐使用此函数的方法。
📑 注意:numpy where()函数:
np.where(condition,x,y)
返回:
– 当条件为true
时,返回来自x
的元素,和
– 当条件为false
时,返回来自y
的元素。
因此,通过链接np.max()
和np.where()
函数,我们可以找到最大元素,然后找到它出现的索引。
与上述的两步骤过程不同,您可以使用numpy的argmax()函数来获取数组中最大元素的索引。
numpy argmax()函数的语法
使用numpy argmax()函数的一般语法如下:
np.argmax(array,axis,out)
# 我们已经将numpy导入为别名np
在上述语法中:
- array 是任何有效的numpy数组。
- axis 是一个可选参数。在使用多维数组时,可以使用axis参数来找到沿特定轴的最大索引。
- out 是另一个可选参数。您可以将
out
参数设置为numpy数组,以存储argmax()
函数的输出。
注意:从numpy版本1.22.0开始,还有一个额外的
keepdims
参数。当我们在argmax()
函数调用中指定axis
参数时,数组会沿该轴进行缩减。但是将keepdims
参数设置为true
可以确保返回的输出与输入数组的形状相同。
使用numpy的argmax()函数找到最大元素的索引
#1. 让我们使用numpy的argmax()函数找到array_1
中最大元素的索引。
array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.argmax(array_1))
# 输出
4
argmax()函数返回的结果是4,这是正确的!✅
#2. 如果我们重新定义array_1
,使得10出现两次,那么argmax()函数只返回第一次出现的索引。
array_1 = np.array([1,5,7,2,10,10,8,4])
print(np.argmax(array_1))
# 输出
4
对于接下来的示例,我们将使用在示例#1中定义的array_1
的元素。
使用numpy的argmax()函数在二维数组中找到最大元素的索引
让我们将numpy数组array_1
重新整形为一个包含两行四列的二维数组。
array_2 = array_1.reshape(2,4)
print(array_2)
# 输出
[[ 1 5 7 2]
[10 9 8 4]]
对于二维数组,轴0表示行,轴1表示列。numpy数组遵循零索引。所以numpy数组array_2
的行和列的索引如下所示:
现在,让我们在二维数组array_2
上调用argmax()函数。
print(np.argmax(array_2))
# 输出
4
即使我们在二维数组上调用了argmax()函数,它仍然返回4。这与前一节中一维数组array_1
的输出相同。
为什么会这样? 🤔
这是因为我们没有为axis参数指定任何值。当未设置此axis参数时,默认情况下,argmax()
函数会返回沿扁平化数组的最大元素的索引。
什么是扁平化数组? 如果存在一个形状为d1 x d2 x … x dn的n维数组,其中d1,d2,直到dn是n个维度上的数组大小,则扁平化数组是一个大小为 d1 * d2 * … * dn 的长一维数组。
要检查array_2
的扁平化数组的样子,可以调用flatten()
方法,如下所示:
array_2.flatten()
# 输出
array([ 1, 5, 7, 2, 10, 9, 8, 4])
最大元素的行索引(轴 = 0)
让我们继续找到最大元素的行索引(轴 = 0)。
np.argmax(array_2,axis=0)
# 输出
array([1, 1, 1, 1])
这个输出可能有点难理解,但我们将会理解它的工作原理。
我们将axis
参数设置为零(axis = 0
),因为我们想找到每列中最大元素出现的行索引。因此,argmax()
函数返回每个列中最大元素出现的行号。
让我们通过可视化来更好地理解。
根据上图和argmax()
的输出,我们得到以下结果:
- 对于索引为0的第一列,最大值10出现在第二行,索引为1。
- 对于索引为1的第二列,最大值9出现在第二行,索引为1。
- 对于索引为2和3的第三列和第四列,最大值8和4都出现在第二行,索引为1。
这就是为什么我们有输出array([1, 1, 1, 1])
,因为最大元素沿着行出现在第二行(对于所有列)。
最大元素的列索引(轴 = 1)
接下来,让我们使用argmax()
函数来找到最大元素的列索引。
运行以下代码片段并观察输出。
np.argmax(array_2,axis=1)
array([2, 0])
你能理解输出吗?
我们设置axis = 1
来计算最大元素沿列的索引。
argmax()
函数返回每行中最大值所在的列号。
这是一个可视化解释:
根据上图和argmax()
的输出,我们得到以下结果:
- 对于索引为0的第一行,最大值7出现在第三列,索引为2。
- 对于索引为1的第二行,最大值10出现在第一列,索引为0。
希望你现在明白了输出array([2, 0])
的含义。
在numpy argmax()中使用可选的out参数
你可以在numpy argmax()函数中使用可选的out
参数将输出存储在一个numpy数组中。
让我们初始化一个全零数组来存储上面argmax()
函数调用的输出结果 – 找到沿列的最大元素的索引(axis = 1
)。
out_arr = np.zeros((2,))
print(out_arr)
[0. 0.]
现在,让我们重新访问找到最大元素的索引(axis = 1
)的示例,并将out
设置为我们上面定义的out_arr
。
np.argmax(array_2,axis=1,out=out_arr)
我们可以看到python解释器抛出了一个typeerror
,因为out_arr
默认初始化为一个浮点数数组。
---------------------------------------------------------------------------
typeerror traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
56 try:
---> 57 return bound(*args, **kwds)
58 except typeerror:
typeerror: cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'
因此,在将out
参数设置为输出数组时,确保输出数组的形状和数据类型是正确的非常重要。由于数组索引始终是整数,因此在定义输出数组时,我们应将dtype
参数设置为int
。
out_arr = np.zeros((2,),dtype=int)
print(out_arr)
# 输出
[0 0]
现在我们可以使用argmax()
函数调用axis
和out
参数,这次可以成功运行。
np.argmax(array_2,axis=1,out=out_arr)
argmax()
函数的输出结果现在可以在数组out_arr
中访问。
print(out_arr)
# 输出
[2 0]
结论
希望本教程能帮助您了解如何使用numpy的argmax()函数。您可以在jupyter notebook中运行示例代码。
让我们回顾一下我们学到的内容。
- numpy的argmax()函数返回数组中最大元素的索引。如果最大元素在数组a中出现多次,则
np.argmax(a)
返回元素第一次出现的索引。 - 在使用多维数组时,可以使用可选的axis参数获取沿特定轴的最大元素的索引。例如,在二维数组中:通过设置axis = 0和axis = 1,分别可以获取行和列的最大元素的索引。
- 如果您希望将返回的值存储在另一个数组中,可以将可选的out参数设置为输出数组。但是,输出数组的形状和数据类型应与原数组兼容。
接下来,请查看关于python集合的深入指南。还可以学习如何使用python sleep函数在代码中添加延迟。