This document is relevant for: Inf2
, Trn1
, Trn2
nki.isa.nc_match_replace8#
- nki.isa.nc_match_replace8(*, data, vals, imm, mask=None, dtype=None, **kwargs)[source]#
Replace first occurrence of each value in
vals
withimm
indata
using the Vector engine. This is an in-place modification of thedata
tensor.This instruction reads the input
data
and replaces the first occurrence of each of the given values (fromvals
tensor) with the specified immediate constant. Other values are written out unchanged.The
data
tensor can be up to 5-dimensional, while thevals
tensor can be up to 3-dimensional. Thevals
tensor must have exactly 8 elements per partition. The data tensor must have no more than 16,384 elements per partition. The output will have the same shape as the input data tensor.data
andvals
must have the same number of partitions. Both input tensors can come from SBUF or PSUM.Behavior is undefined if vals tensor contains values that are not in the data tensor.
If provided, a mask is applied to the data tensor.
Estimated instruction cost:
N
engine cycles, where:N
is the number of elements per partition in the data tensor
- Parameters:
data – the data tensor to modify
vals – tensor containing the 8 values per partition to replace
imm – float32 constant to replace matched values with
mask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)
dtype – (optional) data type to cast the output type to (see Supported Data Types for more information); if not specified, it will default to be the same as the data type of the input tile.
- Returns:
the modified data tensor
Example:
import neuronxcc.nki.isa as nisa import neuronxcc.nki.language as nl from neuronxcc.nki.typing import tensor ################################################################## # Example 1: Generate tile a of random floating point values, # get the 8 largest values in each row, then replace their first # occurrences with -inf: ################################################################## N = 4 M = 16 data_tile = nl.rand((N, M)) max_vals = nisa.max8(src=data_tile) result = nisa.nc_match_replace8(data=data_tile[:, :], vals=max_vals, imm=float('-inf')) result_tensor = nl.ndarray([N, M], dtype=nl.float32, buffer=nl.shared_hbm) nl.store(result_tensor, value=result)
This document is relevant for: Inf2
, Trn1
, Trn2