JetStream MaxText inference on v6e TPU VMs
This tutorial shows how to use JetStream to serve MaxText models on TPU v6e. JetStream is a throughput and memory optimized engine for large language model (LLM) inference on XLA devices (TPUs). In this tutorial, you run the inference benchmark for the Llama2-7B model.
Before you begin
Prepare to provision a TPU v6e with 4 chips:
Follow Set up the Cloud TPU environment guide to set up a Google Cloud project, configure the Google Cloud CLI, enable the Cloud TPU API, and ensure you have access to use Cloud TPUs.
Authenticate with Google Cloud and configure the default project and zone for Google Cloud CLI.
gcloud auth login gcloud config set project PROJECT_ID gcloud config set compute/zone ZONE
Secure capacity
When you are ready to secure TPU capacity, see Cloud TPU Quotas for more information about the Cloud TPU quotas. If you have additional questions about securing capacity, contact your Cloud TPU sales or account team.
Provision the Cloud TPU environment
You can provision TPU VMs with GKE, with GKE and XPK, or as queued resources.
Prerequisites
- Verify that your project has enough
TPUS_PER_TPU_FAMILY
quota, which specifies the maximum number of chips you can access within your Google Cloud project. - Verify that your project has enough TPU quota for:
- TPU VM quota
- IP address quota
- Hyperdisk Balanced quota
- User project permissions
- If you are using GKE with XPK, see Cloud Console Permissions on the user or service account for the permissions needed to run XPK.
Create environment variables
In a Cloud Shell, create the following environment variables:export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-east5-b export ACCELERATOR_TYPE=v6e-4 export RUNTIME_VERSION=v2-alpha-tpuv6e export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Variable | Description |
PROJECT_ID
|
Google Cloud project name. Use an existing project or create a new one. |
TPU_NAME
|
The name of the TPU. |
ZONE
|
See the TPU regions and zones document for the supported zones. |
ACCELERATOR_TYPE
|
The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions. |
RUNTIME_VERSION
|
The Cloud TPU software version. |
SERVICE_ACCOUNT
|
The email address for your service account . You can find it by going to
the Service Accounts
page in the Google Cloud console.
For example: |
QUEUED_RESOURCE_ID
|
The user-assigned text ID of the queued resource request. |
Provision a TPU v6e
Use the following command to provision a TPU v6e:
gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Use the list
or describe
commands
to query the status of your queued resource.
gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} --zone ${ZONE}
For more information about queued resource request statuses, see Manage queued resources.
Connect to the TPU using SSH
gcloud compute tpus tpu-vm ssh ${TPU_NAME}
Once you have connected to the TPU, you can run the inference benchmark.
Set up your TPU VM environment
Create a directory for running the inference benchmark:
export MAIN_DIR=your-main-directory mkdir -p ${MAIN_DIR}
Set up a Python virtual environment:
cd ${MAIN_DIR} sudo apt update sudo apt install python3.10 python3.10-venv python3.10 -m venv venv source venv/bin/activate
Install Git Large File Storage (LFS) (for OpenOrca data):
sudo apt-get install git-lfs git lfs install
Clone and install JetStream:
cd $MAIN_DIR git clone https://github.com/google/JetStream.git cd JetStream git checkout main pip install -e . cd benchmarks pip install -r requirements.in
Set up MaxText:
cd $MAIN_DIR git clone https://github.com/google/maxtext.git cd maxtext git checkout main bash setup.sh pip install torch --index-url https://download.pytorch.org/whl/cpu
Request Access to Llama Models to get a download key from Meta for the Llama 2 model.
Clone the Llama repository:
cd $MAIN_DIR git clone https://github.com/meta-llama/llama cd llama
Run
bash download.sh
. When prompted, provide your download key. This script creates allama-2-7b
directory inside yourllama
directory.bash download.sh
Create storage buckets:
export CHKPT_BUCKET=gs://your-checkpoint-bucket export BASE_OUTPUT_DIRECTORY=gs://your-output-dir export CONVERTED_CHECKPOINT_PATH=gs://bucket-to-store-converted-checkpoints export MAXTEXT_BUCKET_UNSCANNED=gs://bucket-to-store-unscanned-data gcloud storage buckets create ${CHKPT_BUCKET} gcloud storage buckets create ${BASE_OUTPUT_DIRECTORY} gcloud storage buckets create ${CONVERTED_CHECKPOINT_PATH} gcloud storage buckets create ${MAXTEXT_BUCKET_UNSCANNED} gcloud storage cp --recursive llama-2-7b/* ${CHKPT_BUCKET}
Perform checkpoint conversion
Perform conversion to scanned checkpoints:
cd $MAIN_DIR/maxtext python3 -m MaxText.llama_or_mistral_ckpt \ --base-model-path $MAIN_DIR/llama/llama-2-7b \ --model-size llama2-7b \ --maxtext-model-path ${CONVERTED_CHECKPOINT_PATH}
Convert to unscanned checkpoints:
export CONVERTED_CHECKPOINT=${CONVERTED_CHECKPOINT_PATH}/0/items export DIRECT_PARAMETER_CHECKPOINT_RUN=direct_generate_param_only_checkpoint python3 -m MaxText.generate_param_only_checkpoint \ MaxText/configs/base.yml \ base_output_directory=${MAXTEXT_BUCKET_UNSCANNED} \ load_parameters_path=${CONVERTED_CHECKPOINT} \ run_name=${DIRECT_PARAMETER_CHECKPOINT_RUN} \ model_name='llama2-7b' \ force_unroll=true
Perform inference
Run a validation test:
export UNSCANNED_CKPT_PATH=${MAXTEXT_BUCKET_UNSCANNED}/${DIRECT_PARAMETER_CHECKPOINT_RUN}/checkpoints/0/items python3 -m MaxText.decode \ MaxText/configs/base.yml \ load_parameters_path=${UNSCANNED_CKPT_PATH} \ run_name=runner_decode_unscanned_${idx} \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ per_device_batch_size=1 \ model_name='llama2-7b' \ ici_autoregressive_parallelism=4 \ max_prefill_predict_length=4 \ max_target_length=16 \ prompt="I love to" \ attention=dot_product \ scan_layers=false
Run the server in your current terminal:
export TOKENIZER_PATH=assets/tokenizer.llama2 export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH} export MAX_PREFILL_PREDICT_LENGTH=1024 export MAX_TARGET_LENGTH=2048 export MODEL_NAME=llama2-7b export ICI_FSDP_PARALLELISM=1 export ICI_AUTOREGRESSIVE_PARALLELISM=1 export ICI_TENSOR_PARALLELISM=-1 export SCAN_LAYERS=false export WEIGHT_DTYPE=bfloat16 export PER_DEVICE_BATCH_SIZE=11 cd $MAIN_DIR/maxtext python3 -m MaxText.maxengine_server \ MaxText/configs/base.yml \ tokenizer_path=${TOKENIZER_PATH} \ load_parameters_path=${LOAD_PARAMETERS_PATH} \ max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \ max_target_length=${MAX_TARGET_LENGTH} \ model_name=${MODEL_NAME} \ ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \ ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \ ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \ scan_layers=${SCAN_LAYERS} \ weight_dtype=${WEIGHT_DTYPE} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE}
Open a new terminal window, connect to the TPU, and switch to the same virtual environment you used in the first terminal window:
source venv/bin/activate
Run the following commands to run the JetStream benchmark.
export MAIN_DIR=your-main-directory cd $MAIN_DIR python JetStream/benchmarks/benchmark_serving.py \ --tokenizer $MAIN_DIR/maxtext/assets/tokenizer.llama2 \ --warmup-mode sampled \ --save-result \ --save-request-outputs \ --request-outputs-file-path outputs.json \ --num-prompts 1000 \ --max-output-length 1024 \ --dataset openorca \ --dataset-path $MAIN_DIR/JetStream/benchmarks/open_orca_gpt4_tokenized_llama.calibration_1000.pkl
Results
The following output was generated when running the benchmark using v6e-8. Results will vary based on hardware, software, model, and networking.
Mean output size: 929.5959798994975
Median output size: 1026.0
P99 output size: 1026.0
Successful requests: 995
Benchmark duration: 195.533269 s
Total input tokens: 217011
Total generated tokens: 924948
Request throughput: 5.09 requests/s
Input token throughput: 1109.84 tokens/s
Output token throughput: 4730.39 tokens/s
Overall token throughput: 5840.23 tokens/s
Mean ttft: 538.49 ms
Median ttft: 95.66 ms
P99 ttft: 13937.86 ms
Mean ttst: 1218.72 ms
Median ttst: 152.57 ms
P99 ttst: 14241.30 ms
Mean TPOT: 91.83 ms
Median TPOT: 16.63 ms
P99 TPOT: 363.37 ms
Clean up
Disconnect from the TPU:
$ (vm) exit
Delete the TPU:
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --force \ --async
Delete the buckets and their contents:
export CHKPT_BUCKET=gs://your-checkpoint-bucket export BASE_OUTPUT_DIRECTORY=gs://your-output-dir export CONVERTED_CHECKPOINT_PATH=gs://bucket-to-store-converted-checkpoints export MAXTEXT_BUCKET_UNSCANNED=gs://bucket-to-store-unscanned-data gcloud storage rm -r ${CHKPT_BUCKET} gcloud storage rm -r ${BASE_OUTPUT_DIRECTORY} gcloud storage rm -r ${CONVERTED_CHECKPOINT_PATH} gcloud storage rm -r ${MAXTEXT_BUCKET_UNSCANNED} gcloud storage buckets delete ${CHKPT_BUCKET} gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY} gcloud storage buckets delete ${CONVERTED_CHECKPOINT_PATH} gcloud storage buckets delete ${MAXTEXT_BUCKET_UNSCANNED}