From b10178dec5e3800eaa83d6623dffc181458ec604 Mon Sep 17 00:00:00 2001 From: SurbhiJainUSC Date: Thu, 11 Jun 2026 23:49:40 +0000 Subject: [PATCH] Update Gemma3 multimodal SFT Jupyter notebook --- .github/workflows/run_jupyter_notebooks.yml | 5 + docs/tutorials/posttraining/multimodal.md | 2 +- .../examples/multimodal_gemma3_demo.ipynb | 211 ---------- src/maxtext/examples/rl_llama3_demo.ipynb | 6 + .../examples/sft_llama3_demo_tpu.ipynb | 6 + .../examples/sft_multimodal_gemma3_demo.ipynb | 384 ++++++++++++++++++ .../post_train/sft/train_sft_native.py | 1 + 7 files changed, 403 insertions(+), 212 deletions(-) delete mode 100644 src/maxtext/examples/multimodal_gemma3_demo.ipynb create mode 100644 src/maxtext/examples/sft_multimodal_gemma3_demo.ipynb diff --git a/.github/workflows/run_jupyter_notebooks.yml b/.github/workflows/run_jupyter_notebooks.yml index cd243f3039..6b25a2042c 100644 --- a/.github/workflows/run_jupyter_notebooks.yml +++ b/.github/workflows/run_jupyter_notebooks.yml @@ -123,6 +123,11 @@ jobs: echo "------------------------------------------------------" $PAPERMILL_EXE "$notebook" "$output_name" -k maxtext_venv + + # Clean up any checkpoint directories created by the notebook to avoid filling up disk space + echo "Post-notebook disk cleanup for $filename ..." + rm -rf "$MAXTEXT_PKG_DIR"/sft_*_output "$MAXTEXT_PKG_DIR"/rl_*_output + rm -rf "$HOME/.cache/huggingface/hub" done - name: Upload Outputs if: always() diff --git a/docs/tutorials/posttraining/multimodal.md b/docs/tutorials/posttraining/multimodal.md index c00ba5a993..1b8363bd14 100644 --- a/docs/tutorials/posttraining/multimodal.md +++ b/docs/tutorials/posttraining/multimodal.md @@ -6,7 +6,7 @@ This document provides a guide to use the multimodal functionalities in MaxText - **Multimodal Decode**: Inference with text+images as input. - **Supervised Fine-Tuning (SFT)**: Apply SFT to the model using a visual-question-answering dataset. -We also provide a [colab](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/multimodal_gemma3_demo.ipynb) for multimodal features demonstration. The following table provides a list of models and modalities we currently support: +We also provide a [colab](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/sft_multimodal_gemma3_demo.ipynb) for multimodal features demonstration. The following table provides a list of models and modalities we currently support: | Models | Input Modalities | Output Modalities | | :--------------------------------------------- | :--------------- | :---------------- | diff --git a/src/maxtext/examples/multimodal_gemma3_demo.ipynb b/src/maxtext/examples/multimodal_gemma3_demo.ipynb deleted file mode 100644 index cc07975d1d..0000000000 --- a/src/maxtext/examples/multimodal_gemma3_demo.ipynb +++ /dev/null @@ -1,211 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/multimodal_gemma3_demo.ipynb)\n", - "\n", - "# Gemma3 Multimodal Inference/Training Demo" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Overview\n", - "\n", - "This notebook demonstrates MaxText's multimodal features, using Gemma3-4B as an example:\n", - "- Convert an orbax checkpoint from HuggingFace.\n", - "- Apply decoding on a single image input.\n", - "- Apply SFT to the converted checkpoint on ChartQA dataset.\n", - "\n", - "Given the relative small size of Gemma3-4B, you can run this colab on a v4-8, v5p-8 or v6e-4 TPU VM. You can also use [XPK](https://github.com/AI-Hypercomputer/maxtext/blob/64d6d9b425e78dde94c37a82bb13ba5606e74b1b/docs/guides/run_maxtext_via_xpk.md) to run training workloads on a TPU cluster." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Get Your Hugging Face Token\n", - "\n", - "To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.\n", - "\n", - "**Follow these steps to get your token:**\n", - "\n", - "1. **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:\n", - " * [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n", - "\n", - "2. **Create a new token** by clicking the **\"+ Create new token\"** button.\n", - "\n", - "3. **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.\n", - "\n", - "4. **Copy the generated token**. You will need to paste it in `HF_TOKEN`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "5KPyOE8e9WbO" - }, - "outputs": [], - "source": [ - "## Installation: MaxText and Post training Dependencies\n", - "\n", - "Create an virtual environment and install dependencies outside of the notebook using the commands in [MaxText installation and dependency setup guide](../../../docs/guides/run_python_notebook.md#step-4-install-maxtext-and-dependencies) before proceeding. And then run the notebook using that virtual environment.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import maxtext\n", - "\n", - "# Get the root directory of the MaxText\n", - "MAXTEXT_PKG_DIR = os.path.dirname(maxtext.__file__)\n", - "MAXTEXT_REPO_ROOT = os.path.dirname(os.path.dirname(MAXTEXT_PKG_DIR))\n", - "MAXTEXT_ASSETS_ROOT = os.path.join(MAXTEXT_REPO_ROOT, \"src\", \"maxtext\", \"assets\")\n", - "\n", - "\n", - "# Define model name\n", - "MODEL_NAME = \"gemma3-4b\"\n", - "\n", - "# Use either a GCS path or a local path for the model checkpoint\n", - "MODEL_CHECKPOINT_PATH = f\"gs://your-gcs-bucket/{MODEL_NAME}\"\n", - "\n", - "# Replace with your actual Hugging Face token\n", - "HF_TOKEN = \"your_huggingface_token_here\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Convert Checkpoint from HuggingFace" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!python3 -m maxtext.checkpoint_conversion.to_maxtext \\\n", - " $MAXTEXT_CONFIGS_DIR/base.yml \\\n", - " model_name=$MODEL_NAME \\\n", - " hf_access_token=$HF_TOKEN \\\n", - " base_output_directory=$MODEL_CHECKPOINT_PATH \\\n", - " use_multimodal=true \\\n", - " scan_layers=false" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Decode on One Image" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!python -m maxtext.inference.decode \\\n", - " $MAXTEXT_CONFIGS_DIR/base.yml \\\n", - " model_name=$MODEL_NAME \\\n", - " tokenizer_path=$MAXTEXT_ASSETS_ROOT/tokenizers/tokenizer.gemma3 \\\n", - " load_parameters_path=$MODEL_CHECKPOINT_PATH/0/items \\\n", - " per_device_batch_size=1 \\\n", - " run_name=ht_test max_prefill_predict_length=272 \\\n", - " max_target_length=300 \\\n", - " steps=1 \\\n", - " async_checkpointing=false \\\n", - " scan_layers=false \\\n", - " use_multimodal=true \\\n", - " prompt='Describe image ' \\\n", - " image_path=$MAXTEXT_PKG_DIR/tests/assets/test_image.jpg \\\n", - " attention='dot_product'" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Supervised Finetuning (SFT)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Running the cell below will trigger a 10-step SFT on your TPU VM (v4-8, v5p-8, or v6e-4). However, we recommend using [XPK](https://github.com/AI-Hypercomputer/maxtext/blob/64d6d9b425e78dde94c37a82bb13ba5606e74b1b/docs/guides/run_maxtext_via_xpk.md) to schedule a training workload on a TPU cluster for better performance. After the SFT, the result checkpoint will be saved to `BASE_OUTPUT_DIRECTORY`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Define SFT output directory\n", - "BASE_OUTPUT_DIRECTORY=f\"gs://your-gcs-bucket/{MODEL_NAME}-sft\"\n", - "PRE_TRAINED_MODEL_TOKENIZER=\"google/gemma-3-4b-it\"\n", - "WORKLOAD_NAME=f\"{MODEL_NAME}-chartqa-sft\"\n", - "STEPS=10\n", - "PER_DEVICE_BATCH_SIZE=1\n", - "\n", - "!python -m maxtext.trainers.post_train.sft.train_sft_native \\\n", - " $MAXTEXT_CONFIGS_DIR/sft-vision-chartqa.yml \\\n", - " run_name=$WORKLOAD_NAME \\\n", - " model_name=$MODEL_NAME \\\n", - " tokenizer_path=$PRE_TRAINED_MODEL_TOKENIZER \\\n", - " hf_access_token=$HF_TOKEN \\\n", - " load_parameters_path=$MODEL_CHECKPOINT_PATH/0/items \\\n", - " base_output_directory=$BASE_OUTPUT_DIRECTORY \\\n", - " per_device_batch_size=$PER_DEVICE_BATCH_SIZE \\\n", - " steps=$STEPS \\\n", - " max_prefill_predict_length=1024 \\\n", - " max_target_length=2048 \\\n", - " checkpoint_period=1000 \\\n", - " scan_layers=False \\\n", - " async_checkpointing=True \\\n", - " enable_checkpointing=True \\\n", - " attention=dot_product \\\n", - " max_num_images_per_example=1 \\\n", - " dataset_type=hf profiler=xplane" - ] - } - ], - "metadata": { - "accelerator": "TPU", - "colab": { - "gpuType": "V5E1", - "provenance": [] - }, - "kernelspec": { - "display_name": "maxtext_venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.12" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/src/maxtext/examples/rl_llama3_demo.ipynb b/src/maxtext/examples/rl_llama3_demo.ipynb index ebc3991c9c..777534efe5 100644 --- a/src/maxtext/examples/rl_llama3_demo.ipynb +++ b/src/maxtext/examples/rl_llama3_demo.ipynb @@ -222,6 +222,12 @@ " check=True,\n", " env=env\n", " )\n", + "\n", + " # The HF model cache is no longer needed after conversion to MaxText format.\n", + " import shutil\n", + " hf_cache = epath.Path(os.path.expanduser(\"~\")) / \".cache\" / \"huggingface\" / \"hub\"\n", + " if hf_cache.exists():\n", + " shutil.rmtree(str(hf_cache))\n", " \n", " MODEL_CHECKPOINT_PATH = os.path.join(MODEL_CHECKPOINT_PATH, \"0/items\")\n", "else:\n", diff --git a/src/maxtext/examples/sft_llama3_demo_tpu.ipynb b/src/maxtext/examples/sft_llama3_demo_tpu.ipynb index 431db029b1..7f8ede0f4f 100644 --- a/src/maxtext/examples/sft_llama3_demo_tpu.ipynb +++ b/src/maxtext/examples/sft_llama3_demo_tpu.ipynb @@ -236,6 +236,12 @@ " env=env\n", " )\n", "\n", + " # The HF model cache is no longer needed after conversion to MaxText format.\n", + " import shutil\n", + " hf_cache = epath.Path(os.path.expanduser(\"~\")) / \".cache\" / \"huggingface\" / \"hub\"\n", + " if hf_cache.exists():\n", + " shutil.rmtree(str(hf_cache))\n", + "\n", " MODEL_CHECKPOINT_PATH = os.path.join(MODEL_CHECKPOINT_PATH, \"0/items\")\n", "else:\n", " print(f\"Model checkpoint exists at {MODEL_CHECKPOINT_PATH}\")" diff --git a/src/maxtext/examples/sft_multimodal_gemma3_demo.ipynb b/src/maxtext/examples/sft_multimodal_gemma3_demo.ipynb new file mode 100644 index 0000000000..2e4ba2802c --- /dev/null +++ b/src/maxtext/examples/sft_multimodal_gemma3_demo.ipynb @@ -0,0 +1,384 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/sft_multimodal_gemma3_demo.ipynb)\n", + "\n", + "# Gemma3 Multimodal Supervised Fine Tuning (SFT) Demo" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Overview\n", + "\n", + "This notebook demonstrates multimodal Supervised Fine-Tuning (SFT) in MaxText using Gemma3-4B as an example. By the end, you will have a fine-tuned model checkpoint in HuggingFace format.\n", + "\n", + "**Workflow:**\n", + "1. **Download & convert** the Gemma3-4B checkpoint from HuggingFace to MaxText format.\n", + "2. **Fine-tune** the model with SFT on the [ChartQA](https://huggingface.co/datasets/ahmed-masry/ChartQA) dataset β€” a visual question-answering benchmark that requires understanding charts and figures.\n", + "3. **Export** the trained MaxText checkpoint back to HuggingFace format.\n", + "\n", + "**Hardware:** Given the relatively small size of Gemma3-4B, this notebook runs on a **v4-8**, **v5p-8**, or **v6e-4** TPU VM." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prerequisites\n", + "\n", + "Before running this notebook, make sure your environment is set up for the method you are using. Follow the [Run MaxText Python Notebooks on TPUs](https://maxtext.readthedocs.io/en/latest/guides/run_python_notebook.html) guide and complete all steps for your chosen method (Google Colab, VS Code, or Local Jupyter Lab) before proceeding.\n", + "\n", + "If you run into issues, refer to the [Common Pitfalls & Debugging](https://maxtext.readthedocs.io/en/latest/guides/run_python_notebook.html#common-pitfalls-debugging) section of the guide." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " import google.colab\n", + " print(\"Running the notebook on Google Colab\")\n", + " IN_COLAB = True\n", + "except ImportError:\n", + " print(\"Running the notebook on Visual Studio or JupyterLab\")\n", + " IN_COLAB = False" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5KPyOE8e9WbO" + }, + "source": [ + "## Installation: MaxText and Post training Dependencies\n", + "\n", + "**Running the notebook on Visual Studio or JupyterLab**: Before proceeding, create a virtual environment and install the required post-training dependencies by following Option 3: Installing [tpu-post-train] in the [MaxText installation guide](https://maxtext.readthedocs.io/en/latest/install_maxtext.html#from-source). Once the environment is set up, ensure the notebook is running within it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if IN_COLAB:\n", + " # Clone the MaxText repository\n", + " !git clone https://github.com/AI-Hypercomputer/maxtext.git\n", + " %cd maxtext\n", + "\n", + " # Install uv, a fast Python package installer\n", + " !pip install uv\n", + " \n", + " # Install MaxText and post-training dependencies\n", + " import os\n", + " os.environ[\"UV_TORCH_BACKEND\"]=\"cpu\"\n", + " !uv pip install -e .[tpu-post-train] --resolution=lowest\n", + " !install_tpu_post_train_extra_deps" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Session restart Instructions for Colab:\n", + "\n", + "1. Navigate to the menu at the top of the screen.\n", + "2. Click on **Runtime**.\n", + "3. Select **Restart session** from the dropdown menu.\n", + "\n", + "You will be asked to confirm the action in a pop-up dialog. Click on **Yes**." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Environment Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import datetime\n", + "import jax\n", + "import os\n", + "import subprocess\n", + "import sys\n", + "from maxtext.configs import pyconfig\n", + "from maxtext.utils.globals import MAXTEXT_PKG_DIR\n", + "from maxtext.trainers.post_train.sft import train_sft_native\n", + "from etils import epath\n", + "\n", + "print(f\"MaxText installation path: {MAXTEXT_PKG_DIR}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if not jax.distributed.is_initialized():\n", + " jax.distributed.initialize()\n", + "print(f\"JAX version: {jax.__version__}\")\n", + "print(f\"JAX devices: {jax.devices()}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if IN_COLAB:\n", + " from huggingface_hub import notebook_login\n", + " notebook_login()\n", + "else:\n", + " from huggingface_hub import login\n", + " login()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model Configurations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_NAME = \"gemma3-4b\"\n", + "TOKENIZER_NAME = \"google/gemma-3-4b-it\"\n", + "\n", + "BASE_OUTPUT_DIRECTORY = f\"{MAXTEXT_PKG_DIR}/sft_multimodal_gemma3_output\"\n", + "\n", + "# set the path to the model checkpoint (including `/0/items`) or leave empty to download from HuggingFace\n", + "MODEL_CHECKPOINT_PATH = \"\"\n", + "if not MODEL_CHECKPOINT_PATH:\n", + " MODEL_CHECKPOINT_PATH = f\"{BASE_OUTPUT_DIRECTORY}/gemma3_checkpoint\"\n", + " print(\"Model checkpoint will be downloaded from HuggingFace at: \", MODEL_CHECKPOINT_PATH)\n", + " print(\"Set MODEL_CHECKPOINT_PATH if you do not wish to download the checkpoint.\")\n", + "\n", + "RUN_NAME = datetime.datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Download Gemma3-4B Model Checkpoint from Hugging Face" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if not epath.Path(MODEL_CHECKPOINT_PATH).exists():\n", + " print(\"Converting checkpoint from HuggingFace...\")\n", + " env = os.environ.copy()\n", + " env[\"JAX_PLATFORMS\"] = \"cpu\"\n", + "\n", + " subprocess.run(\n", + " [\n", + " sys.executable,\n", + " \"-m\", \"maxtext.checkpoint_conversion.to_maxtext\",\n", + " f\"{MAXTEXT_PKG_DIR}/configs/base.yml\",\n", + " f\"model_name={MODEL_NAME}\",\n", + " f\"base_output_directory={MODEL_CHECKPOINT_PATH}\",\n", + " \"use_multimodal=True\",\n", + " \"scan_layers=True\",\n", + " \"skip_jax_distributed_system=True\",\n", + " \"--eager_load_method=transformers\",\n", + " \"--lazy_load_tensors=False\",\n", + " ],\n", + " check=True,\n", + " env=env\n", + " )\n", + "\n", + " # The HF model cache is no longer needed after conversion to MaxText format.\n", + " import shutil\n", + " hf_cache = epath.Path(os.path.expanduser(\"~\")) / \".cache\" / \"huggingface\" / \"hub\"\n", + " if hf_cache.exists():\n", + " shutil.rmtree(str(hf_cache))\n", + "\n", + " MODEL_CHECKPOINT_PATH = os.path.join(MODEL_CHECKPOINT_PATH, \"0/items\")\n", + "else:\n", + " print(f\"Model checkpoint exists at {MODEL_CHECKPOINT_PATH}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## MaxText configurations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load configuration for SFT training\n", + "config_argv = [\n", + " \"\",\n", + " f\"{MAXTEXT_PKG_DIR}/configs/post_train/sft-vision-chartqa.yml\",\n", + " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}\",\n", + " f\"model_name={MODEL_NAME}\",\n", + " f\"tokenizer_path={TOKENIZER_NAME}\",\n", + " \"steps=10\",\n", + " \"attention=dot_product\",\n", + " \"per_device_batch_size=1\",\n", + " \"max_prefill_predict_length=1024\",\n", + " \"max_target_length=2048\",\n", + " \"max_num_images_per_example=1\",\n", + " \"scan_layers=True\",\n", + " f\"base_output_directory={BASE_OUTPUT_DIRECTORY}\",\n", + " f\"run_name={RUN_NAME}\",\n", + " \"save_checkpoint_on_completion=True\",\n", + "]\n", + "\n", + "config = pyconfig.initialize_pydantic(config_argv)\n", + "\n", + "print(\"βœ“ SFT configuration loaded:\")\n", + "print(f\" Model: {config.model_name}\")\n", + "print(f\" Training Steps: {config.steps}\")\n", + "print(\"Model Checkpoint Path: \", config.load_parameters_path)\n", + "print(f\" Output Directory: {config.base_output_directory}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## SFT Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from absl import flags\n", + "if not flags.FLAGS.is_parsed():\n", + " flags.FLAGS.mark_as_parsed()\n", + "\n", + "import traceback\n", + "\n", + "print(\"=\" * 60)\n", + "print(\"πŸš€ Starting SFT Training...\")\n", + "print(\"=\" * 60)\n", + "\n", + "try:\n", + " _ = train_sft_native.train_loop(config, recorder=None)\n", + " print(\"\\n\" + \"=\" * 60)\n", + " print(\"βœ… Training Completed Successfully!\")\n", + " print(\"=\" * 60)\n", + " print(f\"πŸ“ Checkpoints saved to: {config.checkpoint_dir}\")\n", + "except Exception:\n", + " print(\"\\n\" + \"=\" * 60)\n", + " print(\"❌Training Failed!\")\n", + " print(\"=\" * 60)\n", + " traceback.print_exc()\n", + " print(\"\\nFor troubleshooting, refer to the Common Pitfalls & Debugging section:\")\n", + " print(\"https://maxtext.readthedocs.io/en/latest/guides/run_python_notebook.html#common-pitfalls-debugging\")\n", + " sys.exit(1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Convert MaxText Checkpoint to Hugging Face Format" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the output directory for the Hugging Face checkpoint\n", + "hf_output_directory = epath.Path(BASE_OUTPUT_DIRECTORY) / \"hf_checkpoint\"\n", + "\n", + "# Find the latest MaxText checkpoint\n", + "checkpoint_dir = epath.Path(config.checkpoint_dir)\n", + "step_dirs = [d.name for d in checkpoint_dir.iterdir() if d.name.isdigit() and d.is_dir()]\n", + "if not step_dirs:\n", + " print(f\"No checkpoint found in {checkpoint_dir}\")\n", + "else:\n", + " latest_step = max(step_dirs, key=int)\n", + " maxtext_checkpoint_path = checkpoint_dir / latest_step / \"items\"\n", + "\n", + " print(f\"Converting MaxText checkpoint from: {maxtext_checkpoint_path}\")\n", + " print(f\"Saving Hugging Face checkpoint to: {hf_output_directory}\")\n", + "\n", + " # Run the conversion script\n", + " env = os.environ.copy()\n", + " env[\"JAX_PLATFORMS\"] = \"cpu\"\n", + "\n", + " subprocess.run(\n", + " [\n", + " sys.executable,\n", + " \"-m\", \"maxtext.checkpoint_conversion.to_huggingface\",\n", + " f\"{MAXTEXT_PKG_DIR}/configs/base.yml\",\n", + " f\"model_name={MODEL_NAME}\",\n", + " f\"load_parameters_path={str(maxtext_checkpoint_path)}\",\n", + " f\"base_output_directory={str(hf_output_directory)}\",\n", + " f\"scan_layers={config.scan_layers}\",\n", + " \"use_multimodal=true\",\n", + " \"skip_jax_distributed_system=True\",\n", + " \"weight_dtype=bfloat16\",\n", + " ],\n", + " check=True,\n", + " env=env\n", + " )\n", + "\n", + " print(\"βœ“ Conversion completed successfully!\")" + ] + } + ], + "metadata": { + "accelerator": "TPU", + "colab": { + "gpuType": "V5E1", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python (3.12.11)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/src/maxtext/trainers/post_train/sft/train_sft_native.py b/src/maxtext/trainers/post_train/sft/train_sft_native.py index 54596618ee..1094429d1d 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft_native.py +++ b/src/maxtext/trainers/post_train/sft/train_sft_native.py @@ -86,6 +86,7 @@ def train_loop(config, recorder, state=None): max_utils.print_compiled_memory_stats(compiled_stats) start_step = get_first_step(model, state) # this is the start_step for training + train_utils.validate_completed_steps(start_step, config.steps) prof = profiler.Profiler(config, offset_step=start_step) data_loader = DataLoader(config, mesh, data_iterator, recorder) metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)