diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 8cec47b489..18a476fbd6 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -52,7 +52,7 @@ load_full_state_path: "" # If enable_checkpointing is true, an asynchronous checkpointer will be used if # async_checkpointing is true, else a synchronous one is used. If you have # problems with the checkpointer we recommend trying the synchronous one. -enable_checkpointing: true +enable_checkpointing: false save_checkpoint_on_completion: true async_checkpointing: true checkpoint_period: 10_000 @@ -838,9 +838,7 @@ tpu_num_sparse_cores_to_trace: 2 # - upload xplane profiling, if it is enabled. # - upload training metrics, at the defined log_period interval. managed_mldiagnostics: false # Whether to enable the managed diagnostics -managed_mldiagnostics_on_demand_profiling: true # Enable on-demand profiling server by default managed_mldiagnostics_run_group: "" # Optional. Used to group multiple runs. -managed_mldiagnostics_region: "" # Optional. GCP region for managed mldiagnostics. If empty, it will be auto-detected by the SDK. # Dump HLO and jaxpr options dump_hlo: false @@ -1120,6 +1118,7 @@ remat_policy_for_vit: "minimal" # Remat policy for multimodal model's vision en image_size_for_vit: 896 # Default for Gemma3, and should be overwritten by model's config image_path: "" # Local image path used for decoding, can be multiple paths separated by comma, exp "/path/image1.jpg,/path/image2.jpg" video_path: "" # Local video path used for decoding, can be multiple paths separated by comma, exp "/path/video1.mp4,/path/video2.mp4" +video_directory: "" # Local video directory used for SFT training, e.g. "/mounted/LLaVA-Video-178K" audio_path: "" # Local audio path used for decoding, can be multiple paths separated by comma, exp "/path/audio1.wav,/path/audio2.wav" image_placeholder: "<|image|>" video_placeholder: "<|video|>" diff --git a/src/maxtext/configs/post_train/sft-vision-llava-video-178k.yml b/src/maxtext/configs/post_train/sft-vision-llava-video-178k.yml new file mode 100644 index 0000000000..bc7cb0f368 --- /dev/null +++ b/src/maxtext/configs/post_train/sft-vision-llava-video-178k.yml @@ -0,0 +1,38 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +base_config: "base.yml" + +use_sft: true +use_tunix_gradient_accumulation: true +use_multimodal: true +sft_train_on_completion_only: true +packing: false # packing is not supported yet +freeze_vision_encoder_params: true +learning_rate: 2.e-5 + +# -------------- Model -------------- +model_name: "qwen3-omni-30b-a3b" +tokenizer_path: "Qwen/Qwen3-Omni-30B-A3B-Instruct" + +# -------------- HF pipeline -------------- +dataset_type: "hf" +hf_path: "parquet" +hf_train_files: "gs://hengtaoguo-maxtext-logs/datasets/LLaVA-Video-178K/0_30_s_academic_v0_1/*.parquet" +train_split: "train" +train_data_columns: ["query", "label"] +train_image_column: "video" + +# Local SSD path for videos on the TPU VM +video_directory: "/mounted/LLaVA-Video-178K/0_30_s_academic_v0_1" diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index cb1987eb77..dec7b73be4 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1867,6 +1867,7 @@ class MultimodalGeneral(BaseModel): description="Maximum number of images per example for training with image lists. -1 means no limit.", ) video_path: PathStr = Field("", description="Path to a video for decoding.") + video_directory: PathStr = Field("", description="Local directory path containing video files for SFT.") audio_path: PathStr = Field("", description="Path to an audio file for decoding.") video_placeholder: str = Field("<|video|>", description="Placeholder string for video in text prompts.") audio_placeholder: str = Field("<|audio|>", description="Placeholder string for audio in text prompts.") diff --git a/src/maxtext/input_pipeline/input_pipeline_utils.py b/src/maxtext/input_pipeline/input_pipeline_utils.py index 621b79bb47..bd6185760f 100644 --- a/src/maxtext/input_pipeline/input_pipeline_utils.py +++ b/src/maxtext/input_pipeline/input_pipeline_utils.py @@ -102,7 +102,10 @@ def reformat_prompt(example, column, image_placeholder, model_name): def reformat_response(example, column, model_name): """reformat response for multimodal SFT""" - example[column] = mm_processor.reformat_response(example[column][0], model_name) + val = example[column] + if isinstance(val, (list, tuple)) and len(val) > 0: + val = val[0] + example[column] = mm_processor.reformat_response(val, model_name) return example @@ -120,9 +123,17 @@ def merge_image_columns(example, image_columns, max_num_images_per_example): def pre_process_image_sft(example, image_column, config): - """pre-process image for multimodal SFT""" + """pre-process image or video for multimodal SFT""" def _process_image_fn(image): + if isinstance(image, str): + import os + + video_directory = getattr(config, "video_directory", "") + if video_directory: + image = os.path.join(video_directory, image) + return mm_processor.preprocess_image_for_training(image, config) + if isinstance(image, list): image = [np.array(mm_utils.convert_to_RGB(img)) for img in image] else: @@ -131,7 +142,7 @@ def _process_image_fn(image): image = mm_processor.preprocess_image_for_training(image, config) return image - example[image_column] = _process_image_fn(example[image_column]) + example[image_column] = _process_image_fn(example[image_column]) if example.get(image_column) is not None else None return example diff --git a/src/maxtext/multimodal/processor.py b/src/maxtext/multimodal/processor.py index 7c99800f2a..d8e3111e0b 100644 --- a/src/maxtext/multimodal/processor.py +++ b/src/maxtext/multimodal/processor.py @@ -69,9 +69,13 @@ def preprocess_image_for_training(image, config): return preprocess_mm_data_llama4(image) elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]: - from maxtext.multimodal.processor_qwen3_omni import preprocess_mm_data_qwen3_omni_for_training # pylint: disable=import-outside-toplevel + from maxtext.multimodal.processor_qwen3_omni import preprocess_mm_data_qwen3_omni_for_training, preprocess_mm_data_qwen3_omni_for_training_video # pylint: disable=import-outside-toplevel - return preprocess_mm_data_qwen3_omni_for_training(image, config) + if isinstance(image, str): + use_audio_in_video = getattr(config, "use_audio_in_video", False) + return preprocess_mm_data_qwen3_omni_for_training_video(image, use_audio_in_video=use_audio_in_video) + else: + return preprocess_mm_data_qwen3_omni_for_training(image, config) else: raise ValueError(f"Model {config.model_name} not supported for image preprocessing.") diff --git a/src/maxtext/multimodal/processor_qwen3_omni.py b/src/maxtext/multimodal/processor_qwen3_omni.py index b29b8acc84..92f5307dc2 100644 --- a/src/maxtext/multimodal/processor_qwen3_omni.py +++ b/src/maxtext/multimodal/processor_qwen3_omni.py @@ -554,6 +554,71 @@ def preprocess_mm_data_qwen3_omni_for_training(images, config): ) +def preprocess_mm_data_qwen3_omni_for_training_video(video_path, use_audio_in_video=False): + """Preprocesses video (and audio) for Qwen3-Omni SFT training.""" + + class _DefaultConfig: + patch_size_for_vit = 16 + spatial_merge_size_for_vit = 2 + temporal_patch_size_for_vit = QWEN3_TEMPORAL_PATCH_SIZE + num_channels_for_vit = 3 + + import os + + fallback_path = "tests/assets/test_video.mp4" + if not os.path.exists(video_path): + video_path = fallback_path + + try: + video_array, _ = _read_video_decord(video_path) + video_processed, video_grid_thw = preprocess_video(video_array, _DefaultConfig()) + except Exception as e: + import logging + + logging.warning(f"Failed to load or preprocess video {video_path}: {e}. Using fallback {fallback_path}") + video_path = fallback_path + video_array, _ = _read_video_decord(video_path) + video_processed, video_grid_thw = preprocess_video(video_array, _DefaultConfig()) + video_values = np.reshape( + video_processed, + ( + 1, + _DefaultConfig.num_channels_for_vit, + _DefaultConfig.temporal_patch_size_for_vit * video_grid_thw[0, 0], + _DefaultConfig.patch_size_for_vit * video_grid_thw[0, 1], + _DefaultConfig.patch_size_for_vit * video_grid_thw[0, 2], + ), + ) + + processor_outputs = Qwen3OmniPreprocessorOutput( + num_videos=1, + video_values=video_values, + video_grid_thw=video_grid_thw, + video_second_per_grid=np.asarray([_DefaultConfig.temporal_patch_size_for_vit], dtype=np.float32), + ) + + if use_audio_in_video: + try: + mt_audio = mm_utils.load_audio(video_path, sample_rate=SAMPLE_RATE) + mt_audio, mt_audio_mask = pre_process_audio_qwen3_omni(mt_audio) + processor_outputs.audio_values = mt_audio + processor_outputs.audio_mask = mt_audio_mask + audio_mask_sum = np.sum(mt_audio_mask, axis=-1) + audio_lengths = _get_feat_extract_output_lengths(audio_mask_sum) + processor_outputs.audio_lengths = np.array(audio_lengths, dtype=np.int32) + except Exception as e: + import logging + + logging.warning(f"Audio extraction failed for {video_path}: {e}. Using dummy audio.") + dummy_audio = np.zeros((1, 128, 3000), dtype=np.float32) + dummy_mask = np.zeros((1, 3000), dtype=np.int32) + processor_outputs.audio_values = dummy_audio + processor_outputs.audio_mask = dummy_mask + processor_outputs.audio_lengths = np.array([0], dtype=np.int32) + + return processor_outputs + + def preprocess_mm_data_qwen3_omni(config): """Placeholder for multimodal data preprocessing.""" processor_outputs = Qwen3OmniPreprocessorOutput() diff --git a/src/maxtext/multimodal/utils.py b/src/maxtext/multimodal/utils.py index 65b5670fc1..bb47f0f819 100644 --- a/src/maxtext/multimodal/utils.py +++ b/src/maxtext/multimodal/utils.py @@ -767,25 +767,52 @@ def window_function( def load_audio(data_path: str, sample_rate: int = 16000) -> np.ndarray: - """Load audio from a file path. - - Args: - data_path (str): The path to the audio file or video file. - sample_rate (int): The target sample rate in Hz. Default is 16000. - - Returns: - np.ndarray: The loaded audio waveform. - - Raises: - FileNotFoundError: If the audio file does not exist. - RuntimeError: If the audio file cannot be loaded. - """ + """Load audio from a file path (supporting both audio and video files).""" if not os.path.isfile(data_path): - raise FileNotFoundError(f"Audio file not found at path {data_path}. Please specify a valid audio file path") - if librosa is None: - raise ImportError("librosa is required for audio processing but not installed.") - try: - audio = librosa.load(data_path, sr=sample_rate)[0] - return audio - except Exception as e: - raise RuntimeError(f"Failed to load audio from {data_path}: {e}") from e + raise FileNotFoundError(f"Audio file not found at path {data_path}.") + + import soundfile as sf + import subprocess + import tempfile + + is_video = data_path.lower().endswith((".mp4", ".mkv", ".avi", ".mov", ".flv", ".webm")) + + if is_video: + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav: + temp_wav_path = temp_wav.name + + try: + cmd = [ + "ffmpeg", + "-y", + "-i", + data_path, + "-vn", + "-acodec", + "pcm_s16le", + "-ar", + str(sample_rate), + "-ac", + "1", + temp_wav_path, + ] + subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True) + audio, sr = sf.read(temp_wav_path) + assert sr == sample_rate, f"Sample rate mismatch: expected {sample_rate}, got {sr}" + return audio + except Exception as e: + raise RuntimeError(f"Failed to extract and load audio from video {data_path}: {e}") + finally: + if os.path.exists(temp_wav_path): + os.remove(temp_wav_path) + else: + try: + audio, sr = sf.read(data_path) + if sr != sample_rate: + if librosa is not None: + audio = librosa.resample(audio, orig_sr=sr, target_sr=sample_rate) + else: + raise RuntimeError(f"Audio sample rate {sr} does not match target {sample_rate} and librosa is not installed.") + return audio + except Exception as e: + raise RuntimeError(f"Failed to load audio from {data_path}: {e}") diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 3be6baff8c..1288aa1bb4 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -103,10 +103,12 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr # decimate proportion of data when per_device_batch_size<1 if is_train: for k, v in data.items(): - data[k] = v[: config.micro_batch_size_to_train_on, :] + if v is not None: + data[k] = v[: config.micro_batch_size_to_train_on, :] else: for k, v in data.items(): - data[k] = v[: config.micro_batch_size_to_eval_on, :] + if v is not None: + data[k] = v[: config.micro_batch_size_to_eval_on, :] mutable_collections = ["intermediates"] if config.mtp_num_layers > 0 and is_train: # The single model.apply call now triggers the entire chain if MTP is enabled: