This document is relevant for: Inf2, Trn1, Trn2

nki.isa.nc_find_index8#

nki.isa.nc_find_index8(*, data, vals, mask=None, dtype=None, **kwargs)[source]#

Find indices of the 8 given vals in each partition of the data tensor.

This instruction first loads the 8 values, then loads the data tensor and outputs the indices (starting at 0) of the first occurrence of each value in the data tensor, for each partition.

The data tensor can be up to 5-dimensional, while the vals tensor must be up to 3-dimensional. The data tensor must have between 8 and 16,384 elements per partition. The vals tensor must have exactly 8 elements per partition. The output will contain exactly 8 elements per partition and will be uint16 or uint32 type. Default output type is uint32.

Behavior is undefined if vals tensor contains values that are not in the data tensor.

If provided, a mask is applied only to the data tensor.

Estimated instruction cost:

N engine cycles, where:

  • N is the number of elements per partition in the data tensor

Parameters:
  • data – the data tensor to find indices from

  • vals – tensor containing the 8 values per partition whose indices will be found

  • mask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)

  • dtype – uint16 or uint32

Returns:

a 2D tile containing indices (uint16 or uint32) of the 8 values in each partition with shape [par_dim, 8]

Example:

import neuronxcc.nki.isa as nisa
import neuronxcc.nki.language as nl
from neuronxcc.nki.typing import tensor

##################################################################
# Example 1: Generate tile b of 32 * 128 random floating point values,
# find the 8 largest values in each row, then find their indices:
##################################################################
# Generate random data
data = nl.rand((32, 128))

# Find max 8 values per row
max_vals = nisa.max8(src=data)

# Create output tensor for indices
indices_tensor = nl.ndarray([32, 8], dtype=nl.uint32, buffer=nl.shared_hbm)

# Find indices of max values
indices = nisa.nc_find_index8(data=data, vals=max_vals)

# Store results
nl.store(indices_tensor, value=indices)

This document is relevant for: Inf2, Trn1, Trn2