JetStream MaxText inference on v6e TPU VMs

This tutorial shows how to use JetStream to serve MaxText models on TPU v6e. JetStream is a throughput and memory optimized engine for large language model (LLM) inference on XLA devices (TPUs). In this tutorial, you run the inference benchmark for the Llama2-7B model.

Before you begin

Prepare to provision a TPU v6e with 4 chips:

  1. Follow Set up the Cloud TPU environment guide to set up a Google Cloud project, configure the Google Cloud CLI, enable the Cloud TPU API, and ensure you have access to use Cloud TPUs.

  2. Authenticate with Google Cloud and configure the default project and zone for Google Cloud CLI.

    gcloud auth login
    gcloud config set project PROJECT_ID
    gcloud config set compute/zone ZONE

Secure capacity

When you are ready to secure TPU capacity, see Cloud TPU Quotas for more information about the Cloud TPU quotas. If you have additional questions about securing capacity, contact your Cloud TPU sales or account team.

Provision the Cloud TPU environment

You can provision TPU VMs with GKE, with GKE and XPK, or as queued resources.

Prerequisites

  • Verify that your project has enough TPUS_PER_TPU_FAMILY quota, which specifies the maximum number of chips you can access within your Google Cloud project.
  • Verify that your project has enough TPU quota for:
    • TPU VM quota
    • IP address quota
    • Hyperdisk Balanced quota
  • User project permissions

Create environment variables

In a Cloud Shell, create the following environment variables:
export PROJECT_ID=your-project-id
export TPU_NAME=your-tpu-name
export ZONE=us-east5-b
export ACCELERATOR_TYPE=v6e-4
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=your-service-account
export QUEUED_RESOURCE_ID=your-queued-resource-id

Command flag descriptions

Variable Description
PROJECT_ID Google Cloud project name. Use an existing project or create a new one.
TPU_NAME The name of the TPU.
ZONE See the TPU regions and zones document for the supported zones.
ACCELERATOR_TYPE The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions.
RUNTIME_VERSION The Cloud TPU software version.
SERVICE_ACCOUNT The email address for your service account . You can find it by going to the Service Accounts page in the Google Cloud console.

For example: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

QUEUED_RESOURCE_ID The user-assigned text ID of the queued resource request.

Provision a TPU v6e

Use the following command to provision a TPU v6e:

gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
    --node-id=${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --accelerator-type=${ACCELERATOR_TYPE} \
    --runtime-version=${RUNTIME_VERSION} \
    --service-account=${SERVICE_ACCOUNT}

Use the list or describe commands to query the status of your queued resource.

gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
    --project ${PROJECT_ID} --zone ${ZONE}

For more information about queued resource request statuses, see Manage queued resources.

Connect to the TPU using SSH

   gcloud compute tpus tpu-vm ssh ${TPU_NAME}

Once you have connected to the TPU, you can run the inference benchmark.

Set up your TPU VM environment

  1. Create a directory for running the inference benchmark:

    export MAIN_DIR=your-main-directory
    mkdir -p ${MAIN_DIR}
  2. Set up a Python virtual environment:

    cd ${MAIN_DIR}
    sudo apt update
    sudo apt install python3.10 python3.10-venv
    python3.10 -m venv venv
    source venv/bin/activate
  3. Install Git Large File Storage (LFS) (for OpenOrca data):

    sudo apt-get install git-lfs
    git lfs install
  4. Clone and install JetStream:

    cd $MAIN_DIR
    git clone https://github.com/google/JetStream.git
    cd JetStream
    git checkout main
    pip install -e .
    cd benchmarks
    pip install -r requirements.in
  5. Set up MaxText:

    cd $MAIN_DIR
    git clone https://github.com/google/maxtext.git
    cd maxtext
    git checkout main
    bash setup.sh
    pip install torch --index-url https://download.pytorch.org/whl/cpu
  6. Request Access to Llama Models to get a download key from Meta for the Llama 2 model.

  7. Clone the Llama repository:

    cd $MAIN_DIR
    git clone https://github.com/meta-llama/llama
    cd llama
  8. Run bash download.sh. When prompted, provide your download key. This script creates a llama-2-7b directory inside your llama directory.

    bash download.sh
  9. Create storage buckets:

    export CHKPT_BUCKET=gs://your-checkpoint-bucket
    export BASE_OUTPUT_DIRECTORY=gs://your-output-dir
    export CONVERTED_CHECKPOINT_PATH=gs://bucket-to-store-converted-checkpoints
    export MAXTEXT_BUCKET_UNSCANNED=gs://bucket-to-store-unscanned-data
    gcloud storage buckets create ${CHKPT_BUCKET}
    gcloud storage buckets create ${BASE_OUTPUT_DIRECTORY}
    gcloud storage buckets create ${CONVERTED_CHECKPOINT_PATH}
    gcloud storage buckets create ${MAXTEXT_BUCKET_UNSCANNED}
    gcloud storage cp --recursive llama-2-7b/* ${CHKPT_BUCKET}

Perform checkpoint conversion

  1. Perform conversion to scanned checkpoints:

    cd $MAIN_DIR/maxtext
    python3 -m MaxText.llama_or_mistral_ckpt \
        --base-model-path $MAIN_DIR/llama/llama-2-7b \
        --model-size llama2-7b \
        --maxtext-model-path ${CONVERTED_CHECKPOINT_PATH}
  2. Convert to unscanned checkpoints:

    export CONVERTED_CHECKPOINT=${CONVERTED_CHECKPOINT_PATH}/0/items
    export DIRECT_PARAMETER_CHECKPOINT_RUN=direct_generate_param_only_checkpoint
    python3 -m MaxText.generate_param_only_checkpoint \
        MaxText/configs/base.yml \
        base_output_directory=${MAXTEXT_BUCKET_UNSCANNED} \
        load_parameters_path=${CONVERTED_CHECKPOINT} \
        run_name=${DIRECT_PARAMETER_CHECKPOINT_RUN} \
        model_name='llama2-7b' \
        force_unroll=true

Perform inference

  1. Run a validation test:

    export UNSCANNED_CKPT_PATH=${MAXTEXT_BUCKET_UNSCANNED}/${DIRECT_PARAMETER_CHECKPOINT_RUN}/checkpoints/0/items
    python3 -m MaxText.decode \
        MaxText/configs/base.yml \
        load_parameters_path=${UNSCANNED_CKPT_PATH} \
        run_name=runner_decode_unscanned_${idx} \
        base_output_directory=${BASE_OUTPUT_DIRECTORY} \
        per_device_batch_size=1 \
        model_name='llama2-7b' \
        ici_autoregressive_parallelism=4 \
        max_prefill_predict_length=4 \
        max_target_length=16 \
        prompt="I love to" \
        attention=dot_product \
        scan_layers=false
  2. Run the server in your current terminal:

    export TOKENIZER_PATH=assets/tokenizer.llama2
    export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
    export MAX_PREFILL_PREDICT_LENGTH=1024
    export MAX_TARGET_LENGTH=2048
    export MODEL_NAME=llama2-7b
    export ICI_FSDP_PARALLELISM=1
    export ICI_AUTOREGRESSIVE_PARALLELISM=1
    export ICI_TENSOR_PARALLELISM=-1
    export SCAN_LAYERS=false
    export WEIGHT_DTYPE=bfloat16
    export PER_DEVICE_BATCH_SIZE=11
    
    cd $MAIN_DIR/maxtext
    python3 -m MaxText.maxengine_server \
        MaxText/configs/base.yml \
        tokenizer_path=${TOKENIZER_PATH} \
        load_parameters_path=${LOAD_PARAMETERS_PATH} \
        max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
        max_target_length=${MAX_TARGET_LENGTH} \
        model_name=${MODEL_NAME} \
        ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
        ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
        ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
        scan_layers=${SCAN_LAYERS} \
        weight_dtype=${WEIGHT_DTYPE} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE}
  3. Open a new terminal window, connect to the TPU, and switch to the same virtual environment you used in the first terminal window:

    source venv/bin/activate
    
  4. Run the following commands to run the JetStream benchmark.

    export MAIN_DIR=your-main-directory
    cd $MAIN_DIR
    
    python JetStream/benchmarks/benchmark_serving.py \
        --tokenizer $MAIN_DIR/maxtext/assets/tokenizer.llama2 \
        --warmup-mode sampled \
        --save-result \
        --save-request-outputs \
        --request-outputs-file-path outputs.json \
        --num-prompts 1000 \
        --max-output-length 1024 \
        --dataset openorca \
        --dataset-path $MAIN_DIR/JetStream/benchmarks/open_orca_gpt4_tokenized_llama.calibration_1000.pkl

Results

The following output was generated when running the benchmark using v6e-8. Results will vary based on hardware, software, model, and networking.

Mean output size: 929.5959798994975
Median output size: 1026.0
P99 output size: 1026.0
Successful requests: 995
Benchmark duration: 195.533269 s
Total input tokens: 217011
Total generated tokens: 924948
Request throughput: 5.09 requests/s
Input token throughput: 1109.84 tokens/s
Output token throughput: 4730.39 tokens/s
Overall token throughput: 5840.23 tokens/s
Mean ttft: 538.49 ms
Median ttft: 95.66 ms
P99 ttft: 13937.86 ms
Mean ttst: 1218.72 ms
Median ttst: 152.57 ms
P99 ttst: 14241.30 ms
Mean TPOT: 91.83 ms
Median TPOT: 16.63 ms
P99 TPOT: 363.37 ms

Clean up

  1. Disconnect from the TPU:

    $ (vm) exit
  2. Delete the TPU:

    gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
        --project ${PROJECT_ID} \
        --zone ${ZONE} \
        --force \
        --async
  3. Delete the buckets and their contents:

    export CHKPT_BUCKET=gs://your-checkpoint-bucket
    export BASE_OUTPUT_DIRECTORY=gs://your-output-dir
    export CONVERTED_CHECKPOINT_PATH=gs://bucket-to-store-converted-checkpoints
    export MAXTEXT_BUCKET_UNSCANNED=gs://bucket-to-store-unscanned-data
    gcloud storage rm -r ${CHKPT_BUCKET}
    gcloud storage rm -r ${BASE_OUTPUT_DIRECTORY}
    gcloud storage rm -r ${CONVERTED_CHECKPOINT_PATH}
    gcloud storage rm -r ${MAXTEXT_BUCKET_UNSCANNED}
    gcloud storage buckets delete ${CHKPT_BUCKET}
    gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY}
    gcloud storage buckets delete ${CONVERTED_CHECKPOINT_PATH}
    gcloud storage buckets delete ${MAXTEXT_BUCKET_UNSCANNED}