![]() |
![]() |
|
![]() |
Generating content, summarizing, and analysing content are just some of the tasks you can accomplish with Gemma open models. This tutorial shows you how to get started running Gemma using Keras, including generating text content with text and image input. Keras provides implementations for running Gemma and other models using JAX, PyTorch, and TensorFlow. If you're new to Keras, you might want to read Getting started with Keras before you begin.
Gemma 3 and later models support text and image input. Earlier versions of Gemma only support text input, except for some variants, including PaliGemma.
Setup
Before starting this tutorial, make sure you have completed the following steps:
- Get access to Gemma on kaggle.com.
- Select a Colab runtime with sufficient resources to run the Gemma model size you want to run. Learn more.
- Generate and configure a Kaggle username and API key.
If you need help completing these steps, see the Gemma setup instructions. After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.
Set environment variables
Set environment variables for KAGGLE_USERNAME
and KAGGLE_KEY
.
import os
from google.colab import userdata
# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Install Keras packages
Install the Keras and KerasHub Python packages.
pip install -q -U keras-hub
pip install -q -U keras
Select a backend
Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Keras 3 lets you choose the backend: TensorFlow, JAX, or PyTorch. All three will work for this tutorial. For this tutorial, configure the backend for JAX as it typically provides the better performance.
os.environ["KERAS_BACKEND"] = "jax" # Or "tensorflow" or "torch".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"
Import packages
Import the Keras and KerasHub packages.
import keras
import keras_hub
Load model
Keras provides implementations of many popular model architectures. Download and configure a Gemma model using the Gemma3CausalLM
class to build an end-to-end, causal language modeling implementation for Gemma 3 models. Create the model using the from_preset()
method, as shown in the following code example:
gemma_lm = keras_hub.models.Gemma3CausalLM.from_preset(
"gemma3_instruct_4b",
dtype="bfloat16",
)
The Gemma3CausalLM.from_preset()
method instantiates the model from a preset architecture and weights. In the code above, the string "gemma#_xxxxxxx"
specifies a preset version and parameter size for Gemma. You can find the code strings for Gemma models in their Model Variation listings on Kaggle.
Once you have the model downloaded, Use the summary()
function to get more info about the model:
gemma_lm.summary()
The output of the summary shows the models total number of trainable parameters. For purposes of naming the model, the embedding layer is not counted against the number of parameters.
Generate text with text
Generate text with a text prompt with using generate()
method of the Gemma model object you configured in the previous steps. The optional max_length
argument specifies the maximum length of the generated sequence. The following code examples shows a few ways to prompt the model.
gemma_lm.generate("what is keras in 3 bullet points?", max_length=64)
You can also provide batched prompts using a list as input:
gemma_lm.generate(
["what is keras in 3 bullet points?",
"The universe is"],
max_length=64)
If you're running on JAX or TensorFlow backends, you should notice that the second generate()
call returns an answer more quickly. This performance improvement is because each call to generate()
for a given batch size and max_length
is compiled with XLA. The first run is expensive, but subsequent runs are faster.
Use a prompt template
When building more complex requests or multi-turn chat interactions use a prompt template to structure your request. The following code creates a standard template for Gemma prompts:
PROMPT_TEMPLATE = """<start_of_turn>user
{question}
<end_of_turn>
<start_of_turn>model
"""
The following code shows how to use the template to format a simple request:
question = """"what is keras in 3 bullet points?"""
prompt = PROMPT_TEMPLATE.format(question=question)
gemma_lm.generate(prompt)
Optional: Try a different sampler
You can control the generation strategy for model object by setting the sampler
argument on compile()
. By default, "greedy"
sampling will be used. As an experiment, try setting a "top_k"
strategy:
gemma_lm.compile(sampler="top_k")
gemma_lm.generate("The universe is", max_length=64)
While the default greedy algorithm always picks the token with the largest probability, the top-K algorithm randomly picks the next token from the tokens of top K probability. You don't have to specify a sampler, and you can ignore the last code snippet if it's not helpful to your use case. If you'd like learn more about the available samplers, see Samplers.
Generate text with image data
With Gemma 3 and later models, you can use images as part of a prompt to generate output. This capability allows you to use Gemma to interpret visual content or use images as data for content generation.
Create image loader function
The following function loads an image file from a URL and tokenizes it for use in Gemma prompt:
import numpy as np
import PIL
def read_image(url):
"""Reads image from URL as NumPy array."""
image_path = keras.utils.get_file(origin=url)
image = PIL.Image.open(image_path)
image = np.array(image)
return image
Load image for a prompt
Load the image and format the data so the model can process it. Use read_image()
function defined in the previous section, as shown in the example code below:
from matplotlib import pyplot as plt
image = read_image(
"https://ai.google.dev/gemma/docs/images/thali-indian-plate.jpg"
)
plt.imshow(image)
Figure 1. Image of Thali Indian food on a metal plate.
Run request with an image
When prompting the Gemma model with image content, you use a specific string sequence, <start_of_image>
, within your prompt to include the image as part of the prompt. Use a prompt template, such as the PROMPT_TEMPLATE
string defined previously, to format your request as shown in the following prompt code:
question = """Which cuisine is this: <start_of_image>? \
Identify the food items present. Which macros is the meal \
high and low on? Keep your answer short.\
"""
gemma_lm.generate(
{
"images": image,
"prompts": PROMPT_TEMPLATE.format(question=question),
},
)
If you are using a smaller GPU, and encountering out of memory (OOM) errors, you can set max_images_per_prompt
and sequence_length
to smaller values. The following code shows how to reduce sequence length to 768.
gemma_lm.preprocessor.max_images_per_prompt = 2
gemma_lm.preprocessor.sequence_length = 768
Run requests with multiple images
When using more than one image in a prompt, use multiple <start_of_image>
tokens for each provided image, as shown in the following example:
dog_a = read_image("http://localhost/images/dog-a.jpg")
dog_b = read_image("http://localhost/images/dog-b.jpg")
question = """I have two images:
Dog A: <start_of_image>
Dog B: <start_of_image>
Which breeds are they? Tell me a bit about them. \
Keep it short.\
"""
gemma_lm.generate(
{
"images": [dog_a, dog_b],
"prompts": PROMPT_TEMPLATE.format(question=question),
},
)
What's next
In this tutorial, you learned how to generate text using Keras and Gemma. Here are a few suggestions for what to learn next:
- Learn how to finetune a Gemma model.
- Learn how to perform distributed fine-tuning and inference on a Gemma model.
- Learn about Gemma integration with Vertex AI
- Learn how to use Gemma models with Vertex AI.