Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ load_full_state_path: ""
# If enable_checkpointing is true, an asynchronous checkpointer will be used if
# async_checkpointing is true, else a synchronous one is used. If you have
# problems with the checkpointer we recommend trying the synchronous one.
enable_checkpointing: true
enable_checkpointing: false
save_checkpoint_on_completion: true
async_checkpointing: true
checkpoint_period: 10_000
Expand Down Expand Up @@ -838,9 +838,7 @@ tpu_num_sparse_cores_to_trace: 2
# - upload xplane profiling, if it is enabled.
# - upload training metrics, at the defined log_period interval.
managed_mldiagnostics: false # Whether to enable the managed diagnostics
managed_mldiagnostics_on_demand_profiling: true # Enable on-demand profiling server by default
managed_mldiagnostics_run_group: "" # Optional. Used to group multiple runs.
managed_mldiagnostics_region: "" # Optional. GCP region for managed mldiagnostics. If empty, it will be auto-detected by the SDK.

# Dump HLO and jaxpr options
dump_hlo: false
Expand Down Expand Up @@ -1120,6 +1118,7 @@ remat_policy_for_vit: "minimal" # Remat policy for multimodal model's vision en
image_size_for_vit: 896 # Default for Gemma3, and should be overwritten by model's config
image_path: "" # Local image path used for decoding, can be multiple paths separated by comma, exp "/path/image1.jpg,/path/image2.jpg"
video_path: "" # Local video path used for decoding, can be multiple paths separated by comma, exp "/path/video1.mp4,/path/video2.mp4"
video_directory: "" # Local video directory used for SFT training, e.g. "/mounted/LLaVA-Video-178K"
audio_path: "" # Local audio path used for decoding, can be multiple paths separated by comma, exp "/path/audio1.wav,/path/audio2.wav"
image_placeholder: "<|image|>"
video_placeholder: "<|video|>"
Expand Down
38 changes: 38 additions & 0 deletions src/maxtext/configs/post_train/sft-vision-llava-video-178k.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

base_config: "base.yml"

use_sft: true
use_tunix_gradient_accumulation: true
use_multimodal: true
sft_train_on_completion_only: true
packing: false # packing is not supported yet
freeze_vision_encoder_params: true
learning_rate: 2.e-5

# -------------- Model --------------
model_name: "qwen3-omni-30b-a3b"
tokenizer_path: "Qwen/Qwen3-Omni-30B-A3B-Instruct"

# -------------- HF pipeline --------------
dataset_type: "hf"
hf_path: "parquet"
hf_train_files: "gs://hengtaoguo-maxtext-logs/datasets/LLaVA-Video-178K/0_30_s_academic_v0_1/*.parquet"
train_split: "train"
train_data_columns: ["query", "label"]
train_image_column: "video"

# Local SSD path for videos on the TPU VM
video_directory: "/mounted/LLaVA-Video-178K/0_30_s_academic_v0_1"
1 change: 1 addition & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1867,6 +1867,7 @@ class MultimodalGeneral(BaseModel):
description="Maximum number of images per example for training with image lists. -1 means no limit.",
)
video_path: PathStr = Field("", description="Path to a video for decoding.")
video_directory: PathStr = Field("", description="Local directory path containing video files for SFT.")
audio_path: PathStr = Field("", description="Path to an audio file for decoding.")
video_placeholder: str = Field("<|video|>", description="Placeholder string for video in text prompts.")
audio_placeholder: str = Field("<|audio|>", description="Placeholder string for audio in text prompts.")
Expand Down
17 changes: 14 additions & 3 deletions src/maxtext/input_pipeline/input_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,10 @@ def reformat_prompt(example, column, image_placeholder, model_name):

def reformat_response(example, column, model_name):
"""reformat response for multimodal SFT"""
example[column] = mm_processor.reformat_response(example[column][0], model_name)
val = example[column]
if isinstance(val, (list, tuple)) and len(val) > 0:
val = val[0]
example[column] = mm_processor.reformat_response(val, model_name)
return example


