Kā lietot funkciju NumPy argmax() programmā Python

Šajā apmācībā jūs uzzināsit, kā izmantot funkciju NumPy argmax(), lai atrastu maksimālā elementa indeksu masīvos.

NumPy ir jaudīga bibliotēka zinātniskai skaitļošanai programmā Python; tas nodrošina N-dimensiju masīvus, kas ir efektīvāki nekā Python saraksti. Viena no izplatītākajām darbībām, ko veiksit, strādājot ar NumPy masīviem, ir masīva maksimālās vērtības atrašana. Tomēr dažreiz var vēlēties atrast indeksu, kurā tiek sasniegta maksimālā vērtība.

Funkcija argmax() palīdz atrast maksimuma indeksu gan viendimensiju, gan daudzdimensiju masīvos. Turpināsim, lai uzzinātu, kā tas darbojas.

Kā atrast maksimālā elementa indeksu NumPy masīvā

Lai sekotu šai apmācībai, jums ir jābūt instalētam Python un NumPy. Varat kodēt, startējot Python REPL vai palaižot Jupyter piezīmju grāmatiņu.

Pirmkārt, importēsim NumPy ar parasto aizstājvārdu np.

import numpy as np

Varat izmantot funkciju NumPy max(), lai iegūtu maksimālo vērtību masīvā (pēc izvēles pa noteiktu asi).

array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.max(array_1))

# Output
10

Šajā gadījumā np.max(masīvs_1) atgriež 10, kas ir pareizi.

