This document is relevant for: Inf2
, Trn1
, Trn2
nki.isa.nc_find_index8#
- nki.isa.nc_find_index8(*, data, vals, mask=None, dtype=None, **kwargs)[source]#
Find indices of the 8 given vals in each partition of the data tensor.
This instruction first loads the 8 values, then loads the data tensor and outputs the indices (starting at 0) of the first occurrence of each value in the data tensor, for each partition.
The data tensor can be up to 5-dimensional, while the vals tensor must be up to 3-dimensional. The data tensor must have between 8 and 16,384 elements per partition. The vals tensor must have exactly 8 elements per partition. The output will contain exactly 8 elements per partition and will be uint16 or uint32 type. Default output type is uint32.
Behavior is undefined if vals tensor contains values that are not in the data tensor.
If provided, a mask is applied only to the data tensor.
Estimated instruction cost:
N
engine cycles, where:N
is the number of elements per partition in the data tensor
- Parameters:
data – the data tensor to find indices from
vals – tensor containing the 8 values per partition whose indices will be found
mask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)
dtype – uint16 or uint32
- Returns:
a 2D tile containing indices (uint16 or uint32) of the 8 values in each partition with shape [par_dim, 8]
Example:
import neuronxcc.nki.isa as nisa import neuronxcc.nki.language as nl from neuronxcc.nki.typing import tensor ################################################################## # Example 1: Generate tile b of 32 * 128 random floating point values, # find the 8 largest values in each row, then find their indices: ################################################################## # Generate random data data = nl.rand((32, 128)) # Find max 8 values per row max_vals = nisa.max8(src=data) # Create output tensor for indices indices_tensor = nl.ndarray([32, 8], dtype=nl.uint32, buffer=nl.shared_hbm) # Find indices of max values indices = nisa.nc_find_index8(data=data, vals=max_vals) # Store results nl.store(indices_tensor, value=indices)
This document is relevant for: Inf2
, Trn1
, Trn2