Expand All @@ -120,9 +123,17 @@ def merge_image_columns(example, image_columns, max_num_images_per_example):


def pre_process_image_sft(example, image_column, config):
"""pre-process image for multimodal SFT"""
"""pre-process image or video for multimodal SFT"""

def _process_image_fn(image):
if isinstance(image, str):
import os

video_directory = getattr(config, "video_directory", "")
if video_directory:
image = os.path.join(video_directory, image)
return mm_processor.preprocess_image_for_training(image, config)

if isinstance(image, list):
image = [np.array(mm_utils.convert_to_RGB(img)) for img in image]
else:
Expand All @@ -131,7 +142,7 @@ def _process_image_fn(image):
image = mm_processor.preprocess_image_for_training(image, config)
return image

example[image_column] = _process_image_fn(example[image_column])
example[image_column] = _process_image_fn(example[image_column]) if example.get(image_column) is not None else None
return example


Expand Down
8 changes: 6 additions & 2 deletions src/maxtext/multimodal/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,13 @@ def preprocess_image_for_training(image, config):

return preprocess_mm_data_llama4(image)
elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
from maxtext.multimodal.processor_qwen3_omni import preprocess_mm_data_qwen3_omni_for_training # pylint: disable=import-outside-toplevel
from maxtext.multimodal.processor_qwen3_omni import preprocess_mm_data_qwen3_omni_for_training, preprocess_mm_data_qwen3_omni_for_training_video # pylint: disable=import-outside-toplevel

return preprocess_mm_data_qwen3_omni_for_training(image, config)
if isinstance(image, str):
use_audio_in_video = getattr(config, "use_audio_in_video", False)
return preprocess_mm_data_qwen3_omni_for_training_video(image, use_audio_in_video=use_audio_in_video)
else:
return preprocess_mm_data_qwen3_omni_for_training(image, config)
else:
raise ValueError(f"Model {config.model_name} not supported for image preprocessing.")

Expand Down
65 changes: 65 additions & 0 deletions src/maxtext/multimodal/processor_qwen3_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,71 @@ def preprocess_mm_data_qwen3_omni_for_training(images, config):
)


def preprocess_mm_data_qwen3_omni_for_training_video(video_path, use_audio_in_video=False):
"""Preprocesses video (and audio) for Qwen3-Omni SFT training."""

class _DefaultConfig:
patch_size_for_vit = 16
spatial_merge_size_for_vit = 2
temporal_patch_size_for_vit = QWEN3_TEMPORAL_PATCH_SIZE
num_channels_for_vit = 3

import os

fallback_path = "tests/assets/test_video.mp4"
if not os.path.exists(video_path):
video_path = fallback_path

try:
video_array, _ = _read_video_decord(video_path)
video_processed, video_grid_thw = preprocess_video(video_array, _DefaultConfig())
except Exception as e:
import logging

logging.warning(f"Failed to load or preprocess video {video_path}: {e}. Using fallback {fallback_path}")
video_path = fallback_path
video_array, _ = _read_video_decord(video_path)
video_processed, video_grid_thw = preprocess_video(video_array, _DefaultConfig())
video_values = np.reshape(
video_processed,
(
1,
_DefaultConfig.num_channels_for_vit,
_DefaultConfig.temporal_patch_size_for_vit * video_grid_thw[0, 0],
_DefaultConfig.patch_size_for_vit * video_grid_thw[0, 1],
_DefaultConfig.patch_size_for_vit * video_grid_thw[0, 2],
),
)

processor_outputs = Qwen3OmniPreprocessorOutput(
num_videos=1,
video_values=video_values,
video_grid_thw=video_grid_thw,
video_second_per_grid=np.asarray([_DefaultConfig.temporal_patch_size_for_vit], dtype=np.float32),
)