Pieņemsim, ka vēlaties atrast indeksu, kurā masīvā ir maksimālā vērtība. Varat izmantot šādu divpakāpju pieeju:

  • Atrodiet maksimālo elementu.
  • Atrodiet maksimālā elementa indeksu.
  • Masīvā_1 maksimālā vērtība 10 ir indeksā 4 pēc nulles indeksēšanas. Pirmais elements atrodas indeksā 0; otrais elements atrodas indeksā 1 un tā tālāk.

    Lai atrastu indeksu, kurā tiek sasniegts maksimums, varat izmantot funkciju NumPy where(). np.where(condition) atgriež visu indeksu masīvu, kur nosacījums ir True.

    Jums būs jāpieskaras masīvam un jāpiekļūst vienumam pirmajā rādītājā. Lai noskaidrotu, kur rodas maksimālā vērtība, mēs uzstādām nosacījumu uz masīvs_1==10; atcerieties, ka 10 ir maksimālā vērtība masīvā_1.

    print(int(np.where(array_1==10)[0]))
    
    # Output
    4

    Mēs esam izmantojuši np.where() tikai ar nosacījumu, taču šī nav ieteicamā šīs funkcijas izmantošanas metode.

    📑 Piezīme: NumPy kur() funkcija:
    np.where(condition,x,y) atgriež:

    – elementi no x, ja nosacījums ir Patiess, un
    – Elementi no y, ja nosacījums ir False.

      Ko nozīmē WDYM?

    Tāpēc, savienojot funkcijas np.max() un np.where(), mēs varam atrast maksimālo elementu, kam seko indekss, kurā tas notiek.

    Iepriekš minētā divpakāpju procesa vietā varat izmantot funkciju NumPy argmax(), lai iegūtu masīva maksimālā elementa indeksu.

    Funkcijas NumPy argmax() sintakse

    Funkcijas NumPy argmax() vispārīgā sintakse ir šāda:

    np.argmax(array,axis,out)
    # we've imported numpy under the alias np

    Iepriekš minētajā sintaksē:

    • masīvs ir jebkurš derīgs NumPy masīvs.
    • ass ir izvēles parametrs. Strādājot ar daudzdimensiju masīviem, varat izmantot asis parametru, lai atrastu maksimuma indeksu pa noteiktu asi.
    • out ir vēl viens izvēles parametrs. Varat iestatīt izejas parametru uz NumPy masīvu, lai saglabātu funkcijas argmax () izvadi.

    Piezīme. No NumPy versijas 1.22.0 ir papildu Keepdims parametrs. Kad funkcijas argmax() izsaukumā norādām ass parametru, masīvs tiek samazināts gar šo asi. Bet iestatot Keepdims parametru uz True, tiek nodrošināts, ka atgrieztajai izvadei ir tāda pati forma kā ievades masīvam.

    Izmantojot NumPy argmax(), lai atrastu maksimālā elementa indeksu

    #1. Izmantosim funkciju NumPy argmax(), lai atrastu maksimālā elementa indeksu masīvā_1.

    array_1 = np.array([1,5,7,2,10,9,8,4])
    print(np.argmax(array_1))
    
    # Output
    4

    Funkcija argmax() atgriež 4, kas ir pareizi! ✅

    #2. Ja mēs atkārtoti definējam masīvu_1 tā, lai 10 notiktu divas reizes, funkcija argmax() atgriež tikai pirmās gadījuma indeksu.

    array_1 = np.array([1,5,7,2,10,10,8,4])
    print(np.argmax(array_1))
    
    # Output
    4

    Pārējos piemēros mēs izmantosim masīva_1 elementus, ko definējām 1. piemērā.

    Izmantojot NumPy argmax(), lai atrastu maksimālā elementa indeksu 2D masīvā

    Pārveidosim NumPy masīvu masīvu_1 par divdimensiju masīvu ar divām rindām un četrām kolonnām.

    array_2 = array_1.reshape(2,4)
    print(array_2)
    
    # Output
    [[ 1  5  7  2]
     [10  9  8  4]]

    Divdimensiju masīvam ass 0 apzīmē rindas un 1. ass apzīmē kolonnas. NumPy masīvi seko nulles indeksācijai. Tātad NumPy masīva masīva_2 rindu un kolonnu indeksi ir šādi:

    Tagad izsauksim funkciju argmax() divdimensiju masīvā, masīvs_2.

    print(np.argmax(array_2))
    
    # Output
    4

    Pat ja mēs izsaucām argmax() divdimensiju masīvā, tas joprojām atgriež 4. Tas ir identisks iepriekšējās sadaļas viendimensijas masīva izvadei masīvs_1.

    Kāpēc tas notiek? 🤔

    Tas ir tāpēc, ka mēs neesam norādījuši nekādu vērtību ass parametram. Ja šis ass parametrs nav iestatīts, pēc noklusējuma funkcija argmax() atgriež maksimālā elementa indeksu saplacinātajā masīvā.

      Vai TextNow ir anonīms?

    Kas ir saplacināts masīvs? Ja ir N-dimensionāls masīvs ar formu d1 x d2 x … x dN, kur d1, d2, līdz dN ir masīva izmēri pa N izmēriem, tad saplacinātais masīvs ir garš viendimensijas masīvs ar izmēru. d1 * d2 * … * dN.

    Lai pārbaudītu, kā izskatās saplacināts masīvs masīvam_2, varat izsaukt flatten() metodi, kā parādīts tālāk:

    array_2.flatten()
    
    # Output
    array([ 1,  5,  7,  2, 10,  9,  8,  4])

    Maksimālā elementa rādītājs gar rindām (ass = 0)

    Turpināsim atrast maksimālā elementa indeksu pa rindām (ass = 0).

    np.argmax(array_2,axis=0)
    
    # Output
    array([1, 1, 1, 1])

    Šo rezultātu var būt nedaudz grūti saprast, taču mēs sapratīsim, kā tas darbojas.

    Mēs esam iestatījuši ass parametru uz nulli (ass = 0), jo mēs vēlamies atrast maksimālā elementa indeksu gar rindām. Tāpēc funkcija argmax() atgriež rindas numuru, kurā ir maksimālais elements — katrai no trim kolonnām.

    Vizualizēsim to, lai labāk izprastu.

    No iepriekš minētās diagrammas un argmax () izvades mums ir šāds:

    • Pirmajā kolonnā ar indeksu 0 maksimālā vērtība 10 ir otrajā rindā ar indeksu = 1.
    • Otrajai kolonnai indeksā 1 maksimālā vērtība 9 ir otrajā rindā ar indeksu = 1.
    • Trešajai un ceturtajai slejai ar indeksu 2 un 3 maksimālās vērtības 8 un 4 atrodas otrajā rindā ar indeksu = 1.

    Tieši tāpēc mums ir izvades masīvs ([1, 1, 1, 1]), jo maksimālais elements gar rindām ir otrajā rindā (visām kolonnām).

    Maksimālā elementa rādītājs gar kolonnām (ass = 1)

    Tālāk izmantosim funkciju argmax(), lai kolonnās atrastu maksimālā elementa indeksu.

    Palaidiet šādu koda fragmentu un novērojiet izvadi.

    np.argmax(array_2,axis=1)
    array([2, 0])

    Vai varat parsēt izvadi?

    Mēs esam iestatījuši ass = 1, lai aprēķinātu maksimālā elementa indeksu gar kolonnām.

    Funkcija argmax() katrai rindai atgriež kolonnas numuru, kurā ir maksimālā vērtība.

    Šeit ir vizuāls skaidrojums:

    No iepriekš minētās diagrammas un argmax () izvades mums ir šāds:

    • Pirmajā rindā ar indeksu 0 maksimālā vērtība 7 ir trešajā kolonnā ar indeksu = 2.
    • Otrajai rindai ar indeksu 1 maksimālā vērtība 10 ir pirmajā kolonnā ar indeksu = 0.

    Es ceru, ka tagad saprotat, kāda ir izvade, array ([2, 0]) nozīmē.

      Kā lietot īstu GameCube kontrolieri vai Wiimote programmā Dolphin

    Neobligātā izejas parametra izmantošana programmā NumPy argmax ()

    Varat izmantot neobligāto izejas parametru funkcijā NumPy argmax(), lai saglabātu izvadi NumPy masīvā.

    Inicializēsim nulles masīvu, lai saglabātu iepriekšējās funkcijas argmax() izsaukuma izvadi – lai atrastu maksimuma indeksu pa kolonnām (ass= 1).

    out_arr = np.zeros((2,))
    print(out_arr)
    [0. 0.]

    Tagad vēlreiz apskatīsim piemēru, kā atrast maksimālo elementu indeksu gar kolonnām (ass = 1) un iestatīsim out_arr, ko esam definējuši iepriekš.

    np.argmax(array_2,axis=1,out=out_arr)

    Mēs redzam, ka Python tulks rada TypeError, jo out_arr pēc noklusējuma tika inicializēts uz pludiņu masīvu.

    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    /usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
         56     try:
    ---> 57         return bound(*args, **kwds)
         58     except TypeError:
    
    TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'

    Tāpēc, iestatot izejas parametru izvades masīvam, ir svarīgi nodrošināt, lai izvades masīvam būtu pareiza forma un datu tips. Tā kā masīva indeksi vienmēr ir veseli skaitļi, definējot izvades masīvu, parametram dtype jāiestata uz int.

    out_arr = np.zeros((2,),dtype=int)
    print(out_arr)
    
    # Output
    [0 0]

    Tagad mēs varam iet uz priekšu un izsaukt funkciju argmax () gan ar ass, gan parametriem out, un šoreiz tā darbojas bez kļūdām.

    np.argmax(array_2,axis=1,out=out_arr)

    Funkcijas argmax() izvadei tagad var piekļūt masīvā out_arr.

    print(out_arr)
    # Output
    [2 0]

    Secinājums

    Es ceru, ka šī apmācība palīdzēja jums saprast, kā izmantot NumPy argmax() funkciju. Kodu piemērus var palaist Jupyter piezīmju grāmatiņā.

    Pārskatīsim to, ko esam iemācījušies.

    • Funkcija NumPy argmax() atgriež masīva maksimālā elementa indeksu. Ja maksimālais elements masīvā a ir sastopams vairāk nekā vienu reizi, tad np.argmax(a) atgriež elementa pirmās parādīšanās indeksu.
    • Strādājot ar daudzdimensiju masīviem, varat izmantot izvēles asis parametru, lai iegūtu maksimālā elementa indeksu pa noteiktu asi. Piemēram, divdimensiju masīvā: iestatot ass = 0 un ass = 1, jūs varat iegūt maksimālā elementa indeksu attiecīgi pa rindām un kolonnām.
    • Ja vēlaties saglabāt atgriezto vērtību citā masīvā, izvades masīvam varat iestatīt neobligāto out parametru. Tomēr izvades masīvam jābūt saderīgam ar formu un datu tipu.

    Pēc tam iepazīstieties ar padziļināto ceļvedi par Python komplektiem.