This document is relevant for: Inf2
, Trn1
, Trn2
JAX NeuronX Environment Variables#
Environment variables allow modifications to JAX NeuronX behavior
without requiring code change to user script. It is recommended to set
them in code or just before invoking the python process, such as
NEURON_RT_VISIBLE_CORES=8 python3 <script>
to avoid inadvertently
changing behavior for other scripts. Environment variables specific to
JAX Neuronx are:
NEURON_CC_FLAGS
Compiler options. Full compiler options are described in the mixed-precision-casting-options.
XLA_FLAGS
When set to
"--xla_dump_hlo_snapshots --xla_dump_to=<dir>"
, this environmental variable enables dumping snapshots in<dir>
directory. See Snapshotting With Torch-Neuronx 2.1 section for more information. The snapshotting interface for JAX and Pytorch are identical.When set to
"--xla_dump_hlo_as_text --xla_dump_hlo_as_proto --xla_dump_to=<dir> --xla_dump_hlo_pass_re='.*'"
, this environmental variable enables dumping HLOs in proto and text formats after each XLA pass. The dumped*.hlo.pb
files are in HloProto format.
NEURON_FORCE_PJRT_PLUGIN_REGISTRATION
When
NEURON_FORCE_PJRT_PLUGIN_REGISTRATION=1
, the Neuron PJRT plugin will be registered in JAX regardless of the instance type.
NEURON_RUN_TRIVIAL_COMPUTATION_ON_CPU
When
NEURON_RUN_TRIVIAL_COMPUTATION_ON_CPU=1
, the Neuron PJRT plugin will compile and execute “trivial” computations on CPU instead of Neuron cores. A “trivial” computation is defined as an HLO program that does not contain any collective-compute instructions. The HLO program will be compiled by the XLA CPU compiler and outputs of the computation will be allocated on Neuron cores. The following HLO instructions are considered as collective-compute instructions.all-gather
all-gather-done
all-gather-start
all-reduce-done
all-reduce-start
all-to-all
collective-permute
partition-id
replica-id
recv
recv-done
reduce-scatter
send
send-done
NEURON_PJRT_PROCESSES_NUM_DEVICES
Should be set to a comma-separated list stating the number of NeuronCores used by each worker process. It is used to construct a global device array with its size equal to the sum of the list. This gets reported to the XLA PJRT runtime when requested. Must be set for multi-process executions. It can be used in conjunction with
NEURON_RT_VISIBLE_CORES
to expose a limited number of NeuronCores to each worker process. IfNEURON_RT_VISIBLE_CORES
is not set, it should be set to available number of NeuronCores on the host.NEURON_PJRT_PROCESSES_NUM_DEVICES
must be less than or equal toNEURON_RT_VISIBLE_CORES
.
NEURON_PJRT_PROCESS_INDEX
An integer stating the index (or rank) of the current worker process. This is required for multi-process environments where all workers need to know information on all participating processes. Must be set for multi-process executions. The value should be between
0
andNEURON_PJRT_PROCESS_INDEX - 1
.
NEURON_RT_STOCHASTIC_ROUNDING_EN
[Neuron Runtime]
When
NEURON_RT_STOCHASTIC_ROUNDING_EN=1
, JAX Neuron will use stochastic rounding instead of round-nearest-even for all internal rounding operations when casting from FP32 to a reduced precision data type (FP16, BF16, FP8, TF32). This feature has been shown to improve training convergence for reduced precision training jobs. To switch to round-nearest-even mode, setNEURON_RT_STOCHASTIC_ROUNDING_EN=0
.
NEURON_RT_STOCHASTIC_ROUNDING_SEED
[Neuron Runtime]
Sets the seed for the random number generator used in stochastic rounding (see previous section). If this environment variable is not set, the seed is set to 0 by default. Please set
NEURON_RT_STOCHASTIC_ROUNDING_SEED
to a fixed value to ensure reproducibility between runs.
NEURON_RT_VISIBLE_CORES
[Neuron Runtime]
Integer range of specific NeuronCores needed by the process (for example, 0-3 specifies NeuronCores 0, 1, 2, and 3). Use this environment variable when launching processes to limit the launched process to specific consecutive NeuronCores.
Additional Neuron runtime environment variables are described in NeuronX Runtime Configuration.
This document is relevant for: Inf2
, Trn1
, Trn2