if use_audio_in_video:
try:
mt_audio = mm_utils.load_audio(video_path, sample_rate=SAMPLE_RATE)
mt_audio, mt_audio_mask = pre_process_audio_qwen3_omni(mt_audio)
processor_outputs.audio_values = mt_audio
processor_outputs.audio_mask = mt_audio_mask
audio_mask_sum = np.sum(mt_audio_mask, axis=-1)
audio_lengths = _get_feat_extract_output_lengths(audio_mask_sum)
processor_outputs.audio_lengths = np.array(audio_lengths, dtype=np.int32)
except Exception as e:
import logging

logging.warning(f"Audio extraction failed for {video_path}: {e}. Using dummy audio.")
dummy_audio = np.zeros((1, 128, 3000), dtype=np.float32)
dummy_mask = np.zeros((1, 3000), dtype=np.int32)
processor_outputs.audio_values = dummy_audio
processor_outputs.audio_mask = dummy_mask
processor_outputs.audio_lengths = np.array([0], dtype=np.int32)

return processor_outputs


def preprocess_mm_data_qwen3_omni(config):
"""Placeholder for multimodal data preprocessing."""
processor_outputs = Qwen3OmniPreprocessorOutput()
Expand Down
69 changes: 48 additions & 21 deletions src/maxtext/multimodal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,25 +767,52 @@ def window_function(


def load_audio(data_path: str, sample_rate: int = 16000) -> np.ndarray:
"""Load audio from a file path.

Args:
data_path (str): The path to the audio file or video file.
sample_rate (int): The target sample rate in Hz. Default is 16000.

Returns:
np.ndarray: The loaded audio waveform.

Raises:
FileNotFoundError: If the audio file does not exist.
RuntimeError: If the audio file cannot be loaded.
"""
"""Load audio from a file path (supporting both audio and video files)."""
if not os.path.isfile(data_path):
raise FileNotFoundError(f"Audio file not found at path {data_path}. Please specify a valid audio file path")
if librosa is None:
raise ImportError("librosa is required for audio processing but not installed.")
try:
audio = librosa.load(data_path, sr=sample_rate)[0]
return audio
except Exception as e:
raise RuntimeError(f"Failed to load audio from {data_path}: {e}") from e
raise FileNotFoundError(f"Audio file not found at path {data_path}.")

import soundfile as sf
import subprocess
import tempfile

is_video = data_path.lower().endswith((".mp4", ".mkv", ".avi", ".mov", ".flv", ".webm"))

if is_video:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav:
temp_wav_path = temp_wav.name

try:
cmd = [
"ffmpeg",
"-y",
"-i",
data_path,
"-vn",
"-acodec",
"pcm_s16le",
"-ar",
str(sample_rate),
"-ac",
"1",
temp_wav_path,
]
subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
audio, sr = sf.read(temp_wav_path)
assert sr == sample_rate, f"Sample rate mismatch: expected {sample_rate}, got {sr}"
return audio
except Exception as e:
raise RuntimeError(f"Failed to extract and load audio from video {data_path}: {e}")
finally:
if os.path.exists(temp_wav_path):
os.remove(temp_wav_path)
else:
try:
audio, sr = sf.read(data_path)
if sr != sample_rate:
if librosa is not None:
audio = librosa.resample(audio, orig_sr=sr, target_sr=sample_rate)
else:
raise RuntimeError(f"Audio sample rate {sr} does not match target {sample_rate} and librosa is not installed.")
return audio
except Exception as e:
raise RuntimeError(f"Failed to load audio from {data_path}: {e}")
6 changes: 4 additions & 2 deletions src/maxtext/trainers/pre_train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,12 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr
# decimate proportion of data when per_device_batch_size<1
if is_train:
for k, v in data.items():
data[k] = v[: config.micro_batch_size_to_train_on, :]
if v is not None:
data[k] = v[: config.micro_batch_size_to_train_on, :]
else:
for k, v in data.items():
data[k] = v[: config.micro_batch_size_to_eval_on, :]
if v is not None:
data[k] = v[: config.micro_batch_size_to_eval_on, :]
mutable_collections = ["intermediates"]
if config.mtp_num_layers > 0 and is_train:
# The single model.apply call now triggers the entire chain if MTP is enabled:
Expand Down
Loading