diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000..f2854ab --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,49 @@ +--- +name: ๐Ÿ› Bug Report +about: Report a bug to help us improve the project +title: '[BUG] ' +labels: 'bug' +assignees: '' + +--- + +## Bug Description + + + +## Steps to Reproduce + +1. +2. +3. + +**Code snippet (if applicable):** +```python +# Your code here +``` + +## Expected Behavior + + + +## Actual Behavior + + + +## Environment + +- **Component:** [API | Core functionality | Any] +- **OS:** [Windows | MacOS | Linux] +- **Python Version:** + +## Screenshots/Logs + + + +## Additional Context + + + +## Possible Solution + + diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000..9a27642 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,8 @@ +blank_issues_enabled: false +contact_links: + - name: ๐Ÿ’ฌ Discussion + url: https://github.com/VectorInstitute/aieng-template-uv/discussions + about: Ask questions or discuss ideas with the community + - name: ๐Ÿ“– Documentation + url: https://github.com/VectorInstitute/aieng-template-uv#readme + about: Check the documentation for setup and usage guides diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000..8ff6625 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,52 @@ +--- +name: โœจ Feature Request +about: Suggest a new feature or enhancement +title: '[FEATURE] ' +labels: 'enhancement' +assignees: '' + +--- + +## Problem Statement + + + +## Proposed Solution + + + +## Alternative Solutions + + + +## Use Cases + + +- +- +- + +## Implementation Ideas + + + +## Component Impact + + +- [ ] API +- [ ] Core functionality +- [ ] Docker/Infrastructure +- [ ] Documentation +- [ ] Any other part of the system + +## Additional Context + + + +## Priority + + +- [ ] Nice to have +- [ ] Would be helpful +- [ ] Important for my use case +- [ ] Critical/blocking diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..ea8c4ac --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,17 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file + +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" # Location of package manifests + schedule: + interval: "daily" + + # Keep uv dependencies (uv.lock) up to date + - package-ecosystem: "uv" + directory: "/" # where pyproject.toml and uv.lock live + schedule: + interval: "daily" diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000..8c195a3 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,53 @@ +## Summary + + + +Clickup Ticket(s): Link(s) if applicable. + +## Type of Change + +- [ ] ๐Ÿ› Bug fix (non-breaking change that fixes an issue) +- [ ] โœจ New feature (non-breaking change that adds functionality) +- [ ] ๐Ÿ’ฅ Breaking change (fix or feature that would cause existing functionality to not work as expected) +- [ ] ๐Ÿ“ Documentation update +- [ ] ๐Ÿ”ง Refactoring (no functional changes) +- [ ] โšก Performance improvement +- [ ] ๐Ÿงช Test improvements +- [ ] ๐Ÿ”’ Security fix + +## Changes Made + + +- +- +- + +## Testing + + +- [ ] Tests pass locally (`uv run pytest tests/`) +- [ ] Type checking passes (`uv run mypy `) +- [ ] Linting passes (`uv run ruff check src_dir/`) +- [ ] Manual testing performed (describe below) + +**Manual testing details:** + + +## Screenshots/Recordings + + + +## Related Issues + + + +## Deployment Notes + + + +## Checklist + +- [ ] Code follows the project's style guidelines +- [ ] Self-review of code completed +- [ ] Documentation updated (if applicable) +- [ ] No sensitive information (API keys, credentials) exposed diff --git a/.github/workflows/code_checks.yml b/.github/workflows/code_checks.yml new file mode 100644 index 0000000..b911ce9 --- /dev/null +++ b/.github/workflows/code_checks.yml @@ -0,0 +1,59 @@ +name: code checks +permissions: + contents: read + pull-requests: write + +on: + push: + branches: + - main + paths: + - .pre-commit-config.yaml + - .github/workflows/code_checks.yml + - '**.py' + - uv.lock + - pyproject.toml + - '**.ipynb' + pull_request: + branches: + - main + paths: + - .pre-commit-config.yaml + - .github/workflows/code_checks.yml + - '**.py' + - uv.lock + - pyproject.toml + - '**.ipynb' + +jobs: + run-code-check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6.0.1 + + - name: Install uv + uses: astral-sh/setup-uv@v7.2.0 + with: + # Install a specific version of uv. + version: "0.9.11" + enable-cache: true + + - name: "Set up Python" + uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 + with: + python-version-file: ".python-version" + + - name: Install the project + run: uv sync --all-extras --dev + + - name: Install dependencies and check code + run: | + source .venv/bin/activate + pre-commit run --all-files + + - name: pip-audit (gh-action-pip-audit) + uses: pypa/gh-action-pip-audit@v1.1.0 + with: + virtual-environment: .venv/ + ignore-vulns: | + GHSA-4xh5-x5gv-qwph diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 0000000..a36463e --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,100 @@ +name: docs +permissions: + contents: write + pull-requests: write + +on: + push: + branches: + - main + paths: + - .pre-commit-config.yaml + - .github/workflows/docs.yml + - '**.py' + - '**.ipynb' + - '**.html' + - '**.js' + - '**.md' + - uv.lock + - pyproject.toml + - mkdocs.yml + - '**.png' + - '**.svg' + pull_request: + branches: + - main + paths: + - .pre-commit-config.yaml + - .github/workflows/docs.yml + - '**.py' + - '**.ipynb' + - '**.js' + - '**.html' + - uv.lock + - pyproject.toml + - '**.md' + - mkdocs.yml + - '**.png' + - '**.svg' + +jobs: + build: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v6.0.1 + + - name: Install uv + uses: astral-sh/setup-uv@v7.2.0 + with: + version: "0.9.11" + enable-cache: true + + - name: Set up Python + uses: actions/setup-python@v6.1.0 + with: + python-version-file: ".python-version" + + - name: Install the project + run: uv sync --all-extras --group docs + + - name: Build docs + run: uv run mkdocs build + + - name: Create .nojekyll file + run: touch site/.nojekyll + + - name: Upload artifact + uses: actions/upload-artifact@v6 + with: + name: docs-site + path: site/ + retention-days: 1 + + deploy: + needs: build + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v6.0.1 + + - name: Configure Git Credentials + run: | + git config user.name github-actions[bot] + git config user.email 41898282+github-actions[bot]@users.noreply.github.com + + - name: Download artifact + uses: actions/download-artifact@v7 + with: + name: docs-site + path: site + + - name: Ensure .nojekyll exists + run: touch site/.nojekyll + + - name: Deploy to Github pages + uses: JamesIves/github-pages-deploy-action@v4.8.0 + with: + branch: gh-pages + folder: site diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml new file mode 100644 index 0000000..31e7825 --- /dev/null +++ b/.github/workflows/integration_tests.yml @@ -0,0 +1,70 @@ +name: integration tests +permissions: + contents: read + pull-requests: write + +on: + push: + branches: + - main + paths: + - .pre-commit-config.yaml + - .github/workflows/code_checks.yml + - .github/workflows/docs.yml + - .github/workflows/unit_tests.yml + - .github/workflows/integration_tests.yml + - '**.py' + - '**.ipynb' + - uv.lock + - pyproject.toml + - '**.rst' + - '**.md' + pull_request: + branches: + - main + paths: + - .pre-commit-config.yaml + - .github/workflows/code_checks.yml + - .github/workflows/docs.yml + - .github/workflows/unit_tests.yml + - .github/workflows/integration_tests.yml + - '**.py' + - '**.ipynb' + - uv.lock + - pyproject.toml + - '**.rst' + - '**.md' + +jobs: + integration-tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6.0.1 + + - name: Install uv + uses: astral-sh/setup-uv@v7.2.0 + with: + # Install a specific version of uv. + version: "0.9.11" + enable-cache: true + + - name: "Set up Python" + uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 + with: + python-version-file: ".python-version" + + - name: Install the project + run: uv sync --all-extras --dev + + - name: Install dependencies and check code + run: | + uv run pytest -m "integration_test" --cov src/aieng_template_uv --cov-report=xml tests + + # Uncomment this once this repo is configured on Codecov + - name: Upload coverage to Codecov + uses: codecov/codecov-action@671740ac38dd9b0130fbe1cec585b89eea48d3de + with: + token: ${{ secrets.CODECOV_TOKEN }} + slug: VectorInstitute/aieng-template-uv + fail_ci_if_error: false + verbose: true diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..bdfa60a --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,50 @@ +name: publish package +permissions: + contents: read + pull-requests: write + +on: + push: + tags: + - "v*" + +jobs: + deploy: + runs-on: ubuntu-latest + steps: + - name: Install apt dependencies + run: | + sudo apt-get update + sudo apt-get install libcurl4-openssl-dev libssl-dev + - uses: actions/checkout@v6.0.1 + + - name: Install uv + uses: astral-sh/setup-uv@v7.2.0 + with: + # Install a specific version of uv. + version: "0.9.11" + enable-cache: true + + - name: "Set up Python" + uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 + with: + python-version-file: ".python-version" + + - name: Install the project + run: uv sync --all-extras --dev + + - name: Build package + run: uv build + + - name: Publish package + uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} + + - name: Create GitHub Release + id: create_release + uses: ncipollo/release-action@b7eabc95ff50cbeeedec83973935c8f306dfcd0b # v1.20.0 + with: + artifacts: "dist/*" + generateReleaseNotes: true diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml new file mode 100644 index 0000000..624f694 --- /dev/null +++ b/.github/workflows/unit_tests.yml @@ -0,0 +1,70 @@ +name: unit tests +permissions: + contents: read + pull-requests: write + +on: + push: + branches: + - main + paths: + - .pre-commit-config.yaml + - .github/workflows/code_checks.yml + - .github/workflows/docs.yml + - .github/workflows/unit_tests.yml + - .github/workflows/integration_tests.yml + - '**.py' + - '**.ipynb' + - uv.lock + - pyproject.toml + - '**.rst' + - '**.md' + pull_request: + branches: + - main + paths: + - .pre-commit-config.yaml + - .github/workflows/code_checks.yml + - .github/workflows/docs.yml + - .github/workflows/unit_tests.yml + - .github/workflows/integration_tests.yml + - '**.py' + - '**.ipynb' + - uv.lock + - pyproject.toml + - '**.rst' + - '**.md' + +jobs: + unit-tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6.0.1 + + - name: Install uv + uses: astral-sh/setup-uv@v7.2.0 + with: + # Install a specific version of uv. + version: "0.9.11" + enable-cache: true + + - name: "Set up Python" + uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 + with: + python-version-file: ".python-version" + + - name: Install the project + run: uv sync --all-extras --dev + + - name: Install dependencies and check code + run: | + uv run pytest -m "not integration_test" --cov src/aieng_template_uv --cov-report=xml tests + + # Uncomment this once this repo is configured on Codecov + - name: Upload coverage to Codecov + uses: codecov/codecov-action@671740ac38dd9b0130fbe1cec585b89eea48d3de + with: + token: ${{ secrets.CODECOV_TOKEN }} + slug: VectorInstitute/aieng-template-uv + fail_ci_if_error: false + verbose: true diff --git a/.gitignore b/.gitignore index 4a61067..24d03cb 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ magnet/.env *.lock **.sh **.png +!**/docs/** # Ignore everything under external_repos/ **/external_repos/** @@ -21,7 +22,7 @@ magnet/.env # Large files and datasets **/vqa_backup_before_hf_fix/ **/vqa_backup_reviewed/ -**/vqa/** +**/vqa/** **/plots/**/ **/optionalFiles/VideoAudioDemographicsAnalysis **/optionalFiles/DemographicsAnalysis @@ -82,7 +83,7 @@ test_*.py **/plots/ **/video_analysis/ **/videos_Unfiltered/ -**/optionalFiles/** +**/optionalFiles/** **/repo.py sonic-o1/optionalFiles/ uv.lock @@ -92,3 +93,7 @@ sonic-o1/05_evaluation_inference/metrics/demographics_metrics.py sonic-o1/05_evaluation_inference/models/backup_*.txt sonic-o1/05_evaluation_inference/visualitzation/ **/results/ +.cache/ +**.pt +**/logs/** +docs/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..8fcfb9d --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,85 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v6.0.0 # Use the ref you want to point at + hooks: + - id: trailing-whitespace + - id: check-ast + - id: check-builtin-literals + - id: check-docstring-first + - id: check-executables-have-shebangs + - id: debug-statements + - id: end-of-file-fixer + - id: mixed-line-ending + args: [--fix=lf] + - id: fix-byte-order-marker + - id: check-merge-conflict + - id: check-symlinks + - id: detect-private-key + - id: check-yaml + args: [--unsafe] + - id: check-toml + + - repo: https://github.com/astral-sh/uv-pre-commit + rev: 0.9.24 + hooks: + - id: uv-lock + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: 'v0.14.11' + hooks: + - id: ruff-check + args: [--fix, --exit-non-zero-on-fix] + types_or: [python, jupyter] + - id: ruff-format + types_or: [python, jupyter] + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.19.1 + hooks: + - id: mypy + entry: python3 -m mypy --config-file pyproject.toml + language: system + types: [python] + exclude: "tests" + + - repo: https://github.com/crate-ci/typos + rev: v1 + hooks: + - id: typos + args: [] + + - repo: https://github.com/nbQA-dev/nbQA + rev: 1.9.1 + hooks: + - id: nbqa-ruff + args: [--fix, --exit-non-zero-on-fix] + + - repo: local + hooks: + - id: doctest + name: doctest + entry: python3 -m doctest -o NORMALIZE_WHITESPACE + files: "^sonic-o1/" + exclude: "sonic-o1/huggingface_review_template/" + language: system + + - repo: local + hooks: + - id: pytest + name: pytest + entry: python3 -m pytest -m "not integration_test" + language: system + pass_filenames: false + always_run: true + +ci: + autofix_commit_msg: | + [pre-commit.ci] Add auto fixes from pre-commit.com hooks + + for more information, see https://pre-commit.ci + autofix_prs: true + autoupdate_branch: '' + autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate' + autoupdate_schedule: weekly + skip: [pytest,doctest,mypy] + submodules: false diff --git a/README.md b/README.md index b4ae1d1..6083610 100644 --- a/README.md +++ b/README.md @@ -337,4 +337,3 @@ Common issues: * Open an issue on GitHub: [https://github.com/VectorInstitute/sonic-o1/issues](https://github.com/VectorInstitute/sonic-o1/issues) * Check individual stage README files for detailed troubleshooting * Review stage-specific configuration examples in `config/` directories - diff --git a/docs/assets/AIXPERTLogo.png b/docs/assets/AIXPERTLogo.png new file mode 100644 index 0000000..141949e Binary files /dev/null and b/docs/assets/AIXPERTLogo.png differ diff --git a/docs/assets/GitHub_Logo.png b/docs/assets/GitHub_Logo.png new file mode 100644 index 0000000..e03d8dd Binary files /dev/null and b/docs/assets/GitHub_Logo.png differ diff --git a/docs/assets/Sunburst_Topics.png b/docs/assets/Sunburst_Topics.png new file mode 100644 index 0000000..ceab199 Binary files /dev/null and b/docs/assets/Sunburst_Topics.png differ diff --git a/docs/assets/Teaser_Figure.png b/docs/assets/Teaser_Figure.png new file mode 100644 index 0000000..bdc900c Binary files /dev/null and b/docs/assets/Teaser_Figure.png differ diff --git a/docs/assets/VectorLogo.png b/docs/assets/VectorLogo.png new file mode 100644 index 0000000..4e0ed74 Binary files /dev/null and b/docs/assets/VectorLogo.png differ diff --git a/docs/assets/VectorLogo_Black.png b/docs/assets/VectorLogo_Black.png new file mode 100644 index 0000000..a45b2e8 Binary files /dev/null and b/docs/assets/VectorLogo_Black.png differ diff --git a/docs/assets/arxiv-logo.png b/docs/assets/arxiv-logo.png new file mode 100644 index 0000000..ec67468 Binary files /dev/null and b/docs/assets/arxiv-logo.png differ diff --git a/docs/assets/full_pipeline.drawio.png b/docs/assets/full_pipeline.drawio.png new file mode 100644 index 0000000..72f71d4 Binary files /dev/null and b/docs/assets/full_pipeline.drawio.png differ diff --git a/docs/assets/hf-logo.svg b/docs/assets/hf-logo.svg new file mode 100644 index 0000000..ab959d1 --- /dev/null +++ b/docs/assets/hf-logo.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/docs/assets/plot1_summarization.png b/docs/assets/plot1_summarization.png new file mode 100644 index 0000000..6a3911c Binary files /dev/null and b/docs/assets/plot1_summarization.png differ diff --git a/docs/assets/plot2_mcq.png b/docs/assets/plot2_mcq.png new file mode 100644 index 0000000..a55ddb4 Binary files /dev/null and b/docs/assets/plot2_mcq.png differ diff --git a/docs/assets/plot3_temporal.png b/docs/assets/plot3_temporal.png new file mode 100644 index 0000000..738c895 Binary files /dev/null and b/docs/assets/plot3_temporal.png differ diff --git a/docs/assets/plot_category_task_distribution_topics.png b/docs/assets/plot_category_task_distribution_topics.png new file mode 100644 index 0000000..07d8eda Binary files /dev/null and b/docs/assets/plot_category_task_distribution_topics.png differ diff --git a/docs/assets/plot_duration_category_by_topic_videos.png b/docs/assets/plot_duration_category_by_topic_videos.png new file mode 100644 index 0000000..1d81a04 Binary files /dev/null and b/docs/assets/plot_duration_category_by_topic_videos.png differ diff --git a/docs/assets/spider_chart.png b/docs/assets/spider_chart.png new file mode 100644 index 0000000..a149a44 Binary files /dev/null and b/docs/assets/spider_chart.png differ diff --git a/docs/index.html b/docs/index.html new file mode 100644 index 0000000..1e529cc --- /dev/null +++ b/docs/index.html @@ -0,0 +1,1069 @@ + + + + + + + SONIC-O1: A Real-World Benchmark for Evaluating Multimodal Large Language Models on Audio-Video Understanding + + + + + + + + + +
+ +
+ +
+
+
+ +
+

SONIC-O1: A Real-World Benchmark for Evaluating Multimodal Large Language Models on Audio-Video Understanding

+ +

+ Ahmed Y. Radwan1, + Christos Emmanouildis2, + Hina Tabassum3, + Deval Pandya1, + Shaina Raza1 +

+ +

+ 1Vector Institute, Canada    + 2University of Groningen, The Netherlands    + 3York University, Canada    +

+ + + + +

+ SONIC-O1 is the first open-source evaluation suite for real-world audio-video understanding across diverse video durations (short, medium, and long-form interactions), + targeting global comprehension, fine-grained reasoning, and + temporal grounding โ€” with demographic metadata for comprehensive fairness analysis of multimodal AI systems. +

+ + +
+ 231 videos + ~60 hours + 4,958 human-verified QAs + 13 conversational domains + 3 evaluation tasks +
+
+ + +
+
+ SONIC-O1 multimodal AI benchmark overview showing three video understanding evaluation tasks: video summarization, multiple-choice question answering, and temporal localization with demographic fairness metadata for machine learning model assessment +
+ Figure 1. SONIC-O1 benchmark overview. Three example tasks: (Top) Video Summarization requiring + global comprehension; (Middle) Multiple-Choice Question (MCQ) with evidence-grounded reasoning; (Bottom) Temporal Localization + with precise event timing. Demographic metadata (race, gender, age) shown beneath each video segment enables fairness-aware evaluation + across 13 conversational domains. +
+
+
+
+
+
+ +
+ +
+

Abstract

+
+

+ Multimodal Large Language Models (MLLMs) have become a major focus in recent AI research. However, most existing + work still centers on static image understanding, while their ability to process sequential audio-video data remains + underexplored. This gap highlights the need for a high-quality benchmark to systematically evaluate MLLM performance + in dynamic, temporally grounded settings. +

+

+ We introduce SONIC-O1, a comprehensive, fully human-verified benchmark spanning 13 + real-world conversational domains with 4,958 annotations and demographic metadata. SONIC-O1 evaluates + MLLMs on key tasks, including open-ended summarization, multiple-choice question answering, and temporal localization + with supporting rationales (reasoning). +

+

+ Experiments across both commercial closed-source and open-source models reveal important limitations. While the gap + in MCQ accuracy is relatively small, we observe a substantial 22.9% performance difference in temporal + localization between the best commercial and best open-source systems. Performance further degrades across demographic + groups, indicating persistent disparities in model behavior. +

+

+ Overall, SONIC-O1 provides an open evaluation suite for temporally grounded and socially robust multimodal understanding. +

+
+
+ +
+

Overview

+
+
+
231
+
Videos
+
+
+
~60h
+
Total Duration
+
+
+
4,958
+
Human-Verified QAs
+
+
+
3
+
Evaluation Tasks
+
+
+ +
+
+

What SONIC-O1 Evaluates

+
    +
  • Global comprehension โ€” Video summarization of interactions across diverse durations (short, medium, long-form)
  • +
  • Fine-grained reasoning โ€” Multiple-choice question answering with evidence grounding
  • +
  • Temporal grounding โ€” Event localization with precise timestamps in video timelines
  • +
  • Demographic robustness โ€” AI fairness performance across race, age, and gender demographics
  • +
+

+ Unlike prior AI benchmarks that focus on short clips or single modalities, SONIC-O1 evaluates + omnimodal understanding (audio + video) and video question answering capabilities on realistic interactions across diverse durations (short, medium, long-form) + from high-stakes conversational domains. +

+
+ +
+
+ + +
+

Dataset

+ +
+
+ Sunburst visualization chart displaying distribution of 13 conversational video topics across 5 real-world domains for multimodal AI training and evaluation +
+ Figure 2. Video categories. SONIC-O1 covers 5 key domains and 13 sub-class video types spanning + professional, educational, legal/civic, service-oriented, and community/public health interactions. +
+
+
+

13 Conversational Topics

+
+
+

Professional

+
    +
  • Job Interviews
  • +
  • Workplace Team Meetings
  • +
+
+
+

Educational

+
    +
  • Parent-Teacher Conferences
  • +
+
+
+

Legal / Civic

+
    +
  • Courtroom Proceedings
  • +
  • Community Town Halls
  • +
+
+
+

Service-Oriented

+
    +
  • Customer Service
  • +
  • Restaurant Service
  • +
  • Housing/Apartment Tours
  • +
+
+
+

Community / Public Health

+
    +
  • Medical (Patient-Doctor)
  • +
  • Emergency Response
  • +
  • Public Transportation Conflicts
  • +
  • Mental Health Counseling
  • +
  • Olympics (Sports)
  • +
+
+
+
+
+ +
+
+ Bar chart visualization showing video duration distribution across conversational topics with short, medium, and long-form videos for comprehensive AI model testing +
+ Figure 3. Video duration distribution by topic. SONIC-O1 spans short (<5 min), medium (5-20 min), + and long (20-60 min) videos across all conversational domains. +
+
+
+ Bar chart showing distribution of video summarization, multiple-choice QA, and temporal localization questions across conversational topics for machine learning model benchmarking +
+ Figure 4. Question type distribution over topics. Each domain includes annotations for all three + evaluation tasks: summarization, multiple-choice QA, and temporal localization. +
+
+
+ +
+

Demographic Coverage

+

+ SONIC-O1 includes demographic annotations across: +

+
    +
  • Race/Ethnicity: White, Black, Asian, Hispanic, Indigenous, Arab
  • +
  • Gender: Male, Female
  • +
  • Age: 18-24, 25-39, 40+
  • +
+

+ All demographic labels are annotated from observable characteristics via AI-assisted human verification, + enabling systematic fairness evaluation and bias detection in multimodal AI systems across demographic groups. +

+
+
+ + +
+

Evaluation Tasks

+

+ SONIC-O1 evaluates three complementary capabilities: global comprehension (summarization), + fine-grained reasoning (MCQ), and temporal grounding (localization). +

+ +
+
+
+
Task 1
+

Video Summarization

+
+

+ Generate narrative summaries capturing key events, actions, and outcomes across full videos + (up to 60 minutes). Tests global comprehension and ability to synthesize information across + long temporal spans. +

+
+ Metrics: +
    +
  • LLM-as-Judge score (0โ€“10)
  • +
  • ROUGE-L
  • +
  • Cosine similarity
  • +
+
+
+ 231 instances +
+
+ +
+
+
Task 2
+

Multiple-Choice QA

+
+

+ Answer questions about 3-minute video segments with four answer choices plus "Not enough evidence" option. + Requires fine-grained comprehension and evidence-grounded reasoning across audio-visual modalities. +

+
+ Metrics: +
    +
  • Accuracy (%)
  • +
  • Rationale quality (LLM judge)
  • +
+
+
+ 1,335 instances +
+
+ +
+
+
Task 3
+

Temporal Localization

+
+

+ Localize events in time with start/end timestamps and temporal relations (before/during/after). + Tests whether models can identify not just what happens but when it occurs. +

+
+ Metrics: +
    +
  • Recall@IoU (R@0.3, R@0.5, R@0.7)
  • +
  • Mean IoU (mIoU)
  • +
  • Mean Absolute Error (MAE)
  • +
+
+
+ 3,392 instances +
+
+
+
+ + + +
+

Results

+

+ We evaluate 6 state-of-the-art multimodal models on SONIC-O1. Closed-source models (Gemini 3.0 Pro) + consistently outperform open-source alternatives, though temporal localization remains challenging for all systems. +

+ +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ModelLLM ParamsSummarization
Score (0โ€“10) โ†‘
MCQ
Accuracy (%) โ†‘
Temporal
R@0.5 (%) โ†‘
Gemini 3.0 Pro โ€ โ€”7.0796.425.4
Qwen3-Omni30B5.7293.62.8
UniMoE-2.033B4.7188.21.0
MiniCPM-o-2.69B3.3487.40.7
VITA 1.58B2.7781.61.2
VideoLLaMA27B1.5354.30.4
+
+ +

+ โ€  Denotes closed-source model. + All metrics are macro-averaged across 13 conversational topics. + See the live leaderboard + for full per-task breakdowns and latest submissions. +

+ +
+

Key Findings

+
    +
  • + Accuracy-temporal grounding disconnect: Gemini 3.0 Pro achieves 96.4% MCQ accuracy + but only 25.4% R@0.5 for temporal localization, revealing that models can identify what happens + but struggle to pinpoint when. +
  • +
  • + Temporal localization is the hardest task for video understanding AI: Open-source models achieve <3% R@0.5, + with a 22% performance gap to Gemini 3.0 Pro, indicating fundamental limitations in temporal reasoning and event timing prediction. +
  • +
  • + Model scale matters: Larger models (Qwen3-Omni 30B, UniMoE-2.0 33B) significantly + outperform smaller variants (7-9B), though even at scale, temporal grounding remains challenging. +
  • +
+
+ + +
+ + +
+

Per-Topic Performance

+ +
+
+

+ SONIC-O1 spans 13 real-world conversational domains across professional, civic, service-oriented, + and community interactions with diverse video durations. Multimodal AI performance varies significantly across topics, revealing domain-specific + strengths and weaknesses in audio-video understanding capabilities. +

+ +

Key Observations

+
    +
  • + Structured interactions are easier: Models perform best on formal settings like + medical consultations, job interviews, and courtroom proceedings where interactions follow predictable patterns. +
  • +
  • + High-stakes scenarios are harder: Emergency response and mental health counseling + show lower scores across all models, likely due to increased perceptual complexity and emotional nuance. +
  • +
  • + Gemini 3.0 Pro leads consistently: Achieves the highest scores across nearly all + topics, with particularly strong performance on professional and educational domains. +
  • +
  • + Open-source models show uneven robustness: Smaller models (VideoLLaMA2, VITA 1.5) + struggle more on complex topics, while larger open-source models (Qwen3-Omni) show more stable + cross-topic performance. +
  • +
+
+ +
+ Radar chart visualization comparing multimodal AI model performance across 13 conversational domains for video understanding and audio-video analysis benchmarking +
+ Figure 6. Performance comparison across 13 conversational domains. We evaluate six MLLMs + on video summarization using LLM-judge scores (0-10 scale, higher is better). Gemini 3.0 Pro consistently + outperforms open-source models, while high-stakes scenarios (Emergency Response, Mental Health) prove more + challenging than structured interactions (Medical, Job Interviews) across all models. +
+
+
+
+ + +
+

Fairness Analysis

+

+ SONIC-O1 includes comprehensive demographic annotations (race, gender, age) to enable systematic AI fairness evaluation and bias detection in video understanding models. + Results reveal significant performance disparities across demographic groups, with + Black and Indigenous participants showing consistently lower scores across multimodal AI systems. +

+ +

Summarization Fairness (LLM-as-Judge Score, 0โ€“10)

+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ModelWhiteBlackAsianHispanicIndigenousArabGap โ†“
Gemini 3.0 Pro โ€ 6.686.027.056.416.706.901.03
Qwen3-Omni5.284.395.714.994.135.951.82
UniMoE-2.0-Omni4.293.454.623.704.355.001.55
MiniCPM-o-2.63.262.923.263.043.613.570.69
VITA 1.52.502.312.652.211.652.761.11
VideoLLaMA21.451.381.631.231.041.000.63
+
+ +

MCQ Fairness (Accuracy, %)

+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ModelWhiteBlackAsianHispanicIndigenousArabGap โ†“
Gemini 3.0 Pro โ€ 96.996.497.896.194.398.44.1
Qwen3-Omni93.392.096.192.877.196.919.8
UniMoE-2.0-Omni88.987.489.285.580.095.315.3
MiniCPM-o-2.687.586.388.781.682.992.711.1
VITA 1.582.082.284.679.162.993.230.3
VideoLLaMA255.155.057.951.065.766.015.0
+
+ +

Temporal Localization Fairness (Recall@0.5, %)

+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ModelWhiteBlackAsianHispanicIndigenousArabGap โ†“
Gemini 3.0 Pro โ€ 23.019.530.723.840.921.121.4
Qwen3-Omni2.61.82.92.30.01.62.9
UniMoE-2.01.20.60.60.11.30.21.2
MiniCPM-o-2.60.90.30.80.20.02.62.6
VITA 1.51.41.41.40.81.31.20.6
VideoLLaMA20.50.30.40.01.30.01.3
+
+

+ โ€  Denotes closed-source model. + โ–  Best performing group per model. + โ–  Worst performing group per model. + Gap = difference between best and worst performing groups (lower is better). +

+ +
+

Fairness Observations

+
    +
  • + Black and Indigenous groups show systematically lower performance: Across most models + and tasks, these groups consistently score below other demographic slices, indicating training data imbalances. +
  • +
  • + Temporal localization shows the largest disparities: Gemini 3.0 Pro achieves 40.9% R@0.5 + for Indigenous participants but only 19.3% for Black participantsโ€”a 21.6 point gap, the largest across all tasks. +
  • +
  • + Closed-source models are more robust: Gemini 3.0 Pro shows smaller demographic gaps compared + to open-source alternatives, likely due to more diverse training data and safety alignment. +
  • +
  • + Gender and age show smaller but consistent gaps: Female participants and older adults (40+) + consistently score slightly higher across all models, suggesting bias toward more formal interaction styles. +
  • +
+
+
+ + +
+

Conclusion

+
+

+ We introduced SONIC-O1, a comprehensive real-world audio-video understanding benchmark for evaluating Multimodal Large Language Models through the lens + of fairness and AI safety. SONIC-O1 includes human-verified annotations across diverse video durations and three evaluation tasks: + video summarization, multiple-choice question answering, and temporal localization with reasonings, to assess both semantic + understanding and precise temporal grounding capabilities. +

+
+
๐Ÿ”
+
+

Key Finding: Audio Dominates, Temporal Grounding Remains Hard

+

+ Our findings show that audio or transcripts often provide the strongest cues for + comprehension, while temporal localization remains the most challenging + setting. Gemini 3.0 Pro achieves 96.4% MCQ accuracy and generates + high-quality rationales, yet attains only 25.4% recall at IoUโ‰ฅ0.5 for + temporal localizationโ€”a 22.9% performance gap separates the best commercial + and best open-source systems on this task. +

+
+
+

+ We also observe persistent demographic disparities across models, emphasizing + the need for more equitable multimodal evaluation. Performance degrades across demographic groups, + with Black and Indigenous participants showing systematically lower scores across + most models and tasks. Temporal localization exhibits the most severe disparities, with some models + collapsing to 0.0% R@0.5 for Indigenous participants while maintaining 40.9% for the same group + with Gemini. +

+

+ Overall, SONIC-O1 offers a practical testbed for measuring robustness and fairness in realistic + audio-video scenarios, and we hope it will guide future work on: +

+
    +
  • Stronger temporal reasoning beyond frame-level understanding
  • +
  • Broader benchmark coverage across languages, domains, and modalities
  • +
  • Fairness-aware training to address demographic performance gaps
  • +
  • Native audio-video integration rather than treating audio as optional
  • +
+ +
+
+ + + + + + +
+

Quick Start

+

+ The SONIC-O1 open-source repository provides an end-to-end pipeline for video data curation, annotation generation, + and multimodal AI model evaluation. Dataset and annotations are hosted on Hugging Face for easy access. +

+ +
+

Installation

+
+ +
# Clone repository (note: nested structure)
+            git clone https://github.com/VectorInstitute/sonic-o1.git
+            cd sonic-o1/sonic-o1
+
+            # Download dataset + annotations from Hugging Face
+            pip install huggingface_hub
+            huggingface-cli download vector-institute/sonic-o1 --repo-type dataset --local-dir ./
+
+            # Setup Python environment
+            python -m venv .venv
+            source .venv/bin/activate  # On Windows: .venv\Scripts\activate
+            pip install -r requirements_venv.txt
+
+
+ +
+

Load Dataset

+
+ +
from datasets import load_dataset
+
+            # Load individual tasks
+            ds_summ = load_dataset("vector-institute/sonic-o1", "task1_summarization")
+            ds_mcq = load_dataset("vector-institute/sonic-o1", "task2_mcq")
+            ds_temporal = load_dataset("vector-institute/sonic-o1", "task3_temporal_localization")
+
+            # Each sample includes:
+            # - video metadata (ID, topic, duration, demographics)
+            # - question/prompt
+            # - ground truth answer
+            # - rationale (for MCQ and temporal tasks)
+
+
+ +
+

Pipeline Stages

+

+ The repository includes modular scripts for each stage of the pipeline: +

+
    +
  1. 01_data_curation/ โ€” Video search, filtering, and metadata extraction
  2. +
  3. 02_caption_generation/ โ€” Whisper-based caption generation for videos without captions
  4. +
  5. 03_demographics_annotation/ โ€” AI-assisted demographic labeling with human verification
  6. +
  7. 04_vqa_generation/ โ€” Multi-task annotation generation (summarization, MCQ, temporal)
  8. +
  9. 05_evaluation_inference/ โ€” Model evaluation scripts and metric computation
  10. +
+

+ See the GitHub repository + for detailed documentation and usage examples. +

+
+
+ + +
+

BibTeX

+

+ If you use SONIC-O1 in your research, please cite our paper: +

+ +
+ + +
@article{radwan2026sonico1,
+          title={SONIC-O1: A Real-World Benchmark for Evaluating Multimodal Large Language Models on Audio-Video Understanding},
+          author={Radwan, Ahmed Y and Emmanouilidis, Christos and Tabassum, Hina and Pandya, Deval and Raza, Shaina},
+          journal={arXiv preprint arXiv:2601.21666},
+          year={2026}
+        } 
+ +

+
+
+ + +
+

Acknowledgements

+
+

+ Resources used in preparing this research were provided, in part, by the Province of Ontario, + the Government of Canada through CIFAR, and companies sponsoring the Vector Institute + (http://www.vectorinstitute.ai/#partners). +

+

+ This research was funded by the European Union's Horizon Europe research and innovation programme + under the AIXPERT project (Grant Agreement No. 101214389), which aims to develop an agentic, + multi-layered, GenAI-powered framework for creating explainable, accountable, and transparent AI systems. +

+ +
+ Vector Institute + AIXPERT +
+
+
+ + + +
+ + diff --git a/docs/main.js b/docs/main.js new file mode 100644 index 0000000..5c28140 --- /dev/null +++ b/docs/main.js @@ -0,0 +1,78 @@ +// docs/main.js + +document.addEventListener("DOMContentLoaded", () => { + const nav = document.getElementById("nav"); + const navToggle = document.getElementById("navToggle"); + + // Mobile nav toggle + if (navToggle && nav) { + navToggle.addEventListener("click", () => { + const isOpen = nav.classList.toggle("open"); + navToggle.setAttribute("aria-expanded", String(isOpen)); + }); + + // Close nav on link click (mobile) + nav.querySelectorAll("a").forEach((a) => { + a.addEventListener("click", () => { + if (nav.classList.contains("open")) { + nav.classList.remove("open"); + navToggle.setAttribute("aria-expanded", "false"); + } + }); + }); + } + + // Active section highlight + const links = Array.from(document.querySelectorAll(".nav a")); + const sections = links + .map((a) => document.querySelector(a.getAttribute("href"))) + .filter(Boolean); + + if ("IntersectionObserver" in window && sections.length) { + const obs = new IntersectionObserver( + (entries) => { + // pick the most visible intersecting entry + const visible = entries + .filter((e) => e.isIntersecting) + .sort((a, b) => b.intersectionRatio - a.intersectionRatio)[0]; + + if (!visible) return; + + const id = "#" + visible.target.id; + links.forEach((a) => a.classList.toggle("active", a.getAttribute("href") === id)); + }, + { threshold: [0.2, 0.35, 0.5, 0.65] } + ); + + sections.forEach((s) => obs.observe(s)); + } + + // Copy BibTeX + const copyBtn = document.getElementById("copyBibtex"); + const bibtexBlock = document.getElementById("bibtexBlock"); + const status = document.getElementById("copyStatus"); + + if (copyBtn && bibtexBlock) { + copyBtn.addEventListener("click", async () => { + const text = bibtexBlock.innerText.trim(); + try { + await navigator.clipboard.writeText(text); + if (status) status.textContent = "Copied!"; + copyBtn.textContent = "Copied"; + setTimeout(() => { + copyBtn.textContent = "Copy BibTeX"; + if (status) status.textContent = ""; + }, 1200); + } catch (e) { + // fallback: select text + const range = document.createRange(); + range.selectNodeContents(bibtexBlock); + const sel = window.getSelection(); + sel.removeAllRanges(); + sel.addRange(range); + if (status) status.textContent = "Select + copy (Ctrl/Cmd+C)."; + } + }); + } + }); + \ No newline at end of file diff --git a/docs/style.css b/docs/style.css new file mode 100644 index 0000000..51cc92d --- /dev/null +++ b/docs/style.css @@ -0,0 +1,1285 @@ +/* docs/style.css */ + +:root { + --bg: #ffffff; + --bg-soft: #f6f7fb; + --bg-light: #fafbfd; + --text: #1c1f2a; + --text-muted: #5b6172; + --border: rgba(20, 22, 30, 0.12); + --border-light: rgba(20, 22, 30, 0.06); + --shadow: 0 12px 30px rgba(18, 20, 30, 0.08); + --shadow-sm: 0 8px 18px rgba(18, 20, 30, 0.06); + --brand: #0b2e6f; + --brand-2: #2156c7; + --primary: #2156c7; + --pill: rgba(33, 86, 199, 0.08); + --success: #16a34a; + --warning: #dc2626; + --max: 1120px; +} + +* { + box-sizing: border-box; +} + +html { + scroll-behavior: smooth; +} + +body { + margin: 0; + font-family: ui-sans-serif, -apple-system, BlinkMacSystemFont, "Segoe UI", + Roboto, Helvetica, Arial, "Apple Color Emoji", "Segoe UI Emoji"; + color: var(--text); + background: var(--bg); + line-height: 1.6; + -webkit-font-smoothing: antialiased; +} + +a { + color: var(--brand-2); + text-decoration: none; + transition: color 0.15s ease; +} +a:hover { + text-decoration: underline; +} + +.container { + width: min(var(--max), calc(100% - 2rem)); + margin: 0 auto; +} + +/* ==================== Topbar ==================== */ +.topbar { + position: sticky; + top: 0; + z-index: 50; + background: rgba(255, 255, 255, 0.92); + backdrop-filter: blur(10px); + border-bottom: 1px solid var(--border); +} + +.topbar-inner { + width: min(var(--max), calc(100% - 2rem)); + margin: 0 auto; + padding: 0.75rem 0; + display: flex; + align-items: center; + gap: 1rem; +} + +.brand { + display: flex; + align-items: center; + gap: 0.6rem; + min-width: 160px; +} + +.brand-mark { + color: var(--brand-2); + font-weight: 900; + font-size: 1.5rem; + line-height: 1; +} + +.brand-name { + font-weight: 800; + letter-spacing: 0.2px; + color: var(--text); + font-size: 1.1rem; +} +/* Topbar logos */ +.topbar-logos { + display: flex; + align-items: center; + gap: 0.75rem; + margin-right: 1rem; +} + +.topbar-logos .logo-img { + height: 44px; + width: auto; + object-fit: contain; +} + +.nav { + margin-left: auto; + display: flex; + flex-wrap: wrap; + gap: 0.85rem; + align-items: center; +} + +.nav a { + color: var(--text-muted); + font-weight: 600; + font-size: 0.92rem; + padding: 0.4rem 0.2rem; + border-bottom: 2px solid transparent; + transition: color 0.15s ease, border-color 0.15s ease; +} + +.nav a:hover { + color: var(--text); + text-decoration: none; +} + +.nav a.active { + color: var(--text); + border-bottom-color: var(--brand-2); +} + +.nav-toggle { + margin-left: auto; + display: none; + width: 40px; + height: 36px; + border: 1px solid var(--border); + background: #fff; + border-radius: 10px; + padding: 8px; + cursor: pointer; +} +.nav-toggle span { + display: block; + height: 2px; + margin: 4px 0; + background: var(--text); + border-radius: 2px; +} + +/* ==================== Hero ==================== */ +.hero { + background: linear-gradient(180deg, var(--bg-soft), #fff); + border-bottom: 1px solid var(--border); + padding: 3rem 0 2.5rem; +} + +.hero-content { + max-width: 100%; +} + +/* Hero text - centered like HumaniBench */ +.hero-text { + text-align: center; + max-width: 900px; + margin: 0 auto; +} + +/* Logos row - centered */ +.logo-row { + display: flex; + gap: 1.5rem; + align-items: center; + justify-content: center; + margin-bottom: 1.5rem; +} + +.logo { + width: auto; + object-fit: contain; + filter: none; +} + +.vector-logo { + height: 70px; +} +.aixpert-logo { + height: 60px; + opacity: 0.95; +} + +/* Title - larger and centered */ +.title { + margin: 0; + font-size: clamp(2rem, 5vw, 3.2rem); + font-weight: 700; + letter-spacing: -0.03em; + line-height: 1.15; +} + +/* Authors - styled with links */ +.authors { + margin: 1.5rem 0 0.5rem; + font-size: 1rem; + line-height: 1.8; + color: var(--text); +} + +.authors a { + color: var(--brand-2); + font-weight: 500; +} + +.authors sup { + font-size: 0.75em; + font-weight: 700; +} + +.affiliations { + margin: 0.25rem 0 0; + font-size: 0.95rem; + color: var(--text-muted); +} + +.affiliations sup { + font-size: 0.75em; + font-weight: 700; + margin-right: 0.15rem; +} + +/* Lead paragraph */ +.lead { + margin: 1.5rem auto 0; + font-size: 1.1rem; + color: var(--text); + line-height: 1.7; + max-width: 750px; +} + +.lead strong { + color: var(--brand); + font-weight: 700; +} + +/* Buttons */ +.button-row { + display: flex; + flex-wrap: wrap; + gap: 0.7rem; + margin-top: 1.5rem; + justify-content: center; +} + +.btn { + display: inline-flex; + align-items: center; + justify-content: center; + gap: 0.5rem; + padding: 0.7rem 1.2rem; + border-radius: 10px; + font-weight: 700; + font-size: 0.95rem; + border: 1px solid var(--border); + background: #fff; + color: var(--text); + box-shadow: var(--shadow-sm); + transition: all 0.15s ease; + cursor: pointer; +} + +.btn:hover { + text-decoration: none; + transform: translateY(-2px); + box-shadow: 0 12px 24px rgba(20, 22, 30, 0.12); + border-color: var(--brand-2); +} + +.btn:active { + transform: translateY(0); +} + +.btn-primary { + background: var(--brand-2); + color: #fff; + border-color: var(--brand-2); +} + +.btn-primary:hover { + background: var(--brand); + border-color: var(--brand); +} + +.btn-icon { + width: 18px; + height: 18px; + object-fit: contain; +} + +.btn-small { + padding: 0.5rem 0.9rem; + font-size: 0.9rem; +} + +.btn-copy { + gap: 0.4rem; +} + +.btn-copy svg { + width: 14px; + height: 14px; +} + +/* Badges */ +.badge-row { + margin-top: 1.5rem; + display: flex; + flex-wrap: wrap; + gap: 0.6rem; + justify-content: center; +} + +.pill { + display: inline-flex; + align-items: center; + padding: 0.45rem 0.85rem; + border-radius: 999px; + background: var(--pill); + color: var(--brand); + font-weight: 700; + font-size: 0.9rem; + border: 1px solid rgba(33, 86, 199, 0.14); +} + +.pill-small { + padding: 0.35rem 0.65rem; + font-size: 0.85rem; +} + +/* Hero figure - large and prominent */ +.hero-figure { + margin-top: 3rem; + max-width: 1000px; + margin-left: auto; + margin-right: auto; +} + +.hero-figure figure { + margin: 0; + background: #fff; + border: 1px solid var(--border); + border-radius: 20px; + padding: 16px; + box-shadow: var(--shadow); +} + +.hero-figure img { + width: 100%; + height: auto; + display: block; + border-radius: 14px; +} + +.hero-figure figcaption { + margin-top: 1rem; + color: var(--text-muted); + font-size: 0.95rem; + line-height: 1.6; +} + +.hero-figure figcaption strong { + color: var(--text); + font-weight: 700; +} + +/* ==================== Sections ==================== */ +.section { + padding: 3.5rem 0; +} + +h2 { + margin: 0 0 1.5rem; + font-size: 2.2rem; + font-weight: 700; + letter-spacing: -0.02em; +} + +.h3 { + margin: 0 0 0.75rem; + font-size: 1.3rem; + font-weight: 700; +} + +h4 { + margin: 0 0 0.5rem; + font-size: 1.1rem; + font-weight: 700; +} + +.section-intro { + font-size: 1.05rem; + color: var(--text-muted); + margin-bottom: 1.5rem; + line-height: 1.7; +} + +.muted { + color: var(--text-muted); +} +.small { + font-size: 0.92rem; +} + +/* Abstract box */ +.abstract-box { + background: #fff; + border: 1px solid var(--border); + border-radius: 18px; + padding: 1.75rem 2rem; + box-shadow: var(--shadow-sm); + line-height: 1.8; + font-size: 1.05rem; +} + +.abstract-box p { + margin: 0 0 1rem; +} + +.abstract-box p:last-child { + margin-bottom: 0; +} + +.abstract-box em { + font-style: italic; + color: var(--brand); +} + +.abstract-box strong { + font-weight: 700; + color: var(--text); +} + +/* Info boxes */ +.info-box { + background: var(--bg-light); + border: 1px solid var(--border-light); + border-left: 4px solid var(--brand-2); + border-radius: 12px; + padding: 1.25rem 1.5rem; + margin-top: 1.5rem; +} + +.info-box .h3 { + margin-top: 0; + color: var(--brand); +} + +/* Key findings */ +.key-findings { + background: var(--bg-light); + border: 1px solid var(--border-light); + border-radius: 18px; + padding: 1.5rem 1.75rem; + margin-top: 2rem; +} + +.key-findings .h3 { + margin-top: 0; + color: var(--brand); +} + +/* ==================== Stats Grid ==================== */ +.grid-4 { + display: grid; + grid-template-columns: repeat(4, 1fr); + gap: 1rem; + margin: 1.5rem 0; +} + +.stat { + background: #fff; + border: 1px solid var(--border); + border-radius: 16px; + padding: 1.25rem; + box-shadow: var(--shadow-sm); + text-align: center; +} + +.stat-value { + font-size: 2rem; + font-weight: 900; + color: var(--brand); + line-height: 1; +} + +.stat-label { + color: var(--text-muted); + font-weight: 600; + font-size: 0.9rem; + margin-top: 0.5rem; +} + +/* ==================== Layout Helpers ==================== */ +.two-col { + display: grid; + grid-template-columns: 1fr 1fr; + gap: 1.5rem; + align-items: start; + margin-top: 1.5rem; +} + +.two-col-figure { + display: grid; + grid-template-columns: 1fr 1.2fr; + gap: 2rem; + align-items: start; + margin-top: 1.5rem; +} + +.col-text { + padding-right: 1rem; +} + +.col-figure { + position: sticky; + top: 100px; +} + +.card { + background: #fff; + border: 1px solid var(--border); + border-radius: 18px; + padding: 1.25rem 1.5rem; + box-shadow: var(--shadow-sm); +} + +.link-list { + display: flex; + flex-direction: column; + gap: 0.6rem; +} + +.link-list a { + display: flex; + align-items: center; + gap: 0.5rem; + padding: 0.5rem 0; + font-weight: 600; +} + +.link-icon { + width: 20px; + height: 20px; + object-fit: contain; + flex-shrink: 0; +} + +.bullets { + margin: 0.75rem 0 0; + padding-left: 1.5rem; +} + +.bullets li { + margin: 0.6rem 0; + line-height: 1.6; +} + +.bullets strong { + font-weight: 700; + color: var(--text); +} + +.compact { + margin: 0.5rem 0 0; + padding-left: 1.5rem; +} + +.compact li { + margin: 0.35rem 0; +} + +/* ==================== Topic Categories ==================== */ +.topic-list { + display: flex; + flex-direction: column; + gap: 1rem; +} + +.topic-category h4 { + color: var(--brand); + font-size: 1rem; + margin-bottom: 0.5rem; +} + +/* ==================== Figures ==================== */ +.figure { + margin: 0; + background: #fff; + border: 1px solid var(--border); + border-radius: 18px; + padding: 14px; + box-shadow: var(--shadow-sm); +} + +.figure img { + width: 100%; + height: auto; + display: block; + border-radius: 12px; +} + +.figure figcaption { + margin-top: 0.75rem; + color: var(--text-muted); + font-size: 0.9rem; + line-height: 1.6; +} + +.figure figcaption strong { + color: var(--text); + font-weight: 700; +} + +.figure-large { + margin-top: 1.5rem; +} + +.figure-grid-2 { + display: grid; + grid-template-columns: 1fr 1fr; + gap: 1.5rem; + margin-top: 1.5rem; +} + +.figure-grid-3 { + display: grid; + grid-template-columns: repeat(3, 1fr); + gap: 1.5rem; + margin-top: 1.5rem; +} + +/* ==================== Tasks ==================== */ +.task-grid { + display: grid; + grid-template-columns: repeat(3, 1fr); + gap: 1.25rem; + margin-top: 1.5rem; +} + +.task-card { + background: #fff; + border: 1px solid var(--border); + border-radius: 18px; + padding: 1.5rem; + box-shadow: var(--shadow-sm); + display: flex; + flex-direction: column; +} + +.task-header { + display: flex; + align-items: center; + gap: 0.75rem; + margin-bottom: 0.75rem; +} + +.task-number { + min-width: 48px; + width: 64px; + height: 48px; + border-radius: 50%; + background: var(--brand-2); + color: #fff; + font-weight: 700; + font-size: 1.1rem; + display: flex; + align-items: center; + justify-content: center; + flex-shrink: 0; +} + +.task-header h3 { + margin: 0; + font-size: 1.2rem; +} + +.task-description { + flex: 1; + color: var(--text-muted); + line-height: 1.6; + margin-bottom: 1rem; +} + +.task-metrics { + margin-top: auto; + padding-top: 1rem; + border-top: 1px solid var(--border-light); +} + +.task-metrics strong { + color: var(--brand); + font-size: 0.9rem; +} + +.task-metrics ul { + margin: 0.5rem 0 0; + padding-left: 1.2rem; +} + +.task-metrics li { + font-size: 0.9rem; + color: var(--text-muted); + margin: 0.3rem 0; +} + +.task-stats { + margin-top: 0.75rem; +} + +/* ==================== Pipeline ==================== */ +.pipeline-stages { + display: grid; + grid-template-columns: repeat(5, 1fr); + gap: 1rem; + margin-top: 2rem; +} + +.stage { + background: #fff; + border: 1px solid var(--border); + border-radius: 16px; + padding: 1.25rem; + box-shadow: var(--shadow-sm); + text-align: center; +} + +.stage-number { + width: 56px; + height: 56px; + border-radius: 50%; + background: var(--brand-2); + color: #fff; + font-weight: 700; + font-size: 1.5rem; + display: flex; + align-items: center; + justify-content: center; + margin: 0 auto 0.75rem; +} + +.stage h4 { + font-size: 1rem; + margin-bottom: 0.5rem; +} + +.stage p { + font-size: 0.88rem; + color: var(--text-muted); + line-height: 1.5; + margin: 0; +} + +/* ==================== Tables ==================== */ +.table-wrap { + overflow: auto; + border-radius: 16px; + border: 1px solid var(--border); + background: #fff; + box-shadow: var(--shadow-sm); + margin: 1.5rem 0; +} + +.table { + width: 100%; + border-collapse: collapse; + min-width: 800px; +} + +.table th, +.table td { + padding: 1rem 1.25rem; + border-bottom: 1px solid rgba(20, 22, 30, 0.08); + text-align: left; + font-size: 0.95rem; +} + +.table th { + background: var(--bg-soft); + font-weight: 700; + color: var(--text); + position: sticky; + top: 0; + z-index: 10; +} + +.table .center { + text-align: center; +} +.table .right { + text-align: right; +} + +.table tr.highlight { + background: rgba(33, 86, 199, 0.04); +} + +.table tr.highlight td { + font-weight: 700; +} + +.table tbody tr:hover { + background: rgba(33, 86, 199, 0.02); +} + +.table tbody tr:last-child td { + border-bottom: none; +} + +.table a { + font-weight: 600; +} + +/* Table color indicators for fairness */ +.table .best { + color: var(--success); + font-weight: 700; +} + +.table .worst { + color: var(--warning); + font-weight: 700; +} + +.table-note { + margin-top: 1rem; + font-size: 0.9rem; + color: var(--text-muted); + line-height: 1.6; +} + +.table-note strong { + font-weight: 700; + color: var(--text); +} + +.table-note .best { + display: inline-block; + width: 12px; + height: 12px; + background: var(--success); + border-radius: 2px; + vertical-align: middle; + margin-right: 0.25rem; +} + +.table-note .worst { + display: inline-block; + width: 12px; + height: 12px; + background: var(--warning); + border-radius: 2px; + vertical-align: middle; + margin-right: 0.25rem; +} + +.metric-desc { + font-weight: 500; + font-size: 0.85rem; + color: var(--text-muted); +} + +/* ==================== Conclusion Section ==================== */ +.section-conclusion { + text-align: center; + padding: 4rem 0; + /* Break out of container */ + margin-left: calc(-50vw + 50%); + margin-right: calc(-50vw + 50%); + padding-left: calc(50vw - 50%); + padding-right: calc(50vw - 50%); + background: var(--bg-soft); +} + +.section-conclusion h2 { + font-size: 2.5rem; + margin-bottom: 2rem; +} + +.conclusion-content { + text-align: left; + max-width: 900px; + margin: 0 auto; + width: 100%; + padding: 0 2rem; +} + +.conclusion-content > p { + font-size: 1.05rem; + line-height: 1.8; + color: var(--text-muted); + margin-bottom: 1.5rem; +} + +.highlight-box { + background: linear-gradient(135deg, rgba(33, 86, 199, 0.08), rgba(33, 86, 199, 0.04)); + border: 2px solid var(--brand-2); + border-radius: 18px; + padding: 2rem; + margin: 2rem 0; + display: flex; + gap: 1.5rem; + align-items: start; +} + +.highlight-icon { + font-size: 2.5rem; + flex-shrink: 0; + line-height: 1; +} + +.highlight-content h3 { + margin: 0 0 0.75rem; + font-size: 1.3rem; + color: var(--brand); +} + +.highlight-content p { + margin: 0; + font-size: 1.05rem; + line-height: 1.7; + color: var(--text); +} + +.highlight-content strong { + font-weight: 700; + color: var(--brand); +} + +.conclusion-footer { + margin-top: 2rem; + padding-top: 2rem; + border-top: 1px solid var(--border-light); + font-size: 1rem; + color: var(--text-muted); + line-height: 1.7; +} + +.conclusion-footer strong { + font-weight: 700; + color: var(--text); +} + +/* ==================== Code Blocks ==================== */ +.code-section { + margin-top: 2rem; +} + +.code-card { + position: relative; + background: #0f1526; + border-radius: 16px; + padding: 1.25rem; + box-shadow: var(--shadow); + overflow: auto; + margin-top: 0.75rem; +} + +.code-card pre { + margin: 0; +} + +.code-card code { + color: #e9ecf4; + font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono"; + font-size: 0.9rem; + line-height: 1.6; +} + +.code-copy { + position: absolute; + top: 1rem; + right: 1rem; + background: rgba(255, 255, 255, 0.1); + border: 1px solid rgba(255, 255, 255, 0.2); + color: #fff; + padding: 0.5rem 0.75rem; + border-radius: 8px; + font-size: 0.85rem; + cursor: pointer; + display: flex; + align-items: center; + gap: 0.4rem; + transition: all 0.15s ease; +} + +.code-copy:hover { + background: rgba(255, 255, 255, 0.15); + border-color: rgba(255, 255, 255, 0.3); +} + +.code-copy svg { + width: 14px; + height: 14px; +} + +/* ==================== BibTeX ==================== */ +.bibtex-section { + position: relative; + background: #fff; + border: 1px solid var(--border); + border-radius: 18px; + padding: 1.5rem; + box-shadow: var(--shadow-sm); + margin-top: 1rem; +} + +.bibtex-section .btn-copy { + position: absolute; + top: 1.5rem; + right: 1.5rem; +} + +.bibtex { + margin: 0; + background: #0f1526; + border-radius: 12px; + padding: 1.25rem; + overflow: auto; +} + +.bibtex code { + color: #e9ecf4; + font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono"; + font-size: 0.9rem; + line-height: 1.6; +} + +.copy-status { + margin-top: 0.75rem; + font-size: 0.9rem; + color: var(--success); + text-align: center; +} + +/* ==================== Acknowledgements ==================== */ +#acknowledgements { + background: var(--bg-soft); + border-radius: 0; + padding: 3rem 0; + margin-top: 2rem; + /* Break out of container */ + margin-left: calc(-50vw + 50%); + margin-right: calc(-50vw + 50%); + padding-left: calc(50vw - 50%); + padding-right: calc(50vw - 50%); +} + +.ack-content { + max-width: 900px; + margin: 0 auto; + padding: 0 2rem; +} + +.ack-content p { + line-height: 1.7; + margin-bottom: 1rem; +} + +.ack-logos { + margin-top: 2rem; + display: flex; + gap: 2rem; + align-items: center; + justify-content: center; + flex-wrap: wrap; +} + +.ack-logos img { + height: 70px; + width: auto; + object-fit: contain; +} + +/* ==================== Footer ==================== */ +.footer { + margin-top: 0; + padding: 2.5rem 0; + border-top: 1px solid var(--border); + background: var(--bg-soft); + /* Break out of container */ + margin-left: calc(-50vw + 50%); + margin-right: calc(-50vw + 50%); + padding-left: calc(50vw - 50%); + padding-right: calc(50vw - 50%); +} + +.footer-content { + display: grid; + grid-template-columns: 1fr 1fr; + gap: 2rem; + margin-bottom: 2rem; + max-width: 1120px; + margin-left: auto; + margin-right: auto; + padding: 0 1rem; +} + +.footer-section h3 { + margin: 0 0 1rem; + font-size: 1.1rem; +} + +.footer-section p { + margin: 0 0 0.75rem; + line-height: 1.6; + color: var(--text-muted); +} + +.footer-links { + list-style: none; + padding: 0; + margin: 0; +} + +.footer-links li { + margin: 0.5rem 0; +} + +.footer-links a { + color: var(--text-muted); + font-weight: 600; + transition: color 0.15s ease; +} + +.footer-links a:hover { + color: var(--brand-2); + text-decoration: none; +} + +.footer-bottom { + padding-top: 1.5rem; + border-top: 1px solid var(--border-light); + display: flex; + justify-content: space-between; + align-items: center; + color: var(--text-muted); + font-size: 0.9rem; + max-width: 1120px; + margin: 0 auto; + padding-left: 1rem; + padding-right: 1rem; +} + +.footer-bottom a { + color: var(--text-muted); + font-weight: 600; +} + +.footer-bottom a:hover { + color: var(--brand-2); +} + +/* ==================== Responsive ==================== */ +@media (max-width: 980px) { + .grid-4 { + grid-template-columns: repeat(2, 1fr); + } + .two-col, + .two-col-figure { + grid-template-columns: 1fr; + } + .task-grid { + grid-template-columns: 1fr; + } + .figure-grid-2, + .figure-grid-3 { + grid-template-columns: 1fr; + } + .footer-content { + grid-template-columns: 1fr; + } + .pipeline-stages { + grid-template-columns: 1fr; + } + .col-figure { + position: static; + } + .footer-bottom { + flex-direction: column; + gap: 0.5rem; + text-align: center; + } + + .nav-toggle { + display: block; + } + .nav { + display: none; + width: 100%; + flex-direction: column; + align-items: flex-start; + gap: 0.5rem; + margin-left: 0; + padding: 0.75rem 0 0.5rem; + } + .nav.open { + display: flex; + } + .topbar-inner { + flex-wrap: wrap; + } + + .hero { + padding: 2rem 0 1.5rem; + } + + .title { + font-size: clamp(1.75rem, 6vw, 2.5rem); + } + + .highlight-box { + flex-direction: column; + padding: 1.5rem; + } + .topbar-logos { + order: -1; /* Keeps logos on the left */ + } + +} + +@media (max-width: 640px) { + .vector-logo { + height: 64px; + } + .aixpert-logo { + height: 64px; + } + + .section { + padding: 2.5rem 0; + } + + h2 { + font-size: 1.75rem; + } + + .button-row { + flex-direction: column; + align-items: stretch; + } + + .btn { + justify-content: center; + } + + .code-copy { + position: static; + margin-bottom: 0.75rem; + } + + .bibtex-section .btn-copy { + position: static; + margin-bottom: 1rem; + } + + .topbar-logos .logo-img { + height: 24px; /* Smaller on mobile */ + } + + .brand { + min-width: auto; + } +} + +@media (prefers-reduced-motion: reduce) { + html { + scroll-behavior: auto; + } + .btn, + .nav a, + a { + transition: none; + } +} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 2e90943..fd0101a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ dev = [ "jupyter>=1.0.0", "ipython>=8.12.0", "ipykernel>=6.0.0", + "mypy>=1.19.1", ] [tool.hatch.metadata] @@ -105,3 +106,88 @@ pytorchvideo = { git = "https://github.com/facebookresearch/pytorchvideo.git" } [tool.hatch.build.targets.wheel] packages = ["sonic-o1"] + +[tool.mypy] +follow_imports = "normal" +ignore_missing_imports = false +install_types = true +pretty = true +non_interactive = true +allow_untyped_defs = false +no_implicit_optional = true +check_untyped_defs = true +namespace_packages = true +explicit_package_bases = true +warn_unused_configs = true +allow_subclassing_any = false +allow_untyped_calls = false +allow_incomplete_defs = false +allow_untyped_decorators = false +warn_redundant_casts = true +warn_unused_ignores = true +implicit_reexport = false +strict_equality = true +extra_checks = true +mypy_path = "sonic-o1" + +[tool.ruff] +include = ["*.py", "pyproject.toml", "*.ipynb"] +exclude = [] +line-length = 88 + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +docstring-code-format = true + +[tool.ruff.lint] +select = [ + "A", # flake8-builtins + "B", # flake8-bugbear + "COM", # flake8-commas + "C4", # flake8-comprehensions + "RET", # flake8-return + "SIM", # flake8-simplify + "ICN", # flake8-import-conventions + "Q", # flake8-quotes + "RSE", # flake8-raise + "D", # pydocstyle + "E", # pycodestyle + "F", # pyflakes + "I", # isort + "W", # pycodestyle + "N", # pep8-naming + "ERA", # eradicate + "PL", # pylint +] + +fixable = ["A", "B", "COM", "C4", "RET", "SIM", "ICN", "Q", "RSE", "D", "E", "F", "I", "W", "N", "ERA", "PL"] +ignore = [ + "B905", # `zip()` without an explicit `strict=` parameter + "E501", # line too long + "D203", # 1 blank line required before class docstring + "D213", # Multi-line docstring summary should start at the second line + "PLR2004", # Replace magic number with named constant + "PLR0913", # Too many arguments + "COM812", # Missing trailing comma + "ERA001", # Found commented-out code (too many false positives with math comments) + "A001", # Ignore variable `input` is shadowing a Python builtin (common for torch) + "A002", # Ignore variable `input` is shadowing a Python builtin in function (common for torch) + "D301", # r-strings for docstrings with backslashes +] + +# Ignore import violations in all `__init__.py` files. +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["E402", "F401", "F403", "F811"] + +[tool.ruff.lint.pep8-naming] +ignore-names = ["X*", "setUp"] + +[tool.ruff.lint.isort] +lines-after-imports = 2 + +[tool.ruff.lint.pydocstyle] +convention = "numpy" + +[tool.ruff.lint.pycodestyle] +max-doc-length = 88 diff --git a/sonic-o1/01_data_curation/README.md b/sonic-o1/01_data_curation/README.md index 0e4c0f0..daca2a8 100644 --- a/sonic-o1/01_data_curation/README.md +++ b/sonic-o1/01_data_curation/README.md @@ -3,14 +3,134 @@ This directory contains the YouTube video metadata scraping and parsing pipeline for the **sonic-o1** dataset. ## Overview -- The data curation process has two main steps: - 1. **YouTube Metadata Scraping** โ€” Collect video metadata from YouTube based on topics and demographics. - 2. **Topic Parsing** โ€” Process and filter metadata to create **quality-annotated** datasets, then download videos/audio/captions. -- **Execution note (important):** Step 2 requires a **manual quality review** step before you can run `parse_topic.py`. + +### Directory Structure + +``` +01_data_curation/ +โ”œโ”€โ”€ youtube_metadata_scraper.py # Step 1: Scrapes YouTube metadata +โ”œโ”€โ”€ parse_topic.py # Step 3: Parses and downloads videos +โ”œโ”€โ”€ config.yaml # Configuration file for both scripts +โ”œโ”€โ”€ .env # API keys (create this file) +โ”œโ”€โ”€ README.md # This file +โ”‚ +โ”œโ”€โ”€ videos_Unfiltered/ # Output from Step 1 (auto-generated) +โ”‚ โ”œโ”€โ”€ 01_Patient-Doctor_Consultations/ +โ”‚ โ”œโ”€โ”€ 02_Job_Interviews/ +โ”‚ โ””โ”€โ”€ ... +โ”‚ +โ”œโ”€โ”€ videos_QualityAnnotated/ # Output from Step 2 (manual review) +โ”‚ โ”œโ”€โ”€ 01_Patient-Doctor_Consultations/ +โ”‚ โ”œโ”€โ”€ 02_Job_Interviews/ +โ”‚ โ””โ”€โ”€ ... +โ”‚ +โ””โ”€โ”€ dataset/ # Output from Step 3 (auto-generated) + โ”œโ”€โ”€ videos/ + โ”œโ”€โ”€ audios/ + โ””โ”€โ”€ captions/ +``` + +### Pipeline Workflow + +The data curation pipeline consists of four main steps: + +``` +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Step 1: YouTube Metadata Scraping โ”‚ +โ”‚ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ โ”‚ +โ”‚ Input: config.yaml, YouTube Data API โ”‚ +โ”‚ Output: videos_Unfiltered/ โ”‚ +โ”‚ โ”œโ”€โ”€ 01_Patient-Doctor_Consultations/ โ”‚ +โ”‚ โ”‚ โ”œโ”€โ”€ *_metadata.json โ”‚ +โ”‚ โ”‚ โ”œโ”€โ”€ *_metadata.csv โ”‚ +โ”‚ โ”‚ โ””โ”€โ”€ *_summary.json โ”‚ +โ”‚ โ””โ”€โ”€ all_topics_combined.csv โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + โ†“ +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Step 2: Manual Quality Review (Required) โ”‚ +โ”‚ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ โ”‚ +โ”‚ Input: videos_Unfiltered/ โ”‚ +โ”‚ Output: videos_QualityAnnotated/ โ”‚ +โ”‚ โ”œโ”€โ”€ 01_Patient-Doctor_Consultations/ โ”‚ +โ”‚ โ”‚ โ””โ”€โ”€ *_metadata.json (+ Qualitylabel field) โ”‚ +โ”‚ โ””โ”€โ”€ ... โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + โ†“ +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Step 3: Topic Parsing and Download โ”‚ +โ”‚ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ โ”‚ +โ”‚ Input: videos_QualityAnnotated/ โ”‚ +โ”‚ Output: dataset/ โ”‚ +โ”‚ โ”œโ”€โ”€ videos/01_Patient-Doctor_Consultations/ โ”‚ +โ”‚ โ”‚ โ”œโ”€โ”€ video_001.mp4 โ”‚ +โ”‚ โ”‚ โ””โ”€โ”€ metadata.json โ”‚ +โ”‚ โ”œโ”€โ”€ audios/01_Patient-Doctor_Consultations/ โ”‚ +โ”‚ โ”‚ โ”œโ”€โ”€ audio_001.m4a โ”‚ +โ”‚ โ”‚ โ””โ”€โ”€ metadata.json โ”‚ +โ”‚ โ””โ”€โ”€ captions/01_Patient-Doctor_Consultations/ โ”‚ +โ”‚ โ”œโ”€โ”€ caption_001.srt โ”‚ +โ”‚ โ”œโ”€โ”€ needs_whisper.txt โ”‚ +โ”‚ โ””โ”€โ”€ metadata.json โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + โ†“ +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Step 4: Next Steps โ”‚ +โ”‚ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ โ”‚ +โ”‚ โ€ข Check needs_whisper.txt for videos requiring transcription โ”‚ +โ”‚ โ€ข Proceed to 02_* directory for transcription workflow โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ +``` + +### Topics Covered + +The pipeline processes 13 diverse topics: +1. Patient-Doctor Consultations +2. Job Interviews +3. Parent-Teacher Conferences +4. Customer Service Interactions +5. Courtroom Proceedings +6. Emergency Response Scenarios +7. Public Transportation Conflicts +8. Workplace Team Meetings +9. Housing/Apartment Tours +10. Restaurant Service Encounters +11. Mental Health Counseling +12. Community Town Halls +13. Olympics + +### Quality Filtering + +The scraper applies research-based quality filters: +- **Duration**: 30 seconds to 60 minutes (configurable) +- **Engagement**: Minimum views, like/comment ratios +- **Clickbait Detection**: Filters extreme clickbait patterns +- **Spam Detection**: Removes spam and low-quality content +- **License**: Filters by Creative Commons license (default) + +### Diversity Considerations + +The pipeline ensures demographic and content diversity: +- **Multi-dimensional search queries**: race, gender, age, language +- **Balanced selection**: across demographics +- **Duration distribution**: 40% short, 40% medium, 20% long + +### Features + +- Collects comprehensive metadata (views, likes, captions, duration, etc.) +- Generates demographically diverse search queries +- Supports incremental collection (reruns add new videos without duplicates) +- Downloads videos, extracts audio, and downloads captions +- Creates diverse selections across demographics and durations + +--- ## Prerequisites -1. **Required packages/environment:** All required Python packages are included in the projectโ€™s `requirements_venv.txt` (see `../../requirements_venv.txt`). -2. **Additional requirements:** `ffmpeg` (required for audio extraction; install separately) + +1. **Install Python packages** + - All required Python packages are included in `../../requirements_venv.txt` + +2. **Install ffmpeg** (required for audio extraction) ```bash # Linux sudo apt-get install ffmpeg @@ -20,209 +140,145 @@ This directory contains the YouTube video metadata scraping and parsing pipeline # Conda conda install -c conda-forge ffmpeg + ``` -3. **API setup:** Get a YouTube Data API v3 key (Google Cloud Console) and create a `.env` file in this directory. - +3. **Set up YouTube Data API v3 key** + - Get an API key from Google Cloud Console + - Create a `.env` file in the `01_data_curation` directory: + ```bash + YT_SCRAP_API=your_youtube_api_key_here + ``` + +4. **Configure settings** + - Edit `config.yaml` to customize: + - API rate limits and search parameters + - Directory paths + - Video filtering criteria (duration, quality, demographics) + - Collection targets (videos per topic/query) + +## Step 1: YouTube Metadata Scraping + +1. **Choose topics** (optional) + - Default: processes all 13 topics + - To process specific topics, edit the loop in `youtube_metadata_scraper.py` (around lines 995-1000): + ```python + for topic_id in range(1, 13): + ``` + - For a free tier usage a (3 topics with 25 videos per per run / Day) is optimal + +2. **Configure parameters in `config.yaml`** (optional) + - `videos_per_topic`: target videos per topic (default: 100) + - `videos_per_query`: videos per search query (default: 15) + - `video_duration`: "short" | "medium" | "long" | "any" + - `years_back`: how many years back to search (default: 5) + - `video_license`: "creativeCommon" | "any" + +3. **Run the scraper** ```bash - # Use the environment variable name expected by your codebase. - YT_SCRAP_API=your_youtube_api_key_here + python youtube_metadata_scraper.py ``` -4. **Configuration:** Edit `config.yaml` to customize: - * API rate limits and search parameters - * Directory paths - * Video filtering criteria (duration, quality, demographics) - * Collection targets (videos per topic/query) +4. **Output** + ``` + videos_Unfiltered/ + โ”œโ”€โ”€ 01_Patient-Doctor_Consultations/ + โ”‚ โ”œโ”€โ”€ Patient-Doctor_Consultations_metadata.json + โ”‚ โ”œโ”€โ”€ Patient-Doctor_Consultations_metadata.csv + โ”‚ โ””โ”€โ”€ Patient-Doctor_Consultations_summary.json + โ”œโ”€โ”€ 02_Job_Interviews/ + โ”‚ โ””โ”€โ”€ ... + โ””โ”€โ”€ all_topics_combined.csv + ``` -## Execution (recommended workflow) +## Step 2: Manual Quality Review (Required) -1. **Configure** +1. **Create quality-annotated directory** + - Use tools in `../huggingface_review_template/` + - Create `videos_QualityAnnotated/` following the template structure - * Edit `config.yaml` with your preferences. -2. **Run Step 1: scrape metadata** +2. **Review and annotate videos** + - Manually review Step 1 outputs from `videos_Unfiltered/` + - Add a `Qualitylabel` field to each video in the per-topic metadata JSON + - Mark high-quality videos as `Qualitylabel: "Good"` - ```bash - python youtube_metadata_scraper.py +3. **Output** + ``` + videos_QualityAnnotated/ + โ”œโ”€โ”€ 01_Patient-Doctor_Consultations/ + โ”‚ โ””โ”€โ”€ Patient-Doctor_Consultations_metadata.json # With Qualitylabel field + โ”œโ”€โ”€ 02_Job_Interviews/ + โ”‚ โ””โ”€โ”€ Job_Interviews_metadata.json + โ””โ”€โ”€ ... ``` - * Output: `videos_Unfiltered/` containing per-topic metadata files + `all_topics_combined.csv`. -3. **Manual quality review (required before Step 2)** +## Step 3: Topic Parsing and Download - * Use tools in `../huggingface_review_template/`. - * Create `videos_QualityAnnotated/` following the template structure. - * Add a `Qualitylabel` field to each video in the per-topic metadata JSON. - * The parsing step filters for `Qualitylabel == "Good"`. -4. **Run Step 2: parse + download** +1. **Configure maximum videos per topic** (optional) + - Edit `MAX_COUNT` in `parse_topic.py` (around line 541): + ```python + MAX_COUNT = 25 + ``` +2. **Run the parser** ```bash python parse_topic.py ``` - * Output: `dataset/` containing `videos/`, `audios/`, and `captions/`. -5. **Next steps** - - * Videos missing captions are listed in `dataset/captions/*/needs_whisper.txt`. - * Proceed to the transcription workflow (see `02_` directory). - -## Step 1 โ€” YouTube Metadata Scraping (`youtube_metadata_scraper.py`) - -* Purpose: Collect video metadata from YouTube across 13 topics with demographic diversity. -* Features: - - * Searches across multiple topics (Patient-Doctor Consultations, Job Interviews, etc.) - * Generates demographically diverse queries (race, gender, age, language) - * Filters for quality using engagement metrics and clickbait detection - * Collects comprehensive metadata (views, likes, captions, duration, etc.) - * Supports incremental collection (adds new videos without duplicates) -* Run: - - ```bash - python youtube_metadata_scraper.py - ``` -* Key `config.yaml` settings: - - * `videos_per_topic`: target videos per topic (default: 100) - * `videos_per_query`: videos per search query (default: 15) - * `video_duration`: "short" | "medium" | "long" | "any" - * `years_back`: how many years back to search (default: 5) - * `video_license`: "creativeCommon" | "any" -* Output structure: - - ```text - videos_Unfiltered/ - โ”œโ”€โ”€ 01_Patient-Doctor_Consultations/ - โ”‚ โ”œโ”€โ”€ Patient-Doctor_Consultations_metadata.json - โ”‚ โ”œโ”€โ”€ Patient-Doctor_Consultations_metadata.csv - โ”‚ โ””โ”€โ”€ Patient-Doctor_Consultations_summary.json - โ”œโ”€โ”€ 02_Job_Interviews/ - โ”‚ โ””โ”€โ”€ ... - โ””โ”€โ”€ all_topics_combined.csv - ``` -* Batch processing: - - * The main function processes topics in ranges (edit the loop near the bottom of the script, e.g. lines ~995โ€“1000): - - ```python - for topic_id in range(1, 13): - ``` - -## Step 2 โ€” Topic Parsing (`parse_topic.py`) - -* **Prerequisite reminder:** before running, you must: - - * Create `videos_QualityAnnotated/` following `../huggingface_review_template/` - * Manually review/annotate Step 1 outputs - * Add quality labels in the JSON (`Qualitylabel: "Good"` for selected videos) -* What it does: - - * Loads quality-annotated metadata from `videos_QualityAnnotated/` - * Filters videos by: - - * `Qualitylabel == "Good"` - * `copyright_notice == "creativeCommon"` - * Language: English audio or default language - * Downloads videos, extracts audio, downloads captions - * Creates diverse selections across demographics and durations - * Supports incremental additions (default max 25 videos/topic) -* Required input structure: - - ```text - videos_QualityAnnotated/ - โ”œโ”€โ”€ 01_Patient-Doctor_Consultations/ - โ”‚ โ””โ”€โ”€ Patient-Doctor_Consultations_metadata.json # With Qualitylabel field added - โ”œโ”€โ”€ 02_Job_Interviews/ - โ”‚ โ””โ”€โ”€ Job_Interviews_metadata.json - โ””โ”€โ”€ ... - ``` -* Run: - - ```bash - python parse_topic.py - ``` -* Max videos/topic: - - * Edit `MAX_COUNT` in `parse_topic.py` (around line ~541): - - ```python - MAX_COUNT = 25 - ``` -* Output structure: - - ```text - dataset/ - โ”œโ”€โ”€ videos/ - โ”‚ โ””โ”€โ”€ 01_Patient-Doctor_Consultations/ - โ”‚ โ”œโ”€โ”€ video_001.mp4 - โ”‚ โ”œโ”€โ”€ video_002.mp4 - โ”‚ โ””โ”€โ”€ metadata.json - โ”œโ”€โ”€ audios/ - โ”‚ โ””โ”€โ”€ 01_Patient-Doctor_Consultations/ - โ”‚ โ”œโ”€โ”€ audio_001.m4a - โ”‚ โ”œโ”€โ”€ audio_002.m4a - โ”‚ โ””โ”€โ”€ metadata.json - โ””โ”€โ”€ captions/ - โ””โ”€โ”€ 01_Patient-Doctor_Consultations/ - โ”œโ”€โ”€ caption_001.srt - โ”œโ”€โ”€ caption_002.srt - โ”œโ”€โ”€ needs_whisper.txt - โ””โ”€โ”€ metadata.json - ``` - -## Topics covered - -1. Patient-Doctor Consultations -2. Job Interviews -3. Parent-Teacher Conferences -4. Customer Service Interactions -5. Courtroom Proceedings -6. Emergency Response Scenarios -7. Public Transportation Conflicts -8. Workplace Team Meetings -9. Housing/Apartment Tours -10. Restaurant Service Encounters -11. Mental Health Counseling -12. Community Town Halls -13. Olympics +3. **Output** + ``` + dataset/ + โ”œโ”€โ”€ videos/ + โ”‚ โ””โ”€โ”€ 01_Patient-Doctor_Consultations/ + โ”‚ โ”œโ”€โ”€ video_001.mp4 + โ”‚ โ”œโ”€โ”€ video_002.mp4 + โ”‚ โ””โ”€โ”€ metadata.json + โ”œโ”€โ”€ audios/ + โ”‚ โ””โ”€โ”€ 01_Patient-Doctor_Consultations/ + โ”‚ โ”œโ”€โ”€ audio_001.m4a + โ”‚ โ”œโ”€โ”€ audio_002.m4a + โ”‚ โ””โ”€โ”€ metadata.json + โ””โ”€โ”€ captions/ + โ””โ”€โ”€ 01_Patient-Doctor_Consultations/ + โ”œโ”€โ”€ caption_001.srt + โ”œโ”€โ”€ caption_002.srt + โ”œโ”€โ”€ needs_whisper.txt + โ””โ”€โ”€ metadata.json + ``` -## Quality filtering +## Step 4: Next Steps -* Duration: 30 seconds to 60 minutes (configurable) -* Engagement: minimum views, like/comment ratios -* Clickbait detection: filters extreme patterns -* Spam detection: removes spam/low-quality content -* License: Creative Commons by default +1. **Check for videos needing transcription** + - Videos missing captions are listed in `dataset/captions/*/needs_whisper.txt` -## Diversity considerations +2. **Proceed to transcription workflow** + - See `02_` directory for transcription steps -* Multi-dimensional search queries (race, gender, age, language) -* Balanced selection across demographics -* Duration distribution target (40% short, 40% medium, 20% long) +--- ## Troubleshooting -* โ€œNo videos meet the criteriaโ€ - - * Ensure annotated metadata has `Qualitylabel: "Good"` - * Verify `copyright_notice == "creativeCommon"` - * Check language fields (`default_language` / `default_audio_language` include `"en"`) -* API rate limiting +### "No videos meet the criteria" +- Ensure annotated metadata has `Qualitylabel: "Good"` +- Verify `copyright_notice == "creativeCommon"` +- Check language fields (`default_language` / `default_audio_language` include `"en"`) - * Increase `rate_limit_delay` in `config.yaml` - * Process topics in smaller batches -* ffmpeg not found +### API rate limiting +- Increase `rate_limit_delay` in `config.yaml` +- Process topics in smaller batches - * Install via the commands listed under **Prerequisites โ†’ Additional requirements** +### ffmpeg not found +- Install via the commands listed under Prerequisites section ## Files -* `youtube_metadata_scraper.py` โ€” main YouTube metadata collection -* `parse_topic.py` โ€” download + processing -* `config.yaml` โ€” configuration for both scripts -* `.env` โ€” environment variables (API keys) +- `youtube_metadata_scraper.py` โ€” main YouTube metadata collection +- `parse_topic.py` โ€” download and processing +- `config.yaml` โ€” configuration for both scripts +- `.env` โ€” environment variables (API keys) ## Notes -* Incremental collection is supported (reruns add new videos without duplicates). -* Quality filtering is intended to support research-grade dataset integrity. -* Respect YouTube Terms of Service and copyright laws. -* Video processing is storage/compute intensive; set `MAX_COUNT` accordingly. +- Incremental collection is supported (reruns add new videos without duplicates) +- Quality filtering is intended to support research-grade dataset integrity +- Respect YouTube Terms of Service and copyright laws +- Video processing is storage/compute intensive; set `MAX_COUNT` accordingly diff --git a/sonic-o1/01_data_curation/config.yaml b/sonic-o1/01_data_curation/config.yaml index df4e756..fefa194 100644 --- a/sonic-o1/01_data_curation/config.yaml +++ b/sonic-o1/01_data_curation/config.yaml @@ -20,7 +20,7 @@ search_settings: include_caption_text: true remove_duplicates: true years_back: 5 - video_license: 'creativeCommon' + video_license: 'creativeCommon' # Duration categories reference duration_categories: @@ -49,4 +49,153 @@ demographics: - Hindi - Arabic - Spanish - - Chinese \ No newline at end of file + - Chinese + +# Topics for video scraping +# Each topic has search terms that can be tuned/modified +topics: + 1: + name: "Patient-Doctor Consultations" + search_terms: + - "doctor patient conversation full session" + - "clinic consultation recording" + - "telehealth visit recording" + - "primary care consultation unedited" + - "medical intake interview full visit" + focus: "Medical communication, empathy, diagnosis discussions" + channel_id: null + + 2: + name: "Job Interviews" + search_terms: + - "panel interview full interview" + - "candidate interview recording" + - "on site interview hiring manager" + - "technical interview session full" + - "HR screening call recording" + focus: "Professional interaction, emotion detection, body language" + channel_id: null + + 3: + name: "Parent-Teacher Conferences" + search_terms: + - "parent teacher conference recording" + - "PTC meeting full session" + - "student progress meeting recording" + - "IEP meeting full" + - "teacher parent meeting unedited" + focus: "Educational settings, conflict resolution, child advocacy" + channel_id: null + + 4: + name: "Customer Service Interactions" + search_terms: + - "front desk dispute" + - "restaurant customer service" + - "Customer Service SNL" + - "employee vs customer -karen -compilation" + - "store manager customer service" + - "angry customer store footage -staged" + focus: "Complaint handling, emotion regulation, problem-solving" + channel_id: null + + 5: + name: "Courtroom Proceedings" + search_terms: + - "oral argument" + - '"sentencing hearing" "full recording" courtroom' + - "municipal court arraignment calendar session full" + - '"sentencing hearing" full recording courtroom' + - "small claims court full hearing official recording" + - "mock trial full" + focus: "Legal settings, testimony analysis, fairness assessment" + channel_id: null + + 6: + name: "Emergency Response Scenarios" + search_terms: + - 'firefighters highway incident "full response" -shorts -news' + - "bodycam police" + - "bodycam police rescue" + - "Law&Crime BodyCam" + - "real emergency calls paramedic" + focus: "Crisis management, first aid, triage decisions" + channel_id: null + + 7: + name: "Public Transportation Conflicts" + search_terms: + - "bus passenger fight driver -news -compilation" + - "bar fight" + - "train passenger confrontation cctv" + - "airport security passenger meltdown bodycam" + - "grocery store argument customer" + - "parking lot dispute road rage" + focus: "Social etiquette, accessibility, conflict de-escalation" + channel_id: null + + 8: + name: "Workplace Team Meetings" + search_terms: + - '"team meeting" recording Zoom -webinar -tutorial -class' + - "daily standup meeting" + - "scrum meeting real team -demo -example" + - "sprint review meeting" + - 'workplace "meeting" recording' + focus: "Collaboration, leadership dynamics, idea contribution" + channel_id: null + + 9: + name: "Housing/Apartment Tours" + search_terms: + - "open house walkthrough agent client" + - "apartment tour with agent full" + - "rental inspection landlord tenant recording" + - "home showing buyer walkthrough" + - "accessible apartment tour elevator ramp" + focus: "Real estate interactions, accessibility features" + channel_id: null + + 10: + name: "Restaurant Service Encounters" + search_terms: + - "restaurant vlog" + - "waitress day in the life" + - "restaurant behind the scenes" + - "food service worker" + - "restaurant review visit" + focus: "Service quality, complaint handling, accessibility" + channel_id: null + + 11: + name: "Mental Health Counseling" + search_terms: + - "counseling session demonstration" + - "therapy role play training" + - "mock therapy session psychology" + - "counseling techniques demonstration video" + - "therapeutic communication examples" + focus: "Therapeutic alliance, emotional support, crisis intervention" + channel_id: null + + 12: + name: "Community Town Halls" + search_terms: + - '"town hall" "full recording" community Q&A' + - "town hall meeting complete" + - '"city council" meeting "livestream archive" -highlights -clips' + - "community meeting local government" + - '"Islamic center" community forum full -news -compilation' + focus: "Civic engagement, diverse viewpoints, accessibility" + channel_id: null + + 13: + name: "Olympics" + search_terms: + - "olympic games highlights" + - "summer olympics events" + - "winter olympics full coverage" + - "olympic moments compilation" + - "olympics replay full event" + focus: "Sports videos" + channel_id: null diff --git a/sonic-o1/01_data_curation/parse_topic.py b/sonic-o1/01_data_curation/parse_topic.py index b9899e7..23a164a 100644 --- a/sonic-o1/01_data_curation/parse_topic.py +++ b/sonic-o1/01_data_curation/parse_topic.py @@ -1,30 +1,16 @@ -"""Topic Parsing and Video Download Script. - -IMPORTANT PREREQUISITE: -Before running this script, you MUST: -1. Create a 'videos_QualityAnnotated' directory in your base directory -2. Structure it following the template in ../huggingface_review_template/ -3. Manually review and quality-annotate metadata from - youtube_metadata_scraper.py output -4. Ensure each video in the metadata has a 'Qualitylabel' field set to - "Good" for videos that should be included in the final dataset - -This script filters videos based on: -- Qualitylabel == "Good" -- copyright_notice == "creativeCommon" -- Language: English (default_language or default_audio_language) - -The script will: -- Load metadata from videos_QualityAnnotated//_metadata.json -- Download filtered videos, extract audio, and download captions -- Create a balanced dataset across demographics and durations -- Track videos needing Whisper transcription in needs_whisper.txt files +"""parse_topic.py. + +Processes quality-annotated video metadata by filtering, downloading videos, +extracting audio, and creating a balanced dataset across demographics and durations. + +Author: SONIC-O1 Team """ + import json import os -import random import subprocess import sys +import traceback from collections import defaultdict from datetime import datetime from pathlib import Path @@ -62,15 +48,14 @@ def create_directories(self, topic_name: str): topic_dir.mkdir(parents=True, exist_ok=True) print(f"โœ“ Created directories for {topic_name}") - def get_existing_video_info( - self, topic_name: str - ) -> tuple[int, Set[str], int]: + def get_existing_video_info(self, topic_name: str) -> tuple[int, Set[str], int]: """Get information about existing videos in a topic. Args: topic_name: Name of the topic. - Returns: + Returns + ------- Tuple of (current_count, existing_video_ids, next_video_number). """ metadata_path = self.videos_dir / topic_name / "metadata.json" @@ -79,26 +64,21 @@ def get_existing_video_info( return 0, set(), 1 try: - with open(metadata_path, 'r') as f: + with open(metadata_path, "r") as f: existing_metadata = json.load(f) current_count = len(existing_metadata) - existing_video_ids = { - video['video_id'] for video in existing_metadata - } + existing_video_ids = {video["video_id"] for video in existing_metadata} video_numbers = [] for video in existing_metadata: - video_num_str = video.get('video_number', '000') + video_num_str = video.get("video_number", "000") try: video_numbers.append(int(video_num_str)) except ValueError: continue - next_number = ( - max(video_numbers) + 1 if video_numbers - else current_count + 1 - ) + next_number = max(video_numbers) + 1 if video_numbers else current_count + 1 print( f" Found {current_count} existing videos, " @@ -116,10 +96,11 @@ def load_metadata(self, metadata_path: str) -> List[Dict]: Args: metadata_path: Path to metadata file. - Returns: + Returns + ------- List of video metadata dictionaries. """ - with open(metadata_path, 'r') as f: + with open(metadata_path, "r") as f: return json.load(f) def filter_videos( @@ -131,168 +112,234 @@ def filter_videos( metadata: List of video metadata. existing_video_ids: Set of already-downloaded video IDs. - Returns: + Returns + ------- Filtered list of video metadata. """ filtered = [] excluded_existing = 0 for video in metadata: - if video.get('video_id') in existing_video_ids: + if video.get("video_id") in existing_video_ids: excluded_existing += 1 continue - if video.get('Qualitylabel') != 'Good': + if video.get("Qualitylabel") != "Good": continue - if video.get('copyright_notice') != 'creativeCommon': + if video.get("copyright_notice") != "creativeCommon": continue - default_lang = video.get('default_language', '').lower() - audio_lang = video.get('default_audio_language', '').lower() - if default_lang != 'en' and audio_lang != 'en': + default_lang = video.get("default_language", "").lower() + audio_lang = video.get("default_audio_language", "").lower() + if default_lang != "en" and audio_lang != "en": continue filtered.append(video) - print( - f"โœ“ Filtered {len(filtered)} new videos from " - f"{len(metadata)} total" - ) + print(f"โœ“ Filtered {len(filtered)} new videos from {len(metadata)} total") print(f" (Excluded {excluded_existing} already-downloaded videos)") return filtered - def select_videos( - self, filtered_videos: List[Dict], needed_count: int, - start_number: int - ) -> List[Dict]: - """Select videos with maximum diversity. + def _calculate_duration_distribution( + self, filtered_videos: List[Dict] + ) -> Dict[str, Dict[str, List[Dict]]]: + """Group videos by duration category and demographic label. Args: filtered_videos: List of filtered video metadata. + + Returns + ------- + Nested dict: duration_category -> demographic_label -> list of videos. + """ + duration_demo_groups = defaultdict(lambda: defaultdict(list)) + + for video in filtered_videos: + duration_cat = video.get("duration_category", "unknown") + demo_label = video.get("demographic_label", "general") + duration_demo_groups[duration_cat][demo_label].append(video) + + return duration_demo_groups + + def _calculate_target_counts( + self, needed_count: int, duration_demo_groups: Dict[str, Dict[str, List[Dict]]] + ) -> Dict[str, int]: + """Calculate initial target distribution across duration categories. + + Args: needed_count: Number of videos needed. - start_number: Starting video number. + duration_demo_groups: Grouped videos by duration and demographics. - Returns: - Selected videos with diversity across demographics and duration. + Returns + ------- + Dict mapping duration categories to target counts. """ - if len(filtered_videos) == 0: - print("โœ“ No new videos to select") - return [] + target_distribution = { + "short": int(needed_count * 0.4), + "medium": int(needed_count * 0.4), + "long": int(needed_count * 0.2), + } - if len(filtered_videos) <= needed_count: - selected = filtered_videos - print(f"โœ“ Selected all {len(selected)} available videos") - else: - print( - f" Selecting {needed_count} from {len(filtered_videos)} " - f"videos with diversity optimization..." - ) + available_counts = { + cat: sum(len(demos) for demos in groups.values()) + for cat, groups in duration_demo_groups.items() + } - duration_demo_groups = defaultdict(lambda: defaultdict(list)) + print( + f" Available: Short={available_counts.get('short', 0)}, " + f"Medium={available_counts.get('medium', 0)}, " + f"Long={available_counts.get('long', 0)}" + ) - for video in filtered_videos: - duration_cat = video.get('duration_category', 'unknown') - demo_label = video.get('demographic_label', 'general') - duration_demo_groups[duration_cat][demo_label].append(video) + adjusted_targets = {} + for cat in ["short", "medium", "long"]: + available = available_counts.get(cat, 0) + target = target_distribution.get(cat, 0) + adjusted_targets[cat] = min(target, available) - target_distribution = { - 'short': int(needed_count * 0.4), - 'medium': int(needed_count * 0.4), - 'long': int(needed_count * 0.2), - } + return adjusted_targets - available_counts = { - cat: sum(len(demos) for demos in groups.values()) - for cat, groups in duration_demo_groups.items() - } + def _adjust_targets_for_availability( + self, + adjusted_targets: Dict[str, int], + needed_count: int, + duration_demo_groups: Dict[str, Dict[str, List[Dict]]], + ) -> Dict[str, int]: + """Redistribute remaining slots if targets couldn't be met. - print( - f" Available: Short={available_counts.get('short', 0)}, " - f"Medium={available_counts.get('medium', 0)}, " - f"Long={available_counts.get('long', 0)}" - ) + Args: + adjusted_targets: Initial adjusted target counts. + needed_count: Total number of videos needed. + duration_demo_groups: Grouped videos by duration and demographics. + + Returns + ------- + Final adjusted target counts after redistribution. + """ + available_counts = { + cat: sum(len(demos) for demos in groups.values()) + for cat, groups in duration_demo_groups.items() + } + + total_assigned = sum(adjusted_targets.values()) + remaining = needed_count - total_assigned - adjusted_targets = {} - for cat in ['short', 'medium', 'long']: + if remaining > 0: + for cat in ["medium", "short", "long"]: available = available_counts.get(cat, 0) - target = target_distribution.get(cat, 0) - adjusted_targets[cat] = min(target, available) - - total_assigned = sum(adjusted_targets.values()) - remaining = needed_count - total_assigned - - if remaining > 0: - for cat in ['medium', 'short', 'long']: - available = available_counts.get(cat, 0) - current = adjusted_targets.get(cat, 0) - can_add = min(remaining, available - current) - if can_add > 0: - adjusted_targets[cat] = ( - adjusted_targets.get(cat, 0) + can_add - ) - remaining -= can_add - if remaining == 0: - break + current = adjusted_targets.get(cat, 0) + can_add = min(remaining, available - current) + if can_add > 0: + adjusted_targets[cat] = adjusted_targets.get(cat, 0) + can_add + remaining -= can_add + if remaining == 0: + break - print( - f" Target: Short={adjusted_targets.get('short', 0)}, " - f"Medium={adjusted_targets.get('medium', 0)}, " - f"Long={adjusted_targets.get('long', 0)}" - ) + print( + f" Target: Short={adjusted_targets.get('short', 0)}, " + f"Medium={adjusted_targets.get('medium', 0)}, " + f"Long={adjusted_targets.get('long', 0)}" + ) - selected = [] + return adjusted_targets - for duration_cat in ['short', 'medium', 'long']: - target_count = adjusted_targets.get(duration_cat, 0) - if target_count == 0: - continue + def _select_videos_round_robin( + self, + duration_demo_groups: Dict[str, Dict[str, List[Dict]]], + adjusted_targets: Dict[str, int], + ) -> List[Dict]: + """Select videos using round-robin across demographic groups. - demo_groups = duration_demo_groups.get(duration_cat, {}) - if not demo_groups: - continue + Args: + duration_demo_groups: Grouped videos by duration and demographics. + adjusted_targets: Target counts for each duration category. - demo_labels = list(demo_groups.keys()) + Returns + ------- + List of selected videos. + """ + selected = [] - videos_selected = 0 - demo_index = 0 + for duration_cat in ["short", "medium", "long"]: + target_count = adjusted_targets.get(duration_cat, 0) + if target_count == 0: + continue - while videos_selected < target_count: - demo_label = demo_labels[demo_index % len(demo_labels)] + demo_groups = duration_demo_groups.get(duration_cat, {}) + if not demo_groups: + continue - if demo_groups[demo_label]: - video = demo_groups[demo_label].pop(0) - selected.append(video) - videos_selected += 1 + demo_labels = list(demo_groups.keys()) - demo_index += 1 + videos_selected = 0 + demo_index = 0 - if all( - len(videos) == 0 - for videos in demo_groups.values() - ): - break + while videos_selected < target_count: + demo_label = demo_labels[demo_index % len(demo_labels)] - if len(selected) < needed_count: - remaining_videos = [] - for duration_cat in duration_demo_groups.values(): - for demo_videos in duration_cat.values(): - remaining_videos.extend(demo_videos) + if demo_groups[demo_label]: + video = demo_groups[demo_label].pop(0) + selected.append(video) + videos_selected += 1 - needed_more = needed_count - len(selected) - selected.extend(remaining_videos[:needed_more]) + demo_index += 1 - print(f"โœ“ Selected {len(selected)} videos with diversity") + if all(len(videos) == 0 for videos in demo_groups.values()): + break + + return selected + + def _fill_remaining_slots( + self, + selected: List[Dict], + needed_count: int, + duration_demo_groups: Dict[str, Dict[str, List[Dict]]], + ) -> List[Dict]: + """Fill any remaining slots with available videos. + + Args: + selected: Currently selected videos. + needed_count: Total number of videos needed. + duration_demo_groups: Grouped videos by duration and demographics. + Returns + ------- + Updated list of selected videos. + """ + if len(selected) < needed_count: + remaining_videos = [] + for duration_cat in duration_demo_groups.values(): + for demo_videos in duration_cat.values(): + remaining_videos.extend(demo_videos) + + needed_more = needed_count - len(selected) + selected.extend(remaining_videos[:needed_more]) + + return selected + + def _assign_numbers_and_stats( + self, selected: List[Dict], start_number: int + ) -> List[Dict]: + """Assign video numbers and print final statistics. + + Args: + selected: List of selected videos. + start_number: Starting video number. + + Returns + ------- + Updated list of selected videos with numbers assigned. + """ for idx, video in enumerate(selected): - video['video_number'] = f"{start_number + idx:03d}" + video["video_number"] = f"{start_number + idx:03d}" duration_counts = defaultdict(int) demo_counts = defaultdict(int) for video in selected: - duration_counts[video.get('duration_category', 'unknown')] += 1 - demo_counts[video.get('demographic_label', 'general')] += 1 + duration_counts[video.get("duration_category", "unknown")] += 1 + demo_counts[video.get("demographic_label", "general")] += 1 print(" Final distribution:") print(f" Duration: {dict(duration_counts)}") @@ -300,6 +347,55 @@ def select_videos( return selected + def select_videos( + self, filtered_videos: List[Dict], needed_count: int, start_number: int + ) -> List[Dict]: + """Select videos with maximum diversity. + + Args: + filtered_videos: List of filtered video metadata. + needed_count: Number of videos needed. + start_number: Starting video number. + + Returns + ------- + Selected videos with diversity across demographics and duration. + """ + if len(filtered_videos) == 0: + print("โœ“ No new videos to select") + return [] + + if len(filtered_videos) <= needed_count: + selected = filtered_videos + print(f"โœ“ Selected all {len(selected)} available videos") + else: + print( + f" Selecting {needed_count} from {len(filtered_videos)} " + f"videos with diversity optimization..." + ) + + duration_demo_groups = self._calculate_duration_distribution( + filtered_videos + ) + adjusted_targets = self._calculate_target_counts( + needed_count, duration_demo_groups + ) + adjusted_targets = self._adjust_targets_for_availability( + adjusted_targets, needed_count, duration_demo_groups + ) + + selected = self._select_videos_round_robin( + duration_demo_groups, adjusted_targets + ) + selected = self._fill_remaining_slots( + selected, needed_count, duration_demo_groups + ) + + print(f"โœ“ Selected {len(selected)} videos with diversity") + + return self._assign_numbers_and_stats(selected, start_number) + + def download_video(self, video_id: str, output_path: str) -> bool: """Download video using yt-dlp (max 1080p). @@ -307,26 +403,23 @@ def download_video(self, video_id: str, output_path: str) -> bool: video_id: YouTube video ID. output_path: Path to save video. - Returns: + Returns + ------- True if successful, False otherwise. """ url = f"https://www.youtube.com/watch?v={video_id}" cookies_path = self.base_dir / "cookies.txt" ydl_opts = { - 'format': ( - 'bestvideo[height<=1080][ext=mp4]+' - 'bestaudio[ext=m4a]/best[height<=1080][ext=mp4]' + "format": ( + "bestvideo[height<=1080][ext=mp4]+" + "bestaudio[ext=m4a]/best[height<=1080][ext=mp4]" ), - 'outtmpl': output_path, - 'merge_output_format': 'mp4', - 'quiet': True, - 'no_warnings': True, - 'cookiefile': str(cookies_path), - 'extractor_args': { - 'youtube': { - 'player_client': ['default', '-tv'] - } - }, + "outtmpl": output_path, + "merge_output_format": "mp4", + "quiet": True, + "no_warnings": True, + "cookiefile": str(cookies_path), + "extractor_args": {"youtube": {"player_client": ["default", "-tv"]}}, } try: @@ -344,19 +437,25 @@ def extract_audio(self, video_path: str, audio_path: str) -> bool: video_path: Path to video file. audio_path: Path to save audio. - Returns: + Returns + ------- True if successful, False otherwise. """ cmd = [ - 'ffmpeg', - '-i', video_path, - '-vn', - '-acodec', 'aac', - '-ar', '48000', - '-ac', '2', - '-ab', '192k', - '-y', - audio_path + "ffmpeg", + "-i", + video_path, + "-vn", + "-acodec", + "aac", + "-ar", + "48000", + "-ac", + "2", + "-ab", + "192k", + "-y", + audio_path, ] try: @@ -379,27 +478,24 @@ def download_captions(self, video_id: str, output_path: str) -> bool: video_id: YouTube video ID. output_path: Path to save captions. - Returns: + Returns + ------- True if successful, False otherwise. """ url = f"https://www.youtube.com/watch?v={video_id}" cookies_path = self.base_dir / "cookies.txt" - output_base = output_path.replace('.srt', '') + output_base = output_path.replace(".srt", "") ydl_opts = { - 'writesubtitles': True, - 'subtitleslangs': ['en'], - 'subtitlesformat': 'srt', - 'skip_download': True, - 'outtmpl': output_base, - 'quiet': True, - 'no_warnings': True, - 'cookiefile': str(cookies_path), - 'extractor_args': { - 'youtube': { - 'player_client': ['default', '-tv'] - } - }, + "writesubtitles": True, + "subtitleslangs": ["en"], + "subtitlesformat": "srt", + "skip_download": True, + "outtmpl": output_base, + "quiet": True, + "no_warnings": True, + "cookiefile": str(cookies_path), + "extractor_args": {"youtube": {"player_client": ["default", "-tv"]}}, } try: with yt_dlp.YoutubeDL(ydl_opts) as ydl: @@ -420,22 +516,21 @@ def download_captions(self, video_id: str, output_path: str) -> bool: except Exception: return False - def merge_metadata( - self, topic_name: str, new_videos: List[Dict] - ) -> List[Dict]: + def merge_metadata(self, topic_name: str, new_videos: List[Dict]) -> List[Dict]: """Merge new videos with existing metadata. Args: topic_name: Name of the topic. new_videos: List of new video metadata. - Returns: + Returns + ------- Combined metadata list. """ metadata_path = self.videos_dir / topic_name / "metadata.json" if metadata_path.exists(): - with open(metadata_path, 'r') as f: + with open(metadata_path, "r") as f: existing_metadata = json.load(f) combined = existing_metadata + new_videos @@ -444,121 +539,193 @@ def merge_metadata( f"{len(new_videos)} new = {len(combined)} total videos" ) return combined - else: - return new_videos + return new_videos - def process_topic(self, topic_metadata_path: str, force: bool = False): - """Main processing function for a topic. + def _validate_and_prepare_topic( + self, topic_metadata_path: str, force: bool = False + ) -> tuple: + """Validate topic can be processed and prepare initial data. Args: topic_metadata_path: Path to topic metadata file. force: If True, process even if at max count. + + Returns + ------- + Tuple of (topic_name, current_count, existing_video_ids, next_number, + needed_count) or (None, 0, set(), 0, 0) if should skip. """ topic_path = Path(topic_metadata_path).parent topic_name = topic_path.name - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print(f"Processing Topic: {topic_name}") - print(f"{'='*60}\n") + print(f"{'=' * 60}\n") - (current_count, existing_video_ids, - next_number) = self.get_existing_video_info(topic_name) + (current_count, existing_video_ids, next_number) = self.get_existing_video_info( + topic_name + ) if current_count >= self.max_count and not force: print( - f"Topic already has {current_count}/{self.max_count} " - f"videos - SKIPPING" + f"Topic already has {current_count}/{self.max_count} videos - SKIPPING" ) - return + return (None, 0, set(), 0, 0) needed_count = self.max_count - current_count print(f" Current: {current_count}/{self.max_count} videos") print(f" Need to add: {needed_count} videos\n") - self.create_directories(topic_name) + return ( + topic_name, + current_count, + existing_video_ids, + next_number, + needed_count, + ) - metadata = self.load_metadata(topic_metadata_path) - filtered = self.filter_videos(metadata, existing_video_ids) + def _load_and_filter_videos( + self, topic_metadata_path: str, existing_video_ids: Set[str] + ) -> List[Dict]: + """Load topic metadata and filter to get candidate videos. - if len(filtered) == 0: - print("No new videos meet the criteria!") - return + Args: + topic_metadata_path: Path to topic metadata file. + existing_video_ids: Set of already-downloaded video IDs. - selected = self.select_videos(filtered, needed_count, next_number) + Returns + ------- + List of filtered video metadata. + """ + metadata = self.load_metadata(topic_metadata_path) + return self.filter_videos(metadata, existing_video_ids) - if len(selected) == 0: - print("No videos were selected!") - return + def _download_video_assets(self, video: Dict, topic_name: str) -> str | None: + """Download video, extract audio, and download captions for a single video. - needs_whisper = [] + Args: + video: Video metadata dictionary. + topic_name: Name of the topic. - whisper_file = self.captions_dir / topic_name / "needs_whisper.txt" - if whisper_file.exists(): - with open(whisper_file, 'r') as f: - needs_whisper = [line.strip() for line in f.readlines()] + Returns + ------- + Audio filename if needs Whisper transcription, None otherwise. + """ + video_id = video["video_id"] + video_num = video["video_number"] - for video in selected: - video_id = video['video_id'] - video_num = video['video_number'] + print(f"\n--- Processing video {video_num}/{self.max_count}: {video_id} ---") - print( - f"\n--- Processing video {video_num}/{self.max_count}: " - f"{video_id} ---" - ) + video_path = self.videos_dir / topic_name / f"video_{video_num}.mp4" + audio_path = self.audios_dir / topic_name / f"audio_{video_num}.m4a" + caption_path = self.captions_dir / topic_name / f"caption_{video_num}.srt" - video_path = ( - self.videos_dir / topic_name / f"video_{video_num}.mp4" - ) - audio_path = ( - self.audios_dir / topic_name / f"audio_{video_num}.m4a" - ) - caption_path = ( - self.captions_dir / topic_name / f"caption_{video_num}.srt" - ) + print("Downloading video...") + if self.download_video(video_id, str(video_path)): + print(f"Video downloaded: {video_path.name}") - print("Downloading video...") - if self.download_video(video_id, str(video_path)): - print(f"Video downloaded: {video_path.name}") - - print("Extracting audio...") - if self.extract_audio(str(video_path), str(audio_path)): - print(f"Audio extracted: {audio_path.name}") - else: - print("Audio extraction failed") - - print(" Downloading captions...") - if self.download_captions(video_id, str(caption_path)): - print(f"Captions downloaded: {caption_path.name}") - else: - print("No captions available, adding to Whisper queue") - needs_whisper.append(f"audio_{video_num}.m4a") + print("Extracting audio...") + if self.extract_audio(str(video_path), str(audio_path)): + print(f"Audio extracted: {audio_path.name}") else: - print("Video download failed, skipping") + print("Audio extraction failed") + + print(" Downloading captions...") + if self.download_captions(video_id, str(caption_path)): + print(f"Captions downloaded: {caption_path.name}") + return None + print("No captions available, adding to Whisper queue") + return f"audio_{video_num}.m4a" + print("Video download failed, skipping") + return None + + def _save_topic_metadata_and_summary( + self, + topic_name: str, + merged_metadata: List[Dict], + needs_whisper: List[str], + selected_count: int, + ): + """Save all metadata files and print final summary. - merged_metadata = self.merge_metadata(topic_name, selected) + Args: + topic_name: Name of the topic. + merged_metadata: Combined metadata list. + needs_whisper: List of audio filenames needing Whisper. + selected_count: Number of newly selected videos. + """ + whisper_file = self.captions_dir / topic_name / "needs_whisper.txt" - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print("Saving metadata files...") - for base_dir in [ - self.videos_dir, self.audios_dir, self.captions_dir - ]: + for base_dir in [self.videos_dir, self.audios_dir, self.captions_dir]: metadata_path = base_dir / topic_name / "metadata.json" - with open(metadata_path, 'w') as f: + with open(metadata_path, "w") as f: json.dump(merged_metadata, f, indent=2) print(f"โœ“ Saved: {metadata_path}") if needs_whisper: - with open(whisper_file, 'w') as f: + with open(whisper_file, "w") as f: for audio_file in needs_whisper: f.write(f"{audio_file}\n") print(f"โœ“ Saved Whisper queue: {whisper_file}") - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print(f"โœ“ Topic {topic_name} processing complete!") print(f" Total videos now: {len(merged_metadata)}/{self.max_count}") - print(f" New videos added: {len(selected)}") + print(f" New videos added: {selected_count}") print(f" Needs Whisper: {len(needs_whisper)}") - print(f"{'='*60}\n") + print(f"{'=' * 60}\n") + + def process_topic(self, topic_metadata_path: str, force: bool = False): + """Process a topic by filtering and selecting quality videos. + + Args: + topic_metadata_path: Path to topic metadata file. + force: If True, process even if at max count. + """ + # Validate and prepare topic + (topic_name, current_count, existing_video_ids, next_number, needed_count) = ( + self._validate_and_prepare_topic(topic_metadata_path, force) + ) + + if topic_name is None: + return + + self.create_directories(topic_name) + + # Load and filter videos + filtered = self._load_and_filter_videos(topic_metadata_path, existing_video_ids) + + if len(filtered) == 0: + print("No new videos meet the criteria!") + return + + # Select videos + selected = self.select_videos(filtered, needed_count, next_number) + + if len(selected) == 0: + print("No videos were selected!") + return + + # Load existing Whisper queue + needs_whisper = [] + whisper_file = self.captions_dir / topic_name / "needs_whisper.txt" + if whisper_file.exists(): + with open(whisper_file, "r") as f: + needs_whisper = [line.strip() for line in f.readlines()] + + # Download video assets + for video in selected: + audio_filename = self._download_video_assets(video, topic_name) + if audio_filename: + needs_whisper.append(audio_filename) + + # Merge and save metadata + merged_metadata = self.merge_metadata(topic_name, selected) + self._save_topic_metadata_and_summary( + topic_name, merged_metadata, needs_whisper, len(selected) + ) def generate_summary_from_metadata(self, topic_name: str): """Generate summary.json from existing metadata.json file. @@ -566,7 +733,8 @@ def generate_summary_from_metadata(self, topic_name: str): Args: topic_name: Name of the topic. - Returns: + Returns + ------- Summary dictionary or None if error. """ print(f"\nGenerating summary for: {topic_name}") @@ -577,7 +745,7 @@ def generate_summary_from_metadata(self, topic_name: str): print(f"โœ— Metadata file not found: {metadata_path}") return None - with open(metadata_path, 'r') as f: + with open(metadata_path, "r") as f: selected_videos = json.load(f) print(f"โœ“ Loaded {len(selected_videos)} videos from metadata") @@ -585,9 +753,9 @@ def generate_summary_from_metadata(self, topic_name: str): whisper_file = self.captions_dir / topic_name / "needs_whisper.txt" needs_whisper = [] if whisper_file.exists(): - with open(whisper_file, 'r') as f: + with open(whisper_file, "r") as f: needs_whisper = [ - line.strip().replace('audio_', '').replace('.m4a', '') + line.strip().replace("audio_", "").replace(".m4a", "") for line in f.readlines() ] @@ -596,13 +764,12 @@ def generate_summary_from_metadata(self, topic_name: str): total_duration_seconds = 0 for video in selected_videos: - duration_counts[video.get('duration_category', 'unknown')] += 1 - demo_counts[video.get('demographic_label', 'general')] += 1 - total_duration_seconds += video.get('duration_seconds', 0) + duration_counts[video.get("duration_category", "unknown")] += 1 + demo_counts[video.get("demographic_label", "general")] += 1 + total_duration_seconds += video.get("duration_seconds", 0) avg_duration_seconds = ( - total_duration_seconds / len(selected_videos) - if selected_videos else 0 + total_duration_seconds / len(selected_videos) if selected_videos else 0 ) summary = { @@ -610,22 +777,16 @@ def generate_summary_from_metadata(self, topic_name: str): "processing_timestamp": datetime.now().isoformat(), "statistics": { "selected_videos_count": len(selected_videos), - "videos_with_captions": ( - len(selected_videos) - len(needs_whisper) - ), + "videos_with_captions": (len(selected_videos) - len(needs_whisper)), "videos_needing_whisper": len(needs_whisper), "total_duration_seconds": total_duration_seconds, - "total_duration_minutes": round( - total_duration_seconds / 60, 2 - ), + "total_duration_minutes": round(total_duration_seconds / 60, 2), "average_duration_seconds": round(avg_duration_seconds, 2), - "average_duration_minutes": round( - avg_duration_seconds / 60, 2 - ) + "average_duration_minutes": round(avg_duration_seconds / 60, 2), }, "distribution": { "by_duration": dict(duration_counts), - "by_demographics": dict(demo_counts) + "by_demographics": dict(demo_counts), }, "duration_percentages": { cat: round((count / len(selected_videos)) * 100, 1) @@ -636,14 +797,12 @@ def generate_summary_from_metadata(self, topic_name: str): for demo, count in demo_counts.items() }, "needs_whisper_list": needs_whisper, - "video_ids": [video['video_id'] for video in selected_videos] + "video_ids": [video["video_id"] for video in selected_videos], } - for base_dir in [ - self.videos_dir, self.audios_dir, self.captions_dir - ]: + for base_dir in [self.videos_dir, self.audios_dir, self.captions_dir]: summary_path = base_dir / topic_name / "summary.json" - with open(summary_path, 'w') as f: + with open(summary_path, "w") as f: json.dump(summary, f, indent=2) print(f"โœ“ Saved: {summary_path}") @@ -653,26 +812,24 @@ def generate_summary_from_metadata(self, topic_name: str): if __name__ == "__main__": config_path = os.path.join(os.path.dirname(__file__), "config.yaml") - with open(config_path, 'r') as f: + with open(config_path, "r") as f: config = yaml.safe_load(f) - base_dir = config['directories']['base_dir'] + base_dir = config["directories"]["base_dir"] if not os.path.isabs(base_dir): base_dir = os.path.join(os.path.dirname(__file__), base_dir) - MAX_COUNT = config['max_videos_per_topic']['max_videos_per_topic'] + MAX_COUNT = config["max_videos_per_topic"]["max_videos_per_topic"] processor = VideoDatasetProcessor(base_dir, max_count=MAX_COUNT) - quality_annotated_dir = os.path.join( - base_dir, "videos_QualityAnnotated" - ) + quality_annotated_dir = os.path.join(base_dir, "videos_QualityAnnotated") if not os.path.exists(quality_annotated_dir): - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print("ERROR: videos_QualityAnnotated directory not found!") - print(f"{'='*60}") + print(f"{'=' * 60}") print(f"\nExpected location: {quality_annotated_dir}") print("\nBEFORE RUNNING THIS SCRIPT, YOU MUST:") print("1. Create the 'videos_QualityAnnotated' directory") @@ -680,20 +837,21 @@ def generate_summary_from_metadata(self, topic_name: str): print("3. Add 'Qualitylabel' field to videos you want to include") print("4. Structure it following ../huggingface_review_template/") print("\nSee README.md for detailed instructions.") - print(f"{'='*60}\n") + print(f"{'=' * 60}\n") sys.exit(1) - topic_dirs = sorted([ - d for d in os.listdir(quality_annotated_dir) - if os.path.isdir(os.path.join(quality_annotated_dir, d)) - ]) - - print(f"\n{'='*60}") - print( - f"Found {len(topic_dirs)} topics in videos_QualityAnnotated" + topic_dirs = sorted( + [ + d + for d in os.listdir(quality_annotated_dir) + if os.path.isdir(os.path.join(quality_annotated_dir, d)) + ] ) + + print(f"\n{'=' * 60}") + print(f"Found {len(topic_dirs)} topics in videos_QualityAnnotated") print(f"Max videos per topic: {MAX_COUNT}") - print(f"{'='*60}\n") + print(f"{'=' * 60}\n") topics_processed = 0 topics_skipped_full = 0 @@ -713,8 +871,7 @@ def generate_summary_from_metadata(self, topic_name: str): topic_path = os.path.join(quality_annotated_dir, topic_dir) metadata_files = [ - f for f in os.listdir(topic_path) - if f.endswith('_metadata.json') + f for f in os.listdir(topic_path) if f.endswith("_metadata.json") ] if not metadata_files: @@ -726,10 +883,7 @@ def generate_summary_from_metadata(self, topic_name: str): try: if current_count > 0: - print( - f"\nEXTENDING {topic_dir} " - f"(has {current_count}/{MAX_COUNT})" - ) + print(f"\nEXTENDING {topic_dir} (has {current_count}/{MAX_COUNT})") topics_extended += 1 else: print(f"\nCREATING {topic_dir}") @@ -740,16 +894,15 @@ def generate_summary_from_metadata(self, topic_name: str): except Exception as e: print(f"\nโœ— Error processing {topic_dir}: {e}") - import traceback traceback.print_exc() continue - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print("PROCESSING COMPLETE!") - print(f"{'='*60}") + print(f"{'=' * 60}") print(f"Topics processed: {topics_processed}") print(f" - New topics created: {topics_processed - topics_extended}") print(f" - Existing topics extended: {topics_extended}") print(f"Topics skipped (already full): {topics_skipped_full}") print(f"Topics skipped (no metadata): {topics_skipped_no_metadata}") - print(f"{'='*60}\n") \ No newline at end of file + print(f"{'=' * 60}\n") diff --git a/sonic-o1/01_data_curation/youtube_metadata_scraper.py b/sonic-o1/01_data_curation/youtube_metadata_scraper.py index 16066f1..94f3fee 100644 --- a/sonic-o1/01_data_curation/youtube_metadata_scraper.py +++ b/sonic-o1/01_data_curation/youtube_metadata_scraper.py @@ -1,8 +1,11 @@ -"""YouTube Metadata Scraper for Fairness Analysis. +"""youtube_metadata_scraper.py. -Collects metadata for videos across 13 topics with various demographic -dimensions. +Scrapes YouTube video metadata across multiple topics with demographic variations +and applies quality filtering based on engagement and content analysis. + +Author: SONIC-O1 Team """ + import json import os import re @@ -19,1066 +22,1058 @@ TranscriptsDisabled, ) + +try: + import yaml + + YAML_AVAILABLE = True +except ImportError: + YAML_AVAILABLE = False + + load_dotenv() # YouTube API Configuration -YOUTUBE_API_SERVICE_NAME = 'youtube' -YOUTUBE_API_VERSION = 'v3' +YOUTUBE_API_SERVICE_NAME = "youtube" +YOUTUBE_API_VERSION = "v3" def load_config(config_path: str = None) -> Dict: """Load configuration from JSON or YAML file. - + Args: config_path: Path to config file. If None, searches for config.yaml or config.json in current directory. - - Returns: + + Returns + ------- Configuration dictionary, or None if file not found. """ if config_path is None: - if os.path.exists('config.yaml'): - config_path = 'config.yaml' - elif os.path.exists('config.json'): - config_path = 'config.json' + if os.path.exists("config.yaml"): + config_path = "config.yaml" + elif os.path.exists("config.json"): + config_path = "config.json" else: print("ERROR: No configuration file found!") - print( - "Please create either config.yaml or config.json " - "with your settings." - ) + print("Please create either config.yaml or config.json with your settings.") return None - + if not os.path.exists(config_path): print(f"ERROR: Configuration file not found at {config_path}") return None - - if config_path.endswith('.yaml') or config_path.endswith('.yml'): - try: - import yaml - with open(config_path, 'r') as f: - config = yaml.safe_load(f) - except ImportError: - print( - "ERROR: PyYAML not installed. " - "Install with: pip install pyyaml" - ) + + if config_path.endswith(".yaml") or config_path.endswith(".yml"): + if not YAML_AVAILABLE: + print("ERROR: PyYAML not installed. Install with: pip install pyyaml") return None + with open(config_path, "r") as f: + config = yaml.safe_load(f) else: - with open(config_path, 'r') as f: + with open(config_path, "r") as f: config = json.load(f) - + print(f"โœ“ Loaded configuration from {config_path}") return config CONFIG = load_config() if CONFIG: - API_KEY = os.environ['YT_SCRAP_API'] - BASE_DIR = CONFIG['directories']['base_dir'] - VIDEOS_DIR = os.path.join( - BASE_DIR, - CONFIG['directories']['videos_dir'] - ) + API_KEY = os.environ["YT_SCRAP_API"] + BASE_DIR = CONFIG["directories"]["base_dir"] + VIDEOS_DIR = os.path.join(BASE_DIR, CONFIG["directories"]["videos_dir"]) + # Load topics from config (can be tuned/modified in config.yaml) + TOPICS = CONFIG.get("topics", {}) else: API_KEY = None BASE_DIR = None VIDEOS_DIR = None + TOPICS = {} -# Topics with their focus areas and fairness dimensions -TOPICS = { - 1: { - "name": "Patient-Doctor Consultations", - "search_terms": [ - "doctor patient conversation full session", - "clinic consultation recording", - "telehealth visit recording", - "primary care consultation unedited", - "medical intake interview full visit" - ], - "focus": "Medical communication, empathy, diagnosis discussions", - "channel_id": None - }, - 2: { - "name": "Job Interviews", - "search_terms": [ - "panel interview full interview", - "candidate interview recording", - "on site interview hiring manager", - "technical interview session full", - "HR screening call recording" - ], - "focus": ( - "Professional interaction, emotion detection, body language" - ), - "channel_id": None - }, - 3: { - "name": "Parent-Teacher Conferences", - "search_terms": [ - "parent teacher conference recording", - "PTC meeting full session", - "student progress meeting recording", - "IEP meeting full", - "teacher parent meeting unedited" - ], - "focus": ( - "Educational settings, conflict resolution, child advocacy" - ), - "channel_id": None - }, - 4: { - "name": "Customer Service Interactions", - "search_terms": [ - "front desk dispute", - "restaurant customer service", - "Customer Service SNL", - "employee vs customer -karen -compilation", - "store manager customer service", - "angry customer store footage -staged" - ], - "focus": ( - "Complaint handling, emotion regulation, problem-solving" - ), - "channel_id": None - }, - 5: { - "name": "Courtroom Proceedings", - "search_terms": [ - "oral argument", - "\"sentencing hearing\" \"full recording\" courtroom", - "municipal court arraignment calendar session full", - "\"sentencing hearing\" full recording courtroom", - "small claims court full hearing official recording", - "mock trial full" - ], - "focus": ( - "Legal settings, testimony analysis, fairness assessment" - ), - "channel_id": None - }, - 6: { - "name": "Emergency Response Scenarios", - "search_terms": [ - ( - "firefighters highway incident \"full response\" " - "-shorts -news" - ), - "bodycam police", - "bodycam police rescue", - "Law&Crime BodyCam", - "real emergency calls paramedic" - ], - "focus": "Crisis management, first aid, triage decisions", - "channel_id": None - }, - 7: { - "name": "Public Transportation Conflicts", - "search_terms": [ - "bus passenger fight driver -news -compilation", - "bar fight", - "train passenger confrontation cctv", - "airport security passenger meltdown bodycam", - "grocery store argument customer", - "parking lot dispute road rage" - ], - "focus": ( - "Social etiquette, accessibility, conflict de-escalation" - ), - "channel_id": None - }, - 8: { - "name": "Workplace Team Meetings", - "search_terms": [ - ( - "\"team meeting\" recording Zoom -webinar " - "-tutorial -class" - ), - "daily standup meeting", - "scrum meeting real team -demo -example", - "sprint review meeting", - "workplace \"meeting\" recording", - ], - "focus": "Collaboration, leadership dynamics, idea contribution", - "channel_id": None - }, - 9: { - "name": "Housing/Apartment Tours", - "search_terms": [ - "open house walkthrough agent client", - "apartment tour with agent full", - "rental inspection landlord tenant recording", - "home showing buyer walkthrough", - "accessible apartment tour elevator ramp" - ], - "focus": "Real estate interactions, accessibility features", - "channel_id": None - }, - 10: { - "name": "Restaurant Service Encounters", - "search_terms": [ - "restaurant vlog", - "waitress day in the life", - "restaurant behind the scenes", - "food service worker", - "restaurant review visit" - ], - "focus": ( - "Service quality, complaint handling, accessibility" - ), - "channel_id": None - }, - 11: { - "name": "Mental Health Counseling", - "search_terms": [ - "counseling session demonstration", - "therapy role play training", - "mock therapy session psychology", - "counseling techniques demonstration video", - "therapeutic communication examples" - ], - "focus": ( - "Therapeutic alliance, emotional support, " - "crisis intervention" - ), - "channel_id": None - }, - 12: { - "name": "Community Town Halls", - "search_terms": [ - "\"town hall\" \"full recording\" community Q&A", - "town hall meeting complete", - ( - "\"city council\" meeting \"livestream archive\" " - "-highlights -clips" - ), - "community meeting local government", - ( - "\"Islamic center\" community forum full " - "-news -compilation" - ) - ], - "focus": ( - "Civic engagement, diverse viewpoints, accessibility" - ), - "channel_id": None - }, - 13: { - "name": "Olympics", - "search_terms": [ - "olympic games highlights", - "summer olympics events", - "winter olympics full coverage", - "olympic moments compilation", - "olympics replay full event" - ], - "focus": "Sports videos", - "channel_id": None - } -} - -DEMOGRAPHICS = CONFIG['demographics'] +# Load demographics from config +DEMOGRAPHICS = CONFIG.get("demographics", {}) if CONFIG else {} def categorize_duration(duration_seconds: int) -> str: """Categorize video duration into short/medium/long. - + Args: duration_seconds: Duration in seconds. - - Returns: + + Returns + ------- Category string: 'short', 'medium', 'long', or 'other'. """ duration_minutes = duration_seconds / 60 - + if 0.5 <= duration_minutes < 5: - return 'short' - elif 5 <= duration_minutes < 20: - return 'medium' - elif 20 <= duration_minutes <= 60: - return 'long' - else: - return 'other' + return "short" + if 5 <= duration_minutes < 20: + return "medium" + if 20 <= duration_minutes <= 60: + return "long" + return "other" class YouTubeMetadataScraper: """Scraper for YouTube video metadata with quality filtering.""" - + def __init__(self, api_key: str, config: Dict = None): """Initialize the scraper. - + Args: api_key: YouTube Data API v3 key. config: Configuration dictionary. """ self.youtube = build( - YOUTUBE_API_SERVICE_NAME, - YOUTUBE_API_VERSION, - developerKey=api_key + YOUTUBE_API_SERVICE_NAME, YOUTUBE_API_VERSION, developerKey=api_key ) self.config = config or CONFIG - self.rate_limit_delay = ( - self.config['api_settings'].get('rate_limit_delay', 1) - ) - self.max_results = ( - self.config['api_settings'] - .get('max_results_per_query', 50) + self.rate_limit_delay = self.config["api_settings"].get("rate_limit_delay", 1) + self.max_results = self.config["api_settings"].get("max_results_per_query", 50) + self.caption_text_limit = self.config["collection_settings"].get( + "caption_text_limit", 5000 ) - self.caption_text_limit = ( - self.config['collection_settings'] - .get('caption_text_limit', 5000) - ) - self.video_duration = ( - self.config['collection_settings'] - .get('video_duration', 'medium') - ) - years_back = ( - self.config['search_settings'].get('years_back', 5) + self.video_duration = self.config["collection_settings"].get( + "video_duration", "medium" ) + years_back = self.config["search_settings"].get("years_back", 5) self.published_after = ( - (datetime.now() - timedelta(days=years_back*365)) - .isoformat() + 'Z' - ) - self.video_license = ( - self.config['search_settings'] - .get('video_license', 'any') - ) - + datetime.now() - timedelta(days=years_back * 365) + ).isoformat() + "Z" + self.video_license = self.config["search_settings"].get("video_license", "any") + def search_videos( self, query: str, max_results: int = None, channel_id: str = None, - topic_id: int = None + topic_id: int = None, ) -> List[str]: """Search for videos and return video IDs sorted by view count. - + Args: query: Search query string. max_results: Maximum number of results to return. channel_id: Optional channel ID to filter by. topic_id: Optional topic ID for special handling. - - Returns: + + Returns + ------- List of video IDs. """ if max_results is None: max_results = self.max_results - if topic_id in [4, 6, 7, 8]: - video_duration = 'any' - else: - video_duration = self.video_duration - + video_duration = "any" if topic_id in [4, 6, 7, 8] else self.video_duration + try: search_params = { - 'q': query, - 'part': 'id', - 'maxResults': min(max_results or self.max_results, 80), - 'type': 'video', - 'order': 'relevance', - 'relevanceLanguage': 'en', - 'videoCaption': 'any', - 'videoDefinition': 'high', - 'videoDuration': video_duration, - 'videoLicense': self.video_license, - 'publishedAfter': self.published_after + "q": query, + "part": "id", + "maxResults": min(max_results or self.max_results, 80), + "type": "video", + "order": "relevance", + "relevanceLanguage": "en", + "videoCaption": "any", + "videoDefinition": "high", + "videoDuration": video_duration, + "videoLicense": self.video_license, + "publishedAfter": self.published_after, } - + if channel_id: - search_params['channelId'] = channel_id - - search_response = ( - self.youtube.search().list(**search_params).execute() - ) - + search_params["channelId"] = channel_id + + search_response = self.youtube.search().list(**search_params).execute() + video_ids = [ - item['id']['videoId'] - for item in search_response.get('items', []) + item["id"]["videoId"] for item in search_response.get("items", []) ] time.sleep(self.rate_limit_delay) return video_ids - + except Exception as e: print(f"Error searching videos for query '{query}': {e}") return [] - + def get_video_details(self, video_ids: List[str]) -> List[Dict]: """Get detailed metadata for a list of video IDs. - + Args: video_ids: List of YouTube video IDs. - - Returns: + + Returns + ------- List of dictionaries containing video metadata. """ if not video_ids: return [] - + try: - video_response = self.youtube.videos().list( - part=( - 'snippet,contentDetails,statistics,' - 'status,topicDetails' - ), - id=','.join(video_ids) - ).execute() - + video_response = ( + self.youtube.videos() + .list( + part=("snippet,contentDetails,statistics,status,topicDetails"), + id=",".join(video_ids), + ) + .execute() + ) + videos_data = [] - for item in video_response.get('items', []): + for item in video_response.get("items", []): video_data = self._parse_video_item(item) videos_data.append(video_data) - + time.sleep(self.rate_limit_delay) return videos_data - + except Exception as e: print(f"Error getting video details: {e}") return [] - + def _parse_video_item(self, item: Dict) -> Dict: """Parse video item from API response into structured metadata. - + Args: item: Video item dictionary from YouTube API. - - Returns: + + Returns + ------- Structured video metadata dictionary. """ - snippet = item.get('snippet', {}) - content_details = item.get('contentDetails', {}) - statistics = item.get('statistics', {}) - status = item.get('status', {}) - - video_id = item['id'] - - duration_str = content_details.get('duration', 'PT0S') + snippet = item.get("snippet", {}) + content_details = item.get("contentDetails", {}) + statistics = item.get("statistics", {}) + status = item.get("status", {}) + + video_id = item["id"] + + duration_str = content_details.get("duration", "PT0S") duration_seconds = self._parse_duration(duration_str) - - has_captions = content_details.get('caption') == 'true' - caption_text = ( - self._get_caption_text(video_id) if has_captions else None - ) - + + has_captions = content_details.get("caption") == "true" + caption_text = self._get_caption_text(video_id) if has_captions else None + return { - 'video_id': video_id, - 'url': f'https://www.youtube.com/watch?v={video_id}', - 'title': snippet.get('title', ''), - 'channel_title': snippet.get('channelTitle', ''), - 'channel_id': snippet.get('channelId', ''), - 'published_at': snippet.get('publishedAt', ''), - 'duration_seconds': duration_seconds, - 'duration_formatted': duration_str, - 'duration_category': categorize_duration(duration_seconds), - 'view_count': int(statistics.get('viewCount', 0)), - 'like_count': int(statistics.get('likeCount', 0)), - 'comment_count': int(statistics.get('commentCount', 0)), - 'tags': ','.join(snippet.get('tags', [])), - 'category_id': snippet.get('categoryId', ''), - 'default_language': snippet.get('defaultLanguage', ''), - 'default_audio_language': ( - snippet.get('defaultAudioLanguage', '') - ), - 'has_captions': has_captions, - 'caption_text': caption_text, - 'is_licensed_content': ( - content_details.get('licensedContent', False) - ), - 'copyright_notice': status.get('license', ''), - 'privacy_status': status.get('privacyStatus', ''), - 'embeddable': status.get('embeddable', False), - 'public_stats_viewable': ( - status.get('publicStatsViewable', True) + "video_id": video_id, + "url": f"https://www.youtube.com/watch?v={video_id}", + "title": snippet.get("title", ""), + "channel_title": snippet.get("channelTitle", ""), + "channel_id": snippet.get("channelId", ""), + "published_at": snippet.get("publishedAt", ""), + "duration_seconds": duration_seconds, + "duration_formatted": duration_str, + "duration_category": categorize_duration(duration_seconds), + "view_count": int(statistics.get("viewCount", 0)), + "like_count": int(statistics.get("likeCount", 0)), + "comment_count": int(statistics.get("commentCount", 0)), + "tags": ",".join(snippet.get("tags", [])), + "category_id": snippet.get("categoryId", ""), + "default_language": snippet.get("defaultLanguage", ""), + "default_audio_language": (snippet.get("defaultAudioLanguage", "")), + "has_captions": has_captions, + "caption_text": caption_text, + "is_licensed_content": (content_details.get("licensedContent", False)), + "copyright_notice": status.get("license", ""), + "privacy_status": status.get("privacyStatus", ""), + "embeddable": status.get("embeddable", False), + "public_stats_viewable": (status.get("publicStatsViewable", True)), + "made_for_kids": status.get("madeForKids", False), + "topic_categories": ",".join( + item.get("topicDetails", {}).get("topicCategories", []) ), - 'made_for_kids': status.get('madeForKids', False), - 'topic_categories': ','.join( - item.get('topicDetails', {}) - .get('topicCategories', []) - ) } - + def _parse_duration(self, duration_str: str) -> int: """Convert ISO 8601 duration to seconds. - + Args: duration_str: Duration string (e.g., 'PT15M51S'). - - Returns: + + Returns + ------- Duration in seconds. """ - pattern = re.compile(r'PT(?:(\d+)H)?(?:(\d+)M)?(?:(\d+)S)?') + pattern = re.compile(r"PT(?:(\d+)H)?(?:(\d+)M)?(?:(\d+)S)?") match = pattern.match(duration_str) - + if not match: return 0 - + hours = int(match.group(1) or 0) minutes = int(match.group(2) or 0) seconds = int(match.group(3) or 0) - + return hours * 3600 + minutes * 60 + seconds - + def _get_caption_text(self, video_id: str) -> Optional[str]: """Get caption/transcript text for a video. - + Args: video_id: YouTube video ID. - - Returns: + + Returns + ------- Caption text (truncated to limit) or None if unavailable. """ try: - transcript_list = YouTubeTranscriptApi.list_transcripts( - video_id - ) - + transcript_list = YouTubeTranscriptApi.list_transcripts(video_id) + try: - transcript = transcript_list.find_transcript(['en']) - except: - transcript = ( - transcript_list.find_generated_transcript(['en']) - ) - + transcript = transcript_list.find_transcript(["en"]) + except (NoTranscriptFound, Exception): + transcript = transcript_list.find_generated_transcript(["en"]) + caption_data = transcript.fetch() - caption_text = ' '.join( - [entry['text'] for entry in caption_data] - ) - return caption_text[:self.caption_text_limit] - + caption_text = " ".join([entry["text"] for entry in caption_data]) + return caption_text[: self.caption_text_limit] + except (TranscriptsDisabled, NoTranscriptFound, Exception): return None - - def generate_search_queries( - self, topic_info: Dict - ) -> List[tuple]: + + def generate_search_queries(self, topic_info: Dict) -> List[tuple]: """Generate demographically diverse queries. - + Args: topic_info: Dictionary containing topic search terms. - - Returns: + + Returns + ------- List of (query, demographic_label) tuples. """ queries = [] base_terms = topic_info.get("search_terms", []) - + for term in base_terms: queries.append((term, "general")) - - def make_natural_query( - demographic: str, term: str, dim_key: str - ) -> str: + + def make_natural_query(demographic: str, term: str, dim_key: str) -> str: """Create natural-sounding queries by demographic type.""" - if dim_key == "race": return f"{demographic} {term}" - - elif dim_key == "gender": - gender_map = { - "Male": "man", - "Female": "woman" - } - natural_gender = gender_map.get( - demographic, demographic.lower() - ) + + if dim_key == "gender": + gender_map = {"Male": "man", "Female": "woman"} + natural_gender = gender_map.get(demographic, demographic.lower()) return f"{natural_gender} {term}" - - elif dim_key == "age": + + if dim_key == "age": age_map = { "Young (18-24)": "young adult", "Middle (25-39)": "middle aged", - "Older adults (40+)": "older adult" + "Older adults (40+)": "older adult", } natural_age = age_map.get(demographic, demographic) return f"{natural_age} {term}" - - elif dim_key == "language": + + if dim_key == "language": return f"{term} {demographic}" - - else: - return f"{demographic} {term}" - - for dim_key, demographic_values in ( - self.config['demographics'].items() - ): + + return f"{demographic} {term}" + + for dim_key, demographic_values in self.config["demographics"].items(): for demographic in demographic_values: variations = [demographic] - + if demographic == "Arab": variations = [ - "Arab", "Middle Eastern", "Arabic", - "MENA", "Arab American" + "Arab", + "Middle Eastern", + "Arabic", + "MENA", + "Arab American", ] elif demographic == "Indigenous": variations = [ - "Indigenous", "Native American", - "First Nations", "Aboriginal", "tribal" + "Indigenous", + "Native American", + "First Nations", + "Aboriginal", + "tribal", ] - + for term in base_terms[:2]: for variation in variations: - query = make_natural_query( - variation, term, dim_key - ) - queries.append( - (query, f"{dim_key}:{demographic}") - ) - + query = make_natural_query(variation, term, dim_key) + queries.append((query, f"{dim_key}:{demographic}")) + return queries - - def filter_quality_videos( + + def _validate_basic_video_metrics( + self, video: Dict, is_scarce_topic: bool, is_topic7: bool + ) -> bool: + """Validate video has minimum required metrics. + + Args: + video: Video metadata dictionary. + is_scarce_topic: Whether this is a scarce topic. + is_topic7: Whether this is topic 7. + + Returns + ------- + True if video passes basic validation, False otherwise. + """ + views = video.get("view_count", 0) + duration_seconds = video.get("duration_seconds", 0) + + if views == 0 or duration_seconds == 0: + return False + + min_duration = 15 if is_topic7 else 30 + max_duration = 5400 if is_topic7 else 3600 + + if not (min_duration <= duration_seconds <= max_duration): + return False + + min_views = 100 if is_scarce_topic else 500 + return not views < min_views + + def _detect_clickbait_patterns(self, title: str) -> tuple: + """Detect clickbait patterns in title. + + Args: + title: Video title. + + Returns + ------- + Tuple of (strong_count, moderate_count, is_extreme_clickbait). + """ + title_lower = title.lower() + + strong_clickbait_patterns = [ + r"\byou won\'?t believe\b", + r"\bshocking truth\b", + r"\bdoctors hate\b", + r"\bone weird trick\b", + r"\bwhat happens next\b", + r"\bmind[- ]?blowing\b", + r"\bthis is why\b", + r"\bthe truth about\b.*\bthey don\'?t want\b", + ] + + moderate_clickbait_patterns = [ + r"\bgone wrong\b", + r"\bnumber \d+ will\b", + r"\byou need to see\b", + r"\bwait for it\b", + r"\bwatch till the end\b", + ] + + strong_clickbait_count = sum( + 1 + for pattern in strong_clickbait_patterns + if re.search(pattern, title_lower) + ) + + moderate_clickbait_count = sum( + 1 + for pattern in moderate_clickbait_patterns + if re.search(pattern, title_lower) + ) + + is_extreme_clickbait = strong_clickbait_count >= 2 or ( + strong_clickbait_count >= 1 and moderate_clickbait_count >= 2 + ) + + return strong_clickbait_count, moderate_clickbait_count, is_extreme_clickbait + + def _detect_title_quality_issues( + self, title: str, duration_seconds: int, is_topic7: bool + ) -> Dict: + """Detect quality issues in title. + + Args: + title: Video title. + duration_seconds: Video duration in seconds. + is_topic7: Whether this is topic 7. + + Returns + ------- + Dictionary of quality issue flags. + """ + title_lower = title.lower() + + if len(title) > 0: + caps_ratio = sum(1 for c in title if c.isupper()) / len(title) + excessive_caps = caps_ratio > 0.55 + else: + excessive_caps = False + + excessive_punctuation = ( + title.count("!") > 5 + or title.count("?") > 5 + or len(re.findall(r"[!?]{3,}", title)) > 0 + or len(re.findall(r"\.{3,}", title)) > 2 + ) + + emoji_pattern = re.compile( + "[" + "\U0001f600-\U0001f64f" + "\U0001f300-\U0001f5ff" + "\U0001f680-\U0001f6ff" + "\U0001f1e0-\U0001f1ff" + "]+", + flags=re.UNICODE, + ) + emoji_count = len(emoji_pattern.findall(title)) + excessive_emojis = emoji_count > 5 + + title_words = title.split() + very_short_title = ( + len(title_words) < 3 and duration_seconds > 300 and not is_topic7 + ) + + generic_only_title = all( + word in ["video", "clip", "footage", "content", "new", "best", "top"] + for word in title_lower.split() + if len(word) > 3 + ) + + return { + "excessive_caps": excessive_caps, + "excessive_punctuation": excessive_punctuation, + "excessive_emojis": excessive_emojis, + "very_short_title": very_short_title, + "generic_only_title": generic_only_title, + } + + def _detect_spam_and_suspicious_patterns( + self, + title: str, + channel_title: str, + views: int, + likes: int, + comments: int, + engagement_rate: float, + ) -> Dict: + """Detect spam and suspicious engagement patterns. + + Args: + title: Video title. + channel_title: Channel title. + views: View count. + likes: Like count. + comments: Comment count. + engagement_rate: Engagement rate. + + Returns + ------- + Dictionary of spam/suspicious pattern flags. + """ + title_lower = title.lower() + + spam_patterns = [ + r"\bfree money\b", + r"\bget rich quick\b", + r"\bclick here now\b", + r"\bfree download\b.*\bcrack\b", + r"\bfree robux\b", + r"\bfree vbucks\b", + r"\b100% working\b", + ] + + has_spam = any(re.search(pattern, title_lower) for pattern in spam_patterns) + + has_valid_channel = ( + len(channel_title) >= 3 + and not channel_title.replace(" ", "").isdigit() + and channel_title.strip() != "" + ) + + suspicious_engagement = views > 10000 and likes == 0 and comments == 0 + + very_low_engagement = views > 5000 and engagement_rate < 0.0001 + + return { + "has_spam": has_spam, + "has_valid_channel": has_valid_channel, + "suspicious_engagement": suspicious_engagement, + "very_low_engagement": very_low_engagement, + } + + def _calculate_engagement_metrics( + self, views: int, likes: int, comments: int + ) -> Dict: + """Calculate engagement metrics. + + Args: + views: View count. + likes: Like count. + comments: Comment count. + + Returns + ------- + Dictionary of engagement metrics. + """ + engagement_rate = (likes + comments) / views if views > 0 else 0 + like_ratio = likes / views if views > 0 else 0 + comment_ratio = comments / views if views > 0 else 0 + + return { + "engagement_rate": engagement_rate, + "like_ratio": like_ratio, + "comment_ratio": comment_ratio, + } + + def _calculate_quality_score( + self, + duration_seconds: int, + engagement_rate: float, + views: int, + like_ratio: float, + has_description: bool, + has_valid_channel: bool, + moderate_clickbait_count: int, + strong_clickbait_count: int, + title_issues: Dict, + very_low_engagement: bool, + ) -> int: + """Calculate numeric quality score based on multiple factors. + + Args: + duration_seconds: Video duration in seconds. + engagement_rate: Engagement rate. + views: View count. + like_ratio: Like ratio. + has_description: Whether video has description. + has_valid_channel: Whether channel is valid. + moderate_clickbait_count: Count of moderate clickbait patterns. + strong_clickbait_count: Count of strong clickbait patterns. + title_issues: Dictionary of title quality issues. + very_low_engagement: Whether engagement is very low. + + Returns + ------- + Quality score (integer). + """ + quality_score = 50 + + if 120 <= duration_seconds <= 900: + quality_score += 10 + elif 60 <= duration_seconds <= 1800: + quality_score += 5 + + if engagement_rate >= 0.02: + quality_score += 15 + elif engagement_rate >= 0.01: + quality_score += 10 + elif engagement_rate >= 0.005: + quality_score += 5 + + if views >= 100000: + quality_score += 10 + elif views >= 10000: + quality_score += 7 + elif views >= 5000: + quality_score += 5 + elif views >= 1000: + quality_score += 3 + + if like_ratio >= 0.02: + quality_score += 8 + elif like_ratio >= 0.01: + quality_score += 5 + elif like_ratio >= 0.005: + quality_score += 3 + + if has_description: + quality_score += 5 + if has_valid_channel: + quality_score += 5 + + if moderate_clickbait_count > 0: + quality_score -= 5 + if strong_clickbait_count > 0: + quality_score -= 10 + if title_issues["excessive_caps"]: + quality_score -= 10 + if title_issues["excessive_punctuation"]: + quality_score -= 10 + if title_issues["excessive_emojis"]: + quality_score -= 8 + if title_issues["very_short_title"]: + quality_score -= 8 + if title_issues["generic_only_title"]: + quality_score -= 12 + if very_low_engagement: + quality_score -= 15 + + return quality_score + + def _should_hard_reject( self, - videos_data: List[Dict], - topic_id: int | None = None + is_extreme_clickbait: bool, + spam_flags: Dict, + title_issues: Dict, + ) -> bool: + """Determine if video should be immediately rejected. + + Args: + is_extreme_clickbait: Whether title has extreme clickbait. + spam_flags: Dictionary of spam/suspicious flags. + title_issues: Dictionary of title quality issues. + + Returns + ------- + True if video should be hard rejected, False otherwise. + """ + return ( + is_extreme_clickbait + or spam_flags["has_spam"] + or spam_flags["suspicious_engagement"] + or not spam_flags["has_valid_channel"] + or ( + title_issues["excessive_caps"] and title_issues["excessive_punctuation"] + ) + ) + + + def filter_quality_videos( + self, videos_data: List[Dict], topic_id: int | None = None ) -> List[Dict]: """Research-based video quality filtering system. - + Args: videos_data: List of video metadata dictionaries. topic_id: Optional topic ID for special handling. - - Returns: + + Returns + ------- Filtered list of quality videos. """ filtered = [] is_scarce_topic = topic_id in [4, 5, 6, 7, 8, 12] + is_topic7 = topic_id == 7 min_quality_threshold = 30 if is_scarce_topic else 35 - + for video in videos_data: - views = video.get('view_count', 0) - likes = video.get('like_count', 0) - comments = video.get('comment_count', 0) - duration_seconds = video.get('duration_seconds', 0) - title = video.get('title', '') - description = video.get('description', '') - channel_title = video.get('channel_title', '') - - if views == 0 or duration_seconds == 0: - continue - - is_topic7 = (topic_id == 7) - min_duration = 15 if is_topic7 else 30 - max_duration = 5400 if is_topic7 else 3600 - - if not (min_duration <= duration_seconds <= max_duration): - continue - - engagement_rate = ( - (likes + comments) / views if views > 0 else 0 - ) - like_ratio = likes / views if views > 0 else 0 - comment_ratio = comments / views if views > 0 else 0 - - min_views = 100 if is_scarce_topic else 500 - if views < min_views: + views = video.get("view_count", 0) + likes = video.get("like_count", 0) + comments = video.get("comment_count", 0) + duration_seconds = video.get("duration_seconds", 0) + title = video.get("title", "") + description = video.get("description", "") + channel_title = video.get("channel_title", "") + + # Step 1: Validate basic metrics + if not self._validate_basic_video_metrics( + video, is_scarce_topic, is_topic7 + ): continue - - title_lower = title.lower() - - strong_clickbait_patterns = [ - r'\byou won\'?t believe\b', - r'\bshocking truth\b', - r'\bdoctors hate\b', - r'\bone weird trick\b', - r'\bwhat happens next\b', - r'\bmind[- ]?blowing\b', - r'\bthis is why\b', - r'\bthe truth about\b.*\bthey don\'?t want\b', - ] - - moderate_clickbait_patterns = [ - r'\bgone wrong\b', - r'\bnumber \d+ will\b', - r'\byou need to see\b', - r'\bwait for it\b', - r'\bwatch till the end\b', - ] - - strong_clickbait_count = sum( - 1 for pattern in strong_clickbait_patterns - if re.search(pattern, title_lower) - ) - - moderate_clickbait_count = sum( - 1 for pattern in moderate_clickbait_patterns - if re.search(pattern, title_lower) - ) - - is_extreme_clickbait = ( - strong_clickbait_count >= 2 or - ( - strong_clickbait_count >= 1 and - moderate_clickbait_count >= 2 - ) - ) - - if len(title) > 0: - caps_ratio = ( - sum(1 for c in title if c.isupper()) / len(title) - ) - excessive_caps = caps_ratio > 0.55 - else: - excessive_caps = False - - excessive_punctuation = ( - title.count('!') > 5 or - title.count('?') > 5 or - len(re.findall(r'[!?]{3,}', title)) > 0 or - len(re.findall(r'\.{3,}', title)) > 2 - ) - - emoji_pattern = re.compile( - "[" - "\U0001F600-\U0001F64F" - "\U0001F300-\U0001F5FF" - "\U0001F680-\U0001F6FF" - "\U0001F1E0-\U0001F1FF" - "]+", - flags=re.UNICODE - ) - emoji_count = len(emoji_pattern.findall(title)) - excessive_emojis = emoji_count > 5 - - spam_patterns = [ - r'\bfree money\b', - r'\bget rich quick\b', - r'\bclick here now\b', - r'\bfree download\b.*\bcrack\b', - r'\bfree robux\b', - r'\bfree vbucks\b', - r'\b100% working\b', - ] - - has_spam = any( - re.search(pattern, title_lower) - for pattern in spam_patterns - ) - - title_words = title.split() - very_short_title = ( - len(title_words) < 3 and - duration_seconds > 300 and - not is_topic7 + + # Step 2: Calculate engagement metrics + engagement_metrics = self._calculate_engagement_metrics( + views, likes, comments ) - - has_valid_channel = ( - len(channel_title) >= 3 and - not channel_title.replace(' ', '').isdigit() and - channel_title.strip() != '' + engagement_rate = engagement_metrics["engagement_rate"] + like_ratio = engagement_metrics["like_ratio"] + comment_ratio = engagement_metrics["comment_ratio"] + + # Step 3: Detect clickbait patterns + strong_clickbait_count, moderate_clickbait_count, is_extreme_clickbait = ( + self._detect_clickbait_patterns(title) ) - - suspicious_engagement = ( - views > 10000 and - likes == 0 and - comments == 0 + + # Step 4: Detect title quality issues + title_issues = self._detect_title_quality_issues( + title, duration_seconds, is_topic7 ) - - very_low_engagement = ( - views > 5000 and - engagement_rate < 0.0001 + + # Step 5: Detect spam and suspicious patterns + spam_flags = self._detect_spam_and_suspicious_patterns( + title, channel_title, views, likes, comments, engagement_rate ) - + + # Step 6: Check for hard rejection criteria + if self._should_hard_reject(is_extreme_clickbait, spam_flags, title_issues): + continue + + # Step 7: Calculate quality score has_description = len(description) > 50 - - generic_only_title = all( - word in [ - 'video', 'clip', 'footage', 'content', - 'new', 'best', 'top' - ] - for word in title_lower.split() if len(word) > 3 - ) - - quality_score = 50 - - if 120 <= duration_seconds <= 900: - quality_score += 10 - elif 60 <= duration_seconds <= 1800: - quality_score += 5 - - if engagement_rate >= 0.02: - quality_score += 15 - elif engagement_rate >= 0.01: - quality_score += 10 - elif engagement_rate >= 0.005: - quality_score += 5 - - if views >= 100000: - quality_score += 10 - elif views >= 10000: - quality_score += 7 - elif views >= 5000: - quality_score += 5 - elif views >= 1000: - quality_score += 3 - - if like_ratio >= 0.02: - quality_score += 8 - elif like_ratio >= 0.01: - quality_score += 5 - elif like_ratio >= 0.005: - quality_score += 3 - - if has_description: - quality_score += 5 - if has_valid_channel: - quality_score += 5 - - if moderate_clickbait_count > 0: - quality_score -= 5 - if strong_clickbait_count > 0: - quality_score -= 10 - if excessive_caps: - quality_score -= 10 - if excessive_punctuation: - quality_score -= 10 - if excessive_emojis: - quality_score -= 8 - if very_short_title: - quality_score -= 8 - if generic_only_title: - quality_score -= 12 - if very_low_engagement: - quality_score -= 15 - - hard_reject = ( - is_extreme_clickbait or - has_spam or - suspicious_engagement or - not has_valid_channel or - (excessive_caps and excessive_punctuation) + quality_score = self._calculate_quality_score( + duration_seconds, + engagement_rate, + views, + like_ratio, + has_description, + spam_flags["has_valid_channel"], + moderate_clickbait_count, + strong_clickbait_count, + title_issues, + spam_flags["very_low_engagement"], ) - - min_quality_threshold = 30 if is_scarce_topic else 35 - - if not hard_reject and quality_score >= min_quality_threshold: - video['quality_score'] = quality_score - video['engagement_rate'] = engagement_rate - video['like_ratio'] = like_ratio - video['comment_ratio'] = comment_ratio - video['clickbait_score'] = ( + + # Step 8: Apply quality threshold and add to filtered list + if quality_score >= min_quality_threshold: + video["quality_score"] = quality_score + video["engagement_rate"] = engagement_rate + video["like_ratio"] = like_ratio + video["comment_ratio"] = comment_ratio + video["clickbait_score"] = ( strong_clickbait_count * 2 + moderate_clickbait_count ) - + filtered.append(video) - + return filtered - - def scrape_topic( + + def _initialize_topic_scraping( self, topic_id: int, videos_per_query: int = None - ) -> pd.DataFrame: - """Scrape videos for a specific topic with demographic variations. - + ) -> tuple: + """Initialize topic scraping by loading info and generating queries. + Args: topic_id: Integer ID of the topic to scrape. videos_per_query: Optional override for videos per query. - - Returns: - DataFrame containing all videos for the topic. + + Returns + ------- + Tuple of (topic_info, topic_name, existing_video_ids, search_queries, + videos_per_query, target_videos). """ if videos_per_query is None: - videos_per_query = ( - self.config['collection_settings'] - .get('videos_per_query', 5) + videos_per_query = self.config["collection_settings"].get( + "videos_per_query", 5 ) - + topic_info = TOPICS[topic_id] topic_name = topic_info["name"] - - print(f"\n{'='*60}") + + print(f"\n{'=' * 60}") print(f"Scraping Topic {topic_id}: {topic_name}") - print(f"{'='*60}") - + print(f"{'=' * 60}") + existing_video_ids = self._load_existing_video_ids(topic_id) - print( - f"Found {len(existing_video_ids)} existing videos " - f"for this topic" - ) - + print(f"Found {len(existing_video_ids)} existing videos for this topic") + channel_id = topic_info.get("channel_id") if channel_id: print(f"Filtering to channel: {channel_id}") - - all_videos = [] + search_queries = self.generate_search_queries(topic_info) - - target_videos = ( - self.config['collection_settings'] - .get('videos_per_topic', 60) - ) - + target_videos = self.config["collection_settings"].get("videos_per_topic", 60) videos_per_query = max(7, target_videos // len(search_queries)) - - new_videos_count = 0 - - for query, demographic_label in search_queries: - print( - f"\nSearching: {query} " - f"(demographic: {demographic_label})" + + return ( + topic_info, + topic_name, + existing_video_ids, + search_queries, + videos_per_query, + target_videos, + ) + + def _search_and_filter_query( + self, + query: str, + demographic_label: str, + topic_id: int, + topic_info: Dict, + topic_name: str, + existing_video_ids: set, + videos_per_query: int, + ) -> List[Dict]: + """Search videos for one query and filter for quality. + + Args: + query: Search query string. + demographic_label: Demographic label for this query. + topic_id: Integer ID of the topic. + topic_info: Dictionary containing topic information. + topic_name: Name of the topic. + existing_video_ids: Set of existing video IDs to skip. + videos_per_query: Number of videos to retrieve per query. + + Returns + ------- + List of filtered videos for this query. + """ + print(f"\nSearching: {query} (demographic: {demographic_label})") + + channel_id = topic_info.get("channel_id") + video_ids = self.search_videos( + query, + max_results=videos_per_query, + channel_id=channel_id, + topic_id=topic_id, + ) + print(f" Found {len(video_ids)} video IDs") + + new_video_ids = [vid for vid in video_ids if vid not in existing_video_ids] + print(f" New videos (not in existing data): {len(new_video_ids)}") + + if not new_video_ids: + return [] + + video_details = self.get_video_details(new_video_ids) + print(f" Retrieved details for {len(video_details)} new videos") + + filtered_videos = self.filter_quality_videos(video_details, topic_id=topic_id) + print(f" After quality filtering: {len(filtered_videos)} videos") + + for video in filtered_videos: + video["topic_id"] = topic_id + video["topic_name"] = topic_name + video["search_query"] = query + video["demographic_label"] = demographic_label + video["focus_areas"] = topic_info["focus"] + + return filtered_videos + + def _load_existing_topic_data(self, topic_id: int) -> Optional[pd.DataFrame]: + """Load existing topic data from JSON file. + + Args: + topic_id: Integer ID of the topic. + + Returns + ------- + DataFrame if file exists, None otherwise. + """ + topic_name = TOPICS[topic_id]["name"] + safe_name = re.sub(r"[^\w\s-]", "", topic_name).strip().replace(" ", "_") + + topic_dir = os.path.join(VIDEOS_DIR, f"{topic_id:02d}_{safe_name}") + json_path = os.path.join(topic_dir, f"{safe_name}_metadata.json") + + if os.path.exists(json_path): + try: + df = pd.read_json(json_path, orient="records") + print(f"Loaded {len(df)} existing videos from {json_path}") + return df + except Exception as e: + print(f"Warning: Error loading existing data: {e}") + return None + + return None + + def _merge_with_existing_data( + self, new_df: pd.DataFrame, topic_id: int + ) -> pd.DataFrame: + """Merge new videos with existing topic data and remove duplicates. + + Args: + new_df: DataFrame containing newly scraped videos. + topic_id: Integer ID of the topic. + + Returns + ------- + Merged DataFrame with duplicates removed. + """ + if self.config["search_settings"].get("remove_duplicates", True): + new_df = new_df.drop_duplicates(subset=["video_id"], keep="first") + + print(f"\nNew unique videos collected: {len(new_df)}") + + existing_df = self._load_existing_topic_data(topic_id) + + if existing_df is not None and not existing_df.empty: + combined_df = pd.concat([existing_df, new_df], ignore_index=True) + print(f"Total videos after merge: {len(combined_df)}") + else: + combined_df = new_df + print(f"No existing data, starting fresh with {len(combined_df)} videos") + + return combined_df + + def _sort_and_finalize_dataset(self, df: pd.DataFrame) -> pd.DataFrame: + """Sort by quality score and print final statistics. + + Args: + df: DataFrame to sort and finalize. + + Returns + ------- + Sorted DataFrame. + """ + if "quality_score" in df.columns: + df = df.sort_values( + ["quality_score", "view_count"], ascending=[False, False] ) - - video_ids = self.search_videos( + else: + df = df.sort_values("view_count", ascending=False) + + print(f"Final dataset size: {len(df)} videos") + return df + + def scrape_topic(self, topic_id: int, videos_per_query: int = None) -> pd.DataFrame: + """Scrape videos for a specific topic with demographic variations. + + Args: + topic_id: Integer ID of the topic to scrape. + videos_per_query: Optional override for videos per query. + + Returns + ------- + DataFrame containing all videos for the topic. + """ + ( + topic_info, + topic_name, + existing_video_ids, + search_queries, + videos_per_query, + target_videos, + ) = self._initialize_topic_scraping(topic_id, videos_per_query) + + all_videos = [] + + for query, demographic_label in search_queries: + filtered_videos = self._search_and_filter_query( query, - max_results=videos_per_query, - channel_id=channel_id, - topic_id=topic_id - ) - print(f" Found {len(video_ids)} video IDs") - - new_video_ids = [ - vid for vid in video_ids - if vid not in existing_video_ids - ] - print( - f" New videos (not in existing data): " - f"{len(new_video_ids)}" + demographic_label, + topic_id, + topic_info, + topic_name, + existing_video_ids, + videos_per_query, ) - - if new_video_ids: - video_details = self.get_video_details(new_video_ids) - print( - f" Retrieved details for {len(video_details)} " - f"new videos" - ) - - filtered_videos = self.filter_quality_videos( - video_details, topic_id=topic_id - ) - print( - f" After quality filtering: " - f"{len(filtered_videos)} videos" - ) - - for video in filtered_videos: - video['topic_id'] = topic_id - video['topic_name'] = topic_name - video['search_query'] = query - video['demographic_label'] = demographic_label - video['focus_areas'] = topic_info["focus"] - - all_videos.extend(filtered_videos) - new_videos_count += len(filtered_videos) - + all_videos.extend(filtered_videos) + new_df = pd.DataFrame(all_videos) - + if not new_df.empty: - if self.config['search_settings'].get( - 'remove_duplicates', True - ): - new_df = new_df.drop_duplicates( - subset=['video_id'], keep='first' - ) - - print(f"\nNew unique videos collected: {len(new_df)}") - - existing_df = self._load_existing_topic_data(topic_id) - - if existing_df is not None and not existing_df.empty: - combined_df = pd.concat( - [existing_df, new_df], ignore_index=True - ) - print(f"Total videos after merge: {len(combined_df)}") - else: - combined_df = new_df - print( - f"No existing data, starting fresh with " - f"{len(combined_df)} videos" - ) - - if 'quality_score' in combined_df.columns: - combined_df = combined_df.sort_values( - ['quality_score', 'view_count'], - ascending=[False, False] - ) - else: - combined_df = combined_df.sort_values( - 'view_count', ascending=False - ) - - print(f"Final dataset size: {len(combined_df)} videos") - return combined_df - else: - print(f"\nNo new videos found for {topic_name}") - existing_df = self._load_existing_topic_data(topic_id) - return ( - existing_df if existing_df is not None - else pd.DataFrame() - ) - + combined_df = self._merge_with_existing_data(new_df, topic_id) + return self._sort_and_finalize_dataset(combined_df) + + print(f"\nNo new videos found for {topic_name}") + existing_df = self._load_existing_topic_data(topic_id) + return existing_df if existing_df is not None else pd.DataFrame() + def _load_existing_video_ids(self, topic_id: int) -> set: """Load existing video IDs from JSON file for a topic. - + Args: topic_id: Integer ID of the topic. - - Returns: + + Returns + ------- Set of video IDs that already exist in the dataset. """ topic_name = TOPICS[topic_id]["name"] - safe_name = re.sub( - r'[^\w\s-]', '', topic_name - ).strip().replace(' ', '_') - - topic_dir = os.path.join( - VIDEOS_DIR, f"{topic_id:02d}_{safe_name}" - ) - json_path = os.path.join( - topic_dir, f"{safe_name}_metadata.json" - ) - + safe_name = re.sub(r"[^\w\s-]", "", topic_name).strip().replace(" ", "_") + + topic_dir = os.path.join(VIDEOS_DIR, f"{topic_id:02d}_{safe_name}") + json_path = os.path.join(topic_dir, f"{safe_name}_metadata.json") + if os.path.exists(json_path): try: - with open(json_path, 'r', encoding='utf-8') as f: + with open(json_path, "r", encoding="utf-8") as f: existing_data = json.load(f) - video_ids = { - video['video_id'] for video in existing_data - if 'video_id' in video + return { + video["video_id"] + for video in existing_data + if "video_id" in video } - return video_ids except Exception as e: print(f"Warning: Error loading existing video IDs: {e}") return set() - + return set() - - def _load_existing_topic_data( - self, topic_id: int - ) -> Optional[pd.DataFrame]: - """Load existing topic data from JSON file. - - Args: - topic_id: Integer ID of the topic. - - Returns: - DataFrame if file exists, None otherwise. - """ - topic_name = TOPICS[topic_id]["name"] - safe_name = re.sub( - r'[^\w\s-]', '', topic_name - ).strip().replace(' ', '_') - - topic_dir = os.path.join( - VIDEOS_DIR, f"{topic_id:02d}_{safe_name}" - ) - json_path = os.path.join( - topic_dir, f"{safe_name}_metadata.json" - ) - - if os.path.exists(json_path): - try: - df = pd.read_json(json_path, orient='records') - print( - f"Loaded {len(df)} existing videos " - f"from {json_path}" - ) - return df - except Exception as e: - print(f"Warning: Error loading existing data: {e}") - return None - - return None - + def save_topic_data(self, df: pd.DataFrame, topic_id: int): """Save topic data as both CSV and JSON. - + Args: df: DataFrame containing video metadata. topic_id: Integer ID of the topic. @@ -1086,78 +1081,57 @@ def save_topic_data(self, df: pd.DataFrame, topic_id: int): if df.empty: print("Warning: No data to save (empty DataFrame)") return - + topic_name = TOPICS[topic_id]["name"] - safe_name = re.sub( - r'[^\w\s-]', '', topic_name - ).strip().replace(' ', '_') - - topic_dir = os.path.join( - VIDEOS_DIR, f"{topic_id:02d}_{safe_name}" - ) + safe_name = re.sub(r"[^\w\s-]", "", topic_name).strip().replace(" ", "_") + + topic_dir = os.path.join(VIDEOS_DIR, f"{topic_id:02d}_{safe_name}") os.makedirs(topic_dir, exist_ok=True) - - csv_path = os.path.join( - topic_dir, f"{safe_name}_metadata.csv" - ) - df.to_csv(csv_path, index=False, encoding='utf-8') + + csv_path = os.path.join(topic_dir, f"{safe_name}_metadata.csv") + df.to_csv(csv_path, index=False, encoding="utf-8") print(f"Saved CSV: {csv_path}") - - json_path = os.path.join( - topic_dir, f"{safe_name}_metadata.json" - ) - df.to_json( - json_path, orient='records', indent=2, force_ascii=False - ) + + json_path = os.path.join(topic_dir, f"{safe_name}_metadata.json") + df.to_json(json_path, orient="records", indent=2, force_ascii=False) print(f"Saved JSON: {json_path}") - + summary = { - 'topic_id': topic_id, - 'topic_name': topic_name, - 'total_videos': len(df), - 'total_views': int(df['view_count'].sum()), - 'avg_views': int(df['view_count'].mean()), - 'median_duration_seconds': int( - df['duration_seconds'].median() - ), - 'videos_with_captions': int(df['has_captions'].sum()), - 'caption_percentage': ( + "topic_id": topic_id, + "topic_name": topic_name, + "total_videos": len(df), + "total_views": int(df["view_count"].sum()), + "avg_views": int(df["view_count"].mean()), + "median_duration_seconds": int(df["duration_seconds"].median()), + "videos_with_captions": int(df["has_captions"].sum()), + "caption_percentage": ( f"{(df['has_captions'].sum() / len(df) * 100):.1f}%" ), - 'demographic_distribution': ( - df['demographic_label'].value_counts().to_dict() + "demographic_distribution": ( + df["demographic_label"].value_counts().to_dict() ), - 'last_updated': datetime.now().isoformat() + "last_updated": datetime.now().isoformat(), } - - if 'quality_score' in df.columns: - summary['avg_quality_score'] = float( - df['quality_score'].mean() - ) - summary['quality_score_distribution'] = ( - df['quality_score'].value_counts() - .sort_index().to_dict() - ) - - if 'engagement_rate' in df.columns: - summary['avg_engagement_rate'] = float( - df['engagement_rate'].mean() + + if "quality_score" in df.columns: + summary["avg_quality_score"] = float(df["quality_score"].mean()) + summary["quality_score_distribution"] = ( + df["quality_score"].value_counts().sort_index().to_dict() ) - - summary_path = os.path.join( - topic_dir, f"{safe_name}_summary.json" - ) - with open(summary_path, 'w') as f: + + if "engagement_rate" in df.columns: + summary["avg_engagement_rate"] = float(df["engagement_rate"].mean()) + + summary_path = os.path.join(topic_dir, f"{safe_name}_summary.json") + with open(summary_path, "w") as f: json.dump(summary, f, indent=2) print(f"Saved summary: {summary_path}") - - print( - f"\nSuccessfully saved {len(df)} videos for {topic_name}" - ) + + print(f"\nSuccessfully saved {len(df)} videos for {topic_name}") def main(): - """Main execution function.""" + """Execute the YouTube metadata scraper.""" if CONFIG is None: print("\nERROR: Could not load configuration file.") print( @@ -1165,93 +1139,73 @@ def main(): "in the same directory as this script." ) return - - if not API_KEY: - print( - "ERROR: Please set your YouTube Data API key " - "in your config file" - ) + + if not API_KEY: + print("ERROR: Please set your YouTube Data API key in your config file") print("\nTo get an API key:") print("1. Go to https://console.cloud.google.com/") print("2. Create a new project or select existing one") print("3. Enable YouTube Data API v3") print("4. Create credentials (API key)") print( - "5. Update the 'api_key' field in config.yaml " - "or config.json with your key" + "5. Update the 'api_key' field in config.yaml or config.json with your key" ) return - + os.makedirs(VIDEOS_DIR, exist_ok=True) - + scraper = YouTubeMetadataScraper(API_KEY, CONFIG) - + all_topics_data = [] - + for topic_id in range(1, 13): try: df = scraper.scrape_topic(topic_id) - + if not df.empty: scraper.save_topic_data(df, topic_id) all_topics_data.append(df) else: - print( - f"Warning: No data collected for topic {topic_id}" - ) - + print(f"Warning: No data collected for topic {topic_id}") + time.sleep(2) - + except Exception as e: print(f"Error processing topic {topic_id}: {e}") continue - + if all_topics_data: combined_df = pd.concat(all_topics_data, ignore_index=True) - combined_path = os.path.join( - VIDEOS_DIR, "all_topics_combined.csv" - ) - + combined_path = os.path.join(VIDEOS_DIR, "all_topics_combined.csv") + if os.path.exists(combined_path): existing_df = pd.read_csv(combined_path) - combined_df = pd.concat( - [existing_df, combined_df], ignore_index=True - ) - combined_df = combined_df.drop_duplicates( - subset=['video_id'], keep='first' - ) - - combined_df.to_csv(combined_path, index=False, encoding='utf-8') - print(f"\n{'='*60}") + combined_df = pd.concat([existing_df, combined_df], ignore_index=True) + combined_df = combined_df.drop_duplicates(subset=["video_id"], keep="first") + + combined_df.to_csv(combined_path, index=False, encoding="utf-8") + print(f"\n{'=' * 60}") print(f"Combined dataset saved: {combined_path}") print(f"Total videos across all topics: {len(combined_df)}") - print(f"{'='*60}") - + print(f"{'=' * 60}") + overall_summary = { - 'total_videos': len(combined_df), - 'total_topics': 13, - 'videos_per_topic': ( - combined_df.groupby('topic_name').size().to_dict() - ), - 'total_views': int(combined_df['view_count'].sum()), - 'videos_with_captions': int( - combined_df['has_captions'].sum() - ), - 'caption_percentage': ( + "total_videos": len(combined_df), + "total_topics": 13, + "videos_per_topic": (combined_df.groupby("topic_name").size().to_dict()), + "total_views": int(combined_df["view_count"].sum()), + "videos_with_captions": int(combined_df["has_captions"].sum()), + "caption_percentage": ( f"{(combined_df['has_captions'].sum() / len(combined_df) * 100):.1f}%" ), - 'avg_duration_seconds': int( - combined_df['duration_seconds'].mean() - ) + "avg_duration_seconds": int(combined_df["duration_seconds"].mean()), } - - summary_path = os.path.join( - VIDEOS_DIR, "overall_summary.json" - ) - with open(summary_path, 'w') as f: + + summary_path = os.path.join(VIDEOS_DIR, "overall_summary.json") + with open(summary_path, "w") as f: json.dump(overall_summary, f, indent=2) print(f"Overall summary saved: {summary_path}") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/sonic-o1/02_caption_generation/README.md b/sonic-o1/02_caption_generation/README.md index 03716a0..fced83c 100644 --- a/sonic-o1/02_caption_generation/README.md +++ b/sonic-o1/02_caption_generation/README.md @@ -1,7 +1,43 @@ # Caption Generation with WhisperX +## Overview + This directory handles automatic caption generation for videos that don't have captions available from YouTube. It uses WhisperX for high-quality transcription with word-level timestamps. +### Output Format + +The script generates: +- **SRT files**: `caption_XXX.srt` - YouTube-style captions +- **JSON files**: `caption_XXX.json` - Full transcription details with word-level timestamps + +#### SRT Format Example +``` +1 +00:00:04,720 --> 00:00:10,720 +Hello folks I'm delighted today to be joined by +Dr John Mckeown head of GP teaching and Dr Naomi + +2 +00:00:10,720 --> 00:00:15,720 +Dow who is a GP and Senior clinical lecturer both +from the University of Aberdeen +``` + +## GPU Requirements + +**Minimum Requirements:** +- NVIDIA GPU with CUDA support (compute capability 6.0+) +- 8GB VRAM minimum (for base/small models) +- 16GB+ VRAM recommended (for large-v2/large-v3 models) +- CUDA 12.1+ toolkit + +**Recommended Setup:** +- NVIDIA A40/A100 or equivalent +- 32GB+ system RAM +- CUDA 12.1+ with cuDNN support + +**Note**: CPU-only processing is possible but significantly slower (5-15x) and not recommended for production use. + ## Prerequisites Before running this step, you must have completed the data curation step (see [01_data_curation](../01_data_curation/)): @@ -32,11 +68,6 @@ source ~/.bashrc ### 2. Set cache directories to scratch (avoid disk quota issues) ```bash -# Set all cache directories to scratch -export UV_CACHE_DIR=~/scratch/.uv_cache -export HF_HOME=~/scratch/.huggingface -export TORCH_HOME=~/scratch/.torch -export NLTK_DATA=~/scratch/nltk_data # Create directories mkdir -p ~/scratch/.uv_cache ~/scratch/.huggingface ~/scratch/.torch ~/scratch/nltk_data @@ -69,10 +100,11 @@ uv pip install faster-whisper pyannote-audio ctranslate2 onnxruntime nltk uv pip install nvidia-cudnn-cu12 # Set cuDNN library path + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$(python -c "import nvidia.cudnn; print(nvidia.cudnn.__path__[0])")/lib echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$(python -c "import nvidia.cudnn 2>/dev/null && nvidia.cudnn.__path__[0]" 2>/dev/null)/lib' >> ~/.bashrc -# Download NLTK data +## NLTK DATA DOWNLOAD python << 'NLTK_EOF' import nltk import os @@ -162,6 +194,16 @@ python whisper_captionGen.py python whisper_captionGen.py --config my_config.yaml ``` +## Expected Processing Time + +- **GPU (NVIDIA A40)**: + - ~0.5-2 minutes per video (with large-v2) + - ~0.1-0.5 minutes per video (with base) + +- **CPU**: + - ~5-15 minutes per video (with base) + - Not recommended for large models + ## Model Size Comparison | Model | Parameters | Speed | Accuracy | Use Case | @@ -175,35 +217,6 @@ python whisper_captionGen.py --config my_config.yaml *Speed is relative to large-v2 on GPU* -## Output Format - -The script generates: -- **SRT files**: `caption_XXX.srt` - YouTube-style captions -- **JSON files**: `caption_XXX.json` - Full transcription details with word-level timestamps - -### SRT Format Example -``` -1 -00:00:04,720 --> 00:00:10,720 -Hello folks I'm delighted today to be joined by -Dr John Mckeown head of GP teaching and Dr Naomi - -2 -00:00:10,720 --> 00:00:15,720 -Dow who is a GP and Senior clinical lecturer both -from the University of Aberdeen -``` - -## Expected Processing Time - -- **GPU (NVIDIA A40)**: - - ~0.5-2 minutes per video (with large-v2) - - ~0.1-0.5 minutes per video (with base) - -- **CPU**: - - ~5-15 minutes per video (with base) - - Not recommended for large models - ## Troubleshooting ### 1. Disk Quota Exceeded @@ -278,35 +291,3 @@ nvidia-smi # Verify PyTorch sees GPU python -c "import torch; print(torch.cuda.is_available())" ``` - -## Quality Verification - -After processing, verify the generated captions: -```bash -# View generated caption -cat dataset/captions/01_Patient-Doctor_Consultations/caption_001.srt - -# Check how many captions were generated -ls dataset/captions/01_Patient-Doctor_Consultations/caption_*.srt | wc -l -``` - -The script automatically skips videos that already have captions (controlled by `skip_existing` in config). - -## Environment Variables Summary - -Add these to your `~/.bashrc` for permanent setup: -```bash -# FFmpeg -export PKG_CONFIG_PATH=../.local/lib/pkgconfig:$PKG_CONFIG_PATH -export LD_LIBRARY_PATH=../.local/lib:$LD_LIBRARY_PATH - -# Cache directories (avoid disk quota) -export UV_CACHE_DIR=~/scratch/.uv_cache -export HF_HOME=~/scratch/.huggingface -export TORCH_HOME=~/scratch/.torch -export NLTK_DATA=~/scratch/nltk_data -export TMPDIR=~/scratch - -# cuDNN -export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$(python -c "import nvidia.cudnn 2>/dev/null && print(nvidia.cudnn.__path__[0])" 2>/dev/null)/lib -``` diff --git a/sonic-o1/02_caption_generation/config_whisper.yaml b/sonic-o1/02_caption_generation/config_whisper.yaml index f0d8903..16d3026 100644 --- a/sonic-o1/02_caption_generation/config_whisper.yaml +++ b/sonic-o1/02_caption_generation/config_whisper.yaml @@ -26,4 +26,4 @@ output: # Processing options processing: skip_existing: true # Skip videos that already have captions - verbose: true # Print detailed progress \ No newline at end of file + verbose: true # Print detailed progress diff --git a/sonic-o1/02_caption_generation/whisper_captionGen.py b/sonic-o1/02_caption_generation/whisper_captionGen.py index ec4c66c..3c282bf 100644 --- a/sonic-o1/02_caption_generation/whisper_captionGen.py +++ b/sonic-o1/02_caption_generation/whisper_captionGen.py @@ -1,25 +1,33 @@ #!/usr/bin/env python3 +"""whisper_captionGen.py. + +Process audio files that need captions using WhisperX. + +Generates YouTube-style SRT captions for videos without existing captions. + +Author: SONIC-O1 Team """ -Process audio files that need captions using WhisperX -Generates YouTube-style SRT captions for videos without existing captions -""" -import whisperx + +import argparse import gc -from pathlib import Path import json -import yaml -from typing import List, Dict, Optional +import traceback +from pathlib import Path +from typing import Dict, List + import torch +import whisperx +import yaml def load_config(config_path: str = "config_whisper.yaml") -> Dict: - """Load configuration from YAML file""" - with open(config_path, 'r') as f: + """Load configuration from YAML file.""" + with open(config_path, "r") as f: return yaml.safe_load(f) def format_timestamp(seconds: float) -> str: - """Convert seconds to SRT timestamp format (HH:MM:SS,mmm)""" + """Convert seconds to SRT timestamp format (HH:MM:SS,mmm).""" hours = int(seconds // 3600) minutes = int((seconds % 3600) // 60) secs = int(seconds % 60) @@ -29,209 +37,200 @@ def format_timestamp(seconds: float) -> str: def segments_to_srt(segments: List[Dict], max_chars_per_line: int = 42) -> str: """ - Convert WhisperX segments to SRT format - Similar to YouTube caption style with line breaking + Convert WhisperX segments to SRT format. + + Similar to YouTube caption style with line breaking. """ srt_lines = [] - + for i, segment in enumerate(segments, 1): - start = segment['start'] - end = segment['end'] - text = segment['text'].strip() - + start = segment["start"] + end = segment["end"] + text = segment["text"].strip() + # Format timestamps start_time = format_timestamp(start) end_time = format_timestamp(end) - + # Break long lines (YouTube typically uses ~42 chars per line) if len(text) > max_chars_per_line: words = text.split() lines = [] current_line = [] current_length = 0 - + for word in words: word_length = len(word) + 1 # +1 for space if current_length + word_length > max_chars_per_line and current_line: - lines.append(' '.join(current_line)) + lines.append(" ".join(current_line)) current_line = [word] current_length = word_length else: current_line.append(word) current_length += word_length - + if current_line: - lines.append(' '.join(current_line)) - - text = '\n'.join(lines) - + lines.append(" ".join(current_line)) + + text = "\n".join(lines) + # Add SRT entry srt_lines.append(f"{i}") srt_lines.append(f"{start_time} --> {end_time}") srt_lines.append(text) srt_lines.append("") # Blank line between entries - - return '\n'.join(srt_lines) + return "\n".join(srt_lines) -def transcribe_audio( - audio_path: str, - config: Dict -) -> Dict: + +def transcribe_audio(audio_path: str, config: Dict) -> Dict: """ - Transcribe audio file using WhisperX with alignment - + Transcribe audio file using WhisperX with alignment. + Args: audio_path: Path to audio file (.m4a) config: Configuration dictionary - - Returns: + + Returns + ------- Dictionary with aligned segments """ - model_cfg = config['model'] - device = model_cfg['device'] - language = model_cfg['language'] - + model_cfg = config["model"] + device = model_cfg["device"] + language = model_cfg["language"] + print(f"Loading audio: {audio_path}") audio = whisperx.load_audio(audio_path) - + # 1. Transcribe with Whisper print(f"Loading Whisper model: {model_cfg['name']}") model = whisperx.load_model( - model_cfg['name'], + model_cfg["name"], device, - compute_type=model_cfg['compute_type'], - language=language # Add language here at model load time + compute_type=model_cfg["compute_type"], + language=language, # Add language here at model load time ) - + print(f"Transcribing (language: {language})...") - + # Transcribe - language already set in model - result = model.transcribe(audio, batch_size=model_cfg['batch_size']) - + result = model.transcribe(audio, batch_size=model_cfg["batch_size"]) + # Delete model to free memory del model gc.collect() if device == "cuda": torch.cuda.empty_cache() - + # 2. Align whisper output for better timestamps print(f"Aligning transcription for language: {language}") - model_a, metadata = whisperx.load_align_model( - language_code=language, - device=device - ) - + model_a, metadata = whisperx.load_align_model(language_code=language, device=device) + result = whisperx.align( result["segments"], model_a, metadata, audio, device, - return_char_alignments=False + return_char_alignments=False, ) - + # Delete alignment model del model_a gc.collect() if device == "cuda": torch.cuda.empty_cache() - + return result -def process_topic( - topic_path: Path, - config: Dict -): + +def process_topic(topic_path: Path, config: Dict): """ - Process all videos in a topic that need captions - + Process all videos in a topic that need captions. + Args: topic_path: Path to topic directory (captions/TOPIC_NAME) config: Configuration dictionary """ needs_whisper_file = topic_path / "needs_whisper.txt" - + if not needs_whisper_file.exists(): - if config['processing']['verbose']: + if config["processing"]["verbose"]: print(f"No needs_whisper.txt found in {topic_path.name}") return - + # Read audio files that need captions - with open(needs_whisper_file, 'r') as f: + with open(needs_whisper_file, "r") as f: audio_files = [line.strip() for line in f if line.strip()] - + if not audio_files: - if config['processing']['verbose']: + if config["processing"]["verbose"]: print(f"No audio files need captions in {topic_path.name}") return - - print(f"\n{'='*60}") + + print(f"\n{'=' * 60}") print(f"Processing topic: {topic_path.name}") print(f"Audio files to process: {len(audio_files)}") - print(f"{'='*60}\n") - + print(f"{'=' * 60}\n") + # Update paths to match your structure dataset_root = topic_path.parent.parent # Go up from captions/TOPIC to dataset/ audios_dir = dataset_root / "audios" / topic_path.name captions_dir = topic_path # Already in captions/TOPIC - + for audio_filename in audio_files: # Extract video ID from audio filename (e.g., audio_015.m4a -> 015) - video_id = audio_filename.replace('audio_', '').replace('.m4a', '') - + video_id = audio_filename.replace("audio_", "").replace(".m4a", "") + audio_file = audios_dir / audio_filename caption_file = captions_dir / f"caption_{video_id}.srt" - + if not audio_file.exists(): print(f"[WARNING] Audio file not found: {audio_file}") continue - - if caption_file.exists() and config['processing']['skip_existing']: - if config['processing']['verbose']: + + if caption_file.exists() and config["processing"]["skip_existing"]: + if config["processing"]["verbose"]: print(f"[SKIP] Caption already exists: {caption_file}") continue - + print(f"\n[PROCESSING] Video: {video_id}") - if config['processing']['verbose']: + if config["processing"]["verbose"]: print(f" Audio: {audio_file}") - + try: # Transcribe result = transcribe_audio(str(audio_file), config) - + # Convert to SRT srt_content = segments_to_srt( result["segments"], - max_chars_per_line=config['output']['max_chars_per_line'] + max_chars_per_line=config["output"]["max_chars_per_line"], ) - + # Save SRT file - with open(caption_file, 'w', encoding='utf-8') as f: + with open(caption_file, "w", encoding="utf-8") as f: f.write(srt_content) - + print(f"[SUCCESS] Caption saved: {caption_file}") - + # Optionally save JSON with full details - if config['output']['save_json']: + if config["output"]["save_json"]: json_file = captions_dir / f"caption_{video_id}.json" - with open(json_file, 'w', encoding='utf-8') as f: + with open(json_file, "w", encoding="utf-8") as f: json.dump(result, f, indent=2, ensure_ascii=False) - - if config['processing']['verbose']: + + if config["processing"]["verbose"]: print(f" JSON saved: {json_file}") - + except Exception as e: print(f"[ERROR] Failed to process {video_id}: {e}") - if config['processing']['verbose']: - import traceback + if config["processing"]["verbose"]: traceback.print_exc() def main(): - """Main processing function""" - import argparse - + """Run main processing function.""" parser = argparse.ArgumentParser( description="Generate captions for videos using WhisperX" ) @@ -239,23 +238,23 @@ def main(): "--config", type=str, default="config_whisper.yaml", - help="Path to configuration file" + help="Path to configuration file", ) - + args = parser.parse_args() - + # Load configuration print(f"Loading configuration from: {args.config}") config = load_config(args.config) - - dataset_root = Path(config['dataset']['root']) - + + dataset_root = Path(config["dataset"]["root"]) + if not dataset_root.exists(): print(f"[ERROR] Dataset root not found: {dataset_root}") return - + # Get topics to process from captions directory - topics = config['dataset']['topics'] + topics = config["dataset"]["topics"] if topics: topic_dirs = [dataset_root / "captions" / topic for topic in topics] topic_dirs = [t for t in topic_dirs if t.exists()] @@ -263,31 +262,30 @@ def main(): # Process all topics captions_dir = dataset_root / "captions" topic_dirs = sorted([d for d in captions_dir.iterdir() if d.is_dir()]) - + if not topic_dirs: print("[ERROR] No topics found to process") return - - print(f"\nStarting WhisperX caption generation") + + print("\nStarting WhisperX caption generation") print(f" Device: {config['model']['device']}") print(f" Model: {config['model']['name']}") print(f" Language: {config['model']['language'] or 'auto-detect'}") print(f" Topics: {len(topic_dirs)}") - + # Process each topic for topic_dir in topic_dirs: try: process_topic(topic_dir, config) except Exception as e: print(f"[ERROR] Failed to process topic {topic_dir.name}: {e}") - if config['processing']['verbose']: - import traceback + if config["processing"]["verbose"]: traceback.print_exc() - - print(f"\n{'='*60}") + + print(f"\n{'=' * 60}") print("Processing complete!") - print(f"{'='*60}") + print(f"{'=' * 60}") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/sonic-o1/03_demographics_annotation/README.md b/sonic-o1/03_demographics_annotation/README.md index 2ccfe4f..3e5ebd8 100644 --- a/sonic-o1/03_demographics_annotation/README.md +++ b/sonic-o1/03_demographics_annotation/README.md @@ -1,12 +1,81 @@ # Demographics Annotation with Gemini +## Overview + This directory handles automatic demographics annotation for videos using Google's Gemini multimodal model. It analyzes videos, audio, and captions to extract demographic information (race, gender, age, language) of people appearing in the videos. +### Directory Structure + +``` +03_demographics_annotation/ +โ”œโ”€โ”€ run_annotation.py # Main annotation pipeline script +โ”œโ”€โ”€ config_loader.py # Configuration loader (YAML + .env) +โ”œโ”€โ”€ config.yaml # Configuration file +โ”œโ”€โ”€ model.py # Gemini API wrapper +โ”œโ”€โ”€ prompts.py # Prompt templates +โ”œโ”€โ”€ README.md # This file +โ”‚ +โ””โ”€โ”€ dataset/ # Output directory (from 01_data_curation) + โ””โ”€โ”€ videos// + โ”œโ”€โ”€ video_001.mp4 + โ”œโ”€โ”€ metadata.json + โ””โ”€โ”€ metadata_enhanced.json # Generated output +``` + +### Pipeline Workflow + +``` +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Step 1: Demographics Annotation โ”‚ +โ”‚ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ โ”‚ +โ”‚ Input: dataset/videos// โ”‚ +โ”‚ dataset/audios// โ”‚ +โ”‚ dataset/captions// โ”‚ +โ”‚ Output: dataset/videos//metadata_enhanced.json โ”‚ +โ”‚ โ”œโ”€โ”€ demographics_detailed: {race, gender, age, language} โ”‚ +โ”‚ โ”œโ”€โ”€ demographics_confidence: confidence scores โ”‚ +โ”‚ โ””โ”€โ”€ demographics_annotation: metadata โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + โ†“ +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Step 2: Next Steps โ”‚ +โ”‚ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ โ”‚ +โ”‚ โ€ข Check metadata_enhanced.json for demographic annotations โ”‚ +โ”‚ โ€ข Use --retry-failed to reprocess videos with empty demographics โ”‚ +โ”‚ โ€ข Proceed to 04_* directory for VQA generation โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ +``` + +### Features + +- **Multimodal Analysis**: Combines video, audio, and transcript data +- **Automatic Segmentation**: Handles long videos (>25 min) by splitting +- **Checkpoint/Resume**: Saves progress every N videos +- **Retry Failed**: Reprocess videos with empty demographics +- **Rate Limiting**: Configurable delays to avoid API limits +- **Error Handling**: Retries, validation, and graceful degradation + +### Quality Control + +The pipeline includes built-in quality assurance mechanisms: + +- **Validation**: Ensures demographics match allowed categories from config +- **Confidence Filtering**: Filters low-confidence annotations based on `min_confidence` threshold +- **Retry Logic**: Automatically retries failed API calls (up to `retry_attempts`) +- **Checkpointing**: Saves progress periodically to prevent data loss on interruption +- **Comprehensive Logging**: Detailed logs for debugging and quality monitoring +- **Error Recovery**: Graceful handling of API errors, timeouts, and invalid responses + ## Prerequisites Before running this step, you must have completed: -1. **Data Curation** (see [01_data_curation](../01_data_curation/)) - Downloaded videos and audio -2. **Caption Generation** (see [02_caption_generation](../02_caption_generation/)) - Generated captions for all videos + +1. **Data Curation** (see [01_data_curation](../01_data_curation/)) + - Downloaded videos and audio files + - Generated metadata.json files + +2. **Caption Generation** (see [02_caption_generation](../02_caption_generation/)) + - Generated captions for all videos (SRT format) Your `dataset/` directory should have this structure: ``` @@ -20,31 +89,29 @@ dataset/ โ””โ”€โ”€ caption_001.srt ``` -## Required Packages +3. **API Setup** + 1. **Get Gemini API Key** + - Go to [Google AI Studio](https://makersuite.google.com/app/apikey) + - Create or select a project + - Generate an API key -All required Python packages are already included in the project's [requirements_venv.txt](../../requirements_venv.txt): -- `google-generativeai` - Gemini API -- `python-dotenv` - Environment variable management -- `pyyaml` - Configuration file parsing -- `tqdm` - Progress bars + 2. **Set API Key** + Create a `.env` file in this directory: + ```bash + GEMINI_API_KEY=your_gemini_api_key_here + ``` -## API Setup + Or export it as an environment variable: + ```bash + export GEMINI_API_KEY=your_gemini_api_key_here + ``` -### Get Gemini API Key -1. Go to [Google AI Studio](https://makersuite.google.com/app/apikey) -2. Create or select a project -3. Generate an API key +## Installation -### Set API Key -Create a `.env` file in this directory: -```bash -GEMINI_API_KEY=your_gemini_api_key_here -``` +### Required Packages -Or export it as an environment variable: -```bash -export GEMINI_API_KEY=your_gemini_api_key_here -``` +All required Python packages are already included in the project's +[requirements_venv.txt](../../requirements_venv.txt): ## Configuration @@ -63,7 +130,7 @@ model: ### Dataset Settings ```yaml dataset: - base_path: "dataset" # Path to dataset directory + base_path: "../dataset" # Path to dataset directory topics: # Topics to process (or leave empty for all) - "01_Patient-Doctor_Consultations" - "02_Job_Interviews" @@ -76,6 +143,11 @@ processing: save_interval: 10 # Save checkpoint every N videos max_video_duration: 1500 # Max duration before segmentation (25 min) enable_segmentation: true # Auto-segment long videos + prefer_video_with_audio: false # Send both video AND audio + + # Output settings + save_raw_responses: true # Save raw API responses for debugging + create_backup: true # Create backup before overwriting metadata ``` ### Rate Limiting @@ -83,12 +155,13 @@ processing: rate_limit: delay_between_videos: 15 # Seconds between videos delay_after_long_video: 60 # Extra delay after long videos - long_video_threshold: 1800 # Threshold for "long" video (30 min) + long_video_threshold: 1800 # Threshold for "long" video (30 min) ``` ## Usage -**IMPORTANT**: Always run the annotation script from the project root (sonic-o1/sonic-o1 directory) so relative paths work correctly. +**IMPORTANT**: Always run the annotation script from the project root +(sonic-o1/sonic-o1 directory) so relative paths work correctly. ### Process All Topics @@ -115,18 +188,17 @@ Then run: python 03_demographics_annotation/run_annotation.py ``` -### Test Single Video - -To test on a single video before processing everything: +### Process Single Topic ```bash -# Edit test_single_video.py to set topic and video number -# Lines 25-26: -# topic = "01_Patient-Doctor_Consultations" -# video_number = "015" +python 03_demographics_annotation/run_annotation.py --topic "01_Patient-Doctor_Consultations" +``` -# Run test from project root -python 03_demographics_annotation/test_single_video.py +### Retry Failed Videos + +Reprocess videos with empty demographics: +```bash +python 03_demographics_annotation/run_annotation.py --retry-failed ``` ### Use Custom Configuration @@ -135,9 +207,39 @@ python 03_demographics_annotation/test_single_video.py python 03_demographics_annotation/run_annotation.py --config path/to/custom_config.yaml ``` +### Command-Line Arguments + +The script supports several command-line arguments: + +| Argument | Description | Example | +|----------|-------------|---------| +| `--config` | Path to configuration file (default: `config.yaml`) | `--config my_config.yaml` | +| `--topic` | Process specific topic only | `--topic "01_Patient-Doctor_Consultations"` | +| `--api-key` | Override Gemini API key from config/env | `--api-key "your_key_here"` | +| `--no-cache` | Reprocess all videos even if already done | `--no-cache` | +| `--retry-failed` | Only reprocess videos with empty demographics | `--retry-failed` | + +**Examples:** + +```bash +# Process single topic with custom config +python 03_demographics_annotation/run_annotation.py \ + --topic "01_Patient-Doctor_Consultations" \ + --config custom_config.yaml + +# Retry failed videos with API key override +python 03_demographics_annotation/run_annotation.py \ + --retry-failed \ + --api-key "new_api_key" + +# Reprocess all videos (ignore checkpoints) +python 03_demographics_annotation/run_annotation.py --no-cache +``` + ## Output -The script creates `metadata_enhanced.json` files in each topic directory with demographic annotations: +The script creates `metadata_enhanced.json` files in each topic directory +with demographic annotations: ```json { @@ -149,8 +251,19 @@ The script creates `metadata_enhanced.json` files in each topic directory with d "age": ["Middle (25-39)"], "language": ["English"] }, - "raw_response": "...", - "processing_timestamp": "2024-01-14T12:00:00" + "demographics_confidence": { + "race": {"Asian": 0.9, "White": 0.85}, + "gender": {"Male": 0.95, "Female": 0.9}, + "age": {"Middle (25-39)": 0.8}, + "language": {"English": 1.0} + }, + "demographics_annotation": { + "model": "gemini-2.5-flash", + "annotated_at": "2024-01-14 12:00:00", + "individuals_count": 2, + "modalities_used": ["video", "audio", "transcript"], + "explanation": "Video shows 2 individuals having a conversation..." + } } ``` @@ -159,93 +272,37 @@ The script creates `metadata_enhanced.json` files in each topic directory with d dataset/ โ””โ”€โ”€ videos// โ”œโ”€โ”€ metadata.json # Original metadata - โ”œโ”€โ”€ metadata_enhanced.json # With demographics annotations - โ””โ”€โ”€ metadata_enhanced_checkpoint.json # Checkpoint for resume + โ”œโ”€โ”€ metadata_enhanced.json # With demographics + โ”œโ”€โ”€ metadata_enhanced_checkpoint.json # Checkpoint (auto-deleted) + โ””โ”€โ”€ raw_responses/ # Raw API responses (optional) + โ””โ”€โ”€ video_001_response.json ``` ## Checkpoint and Resume -The script automatically saves checkpoints every N videos (configured by `save_interval`). If the script is interrupted: +The script automatically saves checkpoints every N videos (configured by +`save_interval`). If the script is interrupted: -1. **Resume automatically** - Just run the script again, it will detect the checkpoint and continue where it left off -2. **Start fresh** - Delete the `metadata_enhanced_checkpoint.json` file in the topic directory +1. **Resume automatically** - Just run the script again, it will detect + the checkpoint and continue where it left off -```bash -# To start fresh on a specific topic -rm dataset/videos/01_Patient-Doctor_Consultations/metadata_enhanced_checkpoint.json -``` +2. **Start fresh** - To restart processing from the beginning: + ```bash + # Delete checkpoint file for a specific topic + rm dataset/videos/01_Patient-Doctor_Consultations/metadata_enhanced_checkpoint.json -## Examples + # Or delete all checkpoints + find dataset/videos -name "metadata_enhanced_checkpoint.json" -delete -### Example 1: Process All Topics + # Then run the script normally + python 03_demographics_annotation/run_annotation.py + ``` -```bash -# 1. Navigate to working directory -cd /path/to/sonic-o1/sonic-o1 + **Note**: Starting fresh will reprocess all videos in the topic. If you want + to keep existing annotations and only process new videos, don't delete the + checkpoint - the script will automatically skip already-processed videos. -# 2. Set API key (if not in .env) -export GEMINI_API_KEY=your_key_here -# 3. Run annotation -python 03_demographics_annotation/run_annotation.py -``` - -### Example 2: Process Only New Topics - -Edit [config.yaml](config.yaml): -```yaml -dataset: - topics: - - "11_Mental_Health_Counseling" - - "12_Community_Town_Halls" -``` - -Run: -```bash -python 03_demographics_annotation/run_annotation.py -``` - -### Example 3: Test Single Video First - -```bash -# Edit test_single_video.py to set your test video -# topic = "02_Job_Interviews" -# video_number = "005" - -# Run test -python 03_demographics_annotation/test_single_video.py -``` - -Expected output: -``` -================================================================================ -MULTIMODAL DEMOGRAPHICS ANNOTATION TEST -================================================================================ -Topic: 02_Job_Interviews -Item Number: 005 -Model: gemini-2.5-flash -================================================================================ - -File Paths: - Video: dataset/videos/02_Job_Interviews/video_005.mp4 - Audio: dataset/audios/02_Job_Interviews/audio_005.m4a - Caption: dataset/captions/02_Job_Interviews/caption_005.srt - -Processing... -[Success] Demographics extracted -``` - -### Example 4: Resume After Interruption - -```bash -# If script was interrupted, just run again -python 03_demographics_annotation/run_annotation.py - -# Output will show: -# "Found checkpoint file: metadata_enhanced_checkpoint.json" -# "Loaded 15 processed videos from checkpoint" -# "Resuming from video 16/25" -``` ## Processing Time @@ -298,7 +355,12 @@ rate_limit: **Problem**: Some videos have empty `demographics_detailed` -**Solution**: The script automatically retries failed videos. Check the log file: +**Solution**: Use the retry flag to reprocess failed videos: +```bash +python 03_demographics_annotation/run_annotation.py --retry-failed +``` + +Check the log file: ```bash cat 03_demographics_annotation/demographics_annotation.log ``` @@ -325,41 +387,55 @@ model: timeout: 120 # Increase from 60 seconds ``` -## Quality Control - -The script includes built-in quality checks: - -1. **Validation**: Ensures demographics match allowed categories -2. **Retry Logic**: Automatically retries failed videos -3. **Checkpointing**: Saves progress to prevent data loss -4. **Logging**: Detailed logs for debugging - ### Check Annotation Quality +You can verify the quality and completeness of annotations using these commands: + +**Count videos with successful annotations:** ```bash -# Count videos with demographics +# Count how many videos have non-empty demographics_detailed jq '[.[] | select(.demographics_detailed != null)] | length' \ dataset/videos/01_Patient-Doctor_Consultations/metadata_enhanced.json +``` +This counts videos where `demographics_detailed` exists and is not null. A video +is considered successfully annotated if it has at least one demographic category +(race, gender, age, or language) with non-empty values. -# View specific annotation +**View a specific video's annotation:** +```bash +# View full annotation for video_001 jq '.[] | select(.video_number == "001")' \ dataset/videos/01_Patient-Doctor_Consultations/metadata_enhanced.json ``` +This displays the complete annotation for a specific video, including: +- `demographics_detailed`: Lists of detected demographics per category +- `demographics_confidence`: Confidence scores (0.0-1.0) for each detection +- `demographics_annotation`: Metadata including model used, timestamp, individual count, and explanation + +**Check for videos with empty demographics:** +```bash +# Find videos that failed annotation (empty demographics) +jq '[.[] | select(.demographics_detailed.race == [] and .demographics_detailed.gender == [] and .demographics_detailed.age == [] and .demographics_detailed.language == [])] | length' \ + dataset/videos/01_Patient-Doctor_Consultations/metadata_enhanced.json +``` +This identifies videos that need reprocessing with `--retry-failed`. -## File Descriptions +## Files -- [run_annotation.py](run_annotation.py) - Main annotation pipeline script -- [test_single_video.py](test_single_video.py) - Test script for single video -- [config.yaml](config.yaml) - Configuration file -- [config_loader.py](config_loader.py) - Configuration loader -- [model.py](model.py) - Gemini API wrapper and demographics extraction -- [prompts.py](prompts.py) - System and user prompts for the model -- [.env](.env) - Environment variables (API keys) +- `run_annotation.py` - Main annotation pipeline script +- `config.yaml` - Configuration file +- `config_loader.py` - Configuration loader with .env support +- `model.py` - Gemini API wrapper and demographics extraction +- `prompts.py` - System and user prompts for the model +- `.env` - Environment variables (API keys) - create this file ## Notes - The script processes videos in order by video number - Already processed videos are skipped automatically -- Raw model responses are saved for debugging if `save_raw_responses: true` +- Raw model responses are saved for debugging if + `save_raw_responses: true` - Backups are created before overwriting if `create_backup: true` -- All paths are relative to the project root, so always run from the sonic-o1/sonic-o1 directory (the inner sonic-o1 directory that contains the pipeline code) +- All paths are relative to the project root, so always run from the + sonic-o1/sonic-o1 directory (the inner sonic-o1 directory that + contains the pipeline code) diff --git a/sonic-o1/03_demographics_annotation/config.yaml b/sonic-o1/03_demographics_annotation/config.yaml index f2b7000..1f8cc43 100644 --- a/sonic-o1/03_demographics_annotation/config.yaml +++ b/sonic-o1/03_demographics_annotation/config.yaml @@ -1,7 +1,7 @@ # Demographics Annotation Configuration model: - name: "gemini-2.5-flash" + name: "gemini-2.5-flash" api_key: ${GEMINI_API_KEY} # Will read from environment variable temperature: 0.3 max_output_tokens: 1024 @@ -27,6 +27,13 @@ dataset: - "13_Olympics" demographics: + # Category keys used in demographics_detailed and demographics_confidence + categories: + - "race" + - "gender" + - "age" + - "language" + races: - "White" - "Black" @@ -34,16 +41,16 @@ demographics: - "Indigenous" - "Arab" - "Hispanic" - + genders: - "Male" - "Female" - + age_groups: - "Young (18-24)" - "Middle (25-39)" - "Older adults (40+)" - + languages: - "English" - "Hindi" @@ -55,27 +62,36 @@ processing: batch_size: 5 save_interval: 10 use_cache: true - + # Video processing limits file_processing_timeout: 7200 # Max time (seconds) to wait for Google File API processing (2 hours) max_video_duration: 1500 # Max video duration (seconds) before auto-segmentation (55 minutes) max_transcript_length: 25000 # Max transcript characters (to avoid token overflow) - + # Segmentation settings (for videos longer than max_video_duration) segment_overlap: 60 # Overlap between segments in seconds (1 minute) enable_segmentation: true # Enable automatic segmentation for long videos - + log_level: "INFO" output_format: "metadata_enhanced.json" - + # File patterns video_pattern: "video_{number}.mp4" audio_pattern: "audio_{number}.m4a" caption_pattern: "caption_{number}.srt" - + + # Supported video file extensions (for determining if file is a video) + video_extensions: + - ".mp4" + - ".avi" + - ".mov" + - ".webm" + - ".mkv" + - ".m4v" + # Multimodal processing prefer_video_with_audio: false # Send both video AND separate audio (all files together) - + # Output settings save_raw_responses: true create_backup: true @@ -95,4 +111,4 @@ quality: rate_limit: delay_between_videos: 15 # seconds to wait between processing videos delay_after_long_video: 60 # extra seconds to wait after videos > 30 minutes - long_video_threshold: 1800 # seconds (30 minutes) \ No newline at end of file + long_video_threshold: 1800 # seconds (30 minutes) diff --git a/sonic-o1/03_demographics_annotation/config_loader.py b/sonic-o1/03_demographics_annotation/config_loader.py index 92f26b9..d2daec7 100644 --- a/sonic-o1/03_demographics_annotation/config_loader.py +++ b/sonic-o1/03_demographics_annotation/config_loader.py @@ -1,188 +1,219 @@ +"""config_loader.py. + +Configuration loader for YAML-based config with .env support. + +Author: SONIC-O1 Team """ -Configuration loader for YAML-based config with .env support -""" -import yaml + +import logging import os -from pathlib import Path -from typing import Dict, Any, List from dataclasses import dataclass -import logging +from pathlib import Path +from typing import Any, Dict + +import yaml from dotenv import load_dotenv + # Load .env file from the same directory as this script -load_dotenv(Path(__file__).parent / '.env') +load_dotenv(Path(__file__).parent / ".env") logger = logging.getLogger(__name__) + class ConfigLoader: - """Load and validate configuration from YAML file with environment variable support""" - + """Load and validate configuration from YAML file. + + Supports environment variable substitution. + """ + def __init__(self, config_path: str = "config.yaml"): + """Initialize configuration loader. + + Args: + config_path: Path to YAML configuration file. + """ self.config_path = Path(config_path) if not self.config_path.is_absolute(): self.config_path = Path(__file__).parent / config_path self.config = self.load_config() self._resolve_environment_variables() self._resolve_paths() - + def load_config(self) -> Dict[str, Any]: - """Load YAML configuration file""" + """Load YAML configuration file.""" if not self.config_path.exists(): raise FileNotFoundError(f"Configuration file not found: {self.config_path}") - - with open(self.config_path, 'r') as f: + + with open(self.config_path, "r") as f: config = yaml.safe_load(f) - + logger.info(f"Loaded configuration from {self.config_path}") return config - + def _resolve_environment_variables(self): - """Resolve environment variables in config""" + """Resolve environment variables in config.""" # Handle API key - check multiple sources in order of priority - if 'model' in self.config: + if "model" in self.config: # Priority 1: Environment variable from .env or system - api_key_from_env = os.getenv('GEMINI_API_KEY') - + api_key_from_env = os.getenv("GEMINI_API_KEY") + # Priority 2: Config file (if not using ${} syntax) - api_key_from_config = self.config['model'].get('api_key', '') - + api_key_from_config = self.config["model"].get("api_key", "") + # If config uses ${VAR} syntax, replace with env var - if api_key_from_config.startswith('${') and api_key_from_config.endswith('}'): + if api_key_from_config.startswith("${") and api_key_from_config.endswith( + "}" + ): env_var = api_key_from_config[2:-1] - self.config['model']['api_key'] = os.getenv(env_var, '') + self.config["model"]["api_key"] = os.getenv(env_var, "") elif api_key_from_env: # Use environment variable if available - self.config['model']['api_key'] = api_key_from_env - elif api_key_from_config and not api_key_from_config.startswith('${'): + self.config["model"]["api_key"] = api_key_from_env + elif api_key_from_config and not api_key_from_config.startswith("${"): # Use config value if it's not a variable reference - self.config['model']['api_key'] = api_key_from_config + self.config["model"]["api_key"] = api_key_from_config else: # No API key found - self.config['model']['api_key'] = '' + self.config["model"]["api_key"] = "" logger.warning("No API key found in environment or config") - + def _resolve_paths(self): - """Resolve relative paths to absolute""" - if 'dataset' in self.config and 'base_path' in self.config['dataset']: - base_path = Path(self.config['dataset']['base_path']) + """Resolve relative paths to absolute.""" + if "dataset" in self.config and "base_path" in self.config["dataset"]: + base_path = Path(self.config["dataset"]["base_path"]) if not base_path.is_absolute(): # Make relative to config file location base_path = (self.config_path.parent / base_path).resolve() - self.config['dataset']['base_path'] = str(base_path) - + self.config["dataset"]["base_path"] = str(base_path) + def get_model_config(self) -> Dict[str, Any]: - """Get model configuration""" - return self.config.get('model', {}) - + """Get model configuration.""" + return self.config.get("model", {}) + def get_dataset_config(self) -> Dict[str, Any]: - """Get dataset configuration""" - return self.config.get('dataset', {}) - + """Get dataset configuration.""" + return self.config.get("dataset", {}) + def get_processing_config(self) -> Dict[str, Any]: - """Get processing configuration""" - return self.config.get('processing', {}) - + """Get processing configuration.""" + return self.config.get("processing", {}) + def get_demographics_config(self) -> Dict[str, Any]: - """Get demographics configuration""" - return self.config.get('demographics', {}) + """Get demographics configuration.""" + return self.config.get("demographics", {}) + @dataclass class Config: - """Unified configuration class with .env support""" - + """Unified configuration class with .env support.""" + def __init__(self, config_path: str = "config.yaml"): + """Initialize configuration from YAML file. + + Args: + config_path: Path to YAML configuration file. + """ loader = ConfigLoader(config_path) # Store the raw config from loader self.raw_config = loader.config - + # Model settings model_cfg = loader.get_model_config() - self.model_name = model_cfg.get('name', 'gemini-2.5-flash') - self.api_key = model_cfg.get('api_key', '') - + self.model_name = model_cfg.get("name", "gemini-2.5-flash") + self.api_key = model_cfg.get("api_key", "") + # Log API key status (not the key itself) if self.api_key: logger.info(f"API key loaded (length: {len(self.api_key)})") else: logger.warning("No API key loaded") - - self.temperature = model_cfg.get('temperature', 0.3) - self.max_output_tokens = model_cfg.get('max_output_tokens', 1024) - self.timeout = model_cfg.get('timeout', 60) - self.retry_attempts = model_cfg.get('retry_attempts', 3) - self.retry_delay = model_cfg.get('retry_delay', 5) - + + self.temperature = model_cfg.get("temperature", 0.3) + self.max_output_tokens = model_cfg.get("max_output_tokens", 1024) + self.timeout = model_cfg.get("timeout", 60) + self.retry_attempts = model_cfg.get("retry_attempts", 3) + self.retry_delay = model_cfg.get("retry_delay", 5) + # Dataset settings dataset_cfg = loader.get_dataset_config() - self.base_path = Path(dataset_cfg.get('base_path', '../dataset')) - self.topics = dataset_cfg.get('topics', []) - + self.base_path = Path(dataset_cfg.get("base_path", "../dataset")) + self.topics = dataset_cfg.get("topics", []) + # Demographics settings demo_cfg = loader.get_demographics_config() - self.races = demo_cfg.get('races', []) - self.genders = demo_cfg.get('genders', []) - self.age_groups = demo_cfg.get('age_groups', []) - self.languages = demo_cfg.get('languages', []) - + self.demographic_categories = demo_cfg.get( + "categories", ["race", "gender", "age", "language"] + ) + self.races = demo_cfg.get("races", []) + self.genders = demo_cfg.get("genders", []) + self.age_groups = demo_cfg.get("age_groups", []) + self.languages = demo_cfg.get("languages", []) + # Processing settings proc_cfg = loader.get_processing_config() - self.batch_size = proc_cfg.get('batch_size', 5) - self.save_interval = proc_cfg.get('save_interval', 10) - self.use_cache = proc_cfg.get('use_cache', True) - self.output_format = proc_cfg.get('output_format', 'metadata_enhanced.json') - + self.batch_size = proc_cfg.get("batch_size", 5) + self.save_interval = proc_cfg.get("save_interval", 10) + self.use_cache = proc_cfg.get("use_cache", True) + self.output_format = proc_cfg.get("output_format", "metadata_enhanced.json") + # File patterns - self.video_pattern = proc_cfg.get('video_pattern', 'video_{number}.mp4') - self.audio_pattern = proc_cfg.get('audio_pattern', 'audio_{number}.m4a') - self.caption_pattern = proc_cfg.get('caption_pattern', 'caption_{number}.srt') - + self.video_pattern = proc_cfg.get("video_pattern", "video_{number}.mp4") + self.audio_pattern = proc_cfg.get("audio_pattern", "audio_{number}.m4a") + self.caption_pattern = proc_cfg.get("caption_pattern", "caption_{number}.srt") + + # Supported video file extensions + self.video_extensions = proc_cfg.get( + "video_extensions", [".mp4", ".avi", ".mov", ".webm", ".mkv", ".m4v"] + ) + # Multimodal processing settings - self.file_processing_timeout = proc_cfg.get('file_processing_timeout', 7200) - self.max_video_duration = proc_cfg.get('max_video_duration', 3300) - self.max_transcript_length = proc_cfg.get('max_transcript_length', 50000) - self.prefer_video_with_audio = proc_cfg.get('prefer_video_with_audio', True) - + self.file_processing_timeout = proc_cfg.get("file_processing_timeout", 7200) + self.max_video_duration = proc_cfg.get("max_video_duration", 3300) + self.max_transcript_length = proc_cfg.get("max_transcript_length", 50000) + self.prefer_video_with_audio = proc_cfg.get("prefer_video_with_audio", True) + # Output settings - self.save_raw_responses = proc_cfg.get('save_raw_responses', True) - self.create_backup = proc_cfg.get('create_backup', True) - - rate_limit_cfg = self.raw_config.get('rate_limit', {}) - self.delay_between_videos = rate_limit_cfg.get('delay_between_videos', 10) - self.delay_after_long_video = rate_limit_cfg.get('delay_after_long_video', 30) - self.long_video_threshold = rate_limit_cfg.get('long_video_threshold', 1800) + self.save_raw_responses = proc_cfg.get("save_raw_responses", True) + self.create_backup = proc_cfg.get("create_backup", True) + rate_limit_cfg = self.raw_config.get("rate_limit", {}) + self.delay_between_videos = rate_limit_cfg.get("delay_between_videos", 10) + self.delay_after_long_video = rate_limit_cfg.get("delay_after_long_video", 30) + self.long_video_threshold = rate_limit_cfg.get("long_video_threshold", 1800) # Quality settings - quality_cfg = self.raw_config.get('quality', {}) - self.min_confidence = quality_cfg.get('min_confidence', 0.5) - self.require_explanation = quality_cfg.get('require_explanation', True) - self.validate_json = quality_cfg.get('validate_json', True) - + quality_cfg = self.raw_config.get("quality", {}) + self.min_confidence = quality_cfg.get("min_confidence", 0.5) + self.require_explanation = quality_cfg.get("require_explanation", True) + self.validate_json = quality_cfg.get("validate_json", True) + # Logging settings - log_cfg = self.raw_config.get('logging', {}) - self.log_file = log_cfg.get('log_file', 'demographics_annotation.log') - self.log_format = log_cfg.get('format', '%(asctime)s - %(name)s - %(levelname)s - %(message)s') - self.console_output = log_cfg.get('console_output', True) - self.file_output = log_cfg.get('file_output', True) - + log_cfg = self.raw_config.get("logging", {}) + self.log_file = log_cfg.get("log_file", "demographics_annotation.log") + self.log_format = log_cfg.get( + "format", "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + self.console_output = log_cfg.get("console_output", True) + self.file_output = log_cfg.get("file_output", True) + def get_topic_paths(self, topic: str) -> Dict[str, Path]: - """Get all relevant paths for a topic""" + """Get all relevant paths for a topic.""" return { "videos": self.base_path / "videos" / topic, "audios": self.base_path / "audios" / topic, "captions": self.base_path / "captions" / topic, - "metadata": self.base_path / "videos" / topic / "metadata.json" + "metadata": (self.base_path / "videos" / topic / "metadata.json"), } - + def get_file_path(self, topic: str, file_type: str, number: str) -> Path: - """Get specific file path based on pattern""" + """Get specific file path based on pattern.""" paths = self.get_topic_paths(topic) - + if file_type == "video": return paths["videos"] / self.video_pattern.format(number=number) - elif file_type == "audio": + if file_type == "audio": return paths["audios"] / self.audio_pattern.format(number=number) - elif file_type == "caption": + if file_type == "caption": return paths["captions"] / self.caption_pattern.format(number=number) - else: - raise ValueError(f"Unknown file type: {file_type}") \ No newline at end of file + raise ValueError(f"Unknown file type: {file_type}") diff --git a/sonic-o1/03_demographics_annotation/model.py b/sonic-o1/03_demographics_annotation/model.py index 36fd464..960dda2 100644 --- a/sonic-o1/03_demographics_annotation/model.py +++ b/sonic-o1/03_demographics_annotation/model.py @@ -1,51 +1,66 @@ +"""model.py. + +Model interface for demographics annotation using Gemini. + +Author: SONIC-O1 Team """ -Model interface for demographics annotation using Gemini -""" -from google import genai -from google.genai import types -import os + import json +import logging +import os +import re +import shutil +import subprocess +import tempfile import time from datetime import datetime -import logging from pathlib import Path -from typing import Dict, Optional, Any, List +from typing import Any, Dict, List, Optional + +from google import genai +from google.genai import types +from prompts import MAIN_PROMPT_TEMPLATE + logger = logging.getLogger(__name__) + class DemographicsAnnotator: - """Handle Gemini API interactions for demographics annotation""" - + """Handle Gemini API interactions for demographics annotation.""" + def __init__(self, config): """ Initialize the Gemini client with configuration. - + Args: config: Configuration object with model settings """ self.config = config - self.file_processing_timeout = getattr(config, 'file_processing_timeout', 7200) - self.max_video_duration = getattr(config, 'max_video_duration', 3300) - + self.file_processing_timeout = getattr(config, "file_processing_timeout", 7200) + self.max_video_duration = getattr(config, "max_video_duration", 3300) + self.setup_client() - + def setup_client(self): - """Initialize the Gemini client""" - os.environ['GEMINI_API_KEY'] = self.config.api_key + """Initialize the Gemini client.""" + os.environ["GEMINI_API_KEY"] = self.config.api_key self.client = genai.Client() logger.info(f"Initialized Gemini client with model: {self.config.model_name}") - - def process_media(self, - video_path: Optional[Path], - audio_path: Optional[Path], - transcript_path: Optional[Path], - metadata: Dict[str, Any], - config, - _is_segment: bool = False) -> Dict[str, Any]: + + def process_media( + self, + video_path: Optional[Path], + audio_path: Optional[Path], + transcript_path: Optional[Path], + metadata: Dict[str, Any], + config, + _is_segment: bool = False, + ) -> Dict[str, Any]: """ Process media files (video, audio, transcript) for demographics annotation. + Automatically handles videos longer than the limit by segmenting. - + Args: video_path: Path to video file (optional) audio_path: Path to audio file (optional) @@ -55,232 +70,300 @@ def process_media(self, _is_segment: Internal flag to prevent re-segmentation of segments """ try: - duration = metadata.get('duration_seconds', 0) - + duration = metadata.get("duration_seconds", 0) + # Validate input if not video_path and not audio_path: error_msg = "Must provide either video_path or audio_path" logger.error(error_msg) return self._get_error_response(error_msg) - - # Check if video needs segmentation (but NOT if this is already a segment) + + # Check if video needs segmentation + # (but NOT if this is already a segment) primary_media = video_path if video_path else audio_path - is_video = primary_media.suffix.lower() in ['.mp4', '.avi', '.mov', '.webm', '.mkv', '.m4v'] - - # Only segment if: it's a video, it's too long, AND it's not already a segment + is_video = primary_media.suffix.lower() in config.video_extensions + + # Only segment if: it's a video, it's too long, + # AND it's not already a segment if is_video and duration > self.max_video_duration and not _is_segment: logger.warning( - f"Video duration ({duration}s) exceeds limit ({self.max_video_duration}s). " + f"Video duration ({duration}s) exceeds limit " + f"({self.max_video_duration}s). " f"Will segment and process in chunks." ) return self._process_long_video_segmented( video_path, audio_path, transcript_path, metadata, config ) - + # Normal processing for videos within limit (or segments) logger.info(f"Processing media: {primary_media.name}") - + # Load transcript if available transcript_text = self._load_transcript(transcript_path, config) - + # Prepare the prompt with transcript included prompt = self._build_prompt(metadata, transcript_text) - + # Collect all media files to process media_files = [] if video_path: - media_files.append(('video', video_path)) + media_files.append(("video", video_path)) if audio_path: - media_files.append(('audio', audio_path)) - + media_files.append(("audio", audio_path)) + # Determine processing method based on file sizes total_size = sum(os.path.getsize(path) for _, path in media_files) use_file_api = total_size > 20 * 1024 * 1024 # 20MB threshold - + if use_file_api: - logger.info(f"Using File API for large media (total size: {total_size / (1024*1024):.2f}MB)") - response_text = self._process_large_media_multimodal(media_files, prompt) + logger.info( + f"Using File API for large media (total size: " + f"{total_size / (1024 * 1024):.2f}MB)" + ) + response_text = self._process_large_media_multimodal( + media_files, prompt + ) else: - logger.info(f"Using inline processing for small media (total size: {total_size / (1024*1024):.2f}MB)") - response_text = self._process_small_media_multimodal(media_files, prompt) - + logger.info( + f"Using inline processing for small media " + f"(total size: {total_size / (1024 * 1024):.2f}MB)" + ) + response_text = self._process_small_media_multimodal( + media_files, prompt + ) + # Parse JSON response demographics_data = self._parse_response(response_text) - + # Add raw response if configured if config.save_raw_responses: - demographics_data['raw_response'] = response_text - + demographics_data["raw_response"] = response_text + return demographics_data - + except Exception as e: logger.error(f"Error processing media: {e}", exc_info=True) return self._get_error_response(str(e)) - - def _process_long_video_segmented(self, - video_path: Optional[Path], - audio_path: Optional[Path], - transcript_path: Optional[Path], - metadata: Dict[str, Any], - config) -> Dict[str, Any]: + + def _process_long_video_segmented( + self, + video_path: Optional[Path], + audio_path: Optional[Path], + transcript_path: Optional[Path], + metadata: Dict[str, Any], + config, + ) -> Dict[str, Any]: """ Process videos longer than the limit by segmenting into chunks. - + Strategy: - 1. Split video into overlapping segments (e.g., 50min segments with 5min overlap) + 1. Split video into overlapping segments + (e.g., 50min segments with 5min overlap) 2. Process each segment 3. Aggregate results with deduplication """ try: - import subprocess - import tempfile - import shutil - - duration = metadata.get('duration_seconds', 0) - segment_duration = self.max_video_duration - 300 # 50 minutes (leaving 5min buffer) - overlap = 60 # 1 minute overlap to catch people across boundaries - + duration = metadata.get("duration_seconds", 0) + segment_duration = ( + self.max_video_duration - 300 + ) # 50 minutes (leaving 5min buffer) + # 1 minute overlap to catch people across boundaries + overlap = 60 + num_segments = int(duration / segment_duration) + 1 - logger.info(f"Splitting {duration}s video into {num_segments} segments of ~{segment_duration}s each") - + logger.info( + f"Splitting {duration}s video into {num_segments} segments " + f"of ~{segment_duration}s each" + ) + temp_dir = Path(tempfile.mkdtemp(prefix="video_segments_")) all_demographics = [] - + try: for i in range(num_segments): - start_time = max(0, i * segment_duration - (overlap if i > 0 else 0)) - segment_duration_actual = min(segment_duration + overlap, duration - start_time) - - logger.info(f"Processing segment {i+1}/{num_segments}: {start_time}s to {start_time + segment_duration_actual}s") - + start_time = max( + 0, i * segment_duration - (overlap if i > 0 else 0) + ) + segment_duration_actual = min( + segment_duration + overlap, duration - start_time + ) + + logger.info( + f"Processing segment {i + 1}/{num_segments}: " + f"{start_time}s to " + f"{start_time + segment_duration_actual}s" + ) + # Create segment filename segment_video_path = None if video_path: - segment_video_path = temp_dir / f"segment_{i:03d}{video_path.suffix}" - + segment_video_path = ( + temp_dir / f"segment_{i:03d}{video_path.suffix}" + ) + # Use ffmpeg to extract segment cmd = [ - 'ffmpeg', '-y', - '-ss', str(start_time), - '-i', str(video_path), - '-t', str(segment_duration_actual), - '-c', 'copy', # Fast: just copy streams without re-encoding - '-avoid_negative_ts', '1', - str(segment_video_path) + "ffmpeg", + "-y", + "-ss", + str(start_time), + "-i", + str(video_path), + "-t", + str(segment_duration_actual), + "-c", + # Fast: just copy streams without re-encoding + "copy", + "-avoid_negative_ts", + "1", + str(segment_video_path), ] - - result = subprocess.run(cmd, capture_output=True, text=True) + + result = subprocess.run( + cmd, capture_output=True, text=True, check=False + ) if result.returncode != 0: logger.error(f"FFmpeg error: {result.stderr}") raise Exception(f"Failed to create video segment {i}") - + # Create audio segment if separate audio exists segment_audio_path = None if audio_path: - segment_audio_path = temp_dir / f"segment_{i:03d}{audio_path.suffix}" - + segment_audio_path = ( + temp_dir / f"segment_{i:03d}{audio_path.suffix}" + ) + cmd = [ - 'ffmpeg', '-y', - '-ss', str(start_time), - '-i', str(audio_path), - '-t', str(segment_duration_actual), - '-c', 'copy', - str(segment_audio_path) + "ffmpeg", + "-y", + "-ss", + str(start_time), + "-i", + str(audio_path), + "-t", + str(segment_duration_actual), + "-c", + "copy", + str(segment_audio_path), ] - - result = subprocess.run(cmd, capture_output=True, text=True) + + result = subprocess.run( + cmd, capture_output=True, text=True, check=False + ) if result.returncode != 0: - logger.warning(f"Could not create audio segment: {result.stderr}") + logger.warning( + f"Could not create audio segment: {result.stderr}" + ) segment_audio_path = None - + # Extract relevant transcript section segment_transcript_path = None if transcript_path and transcript_path.exists(): segment_transcript = self._extract_transcript_segment( - transcript_path, start_time, start_time + segment_duration_actual + transcript_path, + start_time, + start_time + segment_duration_actual, ) if segment_transcript: segment_transcript_path = temp_dir / f"segment_{i:03d}.srt" - with open(segment_transcript_path, 'w', encoding='utf-8') as f: + with open( + segment_transcript_path, "w", encoding="utf-8" + ) as f: f.write(segment_transcript) - - # Process this segment (with _is_segment=True to prevent re-segmentation) + + # Process this segment (with _is_segment=True to prevent + # re-segmentation) segment_metadata = metadata.copy() - segment_metadata['duration_seconds'] = segment_duration_actual - segment_metadata['segment_info'] = { - 'segment_number': i + 1, - 'total_segments': num_segments, - 'start_time': start_time, - 'end_time': start_time + segment_duration_actual + segment_metadata["duration_seconds"] = segment_duration_actual + segment_metadata["segment_info"] = { + "segment_number": i + 1, + "total_segments": num_segments, + "start_time": start_time, + "end_time": start_time + segment_duration_actual, } - + segment_demographics = self.process_media( video_path=segment_video_path, audio_path=segment_audio_path, transcript_path=segment_transcript_path, metadata=segment_metadata, config=config, - _is_segment=True # THIS IS THE KEY FIX - prevents re-segmentation + # THIS IS THE KEY FIX - prevents re-segmentation + _is_segment=True, ) - + all_demographics.append(segment_demographics) - + # Aggregate results - aggregated = self._aggregate_segment_demographics(all_demographics, num_segments) - + aggregated = self._aggregate_segment_demographics( + all_demographics, num_segments + ) + # Add segmentation info - aggregated['demographics_annotation']['segmented'] = True - aggregated['demographics_annotation']['num_segments'] = num_segments - aggregated['demographics_annotation']['original_duration'] = duration - + aggregated["demographics_annotation"]["segmented"] = True + aggregated["demographics_annotation"]["num_segments"] = num_segments + aggregated["demographics_annotation"]["original_duration"] = duration + return aggregated - + finally: # Cleanup temporary files if temp_dir.exists(): shutil.rmtree(temp_dir) - logger.info(f"Cleaned up temporary segments") - + logger.info("Cleaned up temporary segments") + except Exception as e: logger.error(f"Error in segmented processing: {e}", exc_info=True) return self._get_error_response(f"Segmentation error: {e}") - - def _extract_transcript_segment(self, transcript_path: Path, start_time: float, end_time: float) -> str: - """Extract portion of SRT transcript for a time segment""" + + def _extract_transcript_segment( + self, transcript_path: Path, start_time: float, end_time: float + ) -> str: + """Extract portion of SRT transcript for a time segment.""" try: - with open(transcript_path, 'r', encoding='utf-8') as f: + with open(transcript_path, "r", encoding="utf-8") as f: content = f.read() - + # Simple SRT parser - import re - segments = content.strip().split('\n\n') + + segments = content.strip().split("\n\n") extracted = [] - + for segment in segments: - lines = segment.split('\n') + lines = segment.split("\n") if len(lines) < 3: continue - - # Parse timestamp line (format: 00:00:10,500 --> 00:00:13,000) - timestamp_match = re.search(r'(\d{2}):(\d{2}):(\d{2}),(\d{3})\s*-->\s*(\d{2}):(\d{2}):(\d{2}),(\d{3})', lines[1]) + + # Parse timestamp line + # Format: 00:00:10,500 --> 00:00:13,000 + timestamp_match = re.search( + r"(\d{2}):(\d{2}):(\d{2}),(\d{3})\s*-->\s*" + r"(\d{2}):(\d{2}):(\d{2}),(\d{3})", + lines[1], + ) if timestamp_match: - h1, m1, s1, ms1, h2, m2, s2, ms2 = map(int, timestamp_match.groups()) - seg_start = h1*3600 + m1*60 + s1 + ms1/1000 - seg_end = h2*3600 + m2*60 + s2 + ms2/1000 - + h1, m1, s1, ms1, h2, m2, s2, ms2 = map( + int, timestamp_match.groups() + ) + seg_start = h1 * 3600 + m1 * 60 + s1 + ms1 / 1000 + seg_end = h2 * 3600 + m2 * 60 + s2 + ms2 / 1000 + # Check if this segment overlaps with our time range if seg_start < end_time and seg_end > start_time: extracted.append(segment) - - return '\n\n'.join(extracted) - + + return "\n\n".join(extracted) + except Exception as e: logger.warning(f"Could not extract transcript segment: {e}") return "" - - def _aggregate_segment_demographics(self, segment_results: List[Dict], num_segments: int) -> Dict[str, Any]: + + def _aggregate_segment_demographics( + self, segment_results: List[Dict], num_segments: int + ) -> Dict[str, Any]: """ Aggregate demographics from multiple segments with deduplication. + Uses voting/confidence averaging to merge results. """ # Collect all demographics and confidences @@ -290,79 +373,102 @@ def _aggregate_segment_demographics(self, segment_results: List[Dict], num_segme all_languages = {} total_individuals = 0 explanations = [] - + for seg_result in segment_results: - if 'error' in seg_result.get('demographics_annotation', {}): + if "error" in seg_result.get("demographics_annotation", {}): continue - + # Aggregate confidence scores (use maximum confidence seen) - for race, conf in seg_result.get('demographics_confidence', {}).get('race', {}).items(): + for race, conf in ( + seg_result.get("demographics_confidence", {}).get("race", {}).items() + ): all_races[race] = max(all_races.get(race, 0), conf) - - for gender, conf in seg_result.get('demographics_confidence', {}).get('gender', {}).items(): + + for gender, conf in ( + seg_result.get("demographics_confidence", {}).get("gender", {}).items() + ): all_genders[gender] = max(all_genders.get(gender, 0), conf) - - for age, conf in seg_result.get('demographics_confidence', {}).get('age', {}).items(): + + for age, conf in ( + seg_result.get("demographics_confidence", {}).get("age", {}).items() + ): all_ages[age] = max(all_ages.get(age, 0), conf) - - for lang, conf in seg_result.get('demographics_confidence', {}).get('language', {}).items(): + + for lang, conf in ( + seg_result.get("demographics_confidence", {}) + .get("language", {}) + .items() + ): all_languages[lang] = max(all_languages.get(lang, 0), conf) - - # Track max individuals seen in any segment - FIX: ensure it's an integer - individuals_count = seg_result.get('demographics_annotation', {}).get('individuals_count', 0) + + # Track max individuals seen in any segment + # FIX: ensure it's an integer + individuals_count = seg_result.get("demographics_annotation", {}).get( + "individuals_count", 0 + ) # Convert to int if it's a string if isinstance(individuals_count, str): try: individuals_count = int(individuals_count) except (ValueError, TypeError): individuals_count = 0 - + total_individuals = max(total_individuals, individuals_count) - - explanations.append(seg_result.get('demographics_annotation', {}).get('explanation', '')) - + + explanations.append( + seg_result.get("demographics_annotation", {}).get("explanation", "") + ) + # Build aggregated result return { "demographics_detailed": { "race": list(all_races.keys()), "gender": list(all_genders.keys()), "age": list(all_ages.keys()), - "language": list(all_languages.keys()) + "language": list(all_languages.keys()), }, "demographics_confidence": { "race": all_races, "gender": all_genders, "age": all_ages, - "language": all_languages + "language": all_languages, }, "demographics_annotation": { "model": self.config.model_name, "annotated_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "individuals_count": total_individuals, - "explanation": f"Aggregated from {num_segments} segments. " + " | ".join(explanations[:3]) - } + "explanation": f"Aggregated from {num_segments} segments. " + + " | ".join(explanations[:3]), + }, } + def _load_transcript(self, transcript_path: Optional[Path], config) -> str: - """Load and truncate transcript if needed""" + """Load and truncate transcript if needed.""" transcript_text = "" if transcript_path and transcript_path.exists(): - with open(transcript_path, 'r', encoding='utf-8') as f: + with open(transcript_path, "r", encoding="utf-8") as f: transcript_text = f.read() - max_length = getattr(config, 'max_transcript_length', 50000) + max_length = getattr(config, "max_transcript_length", 50000) if len(transcript_text) > max_length: - logger.info(f"Truncating transcript from {len(transcript_text)} to {max_length} chars") + logger.info( + f"Truncating transcript from " + f"{len(transcript_text)} to {max_length} chars" + ) transcript_text = transcript_text[:max_length] + "\n...[truncated]" return transcript_text - - def _process_large_media_multimodal(self, media_files: List[tuple], prompt: str) -> str: + + def _process_large_media_multimodal( + self, media_files: List[tuple], prompt: str + ) -> str: """ Process large media files using Gemini File API with multimodal support. - + Args: media_files: List of tuples (media_type, Path) prompt: Analysis prompt - - Returns: + + Returns + ------- Generated response text """ uploaded_files = [] @@ -372,11 +478,11 @@ def _process_large_media_multimodal(self, media_files: List[tuple], prompt: str) uploaded_file = self.client.files.upload(file=str(media_path)) logger.info(f"Uploaded {media_type} file: {uploaded_file.name}") uploaded_files.append(uploaded_file) - + # Wait for all files to process - max_wait = self.file_processing_timeout + max_wait = self.file_processing_timeout wait_time = 0 - + all_processed = False while not all_processed and wait_time < max_wait: all_processed = True @@ -385,25 +491,25 @@ def _process_large_media_multimodal(self, media_files: List[tuple], prompt: str) if uploaded_files[i].state == "PROCESSING": all_processed = False elif uploaded_files[i].state == "FAILED": - raise Exception(f"File processing failed: {getattr(uploaded_files[i], 'error', 'Unknown error')}") - + error_msg = getattr(uploaded_files[i], "error", "Unknown error") + raise Exception(f"File processing failed: {error_msg}") + if not all_processed: time.sleep(10) wait_time += 10 logger.debug(f"Waiting for file processing... ({wait_time}s)") - + if not all_processed: raise Exception(f"File processing timeout after {max_wait} seconds") - + # Generate content with all uploaded files + prompt for attempt in range(self.config.retry_attempts): try: # Build content list: [file1, file2, ..., prompt] content_parts = uploaded_files + [prompt] - + response = self.client.models.generate_content( - model=self.config.model_name, - contents=content_parts + model=self.config.model_name, contents=content_parts ) return response.text except Exception as e: @@ -412,7 +518,7 @@ def _process_large_media_multimodal(self, media_files: List[tuple], prompt: str) time.sleep(self.config.retry_delay) else: raise - + finally: # Clean up all uploaded files for uploaded_file in uploaded_files: @@ -421,48 +527,46 @@ def _process_large_media_multimodal(self, media_files: List[tuple], prompt: str) logger.info(f"Deleted uploaded file: {uploaded_file.name}") except Exception as e: logger.warning(f"Failed to delete uploaded file: {e}") - - - def _process_small_media_multimodal(self, media_files: List[tuple], prompt: str) -> str: + + def _process_small_media_multimodal( + self, media_files: List[tuple], prompt: str + ) -> str: """ Process small media files using inline data with multimodal support. - + Args: media_files: List of tuples (media_type, Path) prompt: Analysis prompt - - Returns: + + Returns + ------- Generated response text """ # Build content parts parts = [] - + # Add all media files as inline data for media_type, media_path in media_files: - with open(media_path, 'rb') as media_file: + with open(media_path, "rb") as media_file: media_bytes = media_file.read() - + mime_type = self._get_media_mime_type(media_path) - + parts.append( types.Part( - inline_data=types.Blob( - data=media_bytes, - mime_type=mime_type - ) + inline_data=types.Blob(data=media_bytes, mime_type=mime_type) ) ) logger.info(f"Added {media_type} ({mime_type}) to inline content") - + # Add prompt as text parts.append(types.Part(text=prompt)) - + # Generate content with retries for attempt in range(self.config.retry_attempts): try: response = self.client.models.generate_content( - model=self.config.model_name, - contents=types.Content(parts=parts) + model=self.config.model_name, contents=types.Content(parts=parts) ) return response.text except Exception as e: @@ -471,68 +575,67 @@ def _process_small_media_multimodal(self, media_files: List[tuple], prompt: str) time.sleep(self.config.retry_delay) else: raise - + return None + def _get_media_mime_type(self, media_path: Path) -> str: - """Get MIME type for media file""" + """Get MIME type for media file.""" extension_map = { # Video types - '.mp4': 'video/mp4', - '.avi': 'video/x-msvideo', - '.mov': 'video/quicktime', - '.wmv': 'video/x-ms-wmv', - '.webm': 'video/webm', - '.mkv': 'video/x-matroska', - '.m4v': 'video/x-m4v', - + ".mp4": "video/mp4", + ".avi": "video/x-msvideo", + ".mov": "video/quicktime", + ".wmv": "video/x-ms-wmv", + ".webm": "video/webm", + ".mkv": "video/x-matroska", + ".m4v": "video/x-m4v", # Audio types - '.m4a': 'audio/m4a', - '.mp3': 'audio/mpeg', - '.wav': 'audio/wav', - '.ogg': 'audio/ogg', - '.flac': 'audio/flac', - '.aac': 'audio/aac', + ".m4a": "audio/m4a", + ".mp3": "audio/mpeg", + ".wav": "audio/wav", + ".ogg": "audio/ogg", + ".flac": "audio/flac", + ".aac": "audio/aac", } - + extension = media_path.suffix.lower() - return extension_map.get(extension, 'application/octet-stream') - + return extension_map.get(extension, "application/octet-stream") + def _build_prompt(self, metadata: Dict[str, Any], transcript_text: str) -> str: - """Build the analysis prompt with transcript embedded""" - from prompts import MAIN_PROMPT_TEMPLATE - + """Build the analysis prompt with transcript embedded.""" # Prepare transcript section if transcript_text: transcript_preview = f"TRANSCRIPT/CAPTIONS:\n{transcript_text}" else: - transcript_preview = "No transcript available. Analyze based on visual and audio content only." - - prompt = MAIN_PROMPT_TEMPLATE.format( - title=metadata.get('title', 'Unknown'), - duration_seconds=metadata.get('duration_seconds', 0), - topic_name=metadata.get('topic_name', 'Unknown'), + transcript_preview = ( + "No transcript available. Analyze based on visual and " + "audio content only." + ) + + return MAIN_PROMPT_TEMPLATE.format( + title=metadata.get("title", "Unknown"), + duration_seconds=metadata.get("duration_seconds", 0), + topic_name=metadata.get("topic_name", "Unknown"), transcript_preview=transcript_preview, model_name=self.config.model_name, - timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S") + timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), ) - - return prompt - + def _parse_response(self, response_text: str) -> Dict[str, Any]: - """Parse and validate JSON response""" + """Parse and validate JSON response.""" try: # Handle None response if response_text is None: logger.error("Received None response from API") return self._get_error_response("API returned no response") - + # Clean response text response_text = response_text.strip() - + # Check for empty response if not response_text: logger.error("Received empty response from API") return self._get_error_response("API returned empty response") - + # Remove markdown code blocks if present if "```json" in response_text: start = response_text.find("```json") + 7 @@ -544,79 +647,90 @@ def _parse_response(self, response_text: str) -> Dict[str, Any]: end = response_text.rfind("```") if end > start: response_text = response_text[start:end] - + # Parse JSON data = json.loads(response_text.strip()) - + # Validate required fields - required_fields = ["demographics_detailed", "demographics_confidence", "demographics_annotation"] + required_fields = [ + "demographics_detailed", + "demographics_confidence", + "demographics_annotation", + ] for field in required_fields: if field not in data: logger.warning(f"Missing required field: {field}") # Add default structure if missing if field == "demographics_detailed": - data[field] = {"race": [], "gender": [], "age": [], "language": []} + data[field] = { + c: [] for c in self.config.demographic_categories + } elif field == "demographics_confidence": - data[field] = {"race": {}, "gender": {}, "age": {}, "language": {}} + data[field] = { + c: {} for c in self.config.demographic_categories + } elif field == "demographics_annotation": data[field] = { "model": self.config.model_name, - "annotated_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "annotated_at": datetime.now().strftime( + "%Y-%m-%d %H:%M:%S" + ), "individuals_count": 0, - "explanation": "Partial response" + "explanation": "Partial response", } - + # Filter by minimum confidence if configured - if hasattr(self.config, 'min_confidence'): + if hasattr(self.config, "min_confidence"): data = self._filter_by_confidence(data, self.config.min_confidence) - + return data - + except json.JSONDecodeError as e: logger.error(f"Failed to parse JSON: {e}") - logger.debug(f"Response text: {response_text[:500] if response_text else 'None'}...") + logger.debug( + f"Response text: {response_text[:500] if response_text else 'None'}..." + ) return self._get_error_response(f"JSON parsing error: {e}") - - def _filter_by_confidence(self, data: Dict[str, Any], min_confidence: float) -> Dict[str, Any]: - """Filter demographics by minimum confidence threshold""" + + def _filter_by_confidence( + self, data: Dict[str, Any], min_confidence: float + ) -> Dict[str, Any]: + """Filter demographics by minimum confidence threshold.""" if "demographics_confidence" not in data or "demographics_detailed" not in data: return data - + filtered_data = data.copy() - - for category in ["race", "gender", "age", "language"]: + + for category in self.config.demographic_categories: if category in data["demographics_confidence"]: # Filter confidence scores - filtered_conf = {k: v for k, v in data["demographics_confidence"][category].items() - if v >= min_confidence} + filtered_conf = { + k: v + for k, v in (data["demographics_confidence"][category].items()) + if v >= min_confidence + } filtered_data["demographics_confidence"][category] = filtered_conf - + # Update detailed list to match filtered confidence if category in data["demographics_detailed"]: - filtered_data["demographics_detailed"][category] = list(filtered_conf.keys()) - + filtered_data["demographics_detailed"][category] = list( + filtered_conf.keys() + ) + return filtered_data - + def _get_error_response(self, error_msg: str) -> Dict[str, Any]: - """Return error response structure""" + """Return error response structure.""" + empty_list = {c: [] for c in self.config.demographic_categories} + empty_dict = {c: {} for c in self.config.demographic_categories} return { - "demographics_detailed": { - "race": [], - "gender": [], - "age": [], - "language": [] - }, - "demographics_confidence": { - "race": {}, - "gender": {}, - "age": {}, - "language": {} - }, + "demographics_detailed": empty_list, + "demographics_confidence": empty_dict, "demographics_annotation": { "model": self.config.model_name, "annotated_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "individuals_count": 0, "explanation": f"Error: {error_msg}", - "error": True - } - } \ No newline at end of file + "error": True, + }, + } diff --git a/sonic-o1/03_demographics_annotation/prompts.py b/sonic-o1/03_demographics_annotation/prompts.py index f013155..6236b95 100644 --- a/sonic-o1/03_demographics_annotation/prompts.py +++ b/sonic-o1/03_demographics_annotation/prompts.py @@ -1,14 +1,17 @@ -""" -Prompt templates for demographics annotation +"""prompts.py. + +Prompt templates for demographics annotation. + +Author: SONIC-O1 Team """ -SYSTEM_PROMPT = """You are a demographics annotation specialist for academic research. -Your task is to analyze multimodal media content (video, audio, and text captions/transcripts) +SYSTEM_PROMPT = """You are a demographics annotation specialist for academic research. +Your task is to analyze multimodal media content (video, audio, and text captions/transcripts) and identify the demographic characteristics of ALL individuals who appear visually or speak. You must be objective, accurate, and avoid making assumptions based on stereotypes. -CRITICAL: Return ONLY valid JSON that can be directly parsed. No explanations outside +CRITICAL: Return ONLY valid JSON that can be directly parsed. No explanations outside the JSON structure.""" MAIN_PROMPT_TEMPLATE = """Analyze this MULTIMODAL media to identify demographics of ALL people who appear visually or speak. @@ -25,7 +28,7 @@ 3. TRANSCRIPT/CAPTIONS: Text representation of spoken content {transcript_preview} -IMPORTANT: Use ALL available modalities together for the most accurate analysis. +IMPORTANT: Use ALL available modalities together for the most accurate analysis. Cross-reference visual, audio, and text cues to identify and confirm demographics. --- @@ -62,24 +65,24 @@ 1. RACE/ETHNICITY (select all that apply): - White: European descent appearance - - Black: African descent appearance + - Black: African descent appearance - Asian: East/Southeast/South Asian appearance - Indigenous: Native American/Aboriginal appearance - Arab: Middle Eastern/North African appearance - Hispanic: Latin American appearance - **Note:** Primarily visual assessment. Audio-only inference should have LOW confidence unless very strong accent indicators. - + 2. GENDER (select all that apply): - Male: Masculine presenting individuals OR deep vocal pitch - Female: Feminine presenting individuals OR high vocal pitch - **Use visual cues first, audio cues second** - + 3. AGE GROUPS (select all that apply): - Young (18-24): Visual appearance OR youthful voice - Middle (25-39): Visual appearance OR mature voice - Older adults (40+): Visual appearance OR older voice characteristics - **Combine visual and audio cues for best accuracy** - + 4. LANGUAGE (select all spoken): - Identify all languages and distinct accents heard - Use audio AND transcript to confirm @@ -139,11 +142,11 @@ 1. **Individual Identification:** - List each person: "Person 1 (visible + speaking), Person 2 (voice only), Person 3 (visible only)" - + 2. **Per-Individual Demographics:** - Person 1: Race [visual], Gender [visual+audio], Age [visual+audio], Language [audio+transcript] - Person 2: Gender [audio], Age [audio], Language [audio+transcript] - + 3. **Aggregate Demographics:** - Unique races: [from all individuals] - Unique genders: [from all individuals] @@ -190,8 +193,8 @@ def get_validation_prompt() -> str: - """Get prompt for validating/fixing JSON output""" - return """The following text should be valid JSON but may have formatting issues. + """Get prompt for validating/fixing JSON output.""" + return """The following text should be valid JSON but may have formatting issues. Please return ONLY the corrected valid JSON with no additional text: - - {response}""" \ No newline at end of file + + {response}""" diff --git a/sonic-o1/03_demographics_annotation/run_annotation.py b/sonic-o1/03_demographics_annotation/run_annotation.py index 5e3e1b1..0cbb42d 100644 --- a/sonic-o1/03_demographics_annotation/run_annotation.py +++ b/sonic-o1/03_demographics_annotation/run_annotation.py @@ -1,62 +1,88 @@ +"""run_annotation.py. + +Main script to run demographics annotation pipeline. + +Author: SONIC-O1 Team """ -Main script to run demographics annotation pipeline -""" + +import argparse import json import logging -from pathlib import Path -from typing import Dict, List, Any, Optional -import argparse -from datetime import datetime -from tqdm import tqdm -import sys import shutil +import sys import time +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional + # Add parent directory to path for imports sys.path.append(str(Path(__file__).parent)) from config_loader import Config from model import DemographicsAnnotator -from prompts import SYSTEM_PROMPT, MAIN_PROMPT_TEMPLATE + class AnnotationPipeline: - """Main pipeline for processing videos""" - + """Main pipeline for processing videos.""" + def __init__(self, config_path: str = "config.yaml"): + """Initialize the annotation pipeline. + + Args: + config_path: Path to configuration YAML file. + """ self.config = Config(config_path) self.setup_logging() self.annotator = DemographicsAnnotator(self.config) self.logger = logging.getLogger(__name__) - self.checkpoint_file = None - + self.checkpoint_file = None + + def setup_logging(self): + """Set up logging based on configuration.""" + handlers = [] + + if self.config.console_output: + handlers.append(logging.StreamHandler()) + + if self.config.file_output: + log_path = Path(__file__).parent / self.config.log_file + handlers.append(logging.FileHandler(log_path)) + + logging.basicConfig( + level=logging.INFO, format=self.config.log_format, handlers=handlers + ) + logging.getLogger("httpx").setLevel(logging.WARNING) + logging.getLogger("google.generativeai").setLevel(logging.WARNING) + def _get_checkpoint_path(self, output_dir: Path) -> Path: - """Get consistent checkpoint file path""" + """Get consistent checkpoint file path.""" return output_dir / "metadata_enhanced_checkpoint.json" - + def _load_checkpoint(self, output_dir: Path) -> Optional[List[Dict]]: - """Load from checkpoint if exists""" + """Load from checkpoint if exists.""" checkpoint_path = self._get_checkpoint_path(output_dir) if checkpoint_path.exists(): self.logger.info(f"Found checkpoint file: {checkpoint_path}") try: - with open(checkpoint_path, 'r') as f: + with open(checkpoint_path, "r") as f: data = json.load(f) self.logger.info(f"Loaded {len(data)} processed videos from checkpoint") return data except Exception as e: self.logger.error(f"Failed to load checkpoint: {e}") return None - + def _save_checkpoint(self, output_dir: Path, metadata: List[Dict]): - """Save checkpoint (atomic write)""" + """Save checkpoint (atomic write).""" checkpoint_path = self._get_checkpoint_path(output_dir) - temp_path = checkpoint_path.with_suffix('.tmp') - + temp_path = checkpoint_path.with_suffix(".tmp") + try: # Write to temp file first - with open(temp_path, 'w') as f: + with open(temp_path, "w") as f: json.dump(metadata, f, indent=2) - + # Atomic rename temp_path.replace(checkpoint_path) self.logger.info(f"Checkpoint saved: {len(metadata)} videos") @@ -64,151 +90,151 @@ def _save_checkpoint(self, output_dir: Path, metadata: List[Dict]): self.logger.error(f"Failed to save checkpoint: {e}") if temp_path.exists(): temp_path.unlink() - + def _has_empty_demographics(self, video_metadata: Dict) -> bool: - """Check if video has empty or missing detailed demographics""" + """Check if video has empty or missing detailed demographics.""" # Check if demographics_detailed exists and has content - demographics_detailed = video_metadata.get('demographics_detailed', {}) - + demographics_detailed = video_metadata.get("demographics_detailed", {}) + if not demographics_detailed: return True - + # Check if it has the expected fields with actual values - required_fields = ['race', 'gender', 'age', 'language'] - for field in required_fields: + for field in self.config.demographic_categories: value = demographics_detailed.get(field) # If field exists and has non-empty list, demographics exist if value and isinstance(value, list) and len(value) > 0: return False # Found at least one valid field - + return True # All fields are empty or missing - - def _get_failed_video_indices(self, metadata_list: List[Dict], enhanced_metadata: List[Dict]) -> List[int]: - """Get indices of videos that need reprocessing""" + + def _get_failed_video_indices( + self, metadata_list: List[Dict], enhanced_metadata: List[Dict] + ) -> List[int]: + """Get indices of videos that need reprocessing.""" failed_indices = [] - + for idx, video_meta in enumerate(enhanced_metadata): if self._has_empty_demographics(video_meta): failed_indices.append(idx) - + return failed_indices - - def setup_logging(self): - """Setup logging based on configuration""" - handlers = [] - - if self.config.console_output: - handlers.append(logging.StreamHandler()) - - if self.config.file_output: - log_path = Path(__file__).parent / self.config.log_file - handlers.append(logging.FileHandler(log_path)) - - logging.basicConfig( - level=logging.INFO, - format=self.config.log_format, - handlers=handlers - ) - logging.getLogger("httpx").setLevel(logging.WARNING) - logging.getLogger("google.generativeai").setLevel(logging.WARNING) - + def process_topic(self, topic: str, retry_failed: bool = False) -> Dict[str, Any]: - """Process all videos in a topic - + """Process all videos in a topic. + Args: topic: Topic name to process - retry_failed: If True, only reprocess videos with empty demographics + retry_failed: If True, only reprocess videos with empty + demographics """ self.logger.info(f"Processing topic: {topic}") if retry_failed: - self.logger.info("RETRY MODE: Only reprocessing videos with empty demographics") - + self.logger.info( + "RETRY MODE: Only reprocessing videos with empty demographics" + ) + # Get paths paths = self.config.get_topic_paths(topic) - + # Check if metadata exists - if not paths['metadata'].exists(): + if not paths["metadata"].exists(): self.logger.error(f"Metadata file not found: {paths['metadata']}") return {"topic": topic, "error": "Metadata not found"} - + # Load existing metadata - with open(paths['metadata'], 'r') as f: + with open(paths["metadata"], "r") as f: metadata_list = json.load(f) - + # Create backup if configured if self.config.create_backup: - backup_path = paths['metadata'].with_suffix('.backup.json') + backup_path = paths["metadata"].with_suffix(".backup.json") if not backup_path.exists(): - shutil.copy(paths['metadata'], backup_path) + shutil.copy(paths["metadata"], backup_path) self.logger.info(f"Created backup: {backup_path}") - + # Load existing enhanced metadata - output_path = paths['videos'] / self.config.output_format + output_path = paths["videos"] / self.config.output_format if output_path.exists(): - with open(output_path, 'r') as f: + with open(output_path, "r") as f: enhanced_metadata = json.load(f) - self.logger.info(f"Loaded {len(enhanced_metadata)} existing processed videos") + self.logger.info( + f"Loaded {len(enhanced_metadata)} existing processed videos" + ) else: enhanced_metadata = [] - + # Determine which videos to process if retry_failed: if not enhanced_metadata: - self.logger.error("No enhanced metadata found. Cannot retry failed videos.") + self.logger.error( + "No enhanced metadata found. Cannot retry failed videos." + ) return {"topic": topic, "error": "No enhanced metadata to retry"} - + # Get indices of failed videos - failed_indices = self._get_failed_video_indices(metadata_list, enhanced_metadata) - + failed_indices = self._get_failed_video_indices( + metadata_list, enhanced_metadata + ) + if not failed_indices: - self.logger.info("No failed videos found. All videos have valid demographics.") + self.logger.info( + "No failed videos found. All videos have valid demographics." + ) return { "topic": topic, "total_videos": len(metadata_list), "failed_videos": 0, "reprocessed": 0, - "output_path": str(output_path) + "output_path": str(output_path), } - - self.logger.info(f"Found {len(failed_indices)} videos with empty demographics") + + self.logger.info( + f"Found {len(failed_indices)} videos with empty demographics" + ) indices_to_process = failed_indices else: # Normal mode: resume from checkpoint or start fresh - checkpoint_metadata = self._load_checkpoint(paths['videos']) + checkpoint_metadata = self._load_checkpoint(paths["videos"]) if checkpoint_metadata: enhanced_metadata = checkpoint_metadata - + if enhanced_metadata: start_idx = len(enhanced_metadata) indices_to_process = list(range(start_idx, len(metadata_list))) - self.logger.info(f"Resuming from video {start_idx + 1}/{len(metadata_list)}") + self.logger.info( + f"Resuming from video {start_idx + 1}/{len(metadata_list)}" + ) else: indices_to_process = list(range(len(metadata_list))) - self.logger.info(f"Starting from beginning") - + self.logger.info("Starting from beginning") + # Create raw responses directory if self.config.save_raw_responses: - raw_dir = paths['videos'] / 'raw_responses' + raw_dir = paths["videos"] / "raw_responses" raw_dir.mkdir(exist_ok=True) - + # Process videos processed_count = 0 for idx in indices_to_process: video_metadata = metadata_list[idx] - video_number = video_metadata.get('video_number', f"{idx+1:03d}") - - self.logger.info(f"Processing video {idx + 1}/{len(metadata_list)} (video_{video_number})") - + video_number = video_metadata.get("video_number", f"{idx + 1:03d}") + + self.logger.info( + f"Processing video {idx + 1}/{len(metadata_list)} " + f"(video_{video_number})" + ) + # Get file paths using config patterns video_path = self.config.get_file_path(topic, "video", video_number) audio_path = self.config.get_file_path(topic, "audio", video_number) caption_path = self.config.get_file_path(topic, "caption", video_number) - + # Determine available media files has_video = video_path.exists() has_audio = audio_path.exists() has_caption = caption_path.exists() - + # Apply preference logic if has_video and has_audio and self.config.prefer_video_with_audio: self.logger.info( @@ -216,7 +242,7 @@ def process_topic(self, topic: str, retry_failed: bool = False) -> Dict[str, Any f"Using video only (with embedded audio) as per config." ) has_audio = False - + # Log what we found media_info = [] if has_video: @@ -225,7 +251,7 @@ def process_topic(self, topic: str, retry_failed: bool = False) -> Dict[str, Any media_info.append(f"audio ({audio_path.name})") if has_caption: media_info.append(f"caption ({caption_path.name})") - + if not has_video and not has_audio: self.logger.warning( f"No media files found for video {video_number}. " @@ -234,55 +260,60 @@ def process_topic(self, topic: str, retry_failed: bool = False) -> Dict[str, Any if not retry_failed: enhanced_metadata.append(video_metadata) continue - - self.logger.info(f"Processing video {video_number} with: {', '.join(media_info)}") - + + self.logger.info( + f"Processing video {video_number} with: {', '.join(media_info)}" + ) + # Process media with multimodal support demographics = self.annotator.process_media( video_path=video_path if has_video else None, audio_path=audio_path if has_audio else None, transcript_path=caption_path if has_caption else None, metadata=video_metadata, - config=self.config + config=self.config, ) - + # Save raw response if configured - if self.config.save_raw_responses and 'raw_response' in demographics: + if self.config.save_raw_responses and "raw_response" in demographics: raw_path = raw_dir / f"video_{video_number}_response.json" - with open(raw_path, 'w') as f: - json.dump({"raw_response": demographics['raw_response']}, f, indent=2) - del demographics['raw_response'] - + with open(raw_path, "w") as f: + json.dump( + {"raw_response": demographics["raw_response"]}, f, indent=2 + ) + del demographics["raw_response"] + # Merge demographics with original metadata enhanced_video_metadata = video_metadata.copy() enhanced_video_metadata.update(demographics) - + # Add processing info - enhanced_video_metadata['processing_info'] = { - 'had_video': has_video, - 'had_audio': has_audio, - 'had_caption': has_caption, - 'processed_at': datetime.now().strftime("%Y-%m-%d %H:%M:%S") + enhanced_video_metadata["processing_info"] = { + "had_video": has_video, + "had_audio": has_audio, + "had_caption": has_caption, + "processed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), } - + if retry_failed: # Update existing entry enhanced_metadata[idx] = enhanced_video_metadata else: # Append new entry enhanced_metadata.append(enhanced_video_metadata) - + processed_count += 1 - + # Save checkpoint periodically if processed_count % self.config.save_interval == 0: - self._save_checkpoint(paths['videos'], enhanced_metadata) - + self._save_checkpoint(paths["videos"], enhanced_metadata) + # Rate limiting between videos - if idx < indices_to_process[-1]: # Don't wait after last video + # Don't wait after last video + if idx < indices_to_process[-1]: # Get video duration from metadata - video_duration = video_metadata.get('duration', 0) - + video_duration = video_metadata.get("duration", 0) + # Determine wait time based on video length if video_duration > self.config.long_video_threshold: wait_time = self.config.delay_after_long_video @@ -293,96 +324,125 @@ def process_topic(self, topic: str, retry_failed: bool = False) -> Dict[str, Any else: wait_time = self.config.delay_between_videos self.logger.info(f"Waiting {wait_time}s before next video...") - + time.sleep(wait_time) - + # Save final enhanced metadata (atomic write) - temp_output = output_path.with_suffix('.tmp') - - with open(temp_output, 'w') as f: + temp_output = output_path.with_suffix(".tmp") + + with open(temp_output, "w") as f: json.dump(enhanced_metadata, f, indent=2) - + temp_output.replace(output_path) - + # Clean up checkpoint - checkpoint_path = self._get_checkpoint_path(paths['videos']) + checkpoint_path = self._get_checkpoint_path(paths["videos"]) if checkpoint_path.exists(): checkpoint_path.unlink() self.logger.info("Checkpoint cleaned up") - - self.logger.info(f"Completed topic {topic}. Enhanced metadata saved to {output_path}") - + + self.logger.info( + f"Completed topic {topic}. Enhanced metadata saved to {output_path}" + ) + result = { "topic": topic, "total_videos": len(metadata_list), "processed": processed_count, - "output_path": str(output_path) + "output_path": str(output_path), } - + if retry_failed: result["failed_videos"] = len(indices_to_process) - + return result + def main(): - """Main entry point""" + """Run main entry point.""" parser = argparse.ArgumentParser(description="Run demographics annotation pipeline") - parser.add_argument("--config", type=str, default="config.yaml", - help="Path to configuration file") + parser.add_argument( + "--config", type=str, default="config.yaml", help="Path to configuration file" + ) parser.add_argument("--topic", type=str, help="Process specific topic only") parser.add_argument("--api-key", type=str, help="Override Gemini API key") - parser.add_argument("--no-cache", action="store_true", - help="Reprocess all videos even if already done") - parser.add_argument("--retry-failed", action="store_true", - help="Only reprocess videos with empty detailed demographics") - + parser.add_argument( + "--no-cache", + action="store_true", + help="Reprocess all videos even if already done", + ) + parser.add_argument( + "--retry-failed", + action="store_true", + help="Only reprocess videos with empty detailed demographics", + ) + args = parser.parse_args() - + # Create pipeline pipeline = AnnotationPipeline(config_path=args.config) - + # Override settings if provided if args.api_key: pipeline.config.api_key = args.api_key if args.no_cache: pipeline.config.use_cache = False - + # Check API key if not pipeline.config.api_key: - pipeline.logger.error("API key not provided. Set GEMINI_API_KEY environment variable or update config.yaml") + pipeline.logger.error( + "API key not provided. Set GEMINI_API_KEY environment " + "variable or update config.yaml" + ) sys.exit(1) - + + # Validate topic early (if specified) + if args.topic and args.topic not in pipeline.config.topics: + pipeline.logger.error( + f"Topic '{args.topic}' not found in configuration. " + f"Available topics: {', '.join(pipeline.config.topics)}" + ) + sys.exit(1) + # Process topics - if args.topic: - # Process single topic - if args.topic not in pipeline.config.topics: - pipeline.logger.error(f"Topic {args.topic} not found in configuration") - sys.exit(1) - result = pipeline.process_topic(args.topic, retry_failed=args.retry_failed) - print(f"Completed processing: {result}") + results = [] + topics_to_process = [args.topic] if args.topic else pipeline.config.topics + + if len(topics_to_process) == 0: + pipeline.logger.error("No topics to process") + sys.exit(1) + + for topic in topics_to_process: + try: + result = pipeline.process_topic(topic, retry_failed=args.retry_failed) + results.append(result) + except Exception as e: + pipeline.logger.error( + f"Failed to process topic {topic}: {e}", exc_info=True + ) + results.append({"topic": topic, "error": str(e)}) + + # Save summary report if processing multiple topics + if len(topics_to_process) > 1: + report_path = ( + Path(pipeline.config.base_path) / "demographics_annotation_report.json" + ) + with open(report_path, "w") as f: + json.dump( + { + "processing_date": datetime.now().isoformat(), + "model_used": pipeline.config.model_name, + "retry_mode": args.retry_failed, + "topics_processed": len(results), + "results": results, + }, + f, + indent=2, + ) + pipeline.logger.info(f"Completed all topics. Report saved to {report_path}") else: - # Process all topics - results = [] - for topic in pipeline.config.topics: - try: - result = pipeline.process_topic(topic, retry_failed=args.retry_failed) - results.append(result) - except Exception as e: - pipeline.logger.error(f"Failed to process topic {topic}: {e}", exc_info=True) - results.append({"topic": topic, "error": str(e)}) - - # Save summary report - report_path = Path(pipeline.config.base_path) / "demographics_annotation_report.json" - with open(report_path, 'w') as f: - json.dump({ - "processing_date": datetime.now().isoformat(), - "model_used": pipeline.config.model_name, - "retry_mode": args.retry_failed, - "topics_processed": len(results), - "results": results - }, f, indent=2) - - print(f"Completed all topics. Report saved to {report_path}") + pipeline.logger.info(f"Completed processing: {results[0]}") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/sonic-o1/04_vqa_generation/README.md b/sonic-o1/04_vqa_generation/README.md index edd3a63..11adcb6 100644 --- a/sonic-o1/04_vqa_generation/README.md +++ b/sonic-o1/04_vqa_generation/README.md @@ -1,150 +1,353 @@ # Video Question-Answer (VQA) Generation +## Overview + This directory handles automatic generation of three video QA-related tasks using Gemini-based multimodal models: 1. **Summarization** - Short + detailed summaries 2. **Multiple Choice Questions (MCQ)** - Segment-level questions with options 3. **Temporal Localization** - Segment-level time-related questions +### Directory Structure + +``` +04_vqa_generation/ +โ”œโ”€โ”€ main.py # Main VQA generation script +โ”œโ”€โ”€ fill_empty_demographics.py # Fill empty demographics in VQA files +โ”œโ”€โ”€ standardize_demographics.py # Standardize demographics to canonical categories +โ”œโ”€โ”€ vqa_config.yaml # Configuration file +โ”œโ”€โ”€ .env # API keys (create this file) +โ”œโ”€โ”€ README.md # This file +โ”‚ +โ”œโ”€โ”€ models/ # Model implementations +โ”‚ โ”œโ”€โ”€ base_gemini.py # Base Gemini client (shared API logic + dry-run) +โ”‚ โ”œโ”€โ”€ summarization_model.py # Task 1: video summarization +โ”‚ โ”œโ”€โ”€ mcq_model.py # Task 2: multiple-choice questions +โ”‚ โ”œโ”€โ”€ temporal_localization_model.py # Task 3: temporal localization +โ”‚ โ””โ”€โ”€ temporal_question_judge.py # GPT-4V validation judge +โ”‚ +โ”œโ”€โ”€ prompts/ # Prompt templates +โ”‚ โ”œโ”€โ”€ summarization_prompts.py +โ”‚ โ”œโ”€โ”€ mcq_prompts.py +โ”‚ โ”œโ”€โ”€ temporal_localization_prompts.py +โ”‚ โ””โ”€โ”€ temporal_judge_prompts.py +โ”‚ +โ””โ”€โ”€ utils/ # Shared utility modules + โ”œโ”€โ”€ config_utils.py # Config class and YAML loader (shared) + โ”œโ”€โ”€ file_utils.py # JSON backup/save helpers (shared) + โ”œโ”€โ”€ demographics_expander.py + โ”œโ”€โ”€ frame_sampler.py + โ””โ”€โ”€ video_segmenter.py +``` + +### Pipeline Workflow + +``` +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Step 1: VQA Generation โ”‚ +โ”‚ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ โ”‚ +โ”‚ Input: dataset/videos//metadata_enhanced.json โ”‚ +โ”‚ dataset/videos//video_*.mp4 โ”‚ +โ”‚ dataset/audios//audio_*.m4a โ”‚ +โ”‚ dataset/captions//caption_*.srt โ”‚ +โ”‚ Output: vqa/ โ”‚ +โ”‚ โ”œโ”€โ”€ task1_summarization/ โ”‚ +โ”‚ โ”‚ โ”œโ”€โ”€ 01_.json โ”‚ +โ”‚ โ”‚ โ””โ”€โ”€ ... โ”‚ +โ”‚ โ”œโ”€โ”€ task2_mcq/ โ”‚ +โ”‚ โ”‚ โ”œโ”€โ”€ 01_.json โ”‚ +โ”‚ โ”‚ โ””โ”€โ”€ ... โ”‚ +โ”‚ โ””โ”€โ”€ task3_temporal_localization/ โ”‚ +โ”‚ โ”œโ”€โ”€ 01_.json โ”‚ +โ”‚ โ””โ”€โ”€ ... โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + โ†“ +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Step 2: Fill Empty Demographics (Optional) โ”‚ +โ”‚ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ โ”‚ +โ”‚ Input: vqa/task*/*.json (with empty demographics arrays) โ”‚ +โ”‚ Output: vqa/task*/*.json (with demographics filled) โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + โ†“ +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Step 3: Standardize Demographics (Optional) โ”‚ +โ”‚ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ โ”‚ +โ”‚ Input: vqa/task*/*.json (with variant demographic terms) โ”‚ +โ”‚ Output: vqa/task*/*.json (with canonical demographic categories) โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ +``` + +### Features + +- **Three Task Types**: Summarization, MCQ generation, and temporal localization +- **Multimodal Analysis**: Combines video, audio, and transcript data +- **Automatic Segmentation**: Handles long videos by splitting into segments +- **Demographics Integration**: Includes demographic information in VQA entries +- **Dry-Run Mode**: Test the full pipeline without API calls (`--dry-run`) +- **Rate Limiting**: Configurable delays to avoid API limits +- **Error Handling**: Retries, validation, and graceful degradation +- **Skip Existing**: Automatically skips videos that already have VQA generated +- **Shared Utilities**: Common config loading (`utils/config_utils.py`) and file I/O (`utils/file_utils.py`) across scripts + ## Prerequisites Before running this step, you must have completed: + 1. **Data Curation** (see [01_data_curation](../01_data_curation/)) + - Downloaded videos and audio files + - Generated metadata.json files + 2. **Caption Generation** (see [02_caption_generation](../02_caption_generation/)) + - Generated captions for all videos (SRT format) + 3. **Demographics Annotation** (see [03_demographics_annotation](../03_demographics_annotation/)) + - Generated metadata_enhanced.json files with demographics -Your `dataset/` directory should have: +Your `dataset/` directory should have this structure: ``` - dataset/ -โ”œโ”€โ”€ videos/ -โ”‚ โ”œโ”€โ”€ 01_/ -โ”‚ โ”‚ โ”œโ”€โ”€ video_001.mp4 -โ”‚ โ”‚ โ”œโ”€โ”€ video_002.mp4 -โ”‚ โ”‚ โ””โ”€โ”€ metadata_enhanced.json -โ”‚ โ”œโ”€โ”€ 02_/ -โ”‚ โ”‚ โ””โ”€โ”€ ... -โ”‚ โ””โ”€โ”€ ... -โ”œโ”€โ”€ audios/ -โ”‚ โ”œโ”€โ”€ 01_/ -โ”‚ โ”‚ โ”œโ”€โ”€ audio_001.m4a -โ”‚ โ”‚ โ””โ”€โ”€ ... -โ”‚ โ””โ”€โ”€ ... -โ””โ”€โ”€ captions/ -โ”œโ”€โ”€ 01_/ -โ”‚ โ”œโ”€โ”€ caption_001.srt -โ”‚ โ””โ”€โ”€ ... -โ””โ”€โ”€ ... +โ”œโ”€โ”€ videos// +โ”‚ โ”œโ”€โ”€ video_001.mp4 +โ”‚ โ”œโ”€โ”€ video_002.mp4 +โ”‚ โ””โ”€โ”€ metadata_enhanced.json +โ”œโ”€โ”€ audios// +โ”‚ โ”œโ”€โ”€ audio_001.m4a +โ”‚ โ””โ”€โ”€ audio_002.m4a +โ””โ”€โ”€ captions// + โ”œโ”€โ”€ caption_001.srt + โ””โ”€โ”€ caption_002.srt +``` -```` +4. **API Setup** + 1. **Get Gemini API Key** (Required) + - Go to [Google AI Studio](https://makersuite.google.com/app/apikey) + - Create or select a project + - Generate an API key -## Required Packages + 2. **Get OpenAI API Key** (Optional) + - Required only if temporal localization uses OpenAI-based judging/validation + - Go to [OpenAI Platform](https://platform.openai.com/api-keys) + - Create an API key -All required Python packages are included in [requirements_venv.txt](../../requirements_venv.txt), including: -- `google-generativeai` -- `openai` -- `python-dotenv` -- `pyyaml` -- `tqdm` + 3. **Set API Keys** + Create a `.env` file in this directory: + ```bash + GEMINI_API_KEY=your_gemini_api_key_here + OPENAI_API_KEY=your_openai_api_key_here # Optional + ``` -## API Setup + Or export them as environment variables: + ```bash + export GEMINI_API_KEY=your_gemini_api_key_here + export OPENAI_API_KEY=your_openai_api_key_here # Optional + ``` -### Gemini API Key (Required) -1. Go to Google AI Studio -2. Generate an API key +## Installation -### OpenAI API Key (Optional) -Required only if your temporal pipeline uses OpenAI-based judging/validation. +### Required Packages + +All required Python packages are included in the project's +[requirements_venv.txt](../../requirements_venv.txt). -### Set API Keys -Create a `.env` file in this directory: -```bash -GEMINI_API_KEY=your_gemini_api_key_here -OPENAI_API_KEY=your_openai_api_key_here -```` ## Configuration -Edit [config/vqa_config.yaml](config/vqa_config.yaml) to customize processing settings. +Edit [vqa_config.yaml](vqa_config.yaml) to customize processing settings. ### Model Settings - ```yaml gemini: - model_name: "gemini-2.5-flash" - temperature: 0.3 - max_output_tokens: 2048 + model_name: "gemini-2.5-flash" # Model to use + temperature: 0.3 # Lower = more deterministic + max_output_tokens: 2048 # Response length + retry_attempts: 3 # Number of retries on failure + file_processing_timeout: 7200 # Max time for file processing (2 hours) ``` -### Rate Limiting +### Video Processing Settings +```yaml +video: + summarization_segment_duration: 600 # 10 minutes for summarization + mcq_segment_duration: 180 # 3 minutes for MCQ + temporal_localization_segment_duration: 180 # 3 minutes for temporal + segment_overlap: 30 # Overlap between segments (30 sec) +``` +### Rate Limiting ```yaml rate_limit: - delay_between_videos: 45 - delay_after_long_video: 60 + delay_between_videos: 45 # Seconds between videos + delay_after_segment: 10 # Seconds after each segment + delay_after_long_video: 90 # Extra delay after long videos + long_video_threshold: 1800 # Threshold for "long" video (30 min) + delay_after_api_call: 15 # Seconds after each Gemini call +``` + +### Task-Specific Settings +```yaml +summarization: + constraints: + max_words_detailed: 300 + max_words_segment: 120 + timeline_items_min: 5 + timeline_items_max: 12 + +mcq: + num_options: 5 # Always 5 options + questions_per_segment: 1 # MCQs per segment + +temporal_localization: + questions_per_segment: 3 # Temporal questions per segment + judge_enabled: true # Enable GPT-4V validation (optional) + judge_model: "gpt-4o" # OpenAI model for validation +``` + +### Processing Options +```yaml +processing: + save_raw_responses: true # Save raw Gemini responses for debugging + skip_existing: true # Skip videos that already have VQA generated + parallel_processing: false # Set to true for faster processing ``` ## Usage -Run from the project root so relative paths resolve correctly. +**IMPORTANT**: Always run the scripts from the project root +(sonic-o1/sonic-o1 directory) so relative paths work correctly. + +### Dry Run (No API Calls) + +Verify the pipeline end-to-end without spending API credits: + +```bash +# Dry run - exercises all code paths with stub responses, writes nothing +python 04_vqa_generation/main.py --dry-run --all + +# Dry run for a single topic and task +python 04_vqa_generation/main.py --dry-run --topics 1 --task summarization +``` ### Process All Topics - All Tasks ```bash +# Navigate to working directory (note: sonic-o1/sonic-o1) +cd /path/to/sonic-o1/sonic-o1 + +# Run VQA generation for all topics python 04_vqa_generation/main.py --all ``` ### Process Specific Topics ```bash +# Process topics 1, 2, and 3 python 04_vqa_generation/main.py --topics 1,2,3 + +# Process single topic python 04_vqa_generation/main.py --topics 5 ``` ### Process Specific Task Only ```bash +# Generate only summarization for topics 1 and 2 python 04_vqa_generation/main.py --topics 1,2 --task summarization + +# Generate only MCQ for all topics python 04_vqa_generation/main.py --all --task mcq + +# Generate only temporal localization for all topics python 04_vqa_generation/main.py --all --task temporal ``` +### Use Custom Configuration + +```bash +python 04_vqa_generation/main.py --config path/to/custom_config.yaml +``` + ### Fill Empty Demographics +After generating VQA, fill empty demographics arrays: + ```bash +# Dry run to see what would be filled python 04_vqa_generation/fill_empty_demographics.py --dry-run + +# Fill demographics for all topics python 04_vqa_generation/fill_empty_demographics.py + +# Fill demographics for specific topics python 04_vqa_generation/fill_empty_demographics.py --topics 10,11 ``` ### Standardize Demographics +Standardize demographic values to canonical categories: + ```bash +# Dry run to see what would be standardized python 04_vqa_generation/standardize_demographics.py --dry-run + +# Standardize demographics for all topics python 04_vqa_generation/standardize_demographics.py + +# Standardize demographics for specific topics python 04_vqa_generation/standardize_demographics.py --topics 1,2,3 ``` -## Output Structure +### Command-Line Arguments + +The main script supports several command-line arguments: + +| Argument | Description | Example | +|----------|-------------|---------| +| `--config` | Path to configuration file (default: `vqa_config.yaml`) | `--config my_config.yaml` | +| `--topics` | Comma-separated topic IDs to process | `--topics 1,2,3` | +| `--all` | Process all topics | `--all` | +| `--task` | Process specific task only | `--task summarization` | +| `--output` | Output directory (overrides config) | `--output custom_output/` | +| `--dry-run` | Run without API calls; generates stub outputs | `--dry-run` | + +**Examples:** + +```bash +# Process single topic with custom config +python 04_vqa_generation/main.py \ + --topics 1 \ + --config custom_config.yaml + +# Process specific task for multiple topics +python 04_vqa_generation/main.py \ + --topics 1,2,3 \ + --task mcq \ + --output vqa_custom/ +``` + +## Output -Outputs are written to the configured output directory (default: `/vqa/`) in per-task folders, with one JSON file per topic: +The script creates JSON files in the configured output directory (default: `vqa/`) +organized by task type, with one JSON file per topic. +### Output Location ``` -/ +vqa/ โ”œโ”€โ”€ task1_summarization/ -โ”‚ โ”œโ”€โ”€ 01_.json -โ”‚ โ”œโ”€โ”€ 02_.json +โ”‚ โ”œโ”€โ”€ 01_Patient-Doctor_Consultations.json +โ”‚ โ”œโ”€โ”€ 02_Job_Interviews.json โ”‚ โ””โ”€โ”€ ... โ”œโ”€โ”€ task2_mcq/ -โ”‚ โ”œโ”€โ”€ 01_.json -โ”‚ โ”œโ”€โ”€ 02_.json +โ”‚ โ”œโ”€โ”€ 01_Patient-Doctor_Consultations.json +โ”‚ โ”œโ”€โ”€ 02_Job_Interviews.json โ”‚ โ””โ”€โ”€ ... โ””โ”€โ”€ task3_temporal_localization/ - โ”œโ”€โ”€ 01_.json - โ”œโ”€โ”€ 02_.json + โ”œโ”€โ”€ 01_Patient-Doctor_Consultations.json + โ”œโ”€โ”€ 02_Job_Interviews.json โ””โ”€โ”€ ... ``` -Each output file has a shared wrapper: +### Output Format + +Each output file has a shared wrapper structure: ```json { @@ -167,16 +370,32 @@ Each output file has a shared wrapper: "topic_id": 1, "topic_name": "Patient-Doctor Consultations", "summary_short": [ + "Bullet point 1", + "Bullet point 2", "..." ], - "summary_detailed": "...", + "summary_detailed": "Full detailed summary text...", + "timeline": [ + {"time": "00:05", "event": "Doctor introduces himself"}, + {"time": "00:12", "event": "Patient describes symptoms"} + ], + "glossary": [ + {"term": "Hypertension", "definition": "High blood pressure"}, + {"term": "Systolic", "definition": "Upper blood pressure reading"} + ], + "demographics": { + "race": ["White"], + "gender": ["Male", "Female"], + "age": ["Middle (25-39)"], + "language": ["English"] + }, "confidence": 0.92 } ``` ### Task 2: MCQ (`task2_mcq/*.json`) -`entries` is a list with one entry per generated question (typically multiple per video). Segment fields are used for merging/replacement: +`entries` is a list with one entry per generated question (typically multiple per video): ```json { @@ -187,16 +406,29 @@ Each output file has a shared wrapper: "start": 120.0, "end": 180.0 }, - "question": "...", - "options": ["...", "...", "...", "...", "..."], + "question": "What is the patient's primary concern?", + "options": [ + "High blood pressure", + "Chest pain", + "Headache", + "Fatigue", + "Not enough evidence" + ], "correct_answer": 0, + "evidence_tags": ["medical_monitors", "beds"], + "demographics": { + "race": ["White"], + "gender": ["Male"], + "age": ["Middle (25-39)"], + "language": ["English"] + }, "confidence": 0.85 } ``` ### Task 3: Temporal Localization (`task3_temporal_localization/*.json`) -`entries` is a list with one entry per generated temporal question. Segment fields are used for merging/replacement: +`entries` is a list with one entry per generated temporal question: ```json { @@ -207,60 +439,121 @@ Each output file has a shared wrapper: "start": 45.0, "end": 90.0 }, - "question": "...", - "answer": "...", + "question": "What happens after the doctor takes the patient's blood pressure?", + "answer": "The doctor reviews the readings and discusses treatment options.", + "temporal_relation": "after", + "demographics": { + "race": ["White"], + "gender": ["Male", "Female"], + "age": ["Middle (25-39)"], + "language": ["English"] + }, "confidence": 0.81 } ``` -## Processing Pipeline +## Processing Time + +- **Per video**: ~30-120 seconds depending on video length and task type +- **Long videos**: Automatically segmented and may take longer +- **Rate limiting**: Script includes delays between videos to avoid API limits + +### Estimated Time for Full Dataset +- 13 topics ร— 25 videos = 325 videos +- Average 60 seconds per video = ~325 minutes +- With rate limiting: ~6-8 hours total (all tasks) + +## Troubleshooting + +### API Key Not Found + +**Problem**: `ERROR: API key not set!` + +**Solution**: +```bash +# Create .env file in 04_vqa_generation directory +echo "GEMINI_API_KEY=your_key_here" > 04_vqa_generation/.env + +# Or export environment variable +export GEMINI_API_KEY=your_key_here +``` + +### File Not Found Errors -1. **Generate VQA** +**Problem**: `FileNotFoundError: dataset/videos/...` +**Solution**: Make sure you're running from the project root: ```bash +cd /path/to/sonic-o1/sonic-o1 python 04_vqa_generation/main.py --all ``` -2. **Fill Empty Demographics** +### API Rate Limit Exceeded -```bash -python 04_vqa_generation/fill_empty_demographics.py +**Problem**: `429 Too Many Requests` + +**Solution**: Increase delays in [vqa_config.yaml](vqa_config.yaml): +```yaml +rate_limit: + delay_between_videos: 60 # Increase from 45 + delay_after_api_call: 20 # Increase from 15 + delay_after_long_video: 120 # Increase from 90 ``` -3. **Standardize Demographics** +### Empty Demographics + +**Problem**: Some VQA entries have empty demographics arrays +**Solution**: Run the fill demographics script: ```bash -python 04_vqa_generation/standardize_demographics.py +python 04_vqa_generation/fill_empty_demographics.py ``` -## Directory Structure +### Video Too Long + +**Problem**: Processing takes too long or times out +**Solution**: Adjust segment duration in [vqa_config.yaml](vqa_config.yaml): +```yaml +video: + summarization_segment_duration: 300 # Reduce from 600 + mcq_segment_duration: 120 # Reduce from 180 ``` -04_vqa_generation/ -โ”œโ”€โ”€ config/ -โ”‚ โ””โ”€โ”€ vqa_config.yaml -โ”œโ”€โ”€ models/ -โ”‚ โ”œโ”€โ”€ base_gemini.py -โ”‚ โ”œโ”€โ”€ summarization_model.py -โ”‚ โ”œโ”€โ”€ mcq_model.py -โ”‚ โ””โ”€โ”€ temporal_localization_model.py -โ”œโ”€โ”€ prompts/ -โ”‚ โ”œโ”€โ”€ summarization_prompts.py -โ”‚ โ”œโ”€โ”€ mcq_prompts.py -โ”‚ โ””โ”€โ”€ temporal_localization_prompts.py -โ”œโ”€โ”€ utils/ -โ”œโ”€โ”€ main.py -โ”œโ”€โ”€ fill_empty_demographics.py -โ”œโ”€โ”€ standardize_demographics.py -โ””โ”€โ”€ .env + +### Timeout Errors + +**Problem**: `TimeoutError: Processing timed out` + +**Solution**: Increase timeout in [vqa_config.yaml](vqa_config.yaml): +```yaml +gemini: + file_processing_timeout: 10800 # Increase from 7200 (3 hours) ``` +## Files + +- `main.py` - Main VQA generation script (supports `--dry-run`) +- `fill_empty_demographics.py` - Fill empty demographics in VQA files (supports `--dry-run`) +- `standardize_demographics.py` - Standardize demographics to canonical categories (supports `--dry-run`) +- `vqa_config.yaml` - Configuration file +- `models/` - Model implementations (base_gemini, summarization, mcq, temporal, judge) +- `prompts/` - Prompt templates for each task +- `utils/` - Shared utilities + - `config_utils.py` - `Config` class and `load_config()` (used by all scripts) + - `file_utils.py` - `save_json_with_backup()` (used by demographics scripts) + - `demographics_expander.py` - Expand human-reviewed demographics via Gemini + - `frame_sampler.py` - Sample video frames for GPT-4V judge + - `video_segmenter.py` - Segment video/audio via FFmpeg +- `.env` - Environment variables (API keys) - create this file ## Notes -- The script skips videos that already have VQA generated (configurable) -- Raw API responses are saved if `save_raw_responses: true` in config +- The script processes videos in order by video number +- Already processed videos are skipped automatically if `skip_existing: true` +- Raw model responses are saved for debugging if `save_raw_responses: true` - Temporal localization uses GPT-4V for validation (optional, requires OpenAI key) -- All paths are relative to project root, so always run from sonic-o1 directory -- Demographics optimization: Task 2 reuses Task 3 demographics (same segments), Task 1 generates separately -- Check scripts (check_empty_demographics.py, check_failed_summary.py) are helper utilities not in git +- All paths are relative to the project root, so always run from the + sonic-o1/sonic-o1 directory (the inner sonic-o1 directory that + contains the pipeline code) +- Demographics optimization: Task 2 reuses demographics from Task 3 + (same segments), Task 1 generates separately for full videos diff --git a/sonic-o1/04_vqa_generation/fill_empty_demographics.py b/sonic-o1/04_vqa_generation/fill_empty_demographics.py index 45cb3ac..6c9810a 100644 --- a/sonic-o1/04_vqa_generation/fill_empty_demographics.py +++ b/sonic-o1/04_vqa_generation/fill_empty_demographics.py @@ -1,334 +1,398 @@ #!/usr/bin/env python3 -""" -Fill Empty Demographics in VQA Files -This script fills empty demographics arrays in task1, task2, and task3 VQA JSON files -by using the DemographicsExpander with Gemini to analyze video/audio content. +"""fill_empty_demographics.py. -OPTIMIZATION: -- Task 3 generates demographics for segments via API -- Task 2 reuses demographics from Task 3 (same segments, Task 3 is reviewed and correct) -- Task 1 generates demographics for full videos via API +Fill empty demographics arrays in task1, task2, and task3 VQA JSON files +using DemographicsExpander with Gemini. Task 3 generates per-segment; +Task 2 reuses Task 3 demographics; Task 1 generates for full videos. Usage: - python fill_empty_demographics.py --config config/vqa_config.yaml - python fill_empty_demographics.py --config config/vqa_config.yaml --dry-run + python fill_empty_demographics.py --config vqa_config.yaml + python fill_empty_demographics.py --config vqa_config.yaml --dry-run python fill_empty_demographics.py --topics 10,11 --dry-run + +Author: SONIC-O1 Team """ + import argparse +import copy import json import logging -import yaml import os import sys import time -from pathlib import Path -from datetime import datetime -from typing import Dict, List, Any, Optional from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List, Optional + +from utils.config_utils import Config, load_config +from utils.file_utils import save_json_with_backup + + +_VIDEO_EXTENSIONS = frozenset({".mp4", ".avi", ".mov", ".webm", ".mkv", ".m4v"}) # Load environment variables try: from dotenv import load_dotenv - env_path = Path(__file__).parent / '.env' + + env_path = Path(__file__).parent / ".env" if env_path.exists(): load_dotenv(env_path) print(f"Loaded environment variables from {env_path}") except ImportError: - print("python-dotenv not installed, ensure GEMINI_API_KEY is set in environment") + print("python-dotenv not installed; set GEMINI_API_KEY in environment") # Setup logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s', + format="%(asctime)s - %(levelname)s - %(message)s", handlers=[ - logging.FileHandler('fill_demographics.log'), - logging.StreamHandler() - ] + logging.FileHandler("fill_demographics.log"), + logging.StreamHandler(), + ], ) logger = logging.getLogger(__name__) -class Config: - """Configuration wrapper""" - def __init__(self, config_dict): - for key, value in config_dict.items(): - if isinstance(value, dict): - setattr(self, key, Config(value)) - else: - setattr(self, key, value) - -def load_config(config_path: str) -> Config: - """Load configuration from YAML file""" - with open(config_path, 'r') as f: - config_dict = yaml.safe_load(f) - return Config(config_dict) class DemographicsFiller: - """Fill empty demographics in VQA files""" - + """Fill empty demographics in VQA files.""" + def __init__(self, config: Config, dry_run: bool = False): """ Initialize demographics filler. - + Args: config: Configuration object - dry_run: If True, only report what would be done without making changes + dry_run: If True, only report what would be done (no changes). """ self.config = config self.dry_run = dry_run self.stats = { - 'task1': {'total': 0, 'empty': 0, 'filled': 0, 'failed': 0, 'reused': 0}, - 'task2': {'total': 0, 'empty': 0, 'filled': 0, 'failed': 0, 'reused': 0}, - 'task3': {'total': 0, 'empty': 0, 'filled': 0, 'failed': 0, 'reused': 0} + "task1": {"total": 0, "empty": 0, "filled": 0, "failed": 0, "reused": 0}, + "task2": {"total": 0, "empty": 0, "filled": 0, "failed": 0, "reused": 0}, + "task3": {"total": 0, "empty": 0, "filled": 0, "failed": 0, "reused": 0}, } - - # Import models (only if not dry-run) + + # Lazy import: only load when not dry-run (env-dependent, avoid heavy deps) if not dry_run: - from models.base_gemini import BaseGeminiClient - from utils.demographics_expander import DemographicsExpander - from utils.video_segmenter import VideoSegmenter - + from models.base_gemini import BaseGeminiClient # noqa: PLC0415 + from utils.demographics_expander import ( # noqa: PLC0415 + DemographicsExpander, + ) + from utils.video_segmenter import VideoSegmenter # noqa: PLC0415 + self.demographics_expander = DemographicsExpander(config) self.segmenter = VideoSegmenter(config) self.gemini_client = BaseGeminiClient(config) - + # Load all metadata self.metadata_by_topic = self._load_all_metadata() - - # Cache for Task 3 demographics (video_id -> {segment_key -> demographics}) + + # Cache: video_id -> {segment_key -> demographics} self.task3_demographics_cache = {} - + def _load_all_metadata(self) -> Dict[int, Dict[str, Any]]: - """Load metadata_enhanced.json for all topics""" + """Load metadata_enhanced.json for all topics.""" metadata_map = {} dataset_root = Path(self.config.paths.dataset_root) - videos_dir = dataset_root / 'videos' - + videos_dir = dataset_root / "videos" + if not videos_dir.exists(): logger.error(f"Videos directory not found: {videos_dir}") return {} - + for topic_dir in sorted(videos_dir.iterdir()): if topic_dir.is_dir() and topic_dir.name[0].isdigit(): - parts = topic_dir.name.split('_', 1) + parts = topic_dir.name.split("_", 1) if len(parts) == 2: topic_id = int(parts[0]) - - metadata_file = topic_dir / 'metadata_enhanced.json' + + metadata_file = topic_dir / "metadata_enhanced.json" if metadata_file.exists(): try: - with open(metadata_file, 'r') as f: + with open(metadata_file, "r") as f: metadata_list = json.load(f) - + # Create map by video_id topic_metadata = {} for meta in metadata_list: - video_id = meta.get('video_id', meta.get('video_number')) + video_id = meta.get( + "video_id", meta.get("video_number") + ) topic_metadata[video_id] = meta - + metadata_map[topic_id] = topic_metadata - logger.info(f"Loaded metadata for topic {topic_id}: {len(topic_metadata)} videos") + logger.info( + "Loaded metadata for topic %s: %d videos", + topic_id, + len(topic_metadata), + ) except Exception as e: - logger.error(f"Failed to load metadata for topic {topic_id}: {e}") - + logger.error( + "Failed to load metadata for topic %s: %s", + topic_id, + e, + ) + return metadata_map - + def _load_task3_demographics(self, topic_id: int): - """Load all Task 3 demographics for a topic into cache""" + """Load all Task 3 demographics for a topic into cache.""" if topic_id in self.task3_demographics_cache: return # Already loaded - + output_dir = Path(self.config.paths.output_dir) - task3_dir = output_dir / 'task3_temporal_localization' - + task3_dir = output_dir / "task3_temporal_localization" + if not task3_dir.exists(): logger.warning(f"Task 3 directory not found: {task3_dir}") self.task3_demographics_cache[topic_id] = {} return - + # Find Task 3 JSON file for this topic task3_files = list(task3_dir.glob(f"{topic_id:02d}_*.json")) - + if not task3_files: logger.warning(f"No Task 3 file found for topic {topic_id}") self.task3_demographics_cache[topic_id] = {} return - + task3_path = task3_files[0] - + try: - with open(task3_path, 'r') as f: + with open(task3_path, "r") as f: task3_data = json.load(f) - + # Build cache: video_id -> {start_time -> demographics_dict} cache = defaultdict(dict) - - for entry in task3_data.get('entries', []): - video_id = entry.get('video_id') - segment = entry.get('segment', {}) - start = segment.get('start', 0) - + + for entry in task3_data.get("entries", []): + video_id = entry.get("video_id") + segment = entry.get("segment", {}) + start = segment.get("start", 0) + # Use start time as key (rounded to handle float differences) start_key = round(start, 1) # Round to 0.1s precision - - demographics = entry.get('demographics', []) - + + demographics = entry.get("demographics", []) + # Only cache if demographics exist if demographics: cache[video_id][start_key] = { - 'demographics': demographics, - 'demographics_total_individuals': entry.get('demographics_total_individuals', 0), - 'demographics_confidence': entry.get('demographics_confidence', 0.0), - 'demographics_explanation': entry.get('demographics_explanation', ''), - 'segment_end': segment.get('end', 0) # Store for logging + "demographics": demographics, + "demographics_total_individuals": entry.get( + "demographics_total_individuals", 0 + ), + "demographics_confidence": entry.get( + "demographics_confidence", 0.0 + ), + "demographics_explanation": entry.get( + "demographics_explanation", "" + ), + "segment_end": segment.get("end", 0), } - - self.task3_demographics_cache[topic_id] = dict(cache) - logger.info(f"Loaded Task 3 demographics cache for topic {topic_id}: " - f"{len(cache)} videos, {sum(len(v) for v in cache.values())} segments") - + + self.task3_demographics_cache[topic_id] = copy.deepcopy(dict(cache)) + n_segments = sum(len(v) for v in cache.values()) + logger.info( + "Loaded Task 3 cache for topic %s: %d videos, %d segments", + topic_id, + len(cache), + n_segments, + ) + except Exception as e: - logger.error(f"Failed to load Task 3 demographics for topic {topic_id}: {e}") + logger.error( + "Failed to load Task 3 demographics for topic %s: %s", + topic_id, + e, + ) self.task3_demographics_cache[topic_id] = {} - - def _get_task3_demographics(self, video_id: str, start_time: float, end_time: float, - topic_id: int) -> Optional[Dict[str, Any]]: + + def _get_task3_demographics( + self, video_id: str, start_time: float, end_time: float, topic_id: int + ) -> Optional[Dict[str, Any]]: """ Get demographics from Task 3 cache for a specific segment. - Uses start time as key for robust matching (ignores end time differences). + + Uses start time as key (ignores end time differences). """ # Ensure cache is loaded self._load_task3_demographics(topic_id) - + topic_cache = self.task3_demographics_cache.get(topic_id, {}) video_cache = topic_cache.get(video_id, {}) - + # Use start time as key (rounded to 0.1s precision) start_key = round(start_time, 1) - + demographics = video_cache.get(start_key) - + if demographics: - cached_end = demographics.get('segment_end', end_time) - logger.debug(f"Found Task3 match for {video_id} at start={start_time}s " - f"(Task2 end={end_time}s, Task3 end={cached_end}s)") - + cached_end = demographics.get("segment_end", end_time) + logger.debug( + f"Found Task3 match for {video_id} at start={start_time}s " + f"(Task2 end={end_time}s, Task3 end={cached_end}s)" + ) + return demographics - - def _update_task3_cache(self, topic_id: int, video_id: str, start_time: float, - end_time: float, demographics_data: Dict[str, Any]): + + def _update_task3_cache( + self, + topic_id: int, + video_id: str, + start_time: float, + end_time: float, + demographics_data: Dict[str, Any], + ): """Update Task 3 cache with newly filled demographics.""" # Ensure cache exists for this topic if topic_id not in self.task3_demographics_cache: self.task3_demographics_cache[topic_id] = {} - + if video_id not in self.task3_demographics_cache[topic_id]: self.task3_demographics_cache[topic_id][video_id] = {} - + # Use start time as key (rounded to 0.1s precision) start_key = round(start_time, 1) - + # Add segment_end for logging purposes - demographics_data['segment_end'] = end_time - + demographics_data["segment_end"] = end_time + # Update cache self.task3_demographics_cache[topic_id][video_id][start_key] = demographics_data - - logger.debug(f"Updated Task3 cache for {video_id} at start={start_key}s") - + logger.debug("Updated Task3 cache for %s at start=%s", video_id, start_key) + def get_file_paths(self, video_id: str, topic_id: int) -> Dict[str, Optional[Path]]: - """Get paths for video, audio, and transcript files""" + """Get paths for video, audio, and transcript files.""" dataset_root = Path(self.config.paths.dataset_root) - videos_dir = dataset_root / 'videos' - + videos_dir = dataset_root / "videos" + # Find topic directory topic_dir = None for d in videos_dir.iterdir(): if d.is_dir() and d.name.startswith(f"{topic_id:02d}_"): topic_dir = d break - + if not topic_dir: logger.warning(f"Topic directory not found for topic {topic_id}") - return {'video_path': None, 'audio_path': None, 'transcript_path': None} - + return { + "video_path": None, + "audio_path": None, + "transcript_path": None, + } + # Get metadata to find video_number metadata = self.metadata_by_topic.get(topic_id, {}).get(video_id, {}) - video_number = metadata.get('video_number', video_id) - + video_number = metadata.get("video_number", video_id) + # Construct paths video_path = topic_dir / f"video_{video_number}.mp4" - audio_dir = dataset_root / 'audios' / topic_dir.name + audio_dir = dataset_root / "audios" / topic_dir.name audio_path = audio_dir / f"audio_{video_number}.m4a" - caption_dir = dataset_root / 'captions' / topic_dir.name - transcript_path = caption_dir / f"caption_{video_number}.srt" # โœ“ Fixed: .srt not .txt - + caption_dir = dataset_root / "captions" / topic_dir.name + transcript_path = ( + caption_dir / f"caption_{video_number}.srt" + ) # โœ“ Fixed: .srt not .txt + return { - 'video_path': video_path if video_path.exists() else None, - 'audio_path': audio_path if audio_path.exists() else None, - 'transcript_path': transcript_path if transcript_path.exists() else None + "video_path": video_path if video_path.exists() else None, + "audio_path": audio_path if audio_path.exists() else None, + "transcript_path": (transcript_path if transcript_path.exists() else None), } - + def _load_transcript(self, transcript_path: Optional[Path]) -> str: - """Load and truncate transcript if needed""" + """Load and truncate transcript if needed.""" if not transcript_path or not transcript_path.exists(): return "" - + try: - with open(transcript_path, 'r', encoding='utf-8') as f: + with open(transcript_path, "r", encoding="utf-8") as f: text = f.read() - + max_length = self.config.file_processing.max_transcript_length if len(text) > max_length: text = text[:max_length] + "\n...[truncated]" - + return text except Exception as e: logger.warning(f"Failed to load transcript: {e}") return "" - - def _generate_with_retry(self, media_files: List, prompt: str, - video_id: str, context: str, max_retries: int = 3) -> Optional[str]: + + def _generate_with_retry( + self, + media_files: List, + prompt: str, + video_id: str, + context: str, + max_retries: int = 3, + ) -> Optional[str]: """ - Generate demographics with retry logic for empty responses (safety filter handling). - + Generate demographics with retry for empty responses (safety filter). + Args: media_files: List of (type, path) tuples prompt: Generation prompt video_id: Video ID for logging - context: Context string for logging (e.g., "Task1 video", "Task3 segment 0-210s") + context: Context for logging (e.g. "Task1 video", "Task3 segment") max_retries: Maximum number of retry attempts - - Returns: + + Returns + ------- Response text or None if all retries failed """ for attempt in range(max_retries): try: - logger.info(f"Generating demographics for {context} (attempt {attempt+1}/{max_retries})") - - response_text = self.gemini_client.generate_content(media_files, prompt, video_fps=0.25) - + logger.info( + "Generating demographics for %s (attempt %d/%d)", + context, + attempt + 1, + max_retries, + ) + + response_text = self.gemini_client.generate_content( + media_files, prompt, video_fps=0.25 + ) + # Check for empty response if not response_text or len(response_text.strip()) < 10: - logger.warning(f"Empty/minimal response for {context} on attempt {attempt+1}/{max_retries}") - logger.warning(f"This is likely due to safety filters being triggered") - + logger.warning( + "Empty response for %s attempt %d/%d", + context, + attempt + 1, + max_retries, + ) + logger.warning( + "This is likely due to safety filters being triggered" + ) + if attempt < max_retries - 1: # Exponential backoff: 10s, 20s, 30s wait_time = 10 * (attempt + 1) - logger.info(f"Retrying after {wait_time}s (safety filters can be inconsistent)...") + logger.info( + "Retrying after %ds (safety filters)...", + wait_time, + ) time.sleep(wait_time) continue - else: - logger.error(f"Failed after {max_retries} attempts - empty responses (likely safety filter)") - return None - + logger.error( + "Failed after %d attempts (empty/safety filter)", + max_retries, + ) + return None + # Got a valid response - logger.info(f"โœ“ Received valid response: {len(response_text)} characters") + logger.info("Received valid response: %d chars", len(response_text)) return response_text - + except Exception as e: - logger.error(f"Attempt {attempt+1}/{max_retries} failed with error: {e}") - + logger.error( + "Attempt %d/%d failed: %s", + attempt + 1, + max_retries, + e, + ) + if attempt < max_retries - 1: wait_time = 10 * (attempt + 1) logger.info(f"Retrying after {wait_time}s...") @@ -336,403 +400,447 @@ def _generate_with_retry(self, media_files: List, prompt: str, else: logger.error(f"Failed after {max_retries} attempts") return None - + return None - + + def _get_human_demographics(self, video_id: str, topic_id: int) -> Optional[tuple]: + """Return (metadata, human_demographics) for *video_id*, or None on failure.""" + metadata = self.metadata_by_topic.get(topic_id, {}).get(video_id, {}) + if not metadata: + logger.error("No metadata found for video %s", video_id) + return None + human_demographics = metadata.get("demographics_detailed_reviewed", {}) + if not human_demographics: + logger.warning("No human-reviewed demographics for %s", video_id) + return None + return metadata, human_demographics + def fill_task1_demographics(self, entry: Dict[str, Any], topic_id: int) -> bool: - """Fill demographics for Task 1 (Summarization) - full video""" - video_id = entry.get('video_id') - + """Fill demographics for Task 1 (Summarization) - full video.""" + video_id = entry.get("video_id") + if self.dry_run: - logger.info(f"[DRY-RUN] Would fill demographics for Task1 video {video_id}") + logger.info( + "[DRY-RUN] Would fill demographics for Task1 video %s", + video_id, + ) return True - + try: - # Get metadata - metadata = self.metadata_by_topic.get(topic_id, {}).get(video_id, {}) - if not metadata: - logger.error(f"No metadata found for video {video_id}") - return False - - human_demographics = metadata.get('demographics_detailed_reviewed', {}) - if not human_demographics: - logger.warning(f"No human-reviewed demographics for {video_id}") + result = self._get_human_demographics(video_id, topic_id) + if result is None: return False - - # Get file paths + _, human_demographics = result + paths = self.get_file_paths(video_id, topic_id) - - # Build prompt prompt = self.demographics_expander.build_expansion_prompt( - human_demographics, - segment_info=None # Full video for task1 + human_demographics, segment_info=None ) - - # Prepare media files + media_files = [] - if paths['video_path']: - media_files.append(('video', paths['video_path'])) - if paths['audio_path']: - media_files.append(('audio', paths['audio_path'])) - - # Add transcript context - transcript_text = self._load_transcript(paths['transcript_path']) + if paths["video_path"]: + media_files.append(("video", paths["video_path"])) + if paths["audio_path"]: + media_files.append(("audio", paths["audio_path"])) + + transcript_text = self._load_transcript(paths["transcript_path"]) if transcript_text: prompt += f"\n\nTRANSCRIPT SUMMARY:\n{transcript_text[:2000]}" - - # Generate demographics with retry + response_text = self._generate_with_retry( media_files, prompt, video_id, f"Task1 video {video_id}", max_retries=3 ) - if not response_text: return False - - demographics_data = self.demographics_expander.parse_demographics_response(response_text) - - # Update entry - entry['demographics'] = demographics_data.get('demographics', []) - - logger.info(f"โœ“ Filled demographics for Task1 video {video_id}: " - f"{len(entry['demographics'])} entries, " - f"confidence={demographics_data.get('confidence', 0):.2f}") - + + demographics_data = self.demographics_expander.parse_demographics_response( + response_text + ) + entry["demographics"] = demographics_data.get("demographics", []) + logger.info( + "Filled Task1 video %s: %d entries conf=%.2f", + video_id, + len(entry["demographics"]), + demographics_data.get("confidence", 0), + ) return True - + except Exception as e: - logger.error(f"Failed to fill Task1 demographics for {video_id}: {e}", exc_info=True) + logger.error( + "Failed to fill Task1 demographics for %s: %s", + video_id, + e, + exc_info=True, + ) return False - + + def _find_matching_segment( + self, + segments: list, + start_time: float, + end_time: float, + ) -> Optional[Path]: + """Return the segment_path matching the given time range, or None.""" + return next( + ( + seg["segment_path"] + for seg in segments + if seg["start"] == start_time and seg["end"] == end_time + ), + None, + ) + + def _cleanup_segments(self, *segment_lists: Optional[list]) -> None: + """Clean up all non-None segment lists, suppressing errors.""" + for segments in segment_lists: + if segments: + try: + self.segmenter.cleanup_segments(segments) + except Exception as e: + logger.warning("Failed to cleanup segments: %s", e) + + def _prepare_segment_media( + self, + video_id: str, + topic_id: int, + start_time: float, + end_time: float, + metadata: Dict[str, Any], + ) -> Optional[tuple]: + """Build media_files and transcript for a segment. + + Returns (media_files, transcript, video_segments, audio_segments) + or None when the target segment cannot be created. + """ + paths = self.get_file_paths(video_id, topic_id) + duration = metadata.get("duration_seconds", end_time) + + video_segments = self.segmenter.segment_video( + paths["video_path"], duration, task_type="temporal_localization" + ) + seg_path = self._find_matching_segment(video_segments, start_time, end_time) + if not seg_path or not seg_path.exists(): + logger.error( + "Could not create segment for %s at %s-%ss", + video_id, + start_time, + end_time, + ) + return None + + audio_segments = None + audio_seg_path = None + if paths["audio_path"]: + audio_segments = self.segmenter.segment_audio( + paths["audio_path"], duration, task_type="temporal_localization" + ) + audio_seg_path = self._find_matching_segment( + audio_segments, start_time, end_time + ) + + transcript_text = "" + if paths["transcript_path"]: + transcript_text = self.segmenter.extract_transcript_segment( + paths["transcript_path"], start_time, end_time, strip_timestamps=True + ) + + media_type = ( + "video" if seg_path.suffix.lower() in _VIDEO_EXTENSIONS else "audio" + ) + media_files = [(media_type, seg_path)] + if audio_seg_path: + media_files.append(("audio", audio_seg_path)) + + return media_files, transcript_text, video_segments, audio_segments + def fill_task3_demographics(self, entry: Dict[str, Any], topic_id: int) -> bool: - """Fill demographics for Task 3 (Temporal Localization) - segment level""" - video_id = entry.get('video_id') - segment = entry.get('segment', {}) - start_time = segment.get('start', 0) - end_time = segment.get('end', 0) - + """Fill demographics for Task 3 (Temporal Localization) - segment.""" + video_id = entry.get("video_id") + segment = entry.get("segment", {}) + start_time = segment.get("start", 0) + end_time = segment.get("end", 0) + if self.dry_run: - logger.info(f"[DRY-RUN] Would fill demographics for Task3 video {video_id} " - f"segment {start_time}-{end_time}s") + logger.info( + "[DRY-RUN] Would fill Task3 video %s segment %s-%ss", + video_id, + start_time, + end_time, + ) return True - + video_segments = None audio_segments = None - try: - # Get metadata - metadata = self.metadata_by_topic.get(topic_id, {}).get(video_id, {}) - if not metadata: - logger.error(f"No metadata found for video {video_id}") + result = self._get_human_demographics(video_id, topic_id) + if result is None: return False - - human_demographics = metadata.get('demographics_detailed_reviewed', {}) - if not human_demographics: - logger.warning(f"No human-reviewed demographics for {video_id}") - return False - - # Get file paths - paths = self.get_file_paths(video_id, topic_id) - duration = metadata.get('duration_seconds', end_time) - - # Create video segment - video_segments = self.segmenter.segment_video( - paths['video_path'], duration, task_type='temporal_localization' + metadata, human_demographics = result + + prepared = self._prepare_segment_media( + video_id, topic_id, start_time, end_time, metadata ) - - # Find the matching segment - seg_path = None - for seg in video_segments: - if seg['start'] == start_time and seg['end'] == end_time: - seg_path = seg['segment_path'] - break - - if not seg_path or not seg_path.exists(): - logger.error(f"Could not create segment for {video_id} at {start_time}-{end_time}s") + if prepared is None: return False - - # Create audio segment if available - audio_seg_path = None - if paths['audio_path']: - audio_segments = self.segmenter.segment_audio( - paths['audio_path'], duration, task_type='temporal_localization' - ) - for seg in audio_segments: - if seg['start'] == start_time and seg['end'] == end_time: - audio_seg_path = seg['segment_path'] - break - - # Extract transcript for segment - transcript_text = "" - if paths['transcript_path']: - transcript_text = self.segmenter.extract_transcript_segment( - paths['transcript_path'], start_time, end_time, strip_timestamps=True - ) - - # Build prompt + media_files, transcript_text, video_segments, audio_segments = prepared + prompt = self.demographics_expander.build_expansion_prompt( - human_demographics, - segment_info={'start': start_time, 'end': end_time} + human_demographics, segment_info={"start": start_time, "end": end_time} ) - - # Prepare media files - media_files = [] - if seg_path.suffix.lower() in ['.mp4', '.avi', '.mov', '.webm', '.mkv', '.m4v']: - media_files.append(('video', seg_path)) - else: - media_files.append(('audio', seg_path)) - - if audio_seg_path: - media_files.append(('audio', audio_seg_path)) - - # Add transcript context if transcript_text: prompt += f"\n\nSEGMENT TRANSCRIPT:\n{transcript_text[:1000]}" - - # Generate demographics with retry + response_text = self._generate_with_retry( - media_files, prompt, video_id, - f"Task3 video {video_id} segment {start_time}-{end_time}s", - max_retries=3 + media_files, + prompt, + video_id, + f"Task3 video {video_id} segment {start_time}-{end_time}s", + max_retries=3, ) - if not response_text: return False - - demographics_data = self.demographics_expander.parse_demographics_response(response_text) - - # Update entry - entry['demographics'] = demographics_data.get('demographics', []) - entry['demographics_total_individuals'] = demographics_data.get('total_individuals', 0) - entry['demographics_confidence'] = demographics_data.get('confidence', 0.0) - entry['demographics_explanation'] = demographics_data.get('explanation', '') - - # Update cache so Task 2 can reuse this immediately - cache_data = { - 'demographics': entry['demographics'], - 'demographics_total_individuals': entry['demographics_total_individuals'], - 'demographics_confidence': entry['demographics_confidence'], - 'demographics_explanation': entry['demographics_explanation'] + + demographics_data = self.demographics_expander.parse_demographics_response( + response_text + ) + demo_fields = { + "demographics": demographics_data.get("demographics", []), + "demographics_total_individuals": demographics_data.get( + "total_individuals", 0 + ), + "demographics_confidence": demographics_data.get("confidence", 0.0), + "demographics_explanation": demographics_data.get("explanation", ""), } - self._update_task3_cache(topic_id, video_id, start_time, end_time, cache_data) - - logger.info(f"โœ“ Filled demographics for Task3 video {video_id} segment {start_time}-{end_time}s: " - f"{len(entry['demographics'])} entries, " - f"confidence={demographics_data.get('confidence', 0):.2f}") - + entry.update(demo_fields) + self._update_task3_cache( + topic_id, video_id, start_time, end_time, demo_fields + ) + + logger.info( + "Filled Task3 %s segment %s-%ss: %d entries conf=%.2f", + video_id, + start_time, + end_time, + len(entry["demographics"]), + entry["demographics_confidence"], + ) return True - + except Exception as e: - logger.error(f"Failed to fill Task3 demographics for {video_id} " - f"segment {start_time}-{end_time}s: {e}", exc_info=True) + logger.error( + "Failed Task3 %s segment %s-%ss: %s", + video_id, + start_time, + end_time, + e, + exc_info=True, + ) return False - + finally: - # Always cleanup segments - if video_segments: - try: - self.segmenter.cleanup_segments(video_segments) - except Exception as e: - logger.warning(f"Failed to cleanup video segments: {e}") - - if audio_segments: - try: - self.segmenter.cleanup_segments(audio_segments) - except Exception as e: - logger.warning(f"Failed to cleanup audio segments: {e}") - - def process_json_file(self, json_path: Path, task_name: str) -> Dict[str, int]: + self._cleanup_segments(video_segments, audio_segments) + + def _fill_stats(self, stats: Dict[str, int], success: bool) -> None: + """Increment filled or failed counter in *stats* based on *success*.""" + if success: + stats["filled"] += 1 + else: + stats["failed"] += 1 + + def _fill_empty_demographics( + self, + entries: List[Dict[str, Any]], + task_name: str, + topic_id: int, + stats: Dict[str, int], + ) -> None: + """Iterate *entries* once, filling any entry whose demographics are empty. + + Updates *stats* in-place (empty / filled / failed counters). """ - Process a single VQA JSON file and fill empty demographics. - - For Task 2: ALWAYS overwrites demographics from Task 3 when available (Task 3 is reviewed and correct). + for entry in entries: + if entry.get("demographics"): + continue + + stats["empty"] += 1 + + if not self.dry_run: + delay = self.config.rate_limit.delay_after_api_call + logger.info("Rate limiting: waiting %ss...", delay) + time.sleep(delay) + + if task_name == "task1": + success = self.fill_task1_demographics(entry, topic_id) + elif task_name == "task3": + success = self.fill_task3_demographics(entry, topic_id) + else: + logger.error("Unknown task: %s", task_name) + success = False + + self._fill_stats(stats, success) + + def process_task_file(self, json_path: Path, task_name: str) -> Dict[str, int]: + """Process a single VQA JSON file and fill empty demographics. + + Task 2: overwrites demographics from Task 3 when available. """ - logger.info(f"\n{'='*80}") + logger.info(f"\n{'=' * 80}") logger.info(f"Processing {json_path.name}") - logger.info(f"{'='*80}") - + logger.info(f"{'=' * 80}") + # Load JSON try: - with open(json_path, 'r') as f: + with open(json_path, "r") as f: data = json.load(f) except Exception as e: logger.error(f"Failed to load {json_path}: {e}") - return {'total': 0, 'empty': 0, 'filled': 0, 'failed': 0, 'reused': 0} - - topic_id = data.get('topic_id') - topic_name = data.get('topic_name') - entries = data.get('entries', []) - - stats = {'total': len(entries), 'empty': 0, 'filled': 0, 'failed': 0, 'reused': 0} - - # For Task 2: Check ALL entries for Task 3 matches (not just empty ones) - if task_name == 'task2': - logger.info(f"Task 2: Checking ALL {len(entries)} entries for Task 3 demographics to copy") - - # Pre-load Task 3 cache - self._load_task3_demographics(topic_id) - + return { + "total": 0, + "empty": 0, + "filled": 0, + "failed": 0, + "reused": 0, + } + + topic_id = data.get("topic_id") + entries = data.get("entries", []) + + stats = { + "total": len(entries), + "empty": 0, + "filled": 0, + "failed": 0, + "reused": 0, + } + + # For Task 2: check all entries for Task 3 matches + if task_name == "task2": + logger.info( + "Task 2: Checking %d entries for Task 3 demographics", + len(entries), + ) + updated_count = 0 for entry in entries: - video_id = entry.get('video_id') - segment = entry.get('segment', {}) - start_time = segment.get('start', 0) - end_time = segment.get('end', 0) - + video_id = entry.get("video_id") + segment = entry.get("segment", {}) + start_time = segment.get("start", 0) + end_time = segment.get("end", 0) + # Try to get Task 3 demographics - task3_demo = self._get_task3_demographics(video_id, start_time, end_time, topic_id) - - if task3_demo and task3_demo.get('demographics'): - # ALWAYS overwrite with Task 3 (it's reviewed and correct) - entry['demographics'] = task3_demo['demographics'] - entry['demographics_total_individuals'] = task3_demo['demographics_total_individuals'] - entry['demographics_confidence'] = task3_demo['demographics_confidence'] - entry['demographics_explanation'] = task3_demo['demographics_explanation'] - + task3_demo = self._get_task3_demographics( + video_id, start_time, end_time, topic_id + ) + + if task3_demo and task3_demo.get("demographics"): + # ALWAYS overwrite with Task 3 (it's reviewed and correct). + # deepcopy so mutations to `entry` don't affect the cache. + demo_copy = copy.deepcopy(task3_demo) + entry["demographics"] = demo_copy["demographics"] + entry["demographics_total_individuals"] = demo_copy[ + "demographics_total_individuals" + ] + entry["demographics_confidence"] = demo_copy[ + "demographics_confidence" + ] + entry["demographics_explanation"] = demo_copy[ + "demographics_explanation" + ] + updated_count += 1 - self.stats['task2']['reused'] += 1 - - logger.info(f"โœ“ Updated {updated_count} Task 2 entries with Task 3 demographics") - stats['filled'] = updated_count - stats['reused'] = updated_count - + self.stats["task2"]["reused"] += 1 + + logger.info( + "Updated %d Task 2 entries with Task 3 demographics", + updated_count, + ) + stats["filled"] = updated_count + stats["reused"] = updated_count + # Save updated JSON (if not dry-run) if not self.dry_run and updated_count > 0: - try: - # Create backup - backup_path = json_path.with_suffix('.json.backup') - with open(backup_path, 'w') as f: - json.dump(data, f, indent=2, ensure_ascii=False) - logger.info(f"Created backup: {backup_path.name}") - - # Save updated file - with open(json_path, 'w') as f: - json.dump(data, f, indent=2, ensure_ascii=False) - logger.info(f"โœ“ Saved updated file: {json_path.name}") - except Exception as e: - logger.error(f"Failed to save {json_path}: {e}") - + save_json_with_backup(data, json_path) + logger.info(f"โœ“ Saved updated file: {json_path.name}") + return stats - - # For Task 1 and Task 3: Original logic (only fill empty) - # Find empty demographics - empty_indices = [] - for i, entry in enumerate(entries): - demographics = entry.get('demographics', []) - if not demographics or demographics == []: - empty_indices.append(i) - stats['empty'] += 1 - - if not empty_indices: - logger.info(f"โœ“ No empty demographics found in {json_path.name}") + + # For Task 1 and Task 3: fill entries that have empty demographics + self._fill_empty_demographics(entries, task_name, topic_id, stats) + + if stats["empty"] == 0: + logger.info("โœ“ No empty demographics found in %s", json_path.name) return stats - - logger.info(f"Found {len(empty_indices)} entries with empty demographics") - - # Process each empty entry - for i in empty_indices: - entry = entries[i] - - # Rate limiting before each API call - if not self.dry_run: - delay = self.config.rate_limit.delay_after_api_call - logger.info(f"Rate limiting: waiting {delay}s...") - time.sleep(delay) - - # Fill demographics based on task - if task_name == 'task1': - success = self.fill_task1_demographics(entry, topic_id) - elif task_name == 'task3': - success = self.fill_task3_demographics(entry, topic_id) - else: - logger.error(f"Unknown task: {task_name}") - success = False - - if success: - stats['filled'] += 1 - else: - stats['failed'] += 1 - + + logger.info("Found %d entries with empty demographics", stats["empty"]) + # Save updated JSON (if not dry-run) - if not self.dry_run and stats['filled'] > 0: - try: - # Create backup - backup_path = json_path.with_suffix('.json.backup') - with open(backup_path, 'w') as f: - json.dump(data, f, indent=2, ensure_ascii=False) - logger.info(f"Created backup: {backup_path.name}") - - # Save updated file - with open(json_path, 'w') as f: - json.dump(data, f, indent=2, ensure_ascii=False) - logger.info(f"โœ“ Saved updated file: {json_path.name}") - except Exception as e: - logger.error(f"Failed to save {json_path}: {e}") - + if not self.dry_run and stats["filled"] > 0: + save_json_with_backup(data, json_path) + logger.info(f"โœ“ Saved updated file: {json_path.name}") + return stats - + def process_all_tasks(self, topic_filter: Optional[List[int]] = None): - """ - Process all VQA task directories. - OPTIMIZED: Processes Task 3 first, then Task 1, then Task 2 (which reuses from Task 3). - """ + """Process all VQA task dirs. Order: Task 3, Task 1, Task 2.""" output_dir = Path(self.config.paths.output_dir) - + if not output_dir.exists(): logger.error(f"Output directory not found: {output_dir}") return - + task_dirs = { - 'task1': output_dir / 'task1_summarization', - 'task2': output_dir / 'task2_mcq', - 'task3': output_dir / 'task3_temporal_localization' + "task1": output_dir / "task1_summarization", + "task2": output_dir / "task2_mcq", + "task3": output_dir / "task3_temporal_localization", } - - # IMPORTANT: Process Task 3 first, then Task 1, then Task 2 (which reuses from Task 3) - task_order = ['task3', 'task1', 'task2'] - + + # Order: Task 3, Task 1, Task 2 (Task 2 reuses Task 3) + task_order = ["task3", "task1", "task2"] + for task_name in task_order: task_dir = task_dirs[task_name] - + if not task_dir.exists(): logger.warning(f"Task directory not found: {task_dir}") continue - - logger.info(f"\n{'#'*80}") + + logger.info(f"\n{'#' * 80}") logger.info(f"# Processing {task_name.upper()}") - logger.info(f"{'#'*80}") - + logger.info(f"{'#' * 80}") + # Get all JSON files json_files = sorted(task_dir.glob("*.json")) - + # Filter by topic if specified if topic_filter: json_files = [ - f for f in json_files + f + for f in json_files if any(f.name.startswith(f"{tid:02d}_") for tid in topic_filter) ] - + logger.info(f"Found {len(json_files)} JSON files to process") - + for json_path in json_files: - stats = self.process_json_file(json_path, task_name) - + stats = self.process_task_file(json_path, task_name) + # Update global stats - self.stats[task_name]['total'] += stats['total'] - self.stats[task_name]['empty'] += stats['empty'] - self.stats[task_name]['filled'] += stats['filled'] - self.stats[task_name]['failed'] += stats['failed'] + self.stats[task_name]["total"] += stats["total"] + self.stats[task_name]["empty"] += stats["empty"] + self.stats[task_name]["filled"] += stats["filled"] + self.stats[task_name]["failed"] += stats["failed"] # reused is already tracked in self.stats - + # Print final summary self.print_summary() - + def print_summary(self): - """Print final statistics summary""" - logger.info(f"\n{'='*80}") + """Print final statistics summary.""" + logger.info(f"\n{'=' * 80}") logger.info("FINAL SUMMARY") - logger.info(f"{'='*80}") - - for task_name in ['task1', 'task2', 'task3']: + logger.info(f"{'=' * 80}") + + for task_name in ["task1", "task2", "task3"]: stats = self.stats[task_name] logger.info(f"\n{task_name.upper()}:") logger.info(f" Total entries: {stats['total']}") @@ -740,85 +848,98 @@ def print_summary(self): logger.info(f" Successfully filled: {stats['filled']}") logger.info(f" Reused from Task3: {stats['reused']}") logger.info(f" Failed: {stats['failed']}") - - if stats['empty'] > 0: - success_rate = (stats['filled'] / stats['empty']) * 100 + + if stats["empty"] > 0: + success_rate = (stats["filled"] / stats["empty"]) * 100 logger.info(f" Success rate: {success_rate:.1f}%") - - if stats['reused'] > 0: - reuse_rate = (stats['reused'] / stats['empty']) * 100 + + if stats["reused"] > 0: + reuse_rate = (stats["reused"] / stats["empty"]) * 100 logger.info(f" Reuse rate: {reuse_rate:.1f}%") - - total_empty = sum(s['empty'] for s in self.stats.values()) - total_filled = sum(s['filled'] for s in self.stats.values()) - total_reused = sum(s['reused'] for s in self.stats.values()) - total_failed = sum(s['failed'] for s in self.stats.values()) - - logger.info(f"\nOVERALL:") + + total_empty = sum(s["empty"] for s in self.stats.values()) + total_filled = sum(s["filled"] for s in self.stats.values()) + total_reused = sum(s["reused"] for s in self.stats.values()) + total_failed = sum(s["failed"] for s in self.stats.values()) + + logger.info("\nOVERALL:") logger.info(f" Total empty: {total_empty}") logger.info(f" Successfully filled: {total_filled}") logger.info(f" Reused from Task3: {total_reused}") logger.info(f" Failed: {total_failed}") - + if total_empty > 0: success_rate = (total_filled / total_empty) * 100 logger.info(f" Success rate: {success_rate:.1f}%") - + if total_reused > 0: reuse_rate = (total_reused / total_empty) * 100 api_savings = (total_reused / total_empty) * 100 logger.info(f" Reuse rate: {reuse_rate:.1f}%") logger.info(f" API calls saved: {api_savings:.1f}%") + def main(): - """Main entry point""" - parser = argparse.ArgumentParser(description='Fill Empty Demographics in VQA Files') - parser.add_argument('--config', type=str, default='config/vqa_config.yaml', - help='Path to configuration file') - parser.add_argument('--topics', type=str, default=None, - help='Comma-separated topic IDs to process (e.g., "10,11")') - parser.add_argument('--dry-run', action='store_true', - help='Show what would be done without making changes') - + """Run main entry point.""" + parser = argparse.ArgumentParser(description="Fill Empty Demographics in VQA Files") + parser.add_argument( + "--config", + type=str, + default="vqa_config.yaml", + help="Path to configuration file", + ) + parser.add_argument( + "--topics", + type=str, + default=None, + help='Comma-separated topic IDs to process (e.g., "10,11")', + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Show what would be done without making changes", + ) + args = parser.parse_args() - + # Load config config_path = Path(args.config) if not config_path.exists(): print(f"Error: Config file not found: {config_path}") sys.exit(1) - - config = load_config(str(config_path)) - + + config = load_config(str(config_path), base_dir=Path(__file__).parent) + # Parse topic filter topic_filter = None if args.topics: try: - topic_filter = [int(t.strip()) for t in args.topics.split(',')] + topic_filter = [int(t.strip()) for t in args.topics.split(",")] logger.info(f"Processing topics: {topic_filter}") except ValueError: logger.error(f"Invalid topics format: {args.topics}") sys.exit(1) - + # Check API key - if not args.dry_run and not os.getenv('GEMINI_API_KEY'): + if not args.dry_run and not os.getenv("GEMINI_API_KEY"): logger.error("GEMINI_API_KEY not found in environment!") sys.exit(1) - + # Create filler and run filler = DemographicsFiller(config, dry_run=args.dry_run) - + if args.dry_run: logger.info("=" * 80) logger.info("DRY RUN MODE - No changes will be made") logger.info("=" * 80) - - logger.info("\nOPTIMIZATION: Task 3 will be processed first, then Task 1, then Task 2 reuses demographics from Task 3") + + logger.info("\nOPTIMIZATION: Task 3 first, then Task 1, then Task 2 reuses Task 3") logger.info("This will significantly reduce API calls!\n") - + filler.process_all_tasks(topic_filter) - + logger.info("\nDone!") -if __name__ == '__main__': - main() \ No newline at end of file + +if __name__ == "__main__": + main() diff --git a/sonic-o1/04_vqa_generation/main.py b/sonic-o1/04_vqa_generation/main.py index 2601d10..550e650 100644 --- a/sonic-o1/04_vqa_generation/main.py +++ b/sonic-o1/04_vqa_generation/main.py @@ -1,578 +1,596 @@ -""" -Main VQA Generation Script +"""main.py. + +Main VQA generation script. Generates summarization, MCQ, and temporal +localization tasks using Gemini-based multimodal models. Usage: python main.py --topics 1,2,3 python main.py --all python main.py --topics 1 --task summarization + +Author: SONIC-O1 Team """ + import argparse import json import logging -import yaml -import os import time -from pathlib import Path from datetime import datetime -from typing import Dict, List, Any +from pathlib import Path +from typing import Any, Dict, List + from tqdm import tqdm +from utils.config_utils import Config, load_config + # Load environment variables from .env file if it exists try: from dotenv import load_dotenv - env_path = Path(__file__).parent / '.env' + + env_path = Path(__file__).parent / ".env" if env_path.exists(): load_dotenv(env_path) logging.info("Loaded environment variables from .env file") except ImportError: - # python-dotenv not installed, skip pass -from models import SummarizationModel, MCQModel, TemporalLocalizationModel +from models import MCQModel, SummarizationModel, TemporalLocalizationModel + # Setup logging logger = logging.getLogger(__name__) - -class Config: - """Configuration wrapper""" - def __init__(self, config_dict): - for key, value in config_dict.items(): - if isinstance(value, dict): - setattr(self, key, Config(value)) - else: - setattr(self, key, value) - - -def load_config(config_path: str) -> Config: - """Load configuration from YAML file""" - with open(config_path, 'r') as f: - config_dict = yaml.safe_load(f) - return Config(config_dict) +# Registry of all tasks: (cli_key, model_class, output_dir_name, multi_entry) +# multi_entry=True -- model.process_video() returns List[Dict] (one per segment) +# multi_entry=False -- model.process_video() returns Dict (one per video) +TASKS = [ + ("summarization", SummarizationModel, "task1_summarization", False), + ("mcq", MCQModel, "task2_mcq", True), + ("temporal", TemporalLocalizationModel, "task3_temporal_localization", True), +] def load_metadata_for_topic(topic_dir: Path) -> List[Dict[str, Any]]: """ Load metadata_enhanced.json for a topic directory. - + Args: - topic_dir: Path to topic directory (e.g., dataset/videos/01_Patient-Doctor_Consultations/) - - Returns: + topic_dir: Path to topic directory (e.g. dataset/videos/01_.../) + + Returns + ------- List of video metadata dicts """ - metadata_file = topic_dir / 'metadata_enhanced.json' - + metadata_file = topic_dir / "metadata_enhanced.json" + if not metadata_file.exists(): - logger.warning(f"No metadata_enhanced.json found in {topic_dir}") + logger.warning("No metadata_enhanced.json found in %s", topic_dir) return [] - + try: - with open(metadata_file, 'r', encoding='utf-8') as f: + with open(metadata_file, "r", encoding="utf-8") as f: metadata_list = json.load(f) - - logger.info(f"Loaded {len(metadata_list)} videos from {topic_dir.name}") + + logger.info("Loaded %d videos from %s", len(metadata_list), topic_dir.name) return metadata_list - + except Exception as e: - logger.error(f"Failed to load metadata from {topic_dir}: {e}") + logger.error("Failed to load metadata from %s: %s", topic_dir, e) return [] def get_file_paths(video_meta: Dict[str, Any], topic_dir: Path) -> Dict[str, Path]: """ Get file paths for video, audio, and transcript. - + Args: video_meta: Video metadata dict topic_dir: Topic directory path - - Returns: + + Returns + ------- Dict with keys: video_path, audio_path, transcript_path """ - video_number = video_meta.get('video_number', video_meta.get('video_id', '001')) - + video_number = video_meta.get("video_number", video_meta.get("video_id", "001")) + # Video path video_filename = f"video_{video_number}.mp4" video_path = topic_dir / video_filename - + # Audio path (in parent audios directory) - audio_dir = topic_dir.parent.parent / 'audios' / topic_dir.name + audio_dir = topic_dir.parent.parent / "audios" / topic_dir.name audio_filename = f"audio_{video_number}.m4a" audio_path = audio_dir / audio_filename - + # Transcript path (in parent captions directory) - captions_dir = topic_dir.parent.parent / 'captions' / topic_dir.name + captions_dir = topic_dir.parent.parent / "captions" / topic_dir.name transcript_filename = f"caption_{video_number}.srt" transcript_path = captions_dir / transcript_filename - + return { - 'video_path': video_path if video_path.exists() else None, - 'audio_path': audio_path if audio_path.exists() else None, - 'transcript_path': transcript_path if transcript_path.exists() else None + "video_path": video_path if video_path.exists() else None, + "audio_path": audio_path if audio_path.exists() else None, + "transcript_path": (transcript_path if transcript_path.exists() else None), } -def process_topic(topic_id: int, - topic_name: str, - topic_dir: Path, - config: Config, - output_dir: Path, - task_filter: str = None) -> tuple: - """ - Process all videos in a topic for VQA generation. - +def get_confidence(entry: Dict[str, Any]) -> float: + """Return the confidence score of a VQA entry.""" + return float(entry.get("confidence", 0)) + + +def get_summary_detailed(entry: Dict[str, Any]) -> str: + """Return the detailed summary text of a Task 1 entry.""" + return entry.get("summary_detailed", "") + + +_SUMMARY_FAIL_PATTERNS = ( + "unavailable", + "summary generation failed", + "could not be generated", + "summary failed", +) + + +def get_summary_failed(entry: Dict[str, Any]) -> bool: + """Return True if a Task 1 entry contains known failure markers.""" + for item in entry.get("summary_short", []): + if isinstance(item, str): + lower = item.lower() + if any(p in lower for p in _SUMMARY_FAIL_PATTERNS): + return True + if "first segment" in lower and "failed" in lower: + return True + + lower_detailed = get_summary_detailed(entry).lower() + return ( + "could not be generated" in lower_detailed + or "summary generation failed" in lower_detailed + or "parsing error" in lower_detailed + or "explicitly reported a failure" in lower_detailed + or ("failed to" in lower_detailed and "summary" in lower_detailed) + ) + + +def skip_task(task_name: str, existing: Dict, video_id: str) -> bool: + """Return True if this video already has valid, complete entries for the task.""" + if video_id not in existing: + return False + + if task_name == "task1_summarization": + entry = existing[video_id] + if get_summary_failed(entry): + logger.info("Reprocessing Task 1 for %s (previous failure)", video_id) + return False + if get_confidence(entry) == 0: + logger.info("Reprocessing Task 1 for %s (confidence was 0)", video_id) + return False + logger.info("Skipping Task 1 for %s (already processed)", video_id) + return True + + # Task 2 / Task 3 โ€” list of segment entries per video + segment_entries = existing[video_id] + if segment_entries and all(get_confidence(e) > 0 for e in segment_entries): + logger.info("Skipping %s for %s (already processed)", task_name, video_id) + return True + return False + + +def _apply_rate_limit(video_category: str, config: Config) -> None: + """Sleep between videos according to config rate limits.""" + delay = int(getattr(config.rate_limit, "delay_between_videos", 15)) + if video_category == "long": + delay += int(getattr(config.rate_limit, "delay_after_long_video", 60)) + logger.info("Long video โ€” waiting %ss before next", delay) + else: + logger.info("Waiting %ss before next video (rate limit)", delay) + time.sleep(delay) + + +def process_topic( + topic_id: int, + topic_name: str, + topic_dir: Path, + task_name: str, + model, + existing: Dict, + config: Config, + multi_entry: bool = False, +) -> List[Dict[str, Any]]: + """Process all videos in a topic for a single task. + Args: - topic_id: Topic ID (1-13) - topic_name: Topic name (e.g., "Patient-Doctor Consultations") - topic_dir: Path to topic directory - config: Configuration object - task_filter: Optional filter - "summarization" or "mcq" (None = both) - - Returns: - Tuple of (task1_entries, task2_entries) + topic_id: Topic ID (1-13). + topic_name: Human-readable topic name. + topic_dir: Path to topic video directory. + task_name: Output subdirectory name, e.g. "task1_summarization". + model: Instantiated task model. + existing: Pre-loaded existing entries indexed by video_id. + config: Configuration object. + multi_entry: True when model.process_video() returns List[Dict] + (MCQ/temporal); False when it returns a single Dict (summarization). + + Returns + ------- + List of VQA entry dicts for this task across all videos. """ - logger.info(f"Processing Topic {topic_id}: {topic_name}") - - # Load metadata metadata_list = load_metadata_for_topic(topic_dir) if not metadata_list: - logger.warning(f"No videos found for topic {topic_id}") - return ([], [], []) - - # Initialize models - task1_entries = [] - task2_entries = [] - task3_entries = [] - # Load existing results if they exist - existing_task1 = {} - existing_task2 = {} - existing_task3 = {} - if task_filter is None or task_filter == "summarization": - summarizer = SummarizationModel(config) - # Check for existing Task 1 output - task1_output_file = output_dir / 'task1_summarization' / f"{topic_id:02d}_{topic_name.replace(' ', '_')}.json" - if task1_output_file.exists(): - with open(task1_output_file, 'r') as f: - task1_data = json.load(f) - # Index by video_id - for entry in task1_data.get('entries', []): - existing_task1[entry['video_id']] = entry - logger.info(f"Loaded {len(existing_task1)} existing Task 1 entries") - - if task_filter is None or task_filter == "mcq": - mcq_generator = MCQModel(config) - # Check for existing Task 2 output - task2_output_file = output_dir / 'task2_mcq' / f"{topic_id:02d}_{topic_name.replace(' ', '_')}.json" - if task2_output_file.exists(): - with open(task2_output_file, 'r') as f: - task2_data = json.load(f) - # Index by video_id - for entry in task2_data.get('entries', []): - vid = entry['video_id'] - if vid not in existing_task2: - existing_task2[vid] = [] - existing_task2[vid].append(entry) - logger.info(f"Loaded {sum(len(v) for v in existing_task2.values())} existing Task 2 entries") - - if task_filter is None or task_filter == "temporal": - temporal_generator = TemporalLocalizationModel(config) - # Check for existing Task 3 output - task3_output_file = output_dir / 'task3_temporal_localization' / f"{topic_id:02d}_{topic_name.replace(' ', '_')}.json" - if task3_output_file.exists(): - with open(task3_output_file, 'r') as f: - task3_data = json.load(f) - for entry in task3_data.get('entries', []): - vid = entry['video_id'] - if vid not in existing_task3: - existing_task3[vid] = [] - existing_task3[vid].append(entry) - logger.info(f"Loaded {sum(len(v) for v in existing_task3.values())} existing Task 3 entries") - - # Process each video - for video_meta in tqdm(metadata_list, desc=f"Processing {topic_name}"): - video_id = video_meta.get('video_id', video_meta.get('video_number', 'unknown')) - - video_category = video_meta.get('duration_category', 'short') - duration = video_meta.get('duration_seconds', 0) - if video_category not in ['short', 'medium', 'long']: - if duration <= 300: # <= 5 minutes - video_category = 'short' - elif duration <= 1800: # <= 30 minutes - video_category = 'medium' - else: # > 30 minutes - video_category = 'long' - # Check if already successfully processed - skip_task1 = False - skip_task2 = False - skip_task3 = False - - if video_id in existing_task1: - entry = existing_task1[video_id] - - # Check for failure indicators in summary - has_summary_failure = False - - # Check summary_short for failures (more specific patterns) - summary_short = entry.get('summary_short', []) - if isinstance(summary_short, list): - for item in summary_short: - if isinstance(item, str): - lower_item = item.lower() - # Look for specific failure patterns, not just the word "failure" - if ('unavailable' in lower_item or - 'summary generation failed' in lower_item or - 'could not be generated' in lower_item or - 'summary failed' in lower_item or - 'first segment' in lower_item and 'failed' in lower_item): - has_summary_failure = True - break - - # Check summary_detailed for failures (more specific patterns) - summary_detailed = entry.get('summary_detailed', '') - if isinstance(summary_detailed, str): - lower_detailed = summary_detailed.lower() - if ('could not be generated' in lower_detailed or - 'summary generation failed' in lower_detailed or - 'parsing error' in lower_detailed or - 'failed to' in lower_detailed and 'summary' in lower_detailed or - 'explicitly reported a failure' in lower_detailed): - has_summary_failure = True - - # Only skip if no failures detected AND confidence > 0 - if not has_summary_failure and entry.get('confidence', 0) > 0: - skip_task1 = True - task1_entries.append(entry) - logger.info(f"Skipping Task 1 for {video_id} (already processed successfully)") - else: - if has_summary_failure: - logger.info(f"Reprocessing Task 1 for {video_id} (detected failure in previous attempt)") - else: - logger.info(f"Reprocessing Task 1 for {video_id} (confidence was 0)") - if video_id in existing_task2: - # Check if all MCQ entries have good confidence - all_good = all(e.get('confidence', 0) > 0 for e in existing_task2[video_id]) - if all_good and len(existing_task2[video_id]) > 0: - skip_task2 = True - task2_entries.extend(existing_task2[video_id]) - logger.info(f"Skipping Task 2 for {video_id} (already processed successfully)") - - if video_id in existing_task3: - # Check if all temporal entries have good confidence (same pattern as Task 2) - all_good = all(e.get('confidence', 0) > 0 for e in existing_task3[video_id]) - if all_good and len(existing_task3[video_id]) > 0: - skip_task3 = True - task3_entries.extend(existing_task3[video_id]) - logger.info(f"Skipping Task 3 for {video_id} (already processed successfully)") - - # If both tasks should be skipped, continue - if task_filter == "summarization" and skip_task1: - continue - elif task_filter == "mcq" and skip_task2: - continue - elif task_filter == "temporal" and skip_task3: - continue - elif task_filter is None and skip_task1 and skip_task2 and skip_task3: + logger.warning("No videos found for topic %s", topic_id) + return [] + + entries: List[Dict[str, Any]] = [] + + for video_meta in tqdm(metadata_list, desc=f"Topic {topic_id} / {task_name}"): + video_id = video_meta.get("video_id", video_meta.get("video_number", "unknown")) + duration = video_meta.get("duration_seconds", 0) + video_category = video_meta.get("duration_category", "") + if video_category not in ("short", "medium", "long"): + video_category = ( + "short" if duration <= 300 else "medium" if duration <= 1800 else "long" + ) + + if skip_task(task_name, existing, video_id): + cached = existing[video_id] + entries.extend([cached] if not multi_entry else cached) continue - try: - # Enhance metadata with topic info - video_meta['topic_id'] = topic_id - video_meta['topic_name'] = topic_name - - # Get file paths + video_meta["topic_id"] = topic_id + video_meta["topic_name"] = topic_name file_paths = get_file_paths(video_meta, topic_dir) - - if not file_paths['video_path'] and not file_paths['audio_path']: - logger.warning(f"No video or audio found for {video_id}, skipping") + + if not file_paths["video_path"] and not file_paths["audio_path"]: + logger.warning("No video or audio for %s, skipping", video_id) continue - - # Task 1: Summarization - if (task_filter is None or task_filter == "summarization") and not skip_task1: - logger.info(f"Generating summarization for {video_id}") - summary_entry = summarizer.process_video( - video_path=file_paths['video_path'], - audio_path=file_paths['audio_path'], - transcript_path=file_paths['transcript_path'], - metadata=video_meta - ) - task1_entries.append(summary_entry) - - # Task 2: MCQ - if (task_filter is None or task_filter == "mcq") and not skip_task2: - logger.info(f"Generating MCQs for {video_id}") - - new_mcq_entries = mcq_generator.process_video( - video_path=file_paths['video_path'], - audio_path=file_paths['audio_path'], - transcript_path=file_paths['transcript_path'], - metadata=video_meta - ) - - # If we have existing entries, merge intelligently - if video_id in existing_task2 and len(existing_task2[video_id]) > 0: - merged_entries, kept, replaced = merge_entries_keep_good( - existing_task2[video_id], - new_mcq_entries + + logger.info("Generating %s for %s", task_name, video_id) + new = model.process_video( + video_path=file_paths["video_path"], + audio_path=file_paths["audio_path"], + transcript_path=file_paths["transcript_path"], + metadata=video_meta, + ) + + if multi_entry: + new_list = new if isinstance(new, list) else [new] + if video_id in existing and existing[video_id]: + merged, kept, replaced = merge_entries_keep_good( + existing[video_id], new_list ) - task2_entries.extend(merged_entries) - logger.info(f"Task 2 for {video_id}: kept {kept} good entries, replaced {replaced} failed entries") - else: - # No existing entries, use all new ones - task2_entries.extend(new_mcq_entries) - - # Task 3: Temporal Localization - if (task_filter is None or task_filter == "temporal") and not skip_task3: - logger.info(f"Generating temporal questions for {video_id}") - - new_temporal_entries = temporal_generator.process_video( - video_path=file_paths['video_path'], - audio_path=file_paths['audio_path'], - transcript_path=file_paths['transcript_path'], - metadata=video_meta - ) - - # If we have existing entries, merge intelligently - if video_id in existing_task3 and len(existing_task3[video_id]) > 0: - merged_entries, kept, replaced = merge_entries_keep_good( - existing_task3[video_id], - new_temporal_entries + entries.extend(merged) + logger.info( + "%s for %s: kept %d good, replaced %d failed", + task_name, + video_id, + kept, + replaced, ) - task3_entries.extend(merged_entries) - logger.info(f"Task 3 for {video_id}: kept {kept} good entries, replaced {replaced} failed entries") else: - # No existing entries, use all new ones - task3_entries.extend(new_temporal_entries) - - # Rate limiting: Add delay after processing each video - if not (skip_task1 and skip_task2 and skip_task3): - delay_between_videos = int(getattr(config.rate_limit, 'delay_between_videos', 15)) - - # Extra delay for long videos (use metadata category instead of threshold) - if video_category == 'long': - extra_delay = int(getattr(config.rate_limit, 'delay_after_long_video', 60)) - total_delay = delay_between_videos + extra_delay - logger.info(f"Long video ({video_category}) - waiting {total_delay}s before next video") - time.sleep(total_delay) - else: - logger.info(f"Waiting {delay_between_videos}s before next video (rate limiting)") - time.sleep(delay_between_videos) + entries.extend(new_list) else: - logger.info(f"Skipped all tasks for {video_id} - no rate limiting delay needed") + entries.append(new) + + _apply_rate_limit(video_category, config) except Exception as e: - logger.error(f"Failed to process video {video_id}: {e}", exc_info=True) - continue - - logger.info(f"Completed Topic {topic_id}: {len(task1_entries)} summaries, {len(task2_entries)} MCQs") - return (task1_entries, task2_entries,task3_entries) + logger.error("Failed to process %s: %s", video_id, e, exc_info=True) + + logger.info( + "Completed %s for Topic %s: %d entries", task_name, topic_id, len(entries) + ) + return entries + + +def get_topic_output_path( + output_dir: Path, task_name: str, topic_id: int, topic_name: str +) -> Path: + """ + Build the output JSON path for a given task and topic. + + Args: + output_dir: Root output directory (e.g., vqa/) + task_name: e.g. "task1_summarization" + topic_id: Topic ID + topic_name: Topic name (spaces allowed; converted to underscores) + + Returns + ------- + Full Path to the output JSON file. + """ + filename = f"{topic_id:02d}_{topic_name.replace(' ', '_')}.json" + return output_dir / task_name / filename + + +def init_task( + task_filter_key: str, + model_class, + task_name: str, + config: "Config", + output_dir: Path, + topic_id: int, + topic_name: str, + task_filter: str, + list_per_video: bool = False, + dry_run: bool = False, +) -> tuple: + """ + Initialize a task model and load any pre-existing output entries. + + Args: + task_filter_key: Filter string for this task, e.g. "summarization". + model_class: Model class to instantiate, e.g. SummarizationModel. + task_name: Output subdirectory name, e.g. "task1_summarization". + config: Configuration object. + output_dir: Root output directory. + topic_id: Topic ID. + topic_name: Topic name. + task_filter: Active CLI filter (None = run all tasks). + list_per_video: If True, existing entries are indexed as + dict[video_id -> list]; if False, dict[video_id -> entry]. + + Returns + ------- + Tuple of (model_instance_or_None, existing_entries_dict). + model is None when this task is excluded by task_filter. + """ + if task_filter is not None and task_filter != task_filter_key: + return None, {} + + model = model_class(config, dry_run=dry_run) + existing = {} + + output_file = get_topic_output_path(output_dir, task_name, topic_id, topic_name) + if output_file.exists(): + with open(output_file, "r") as f: + data = json.load(f) + for entry in data.get("entries", []): + vid = entry["video_id"] + if list_per_video: + existing.setdefault(vid, []).append(entry) + else: + existing[vid] = entry + n = sum(len(v) for v in existing.values()) if list_per_video else len(existing) + logger.info("Loaded %d existing %s entries", n, task_name) + return model, existing -def save_task_results(task_name: str, - topic_id: int, - topic_name: str, - entries: List[Dict[str, Any]], - output_dir: Path): + +def save_task_results( + task_name: str, + topic_id: int, + topic_name: str, + entries: List[Dict[str, Any]], + output_dir: Path, +): """ Save VQA entries to JSON file. - + Args: - task_name: "task1_summarization" or "task2_mcq" + task_name: e.g. "task1_summarization" or "task2_mcq" topic_id: Topic ID topic_name: Topic name entries: List of VQA entry dicts output_dir: Output directory (e.g., vqa/) """ if not entries: - logger.warning(f"No entries to save for {task_name} - {topic_name}") + logger.warning("No entries to save for %s - %s", task_name, topic_name) return - - # Create output directory - task_dir = output_dir / task_name - task_dir.mkdir(parents=True, exist_ok=True) - - # Build output JSON + + output_file = get_topic_output_path(output_dir, task_name, topic_id, topic_name) + output_file.parent.mkdir(parents=True, exist_ok=True) + output_data = { - 'topic_id': topic_id, - 'topic_name': topic_name, - 'task': task_name.replace('task1_', '').replace('task2_', ''), - 'generated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), - 'num_entries': len(entries), - 'entries': entries + "topic_id": topic_id, + "topic_name": topic_name, + "task": task_name.split("_", 1)[1] if "_" in task_name else task_name, + "generated_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "num_entries": len(entries), + "entries": entries, } - - # Save to file - output_file = task_dir / f"{topic_id:02d}_{topic_name.replace(' ', '_')}.json" - with open(output_file, 'w', encoding='utf-8') as f: + + with open(output_file, "w", encoding="utf-8") as f: json.dump(output_data, f, indent=2, ensure_ascii=False) - - logger.info(f"Saved {len(entries)} entries to {output_file}") + + logger.info("Saved %d entries to %s", len(entries), output_file) def get_all_topic_dirs(dataset_root: Path) -> List[tuple]: """ Get all topic directories. - - Returns: + + Returns + ------- List of tuples: (topic_id, topic_name, topic_dir_path) """ - videos_dir = dataset_root / 'videos' + videos_dir = dataset_root / "videos" if not videos_dir.exists(): logger.error(f"Videos directory not found: {videos_dir}") return [] - + topics = [] for topic_dir in sorted(videos_dir.iterdir()): if topic_dir.is_dir() and topic_dir.name[0].isdigit(): - # Extract topic ID and name from directory name (e.g., "01_Patient-Doctor_Consultations") - parts = topic_dir.name.split('_', 1) + # Extract topic ID and name from dir (e.g. 01_Patient-Doctor_...) + parts = topic_dir.name.split("_", 1) if len(parts) == 2: topic_id = int(parts[0]) - topic_name = parts[1].replace('_', ' ') + topic_name = parts[1].replace("_", " ") topics.append((topic_id, topic_name, topic_dir)) - + return topics -def merge_entries_keep_good(existing_entries: List[Dict], new_entries: List[Dict]) -> List[Dict]: + +def merge_entries_keep_good( + existing_entries: List[Dict], new_entries: List[Dict] +) -> List[Dict]: """ - Merge existing and new entries by: + Merge existing and new entries. + - Keeping existing entries with confidence > 0 - Replacing existing entries with confidence 0.0 with matching new entries - - Adding new entries for segments that don't exist in existing - + - Adding new entries for segments that don't exist in existing. + Args: existing_entries: List of existing entries for a video_id new_entries: List of newly generated entries - - Returns: + + Returns + ------- Merged list with good existing + new replacements for failed """ merged = [] - + # Keep all existing entries with confidence > 0 for existing in existing_entries: - if existing.get('confidence', 0) > 0: + if existing.get("confidence", 0) > 0: merged.append(existing) - - # For failed entries (confidence 0.0), replace with new entries if segment matches + + # Replace failed entries (confidence 0) with new if segment matches failed_segments = { - (e.get('segment', {}).get('start'), e.get('segment', {}).get('end')): e - for e in existing_entries - if e.get('confidence', 0) == 0.0 + (e.get("segment", {}).get("start"), e.get("segment", {}).get("end")): e + for e in existing_entries + if e.get("confidence", 0) == 0.0 } - + new_segments_used = set() - + # Replace failed segments with new entries for new in new_entries: - new_seg = (new.get('segment', {}).get('start'), new.get('segment', {}).get('end')) - + new_seg = ( + new.get("segment", {}).get("start"), + new.get("segment", {}).get("end"), + ) + if new_seg in failed_segments: - # This new entry replaces a failed one - merged.append(new) - new_segments_used.add(new_seg) - elif new_seg not in [(e.get('segment', {}).get('start'), e.get('segment', {}).get('end')) - for e in existing_entries]: - # This is a completely new segment (shouldn't happen usually) merged.append(new) new_segments_used.add(new_seg) - + else: + exist_segs = [ + (e.get("segment", {}).get("start"), e.get("segment", {}).get("end")) + for e in existing_entries + ] + if new_seg not in exist_segs: + merged.append(new) + new_segments_used.add(new_seg) + # Log what happened replaced = len(set(failed_segments.keys()).intersection(new_segments_used)) - kept_good = len([e for e in existing_entries if e.get('confidence', 0) > 0]) - + kept_good = sum(1 for e in existing_entries if e.get("confidence", 0) > 0) + return merged, kept_good, replaced + def main(): - """Main entry point""" - parser = argparse.ArgumentParser(description='VQA Generation System') - parser.add_argument('--config', type=str, default='04_vqa_generation/config/vqa_config.yaml', - help='Path to configuration file') - parser.add_argument('--topics', type=str, default=None, - help='Comma-separated topic IDs (e.g., "1,2,3")') - parser.add_argument('--all', action='store_true', - help='Process all topics') - parser.add_argument('--task', type=str, choices=['summarization', 'mcq','temporal'], default=None, - help='Process only specific task (default: both)') - parser.add_argument('--output', type=str, default=None, - help='Output directory (overrides config)') - + """Run main entry point.""" + parser = argparse.ArgumentParser(description="VQA Generation System") + parser.add_argument("--config", type=str, default="vqa_config.yaml") + parser.add_argument( + "--topics", + type=str, + default=None, + help='Comma-separated topic IDs (e.g., "1,2,3")', + ) + parser.add_argument("--all", action="store_true", help="Process all topics") + parser.add_argument( + "--task", + type=str, + default=None, + choices=["summarization", "mcq", "temporal"], + help="Process only a specific task (default: all)", + ) + parser.add_argument( + "--output", type=str, default=None, help="Output directory (overrides config)" + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Run without API calls; generates stub outputs", + ) args = parser.parse_args() - - # Load config - logger.info(f"Loading configuration from {args.config}") - config = load_config(args.config) - - # Set output directory - if args.output: - output_dir = Path(args.output) - else: - output_dir = Path(config.paths.output_dir) - + + config = load_config(args.config, base_dir=Path(__file__).parent) + + if args.dry_run: + logger.info("=" * 60) + logger.info("[DRY-RUN] No API calls will be made; outputs are stubs") + logger.info("=" * 60) + + output_dir = Path(args.output) if args.output else Path(config.paths.output_dir) output_dir.mkdir(parents=True, exist_ok=True) - logger.info(f"Output directory: {output_dir}") - - # Get dataset root + dataset_root = Path(config.paths.dataset_root) if not dataset_root.exists(): - logger.error(f"Dataset root not found: {dataset_root}") + logger.error("Dataset root not found: %s", dataset_root) return - - # Get all topics + all_topics = get_all_topic_dirs(dataset_root) - logger.info(f"Found {len(all_topics)} topics in dataset") - - # Filter topics if specified + logger.info("Found %d topics in dataset", len(all_topics)) + if args.topics: - topic_ids = [int(tid.strip()) for tid in args.topics.split(',')] + topic_ids = {int(t.strip()) for t in args.topics.split(",")} topics_to_process = [t for t in all_topics if t[0] in topic_ids] elif args.all: topics_to_process = all_topics else: logger.error("Must specify either --topics or --all") return - - logger.info(f"Processing {len(topics_to_process)} topics") - - # Process each topic - total_task1 = 0 - total_task2 = 0 - total_task3 = 0 + + logger.info("Processing %d topics", len(topics_to_process)) + + totals = {task_name: 0 for _, _, task_name, _ in TASKS} + for topic_id, topic_name, topic_dir in topics_to_process: - try: - task1_entries, task2_entries,task3_entries = process_topic( - topic_id, topic_name, topic_dir, config,output_dir ,task_filter=args.task - ) - - # Save results - if args.task is None or args.task == "summarization": - save_task_results('task1_summarization', topic_id, topic_name, task1_entries, output_dir) - total_task1 += len(task1_entries) - - if args.task is None or args.task == "mcq": - save_task_results('task2_mcq', topic_id, topic_name, task2_entries, output_dir) - total_task2 += len(task2_entries) - - if args.task is None or args.task == "temporal": - save_task_results('task3_temporal_localization', topic_id, topic_name, task3_entries, output_dir) - total_task3 += len(task3_entries) - + for task_key, model_class, task_name, multi_entry in TASKS: + if args.task and args.task != task_key: + continue + try: + model, existing = init_task( + task_key, + model_class, + task_name, + config, + output_dir, + topic_id, + topic_name, + task_filter=None, + list_per_video=multi_entry, + dry_run=args.dry_run, + ) + entries = process_topic( + topic_id, + topic_name, + topic_dir, + task_name, + model, + existing, + config, + multi_entry, + ) + if not args.dry_run: + save_task_results( + task_name, topic_id, topic_name, entries, output_dir + ) + else: + logger.info( + "[DRY-RUN] Would save %d entries for %s", + len(entries), + task_name, + ) + totals[task_name] += len(entries) + except Exception as e: + logger.error( + "Failed %s for topic %s: %s", task_name, topic_id, e, exc_info=True + ) - except Exception as e: - logger.error(f"Failed to process topic {topic_id}: {e}", exc_info=True) - continue - - # Final statistics logger.info("=" * 60) logger.info("VQA Generation Complete!") - logger.info(f"Topics processed: {len(topics_to_process)}") - logger.info(f"Task 1 (Summarization): {total_task1} entries") - logger.info(f"Task 2 (MCQ): {total_task2} entries") - logger.info(f"Task 3 (Temporal): {total_task3} entries") - logger.info(f"Output directory: {output_dir}") + logger.info("Topics processed: %d", len(topics_to_process)) + for _, _, task_name, _ in TASKS: + logger.info(" %s: %d entries", task_name, totals[task_name]) + logger.info("Output: %s", output_dir) logger.info("=" * 60) -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/sonic-o1/04_vqa_generation/models/__init__.py b/sonic-o1/04_vqa_generation/models/__init__.py index 427ec54..9c4049c 100644 --- a/sonic-o1/04_vqa_generation/models/__init__.py +++ b/sonic-o1/04_vqa_generation/models/__init__.py @@ -1,7 +1,19 @@ -"""VQA Generation Models""" +"""__init__.py. + +VQA Generation Models. + +Author: SONIC-O1 Team +""" + from .base_gemini import BaseGeminiClient -from .summarization_model import SummarizationModel from .mcq_model import MCQModel +from .summarization_model import SummarizationModel from .temporal_localization_model import TemporalLocalizationModel -__all__ = ['BaseGeminiClient', 'SummarizationModel', 'MCQModel','TemporalLocalizationModel'] + +__all__ = [ + "BaseGeminiClient", + "SummarizationModel", + "MCQModel", + "TemporalLocalizationModel", +] diff --git a/sonic-o1/04_vqa_generation/models/base_gemini.py b/sonic-o1/04_vqa_generation/models/base_gemini.py index 60f4505..842305b 100644 --- a/sonic-o1/04_vqa_generation/models/base_gemini.py +++ b/sonic-o1/04_vqa_generation/models/base_gemini.py @@ -1,130 +1,179 @@ +"""base_gemini.py. + +Base Gemini API client - reusable logic for all VQA tasks. + +Author: SONIC-O1 Team """ -Base Gemini API client - reusable logic for all VQA tasks -""" -from google import genai -from google.genai import types + +import logging import os import time -import logging from pathlib import Path -from typing import List, Tuple +from typing import List, Optional, Tuple + +from google import genai +from google.genai import types + logger = logging.getLogger(__name__) class BaseGeminiClient: - """Base class for Gemini API interactions""" - - def __init__(self, config): + """Base class for Gemini API interactions.""" + + _DRY_RUN_RESPONSE = '{"summary_short":[],"summary_detailed":"[DRY-RUN]","timeline":[],"glossary":[],"confidence":0.1,"question":"[DRY-RUN]","options":["(A) Dry run","(B) Dry run","(C) Dry run","(D) Not enough evidence"],"answer_index":3,"answer_letter":"D","rationale":"dry-run","evidence_tags":[],"requires_audio":false,"demographics":[],"total_individuals":0,"explanation":"dry-run"}' + + def __init__(self, config, dry_run: bool = False): """ Initialize Gemini client with configuration. - + Args: config: Configuration object with API settings + dry_run: If True, skip API client setup and return stub responses. """ self.config = config + self.dry_run = dry_run self.model_name = config.gemini.model_name self.retry_attempts = int(config.gemini.retry_attempts) self.retry_delay = int(config.gemini.retry_delay) self.file_processing_timeout = int(config.gemini.file_processing_timeout) - self.inline_threshold = int(config.file_processing.inline_threshold_mb) * 1024 * 1024 - - # Rate limiting settings (with type conversion) - self.rate_limit_delay = int(getattr(config.rate_limit, 'delay_after_api_call', 2)) - self.rate_limit_max_retries = int(getattr(config.rate_limit, 'max_retries_on_rate_limit', 5)) - self.rate_limit_backoff = int(getattr(config.rate_limit, 'rate_limit_backoff_multiplier', 2)) - - self.setup_client() - + self.inline_threshold = ( + int(config.file_processing.inline_threshold_mb) * 1024 * 1024 + ) + + self.rate_limit_delay = int( + getattr(config.rate_limit, "delay_after_api_call", 2) + ) + self.rate_limit_max_retries = int( + getattr(config.rate_limit, "max_retries_on_rate_limit", 5) + ) + self.rate_limit_backoff = int( + getattr(config.rate_limit, "rate_limit_backoff_multiplier", 2) + ) + + if dry_run: + self.client = None + logger.info("[DRY-RUN] Skipping Gemini client setup") + else: + self.setup_client() + def setup_client(self): - """Initialize the Gemini client""" + """Initialize the Gemini client.""" api_key = self.config.gemini.api_key - if api_key.startswith('${') and api_key.endswith('}'): + if api_key.startswith("${") and api_key.endswith("}"): # Extract environment variable name env_var = api_key[2:-1] api_key = os.getenv(env_var) if not api_key: raise ValueError(f"Environment variable {env_var} not set") - - os.environ['GEMINI_API_KEY'] = api_key + + os.environ["GEMINI_API_KEY"] = api_key self.client = genai.Client() logger.info(f"Initialized Gemini client with model: {self.model_name}") - - def generate_content(self, - media_files: List[Tuple[str, Path]], - prompt: str, - video_fps: float = 1.0) -> str: + + def generate_content( + self, media_files: List[Tuple[str, Path]], prompt: str, video_fps: float = 1.0 + ) -> str: """ Generate content using Gemini with multimodal inputs. - + Args: - media_files: List of tuples (media_type, Path) - e.g., [('video', path), ('audio', path)] - prompt: Text prompt for generation - video_fps: FPS for video sampling (default: 1.0) - - Returns: - Generated text response + media_files: List of (media_type, Path), e.g. [('video', path)]. + prompt: Text prompt for generation. + video_fps: FPS for video sampling (default: 1.0). + + Returns + ------- + Generated text response. """ - # Calculate total size to determine processing method + if self.dry_run: + logger.info( + "[DRY-RUN] Would call Gemini with %d media files", len(media_files) + ) + return self._DRY_RUN_RESPONSE + total_size = sum(os.path.getsize(path) for _, path in media_files) - + if total_size > self.inline_threshold: - logger.info(f"Using File API for large media (size: {total_size / (1024*1024):.2f}MB)") + logger.info( + f"Using File API for large media " + f"(size: {total_size / (1024 * 1024):.2f}MB)" + ) return self._process_with_file_api(media_files, prompt, video_fps) - else: - logger.info(f"Using inline processing (size: {total_size / (1024*1024):.2f}MB)") - return self._process_inline(media_files, prompt, video_fps) + logger.info( + f"Using inline processing (size: {total_size / (1024 * 1024):.2f}MB)" + ) + return self._process_inline(media_files, prompt, video_fps) + + def _wait_for_files( + self, uploaded_files: List[Tuple[str, object]] + ) -> List[Tuple[str, object]]: + """Poll *uploaded_files* until all reach ACTIVE state. + + Args: + uploaded_files: List of (media_type, uploaded_file) pairs. + + Returns + ------- + Updated list with the latest file objects from the File API. + + Raises + ------ + Exception: If any file transitions to FAILED, or the timeout is exceeded. + """ + max_wait = self.file_processing_timeout + wait_time = 0 + + while wait_time < max_wait: + all_processed = True + for i, (media_type, uploaded_file) in enumerate(uploaded_files): + updated_file = self.client.files.get(name=uploaded_file.name) + uploaded_files[i] = (media_type, updated_file) + if updated_file.state == "PROCESSING": + all_processed = False + elif updated_file.state == "FAILED": + error_msg = getattr(updated_file, "error", "Unknown error") + raise Exception(f"File processing failed: {error_msg}") + + if all_processed: + return uploaded_files + + time.sleep(10) + wait_time += 10 + if wait_time % 60 == 0: + logger.info( + "Still waiting for file processing (%ds elapsed)", wait_time + ) + raise Exception(f"File processing timeout after {max_wait}s") - def _process_with_file_api(self, media_files: List[Tuple[str, Path]], prompt: str, video_fps: float = 1.0) -> str: - """Process large files using Gemini File API""" + def _process_with_file_api( + self, media_files: List[Tuple[str, Path]], prompt: str, video_fps: float = 1.0 + ) -> str: + """Process large files using Gemini File API.""" uploaded_files = [] try: - # Upload all media files for media_type, media_path in media_files: uploaded_file = self.client.files.upload(file=str(media_path)) logger.info(f"Uploaded {media_type}: {uploaded_file.name}") - uploaded_files.append((media_type, uploaded_file)) # Store type with file - - # Wait for all files to process - max_wait = self.file_processing_timeout - wait_time = 0 - all_processed = False - - while not all_processed and wait_time < max_wait: - all_processed = True - for i, (media_type, uploaded_file) in enumerate(uploaded_files): - updated_file = self.client.files.get(name=uploaded_file.name) - uploaded_files[i] = (media_type, updated_file) - if updated_file.state == "PROCESSING": - all_processed = False - elif updated_file.state == "FAILED": - error_msg = getattr(updated_file, 'error', 'Unknown error') - raise Exception(f"File processing failed: {error_msg}") - - if not all_processed: - time.sleep(10) - wait_time += 10 - if wait_time % 60 == 0: # Log every minute - logger.info(f"Still waiting for file processing... ({wait_time}s elapsed)") - - if not all_processed: - raise Exception(f"File processing timeout after {max_wait}s") - + uploaded_files.append((media_type, uploaded_file)) + + uploaded_files = self._wait_for_files(uploaded_files) + # Generate content with uploaded files + prompt for attempt in range(self.retry_attempts): try: # Build content parts with video_metadata for video files content_parts = [] for media_type, uploaded_file in uploaded_files: - if media_type == 'video': + if media_type == "video": content_parts.append( types.Part( file_data=types.FileData( file_uri=uploaded_file.uri, - mime_type=uploaded_file.mime_type + mime_type=uploaded_file.mime_type, ), - video_metadata=types.VideoMetadata(fps=video_fps) + video_metadata=types.VideoMetadata(fps=video_fps), ) ) else: @@ -132,17 +181,17 @@ def _process_with_file_api(self, media_files: List[Tuple[str, Path]], prompt: st types.Part( file_data=types.FileData( file_uri=uploaded_file.uri, - mime_type=uploaded_file.mime_type + mime_type=uploaded_file.mime_type, ) ) ) - + # Add prompt content_parts.append(types.Part(text=prompt)) - + response = self.client.models.generate_content( model=self.model_name, - contents=types.Content(parts=content_parts) + contents=types.Content(parts=content_parts), ) return response.text except Exception as e: @@ -153,56 +202,54 @@ def _process_with_file_api(self, media_files: List[Tuple[str, Path]], prompt: st raise finally: # Cleanup uploaded files - for media_type, uploaded_file in uploaded_files: + for _, uploaded_file in uploaded_files: try: self.client.files.delete(name=uploaded_file.name) logger.debug(f"Deleted uploaded file: {uploaded_file.name}") except Exception as e: logger.warning(f"Failed to delete file {uploaded_file.name}: {e}") - def _process_inline(self, media_files: List[Tuple[str, Path]], prompt: str, video_fps: float = 1.0) -> str: - """Process small files using inline data""" + def _process_inline( + self, media_files: List[Tuple[str, Path]], prompt: str, video_fps: float = 1.0 + ) -> Optional[str]: + """Process small files using inline data.""" parts = [] - + # Add all media files as inline data for media_type, media_path in media_files: - with open(media_path, 'rb') as f: + with open(media_path, "rb") as f: media_bytes = f.read() - + mime_type = self._get_mime_type(media_path) - + # Add video metadata only for video files - if media_type == 'video': + if media_type == "video": parts.append( types.Part( - inline_data=types.Blob( - data=media_bytes, - mime_type=mime_type - ), - video_metadata=types.VideoMetadata(fps=video_fps) + inline_data=types.Blob(data=media_bytes, mime_type=mime_type), + video_metadata=types.VideoMetadata(fps=video_fps), ) ) - logger.info(f"Added {media_type} ({mime_type}) as inline data with fps={video_fps}") + logger.info( + f"Added {media_type} ({mime_type}) as inline " + f"data with fps={video_fps}" + ) else: parts.append( types.Part( - inline_data=types.Blob( - data=media_bytes, - mime_type=mime_type - ) + inline_data=types.Blob(data=media_bytes, mime_type=mime_type) ) ) logger.info(f"Added {media_type} ({mime_type}) as inline data") - + # Add text prompt parts.append(types.Part(text=prompt)) - + # Generate with retries for attempt in range(self.retry_attempts): try: response = self.client.models.generate_content( - model=self.model_name, - contents=types.Content(parts=parts) + model=self.model_name, contents=types.Content(parts=parts) ) return response.text except Exception as e: @@ -211,24 +258,26 @@ def _process_inline(self, media_files: List[Tuple[str, Path]], prompt: str, vide time.sleep(self.retry_delay) else: raise + return None + def _get_mime_type(self, file_path: Path) -> str: - """Get MIME type for media file""" + """Get MIME type for media file.""" extension_map = { # Video - '.mp4': 'video/mp4', - '.avi': 'video/x-msvideo', - '.mov': 'video/quicktime', - '.webm': 'video/webm', - '.mkv': 'video/x-matroska', - '.m4v': 'video/x-m4v', + ".mp4": "video/mp4", + ".avi": "video/x-msvideo", + ".mov": "video/quicktime", + ".webm": "video/webm", + ".mkv": "video/x-matroska", + ".m4v": "video/x-m4v", # Audio - '.m4a': 'audio/m4a', - '.mp3': 'audio/mpeg', - '.wav': 'audio/wav', - '.ogg': 'audio/ogg', - '.flac': 'audio/flac', - '.aac': 'audio/aac', + ".m4a": "audio/m4a", + ".mp3": "audio/mpeg", + ".wav": "audio/wav", + ".ogg": "audio/ogg", + ".flac": "audio/flac", + ".aac": "audio/aac", } - + ext = file_path.suffix.lower() - return extension_map.get(ext, 'application/octet-stream') \ No newline at end of file + return extension_map.get(ext, "application/octet-stream") diff --git a/sonic-o1/04_vqa_generation/models/mcq_model.py b/sonic-o1/04_vqa_generation/models/mcq_model.py index 9f58290..948ccb7 100644 --- a/sonic-o1/04_vqa_generation/models/mcq_model.py +++ b/sonic-o1/04_vqa_generation/models/mcq_model.py @@ -1,94 +1,115 @@ +"""mcq_model.py. + +Task 2: MCQ (Multiple Choice Questions) Generation Model. + +Author: SONIC-O1 Team """ -Task 2: MCQ (Multiple Choice Questions) Generation Model -""" + import json import logging +import time from pathlib import Path -from typing import Dict, List, Any, Optional -from .base_gemini import BaseGeminiClient -from utils.video_segmenter import VideoSegmenter -from utils.demographics_expander import DemographicsExpander +from typing import Any, Dict, List, Optional + from prompts.mcq_prompts import get_mcq_prompt -import time +from utils.demographics_expander import DemographicsExpander +from utils.video_segmenter import VideoSegmenter + +from .base_gemini import BaseGeminiClient + + logger = logging.getLogger(__name__) + class MCQModel(BaseGeminiClient): - """Generate segment-level MCQ VQA entries""" - - def __init__(self, config): + """Generate segment-level MCQ VQA entries.""" + + def __init__(self, config, dry_run: bool = False): """ Initialize MCQ model. - + Args: config: Configuration object + dry_run: If True, skip API calls and return stub responses. """ - super().__init__(config) + super().__init__(config, dry_run=dry_run) self.config = config self.segmenter = VideoSegmenter(config) self.demographics_expander = DemographicsExpander(config) - + # Get num_options from config for validation self.num_options = config.mcq.num_options self.option_letters = [chr(65 + i) for i in range(self.num_options)] - - def process_video(self, - video_path: Path, - audio_path: Optional[Path], - transcript_path: Optional[Path], - metadata: Dict[str, Any]) -> List[Dict[str, Any]]: + + def process_video( + self, + video_path: Path, + audio_path: Optional[Path], + transcript_path: Optional[Path], + metadata: Dict[str, Any], + ) -> List[Dict[str, Any]]: """ Process a video and generate MCQ VQA entries (one per segment). - + Args: - video_path: Path to video file - audio_path: Path to audio file (optional) - transcript_path: Path to transcript/caption file (optional) - metadata: Video metadata from metadata_enhanced.json - - Returns: - List of VQA entry dicts for Task 2 (one per segment) + video_path: Path to video file. + audio_path: Path to audio file (optional). + transcript_path: Path to transcript/caption file (optional). + metadata: Video metadata from metadata_enhanced.json. + + Returns + ------- + List of VQA entry dicts for Task 2 (one per segment). """ try: - video_id = metadata.get('video_id', metadata.get('video_number', 'unknown')) - duration = metadata.get('duration_seconds', 0) - - logger.info(f"Processing video {video_id} for MCQ generation (duration: {duration}s)") - + video_id = metadata.get("video_id", metadata.get("video_number", "unknown")) + duration = metadata.get("duration_seconds", 0) + + logger.info(f"Processing video {video_id} for MCQ (duration: {duration}s)") + # Always segment videos (even short ones get 1 MCQ) - video_segments = self.segmenter.segment_video(video_path, duration, task_type='mcq') + video_segments = self.segmenter.segment_video( + video_path, duration, task_type="mcq" + ) audio_segments = None if audio_path and audio_path.exists(): - audio_segments = self.segmenter.segment_audio(audio_path, duration, task_type='mcq') - + audio_segments = self.segmenter.segment_audio( + audio_path, duration, task_type="mcq" + ) + logger.info(f"Created {len(video_segments)} segments for MCQ generation") - + # Generate MCQ for each segment mcq_entries = [] for i, seg in enumerate(video_segments): try: # Get corresponding audio segment - audio_seg_path = audio_segments[i]['segment_path'] if audio_segments else None - + audio_seg_path = ( + audio_segments[i]["segment_path"] if audio_segments else None + ) + # Extract transcript for this segment transcript_text = "" if transcript_path and transcript_path.exists(): transcript_text = self.segmenter.extract_transcript_segment( - transcript_path, seg['start'], seg['end'], - strip_timestamps=True + transcript_path, + seg["start"], + seg["end"], + strip_timestamps=True, ) - + # Generate MCQ for this segment mcq_entry = self._generate_mcq_for_segment( seg, audio_seg_path, transcript_text, metadata ) - + if mcq_entry: mcq_entries.append(mcq_entry) - + except Exception as e: logger.error(f"Failed to generate MCQ for segment {i}: {e}") continue - + # Cleanup temporary segment files try: self.segmenter.cleanup_segments(video_segments) @@ -96,78 +117,113 @@ def process_video(self, self.segmenter.cleanup_segments(audio_segments) except Exception as e: logger.warning(f"Failed to cleanup segments: {e}") - - logger.info(f"Generated {len(mcq_entries)} MCQ entries for video {video_id}") + + logger.info( + f"Generated {len(mcq_entries)} MCQ entries for video {video_id}" + ) return mcq_entries - + except Exception as e: - logger.error(f"Error processing video {video_id} for MCQ: {e}", exc_info=True) + logger.error( + f"Error processing video {video_id} for MCQ: {e}", exc_info=True + ) return [] - - def _generate_mcq_for_segment(self, - segment_info: Dict, - audio_path: Optional[Path], - transcript_text: str, - metadata: Dict[str, Any]) -> Optional[Dict[str, Any]]: - """Generate one MCQ for a video segment with retry on parse failure""" - video_id = metadata.get('video_id', metadata.get('video_number', 'unknown')) - seg_num = segment_info['segment_number'] - - logger.info(f"Generating MCQ for {video_id} segment {seg_num} ({segment_info['start']}s-{segment_info['end']}s)") - + + def _generate_mcq_for_segment( + self, + segment_info: Dict, + audio_path: Optional[Path], + transcript_text: str, + metadata: Dict[str, Any], + ) -> Optional[Dict[str, Any]]: + """Generate one MCQ for a video segment with retry on parse failure.""" + video_id = metadata.get("video_id", metadata.get("video_number", "unknown")) + seg_num = segment_info["segment_number"] + + logger.info( + f"Generating MCQ for {video_id} segment {seg_num} " + f"({segment_info['start']}s-{segment_info['end']}s)" + ) + max_attempts = 3 mcq_data = None - + # Prepare media files once (outside retry loop) media_files = [] - seg_path = segment_info['segment_path'] + seg_path = segment_info["segment_path"] if seg_path.exists(): # Determine if it's video or audio - if seg_path.suffix.lower() in ['.mp4', '.avi', '.mov', '.webm', '.mkv', '.m4v']: - media_files.append(('video', seg_path)) + if seg_path.suffix.lower() in [ + ".mp4", + ".avi", + ".mov", + ".webm", + ".mkv", + ".m4v", + ]: + media_files.append(("video", seg_path)) else: - media_files.append(('audio', seg_path)) - + media_files.append(("audio", seg_path)) + if audio_path and audio_path.exists(): - media_files.append(('audio', audio_path)) - + media_files.append(("audio", audio_path)) + # Retry loop for MCQ generation for attempt in range(max_attempts): try: # Build MCQ generation prompt - prompt = get_mcq_prompt(segment_info, metadata, transcript_text, self.config) - + prompt = get_mcq_prompt( + segment_info, metadata, transcript_text, self.config + ) + # On retry, add explicit JSON validation instructions if attempt > 0: - prompt += "\n\nPREVIOUS ATTEMPT RETURNED INVALID JSON. Critical requirements:\n" + prompt += ( + "\n\nPREVIOUS ATTEMPT RETURNED INVALID JSON. " + "Critical requirements:\n" + ) prompt += f"- Return ONLY valid JSON with exactly {self.num_options} options\n" - prompt += f"- Use commas between all properties/elements except the last\n" + prompt += ( + "- Use commas between all properties/elements except the last\n" + ) prompt += f"- Last option MUST be '({self.option_letters[-1]}) Not enough evidence'\n" - prompt += f"- Include both 'answer_index' (0-{self.num_options - 1}) and 'answer_letter' (A-{self.option_letters[-1]})\n" + prompt += ( + f"- Include 'answer_index' (0-{self.num_options - 1}) " + f"and 'answer_letter' (A-{self.option_letters[-1]})\n" + ) prompt += "- Use double quotes for all strings\n" prompt += "- No trailing commas before closing brackets\n" - + # Generate MCQ - response_text = self.generate_content(media_files, prompt, video_fps=0.5) + response_text = self.generate_content( + media_files, prompt, video_fps=0.5 + ) mcq_data = self._parse_mcq_response(response_text) - - # Check if parsing succeeded (confidence > 0 and rationale is not failure message) - if (mcq_data.get('confidence', 0) > 0 and - mcq_data.get('rationale', '') != 'Failed to generate MCQ'): + + # Check if parsing succeeded (confidence > 0 and rationale valid) + if ( + mcq_data.get("confidence", 0) > 0 + and mcq_data.get("rationale", "") != "Failed to generate MCQ" + ): break - else: - if attempt < max_attempts - 1: - time.sleep(30) - + if attempt < max_attempts - 1: + time.sleep(30) + except Exception as e: - logger.error(f"Failed to generate MCQ for segment {seg_num}: {e}", exc_info=True) + logger.error( + f"Failed to generate MCQ for segment {seg_num}: {e}", exc_info=True + ) if attempt < max_attempts - 1: time.sleep(30) - + # If all attempts failed, use default MCQ - if not mcq_data or mcq_data.get('confidence', 0) == 0 or mcq_data.get('rationale', '') == 'Failed to generate MCQ': + if ( + not mcq_data + or mcq_data.get("confidence", 0) == 0 + or mcq_data.get("rationale", "") == "Failed to generate MCQ" + ): mcq_data = self._get_default_mcq() - + # Get demographics for this segment try: demographics_data = self._get_segment_demographics( @@ -175,214 +231,255 @@ def _generate_mcq_for_segment(self, ) except Exception as e: demographics_data = { - 'demographics': [], - 'total_individuals': 0, - 'confidence': 0.0, - 'explanation': f'Error generating demographics: {str(e)}' + "demographics": [], + "total_individuals": 0, + "confidence": 0.0, + "explanation": f"Error generating demographics: {str(e)}", } - + # Build MCQ entry with all demographic information - entry = { - 'video_id': video_id, - 'video_number': metadata.get('video_number', video_id), - 'segment': { - 'start': segment_info['start'], - 'end': segment_info['end'] - }, - 'question': mcq_data.get('question', ''), - 'options': mcq_data.get('options', []), - 'answer_index': mcq_data.get('answer_index', self.num_options - 1), - 'answer_letter': mcq_data.get('answer_letter', self.option_letters[-1]), - 'rationale': mcq_data.get('rationale', ''), - 'evidence_tags': mcq_data.get('evidence_tags', []), - 'requires_audio': mcq_data.get('requires_audio', False), - 'demographics': demographics_data.get('demographics', []), - 'demographics_total_individuals': demographics_data.get('total_individuals', 0), - 'demographics_confidence': demographics_data.get('confidence', 0.0), - 'demographics_explanation': demographics_data.get('explanation', ''), - 'confidence': mcq_data.get('confidence', 0.0) + return { + "video_id": video_id, + "video_number": metadata.get("video_number", video_id), + "segment": {"start": segment_info["start"], "end": segment_info["end"]}, + "question": mcq_data.get("question", ""), + "options": mcq_data.get("options", []), + "answer_index": mcq_data.get("answer_index", self.num_options - 1), + "answer_letter": mcq_data.get("answer_letter", self.option_letters[-1]), + "rationale": mcq_data.get("rationale", ""), + "evidence_tags": mcq_data.get("evidence_tags", []), + "requires_audio": mcq_data.get("requires_audio", False), + "demographics": demographics_data.get("demographics", []), + "demographics_total_individuals": demographics_data.get( + "total_individuals", 0 + ), + "demographics_confidence": demographics_data.get("confidence", 0.0), + "demographics_explanation": demographics_data.get("explanation", ""), + "confidence": mcq_data.get("confidence", 0.0), } - - return entry - - def _get_segment_demographics(self, - segment_info: Dict, - video_path: Path, - audio_path: Optional[Path], - transcript_text: str, - metadata: Dict[str, Any]) -> Dict[str, Any]: - """Get expanded demographics for a specific segment""" + + def _get_segment_demographics( + self, + segment_info: Dict, + video_path: Path, + audio_path: Optional[Path], + transcript_text: str, + metadata: Dict[str, Any], + ) -> Dict[str, Any]: + """Get expanded demographics for a specific segment.""" try: # Get human-reviewed demographics from metadata (video-level) - human_demographics = metadata.get('demographics_detailed_reviewed', {}) + human_demographics = metadata.get("demographics_detailed_reviewed", {}) if not human_demographics: - logger.warning(f"No human-reviewed demographics found for {metadata.get('video_id')}") + logger.warning( + f"No human-reviewed demographics for {metadata.get('video_id')}" + ) return { - 'demographics': [], - 'total_individuals': 0, - 'confidence': 0.0, - 'explanation': 'No human-reviewed demographics available' + "demographics": [], + "total_individuals": 0, + "confidence": 0.0, + "explanation": "No human-reviewed demographics available", } - + # Build expansion prompt (segment-level) prompt = self.demographics_expander.build_expansion_prompt( human_demographics, - segment_info={'start': segment_info['start'], 'end': segment_info['end']} + segment_info={ + "start": segment_info["start"], + "end": segment_info["end"], + }, ) - + # Prepare media files media_files = [] if video_path and video_path.exists(): - if video_path.suffix.lower() in ['.mp4', '.avi', '.mov', '.webm', '.mkv', '.m4v']: - media_files.append(('video', video_path)) + if video_path.suffix.lower() in [ + ".mp4", + ".avi", + ".mov", + ".webm", + ".mkv", + ".m4v", + ]: + media_files.append(("video", video_path)) else: - media_files.append(('audio', video_path)) - + media_files.append(("audio", video_path)) + if audio_path and audio_path.exists(): - media_files.append(('audio', audio_path)) - + media_files.append(("audio", audio_path)) + # Add transcript context to prompt if transcript_text: prompt += f"\n\nSEGMENT TRANSCRIPT:\n{transcript_text[:1000]}" - + # Generate demographics - response_text = self.generate_content(media_files, prompt,video_fps=0.25) - demographics_data = self.demographics_expander.parse_demographics_response(response_text) - + response_text = self.generate_content(media_files, prompt, video_fps=0.25) + demographics_data = self.demographics_expander.parse_demographics_response( + response_text + ) + # Ensure all fields are present - if 'explanation' not in demographics_data: - demographics_data['explanation'] = 'No explanation provided' - + if "explanation" not in demographics_data: + demographics_data["explanation"] = "No explanation provided" + return demographics_data - + except Exception as e: logger.error(f"Failed to get segment demographics: {e}") return { - 'demographics': [], - 'total_individuals': 0, - 'confidence': 0.0, - 'explanation': f'Error generating demographics: {str(e)}' + "demographics": [], + "total_individuals": 0, + "confidence": 0.0, + "explanation": f"Error generating demographics: {str(e)}", } - + + def _extract_json_from_markdown(self, text: str) -> str: + """Extract JSON from markdown code blocks.""" + text = text.strip() + if "```json" in text: + start = text.find("```json") + 7 + end = text.rfind("```") + if end > start: + return text[start:end].strip() + elif "```" in text: + start = text.find("```") + 3 + end = text.rfind("```") + if end > start: + return text[start:end].strip() + return text + + def _normalize_mcq_data(self, data: Any) -> Optional[Dict[str, Any]]: + """Normalize parsed data to dict format.""" + if isinstance(data, list): + if len(data) > 0 and isinstance(data[0], dict): + logger.warning("API returned list, extracting first element") + return data[0] + logger.error("API returned invalid list format") + return None + + if not isinstance(data, dict): + logger.error(f"Parsed data is not a dict, got {type(data)}") + return None + + return data + + def _normalize_options_count(self, options: List[str]) -> List[str]: + """Normalize options count to expected num_options.""" + expected_num_options = self.num_options + + if len(options) != expected_num_options: + logger.warning( + f"MCQ has {len(options)} options, expected {expected_num_options}" + ) + # Pad with "Not enough evidence" if needed + while len(options) < expected_num_options: + last_letter = self.option_letters[len(options)] + options.append(f"({last_letter}) Not enough evidence") + options = options[:expected_num_options] + + # Ensure last option is "Not enough evidence" + last_letter = self.option_letters[-1] + if "Not enough evidence" not in options[-1].lower(): + options[-1] = f"({last_letter}) Not enough evidence" + + return options + + def _format_options_with_letters(self, options: List[str]) -> List[str]: + """Assign correct letter prefixes, replacing any existing ones.""" + formatted = [] + for letter, option in zip(self.option_letters, options): + cleaned = option.strip() + if len(cleaned) >= 3 and cleaned[0] == "(" and cleaned[2] == ")": + cleaned = cleaned[3:].strip() + formatted.append(f"({letter}) {cleaned}") + return formatted + + def _validate_answer_fields(self, data: Dict[str, Any]) -> tuple[int, str]: + """Validate and normalize answer_index and answer_letter.""" + expected_num_options = self.num_options + answer_index = int(data.get("answer_index", expected_num_options - 1)) + max_index = expected_num_options - 1 + + if answer_index < 0 or answer_index > max_index: + logger.warning( + f"Invalid answer_index {answer_index}, defaulting to {max_index}" + ) + answer_index = max_index + + answer_letter = data.get("answer_letter", self.option_letters[answer_index]) + expected_letter = self.option_letters[answer_index] + + if answer_letter != expected_letter: + logger.warning( + f"Mismatch answer_letter={answer_letter}, " + f"answer_index={answer_index} (expected {expected_letter})" + ) + answer_letter = expected_letter + + return answer_index, answer_letter + def _parse_mcq_response(self, response_text: str) -> Dict[str, Any]: - """Parse MCQ response from Gemini""" + """Parse MCQ response from Gemini.""" try: - response_text = response_text.strip() - - # Remove markdown code blocks - if "```json" in response_text: - start = response_text.find("```json") + 7 - end = response_text.rfind("```") - if end > start: - response_text = response_text[start:end] - elif "```" in response_text: - start = response_text.find("```") + 3 - end = response_text.rfind("```") - if end > start: - response_text = response_text[start:end] - - response_text = response_text.strip() - data = json.loads(response_text) - - if isinstance(data, list): - if len(data) > 0 and isinstance(data[0], dict): - logger.warning("API returned list, extracting first element") - data = data[0] - else: - logger.error("API returned invalid list format") - return self._get_default_mcq() - - if not isinstance(data, dict): - logger.error(f"Parsed data is not a dict, got {type(data)}") + cleaned_text = self._extract_json_from_markdown(response_text) + data = json.loads(cleaned_text) + + normalized_data = self._normalize_mcq_data(data) + if normalized_data is None: return self._get_default_mcq() - - # Validate structure - use config's num_options - options = data.get('options', []) - expected_num_options = self.num_options - - if len(options) != expected_num_options: - logger.warning(f"MCQ has {len(options)} options instead of {expected_num_options}") - # Pad with "Not enough evidence" if needed - while len(options) < expected_num_options: - last_letter = self.option_letters[len(options)] - options.append(f"({last_letter}) Not enough evidence") - options = options[:expected_num_options] - - # Ensure last option is "Not enough evidence" (with or without letter) - last_letter = self.option_letters[-1] - if "Not enough evidence" not in options[-1].lower(): - options[-1] = f"({last_letter}) Not enough evidence" - - # Ensure all options have letters - add if missing - formatted_options = [] - for i, opt in enumerate(options): - opt = opt.strip() - letter = self.option_letters[i] - # Check if option already has a letter prefix - if not opt.startswith(f"({letter})"): - # Remove any existing letter prefix first - for existing_letter in self.option_letters: - if opt.startswith(f"({existing_letter})"): - opt = opt[3:].strip() - break - opt = f"({letter}) {opt}" - formatted_options.append(opt) - - answer_index = int(data.get('answer_index', expected_num_options - 1)) - max_index = expected_num_options - 1 - - if answer_index < 0 or answer_index > max_index: - logger.warning(f"Invalid answer_index {answer_index}, defaulting to {max_index}") - answer_index = max_index - - # Get or derive answer_letter - answer_letter = data.get('answer_letter', self.option_letters[answer_index]) - - # Validate answer_letter matches answer_index - expected_letter = self.option_letters[answer_index] - if answer_letter != expected_letter: - logger.warning(f"Mismatch: answer_letter={answer_letter}, answer_index={answer_index} (expected {expected_letter}). Using index.") - answer_letter = expected_letter - + + # Normalize options count and format + options = normalized_data.get("options", []) + options = self._normalize_options_count(options) + formatted_options = self._format_options_with_letters(options) + + # Validate answer fields + answer_index, answer_letter = self._validate_answer_fields(normalized_data) + return { - 'question': data.get('question', 'What is happening in the video and audio?'), - 'options': formatted_options, - 'answer_index': answer_index, - 'answer_letter': answer_letter, - 'rationale': data.get('rationale', ''), - 'evidence_tags': data.get('evidence_tags', []), - 'requires_audio': bool(data.get('requires_audio', False)), - 'confidence': float(data.get('confidence', 0.0)) + "question": normalized_data.get( + "question", "What is happening in the video and audio?" + ), + "options": formatted_options, + "answer_index": answer_index, + "answer_letter": answer_letter, + "rationale": normalized_data.get("rationale", ""), + "evidence_tags": normalized_data.get("evidence_tags", []), + "requires_audio": bool(normalized_data.get("requires_audio", False)), + "confidence": float(normalized_data.get("confidence", 0.0)), } - + except json.JSONDecodeError as e: logger.error(f"Failed to parse MCQ JSON: {e}") logger.debug(f"Response text: {response_text[:500]}...") return self._get_default_mcq() except Exception as e: logger.error(f"Error parsing MCQ response: {e}") - logger.debug(f"Response text (first 500 chars): {response_text[:500] if response_text else 'None'}...") + logger.debug( + f"Response (500 chars): " + f"{response_text[:500] if response_text else 'None'}..." + ) return self._get_default_mcq() - + def _get_default_mcq(self) -> Dict[str, Any]: - """Return default MCQ structure when parsing fails""" + """Return default MCQ structure when parsing fails.""" # Build default options dynamically based on num_options default_options = [] for i in range(self.num_options - 1): letter = self.option_letters[i] default_options.append(f"({letter}) Unable to generate option") - + # Last option is always "Not enough evidence" last_letter = self.option_letters[-1] default_options.append(f"({last_letter}) Not enough evidence") - + return { - 'question': 'What is happening in the video and audio?', - 'options': default_options, - 'answer_index': self.num_options - 1, - 'answer_letter': last_letter, - 'rationale': 'Failed to generate MCQ', - 'evidence_tags': [], - 'requires_audio': False, - 'confidence': 0.0 - } \ No newline at end of file + "question": "What is happening in the video and audio?", + "options": default_options, + "answer_index": self.num_options - 1, + "answer_letter": last_letter, + "rationale": "Failed to generate MCQ", + "evidence_tags": [], + "requires_audio": False, + "confidence": 0.0, + } diff --git a/sonic-o1/04_vqa_generation/models/summarization_model.py b/sonic-o1/04_vqa_generation/models/summarization_model.py index d0c5094..f8c5e6f 100644 --- a/sonic-o1/04_vqa_generation/models/summarization_model.py +++ b/sonic-o1/04_vqa_generation/models/summarization_model.py @@ -1,363 +1,422 @@ +"""summarization_model.py. + +Task 1: Video Summarization Model. + +Author: SONIC-O1 Team """ -Task 1: Video Summarization Model -""" + import json import logging -import time import re +import time from pathlib import Path -from typing import Dict, List, Any, Optional +from typing import Any, Dict, List, Optional -from .base_gemini import BaseGeminiClient -from utils.video_segmenter import VideoSegmenter -from utils.demographics_expander import DemographicsExpander from prompts.summarization_prompts import ( - get_map_prompt, - get_reduce_prompt, get_direct_prompt, get_initialize_prompt, - get_streaming_update_prompt + get_map_prompt, + get_streaming_update_prompt, ) +from utils.demographics_expander import DemographicsExpander +from utils.video_segmenter import VideoSegmenter + +from .base_gemini import BaseGeminiClient + logger = logging.getLogger(__name__) class SummarizationModel(BaseGeminiClient): - """Generate video-level summarization VQA entries""" - - def __init__(self, config): + """Generate video-level summarization VQA entries.""" + + def __init__(self, config, dry_run: bool = False): """ Initialize summarization model. - + Args: config: Configuration object + dry_run: If True, skip API calls and return stub responses. """ - super().__init__(config) + super().__init__(config, dry_run=dry_run) self.config = config self.segmenter = VideoSegmenter(config) self.demographics_expander = DemographicsExpander(config) - - def process_video(self, - video_path: Path, - audio_path: Optional[Path], - transcript_path: Optional[Path], - metadata: Dict[str, Any]) -> Dict[str, Any]: + + def process_video( + self, + video_path: Path, + audio_path: Optional[Path], + transcript_path: Optional[Path], + metadata: Dict[str, Any], + ) -> Dict[str, Any]: """ Process a video and generate summarization VQA entry. - + Args: - video_path: Path to video file - audio_path: Path to audio file (optional) - transcript_path: Path to transcript/caption file (optional) - metadata: Video metadata from metadata_enhanced.json - - Returns: - VQA entry dict for Task 1 + video_path: Path to video file. + audio_path: Path to audio file (optional). + transcript_path: Path to transcript/caption file (optional). + metadata: Video metadata from metadata_enhanced.json. + + Returns + ------- + VQA entry dict for Task 1. """ try: - video_id = metadata.get('video_id', metadata.get('video_number', 'unknown')) - duration = metadata.get('duration_seconds', 0) - category = metadata.get('duration_category', 'short') - - if category not in ['short', 'medium', 'long']: + video_id = metadata.get("video_id", metadata.get("video_number", "unknown")) + duration = metadata.get("duration_seconds", 0) + category = metadata.get("duration_category", "short") + + if category not in ["short", "medium", "long"]: if duration <= 300: - category = 'short' + category = "short" elif duration <= 1800: - category = 'medium' + category = "medium" else: - category = 'long' + category = "long" logger.warning( - f"Video {video_id}: No valid category in metadata, " - f"determined as '{category}' based on duration ({duration}s)" + f"Video {video_id}: no valid category, " + f"using '{category}' from duration ({duration}s)" + ) + + print( + f"Processing video {video_id} for summarization " + f"(duration: {duration}s, category: {category})" + ) + + if category == "short": + result = self._process_short_video( + video_path, audio_path, transcript_path, metadata ) - - print(f"Processing video {video_id} for summarization (duration: {duration}s, category: {category})") - - if category == 'short': - result = self._process_short_video(video_path, audio_path, transcript_path, metadata) else: - result = self._process_segmented_video(video_path, audio_path, transcript_path, metadata) - + result = self._process_segmented_video( + video_path, audio_path, transcript_path, metadata + ) + return result - + except Exception as e: logger.error(f"Error processing video {video_id}: {e}", exc_info=True) return self._get_error_entry(metadata, str(e)) - - def _process_short_video(self, - video_path: Path, - audio_path: Optional[Path], - transcript_path: Optional[Path], - metadata: Dict[str, Any]) -> Dict[str, Any]: - """Process short video with direct summarization (no segmentation)""" - video_id = metadata.get('video_id', metadata.get('video_number', 'unknown')) + + def _process_short_video( + self, + video_path: Path, + audio_path: Optional[Path], + transcript_path: Optional[Path], + metadata: Dict[str, Any], + ) -> Dict[str, Any]: + """Process short video with direct summarization (no segmentation).""" + video_id = metadata.get("video_id", metadata.get("video_number", "unknown")) print(f"Direct summarization for short video {video_id}") - + transcript_text = self._load_transcript(transcript_path) prompt = get_direct_prompt(video_id, metadata, transcript_text, self.config) - + media_files = [] if video_path and video_path.exists(): - media_files.append(('video', video_path)) + media_files.append(("video", video_path)) if audio_path and audio_path.exists(): - media_files.append(('audio', audio_path)) - + media_files.append(("audio", audio_path)) + max_attempts = 3 summary_data = None - + for attempt in range(max_attempts): try: - response_text = self.generate_content(media_files, prompt, video_fps=0.5) + response_text = self.generate_content( + media_files, prompt, video_fps=0.5 + ) summary_data = self._parse_summary_response(response_text) - - if summary_data.get('confidence', 0) > 0: + + if summary_data.get("confidence", 0) > 0: print(f"Successfully generated summary for {video_id}") break - else: - logger.warning(f"Parse attempt {attempt + 1} failed for {video_id}, retrying...") - if attempt < max_attempts - 1: - time.sleep(5) + logger.warning( + f"Parse attempt {attempt + 1} failed for {video_id}, retrying..." + ) + if attempt < max_attempts - 1: + time.sleep(5) except Exception as e: logger.error(f"Generation attempt {attempt + 1} failed: {e}") if attempt < max_attempts - 1: time.sleep(5) - - if not summary_data or summary_data.get('confidence', 0) == 0: - logger.error(f"FAILED to generate valid summary for {video_id} after {max_attempts} attempts") + + if not summary_data or summary_data.get("confidence", 0) == 0: + logger.error( + f"FAILED to generate summary for {video_id} " + f"after {max_attempts} attempts" + ) summary_data = self._get_default_summary() - + demographics = self._get_video_demographics( video_path, audio_path, transcript_path, metadata, segments_info=None ) - - entry = { - 'video_id': video_id, - 'video_number': metadata.get('video_number', video_id), - 'duration_seconds': metadata.get('duration_seconds', 0), - 'segments_processed': None, - 'summary_short': summary_data.get('summary_short', []), - 'summary_detailed': summary_data.get('summary_detailed', ''), - 'timeline': summary_data.get('timeline', []), - 'glossary': summary_data.get('glossary', []), - 'demographics': demographics.get('demographics', []), - 'confidence': summary_data.get('confidence', 0.0) - } - - return entry - - def _process_segmented_video(self, - video_path: Path, - audio_path: Optional[Path], - transcript_path: Optional[Path], - metadata: Dict[str, Any]) -> Dict[str, Any]: - """Process medium/long video with MAP-REDUCE approach""" - video_id = metadata.get('video_id', metadata.get('video_number', 'unknown')) - duration = metadata.get('duration_seconds', 0) - + + return self._create_entry(video_id, metadata, summary_data, demographics) + + def _process_segmented_video( + self, + video_path: Path, + audio_path: Optional[Path], + transcript_path: Optional[Path], + metadata: Dict[str, Any], + ) -> Dict[str, Any]: + """Process medium/long video with MAP-REDUCE approach.""" + video_id = metadata.get("video_id", metadata.get("video_number", "unknown")) + duration = metadata.get("duration_seconds", 0) + print(f"MAP-REDUCE processing for {video_id} ({duration}s)") - - video_segments = self.segmenter.segment_video(video_path, duration, task_type='summarization') + + video_segments = self.segmenter.segment_video( + video_path, duration, task_type="summarization" + ) audio_segments = None if audio_path and audio_path.exists(): - audio_segments = self.segmenter.segment_audio(audio_path, duration, task_type='summarization') - + audio_segments = self.segmenter.segment_audio( + audio_path, duration, task_type="summarization" + ) + print(f"Created {len(video_segments)} segments") - + segment_summaries = [] for i, seg in enumerate(video_segments): try: - audio_seg_path = audio_segments[i]['segment_path'] if audio_segments else None - + audio_seg_path = ( + audio_segments[i]["segment_path"] if audio_segments else None + ) + transcript_text = "" if transcript_path and transcript_path.exists(): transcript_text = self.segmenter.extract_transcript_segment( - transcript_path, seg['start'], seg['end'] + transcript_path, seg["start"], seg["end"] ) - + segment_summary = self._generate_segment_summary( seg, audio_seg_path, transcript_text, metadata ) segment_summaries.append(segment_summary) - + except Exception as e: logger.error(f"Failed to process segment {i}: {e}") continue - - merged_summary = self._merge_segment_summaries(video_id, metadata, segment_summaries) - + + merged_summary = self._merge_segment_summaries( + video_id, metadata, segment_summaries + ) + demographics = self._get_video_demographics( - video_path, audio_path, transcript_path, metadata, - segments_info=[{'start': s['start'], 'end': s['end']} for s in video_segments] + video_path, + audio_path, + transcript_path, + metadata, + segments_info=[ + {"start": s["start"], "end": s["end"]} for s in video_segments + ], ) - + try: self.segmenter.cleanup_segments(video_segments) if audio_segments: self.segmenter.cleanup_segments(audio_segments) except Exception as e: logger.warning(f"Failed to cleanup segments: {e}") - - entry = { - 'video_id': video_id, - 'video_number': metadata.get('video_number', video_id), - 'duration_seconds': duration, - 'segments_processed': [{'start': s['start'], 'end': s['end']} for s in video_segments], - 'summary_short': merged_summary.get('summary_short', []), - 'summary_detailed': merged_summary.get('summary_detailed', ''), - 'timeline': merged_summary.get('timeline', []), - 'glossary': merged_summary.get('glossary', []), - 'demographics': demographics.get('demographics', []), - 'confidence': merged_summary.get('confidence', 0.0) - } - - return entry - - def _generate_segment_summary(self, - segment_info: Dict, - audio_path: Optional[Path], - transcript_text: str, - metadata: Dict[str, Any]) -> Dict[str, Any]: - """MAP phase: Generate summary for one segment with retry on JSON parse failure""" - seg_num = segment_info['segment_number'] - print(f"Generating summary for segment {seg_num} ({segment_info['start']}s-{segment_info['end']}s)") - + + segments_processed = [ + {"start": s["start"], "end": s["end"]} for s in video_segments + ] + return self._create_entry( + video_id, metadata, merged_summary, demographics, segments_processed + ) + + def _generate_segment_summary( + self, + segment_info: Dict, + audio_path: Optional[Path], + transcript_text: str, + metadata: Dict[str, Any], + ) -> Dict[str, Any]: + """MAP: Generate summary for one segment with retry on JSON parse.""" + seg_num = segment_info["segment_number"] + print( + f"Generating summary for segment {seg_num} " + f"({segment_info['start']}s-{segment_info['end']}s)" + ) + max_attempts = 2 - + for attempt in range(max_attempts): try: - prompt = get_map_prompt(segment_info, metadata, transcript_text, self.config) - + prompt = get_map_prompt( + segment_info, metadata, transcript_text, self.config + ) + if attempt > 0: - prompt += "\n\nPREVIOUS ATTEMPT RETURNED INVALID JSON. Requirements:\n" - prompt += "- Use commas between all object properties except the last\n" - prompt += "- Use commas between all array elements except the last\n" + prompt += ( + "\n\nPREVIOUS ATTEMPT RETURNED INVALID JSON. Requirements:\n" + ) + prompt += ( + "- Use commas between all object properties except the last\n" + ) + prompt += ( + "- Use commas between all array elements except the last\n" + ) prompt += "- Use double quotes for all strings\n" prompt += "- No trailing commas before closing brackets\n" prompt += "- Escape special characters in strings\n" - + media_files = [] - seg_path = segment_info['segment_path'] + seg_path = segment_info["segment_path"] if seg_path.exists(): - if seg_path.suffix.lower() in ['.mp4', '.avi', '.mov', '.webm', '.mkv', '.m4v']: - media_files.append(('video', seg_path)) + if seg_path.suffix.lower() in [ + ".mp4", + ".avi", + ".mov", + ".webm", + ".mkv", + ".m4v", + ]: + media_files.append(("video", seg_path)) else: - media_files.append(('audio', seg_path)) - + media_files.append(("audio", seg_path)) + if audio_path and audio_path.exists(): - media_files.append(('audio', audio_path)) - + media_files.append(("audio", audio_path)) + total_size = sum(p.stat().st_size for _, p in media_files if p.exists()) - print(f"Calling Gemini with {len(media_files)} files, total {total_size/1024/1024:.1f}MB") - - response_text = self.generate_content(media_files, prompt, video_fps=0.5) + print( + f"Calling Gemini with {len(media_files)} files, " + f"total {total_size / 1024 / 1024:.1f}MB" + ) + + response_text = self.generate_content( + media_files, prompt, video_fps=0.5 + ) summary_data = self._parse_segment_summary_response(response_text) - - if summary_data.get('confidence', 0) > 0: + + if summary_data.get("confidence", 0) > 0: return summary_data - else: - logger.warning(f"Segment {seg_num} parsing failed on attempt {attempt + 1}/{max_attempts}") - if attempt < max_attempts - 1: - time.sleep(3) - + logger.warning( + f"Segment {seg_num} parse failed " + f"attempt {attempt + 1}/{max_attempts}" + ) + if attempt < max_attempts - 1: + time.sleep(3) + except Exception as e: - logger.error(f"Segment {seg_num} generation failed on attempt {attempt + 1}/{max_attempts}: {e}") + logger.error( + f"Segment {seg_num} failed " + f"attempt {attempt + 1}/{max_attempts}: {e}" + ) if attempt < max_attempts - 1: time.sleep(3) - - logger.error(f"Failed to generate valid summary for segment {seg_num} after {max_attempts} attempts") + + logger.error( + f"Failed to generate valid summary for segment {seg_num} after {max_attempts} attempts" + ) return self._get_default_segment_summary() - - def _merge_segment_summaries(self, - video_id: str, - metadata: Dict[str, Any], - segment_summaries: List[Dict]) -> Dict[str, Any]: - """REDUCE phase: Incrementally build summary by adding segments one at a time""" + + def _merge_segment_summaries( + self, + video_id: str, + metadata: Dict[str, Any], + segment_summaries: List[Dict], + ) -> Dict[str, Any]: + """REDUCE: Build summary by adding segments one at a time.""" print(f"Streaming accumulation for {len(segment_summaries)} segments") - + if not segment_summaries: return self._get_default_summary() - + accumulated_summary = self._initialize_summary_from_segment( segment_summaries[0], video_id, metadata ) - + for i, segment in enumerate(segment_summaries[1:], start=2): print(f"Adding segment {i}/{len(segment_summaries)} to accumulated summary") - + accumulated_summary = self._add_segment_to_summary( - accumulated_summary, - segment, - video_id, + accumulated_summary, + segment, + video_id, metadata, segment_num=i, - total_segments=len(segment_summaries) + total_segments=len(segment_summaries), ) - + time.sleep(30) - + return accumulated_summary - + def _sanitize_metadata_for_prompt(self, metadata: Dict[str, Any]) -> Dict[str, Any]: - """Sanitize metadata to avoid triggering safety filters""" + """Sanitize metadata to avoid triggering safety filters.""" sanitized = metadata.copy() - title = sanitized.get('title', '') - - if title and re.search(r'\d+\s*year\s*old|minor|child', title, re.IGNORECASE): - sanitized['title'] = f"[{metadata.get('topic', 'Video')} Incident]" + title = sanitized.get("title", "") + + if title and re.search(r"\d+\s*year\s*old|minor|child", title, re.IGNORECASE): + sanitized["title"] = f"[{metadata.get('topic', 'Video')} Incident]" logger.info(f"Sanitized title to avoid safety filter: {title[:50]}...") - + return sanitized - - def _initialize_summary_from_segment(self, - first_segment: Dict, - video_id: str, - metadata: Dict[str, Any]) -> Dict[str, Any]: - """Convert first segment into initial video-level summary structure with retry""" - + + def _initialize_summary_from_segment( + self, first_segment: Dict, video_id: str, metadata: Dict[str, Any] + ) -> Dict[str, Any]: + """Convert first segment to video-level summary with retry.""" safe_metadata = self._sanitize_metadata_for_prompt(metadata) max_attempts = 3 - + for attempt in range(max_attempts): try: prompt = get_initialize_prompt(first_segment, video_id, safe_metadata) response_text = self.generate_content([], prompt, video_fps=0.5) - + if not response_text or not response_text.strip(): - logger.warning(f"Empty response on attempt {attempt + 1}, using direct conversion") + logger.warning( + f"Empty response attempt {attempt + 1}, using direct conversion" + ) if attempt == max_attempts - 1: return self._direct_convert_segment_to_summary(first_segment) time.sleep(5) continue - + summary_data = self._parse_summary_response(response_text) - - if summary_data.get('confidence', 0) > 0: + + if summary_data.get("confidence", 0) > 0: return summary_data - else: - logger.warning(f"Initialize attempt {attempt + 1}/{max_attempts} failed for {video_id}") - if attempt < max_attempts - 1: - time.sleep(5) - + logger.warning( + f"Initialize attempt {attempt + 1}/{max_attempts} " + f"failed for {video_id}" + ) + if attempt < max_attempts - 1: + time.sleep(5) + except Exception as e: - logger.error(f"Initialize attempt {attempt + 1}/{max_attempts} failed: {e}") + logger.error( + f"Initialize attempt {attempt + 1}/{max_attempts} failed: {e}" + ) if attempt < max_attempts - 1: time.sleep(5) - - logger.error(f"Failed to initialize summary for {video_id}, using direct conversion") + + logger.error( + f"Failed to initialize summary for {video_id}, using direct conversion" + ) return self._direct_convert_segment_to_summary(first_segment) - - def _add_segment_to_summary(self, - current_summary: Dict[str, Any], - new_segment: Dict, - video_id: str, - metadata: Dict[str, Any], - segment_num: int, - total_segments: int) -> Dict[str, Any]: - """Add one new segment to the accumulated summary with retry""" - + + def _add_segment_to_summary( + self, + current_summary: Dict[str, Any], + new_segment: Dict, + video_id: str, + metadata: Dict[str, Any], + segment_num: int, + total_segments: int, + ) -> Dict[str, Any]: + """Add one new segment to the accumulated summary with retry.""" safe_metadata = self._sanitize_metadata_for_prompt(metadata) max_attempts = 3 - + for attempt in range(max_attempts): try: prompt = get_streaming_update_prompt( @@ -367,130 +426,149 @@ def _add_segment_to_summary(self, safe_metadata, segment_num, total_segments, - self.config + self.config, ) - + response_text = self.generate_content([], prompt, video_fps=0.5) - + if not response_text or not response_text.strip(): - logger.warning(f"Empty response for segment {segment_num}, attempt {attempt + 1}") + logger.warning( + f"Empty response for segment {segment_num}, attempt {attempt + 1}" + ) if attempt == max_attempts - 1: return self._programmatic_merge(current_summary, new_segment) time.sleep(5) continue - + summary_data = self._parse_summary_response(response_text) - - if summary_data.get('confidence', 0) > 0: + + if summary_data.get("confidence", 0) > 0: return summary_data - else: - logger.warning(f"Merge attempt {attempt + 1}/{max_attempts} failed for segment {segment_num}") - if attempt < max_attempts - 1: - time.sleep(5) - + logger.warning( + f"Merge attempt {attempt + 1}/{max_attempts} " + f"failed for segment {segment_num}" + ) + if attempt < max_attempts - 1: + time.sleep(5) + except Exception as e: logger.error(f"Merge attempt {attempt + 1}/{max_attempts} failed: {e}") if attempt < max_attempts - 1: time.sleep(5) - + logger.error(f"Failed to merge segment {segment_num}, using programmatic merge") return self._programmatic_merge(current_summary, new_segment) - + def _direct_convert_segment_to_summary(self, segment: Dict) -> Dict[str, Any]: - """Directly convert segment format to summary format without LLM""" - summary_text = segment.get('summary', '') - - lines = [l.strip().lstrip('โ€ข-* ') for l in summary_text.split('\n') if l.strip()] - bullet_points = [l for l in lines if len(l) > 10][:5] - + """Directly convert segment format to summary format without LLM.""" + summary_text = segment.get("summary", "") + + lines = [ + line.strip().lstrip("โ€ข-* ") + for line in summary_text.split("\n") + if line.strip() + ] + bullet_points = [line for line in lines if len(line) > 10][:5] + if not bullet_points and summary_text: bullet_points = [summary_text[:200]] - + return { - 'summary_short': bullet_points, - 'summary_detailed': summary_text, - 'timeline': segment.get('mini_timeline', []), - 'glossary': self._entities_to_glossary(segment.get('entities', [])), - 'confidence': segment.get('confidence', 0.5) + "summary_short": bullet_points, + "summary_detailed": summary_text, + "timeline": segment.get("mini_timeline", []), + "glossary": self._entities_to_glossary(segment.get("entities", [])), + "confidence": segment.get("confidence", 0.5), } - + def _programmatic_merge(self, current: Dict, new_segment: Dict) -> Dict: - """Programmatically merge segments without LLM""" - new_summary = new_segment.get('summary', '') + """Programmatically merge segments without LLM.""" + new_summary = new_segment.get("summary", "") if new_summary: - lines = [l.strip().lstrip('โ€ข-* ') for l in new_summary.split('\n') - if l.strip() and len(l) > 10] - current['summary_short'].extend(lines[:3]) - + lines = [ + line.strip().lstrip("โ€ข-* ") + for line in new_summary.split("\n") + if line.strip() and len(line) > 10 + ] + current["summary_short"].extend(lines[:3]) + if new_summary: - current['summary_detailed'] += f"\n\n{new_summary}" - - current['timeline'].extend(new_segment.get('mini_timeline', [])) - - new_terms = self._entities_to_glossary(new_segment.get('entities', [])) - existing = {t['term'].lower() for t in current.get('glossary', [])} + current["summary_detailed"] += f"\n\n{new_summary}" + + current["timeline"].extend(new_segment.get("mini_timeline", [])) + + new_terms = self._entities_to_glossary(new_segment.get("entities", [])) + existing = {t["term"].lower() for t in current.get("glossary", [])} for term in new_terms: - if term['term'].lower() not in existing: - current.setdefault('glossary', []).append(term) - existing.add(term['term'].lower()) - - current['confidence'] = min(current.get('confidence', 1.0), - new_segment.get('confidence', 1.0)) - + if term["term"].lower() not in existing: + current.setdefault("glossary", []).append(term) + existing.add(term["term"].lower()) + + current["confidence"] = min( + current.get("confidence", 1.0), new_segment.get("confidence", 1.0) + ) + return current - + def _entities_to_glossary(self, entities: List) -> List[Dict]: - """Convert entities to glossary format""" + """Convert entities to glossary format.""" glossary = [] for entity in entities: if isinstance(entity, dict): - glossary.append({ - 'term': entity.get('name', entity.get('term', '')), - 'definition': entity.get('description', entity.get('definition', '')), - 'category': entity.get('type', entity.get('category', 'entity')) - }) + glossary.append( + { + "term": entity.get("name", entity.get("term", "")), + "definition": entity.get( + "description", entity.get("definition", "") + ), + "category": entity.get( + "type", entity.get("category", "entity") + ), + } + ) return glossary - - def _get_video_demographics(self, - video_path: Path, - audio_path: Optional[Path], - transcript_path: Optional[Path], - metadata: Dict[str, Any], - segments_info: Optional[List[Dict]]) -> Dict[str, Any]: - """Get expanded demographics for full video""" + + def _get_video_demographics( + self, + video_path: Path, + audio_path: Optional[Path], + transcript_path: Optional[Path], + metadata: Dict[str, Any], + segments_info: Optional[List[Dict]], + ) -> Dict[str, Any]: + """Get expanded demographics for full video.""" temp_video = None - + try: - human_demographics = metadata.get('demographics_detailed_reviewed', {}) + human_demographics = metadata.get("demographics_detailed_reviewed", {}) if not human_demographics: - logger.warning(f"No human-reviewed demographics found for {metadata.get('video_id')}") - return {'demographics': [], 'total_individuals': 0, 'confidence': 0.0} - + logger.warning( + f"No human-reviewed demographics for {metadata.get('video_id')}" + ) + return {"demographics": [], "total_individuals": 0, "confidence": 0.0} + prompt = self.demographics_expander.build_expansion_prompt( - human_demographics, - segment_info=None + human_demographics, segment_info=None ) - + media_files = [] if video_path and video_path.exists(): - media_files.append(('video', video_path)) - + media_files.append(("video", video_path)) + if audio_path and audio_path.exists(): - media_files.append(('audio', audio_path)) - + media_files.append(("audio", audio_path)) + transcript_text = self._load_transcript(transcript_path) if transcript_text: prompt += f"\n\nTRANSCRIPT SUMMARY:\n{transcript_text[:2000]}" - + response_text = self.generate_content(media_files, prompt, video_fps=0.25) - demographics_data = self.demographics_expander.parse_demographics_response(response_text) - - return demographics_data - + return self.demographics_expander.parse_demographics_response(response_text) + except Exception as e: logger.error(f"Failed to get video demographics: {e}", exc_info=True) - return {'demographics': [], 'total_individuals': 0, 'confidence': 0.0} - + return {"demographics": [], "total_individuals": 0, "confidence": 0.0} + finally: if temp_video and temp_video != video_path and temp_video.exists(): try: @@ -498,235 +576,232 @@ def _get_video_demographics(self, logger.info(f"Cleaned up temporary video: {temp_video.name}") except Exception as e: logger.warning(f"Failed to cleanup temporary video: {e}") - + def _get_default_summary(self) -> Dict[str, Any]: - """Return default summary structure when parsing fails""" + """Return default summary structure when parsing fails.""" return { - 'summary_short': [], - 'summary_detailed': 'Summary generation failed due to parsing error', - 'timeline': [], - 'glossary': [], - 'confidence': 0.0 + "summary_short": [], + "summary_detailed": "Summary generation failed due to parsing error", + "timeline": [], + "glossary": [], + "confidence": 0.0, } - + def _get_default_segment_summary(self) -> Dict[str, Any]: - """Return default segment summary structure when parsing fails""" + """Return default segment summary structure when parsing fails.""" return { - 'segment_start': '', - 'segment_end': '', - 'summary': 'Segment summary generation failed', - 'mini_timeline': [], - 'entities': [], - 'confidence': 0.0 + "segment_start": "", + "segment_end": "", + "summary": "Segment summary generation failed", + "mini_timeline": [], + "entities": [], + "confidence": 0.0, } - + def _load_transcript(self, transcript_path: Optional[Path]) -> str: - """Load and truncate transcript if needed""" + """Load and truncate transcript if needed.""" if not transcript_path or not transcript_path.exists(): return "" - + try: - with open(transcript_path, 'r', encoding='utf-8') as f: + with open(transcript_path, "r", encoding="utf-8") as f: text = f.read() - + max_length = self.config.file_processing.max_transcript_length if len(text) > max_length: print(f"Truncating transcript from {len(text)} to {max_length} chars") text = text[:max_length] + "\n...[truncated]" - + return text except Exception as e: logger.warning(f"Failed to load transcript: {e}") return "" - + def _fix_common_json_errors(self, text: str) -> str: - """Attempt to fix common JSON formatting errors from Gemini""" - text = re.sub(r',(\s*[}\]])', r'\1', text) - text = re.sub(r'(\"[^\"]*\")\s+(\"\w+\"\s*:)', r'\1,\2', text) - text = re.sub(r'(\})\s*\n\s*(\{)', r'\1,\2', text) - text = re.sub(r'(\])\s*\n\s*(\[)', r'\1,\2', text) - text = re.sub(r'(\})\s*\n\s*(\[)', r'\1,\2', text) - text = re.sub(r'(\])\s*\n\s*(\{)', r'\1,\2', text) - text = re.sub(r'("\s*)\n(\s*")', r'\1,\2', text) - text = re.sub(r'(\":\s*\"[^\"]*\")\s+(\"\w+\":)', r'\1,\2', text) - text = re.sub(r'(\":\s*\d+\.?\d*)\s+(\"\w+\":)', r'\1,\2', text) - text = re.sub(r'(\":\s*(?:true|false|null))\s+(\"\w+\":)', r'\1,\2', text) - text = re.sub(r'(\])\s+(\"\w+\":)', r'\1,\2', text) - text = re.sub(r'(\})\s+(\"\w+\":)', r'\1,\2', text) - text = re.sub(r',\s*,', r',', text) - + """Attempt to fix common JSON formatting errors from Gemini.""" + text = re.sub(r",(\s*[}\]])", r"\1", text) + text = re.sub(r"(\"[^\"]*\")\s+(\"\w+\"\s*:)", r"\1,\2", text) + text = re.sub(r"(\})\s*\n\s*(\{)", r"\1,\2", text) + text = re.sub(r"(\])\s*\n\s*(\[)", r"\1,\2", text) + text = re.sub(r"(\})\s*\n\s*(\[)", r"\1,\2", text) + text = re.sub(r"(\])\s*\n\s*(\{)", r"\1,\2", text) + text = re.sub(r'("\s*)\n(\s*")', r"\1,\2", text) + text = re.sub(r"(\":\s*\"[^\"]*\")\s+(\"\w+\":)", r"\1,\2", text) + text = re.sub(r"(\":\s*\d+\.?\d*)\s+(\"\w+\":)", r"\1,\2", text) + text = re.sub(r"(\":\s*(?:true|false|null))\s+(\"\w+\":)", r"\1,\2", text) + text = re.sub(r"(\])\s+(\"\w+\":)", r"\1,\2", text) + text = re.sub(r"(\})\s+(\"\w+\":)", r"\1,\2", text) + return re.sub(r",\s*,", r",", text) + + def _extract_json_from_markdown(self, text: str) -> str: + """Extract JSON from markdown code blocks.""" + text = text.strip() + if "```json" in text: + start = text.find("```json") + 7 + end = text.rfind("```") + if end > start: + return text[start:end].strip() + elif "```" in text: + start = text.find("```") + 3 + end = text.rfind("```") + if end > start: + return text[start:end].strip() return text - + + def _log_json_error_details( + self, error: json.JSONDecodeError, corrected_text: str, original_text: str + ) -> None: + """Log detailed JSON error information.""" + error_line = getattr(error, "lineno", 0) + error_col = getattr(error, "colno", 0) + + if error_line > 0: + lines = corrected_text.split("\n") + start_line = max(0, error_line - 3) + end_line = min(len(lines), error_line + 2) + + logger.error(f"Error at line {error_line}, column {error_col}:") + for i in range(start_line, end_line): + if i < len(lines): + marker = " >>> " if i == error_line - 1 else " " + logger.error(f"{marker}Line {i + 1}: {lines[i][:200]}") + + debug_file = Path(f"debug_json_error_{int(time.time())}.txt") + with open(debug_file, "w") as f: + f.write("=== ORIGINAL ===\n") + f.write(original_text) + f.write("\n\n=== CORRECTED ===\n") + f.write(corrected_text) + logger.error(f"Saved problematic JSON to {debug_file}") + + def _parse_json_with_retry(self, response_text: str) -> Optional[Dict[str, Any]]: + """Parse JSON with auto-fix retry and error logging.""" + try: + return json.loads(response_text) + except json.JSONDecodeError as e: + logger.warning(f"JSON decode error, attempting auto-fix: {e}") + corrected_text = self._fix_common_json_errors(response_text) + + try: + data = json.loads(corrected_text) + logger.info("Successfully fixed JSON formatting!") + return data + except json.JSONDecodeError as e2: + logger.error(f"JSON still invalid after auto-fix: {e2}") + self._log_json_error_details(e2, corrected_text, response_text) + return None + + def _normalize_parsed_data(self, data: Any) -> Optional[Dict[str, Any]]: + """Normalize parsed data to dict format.""" + if isinstance(data, list): + if len(data) > 0 and isinstance(data[0], dict): + logger.warning("API returned list, extracting first element") + return data[0] + logger.error("API returned invalid list format") + return None + + if not isinstance(data, dict): + logger.error(f"Parsed data is not a dict, got {type(data)}") + return None + + return data + def _parse_summary_response(self, response_text: str) -> Dict[str, Any]: - """Parse summary response from Gemini with enhanced error handling""" + """Parse summary response from Gemini with enhanced error handling.""" try: if not response_text or not response_text.strip(): logger.warning("Empty summary response received") return self._get_default_summary() - - response_text = response_text.strip() - - if "```json" in response_text: - start = response_text.find("```json") + 7 - end = response_text.rfind("```") - if end > start: - response_text = response_text[start:end] - elif "```" in response_text: - start = response_text.find("```") + 3 - end = response_text.rfind("```") - if end > start: - response_text = response_text[start:end] - - response_text = response_text.strip() - - try: - data = json.loads(response_text) - except json.JSONDecodeError as e: - logger.warning(f"JSON decode error, attempting auto-fix: {e}") - - corrected_text = self._fix_common_json_errors(response_text) - - try: - data = json.loads(corrected_text) - logger.info("Successfully fixed JSON formatting!") - except json.JSONDecodeError as e2: - logger.error(f"JSON still invalid after auto-fix: {e2}") - - error_line = getattr(e2, 'lineno', 0) - error_col = getattr(e2, 'colno', 0) - - if error_line > 0: - lines = corrected_text.split('\n') - start_line = max(0, error_line - 3) - end_line = min(len(lines), error_line + 2) - - logger.error(f"Error at line {error_line}, column {error_col}:") - for i in range(start_line, end_line): - if i < len(lines): - marker = " >>> " if i == error_line - 1 else " " - logger.error(f"{marker}Line {i+1}: {lines[i][:200]}") - - debug_file = Path(f"debug_json_error_{int(time.time())}.txt") - with open(debug_file, 'w') as f: - f.write("=== ORIGINAL ===\n") - f.write(response_text) - f.write("\n\n=== CORRECTED ===\n") - f.write(corrected_text) - logger.error(f"Saved problematic JSON to {debug_file}") - - return self._get_default_summary() - - if isinstance(data, list): - if len(data) > 0 and isinstance(data[0], dict): - logger.warning("API returned list, extracting first element") - data = data[0] - else: - logger.error("API returned invalid list format") - return self._get_default_summary() - - if not isinstance(data, dict): - logger.error(f"Parsed data is not a dict, got {type(data)}") + + cleaned_text = self._extract_json_from_markdown(response_text) + data = self._parse_json_with_retry(cleaned_text) + + if data is None: + return self._get_default_summary() + + normalized_data = self._normalize_parsed_data(data) + if normalized_data is None: return self._get_default_summary() - + return { - 'summary_short': data.get('summary_short', []), - 'summary_detailed': data.get('summary_detailed', ''), - 'timeline': data.get('timeline', []), - 'glossary': data.get('glossary', []), - 'confidence': float(data.get('confidence', 0.0)) + "summary_short": normalized_data.get("summary_short", []), + "summary_detailed": normalized_data.get("summary_detailed", ""), + "timeline": normalized_data.get("timeline", []), + "glossary": normalized_data.get("glossary", []), + "confidence": float(normalized_data.get("confidence", 0.0)), } - + except Exception as e: logger.error(f"Failed to parse summary response: {e}", exc_info=True) return self._get_default_summary() - + def _parse_segment_summary_response(self, response_text: str) -> Dict[str, Any]: - """Parse segment summary response from Gemini with enhanced error handling""" + """Parse segment summary response from Gemini with enhanced error handling.""" try: if not response_text or not response_text.strip(): logger.warning("Empty response text received") return self._get_default_segment_summary() - - response_text = response_text.strip() - - if "```json" in response_text: - start = response_text.find("```json") + 7 - end = response_text.rfind("```") - if end > start: - response_text = response_text[start:end] - elif "```" in response_text: - start = response_text.find("```") + 3 - end = response_text.rfind("```") - if end > start: - response_text = response_text[start:end] - - response_text = response_text.strip() - - try: - data = json.loads(response_text) - except json.JSONDecodeError as e: - logger.warning(f"JSON decode error in segment, attempting auto-fix: {e}") - - corrected_text = self._fix_common_json_errors(response_text) - - try: - data = json.loads(corrected_text) - logger.info("Successfully fixed segment JSON formatting!") - except json.JSONDecodeError as e2: - logger.error(f"Segment JSON still invalid after auto-fix: {e2}") - - error_line = getattr(e2, 'lineno', 0) - if error_line > 0: - lines = corrected_text.split('\n') - start_line = max(0, error_line - 3) - end_line = min(len(lines), error_line + 2) - - logger.error(f"Error at line {error_line}:") - for i in range(start_line, end_line): - if i < len(lines): - marker = " >>> " if i == error_line - 1 else " " - logger.error(f"{marker}Line {i+1}: {lines[i][:200]}") - - return self._get_default_segment_summary() - - if isinstance(data, list): - if len(data) > 0 and isinstance(data[0], dict): - logger.warning("API returned list in segment, extracting first element") - data = data[0] - else: - logger.error("API returned invalid list format in segment") - return self._get_default_segment_summary() - - if not isinstance(data, dict): - logger.error(f"Segment data is not a dict, got {type(data)}") + + cleaned_text = self._extract_json_from_markdown(response_text) + data = self._parse_json_with_retry(cleaned_text) + + if data is None: return self._get_default_segment_summary() - + + normalized_data = self._normalize_parsed_data(data) + if normalized_data is None: + return self._get_default_segment_summary() + return { - 'segment_start': data.get('segment_start', ''), - 'segment_end': data.get('segment_end', ''), - 'summary': data.get('summary', ''), - 'mini_timeline': data.get('mini_timeline', []), - 'entities': data.get('entities', []), - 'confidence': float(data.get('confidence', 0.0)) + "segment_start": normalized_data.get("segment_start", ""), + "segment_end": normalized_data.get("segment_end", ""), + "summary": normalized_data.get("summary", ""), + "mini_timeline": normalized_data.get("mini_timeline", []), + "entities": normalized_data.get("entities", []), + "confidence": float(normalized_data.get("confidence", 0.0)), } - + except Exception as e: logger.error(f"Failed to parse segment summary: {e}", exc_info=True) return self._get_default_segment_summary() - - def _get_error_entry(self, metadata: Dict[str, Any], error_msg: str) -> Dict[str, Any]: - """Return error entry structure""" - video_id = metadata.get('video_id', metadata.get('video_number', 'unknown')) + + def _create_entry( + self, + video_id: str, + metadata: Dict[str, Any], + summary_data: Dict[str, Any], + demographics: Dict[str, Any], + segments_processed: Optional[List[Dict]] = None, + ) -> Dict[str, Any]: + """Assemble a Task 1 VQA entry from summary and demographics data.""" + return { + "video_id": video_id, + "video_number": metadata.get("video_number", video_id), + "duration_seconds": metadata.get("duration_seconds", 0), + "segments_processed": segments_processed, + "summary_short": summary_data.get("summary_short", []), + "summary_detailed": summary_data.get("summary_detailed", ""), + "timeline": summary_data.get("timeline", []), + "glossary": summary_data.get("glossary", []), + "demographics": demographics.get("demographics", []), + "confidence": summary_data.get("confidence", 0.0), + } + + def _get_error_entry( + self, metadata: Dict[str, Any], error_msg: str + ) -> Dict[str, Any]: + """Return error entry structure.""" + video_id = metadata.get("video_id", metadata.get("video_number", "unknown")) return { - 'video_id': video_id, - 'video_number': metadata.get('video_number', video_id), - 'duration_seconds': metadata.get('duration_seconds', 0), - 'segments_processed': None, - 'summary_short': [], - 'summary_detailed': f'Error: {error_msg}', - 'timeline': [], - 'glossary': [], - 'demographics': [], - 'confidence': 0.0, - 'error': True - } \ No newline at end of file + "video_id": video_id, + "video_number": metadata.get("video_number", video_id), + "duration_seconds": metadata.get("duration_seconds", 0), + "segments_processed": None, + "summary_short": [], + "summary_detailed": f"Error: {error_msg}", + "timeline": [], + "glossary": [], + "demographics": [], + "confidence": 0.0, + "error": True, + } diff --git a/sonic-o1/04_vqa_generation/models/temporal_localization_model.py b/sonic-o1/04_vqa_generation/models/temporal_localization_model.py index 86156ed..01ceeea 100644 --- a/sonic-o1/04_vqa_generation/models/temporal_localization_model.py +++ b/sonic-o1/04_vqa_generation/models/temporal_localization_model.py @@ -1,24 +1,35 @@ +"""temporal_localization_model.py. + +Task 3: Temporal Action Localization (Open-Ended) Generation Model. + +Author: SONIC-O1 Team """ -Task 3: Temporal Action Localization (Open-Ended) Generation Model -""" + +import base64 import json import logging -import time -from pathlib import Path -from typing import Dict, List, Any, Optional -from .base_gemini import BaseGeminiClient -from utils.video_segmenter import VideoSegmenter -from utils.demographics_expander import DemographicsExpander -from utils.frame_sampler import FrameSampler -from prompts.temporal_localization_prompts import get_temporal_localization_prompt -import base64 -import openai import os +import re import subprocess -import tempfile -import shutil -from prompts.temporal_judge_prompts import build_validation_prompt, build_batch_validation_system_prompt import sys +import tempfile +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import openai +from prompts.temporal_judge_prompts import ( + BATCH_VALIDATION_SYSTEM_PROMPT, + build_validation_prompt, +) +from prompts.temporal_localization_prompts import get_temporal_localization_prompt +from utils.demographics_expander import DemographicsExpander +from utils.frame_sampler import FrameSampler +from utils.video_segmenter import VideoSegmenter + +from .base_gemini import BaseGeminiClient + + logging.basicConfig( level=logging.INFO, # show info and above format="%(message)s", @@ -28,37 +39,51 @@ logger = logging.getLogger(__name__) + class TemporalLocalizationModel(BaseGeminiClient): - """Generate segment-level temporal localization VQA entries""" - - def __init__(self, config): + """Generate segment-level temporal localization VQA entries.""" + + def __init__(self, config, dry_run: bool = False): """ Initialize Temporal Localization model. - + Args: config: Configuration object + dry_run: If True, skip API calls and return stub responses. """ - super().__init__(config) + super().__init__(config, dry_run=dry_run) self.config = config self.segmenter = VideoSegmenter(config) self.demographics_expander = DemographicsExpander(config) - self.frame_sampler = FrameSampler(config) - - # Get questions per segment from config - self.questions_per_segment = int(config.temporal_localization.questions_per_segment) - - # Retry configuration - self.max_retries = getattr(config.temporal_localization, 'max_retries', 3) - self.retry_delay = getattr(config.temporal_localization, 'retry_delay', 2) # seconds - # Initialize GPT-4V judge + self.frame_sampler = FrameSampler(config) + + self.questions_per_segment = int( + config.temporal_localization.questions_per_segment + ) + + self.max_retries = getattr(config.temporal_localization, "max_retries", 3) + self.retry_delay = getattr(config.temporal_localization, "retry_delay", 2) + + if dry_run: + self.judge_enabled = False + self.openai_client = None + logger.info("[DRY-RUN] GPT-4V judge disabled") + return + try: - api_key = os.getenv('OPENAI_API_KEY') + api_key = os.getenv("OPENAI_API_KEY") if api_key: self.openai_client = openai.OpenAI(api_key=api_key) - self.judge_enabled = getattr(config.temporal_localization, 'judge_enabled', True) - self.judge_model = getattr(config.temporal_localization, 'judge_model', 'gpt-4o') - self.judge_frame_count = getattr(config.temporal_localization, 'judge_frame_count', 32) - self.temp_frames_dir = Path(tempfile.mkdtemp(prefix='temporal_judge_')) + self.judge_enabled = getattr( + config.temporal_localization, "judge_enabled", True + ) + self.judge_model = getattr( + config.temporal_localization, "judge_model", "gpt-4o" + ) + self.judge_frame_count = getattr( + config.temporal_localization, "judge_frame_count", 32 + ) + self.temp_frames_dir = Path(tempfile.mkdtemp(prefix="temporal_judge_")) logger.info(f"โœ“ GPT-4V judge enabled (model: {self.judge_model})") else: self.judge_enabled = False @@ -67,82 +92,97 @@ def __init__(self, config): logger.warning(f"GPT-4V judge initialization failed: {e}") self.judge_enabled = False - def process_video(self, - video_path: Path, - audio_path: Optional[Path], - transcript_path: Optional[Path], - metadata: Dict[str, Any]) -> List[Dict[str, Any]]: + def process_video( + self, + video_path: Path, + audio_path: Optional[Path], + transcript_path: Optional[Path], + metadata: Dict[str, Any], + ) -> List[Dict[str, Any]]: """ Process a video and generate temporal localization VQA entries. - + Args: - video_path: Path to video file - audio_path: Path to audio file (optional) - transcript_path: Path to transcript/caption file (optional) - metadata: Video metadata from metadata_enhanced.json - - Returns: - List of VQA entry dicts for Task 3 (one entry per segment, each with multiple questions) + video_path: Path to video file. + audio_path: Path to audio file (optional). + transcript_path: Path to transcript/caption file (optional). + metadata: Video metadata from metadata_enhanced.json. + + Returns + ------- + List of VQA entry dicts for Task 3 (one per segment, multi-Q). """ # Track segments to cleanup AFTER processing completes video_segments = None audio_segments = None - + try: - video_id = metadata.get('video_id', metadata.get('video_number', 'unknown')) - duration = metadata.get('duration_seconds', 0) - - logger.info(f"Processing video {video_id} for temporal localization (duration: {duration}s)") - + video_id = metadata.get("video_id", metadata.get("video_number", "unknown")) + duration = metadata.get("duration_seconds", 0) + + logger.info( + f"Processing video {video_id} for temporal localization " + f"(duration: {duration}s)" + ) + # Always segment videos into 3-minute chunks video_segments = self.segmenter.segment_video( - video_path, - duration, - task_type='temporal_localization' + video_path, duration, task_type="temporal_localization" ) audio_segments = None if audio_path and audio_path.exists(): audio_segments = self.segmenter.segment_audio( - audio_path, - duration, - task_type='temporal_localization' + audio_path, duration, task_type="temporal_localization" ) - - logger.info(f"Created {len(video_segments)} segments for temporal localization") - + + logger.info( + f"Created {len(video_segments)} segments for temporal localization" + ) + # Generate temporal questions for each segment temporal_entries = [] for i, seg in enumerate(video_segments): try: # Get corresponding audio segment - audio_seg_path = audio_segments[i]['segment_path'] if audio_segments else None - + audio_seg_path = ( + audio_segments[i]["segment_path"] if audio_segments else None + ) + # Extract transcript for this segment transcript_text = "" if transcript_path and transcript_path.exists(): transcript_text = self.segmenter.extract_transcript_segment( - transcript_path, seg['start'], seg['end'] + transcript_path, seg["start"], seg["end"] ) - + # Generate temporal questions for this segment with retry segment_entry = self._generate_temporal_questions_with_retry( seg, audio_seg_path, transcript_text, metadata ) - + if segment_entry: temporal_entries.append(segment_entry) - + except Exception as e: - logger.error(f"Failed to generate temporal questions for segment {i}: {e}") + logger.error( + f"Failed to generate temporal questions for segment {i}: {e}" + ) continue - - logger.info(f"Generated {len(temporal_entries)} segment entries (with {sum(e['num_questions'] for e in temporal_entries)} total questions) for video {video_id}") + + nq = sum(e["num_questions"] for e in temporal_entries) + logger.info( + f"Generated {len(temporal_entries)} segment entries " + f"({nq} questions) for video {video_id}" + ) return temporal_entries - + except Exception as e: - logger.error(f"Error processing video {video_id} for temporal localization: {e}", exc_info=True) + logger.error( + f"Error processing video {video_id} for temporal localization: {e}", + exc_info=True, + ) return [] - + finally: # This ensures segments exist during GPT-4V validation logger.info("====== CLEANUP FINALLY BLOCK STARTING ======") @@ -162,470 +202,655 @@ def process_video(self, except Exception as e: logger.warning(f"Failed to cleanup frame sampler: {e}") - - def _generate_temporal_questions_with_retry(self, - segment_info: Dict, - audio_path: Optional[Path], - transcript_text: str, - metadata: Dict[str, Any]) -> Optional[Dict[str, Any]]: - """Generate temporal questions with retry logic""" - video_id = metadata.get('video_id', metadata.get('video_number', 'unknown')) - seg_num = segment_info['segment_number'] - + def _generate_temporal_questions_with_retry( + self, + segment_info: Dict, + audio_path: Optional[Path], + transcript_text: str, + metadata: Dict[str, Any], + ) -> Optional[Dict[str, Any]]: + """Generate temporal questions with retry logic.""" + video_id = metadata.get("video_id", metadata.get("video_number", "unknown")) + seg_num = segment_info["segment_number"] + last_error = None - + for attempt in range(self.max_retries): try: - logger.info(f"Attempt {attempt + 1}/{self.max_retries}: Generating temporal questions for {video_id} segment {seg_num}") - + logger.info( + f"Attempt {attempt + 1}/{self.max_retries}: " + f"Generating temporal Qs for {video_id} seg {seg_num}" + ) + result = self._generate_temporal_questions_for_segment( segment_info, audio_path, transcript_text, metadata ) - + if result: - logger.info(f"โœ“ Successfully generated temporal questions for segment {seg_num} on attempt {attempt + 1}") + logger.info( + f"โœ“ Generated temporal Qs for segment {seg_num} " + f"on attempt {attempt + 1}" + ) return result - else: - last_error = "Empty result" - + last_error = "Empty result" + except Exception as e: last_error = str(e) - logger.warning(f"Attempt {attempt + 1} failed for segment {seg_num}: {e}") - + logger.warning( + f"Attempt {attempt + 1} failed for segment {seg_num}: {e}" + ) + if attempt < self.max_retries - 1: delay = self.retry_delay * (attempt + 1) # Exponential backoff logger.info(f"Retrying in {delay}s...") time.sleep(delay) - - logger.error(f"โœ— Failed to generate temporal questions for segment {seg_num} after {self.max_retries} attempts. Last error: {last_error}") + + logger.error( + f"โœ— Failed temporal Qs for segment {seg_num} after " + f"{self.max_retries} attempts. Last error: {last_error}" + ) return None - - def _generate_temporal_questions_for_segment(self, - segment_info: Dict, - audio_path: Optional[Path], - transcript_text: str, - metadata: Dict[str, Any]) -> Optional[Dict[str, Any]]: - """Generate multiple temporal questions for a video segment - returns ONE entry with questions list""" - video_id = metadata.get('video_id', metadata.get('video_number', 'unknown')) - seg_num = segment_info['segment_number'] - - logger.info(f"Generating {self.questions_per_segment} temporal questions for {video_id} segment {seg_num}") - + + def _prepare_segment_media_files( + self, segment_info: Dict, audio_path: Optional[Path] + ) -> List[tuple]: + """Prepare media files list for segment processing.""" + media_files = [] + seg_path = segment_info["segment_path"] + if seg_path.exists(): + video_extensions = [".mp4", ".avi", ".mov", ".webm", ".mkv", ".m4v"] + if seg_path.suffix.lower() in video_extensions: + media_files.append(("video", seg_path)) + else: + media_files.append(("audio", seg_path)) + + if audio_path and audio_path.exists(): + media_files.append(("audio", audio_path)) + + return media_files + + def _build_question_entry( + self, q_data: Dict, q_idx: int, segment_start: float + ) -> Dict[str, Any]: + """Build a single question entry with absolute timestamps.""" + question_id = f"{(q_idx + 1):03d}" + answer = q_data.get("answer", {}) + answer_start_relative = answer.get("start_s") + answer_end_relative = answer.get("end_s") + + answer_start_absolute = None + answer_end_absolute = None + if answer_start_relative is not None: + answer_start_absolute = round(segment_start + answer_start_relative, 3) + if answer_end_relative is not None: + answer_end_absolute = round(segment_start + answer_end_relative, 3) + + return { + "question_id": question_id, + "question": q_data.get("question", ""), + "temporal_relation": q_data.get("temporal_relation", "after"), + "anchor_event": q_data.get("anchor_event", ""), + "target_event": q_data.get("target_event", ""), + "answer": { + "start_s": answer_start_absolute, + "end_s": answer_end_absolute, + }, + "requires_audio": q_data.get("requires_audio", False), + "confidence": q_data.get("confidence", 0.0), + "abstained": q_data.get("abstained", False), + "rationale_model": q_data.get("rationale_model", ""), + } + + def _build_questions_list( + self, questions_data: List[Dict], segment_start: float + ) -> List[Dict[str, Any]]: + """Build questions list with IDs and absolute timestamps.""" + questions_list = [] + for q_idx, q_data in enumerate(questions_data): + question_entry = self._build_question_entry(q_data, q_idx, segment_start) + questions_list.append(question_entry) + return questions_list + + def _calculate_segment_confidence(self, questions_list: List[Dict]) -> float: + """Calculate segment-level confidence as average of questions.""" + if not questions_list: + return 0.0 + return sum(q["confidence"] for q in questions_list) / len(questions_list) + + def _setup_judge_validation( + self, segment_info: Dict, video_id: str, seg_num: int + ) -> List[Path]: + """Set up GPT-4V judge validation by sampling frames.""" + video_path = segment_info.get("segment_path") + if not video_path or not Path(video_path).exists(): + return [] + + logger.info(f"Judge video path: {video_path}") + logger.info("Judge video exists: True") + logger.info(f"Judge video size: {Path(video_path).stat().st_size} bytes") + + result = subprocess.run( + [ + "ffprobe", + "-v", + "error", + "-show_entries", + "format=duration", + "-of", + "default=noprint_wrappers=1:nokey=1", + str(video_path), + ], + capture_output=True, + text=True, + check=False, + ) + actual_duration = float(result.stdout.strip()) + frame_end = min(actual_duration, segment_info["end"] - segment_info["start"]) + + cfg = self.config.temporal_localization + return self.frame_sampler.sample_frames_from_segment( + video_path=Path(video_path), + segment_start=0.0, + segment_end=frame_end, + num_frames=self.judge_frame_count, + strategy=cfg.judge_frame_strategy, + ) + + def _apply_judge_validation( + self, + entry: Dict[str, Any], + segment_info: Dict, + frame_paths: List[Path], + transcript_text: str, + ) -> None: + """Apply GPT-4V judge validation to questions.""" + if not frame_paths: + return + + validated_questions, validation_stats = self._validate_questions( + entry["questions"], + { + "start": segment_info["start"], + "end": segment_info["end"], + }, + frame_paths, + transcript_text, + ) + + entry["questions"] = validated_questions + entry["num_questions"] = len(validated_questions) + entry["validation"] = validation_stats + entry["validation"]["judge_used"] = True + + if validation_stats["total"] > 0: + val_rate = validation_stats["valid"] / validation_stats["total"] + entry["confidence"] = round(entry["confidence"] * (0.7 + 0.3 * val_rate), 3) + + v, t, f = ( + validation_stats["valid"], + validation_stats["total"], + validation_stats["fixed"], + ) + logger.info(f"โœ“ {v}/{t} valid, {f} fixed") + + def _cleanup_frame_paths(self, frame_paths: List[Path]) -> None: + """Clean up temporary frame files.""" + for fp in frame_paths: + try: + if fp.exists(): + fp.unlink() + except Exception as e: + logger.debug(f"Failed to delete frame {fp}: {e}") + + def _generate_temporal_questions_for_segment( + self, + segment_info: Dict, + audio_path: Optional[Path], + transcript_text: str, + metadata: Dict[str, Any], + ) -> Optional[Dict[str, Any]]: + """ + Generate temporal questions for segment. + + Returns one entry with questions list. + """ + video_id = metadata.get("video_id", metadata.get("video_number", "unknown")) + seg_num = segment_info["segment_number"] + + logger.info( + f"Generating {self.questions_per_segment} temporal Qs " + f"for {video_id} segment {seg_num}" + ) + try: - # Build temporal localization prompt prompt = get_temporal_localization_prompt( - segment_info, - metadata, - transcript_text, - self.config + segment_info, metadata, transcript_text, self.config ) - - # Prepare media files - media_files = [] - seg_path = segment_info['segment_path'] - if seg_path.exists(): - # Determine if it's video or audio - if seg_path.suffix.lower() in ['.mp4', '.avi', '.mov', '.webm', '.mkv', '.m4v']: - media_files.append(('video', seg_path)) - else: - media_files.append(('audio', seg_path)) - - if audio_path and audio_path.exists(): - media_files.append(('audio', audio_path)) - - # Generate temporal questions - response_text = self.generate_content(media_files, prompt,video_fps=0.5) - print(f"๐Ÿ” Segment {seg_num}: start={segment_info['start']}, end={segment_info['end']}, duration={segment_info['end']-segment_info['start']}") - print(f"๐Ÿ“น Segment file path: {seg_path}") - print(f"๐Ÿ“น Segment file exists: {seg_path.exists()}") - if seg_path.exists(): - import subprocess - result = subprocess.run(['ffprobe', '-v', 'error', '-show_entries', - 'format=duration', '-of', - 'default=noprint_wrappers=1:nokey=1', str(seg_path)], - capture_output=True, text=True) - actual_duration = float(result.stdout.strip()) - print(f"๐Ÿ“น ACTUAL segment file duration: {actual_duration}s") - print(f"๐Ÿ“น EXPECTED duration: {segment_info['duration']}s") - - print(f"๐Ÿ“ RAW MODEL RESPONSE:\n{response_text}") - - + + media_files = self._prepare_segment_media_files(segment_info, audio_path) + response_text = self.generate_content(media_files, prompt, video_fps=0.5) + + seg_path = segment_info["segment_path"] + dur = segment_info["end"] - segment_info["start"] + logger.info( + f"Segment {seg_num}: start={segment_info['start']}, " + f"end={segment_info['end']}, duration={dur}" + ) + questions_data = self._parse_temporal_response(response_text) - + for i, q in enumerate(questions_data): - logger.info(f"Question {i+1}: start_s={q.get('answer',{}).get('start_s')}, end_s={q.get('answer',{}).get('end_s')}") - - # Get demographics for this segment (shared across all questions) + ans = q.get("answer", {}) + logger.info( + f"Question {i + 1}: start_s={ans.get('start_s')}, " + f"end_s={ans.get('end_s')}" + ) + demographics_data = self._get_segment_demographics( segment_info, seg_path, audio_path, transcript_text, metadata ) - - # Build questions list with IDs and absolute timestamps - questions_list = [] - segment_start = segment_info['start'] # Get segment start time - - for q_idx, q_data in enumerate(questions_data): - # Generate question ID with zero-padded index - question_id = f"{(q_idx+1):03d}" # 001, 002, 003 - - # Get answer times (relative to segment) - answer_start_relative = q_data.get('answer', {}).get('start_s') - answer_end_relative = q_data.get('answer', {}).get('end_s') - - # Convert to absolute times (relative to full video) - answer_start_absolute = None - answer_end_absolute = None - if answer_start_relative is not None: - answer_start_absolute = round(segment_start + answer_start_relative, 3) - if answer_end_relative is not None: - answer_end_absolute = round(segment_start + answer_end_relative, 3) - - question_entry = { - 'question_id': question_id, - 'question': q_data.get('question', ''), - 'temporal_relation': q_data.get('temporal_relation', 'after'), - 'anchor_event': q_data.get('anchor_event', ''), - 'target_event': q_data.get('target_event', ''), - 'answer': { - 'start_s': answer_start_absolute, # Absolute timestamp - 'end_s': answer_end_absolute # Absolute timestamp - }, - 'requires_audio': q_data.get('requires_audio', False), - 'confidence': q_data.get('confidence', 0.0), - 'abstained': q_data.get('abstained', False), - 'rationale_model': q_data.get('rationale_model', '') - } - - questions_list.append(question_entry) - - # Calculate segment-level confidence (average of all questions) - segment_confidence = 0.0 - if questions_list: - segment_confidence = sum(q['confidence'] for q in questions_list) / len(questions_list) - - # Build single entry for this segment + + segment_start = segment_info["start"] + questions_list = self._build_questions_list(questions_data, segment_start) + segment_confidence = self._calculate_segment_confidence(questions_list) + entry = { - 'video_id': video_id, - 'video_number': metadata.get('video_number', video_id), - 'segment': { - 'start': segment_info['start'], - 'end': segment_info['end'] - }, - 'questions': questions_list, # List of questions - 'num_questions': len(questions_list), - 'confidence': round(segment_confidence, 3), # Segment-level confidence - - # Segment-level demographics (shared across all questions) - 'demographics': demographics_data.get('demographics', []), - 'demographics_total_individuals': demographics_data.get('total_individuals', 0), - 'demographics_confidence': demographics_data.get('confidence', 0.0), - 'demographics_explanation': demographics_data.get('explanation', '') + "video_id": video_id, + "video_number": metadata.get("video_number", video_id), + "segment": {"start": segment_info["start"], "end": segment_info["end"]}, + "questions": questions_list, + "num_questions": len(questions_list), + "confidence": round(segment_confidence, 3), + "demographics": demographics_data.get("demographics", []), + "demographics_total_individuals": demographics_data.get( + "total_individuals", 0 + ), + "demographics_confidence": demographics_data.get("confidence", 0.0), + "demographics_explanation": demographics_data.get("explanation", ""), } if self.judge_enabled: logger.info(f"[{video_id} seg {seg_num}] Validating with GPT-4V...") frame_paths = [] try: - video_path = segment_info.get('segment_path') - logger.info(f"๐Ÿ” Judge video path: {video_path}") - logger.info(f"๐Ÿ” Judge video exists: {Path(video_path).exists() if video_path else False}") - if video_path and Path(video_path).exists(): - logger.info(f"๐Ÿ” Judge video size: {Path(video_path).stat().st_size} bytes") - - if video_path and Path(video_path).exists(): - import subprocess - result = subprocess.run(['ffprobe', '-v', 'error', '-show_entries', - 'format=duration', '-of', - 'default=noprint_wrappers=1:nokey=1', str(video_path)], - capture_output=True, text=True) - actual_duration = float(result.stdout.strip()) - frame_end = min(actual_duration, segment_info['end'] - segment_info['start']) - - frame_paths = self.frame_sampler.sample_frames_from_segment( - video_path=Path(video_path), - segment_start=0.0, # Segment file always starts at 0 - segment_end=frame_end,# Duration of segment file - num_frames=self.judge_frame_count, - strategy=self.config.temporal_localization.judge_frame_strategy + frame_paths = self._setup_judge_validation( + segment_info, video_id, seg_num + ) + if frame_paths: + self._apply_judge_validation( + entry, segment_info, frame_paths, transcript_text ) - - if frame_paths: - validated_questions, validation_stats = self._validate_questions( - entry['questions'], - {'start': segment_info['start'], 'end': segment_info['end']}, - frame_paths, - transcript_text - ) - entry['questions'] = validated_questions - entry['num_questions'] = len(validated_questions) - entry['validation'] = validation_stats - entry['validation']['judge_used'] = True - - if validation_stats['total'] > 0: - val_rate = validation_stats['valid'] / validation_stats['total'] - entry['confidence'] = round(entry['confidence'] * (0.7 + 0.3 * val_rate), 3) - - logger.info(f"โœ“ {validation_stats['valid']}/{validation_stats['total']} valid, {validation_stats['fixed']} fixed") except Exception as e: logger.error(f"Validation error: {e}") - finally: - # Clean up ONLY the frames from this segment - for fp in frame_paths: - try: - if fp.exists(): - fp.unlink() - except Exception as e: - logger.debug(f"Failed to delete frame {fp}: {e}") + self._cleanup_frame_paths(frame_paths) return entry - + except Exception as e: - logger.error(f"Failed to generate temporal questions for segment {seg_num}: {e}", exc_info=True) - return None # Return None to trigger retry - - def _get_segment_demographics(self, - segment_info: Dict, - video_path: Path, - audio_path: Optional[Path], - transcript_text: str, - metadata: Dict[str, Any]) -> Dict[str, Any]: - """Get expanded demographics for a specific segment (same as MCQ)""" + logger.error( + f"Failed to generate temporal questions for segment {seg_num}: {e}", + exc_info=True, + ) + return None + + def _get_segment_demographics( + self, + segment_info: Dict, + video_path: Path, + audio_path: Optional[Path], + transcript_text: str, + metadata: Dict[str, Any], + ) -> Dict[str, Any]: + """Get expanded demographics for a specific segment (same as MCQ).""" try: # Get human-reviewed demographics from metadata (video-level) - human_demographics = metadata.get('demographics_detailed_reviewed', {}) + human_demographics = metadata.get("demographics_detailed_reviewed", {}) if not human_demographics: - logger.warning(f"No human-reviewed demographics found for {metadata.get('video_id')}") + logger.warning( + f"No human-reviewed demographics for {metadata.get('video_id')}" + ) return { - 'demographics': [], - 'total_individuals': 0, - 'confidence': 0.0, - 'explanation': 'No human-reviewed demographics available' + "demographics": [], + "total_individuals": 0, + "confidence": 0.0, + "explanation": "No human-reviewed demographics available", } - + # Build expansion prompt (segment-level) prompt = self.demographics_expander.build_expansion_prompt( human_demographics, - segment_info={'start': segment_info['start'], 'end': segment_info['end']} + segment_info={ + "start": segment_info["start"], + "end": segment_info["end"], + }, ) - + # Prepare media files media_files = [] if video_path and video_path.exists(): - if video_path.suffix.lower() in ['.mp4', '.avi', '.mov', '.webm', '.mkv', '.m4v']: - media_files.append(('video', video_path)) + if video_path.suffix.lower() in [ + ".mp4", + ".avi", + ".mov", + ".webm", + ".mkv", + ".m4v", + ]: + media_files.append(("video", video_path)) else: - media_files.append(('audio', video_path)) - + media_files.append(("audio", video_path)) + if audio_path and audio_path.exists(): - media_files.append(('audio', audio_path)) - + media_files.append(("audio", audio_path)) + # Add transcript context to prompt if transcript_text: prompt += f"\n\nSEGMENT TRANSCRIPT:\n{transcript_text[:1000]}" - + # Generate demographics - response_text = self.generate_content(media_files, prompt,video_fps=0.25) - demographics_data = self.demographics_expander.parse_demographics_response(response_text) - + response_text = self.generate_content(media_files, prompt, video_fps=0.25) + demographics_data = self.demographics_expander.parse_demographics_response( + response_text + ) + # Ensure all fields are present - if 'explanation' not in demographics_data: - demographics_data['explanation'] = 'No explanation provided' - + if "explanation" not in demographics_data: + demographics_data["explanation"] = "No explanation provided" + return demographics_data - + except Exception as e: logger.error(f"Failed to get segment demographics: {e}") return { - 'demographics': [], - 'total_individuals': 0, - 'confidence': 0.0, - 'explanation': f'Error generating demographics: {str(e)}' + "demographics": [], + "total_individuals": 0, + "confidence": 0.0, + "explanation": f"Error generating demographics: {str(e)}", } - def _validate_with_gpt4v(self, question: Dict, segment_info: Dict, - frame_paths: List[Path], transcript_text: str) -> Dict: - """Use GPT-4V to validate question""" + def _validate_with_gpt4v( + self, + question: Dict, + segment_info: Dict, + frame_paths: List[Path], + transcript_text: str, + ) -> Dict: + """Use GPT-4V to validate question.""" try: # Encode frames image_contents = [] for i, frame_path in enumerate(frame_paths): - with open(frame_path, 'rb') as f: - image_data = base64.b64encode(f.read()).decode('utf-8') - image_contents.append({ - "type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{image_data}", "detail": "high"} - }) - image_contents.append({"type": "text", "text": f"[Frame {i+1}/{len(frame_paths)}]"}) - + with open(frame_path, "rb") as f: + image_data = base64.b64encode(f.read()).decode("utf-8") + image_contents.append( + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_data}", + "detail": "high", + }, + } + ) + image_contents.append( + {"type": "text", "text": f"[Frame {i + 1}/{len(frame_paths)}]"} + ) + # Build prompt using imported function prompt = build_validation_prompt(question, segment_info, transcript_text) - + # Call GPT-4V response = self.openai_client.chat.completions.create( model=self.judge_model, messages=[ - {"role": "system", "content": build_batch_validation_system_prompt()}, - {"role": "user", "content": [{"type": "text", "text": prompt}] + image_contents} + { + "role": "system", + "content": BATCH_VALIDATION_SYSTEM_PROMPT, + }, + { + "role": "user", + "content": [{"type": "text", "text": prompt}] + image_contents, + }, ], max_tokens=800, temperature=0.1, - response_format={"type": "json_object"} + response_format={"type": "json_object"}, ) - + return json.loads(response.choices[0].message.content) - + except Exception as e: logger.error(f"GPT-4V error: {e}") - return {'valid': True, 'message': f'Error: {e}'} - def _convert_rationale_to_absolute(self, rationale: str, segment_start: float) -> str: - import re - + return {"valid": True, "message": f"Error: {e}"} + + def _convert_rationale_to_absolute( + self, rationale: str, segment_start: float + ) -> str: + """Convert relative timestamps in rationale to absolute timestamps.""" + def replace_timestamp(match): timestamp_str = match.group(1) try: timestamp = float(timestamp_str) absolute = segment_start + timestamp - + if absolute == int(absolute): return f"{int(absolute)}.0s" - else: - formatted = f"{absolute:.3f}".rstrip('0').rstrip('.') - return f"{formatted}s" - except: + formatted = f"{absolute:.3f}".rstrip("0").rstrip(".") + return f"{formatted}s" + except (ValueError, TypeError): return match.group(0) - - pattern = r'(?:^|(?<=[\s=,\(\.]))(\d+\.?\d*)s(?=[\s,\.\)\-โ€“]|$)' + + pattern = r"(?:^|(?<=[\s=,\(\.]))(\d+\.?\d*)s(?=[\s,\.\)\-โ€“]|$)" return re.sub(pattern, replace_timestamp, rationale) - def _validate_questions(self, questions: List[Dict], segment_info: Dict, - frame_paths: List[Path], transcript_text: str) -> tuple: - """Validate questions, fix absolute timestamps, drop invalid""" - segment_start = segment_info['start'] - segment_end = segment_info['end'] + + def _fix_double_converted_timestamps( + self, + start_s: float, + end_s: float, + segment_start: float, + segment_end: float, + segment_duration: float, + question_idx: int, + ) -> Optional[Tuple[float, float]]: + """Fix double-converted timestamps (absolute that was converted again).""" + if start_s > segment_end or end_s > segment_end: + test_start = start_s - segment_start + test_end = end_s - segment_start + + if ( + 0 <= test_start <= segment_duration + and 0 <= test_end <= segment_duration + ): + logger.warning( + f"Q{question_idx + 1}: Double-converted [{start_s}, " + f"{end_s}] โ†’ relative [{test_start:.1f}, {test_end:.1f}]" + ) + return test_start, test_end + return None + + def _fix_absolute_timestamps( + self, + start_s: float, + end_s: float, + segment_start: float, + segment_end: float, + question_idx: int, + ) -> Optional[Tuple[float, float]]: + """Convert absolute timestamps to relative if in segment range.""" + if ( + start_s >= segment_start + and start_s < segment_end + and end_s > segment_start + and end_s <= segment_end + ): + logger.warning( + f"Q{question_idx + 1}: Absolute timestamps [{start_s}, " + f"{end_s}] โ†’ converting to relative" + ) + return start_s - segment_start, end_s - segment_start + return None + + def _check_relative_bounds( + self, + start_s: float, + end_s: float, + segment_duration: float, + question_idx: int, + ) -> bool: + """Check if relative timestamps are within valid bounds.""" + if start_s < 0 or end_s > segment_duration or start_s >= end_s: + logger.warning( + f"Q{question_idx + 1}: Out of bounds [{start_s:.1f}, " + f"{end_s:.1f}] (segment: 0-{segment_duration}s)" + ) + return False + return True + + def _apply_gpt4v_corrections( + self, question: Dict, gpt_result: Dict, question_idx: int + ) -> bool: + """Apply GPT-4V timestamp corrections if available.""" + if gpt_result.get("corrected_timestamps"): + corrected = gpt_result["corrected_timestamps"] + question["answer"]["start_s"] = corrected["start_s"] + question["answer"]["end_s"] = corrected["end_s"] + reason = gpt_result.get("correction_reason", "adjusted")[:50] + question["rationale_model"] += f" [Judge: {reason}]" + return True + return False + + def _validate_questions( + self, + questions: List[Dict], + segment_info: Dict, + frame_paths: List[Path], + transcript_text: str, + ) -> Tuple[List[Dict], Dict[str, Any]]: + """Validate questions, fix absolute timestamps, drop invalid.""" + segment_start = segment_info["start"] + segment_end = segment_info["end"] segment_duration = segment_end - segment_start - + valid_questions = [] - stats = {'total': len(questions), 'valid': 0, 'fixed': 0, 'dropped': 0, 'reasons': {}} - + stats = { + "total": len(questions), + "valid": 0, + "fixed": 0, + "dropped": 0, + "reasons": {}, + } + for i, q in enumerate(questions): try: # Skip abstained - if q.get('abstained', False): - stats['dropped'] += 1 - stats['reasons']['abstained'] = stats['reasons'].get('abstained', 0) + 1 + if q.get("abstained", False): + stats["dropped"] += 1 + stats["reasons"]["abstained"] = ( + stats["reasons"].get("abstained", 0) + 1 + ) continue - - answer = q.get('answer', {}) - start_s = answer.get('start_s') - end_s = answer.get('end_s') - + + answer = q.get("answer", {}) + start_s = answer.get("start_s") + end_s = answer.get("end_s") + if start_s is None or end_s is None: - stats['dropped'] += 1 - stats['reasons']['missing_timestamps'] = stats['reasons'].get('missing_timestamps', 0) + 1 + stats["dropped"] += 1 + stats["reasons"]["missing_timestamps"] = ( + stats["reasons"].get("missing_timestamps", 0) + 1 + ) continue - - # Detect if Gemini output absolute instead of relative timestamps + fixed = False - - # Strategy: Timestamps after YOUR conversion should be: - # - In range [segment_start, segment_end] for absolute - # - OR in range [0, segment_duration] for relative that wasn't converted yet - # - # We detect Gemini mistakes by checking if timestamps look like they're - # in the segment's absolute range when they should already be absolute - # (i.e., Gemini gave us absolute, we added segment_start, now they're way off) - - if segment_start > 0: # Not first segment (can't detect for segment 0) - # Check if timestamps are suspiciously high (likely double-converted) - # OR if they look like segment-relative absolute times - - # Case 1: Way too high - definitely double-converted - if start_s > segment_end or end_s > segment_end: - # These are likely double-converted: Gemini gave absolute, we added segment_start - # Try converting back - test_start = start_s - segment_start - test_end = end_s - segment_start - - # Check if this makes sense - if 0 <= test_start <= segment_duration and 0 <= test_end <= segment_duration: - logger.warning(f"Q{i+1}: Double-converted timestamps [{start_s}, {end_s}] โ†’ fixing to relative [{test_start:.1f}, {test_end:.1f}]") - start_s = test_start - end_s = test_end - q['answer']['start_s'] = start_s - q['answer']['end_s'] = end_s - q['rationale_model'] += " [Judge: fixed double-conversion]" - fixed = True - else: - # Can't fix, drop it - stats['dropped'] += 1 - stats['reasons']['out_of_bounds'] = stats['reasons'].get('out_of_bounds', 0) + 1 - logger.warning(f"Q{i+1}: Unfixable out of bounds [{start_s}, {end_s}]") - continue - - # Case 2: In segment absolute range - Gemini gave absolute, we converted correctly - # BUT need to convert to relative for bounds checking - elif (start_s >= segment_start and start_s < segment_end and - end_s > segment_start and end_s <= segment_end): - logger.warning(f"Q{i+1}: Absolute timestamps detected [{start_s}, {end_s}] โ†’ converting to relative") - start_s = start_s - segment_start - end_s = end_s - segment_start - q['answer']['start_s'] = start_s - q['answer']['end_s'] = end_s - q['rationale_model'] += " [Judge: absoluteโ†’relative]" + + # Fix timestamp conversion issues + if segment_start > 0: + # Try fixing double-converted timestamps + fixed_times = self._fix_double_converted_timestamps( + start_s, end_s, segment_start, segment_end, segment_duration, i + ) + if fixed_times: + start_s, end_s = fixed_times + q["answer"]["start_s"] = start_s + q["answer"]["end_s"] = end_s + q["rationale_model"] += " [Judge: fixed double-conversion]" fixed = True - - # Now check bounds on what should be RELATIVE timestamps - if start_s < 0 or end_s > segment_duration or start_s >= end_s: - stats['dropped'] += 1 - stats['reasons']['out_of_bounds'] = stats['reasons'].get('out_of_bounds', 0) + 1 - logger.warning(f"Q{i+1}: Out of bounds after fixing [{start_s:.1f}, {end_s:.1f}] (segment: 0-{segment_duration}s)") + else: + # Try fixing absolute timestamps + fixed_times = self._fix_absolute_timestamps( + start_s, end_s, segment_start, segment_end, i + ) + if fixed_times: + start_s, end_s = fixed_times + q["answer"]["start_s"] = start_s + q["answer"]["end_s"] = end_s + q["rationale_model"] += " [Judge: absoluteโ†’relative]" + fixed = True + + # Check bounds on relative timestamps + if not self._check_relative_bounds(start_s, end_s, segment_duration, i): + stats["dropped"] += 1 + stats["reasons"]["out_of_bounds"] = ( + stats["reasons"].get("out_of_bounds", 0) + 1 + ) continue - + # GPT-4V validation (only if bounds check passed) if self.judge_enabled and frame_paths: - gpt_result = self._validate_with_gpt4v(q, segment_info, frame_paths, transcript_text) - - if not gpt_result.get('valid', False): - stats['dropped'] += 1 - reason = gpt_result.get('reason', 'gpt4v_rejected') - stats['reasons'][reason] = stats['reasons'].get(reason, 0) + 1 - logger.warning(f"Q{i+1}: GPT-4V rejected - {gpt_result.get('message', '')[:100]}") + gpt_result = self._validate_with_gpt4v( + q, segment_info, frame_paths, transcript_text + ) + + if not gpt_result.get("valid", False): + stats["dropped"] += 1 + reason = gpt_result.get("reason", "gpt4v_rejected") + stats["reasons"][reason] = stats["reasons"].get(reason, 0) + 1 + msg = gpt_result.get("message", "")[:100] + logger.warning(f"Q{i + 1}: GPT-4V rejected - {msg}") continue - - if gpt_result.get('corrected_timestamps'): - corrected = gpt_result['corrected_timestamps'] - q['answer']['start_s'] = corrected['start_s'] - q['answer']['end_s'] = corrected['end_s'] - q['rationale_model'] += f" [Judge: {gpt_result.get('correction_reason', 'adjusted')[:50]}]" + + if self._apply_gpt4v_corrections(q, gpt_result, i): fixed = True - # After all validation passes, convert back to absolute + start_s = q["answer"]["start_s"] + end_s = q["answer"]["end_s"] + + # Convert back to absolute timestamps answer_start_absolute = round(segment_start + start_s, 3) answer_end_absolute = round(segment_start + end_s, 3) - q['answer']['start_s'] = answer_start_absolute - q['answer']['end_s'] = answer_end_absolute - q['rationale_model'] = self._convert_rationale_to_absolute(q['rationale_model'], segment_start) - # Question passed validation + q["answer"]["start_s"] = answer_start_absolute + q["answer"]["end_s"] = answer_end_absolute + q["rationale_model"] = self._convert_rationale_to_absolute( + q["rationale_model"], segment_start + ) + valid_questions.append(q) - stats['valid'] += 1 + stats["valid"] += 1 if fixed: - stats['fixed'] += 1 - stats['reasons']['fixed_timestamps'] = stats['reasons'].get('fixed_timestamps', 0) + 1 - + stats["fixed"] += 1 + stats["reasons"]["fixed_timestamps"] = ( + stats["reasons"].get("fixed_timestamps", 0) + 1 + ) + except Exception as e: - logger.error(f"Q{i+1} validation error: {e}") - stats['dropped'] += 1 - stats['reasons']['validation_error'] = stats['reasons'].get('validation_error', 0) + 1 - - logger.info(f"Validation: {stats['valid']} valid, {stats['fixed']} fixed, {stats['dropped']} dropped") + logger.error(f"Q{i + 1} validation error: {e}") + stats["dropped"] += 1 + stats["reasons"]["validation_error"] = ( + stats["reasons"].get("validation_error", 0) + 1 + ) + + logger.info( + f"Validation: {stats['valid']} valid, {stats['fixed']} fixed, " + f"{stats['dropped']} dropped" + ) return valid_questions, stats def _parse_temporal_response(self, response_text: str) -> List[Dict[str, Any]]: - """Parse temporal localization response from Gemini - ORIGINAL VERSION""" + """Parse temporal localization response from Gemini.""" try: response_text = response_text.strip() - + # Remove markdown code blocks if "```json" in response_text: start = response_text.find("```json") + 7 @@ -637,87 +862,96 @@ def _parse_temporal_response(self, response_text: str) -> List[Dict[str, Any]]: end = response_text.rfind("```") if end > start: response_text = response_text[start:end] - + data = json.loads(response_text.strip()) - + # Ensure it's a list if isinstance(data, dict): data = [data] - + # Validate and clean each question validated_questions = [] for i, q in enumerate(data): validated_q = self._validate_temporal_question(q, i) validated_questions.append(validated_q) - + # Ensure we have exactly the expected number of questions while len(validated_questions) < self.questions_per_segment: - validated_questions.append(self._get_default_temporal_question(len(validated_questions))) - - return validated_questions[:self.questions_per_segment] - + validated_questions.append( + self._get_default_temporal_question(len(validated_questions)) + ) + + return validated_questions[: self.questions_per_segment] + except json.JSONDecodeError as e: logger.error(f"Failed to parse temporal JSON: {e}") logger.debug(f"Response text: {response_text[:500]}...") - return [self._get_default_temporal_question(i) for i in range(self.questions_per_segment)] + return [ + self._get_default_temporal_question(i) + for i in range(self.questions_per_segment) + ] except Exception as e: logger.error(f"Error parsing temporal response: {e}") - return [self._get_default_temporal_question(i) for i in range(self.questions_per_segment)] - - def _validate_temporal_question(self, question_data: Dict, index: int) -> Dict[str, Any]: - """Validate and clean a temporal question""" + return [ + self._get_default_temporal_question(i) + for i in range(self.questions_per_segment) + ] + + def _validate_temporal_question( + self, question_data: Dict, index: int + ) -> Dict[str, Any]: + """Validate and clean a temporal question.""" try: # Extract answer times - answer = question_data.get('answer', {}) - start_s = answer.get('start_s') - end_s = answer.get('end_s') - + answer = question_data.get("answer", {}) + start_s = answer.get("start_s") + end_s = answer.get("end_s") + # Determine if abstained - abstained = question_data.get('abstained', False) + abstained = question_data.get("abstained", False) if start_s is None or end_s is None: abstained = True - + # Validate temporal_relation - valid_relations = ['after', 'once_finished', 'next', 'during', 'before'] - temporal_relation = question_data.get('temporal_relation', 'after') + valid_relations = ["after", "once_finished", "next", "during", "before"] + temporal_relation = question_data.get("temporal_relation", "after") if temporal_relation not in valid_relations: - logger.warning(f"Invalid temporal_relation '{temporal_relation}', defaulting to 'after'") - temporal_relation = 'after' - + logger.warning( + f"Invalid temporal_relation '{temporal_relation}', " + f"defaulting to 'after'" + ) + temporal_relation = "after" + return { - 'question_index': question_data.get('question_index', index), - 'question': question_data.get('question', 'When does an event occur?'), - 'temporal_relation': temporal_relation, - 'anchor_event': question_data.get('anchor_event', 'Unknown anchor'), - 'target_event': question_data.get('target_event', 'Unknown target'), - 'answer': { - 'start_s': start_s, - 'end_s': end_s - }, - 'requires_audio': bool(question_data.get('requires_audio', False)), - 'confidence': float(question_data.get('confidence', 0.0)), - 'abstained': abstained, - 'rationale_model': question_data.get('rationale_model', 'No rationale provided') + "question_index": question_data.get("question_index", index), + "question": question_data.get("question", "When does an event occur?"), + "temporal_relation": temporal_relation, + "anchor_event": question_data.get("anchor_event", "Unknown anchor"), + "target_event": question_data.get("target_event", "Unknown target"), + "answer": {"start_s": start_s, "end_s": end_s}, + "requires_audio": bool(question_data.get("requires_audio", False)), + "confidence": float(question_data.get("confidence", 0.0)), + "abstained": abstained, + "rationale_model": question_data.get( + "rationale_model", "No rationale provided" + ), } - + except Exception as e: logger.error(f"Error validating temporal question: {e}") return self._get_default_temporal_question(index) - + def _get_default_temporal_question(self, index: int) -> Dict[str, Any]: - """Return default temporal question structure when parsing fails""" + """Return default temporal question structure when parsing fails.""" return { - 'question_index': index, - 'question': 'When does an event occur in this segment?', - 'temporal_relation': 'after', - 'anchor_event': 'Unable to identify anchor', - 'target_event': 'Unable to identify target', - 'answer': { - 'start_s': None, - 'end_s': None - }, - 'requires_audio': False, - 'confidence': 0.0, - 'abstained': True, - 'rationale_model': 'Failed to generate temporal question' - } \ No newline at end of file + "question_index": index, + "question": "When does an event occur in this segment?", + "temporal_relation": "after", + "anchor_event": "Unable to identify anchor", + "target_event": "Unable to identify target", + "answer": {"start_s": None, "end_s": None}, + "requires_audio": False, + "confidence": 0.0, + "abstained": True, + "rationale_model": "Failed to generate temporal question", + } diff --git a/sonic-o1/04_vqa_generation/models/temporal_question_judge.py b/sonic-o1/04_vqa_generation/models/temporal_question_judge.py index f231267..cbb8fd2 100644 --- a/sonic-o1/04_vqa_generation/models/temporal_question_judge.py +++ b/sonic-o1/04_vqa_generation/models/temporal_question_judge.py @@ -1,339 +1,444 @@ -""" -GPT-4V Temporal Question Judge +"""temporal_question_judge.py. + +GPT-4V Temporal Question Judge. -Validates temporal localization questions by: +Validate temporal localization questions by: 1. Checking if timestamps are within segment bounds 2. Verifying events exist in the video frames 3. Validating temporal relationships 4. Attempting to fix correctable errors + +Author: SONIC-O1 Team """ + +import base64 import json import logging -import base64 +import os from pathlib import Path -from typing import Dict, List, Any, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple + import openai -import os + logger = logging.getLogger(__name__) +VALIDATION_PROMPT_TEMPLATE = """\ +You are validating a temporal localization question generated for a video segment. + +SEGMENT INFO: +- Segment absolute time: {segment_start}s to {segment_end}s (duration: {segment_duration}s) +- All timestamps in the question should be RELATIVE to segment start (0.0s to {segment_duration}s) + +QUESTION TO VALIDATE: +{question_json} + +TRANSCRIPT (if available): +{transcript_text} + +VALIDATION CRITERIA: + +1. **Event Existence**: Do BOTH the anchor event and target event actually exist in the frames you see? + - Check if the described events are visible or can be inferred from the frames + - For audio events (requires_audio=true), check if transcript supports the events + +2. **Timestamp Accuracy**: Are the provided timestamps [{start_s}s, {end_s}s] reasonable? + - Timestamps should be in SEGMENT-RELATIVE time (0.0 to {segment_duration}s) + - Do the frames near the target timestamps show the target event? + - Allow ยฑ5 second tolerance for minor inaccuracies + +3. **Temporal Relationship**: Does the temporal relationship make sense? + - Relation: {temporal_relation} + - Check if anchor and target have the stated relationship + +TASKS: + +1. Determine if the question is VALID (events exist, timestamps reasonable, relation makes sense) + +2. If timestamps are slightly off but events are identifiable: + - Provide corrected timestamps if you can identify better times + - Only correct if deviation is โ‰ค5 seconds + +3. If question is invalid (events don't exist, wrong relation, timestamps way off): + - Mark as invalid and provide reason + +OUTPUT FORMAT (JSON): +{{ + "valid": true/false, + "reason": "events_not_found|invalid_relation|timestamps_way_off|other", + "message": "Detailed explanation", + "corrected_timestamps": {{ + "start_s": , + "end_s": + }}, + "correction_reason": "Brief explanation of correction" +}} + +Respond with ONLY the JSON object, no other text.\ +""" + class TemporalQuestionJudge: - """GPT-4V-based judge for validating temporal localization questions""" - + """GPT-4V-based judge for validating temporal localization questions.""" + def __init__(self, config=None): """ Initialize the temporal question judge. - + Args: config: Optional configuration object """ self.config = config - + # Initialize OpenAI client - api_key = os.getenv('OPENAI_API_KEY') + api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise ValueError("OPENAI_API_KEY environment variable not set") - + self.client = openai.OpenAI(api_key=api_key) - + # GPT-4V model to use - self.model = getattr(config, 'judge_model', 'gpt-4o') if config else 'gpt-4o' - - # Validation thresholds - self.max_timestamp_deviation = 5.0 # seconds tolerance for fixing timestamps - + self.model = getattr(config, "judge_model", "gpt-4o") if config else "gpt-4o" + # Validation thresholds (seconds tolerance for fixing timestamps) + self.max_timestamp_deviation = 5.0 + def validate_segment_questions( self, questions: List[Dict[str, Any]], segment_info: Dict[str, Any], frame_paths: List[Path], - transcript_text: str = "" + transcript_text: str = "", ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: """ Validate all questions for a segment. - + Args: - questions: List of temporal questions to validate - segment_info: Segment metadata (start, end, duration) - frame_paths: List of sampled frame paths from segment - transcript_text: Optional transcript text for the segment - - Returns: - Tuple of (valid_questions, validation_stats) + questions: List of temporal questions to validate. + segment_info: Segment metadata (start, end, duration). + frame_paths: List of sampled frame paths from segment. + transcript_text: Optional transcript text for the segment. + + Returns + ------- + Tuple of (valid_questions, validation_stats). """ - segment_start = segment_info['start'] - segment_end = segment_info['end'] - segment_duration = segment_end - segment_start - - logger.info(f"Validating {len(questions)} questions for segment [{segment_start}, {segment_end}]") - + segment_start = segment_info["start"] + segment_end = segment_info["end"] + + logger.info( + f"Validating {len(questions)} questions for segment " + f"[{segment_start}, {segment_end}]" + ) + valid_questions = [] stats = { - 'total': len(questions), - 'valid': 0, - 'fixed': 0, - 'dropped': 0, - 'reasons': { - 'out_of_bounds': 0, - 'events_not_found': 0, - 'invalid_relation': 0, - 'failed_generation': 0, - 'judge_error': 0, - 'fixed_timestamps': 0, - 'abstained': 0 - } + "total": len(questions), + "valid": 0, + "fixed": 0, + "dropped": 0, + "reasons": { + "out_of_bounds": 0, + "events_not_found": 0, + "invalid_relation": 0, + "failed_generation": 0, + "judge_error": 0, + "fixed_timestamps": 0, + "abstained": 0, + }, } - + for i, question in enumerate(questions): try: - logger.info(f"Validating question {i+1}/{len(questions)}") - + logger.info(f"Validating question {i + 1}/{len(questions)}") + # Check if question is already marked as failed - if question.get('abstained', False): - stats['reasons']['abstained'] += 1 - stats['dropped'] += 1 - logger.info(f"Question {i+1} was abstained by generator, dropping") + if question.get("abstained", False): + stats["reasons"]["abstained"] += 1 + stats["dropped"] += 1 + logger.info(f"Question {i + 1} abstained by generator, dropping") continue - - if question.get('rationale_model') == 'Failed to generate temporal question': - stats['reasons']['failed_generation'] += 1 - stats['dropped'] += 1 - logger.info(f"Question {i+1} failed generation, dropping") + + if question.get("rationale_model") == ( + "Failed to generate temporal question" + ): + stats["reasons"]["failed_generation"] += 1 + stats["dropped"] += 1 + logger.info(f"Question {i + 1} failed generation, dropping") continue - + # Validate and potentially fix the question validation_result = self._validate_single_question( question, segment_info, frame_paths, transcript_text ) - - if validation_result['valid']: + + if validation_result["valid"]: # Question is valid (possibly after fixing) - valid_question = validation_result['question'] + valid_question = validation_result["question"] valid_questions.append(valid_question) - stats['valid'] += 1 - - if validation_result['fixed']: - stats['fixed'] += 1 - stats['reasons']['fixed_timestamps'] += 1 - logger.info(f"โœ“ Question {i+1} validated and fixed") + stats["valid"] += 1 + + if validation_result["fixed"]: + stats["fixed"] += 1 + stats["reasons"]["fixed_timestamps"] += 1 + logger.info(f"โœ“ Question {i + 1} validated and fixed") else: - logger.info(f"โœ“ Question {i+1} validated successfully") + logger.info(f"โœ“ Question {i + 1} validated successfully") else: # Question is invalid and cannot be fixed - stats['dropped'] += 1 - reason = validation_result.get('reason', 'unknown') - stats['reasons'][reason] = stats['reasons'].get(reason, 0) + 1 - logger.warning(f"โœ— Question {i+1} dropped: {reason}") - + stats["dropped"] += 1 + reason = validation_result.get("reason", "unknown") + stats["reasons"][reason] = stats["reasons"].get(reason, 0) + 1 + logger.warning(f"โœ— Question {i + 1} dropped: {reason}") + except Exception as e: - logger.error(f"Error validating question {i+1}: {e}") - stats['dropped'] += 1 - stats['reasons']['judge_error'] += 1 - - logger.info(f"Validation complete: {stats['valid']} valid, {stats['fixed']} fixed, {stats['dropped']} dropped") + logger.error(f"Error validating question {i + 1}: {e}") + stats["dropped"] += 1 + stats["reasons"]["judge_error"] += 1 + + logger.info( + f"Validation complete: {stats['valid']} valid, " + f"{stats['fixed']} fixed, {stats['dropped']} dropped" + ) return valid_questions, stats - + + def _check_absolute( + self, + start_s: float, + end_s: float, + segment_start: float, + segment_end: float, + segment_duration: float, + ) -> Tuple[bool, Optional[Dict[str, Any]]]: + """Determine whether timestamps are absolute (video-level) or segment-relative. + + Returns + ------- + (is_absolute, error_result): + - (True, None) โ€“ timestamps are absolute and can be converted + - (False, None) โ€“ timestamps are already valid relative values + - (False, error_dict) โ€“ timestamps are out of bounds and cannot be fixed + """ + if start_s >= segment_start and end_s <= segment_end: + logger.warning("Absolute timestamps [%s, %s], converting", start_s, end_s) + return True, None + + if start_s < 0 or end_s > segment_duration: + if ( + segment_start <= start_s <= segment_end + and segment_start <= end_s <= segment_end + ): + logger.warning( + "Timestamps [%s, %s] seem to be absolute, converting", + start_s, + end_s, + ) + return True, None + return False, { + "valid": False, + "reason": "out_of_bounds", + "message": ( + f"Timestamps [{start_s}, {end_s}] out of bounds [0, {segment_duration}]" + ), + } + + return False, None + + def _convert_to_relative( + self, + question: Dict[str, Any], + start_s: float, + end_s: float, + segment_start: float, + ) -> Dict[str, Any]: + """Convert absolute timestamps to segment-relative in-place.""" + start_rel = start_s - segment_start + end_rel = end_s - segment_start + question["answer"]["start_s"] = start_rel + question["answer"]["end_s"] = end_rel + logger.info( + "Converted: absolute [%s, %s] โ†’ relative [%.2f, %.2f]", + start_s, + end_s, + start_rel, + end_rel, + ) + return { + "valid": True, + "question": question, + "fixed": True, + "message": "Fixed absolute timestamps to relative", + } + def _validate_single_question( self, question: Dict[str, Any], segment_info: Dict[str, Any], frame_paths: List[Path], - transcript_text: str + transcript_text: str, ) -> Dict[str, Any]: """ Validate a single temporal question. - - Returns: - Dict with keys: - - valid: bool - - question: corrected question dict (if valid) - - fixed: bool (True if timestamps were corrected) - - reason: str (if invalid) + + Returns + ------- + Dict with keys: valid (bool), question (corrected dict if valid), + fixed (bool), reason (str if invalid). """ - segment_start = segment_info['start'] - segment_end = segment_info['end'] + segment_start = segment_info["start"] + segment_end = segment_info["end"] segment_duration = segment_end - segment_start - - answer = question.get('answer', {}) - start_s = answer.get('start_s') - end_s = answer.get('end_s') - - # Check 1: Timestamp bounds + + answer = question.get("answer", {}) + start_s = answer.get("start_s") + end_s = answer.get("end_s") + if start_s is None or end_s is None: return { - 'valid': False, - 'reason': 'out_of_bounds', - 'message': 'Missing timestamps' + "valid": False, + "reason": "out_of_bounds", + "message": "Missing timestamps", } - - # CRITICAL CHECK: Are timestamps in ABSOLUTE time instead of SEGMENT-RELATIVE time? - # If start_s >= segment_start, it's likely absolute time - is_absolute = False - if start_s >= segment_start and end_s <= segment_end: - # Timestamps are in absolute video time, need to convert to segment-relative - is_absolute = True - logger.warning(f"Detected absolute timestamps: [{start_s}, {end_s}], converting to relative") - - # Check if timestamps are within segment bounds (for relative time) - if not is_absolute: - if start_s < 0 or end_s > segment_duration: - # Check if adding segment_start would put them in bounds (model might have used absolute time) - if segment_start <= start_s <= segment_end and segment_start <= end_s <= segment_end: - is_absolute = True - logger.warning(f"Timestamps [{start_s}, {end_s}] seem to be absolute, converting") - else: - return { - 'valid': False, - 'reason': 'out_of_bounds', - 'message': f'Timestamps [{start_s}, {end_s}] out of segment bounds [0, {segment_duration}]' - } - - # Convert absolute to relative if needed + + is_absolute, error = self._check_absolute( + start_s, end_s, segment_start, segment_end, segment_duration + ) + if error: + return error + if is_absolute: - original_start = start_s - original_end = end_s - start_s_relative = start_s - segment_start - end_s_relative = end_s - segment_start - - # Update question with relative times - question['answer']['start_s'] = start_s_relative - question['answer']['end_s'] = end_s_relative - - logger.info(f"Converted timestamps: absolute [{original_start}, {original_end}] โ†’ relative [{start_s_relative:.2f}, {end_s_relative:.2f}]") - - return { - 'valid': True, - 'question': question, - 'fixed': True, - 'message': f'Fixed absolute timestamps to relative' - } - - # Check 2: Use GPT-4V to validate events and temporal relationship + return self._convert_to_relative(question, start_s, end_s, segment_start) + + # Use GPT-4V to validate events and temporal relationship try: gpt4v_result = self._validate_with_gpt4v( question, segment_info, frame_paths, transcript_text ) - - if gpt4v_result['valid']: - # GPT-4V validated the question - if gpt4v_result.get('corrected_timestamps'): - # GPT-4V suggested timestamp corrections - corrected = gpt4v_result['corrected_timestamps'] - question['answer']['start_s'] = corrected['start_s'] - question['answer']['end_s'] = corrected['end_s'] - question['rationale_model'] += f" [Judge corrected: {gpt4v_result.get('correction_reason', 'timestamps adjusted')}]" - - return { - 'valid': True, - 'question': question, - 'fixed': True, - 'message': 'GPT-4V corrected timestamps' - } - else: - return { - 'valid': True, - 'question': question, - 'fixed': False, - 'message': 'GPT-4V validated' - } - else: + + if not gpt4v_result["valid"]: return { - 'valid': False, - 'reason': gpt4v_result.get('reason', 'events_not_found'), - 'message': gpt4v_result.get('message', 'GPT-4V validation failed') + "valid": False, + "reason": gpt4v_result.get("reason", "events_not_found"), + "message": gpt4v_result.get("message", "GPT-4V validation failed"), } - + + fixed = False + message = "GPT-4V validated" + if gpt4v_result.get("corrected_timestamps"): + corrected = gpt4v_result["corrected_timestamps"] + question["answer"]["start_s"] = corrected["start_s"] + question["answer"]["end_s"] = corrected["end_s"] + correction_reason = gpt4v_result.get( + "correction_reason", "timestamps adjusted" + ) + question["rationale_model"] += ( + f" [Judge corrected: {correction_reason}]" + ) + fixed = True + message = "GPT-4V corrected timestamps" + + return { + "valid": True, + "question": question, + "fixed": fixed, + "message": message, + } + except Exception as e: logger.error(f"GPT-4V validation error: {e}") # If GPT-4V fails but timestamps are in bounds, keep the question return { - 'valid': True, - 'question': question, - 'fixed': False, - 'message': f'GPT-4V error, keeping question with valid timestamps: {e}' + "valid": True, + "question": question, + "fixed": False, + "message": ( + f"GPT-4V error, keeping question with valid timestamps: {e}" + ), } - - def _validate_with_gpt4v( - self, - question: Dict[str, Any], - segment_info: Dict[str, Any], - frame_paths: List[Path], - transcript_text: str - ) -> Dict[str, Any]: - """ - Use GPT-4V to validate the question against visual evidence. - - Returns: - Dict with keys: - - valid: bool - - reason: str (if invalid) - - corrected_timestamps: dict (if timestamps need correction) - - correction_reason: str + + def _create_frame_payload( + self, frame_paths: List[Path], segment_start: float + ) -> List[Dict[str, Any]]: + """Encode *frame_paths* as base64 and build image_url + caption content blocks. + + Raises + ------ + ValueError: If no frames could be encoded successfully. """ - segment_start = segment_info['start'] - segment_end = segment_info['end'] - segment_duration = segment_end - segment_start - - # Build the validation prompt - prompt = self._build_validation_prompt( - question, segment_info, transcript_text - ) - - # Prepare images image_contents = [] + total = len(frame_paths) + for i, frame_path in enumerate(frame_paths): try: - with open(frame_path, 'rb') as f: - image_data = base64.b64encode(f.read()).decode('utf-8') - - # Extract timestamp from filename if possible - frame_filename = frame_path.name + with open(frame_path, "rb") as f: + image_data = base64.b64encode(f.read()).decode("utf-8") + timestamp_info = "" - if '_t' in frame_filename: + if "_t" in frame_path.name: try: - ts_part = frame_filename.split('_t')[1].split('s')[0] + ts_part = frame_path.name.split("_t")[1].split("s")[0] absolute_time = float(ts_part) relative_time = absolute_time - segment_start - timestamp_info = f" (absolute: {absolute_time:.2f}s, relative: {relative_time:.2f}s)" - except: + timestamp_info = ( + f" (abs: {absolute_time:.2f}s, rel: {relative_time:.2f}s)" + ) + except (ValueError, IndexError): pass - - image_contents.append({ - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{image_data}", - "detail": "high" + + image_contents.append( + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_data}", + "detail": "high", + }, } - }) - - # Add caption for each frame - image_contents.append({ - "type": "text", - "text": f"[Frame {i+1}/{len(frame_paths)}{timestamp_info}]" - }) - + ) + image_contents.append( + {"type": "text", "text": f"[Frame {i + 1}/{total}{timestamp_info}]"} + ) + except Exception as e: - logger.error(f"Error encoding frame {frame_path}: {e}") + logger.error("Error encoding frame %s: %s", frame_path, e) continue - + if not image_contents: raise ValueError("No valid frames to send to GPT-4V") - + + return image_contents + + def _validate_with_gpt4v( + self, + question: Dict[str, Any], + segment_info: Dict[str, Any], + frame_paths: List[Path], + transcript_text: str, + ) -> Dict[str, Any]: + """ + Use GPT-4V to validate the question against visual evidence. + + Returns + ------- + Dict with keys: valid (bool), reason (str), corrected_timestamps + (dict), correction_reason (str). + """ + prompt = self._build_validation_prompt(question, segment_info, transcript_text) + image_contents = self._create_frame_payload(frame_paths, segment_info["start"]) + # Build messages messages = [ { "role": "system", - "content": "You are an expert video analyst validating temporal localization questions. You must provide responses in valid JSON format." + "content": ( + "You are an expert video analyst validating temporal " + "localization questions. Respond in valid JSON." + ), }, { "role": "user", - "content": [ - {"type": "text", "text": prompt} - ] + image_contents - } + "content": [{"type": "text", "text": prompt}] + image_contents, + }, ] - + # Call GPT-4V try: response = self.client.chat.completions.create( @@ -341,15 +446,15 @@ def _validate_with_gpt4v( messages=messages, max_tokens=1000, temperature=0.1, # Low temperature for consistent validation - response_format={"type": "json_object"} + response_format={"type": "json_object"}, ) - + response_text = response.choices[0].message.content result = json.loads(response_text) - + logger.debug(f"GPT-4V validation result: {result}") return result - + except json.JSONDecodeError as e: logger.error(f"Failed to parse GPT-4V response as JSON: {e}") logger.debug(f"Response text: {response_text}") @@ -357,72 +462,24 @@ def _validate_with_gpt4v( except Exception as e: logger.error(f"GPT-4V API call failed: {e}") raise - + def _build_validation_prompt( self, question: Dict[str, Any], segment_info: Dict[str, Any], - transcript_text: str + transcript_text: str, ) -> str: - """Build the validation prompt for GPT-4V""" - segment_start = segment_info['start'] - segment_end = segment_info['end'] - segment_duration = segment_end - segment_start - - answer = question.get('answer', {}) - start_s = answer.get('start_s') - end_s = answer.get('end_s') - - prompt = f"""You are validating a temporal localization question generated for a video segment. - - SEGMENT INFO: - - Segment absolute time: {segment_start}s to {segment_end}s (duration: {segment_duration}s) - - All timestamps in the question should be RELATIVE to segment start (0.0s to {segment_duration}s) - - QUESTION TO VALIDATE: - {json.dumps(question, indent=2)} - - TRANSCRIPT (if available): - {transcript_text if transcript_text else "No transcript available"} - - VALIDATION CRITERIA: - - 1. **Event Existence**: Do BOTH the anchor event and target event actually exist in the frames you see? - - Check if the described events are visible or can be inferred from the frames - - For audio events (requires_audio=true), check if transcript supports the events - - 2. **Timestamp Accuracy**: Are the provided timestamps [{start_s}s, {end_s}s] reasonable? - - Timestamps should be in SEGMENT-RELATIVE time (0.0 to {segment_duration}s) - - Do the frames near the target timestamps show the target event? - - Allow ยฑ5 second tolerance for minor inaccuracies - - 3. **Temporal Relationship**: Does the temporal relationship make sense? - - Relation: {question.get('temporal_relation', 'unknown')} - - Check if anchor and target have the stated relationship - - TASKS: - - 1. Determine if the question is VALID (events exist, timestamps reasonable, relation makes sense) - - 2. If timestamps are slightly off but events are identifiable: - - Provide corrected timestamps if you can identify better times - - Only correct if deviation is โ‰ค5 seconds - - 3. If question is invalid (events don't exist, wrong relation, timestamps way off): - - Mark as invalid and provide reason - - OUTPUT FORMAT (JSON): - {{ - "valid": true/false, - "reason": "events_not_found|invalid_relation|timestamps_way_off|other", - "message": "Detailed explanation", - "corrected_timestamps": {{ - "start_s": , - "end_s": - }}, // Only if minor corrections needed - "correction_reason": "Brief explanation of correction" - }} - - Respond with ONLY the JSON object, no other text.""" - - return prompt \ No newline at end of file + """Build the validation prompt for GPT-4V.""" + segment_start = segment_info["start"] + segment_end = segment_info["end"] + answer = question.get("answer", {}) + return VALIDATION_PROMPT_TEMPLATE.format( + segment_start=segment_start, + segment_end=segment_end, + segment_duration=segment_end - segment_start, + question_json=json.dumps(question, indent=2), + transcript_text=transcript_text or "No transcript available", + start_s=answer.get("start_s"), + end_s=answer.get("end_s"), + temporal_relation=question.get("temporal_relation", "unknown"), + ) diff --git a/sonic-o1/04_vqa_generation/prompts/__init__.py b/sonic-o1/04_vqa_generation/prompts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sonic-o1/04_vqa_generation/prompts/mcq_prompts.py b/sonic-o1/04_vqa_generation/prompts/mcq_prompts.py index 7835675..bbce929 100644 --- a/sonic-o1/04_vqa_generation/prompts/mcq_prompts.py +++ b/sonic-o1/04_vqa_generation/prompts/mcq_prompts.py @@ -1,3 +1,5 @@ +"""Prompts for MCQ generation.""" + MCQ_GENERATION_PROMPT = """You are a meticulous multimodal annotator creating challenging multiple-choice questions that test deep understanding of the content. SEGMENT INFORMATION: @@ -118,38 +120,45 @@ Now generate ONE MCQ question as a single JSON object for this segment:""" + def get_mcq_prompt(segment_info: dict, metadata: dict, transcript: str, config) -> str: - """Generate MCQ prompt for a segment""" + """Generate MCQ prompt for a segment.""" evidence_tags = config.mcq.evidence_tags num_options = config.mcq.num_options # e.g., 5 - + # Calculate derived values num_content_options = num_options - 1 # e.g., 4 (excluding "Not enough evidence") max_index = num_options - 1 # e.g., 4 - + # Generate option letters - option_letters = [chr(65 + i) for i in range(num_options)] # ['A', 'B', 'C', 'D', 'E'] + option_letters = [ + chr(65 + i) for i in range(num_options) + ] # ['A', 'B', 'C', 'D', 'E'] last_option_letter = option_letters[-1] # e.g., 'E' second_to_last_letter = option_letters[-2] # e.g., 'D' - + # Calculate distribution percentage (now including "Not enough evidence" position) distribution_percentage = round(100 / num_options, 1) # e.g., 20.0% for 5 options - + # Generate extra positions text for distribution - extra_positions = '' + extra_positions = "" if num_content_options > 4: for i in range(4, num_content_options): - extra_positions += f"\n * Position {option_letters[i]}: ~{distribution_percentage}%" - - evidence_tags_str = '\n'.join([f" - {tag}" for tag in evidence_tags]) - + extra_positions += ( + f"\n * Position {option_letters[i]}: ~{distribution_percentage}%" + ) + + evidence_tags_str = "\n".join([f" - {tag}" for tag in evidence_tags]) + return MCQ_GENERATION_PROMPT.format( - video_id=metadata.get('video_id', 'unknown'), - topic_name=metadata.get('topic_name', 'Unknown'), - start_time=segment_info['start'], - end_time=segment_info['end'], - duration=segment_info['duration'], - transcript_text=transcript if transcript else "No transcript available for this segment", + video_id=metadata.get("video_id", "unknown"), + topic_name=metadata.get("topic_name", "Unknown"), + start_time=segment_info["start"], + end_time=segment_info["end"], + duration=segment_info["duration"], + transcript_text=transcript + if transcript + else "No transcript available for this segment", evidence_tags_list=evidence_tags_str, num_options=num_options, num_content_options=num_content_options, @@ -157,5 +166,5 @@ def get_mcq_prompt(segment_info: dict, metadata: dict, transcript: str, config) last_option_letter=last_option_letter, second_to_last_letter=second_to_last_letter, distribution_percentage=distribution_percentage, - extra_positions=extra_positions - ) \ No newline at end of file + extra_positions=extra_positions, + ) diff --git a/sonic-o1/04_vqa_generation/prompts/summarization_prompts.py b/sonic-o1/04_vqa_generation/prompts/summarization_prompts.py index a0e6254..7552728 100644 --- a/sonic-o1/04_vqa_generation/prompts/summarization_prompts.py +++ b/sonic-o1/04_vqa_generation/prompts/summarization_prompts.py @@ -1,6 +1,7 @@ -""" -Task 1: Summarization prompts (map and reduce phases) -""" +"""Task 1: Summarization prompts (map and reduce phases).""" + +import json + # MAP PHASE: Per-segment summarization MAP_PHASE_PROMPT = """You are a precise video segment summarizer. @@ -152,7 +153,7 @@ CRITICAL: - Return ONLY valid JSON -- No markdown, no extra text +- No markdown, no extra text - Timeline in HH:MM:SS format - Aim for {timeline_min}-{timeline_max} timeline items @@ -246,47 +247,47 @@ # Helper functions + def get_map_prompt(segment_info: dict, metadata: dict, transcript: str, config) -> str: - """Generate map phase prompt for a segment""" + """Generate map phase prompt for a segment.""" max_words = config.summarization.constraints.max_words_segment - + return MAP_PHASE_PROMPT.format( - start_time=segment_info['start'], - end_time=segment_info['end'], - duration=segment_info['duration'], - title=metadata.get('title', 'Unknown'), - topic_name=metadata.get('topic_name', 'Unknown'), + start_time=segment_info["start"], + end_time=segment_info["end"], + duration=segment_info["duration"], + title=metadata.get("title", "Unknown"), + topic_name=metadata.get("topic_name", "Unknown"), transcript_text=transcript if transcript else "No transcript available", - max_words=max_words + max_words=max_words, ) + def get_initialize_prompt(first_segment: dict, video_id: str, metadata: dict) -> str: - """Generate prompt to initialize summary from first segment""" - import json - + """Generate prompt to initialize summary from first segment.""" return INITIALIZE_SUMMARY_PROMPT.format( video_id=video_id, - title=metadata.get('title', 'Unknown'), - duration=metadata.get('duration_seconds', 0), - first_segment_json=json.dumps(first_segment, indent=2) + title=metadata.get("title", "Unknown"), + duration=metadata.get("duration_seconds", 0), + first_segment_json=json.dumps(first_segment, indent=2), ) -def get_streaming_update_prompt(current_summary: dict, - new_segment: dict, - video_id: str, - metadata: dict, - segment_num: int, - total_segments: int, - config) -> str: - """Generate prompt to add new segment to accumulated summary""" - import json - +def get_streaming_update_prompt( + current_summary: dict, + new_segment: dict, + video_id: str, + metadata: dict, + segment_num: int, + total_segments: int, + config, +) -> str: + """Generate prompt to add new segment to accumulated summary.""" constraints = config.summarization.constraints - + return STREAMING_UPDATE_PROMPT.format( video_id=video_id, - title=metadata.get('title', 'Unknown'), + title=metadata.get("title", "Unknown"), segment_num=segment_num, total_segments=total_segments, current_summary_json=json.dumps(current_summary, indent=2), @@ -296,20 +297,21 @@ def get_streaming_update_prompt(current_summary: dict, timeline_min=constraints.timeline_items_min, timeline_max=constraints.timeline_items_max, glossary_min=constraints.glossary_items_min, - glossary_max=constraints.glossary_items_max + glossary_max=constraints.glossary_items_max, ) -def get_reduce_prompt(video_id: str, metadata: dict, segment_summaries: list, config) -> str: - """Generate reduce phase prompt for merging segments""" - import json - + +def get_reduce_prompt( + video_id: str, metadata: dict, segment_summaries: list, config +) -> str: + """Generate reduce phase prompt for merging segments.""" constraints = config.summarization.constraints - + return REDUCE_PHASE_PROMPT.format( video_id=video_id, - title=metadata.get('title', 'Unknown'), - topic_name=metadata.get('topic_name', 'Unknown'), - duration=metadata.get('duration_seconds', 0), + title=metadata.get("title", "Unknown"), + topic_name=metadata.get("topic_name", "Unknown"), + duration=metadata.get("duration_seconds", 0), num_segments=len(segment_summaries), segment_summaries_json=json.dumps(segment_summaries, indent=2), num_bullets=constraints.summary_short_bullets, @@ -317,24 +319,24 @@ def get_reduce_prompt(video_id: str, metadata: dict, segment_summaries: list, co timeline_min=constraints.timeline_items_min, timeline_max=constraints.timeline_items_max, glossary_min=constraints.glossary_items_min, - glossary_max=constraints.glossary_items_max + glossary_max=constraints.glossary_items_max, ) def get_direct_prompt(video_id: str, metadata: dict, transcript: str, config) -> str: - """Generate direct summarization prompt for short videos""" + """Generate direct summarization prompt for short videos.""" constraints = config.summarization.constraints - + return DIRECT_SUMMARY_PROMPT.format( video_id=video_id, - title=metadata.get('title', 'Unknown'), - topic_name=metadata.get('topic_name', 'Unknown'), - duration=metadata.get('duration_seconds', 0), + title=metadata.get("title", "Unknown"), + topic_name=metadata.get("topic_name", "Unknown"), + duration=metadata.get("duration_seconds", 0), transcript_text=transcript if transcript else "No transcript available", num_bullets=constraints.summary_short_bullets, max_words_detailed=constraints.max_words_detailed, timeline_min=constraints.timeline_items_min, timeline_max=constraints.timeline_items_max, glossary_min=constraints.glossary_items_min, - glossary_max=constraints.glossary_items_max - ) \ No newline at end of file + glossary_max=constraints.glossary_items_max, + ) diff --git a/sonic-o1/04_vqa_generation/prompts/temporal_judge_prompts.py b/sonic-o1/04_vqa_generation/prompts/temporal_judge_prompts.py index f8f5a6c..93725f7 100644 --- a/sonic-o1/04_vqa_generation/prompts/temporal_judge_prompts.py +++ b/sonic-o1/04_vqa_generation/prompts/temporal_judge_prompts.py @@ -1,35 +1,33 @@ -""" -Prompts for GPT-4V Temporal Question Validation -""" -from typing import Dict, Any +"""Prompts for GPT-4V Temporal Question Validation.""" + import json +from typing import Any, Dict def build_validation_prompt( - question: Dict[str, Any], - segment_info: Dict[str, Any], - transcript_text: str = "" + question: Dict[str, Any], segment_info: Dict[str, Any], transcript_text: str = "" ) -> str: """ Build GPT-4V validation prompt for a temporal question. - + Args: question: The temporal question to validate segment_info: Segment metadata (start, end) transcript_text: Optional transcript text - - Returns: + + Returns + ------- Formatted validation prompt """ - segment_start = segment_info['start'] - segment_end = segment_info['end'] + segment_start = segment_info["start"] + segment_end = segment_info["end"] segment_duration = segment_end - segment_start - - answer = question.get('answer', {}) - start_s = answer.get('start_s') - end_s = answer.get('end_s') - - prompt = f"""You are validating a temporal localization question for a video segment. + + answer = question.get("answer", {}) + start_s = answer.get("start_s") + end_s = answer.get("end_s") + + return f"""You are validating a temporal localization question for a video segment. โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ• SEGMENT INFORMATION @@ -74,7 +72,7 @@ def build_validation_prompt( - If end_s cuts off the event prematurely โ†’ provide correction 3. **TEMPORAL RELATIONSHIP** - - Stated relation: {question.get('temporal_relation', 'unknown')} + - Stated relation: {question.get("temporal_relation", "unknown")} - Does this relationship make sense between anchor and target? - Relations: * "after" = target occurs sometime after anchor completes @@ -168,17 +166,8 @@ def build_validation_prompt( Begin your validation. Respond with ONLY the JSON object.""" - return prompt - -def build_batch_validation_system_prompt() -> str: - """ - System prompt for GPT-4V judge validation. - - Returns: - System prompt string - """ - return """You are an expert video analyst specializing in temporal event validation. +BATCH_VALIDATION_SYSTEM_PROMPT = """You are an expert video analyst specializing in temporal event validation. Your role is to: 1. Carefully examine video frames to identify events @@ -193,4 +182,4 @@ def build_batch_validation_system_prompt() -> str: - Allow minor timestamp adjustments (โ‰ค5s) if events are clearly identifiable - Validate temporal relationships carefully (after, before, during, etc.) - Use transcript to validate audio events when available -- When uncertain, mark as invalid rather than guessing""" \ No newline at end of file +- When uncertain, mark as invalid rather than guessing""" diff --git a/sonic-o1/04_vqa_generation/prompts/temporal_localization_prompts.py b/sonic-o1/04_vqa_generation/prompts/temporal_localization_prompts.py index 9c85e70..74482e4 100644 --- a/sonic-o1/04_vqa_generation/prompts/temporal_localization_prompts.py +++ b/sonic-o1/04_vqa_generation/prompts/temporal_localization_prompts.py @@ -1,4 +1,7 @@ -from typing import Dict, Any +"""Prompts for temporal localization.""" + +from typing import Any, Dict + # Main prompt for temporal localization question generation TEMPORAL_LOCALIZATION_PROMPT = """You are a careful video annotator creating temporal reasoning questions. @@ -263,31 +266,34 @@ def get_temporal_localization_prompt( segment_info: Dict[str, Any], metadata: Dict[str, Any], transcript_text: str, - config: Any + config: Any, ) -> str: """ Generate temporal localization prompt for a video segment. - + Args: segment_info: Segment metadata with start/end times metadata: Video metadata transcript_text: Transcript for this segment config: Configuration object - - Returns: + + Returns + ------- Formatted prompt string optimized for Gemini 2.5 with thinking mode """ - video_id = metadata.get('video_id', metadata.get('video_number', 'unknown')) - + video_id = metadata.get("video_id", metadata.get("video_number", "unknown")) + # Calculate duration (model only sees the segment video, not full video) - duration = segment_info['end'] - segment_info['start'] - + duration = segment_info["end"] - segment_info["start"] + # Get number of questions from config num_questions = int(config.temporal_localization.questions_per_segment) - + return TEMPORAL_LOCALIZATION_PROMPT.format( video_id=video_id, duration=duration, num_questions=num_questions, - transcript_text=transcript_text if transcript_text else "No transcript available" - ) \ No newline at end of file + transcript_text=transcript_text + if transcript_text + else "No transcript available", + ) diff --git a/sonic-o1/04_vqa_generation/standardize_demographics.py b/sonic-o1/04_vqa_generation/standardize_demographics.py index 0d2322d..2497489 100644 --- a/sonic-o1/04_vqa_generation/standardize_demographics.py +++ b/sonic-o1/04_vqa_generation/standardize_demographics.py @@ -1,719 +1,776 @@ #!/usr/bin/env python3 -""" -Standardize Demographics in VQA Files -This script standardizes demographic values across all VQA JSON files to ensure -consistency with the canonical categories defined in vqa_config.yaml. +"""standardize_demographics.py. -Fixes: -- Maps variant race/ethnicity terms to canonical values (e.g., "South Asian" -> "Asian") -- Standardizes gender terms -- Converts age descriptors to numeric brackets (e.g., "Young (18-24)" -> "18-24") -- Normalizes language variants (e.g., "English American accent" -> "English") -- Removes "Unknown" entries where possible +Standardize demographic values across VQA JSON files to match canonical +categories in vqa_config.yaml. Maps variant race/gender/age/language +terms to canonical values and removes "Unknown" where possible. Usage: - python standardize_demographics.py --config config/vqa_config.yaml --dry-run - python standardize_demographics.py --config config/vqa_config.yaml + python standardize_demographics.py --config vqa_config.yaml --dry-run + python standardize_demographics.py --config vqa_config.yaml python standardize_demographics.py --topics 10,11 --dry-run + +Author: SONIC-O1 Team """ + import argparse import json import logging -import yaml import sys +from collections import Counter from pathlib import Path -from typing import Dict, List, Any, Optional, Tuple -from collections import defaultdict, Counter +from typing import Any, Dict, List, Optional, Tuple + +from utils.config_utils import Config, load_config +from utils.file_utils import save_json_with_backup + # Setup logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s', + format="%(asctime)s - %(levelname)s - %(message)s", handlers=[ - logging.FileHandler('standardize_demographics.log'), - logging.StreamHandler() - ] + logging.FileHandler("standardize_demographics.log"), + logging.StreamHandler(), + ], ) logger = logging.getLogger(__name__) -class Config: - """Configuration wrapper""" - def __init__(self, config_dict): - for key, value in config_dict.items(): - if isinstance(value, dict): - setattr(self, key, Config(value)) - else: - setattr(self, key, value) - -def load_config(config_path: str) -> Config: - """Load configuration from YAML file""" - with open(config_path, 'r') as f: - config_dict = yaml.safe_load(f) - return Config(config_dict) class DemographicsStandardizer: - """Standardize demographics to canonical categories""" - + """Standardize demographics to canonical categories.""" + # Canonical categories from config - CANONICAL_RACE = {'White', 'Black', 'Asian', 'Indigenous', 'Arab', 'Hispanic'} - CANONICAL_GENDER = {'Male', 'Female'} - CANONICAL_AGE = {'18-24', '25-39', '40+'} + CANONICAL_RACE = {"White", "Black", "Asian", "Indigenous", "Arab", "Hispanic"} + CANONICAL_GENDER = {"Male", "Female"} + CANONICAL_AGE = {"18-24", "25-39", "40+"} # Extended language list to include common languages in dataset - CANONICAL_LANGUAGE = {'English', 'Hindi', 'Arabic', 'Spanish', 'Chinese', - 'French', 'Korean', 'Thai', 'Swahili', 'Punjabi', - 'Telugu', 'Malay', 'Urdu'} - + CANONICAL_LANGUAGE = { + "English", + "Hindi", + "Arabic", + "Spanish", + "Chinese", + "French", + "Korean", + "Thai", + "Swahili", + "Punjabi", + "Telugu", + "Malay", + "Urdu", + } + # Mapping rules for race/ethnicity RACE_MAPPINGS = { # Asian variants - 'south asian': 'Asian', - 'east asian': 'Asian', - 'southeast asian': 'Asian', - 'east/southeast asian': 'Asian', - 'asian american': 'Asian', - 'south east asian': 'Asian', - 'indian/south asian': 'Asian', - 'middle eastern': 'Arab', - + "south asian": "Asian", + "east asian": "Asian", + "southeast asian": "Asian", + "east/southeast asian": "Asian", + "asian american": "Asian", + "south east asian": "Asian", + "indian/south asian": "Asian", + "middle eastern": "Arab", # Hispanic variants - 'latino': 'Hispanic', - 'latina': 'Hispanic', - 'latinx': 'Hispanic', - 'latin american': 'Hispanic', - 'spanish': 'Hispanic', - 'hispanic': 'Hispanic', - + "latino": "Hispanic", + "latina": "Hispanic", + "latinx": "Hispanic", + "latin american": "Hispanic", + "spanish": "Hispanic", + "hispanic": "Hispanic", # Black variants - 'african american': 'Black', - 'african': 'Black', - 'afro-american': 'Black', - + "african american": "Black", + "african": "Black", + "afro-american": "Black", # White variants - 'caucasian': 'White', - 'european': 'White', - 'european american': 'White', - + "caucasian": "White", + "european": "White", + "european american": "White", # Indigenous variants - 'native american': 'Indigenous', - 'aboriginal': 'Indigenous', - 'first nations': 'Indigenous', - + "native american": "Indigenous", + "aboriginal": "Indigenous", + "first nations": "Indigenous", # Arab variants - 'arab american': 'Arab', - 'middle-eastern': 'Arab', - + "arab american": "Arab", + "middle-eastern": "Arab", # Asian - lowercase variant - 'asian': 'Asian', - + "asian": "Asian", # Mixed-race - map to first mentioned or most specific - 'mixed-race': 'Asian', # Review case-by-case if needed + "mixed-race": "Asian", # Review case-by-case if needed } - + # Mapping rules for gender GENDER_MAPPINGS = { - 'man': 'Male', - 'woman': 'Female', - 'm': 'Male', - 'f': 'Female', - 'female': 'Female', - 'male': 'Male', + "man": "Male", + "woman": "Female", + "m": "Male", + "f": "Female", + "female": "Female", + "male": "Male", } - + # Mapping rules for age (from descriptive to numeric brackets) AGE_MAPPINGS = { - 'young (18-24)': '18-24', - 'young': '18-24', - 'young adult': '18-24', - 'young adults': '18-24', - - 'middle (25-39)': '25-39', - 'middle': '25-39', - 'middle age': '25-39', - 'middle-aged': '25-39', - 'middle aged': '25-39', - 'middle adults': '25-39', - 'older adults (25-39)': '25-39', - - 'older adults (40+)': '40+', - 'older (40+)': '40+', - 'older adults': '40+', - 'older adult': '40+', - 'older adult (40+)': '40+', - 'older adults (+40)': '40+', - 'older 40+': '40+', - 'older': '40+', - 'old': '40+', - 'senior': '40+', - 'elderly': '40+', + "young (18-24)": "18-24", + "young": "18-24", + "young adult": "18-24", + "young adults": "18-24", + "middle (25-39)": "25-39", + "middle": "25-39", + "middle age": "25-39", + "middle-aged": "25-39", + "middle aged": "25-39", + "middle adults": "25-39", + "older adults (25-39)": "25-39", + "older adults (40+)": "40+", + "older (40+)": "40+", + "older adults": "40+", + "older adult": "40+", + "older adult (40+)": "40+", + "older adults (+40)": "40+", + "older 40+": "40+", + "older": "40+", + "old": "40+", + "senior": "40+", + "elderly": "40+", } - + # Mapping rules for language LANGUAGE_MAPPINGS = { # English variants with accents - all map to English - 'English ': 'English', # Capitalized with trailing space - 'english ': 'English', # lowercase with trailing space - 'english american accent': 'English', - 'english (american)': 'English', - 'english (british)': 'English', - 'english (british accent)': 'English', - 'english (british/uk accent)': 'English', - 'english (american accent)': 'English', - 'english (north american accent)': 'English', - 'english (north american)': 'English', - 'english (uk accent)': 'English', - 'english (australian accent)': 'English', - 'english (indian accent)': 'English', - 'english (south asian accent)': 'English', - 'english (south asian/middle eastern accent)': 'English', - 'english (singaporean accent)': 'English', - 'english (singaporean accent), mandarin': 'English', - 'english (eastern european accent)': 'English', - 'english (spanish/latin american accent)': 'English', - 'english (middle eastern/foreign accent)': 'English', - 'english (general/non-native kenyan accent)': 'English', - 'english (general/non-native accent)': 'English', - 'english (non-indian accent)': 'English', - 'english (jamaican accent)': 'English', - 'english (caribbean/jamaican accent)': 'English', - 'english (african accent)': 'English', - 'english (arab accent)': 'English', - 'english with accent': 'English', - 'english with indian accent': 'English', - 'english with indian/south asian accent': 'English', - 'english with east asian accent': 'English', - 'english with arab accent': 'English', - 'american english': 'English', - 'british english': 'English', - 'english, singaporean accent': 'English', - 'english, arabic accent': 'English', - 'english ': 'English', # with trailing space - + "English ": "English", # Capitalized with trailing space + "english ": "English", # lowercase with trailing space + "english american accent": "English", + "english (american)": "English", + "english (british)": "English", + "english (british accent)": "English", + "english (british/uk accent)": "English", + "english (american accent)": "English", + "english (north american accent)": "English", + "english (north american)": "English", + "english (uk accent)": "English", + "english (australian accent)": "English", + "english (indian accent)": "English", + "english (south asian accent)": "English", + "english (south asian/middle eastern accent)": "English", + "english (singaporean accent)": "English", + "english (singaporean accent), mandarin": "English", + "english (eastern european accent)": "English", + "english (spanish/latin american accent)": "English", + "english (middle eastern/foreign accent)": "English", + "english (general/non-native kenyan accent)": "English", + "english (general/non-native accent)": "English", + "english (non-indian accent)": "English", + "english (jamaican accent)": "English", + "english (caribbean/jamaican accent)": "English", + "english (african accent)": "English", + "english (arab accent)": "English", + "english with accent": "English", + "english with indian accent": "English", + "english with indian/south asian accent": "English", + "english with east asian accent": "English", + "english with arab accent": "English", + "american english": "English", + "british english": "English", + "english, singaporean accent": "English", + "english, arabic accent": "English", # with trailing space # Multilingual - take first/primary language - 'english, spanish': 'English', - 'english, hindi': 'English', - 'english, arabic': 'English', - 'english, swahili': 'English', - 'english, telugu': 'English', - 'english, mandarin': 'English', - 'english, mandarin, cantonese': 'English', - 'english, hindi, punjabi': 'English', - 'english, urdu': 'English', - 'english, italian accent': 'English', - 'french, english': 'French', - 'thai, english': 'Thai', - 'thai, korean, english': 'English', - 'malay, english': 'English', - 'malay, chinese (mandarin), english': 'English', - 'mandarin, english': 'Chinese', - 'hokkien, mandarin': 'Chinese', - 'urdu, english': 'Urdu', - + "english, spanish": "English", + "english, hindi": "English", + "english, arabic": "English", + "english, swahili": "English", + "english, telugu": "English", + "english, mandarin": "English", + "english, mandarin, cantonese": "English", + "english, hindi, punjabi": "English", + "english, urdu": "English", + "english, italian accent": "English", + "french, english": "French", + "thai, english": "Thai", + "thai, korean, english": "English", + "malay, english": "English", + "malay, chinese (mandarin), english": "English", + "mandarin, english": "Chinese", + "hokkien, mandarin": "Chinese", + "urdu, english": "Urdu", # Chinese variants - 'mandarin': 'Chinese', - 'cantonese': 'Chinese', - 'mandarin chinese': 'Chinese', - 'chinese (mandarin)': 'Chinese', - 'hokkien': 'Chinese', - + "mandarin": "Chinese", + "cantonese": "Chinese", + "mandarin chinese": "Chinese", + "chinese (mandarin)": "Chinese", + "hokkien": "Chinese", # Spanish variants - 'spanish (latin american)': 'Spanish', - 'spanish (spain)': 'Spanish', - 'latin american spanish': 'Spanish', - + "spanish (latin american)": "Spanish", + "spanish (spain)": "Spanish", + "latin american spanish": "Spanish", # Hindi variants - 'hindi/urdu': 'Hindi', - + "hindi/urdu": "Hindi", # Arabic variants - 'modern standard arabic': 'Arabic', - 'arabic (egyptian)': 'Arabic', - 'arabic (levantine)': 'Arabic', - 'arabic accent': 'Arabic', - + "modern standard arabic": "Arabic", + "arabic (egyptian)": "Arabic", + "arabic (levantine)": "Arabic", + "arabic accent": "Arabic", # Urdu standalone - 'urdu': 'Urdu', - + "urdu": "Urdu", # Sign language - map to English (most common in dataset) - 'asl': 'English', + "asl": "English", } - + def __init__(self, config: Config, dry_run: bool = False): """ Initialize standardizer. - + Args: config: Configuration object dry_run: If True, only report what would be changed """ self.config = config self.dry_run = dry_run - + # Statistics tracking self.stats = { - 'total_entries': 0, - 'entries_with_demographics': 0, - 'entries_modified': 0, - 'total_demographic_items': 0, - 'items_modified': 0, - 'race_changes': Counter(), - 'gender_changes': Counter(), - 'age_changes': Counter(), - 'language_changes': Counter(), - 'out_of_scope_race': Counter(), - 'out_of_scope_gender': Counter(), - 'out_of_scope_age': Counter(), - 'out_of_scope_language': Counter(), - 'unknown_removed': 0, + "total_entries": 0, + "entries_with_demographics": 0, + "entries_modified": 0, + "total_demographic_items": 0, + "items_modified": 0, + "race_changes": Counter(), + "gender_changes": Counter(), + "age_changes": Counter(), + "language_changes": Counter(), + "out_of_scope_race": Counter(), + "out_of_scope_gender": Counter(), + "out_of_scope_age": Counter(), + "out_of_scope_language": Counter(), + "unknown_removed": 0, } - - def standardize_value(self, value: str, category: str) -> Tuple[str, bool]: - """ - Standardize a single demographic value. - - Args: - value: Original value - category: 'race', 'gender', 'age', or 'language' - - Returns: - Tuple of (standardized_value, was_changed) + + def _apply_mapping( + self, + value: str, + value_lower: str, + original: str, + mappings: Dict, + canonical: set, + changes_counter: Counter, + out_of_scope_counter: Counter, + ) -> Tuple[str, bool]: + """Check canonical set, try mapping table, track stats. + + Shared by all per-category standardize methods. """ - value = value.strip() - if not value or value.lower() == 'unknown': - return value, False - - original = value - value_lower = value.lower() - - # Special case handling before mapping - if category == 'race': - # Handle multi-racial entries - keep first race mentioned - if 'mixed-race' in value_lower or 'mixed race' in value_lower: - # Map to Asian as default (most common in dataset) - pass - elif ',' in value and 'white' in value_lower and 'hispanic' in value_lower: - # "White, Hispanic" -> "Hispanic" (prioritize minority) - value_lower = 'hispanic' - mappings = self.RACE_MAPPINGS - canonical = self.CANONICAL_RACE - out_of_scope = self.stats['out_of_scope_race'] - new_value = 'Hispanic' - self.stats['race_changes'][f"{original} -> {new_value}"] += 1 - return new_value, True - elif 'not specified' in value_lower: - # Map to Unknown - return 'Unknown', True - - elif category == 'age': - # Handle children - map to youngest bracket - if 'under' in value_lower or 'child' in value_lower or value_lower.startswith('young (under'): - value_lower = 'young (18-24)' - - # Select appropriate mappings - if category == 'race': - mappings = self.RACE_MAPPINGS - canonical = self.CANONICAL_RACE - out_of_scope = self.stats['out_of_scope_race'] - elif category == 'gender': - mappings = self.GENDER_MAPPINGS - canonical = self.CANONICAL_GENDER - out_of_scope = self.stats['out_of_scope_gender'] - elif category == 'age': - mappings = self.AGE_MAPPINGS - canonical = self.CANONICAL_AGE - out_of_scope = self.stats['out_of_scope_age'] - elif category == 'language': - mappings = self.LANGUAGE_MAPPINGS - canonical = self.CANONICAL_LANGUAGE - out_of_scope = self.stats['out_of_scope_language'] - else: - return value, False - - # Check if already canonical if value in canonical: return value, False - - # Try to map if value_lower in mappings: new_value = mappings[value_lower] - if category == 'race': - self.stats['race_changes'][f"{original} -> {new_value}"] += 1 - elif category == 'gender': - self.stats['gender_changes'][f"{original} -> {new_value}"] += 1 - elif category == 'age': - self.stats['age_changes'][f"{original} -> {new_value}"] += 1 - elif category == 'language': - self.stats['language_changes'][f"{original} -> {new_value}"] += 1 + changes_counter[f"{original} -> {new_value}"] += 1 return new_value, True - - # Not found in mappings - track as out of scope - if value.lower() != 'unknown': - out_of_scope[original] += 1 - + if value_lower != "unknown": + out_of_scope_counter[original] += 1 return value, False - - def standardize_demographic_entry(self, entry: Dict[str, Any]) -> bool: + + def standardize_race(self, value: str) -> Tuple[str, bool]: + """Standardize a race/ethnicity value.""" + original = value + value_lower = value.lower() + # "White, Hispanic" โ†’ prioritise minority + if "," in value and "white" in value_lower and "hispanic" in value_lower: + self.stats["race_changes"][f"{original} -> Hispanic"] += 1 + return "Hispanic", True + # "Not specified" โ†’ Unknown + if "not specified" in value_lower: + return "Unknown", True + return self._apply_mapping( + value, + value_lower, + original, + self.RACE_MAPPINGS, + self.CANONICAL_RACE, + self.stats["race_changes"], + self.stats["out_of_scope_race"], + ) + + def standardize_gender(self, value: str) -> Tuple[str, bool]: + """Standardize a gender value.""" + original = value + value_lower = value.lower() + return self._apply_mapping( + value, + value_lower, + original, + self.GENDER_MAPPINGS, + self.CANONICAL_GENDER, + self.stats["gender_changes"], + self.stats["out_of_scope_gender"], + ) + + def standardize_age(self, value: str) -> Tuple[str, bool]: + """Standardize an age value.""" + original = value + value_lower = value.lower() + # Children / under-18 โ†’ youngest bracket + if ( + "under" in value_lower + or "child" in value_lower + or value_lower.startswith("young (under") + ): + value_lower = "young (18-24)" + return self._apply_mapping( + value, + value_lower, + original, + self.AGE_MAPPINGS, + self.CANONICAL_AGE, + self.stats["age_changes"], + self.stats["out_of_scope_age"], + ) + + def standardize_language(self, value: str) -> Tuple[str, bool]: + """Standardize a language value.""" + original = value + value_lower = value.lower() + return self._apply_mapping( + value, + value_lower, + original, + self.LANGUAGE_MAPPINGS, + self.CANONICAL_LANGUAGE, + self.stats["language_changes"], + self.stats["out_of_scope_language"], + ) + + def standardize_value(self, value: str, category: str) -> Tuple[str, bool]: """ - Standardize a single demographic entry. - + Dispatch to the appropriate per-category standardize method. + Args: - entry: Demographic entry dict with race, gender, age, language, count - - Returns: - True if any changes were made + value: Original demographic value. + category: One of "race", "gender", "age", "language". + + Returns + ------- + Tuple of (standardized_value, was_changed). """ - changed = False - - # Standardize race - if 'race' in entry: - new_race, race_changed = self.standardize_value(entry['race'], 'race') - if race_changed: - entry['race'] = new_race - changed = True - - # Standardize gender - if 'gender' in entry: - new_gender, gender_changed = self.standardize_value(entry['gender'], 'gender') - if gender_changed: - entry['gender'] = new_gender - changed = True - - # Standardize age - if 'age' in entry: - new_age, age_changed = self.standardize_value(entry['age'], 'age') - if age_changed: - entry['age'] = new_age - changed = True - - # Standardize language - if 'language' in entry: - new_language, language_changed = self.standardize_value(entry['language'], 'language') - if language_changed: - entry['language'] = new_language - changed = True - + value = value.strip() + if not value or value.lower() == "unknown": + return value, False + + dispatch = { + "race": self.standardize_race, + "gender": self.standardize_gender, + "age": self.standardize_age, + "language": self.standardize_language, + } + fn = dispatch.get(category) + if fn is None: + return value, False + return fn(value) + + def standardize_field(self, key: str, entry: Dict[str, Any]) -> bool: + """ + Standardize a single field of a demographic entry in-place. + + Args: + key: Field name โ€” one of "race", "gender", "age", "language". + entry: Demographic entry dict to update. + + Returns + ------- + True if the field value was changed. + """ + if key not in entry: + return False + new_val, changed = self.standardize_value(entry[key], key) + if changed: + entry[key] = new_val return changed - + + def standardize_demographic_entry(self, entry: Dict[str, Any]) -> bool: + """ + Standardize all demographic fields of a single entry. + + Args: + entry: Demographic entry dict with race, gender, age, language, count. + + Returns + ------- + True if any field was changed. + """ + results = [ + self.standardize_field(key, entry) + for key in ("race", "gender", "age", "language") + ] + return any(results) + def should_remove_entry(self, entry: Dict[str, Any]) -> bool: """ Check if demographic entry should be removed (all unknowns). - + Args: entry: Demographic entry dict - - Returns: + + Returns + ------- True if entry should be removed """ - race = entry.get('race', '').lower() - gender = entry.get('gender', '').lower() - age = entry.get('age', '').lower() - language = entry.get('language', '').lower() - + race = entry.get("race", "").lower() + gender = entry.get("gender", "").lower() + age = entry.get("age", "").lower() + language = entry.get("language", "").lower() + # Remove if all fields are unknown - return (race == 'unknown' and - gender == 'unknown' and - age == 'unknown' and - language == 'unknown') - - def standardize_demographics_array(self, demographics: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], bool]: + return ( + race == "unknown" + and gender == "unknown" + and age == "unknown" + and language == "unknown" + ) + + def standardize_demographics_array( + self, demographics: List[Dict[str, Any]] + ) -> Tuple[List[Dict[str, Any]], bool]: """ Standardize an array of demographic entries. - + Args: demographics: List of demographic dicts - - Returns: + + Returns + ------- Tuple of (standardized_list, was_changed) """ if not demographics: return demographics, False - + changed = False standardized = [] - + for entry in demographics: - self.stats['total_demographic_items'] += 1 - + self.stats["total_demographic_items"] += 1 + # Check if should be removed if self.should_remove_entry(entry): - self.stats['unknown_removed'] += 1 + self.stats["unknown_removed"] += 1 changed = True continue - + # Standardize entry entry_changed = self.standardize_demographic_entry(entry) if entry_changed: - self.stats['items_modified'] += 1 + self.stats["items_modified"] += 1 changed = True - + standardized.append(entry) - + return standardized, changed - + def process_vqa_entry(self, entry: Dict[str, Any], task_name: str) -> bool: """ Process a single VQA entry. - + Args: entry: VQA entry dict task_name: 'task1', 'task2', or 'task3' - - Returns: + + Returns + ------- True if entry was modified """ - self.stats['total_entries'] += 1 - - demographics = entry.get('demographics', []) + self.stats["total_entries"] += 1 + + demographics = entry.get("demographics", []) if not demographics: return False - - self.stats['entries_with_demographics'] += 1 - + + self.stats["entries_with_demographics"] += 1 + # Standardize demographics array standardized, changed = self.standardize_demographics_array(demographics) - + if changed and not self.dry_run: - entry['demographics'] = standardized - + entry["demographics"] = standardized + # Update total_individuals if field exists (task2/task3) - if 'demographics_total_individuals' in entry: - total = sum(d.get('count', 0) for d in standardized) - entry['demographics_total_individuals'] = total - + if "demographics_total_individuals" in entry: + total = sum(d.get("count", 0) for d in standardized) + entry["demographics_total_individuals"] = total + if changed: - self.stats['entries_modified'] += 1 - + self.stats["entries_modified"] += 1 + return changed - - def process_json_file(self, json_path: Path, task_name: str) -> Dict[str, int]: + + def process_task_file(self, json_path: Path, task_name: str) -> Dict[str, int]: """ Process a single VQA JSON file. - + Args: json_path: Path to JSON file task_name: 'task1', 'task2', or 'task3' - - Returns: + + Returns + ------- Dict with stats """ - logger.info(f"\n{'='*80}") + logger.info(f"\n{'=' * 80}") logger.info(f"Processing {json_path.name}") - logger.info(f"{'='*80}") - + logger.info(f"{'=' * 80}") + # Load JSON try: - with open(json_path, 'r') as f: + with open(json_path, "r") as f: data = json.load(f) except Exception as e: logger.error(f"Failed to load {json_path}: {e}") - return {'total': 0, 'modified': 0} - - entries = data.get('entries', []) - + return {"total": 0, "modified": 0} + + entries = data.get("entries", []) + # Track changes for this file - file_stats = {'total': len(entries), 'modified': 0} - + file_stats = {"total": len(entries), "modified": 0} + # Process each entry for entry in entries: if self.process_vqa_entry(entry, task_name): - file_stats['modified'] += 1 - - # Save if changes were made (and not dry-run) - if file_stats['modified'] > 0: + file_stats["modified"] += 1 + + if file_stats["modified"] > 0: if self.dry_run: - logger.info(f"[DRY-RUN] Would modify {file_stats['modified']} entries in {json_path.name}") + logger.info( + f"[DRY-RUN] Would modify {file_stats['modified']} entries in {json_path.name}" + ) else: - try: - # Create backup - backup_path = json_path.with_suffix('.json.backup_standardize') - with open(backup_path, 'w') as f: - json.dump(data, f, indent=2, ensure_ascii=False) - logger.info(f"Created backup: {backup_path.name}") - - # Save updated file - with open(json_path, 'w') as f: - json.dump(data, f, indent=2, ensure_ascii=False) - logger.info(f"โœ“ Saved {file_stats['modified']} changes to {json_path.name}") - except Exception as e: - logger.error(f"Failed to save {json_path}: {e}") + save_json_with_backup(data, json_path, ".json.backup_standardize") + logger.info( + f"โœ“ Saved {file_stats['modified']} changes to {json_path.name}" + ) else: logger.info(f"โœ“ No changes needed for {json_path.name}") - + return file_stats - + def process_all_tasks(self, topic_filter: Optional[List[int]] = None): """ Process all VQA task directories. - + Args: topic_filter: Optional list of topic IDs to process """ output_dir = Path(self.config.paths.output_dir) - + if not output_dir.exists(): logger.error(f"Output directory not found: {output_dir}") return - + task_dirs = { - 'task1': output_dir / 'task1_summarization', - 'task2': output_dir / 'task2_mcq', - 'task3': output_dir / 'task3_temporal_localization' + "task1": output_dir / "task1_summarization", + "task2": output_dir / "task2_mcq", + "task3": output_dir / "task3_temporal_localization", } - + for task_name, task_dir in task_dirs.items(): if not task_dir.exists(): logger.warning(f"Task directory not found: {task_dir}") continue - - logger.info(f"\n{'#'*80}") + + logger.info(f"\n{'#' * 80}") logger.info(f"# Processing {task_name.upper()}") - logger.info(f"{'#'*80}") - + logger.info(f"{'#' * 80}") + # Get all JSON files json_files = sorted(task_dir.glob("*.json")) - + # Skip backup files - json_files = [f for f in json_files if not f.name.endswith('.backup') - and not f.name.endswith('.backup_standardize')] - + json_files = [ + f + for f in json_files + if not f.name.endswith(".backup") + and not f.name.endswith(".backup_standardize") + ] + # Filter by topic if specified if topic_filter: json_files = [ - f for f in json_files + f + for f in json_files if any(f.name.startswith(f"{tid:02d}_") for tid in topic_filter) ] - + logger.info(f"Found {len(json_files)} JSON files to process") - + for json_path in json_files: - self.process_json_file(json_path, task_name) - + self.process_task_file(json_path, task_name) + # Print final summary self.print_summary() - + def print_summary(self): - """Print comprehensive statistics summary""" - logger.info(f"\n{'='*80}") + """Print comprehensive statistics summary.""" + logger.info(f"\n{'=' * 80}") logger.info("STANDARDIZATION SUMMARY") - logger.info(f"{'='*80}") - - logger.info(f"\nOVERALL STATISTICS:") - logger.info(f" Total VQA entries processed: {self.stats['total_entries']}") - logger.info(f" Entries with demographics: {self.stats['entries_with_demographics']}") - logger.info(f" Entries modified: {self.stats['entries_modified']}") - logger.info(f" Total demographic items: {self.stats['total_demographic_items']}") - logger.info(f" Items modified: {self.stats['items_modified']}") - logger.info(f" Unknown entries removed: {self.stats['unknown_removed']}") - + logger.info(f"{'=' * 80}") + + logger.info("\nOVERALL STATISTICS:") + logger.info( + f" Total VQA entries processed: {self.stats['total_entries']}" + ) + logger.info( + f" Entries with demographics: {self.stats['entries_with_demographics']}" + ) + logger.info( + f" Entries modified: {self.stats['entries_modified']}" + ) + logger.info( + f" Total demographic items: {self.stats['total_demographic_items']}" + ) + logger.info( + f" Items modified: {self.stats['items_modified']}" + ) + logger.info( + f" Unknown entries removed: {self.stats['unknown_removed']}" + ) + # Race changes - if self.stats['race_changes']: - logger.info(f"\nRACE/ETHNICITY CHANGES ({sum(self.stats['race_changes'].values())} total):") - for change, count in self.stats['race_changes'].most_common(): + if self.stats["race_changes"]: + logger.info( + f"\nRACE/ETHNICITY CHANGES ({sum(self.stats['race_changes'].values())} total):" + ) + for change, count in self.stats["race_changes"].most_common(): logger.info(f" {change}: {count}x") - + # Gender changes - if self.stats['gender_changes']: - logger.info(f"\nGENDER CHANGES ({sum(self.stats['gender_changes'].values())} total):") - for change, count in self.stats['gender_changes'].most_common(): + if self.stats["gender_changes"]: + logger.info( + f"\nGENDER CHANGES ({sum(self.stats['gender_changes'].values())} total):" + ) + for change, count in self.stats["gender_changes"].most_common(): logger.info(f" {change}: {count}x") - + # Age changes - if self.stats['age_changes']: - logger.info(f"\nAGE CHANGES ({sum(self.stats['age_changes'].values())} total):") - for change, count in self.stats['age_changes'].most_common(): + if self.stats["age_changes"]: + logger.info( + f"\nAGE CHANGES ({sum(self.stats['age_changes'].values())} total):" + ) + for change, count in self.stats["age_changes"].most_common(): logger.info(f" {change}: {count}x") - + # Language changes - if self.stats['language_changes']: - logger.info(f"\nLANGUAGE CHANGES ({sum(self.stats['language_changes'].values())} total):") - for change, count in self.stats['language_changes'].most_common(): + if self.stats["language_changes"]: + logger.info( + f"\nLANGUAGE CHANGES ({sum(self.stats['language_changes'].values())} total):" + ) + for change, count in self.stats["language_changes"].most_common(): logger.info(f" {change}: {count}x") - + # Out of scope values (these need manual review) - logger.info(f"\n{'='*80}") + logger.info(f"\n{'=' * 80}") logger.info("OUT-OF-SCOPE VALUES (Need Manual Review)") - logger.info(f"{'='*80}") - - if self.stats['out_of_scope_race']: - logger.info(f"\nRACE/ETHNICITY values not in canonical list:") - for value, count in self.stats['out_of_scope_race'].most_common(): + logger.info(f"{'=' * 80}") + + if self.stats["out_of_scope_race"]: + logger.info("\nRACE/ETHNICITY values not in canonical list:") + for value, count in self.stats["out_of_scope_race"].most_common(): logger.info(f" '{value}': {count}x") else: - logger.info(f"\nโœ“ All race/ethnicity values are canonical") - - if self.stats['out_of_scope_gender']: - logger.info(f"\nGENDER values not in canonical list:") - for value, count in self.stats['out_of_scope_gender'].most_common(): + logger.info("\nโœ“ All race/ethnicity values are canonical") + + if self.stats["out_of_scope_gender"]: + logger.info("\nGENDER values not in canonical list:") + for value, count in self.stats["out_of_scope_gender"].most_common(): logger.info(f" '{value}': {count}x") else: - logger.info(f"\nโœ“ All gender values are canonical") - - if self.stats['out_of_scope_age']: - logger.info(f"\nAGE values not in canonical list:") - for value, count in self.stats['out_of_scope_age'].most_common(): + logger.info("\nโœ“ All gender values are canonical") + + if self.stats["out_of_scope_age"]: + logger.info("\nAGE values not in canonical list:") + for value, count in self.stats["out_of_scope_age"].most_common(): logger.info(f" '{value}': {count}x") else: - logger.info(f"\nโœ“ All age values are canonical") - - if self.stats['out_of_scope_language']: - logger.info(f"\nLANGUAGE values not in canonical list:") - for value, count in self.stats['out_of_scope_language'].most_common(): + logger.info("\nโœ“ All age values are canonical") + + if self.stats["out_of_scope_language"]: + logger.info("\nLANGUAGE values not in canonical list:") + for value, count in self.stats["out_of_scope_language"].most_common(): logger.info(f" '{value}': {count}x") else: - logger.info(f"\nโœ“ All language values are canonical") - + logger.info("\nโœ“ All language values are canonical") + # Summary - total_out_of_scope = (len(self.stats['out_of_scope_race']) + - len(self.stats['out_of_scope_gender']) + - len(self.stats['out_of_scope_age']) + - len(self.stats['out_of_scope_language'])) - - logger.info(f"\n{'='*80}") + total_out_of_scope = ( + len(self.stats["out_of_scope_race"]) + + len(self.stats["out_of_scope_gender"]) + + len(self.stats["out_of_scope_age"]) + + len(self.stats["out_of_scope_language"]) + ) + + logger.info(f"\n{'=' * 80}") if total_out_of_scope > 0: - logger.warning(f"โš  Found {total_out_of_scope} unique out-of-scope values requiring manual review") - logger.warning(f" Consider adding these to the mapping rules in this script") + logger.warning( + f"โš  Found {total_out_of_scope} unique out-of-scope values requiring manual review" + ) + logger.warning( + " Consider adding these to the mapping rules in this script" + ) else: - logger.info(f"โœ“ All demographic values conform to canonical categories!") + logger.info("โœ“ All demographic values conform to canonical categories!") + def main(): - """Main entry point""" - parser = argparse.ArgumentParser(description='Standardize Demographics in VQA Files') - parser.add_argument('--config', type=str, default='config/vqa_config.yaml', - help='Path to configuration file') - parser.add_argument('--topics', type=str, default=None, - help='Comma-separated topic IDs to process (e.g., "10,11")') - parser.add_argument('--dry-run', action='store_true', - help='Show what would be changed without making modifications') - + """Run main entry point.""" + parser = argparse.ArgumentParser( + description="Standardize Demographics in VQA Files" + ) + parser.add_argument( + "--config", + type=str, + default="vqa_config.yaml", + help="Path to configuration file", + ) + parser.add_argument( + "--topics", + type=str, + default=None, + help='Comma-separated topic IDs to process (e.g., "10,11")', + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Show what would be changed without making modifications", + ) + args = parser.parse_args() - + # Load config config_path = Path(args.config) if not config_path.exists(): print(f"Error: Config file not found: {config_path}") sys.exit(1) - - config = load_config(str(config_path)) - + + config = load_config(str(config_path), base_dir=Path(__file__).parent) + # Parse topic filter topic_filter = None if args.topics: try: - topic_filter = [int(t.strip()) for t in args.topics.split(',')] + topic_filter = [int(t.strip()) for t in args.topics.split(",")] logger.info(f"Processing topics: {topic_filter}") except ValueError: logger.error(f"Invalid topics format: {args.topics}") sys.exit(1) - + # Create standardizer and run standardizer = DemographicsStandardizer(config, dry_run=args.dry_run) - + if args.dry_run: logger.info("=" * 80) logger.info("DRY RUN MODE - No changes will be made") logger.info("=" * 80) - + standardizer.process_all_tasks(topic_filter) - + logger.info("\nโœ“ Done!") -if __name__ == '__main__': - main() \ No newline at end of file + +if __name__ == "__main__": + main() diff --git a/sonic-o1/04_vqa_generation/utils/__init__.py b/sonic-o1/04_vqa_generation/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sonic-o1/04_vqa_generation/utils/config_utils.py b/sonic-o1/04_vqa_generation/utils/config_utils.py new file mode 100644 index 0000000..d724b4c --- /dev/null +++ b/sonic-o1/04_vqa_generation/utils/config_utils.py @@ -0,0 +1,48 @@ +"""config_utils.py. + +Shared configuration loader for VQA generation scripts. + +Author: SONIC-O1 Team +""" + +from pathlib import Path + +import yaml + + +class Config: + """Configuration wrapper with nested attribute access.""" + + def __init__(self, config_dict): + """Initialize from nested dict. Recursively wraps dicts as Config.""" + for key, value in config_dict.items(): + if isinstance(value, dict): + setattr(self, key, Config(value)) + else: + setattr(self, key, value) + + +def load_config(config_path: str, base_dir: Path | None = None) -> Config: + """Load configuration from a YAML file. + + Args: + config_path: Path to the config file. If relative, resolved + relative to *base_dir* (or the caller's directory when + *base_dir* is ``None``). + base_dir: Optional directory used to resolve relative paths. + When ``None``, callers should pass + ``Path(__file__).parent`` explicitly. + + Returns + ------- + Config object with nested attribute access. + """ + config_file = Path(config_path) + if not config_file.is_absolute(): + if base_dir is None: + base_dir = Path.cwd() + config_file = base_dir / config_path + + with open(config_file, "r", encoding="utf-8") as f: + config_dict = yaml.safe_load(f) + return Config(config_dict) diff --git a/sonic-o1/04_vqa_generation/utils/demographics_expander.py b/sonic-o1/04_vqa_generation/utils/demographics_expander.py index 995c68d..de26379 100644 --- a/sonic-o1/04_vqa_generation/utils/demographics_expander.py +++ b/sonic-o1/04_vqa_generation/utils/demographics_expander.py @@ -1,64 +1,69 @@ +"""demographics_expander.py. + +Transform metadata demographics to VQA format with per-segment counting. +Uses human-reviewed demographics from metadata_enhanced.json. + +Author: SONIC-O1 Team """ -Demographics expansion utility - transform metadata format to VQA format -""" + +import json import logging -from typing import Dict, List, Any +from typing import Any, Dict, List + logger = logging.getLogger(__name__) class DemographicsExpander: - """ - Expand demographics from metadata format to VQA format with individual counts. - - Takes human-reviewed demographics from metadata_enhanced.json and transforms them - into the VQA format with individual counting per segment. - """ - + """Expand metadata demographics to VQA format with individual counts.""" + def __init__(self, config): - """ - Initialize expander with configuration. - + """Initialize expander with configuration. + Args: config: Configuration object """ self.config = config self.demographics_categories = config.demographics.categories - - def build_expansion_prompt(self, - metadata_demographics: Dict[str, List[str]], - segment_info: Dict = None) -> str: - """ - Build prompt for Gemini to expand demographics with counting. - + + def build_expansion_prompt( + self, + metadata_demographics: Dict[str, List[str]], + segment_info: Dict = None, + ) -> str: + """Build prompt for Gemini to expand demographics with counting. + Args: - metadata_demographics: Demographics from metadata_enhanced.json - Format: {"race": ["Arab"], "gender": ["Male"], "age": ["Young (18-24)"], "language": ["English"]} - segment_info: Optional segment information (start, end times) - - Returns: + metadata_demographics: From metadata_enhanced.json; keys race, + gender, age, language; values lists of strings. + segment_info: Optional dict with start/end times. + + Returns + ------- Prompt string for Gemini """ - race_list = metadata_demographics.get('race', []) - gender_list = metadata_demographics.get('gender', []) - age_list = metadata_demographics.get('age', []) - language_list = metadata_demographics.get('language', []) - + race_list = metadata_demographics.get("race", []) + gender_list = metadata_demographics.get("gender", []) + age_list = metadata_demographics.get("age", []) + language_list = metadata_demographics.get("language", []) + segment_context = "" if segment_info: - segment_context = f""" -This is a SEGMENT of the full video from {segment_info['start']}s to {segment_info['end']}s. -Only count individuals visible/audible in THIS segment. -""" - - prompt = f"""You are analyzing video content for demographics annotation. + seg_start = segment_info["start"] + seg_end = segment_info["end"] + segment_context = ( + f"\nThis is a SEGMENT from {seg_start}s to {seg_end}s.\n" + "Only count individuals visible/audible in THIS segment.\n" + ) + + return f"""You are analyzing video content for demographics annotation. HUMAN-REVIEWED DEMOGRAPHICS (Ground Truth): The video contains individuals with these demographic characteristics: - - Race/Ethnicity: {', '.join(race_list) if race_list else 'Not specified'} - - Gender: {', '.join(gender_list) if gender_list else 'Not specified'} - - Age: {', '.join(age_list) if age_list else 'Not specified'} - - Language: {', '.join(language_list) if language_list else 'Not specified'} + - Race/Ethnicity: {", ".join(race_list) if race_list else "Not specified"} + - Gender: {", ".join(gender_list) if gender_list else "Not specified"} + - Age: {", ".join(age_list) if age_list else "Not specified"} + - Language: {", ".join(language_list) if language_list else "Not specified"} {segment_context} @@ -93,7 +98,7 @@ def build_expansion_prompt(self, "demographics": [ {{ "race": "Arab", - "gender": "Male", + "gender": "Male", "age": "Young (18-24)", "language": "English", "count": 2 @@ -101,7 +106,7 @@ def build_expansion_prompt(self, {{ "race": "White", "gender": "Female", - "age": "Middle (25-39)", + "age": "Middle (25-39)", "language": "English", "count": 1 }} @@ -121,25 +126,22 @@ def build_expansion_prompt(self, - Confidence should be 0.8+ if you're using the provided demographics correctly Begin analysis:""" - - return prompt - + def parse_demographics_response(self, response_text: str) -> Dict[str, Any]: """ Parse Gemini's response and validate demographics format. - + Args: response_text: JSON response from Gemini - - Returns: + + Returns + ------- Parsed and validated demographics dict """ - import json - try: # Clean response response_text = response_text.strip() - + # Remove markdown code blocks if present if "```json" in response_text: start = response_text.find("```json") + 7 @@ -151,108 +153,118 @@ def parse_demographics_response(self, response_text: str) -> Dict[str, Any]: end = response_text.rfind("```") if end > start: response_text = response_text[start:end] - + # Parse JSON data = json.loads(response_text.strip()) - + # Validate structure - if 'demographics' not in data: + if "demographics" not in data: logger.error("Missing 'demographics' field in response") return self._get_empty_demographics() - - if not isinstance(data['demographics'], list): + + if not isinstance(data["demographics"], list): logger.error("'demographics' field is not a list") return self._get_empty_demographics() - + # Validate each demographic entry validated_demographics = [] total_count = 0 unknown_count = 0 - - for entry in data['demographics']: + + for entry in data["demographics"]: if not isinstance(entry, dict): continue - + # Ensure required fields - if 'count' not in entry: - logger.warning(f"Missing 'count' in demographic entry: {entry}") + if "count" not in entry: + logger.warning("Missing 'count' in entry: %s", entry) continue - + # Convert count to int if needed try: - count = int(entry['count']) + count = int(entry["count"]) except (ValueError, TypeError): - logger.warning(f"Invalid count value: {entry.get('count')}") + logger.warning("Invalid count: %s", entry.get("count")) count = 0 - + # Check for "Unknown" usage - race = entry.get('race', 'Unknown') - gender = entry.get('gender', 'Unknown') - age = entry.get('age', 'Unknown') - language = entry.get('language', 'Unknown') - + race = entry.get("race", "Unknown") + gender = entry.get("gender", "Unknown") + age = entry.get("age", "Unknown") + language = entry.get("language", "Unknown") + # Count how many unknowns in this entry - if race == 'Unknown' or gender == 'Unknown' or age == 'Unknown' or language == 'Unknown': + if ( + race == "Unknown" + or gender == "Unknown" + or age == "Unknown" + or language == "Unknown" + ): unknown_count += 1 - logger.warning(f"Entry contains 'Unknown' values: {entry}") - + logger.warning("Entry contains Unknown: %s", entry) + validated_entry = { - 'race': race, - 'gender': gender, - 'age': age, - 'language': language, - 'count': count + "race": race, + "gender": gender, + "age": age, + "language": language, + "count": count, } - + validated_demographics.append(validated_entry) total_count += count - + # Log warning if too many unknowns if unknown_count > 0: - logger.warning(f"{unknown_count} demographic entries contain 'Unknown' values - model may be too conservative") - - confidence = float(data.get('confidence', 0.0)) - + logger.warning( + "%d entries contain Unknown - model may be conservative", + unknown_count, + ) + + confidence = float(data.get("confidence", 0.0)) + # Reduce confidence if too many unknowns if unknown_count > len(validated_demographics) / 2: - logger.warning("More than 50% of entries have Unknown values, reducing confidence") + logger.warning(">50%% entries have Unknown, reducing confidence") confidence *= 0.5 - + return { - 'demographics': validated_demographics, - 'total_individuals': data.get('total_individuals', total_count), - 'confidence': confidence, - 'explanation': data.get('explanation', '') + "demographics": validated_demographics, + "total_individuals": data.get("total_individuals", total_count), + "confidence": confidence, + "explanation": data.get("explanation", ""), } - + except json.JSONDecodeError as e: - logger.error(f"Failed to parse demographics JSON: {e}") - logger.debug(f"Response text: {response_text[:500]}...") + logger.error("Failed to parse demographics JSON: %s", e) + logger.debug("Response: %s...", response_text[:500]) return self._get_empty_demographics() except Exception as e: - logger.error(f"Error parsing demographics response: {e}") + logger.error("Error parsing demographics: %s", e) return self._get_empty_demographics() - + def _get_empty_demographics(self) -> Dict[str, Any]: - """Return empty demographics structure""" + """Return empty demographics structure.""" return { - 'demographics': [], - 'total_individuals': 0, - 'confidence': 0.0, - 'explanation': 'Failed to parse demographics' + "demographics": [], + "total_individuals": 0, + "confidence": 0.0, + "explanation": "Failed to parse demographics", } - - def merge_segment_demographics(self, - segment_demographics: List[Dict[str, Any]]) -> Dict[str, Any]: + + def merge_segment_demographics( + self, segment_demographics: List[Dict[str, Any]] + ) -> Dict[str, Any]: """ Merge demographics from multiple segments for video-level summary. - + Uses maximum count seen across all segments (conservative approach). - + Args: segment_demographics: List of demographics dicts from each segment - - Returns: + + Returns + ------- Merged demographics dict """ # Track unique demographic combinations and their max counts @@ -260,40 +272,44 @@ def merge_segment_demographics(self, total_max = 0 all_explanations = [] min_confidence = 1.0 - + for seg_demo in segment_demographics: - for entry in seg_demo.get('demographics', []): + for entry in seg_demo.get("demographics", []): # Create key from demographic attributes key = ( - entry.get('race', 'Unknown'), - entry.get('gender', 'Unknown'), - entry.get('age', 'Unknown'), - entry.get('language', 'Unknown') + entry.get("race", "Unknown"), + entry.get("gender", "Unknown"), + entry.get("age", "Unknown"), + entry.get("language", "Unknown"), ) - - count = entry.get('count', 0) - - # Keep maximum count seen for this combination - if key not in demographic_map or count > demographic_map[key]['count']: + + count = entry.get("count", 0) + + # Keep maximum count for this combination + if key not in demographic_map or count > demographic_map[key].get( + "count", 0 + ): demographic_map[key] = entry.copy() - + # Track total - total_max = max(total_max, seg_demo.get('total_individuals', 0)) - + total_max = max(total_max, seg_demo.get("total_individuals", 0)) + # Collect explanations - if seg_demo.get('explanation'): - all_explanations.append(seg_demo['explanation']) - + if seg_demo.get("explanation"): + all_explanations.append(seg_demo["explanation"]) + # Track minimum confidence - min_confidence = min(min_confidence, seg_demo.get('confidence', 1.0)) - + min_confidence = min(min_confidence, seg_demo.get("confidence", 1.0)) + # Convert back to list merged_demographics = list(demographic_map.values()) - + return { - 'demographics': merged_demographics, - 'total_individuals': total_max, - 'confidence': min_confidence, - 'explanation': f"Merged from {len(segment_demographics)} segments. " + - " | ".join(all_explanations[:2]) - } \ No newline at end of file + "demographics": merged_demographics, + "total_individuals": total_max, + "confidence": min_confidence, + "explanation": ( + "Merged from %d segments. " % len(segment_demographics) + + " | ".join(all_explanations[:2]) + ), + } diff --git a/sonic-o1/04_vqa_generation/utils/file_utils.py b/sonic-o1/04_vqa_generation/utils/file_utils.py new file mode 100644 index 0000000..f3172f2 --- /dev/null +++ b/sonic-o1/04_vqa_generation/utils/file_utils.py @@ -0,0 +1,44 @@ +"""file_utils.py. + +Shared file I/O helpers for VQA generation scripts. + +Author: SONIC-O1 Team +""" + +import json +import logging +from pathlib import Path +from typing import Any, Dict + + +logger = logging.getLogger(__name__) + + +def save_json_with_backup( + data: Dict[str, Any], + json_path: Path, + backup_suffix: str = ".json.backup", +) -> None: + """Write a backup of *json_path* then overwrite it with *data*. + + Args: + data: JSON-serialisable data to write. + json_path: Destination path (original file to overwrite). + backup_suffix: Suffix appended to *json_path* for the backup file. + Defaults to ".json.backup". + + Raises + ------ + Exception: Re-raises any I/O error after logging it. + """ + try: + backup_path = json_path.with_suffix(backup_suffix) + with open(backup_path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2, ensure_ascii=False) + logger.info("Created backup: %s", backup_path.name) + + with open(json_path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2, ensure_ascii=False) + except Exception as e: + logger.error("Failed to save %s: %s", json_path, e) + raise diff --git a/sonic-o1/04_vqa_generation/utils/frame_sampler.py b/sonic-o1/04_vqa_generation/utils/frame_sampler.py index 4345d40..c8fcb18 100644 --- a/sonic-o1/04_vqa_generation/utils/frame_sampler.py +++ b/sonic-o1/04_vqa_generation/utils/frame_sampler.py @@ -1,19 +1,24 @@ -""" -Frame Sampler Utility +"""frame_sampler.py. + +Extract sample frames from video segments for GPT-4V validation. +Uses PyAV for frame extraction (faster than FFmpeg subprocess). -Extracts sample frames from video segments for GPT-4V validation. -Uses PyAV for direct video frame extraction (faster and more reliable than FFmpeg subprocess). +Author: SONIC-O1 Team """ + import logging -from pathlib import Path -from typing import List, Dict, Optional, Tuple -import tempfile import shutil +import tempfile +import traceback +from pathlib import Path +from typing import List, Tuple + logger = logging.getLogger(__name__) try: import av + PYAV_AVAILABLE = True except ImportError: PYAV_AVAILABLE = False @@ -21,333 +26,370 @@ class FrameSampler: - """Sample frames from video segments for visual validation""" - + """Sample frames from video segments for visual validation.""" + def __init__(self, config=None): """ Initialize frame sampler. - + Args: config: Optional configuration object """ if not PYAV_AVAILABLE: raise ImportError("PyAV is required. Install with: pip install av") - + self.config = config # Use scratch directory like video segmenter - scratch_base = Path.home() / 'scratch' / 'frame_sampler' + scratch_base = Path.home() / "scratch" / "frame_sampler" scratch_base.mkdir(parents=True, exist_ok=True) - self.temp_dir = Path(tempfile.mkdtemp(prefix='frames_', dir=scratch_base)) + self.temp_dir = Path(tempfile.mkdtemp(prefix="frames_", dir=scratch_base)) self._cleaned_up = False - logger.info(f"Frame sampler temporary directory: {self.temp_dir}") - + logger.info("Frame sampler temp dir: %s", self.temp_dir) + def sample_frames_from_segment( self, video_path: Path, segment_start: float, segment_end: float, num_frames: int = 8, - strategy: str = 'uniform' + strategy: str = "uniform", ) -> List[Path]: """ Sample frames from a video segment. - + Args: video_path: Path to video file segment_start: Start time of segment in seconds segment_end: End time of segment in seconds num_frames: Number of frames to sample strategy: Sampling strategy ('uniform', 'keyframes', or 'adaptive') - - Returns: - List of paths to extracted frame images + + Returns + ------- + list of Path + Paths to extracted frame images. """ try: - if strategy == 'uniform': + if strategy == "uniform": return self._sample_uniform_frames( video_path, segment_start, segment_end, num_frames ) - elif strategy == 'keyframes': + if strategy == "keyframes": return self._sample_keyframes( video_path, segment_start, segment_end, num_frames ) - elif strategy == 'adaptive': + if strategy == "adaptive": return self._sample_adaptive_frames( video_path, segment_start, segment_end, num_frames ) - else: - raise ValueError(f"Unknown sampling strategy: {strategy}") - + raise ValueError("Unknown sampling strategy: %s" % strategy) + except Exception as e: - logger.error(f"Error sampling frames: {e}") + logger.error("Error sampling frames: %s", e) return [] - + def _sample_uniform_frames( self, video_path: Path, segment_start: float, segment_end: float, - num_frames: int + num_frames: int, ) -> List[Path]: - """ - Sample frames uniformly across the segment using PyAV. - """ + """Sample frames uniformly across the segment using PyAV.""" frame_paths = [] - + # Ensure temp directory exists if not self.temp_dir.exists(): self.temp_dir.mkdir(parents=True, exist_ok=True) - logger.warning(f"Temp dir didn't exist, recreated: {self.temp_dir}") - + logger.warning("Temp dir recreated: %s", self.temp_dir) + # Add epsilon buffer epsilon = 0.1 safe_end = max(segment_start, segment_end - epsilon) safe_duration = safe_end - segment_start - + # Calculate timestamps if num_frames == 1: timestamps = [segment_start + safe_duration / 2] else: interval = safe_duration / (num_frames - 1) timestamps = [segment_start + i * interval for i in range(num_frames)] - + try: container = av.open(str(video_path)) video_stream = container.streams.video[0] time_base = video_stream.time_base - + for i, timestamp in enumerate(timestamps): try: - frame_path = self.temp_dir / f"frame_{i:03d}_t{timestamp:.2f}s.jpg" - + frame_path = ( + self.temp_dir / "frame_%03d_t%.2fs.jpg" % (i, timestamp) + ) + # Convert timestamp to PTS pts = int(timestamp / float(time_base)) - + # Seek to timestamp container.seek(pts, stream=video_stream) - + # Decode next frame frame_found = False for frame in container.decode(video=0): img = frame.to_image() - img.save(str(frame_path), 'JPEG', quality=95) + img.save(str(frame_path), "JPEG", quality=95) if frame_path.exists(): frame_paths.append(frame_path) - logger.debug(f"Extracted frame at {timestamp:.2f}s") + logger.debug("Extracted frame at %.2fs", timestamp) frame_found = True else: - logger.warning(f"Frame saved but file doesn't exist: {frame_path}") + logger.warning("Frame saved but missing: %s", frame_path) break - + if not frame_found: - logger.warning(f"No frame decoded at {timestamp:.2f}s") - + logger.warning("No frame at %.2fs", timestamp) + except Exception as e: - logger.warning(f"Failed to extract frame at {timestamp:.2f}s: {e}") - import traceback - logger.warning(f"Traceback: {traceback.format_exc()}") + logger.warning("Failed frame at %.2fs: %s", timestamp, e) + logger.warning("Traceback: %s", traceback.format_exc()) continue - + container.close() - + except Exception as e: - logger.error(f"Error opening video: {e}") + logger.error("Error opening video: %s", e) return [] - - logger.info(f"Extracted {len(frame_paths)}/{num_frames} uniform frames") + + logger.info( + "Extracted %d/%d uniform frames", + len(frame_paths), + num_frames, + ) return frame_paths - + def _sample_keyframes( self, video_path: Path, segment_start: float, segment_end: float, - num_frames: int + num_frames: int, ) -> List[Path]: - """ - Sample keyframes (I-frames) from the segment using PyAV. - """ + """Sample keyframes (I-frames) from the segment using PyAV.""" frame_paths = [] - + try: container = av.open(str(video_path)) video_stream = container.streams.video[0] time_base = float(video_stream.time_base) - + # Convert to PTS start_pts = int(segment_start / time_base) - end_pts = int(segment_end / time_base) - + int(segment_end / time_base) + container.seek(start_pts, stream=video_stream) - + keyframe_count = 0 for frame in container.decode(video=0): frame_time = frame.pts * time_base - + if frame_time > segment_end: break if frame_time < segment_start: continue - + # Only keyframes if frame.key_frame: - frame_path = self.temp_dir / f"keyframe_{keyframe_count:03d}_t{frame_time:.2f}s.jpg" + name = "keyframe_%03d_t%.2fs.jpg" % (keyframe_count, frame_time) + frame_path = self.temp_dir / name img = frame.to_image() img.save(str(frame_path), quality=95) frame_paths.append(frame_path) keyframe_count += 1 - + if keyframe_count >= num_frames: break - + container.close() - - logger.info(f"Extracted {len(frame_paths)} keyframes") - + + logger.info("Extracted %d keyframes", len(frame_paths)) + # Supplement if needed if len(frame_paths) < num_frames: - logger.info(f"Supplementing with uniform frames") + logger.info("Supplementing with uniform frames") uniform_frames = self._sample_uniform_frames( - video_path, segment_start, segment_end, num_frames - len(frame_paths) + video_path, + segment_start, + segment_end, + num_frames - len(frame_paths), ) frame_paths.extend(uniform_frames) - + except Exception as e: - logger.error(f"Error extracting keyframes: {e}") - return self._sample_uniform_frames(video_path, segment_start, segment_end, num_frames) - + logger.error("Error extracting keyframes: %s", e) + return self._sample_uniform_frames( + video_path, segment_start, segment_end, num_frames + ) + return frame_paths[:num_frames] - + def _sample_adaptive_frames( self, video_path: Path, segment_start: float, segment_end: float, - num_frames: int + num_frames: int, ) -> List[Path]: - """ - Adaptive sampling: denser at start/end, sparse in middle. - """ + """Adaptive sampling: denser at start/end, sparse in middle.""" epsilon = 0.1 safe_end = max(segment_start, segment_end - epsilon) - + # Adaptive distribution num_start = max(2, int(num_frames * 0.3)) num_end = max(2, int(num_frames * 0.3)) num_middle = num_frames - num_start - num_end - + timestamps = [] - + # Start frames start_zone = (safe_end - segment_start) * 0.2 for i in range(num_start): - t = segment_start + (i / (num_start - 1) if num_start > 1 else 0.5) * start_zone + t = ( + segment_start + + (i / (num_start - 1) if num_start > 1 else 0.5) * start_zone + ) timestamps.append(t) - + # Middle frames middle_start = segment_start + start_zone middle_end = safe_end - start_zone middle_duration = middle_end - middle_start for i in range(num_middle): - t = middle_start + (i / (num_middle - 1) if num_middle > 1 else 0.5) * middle_duration + t = ( + middle_start + + (i / (num_middle - 1) if num_middle > 1 else 0.5) * middle_duration + ) timestamps.append(t) - + # End frames end_zone_start = safe_end - start_zone for i in range(num_end): - t = end_zone_start + (i / (num_end - 1) if num_end > 1 else 0.5) * start_zone + t = ( + end_zone_start + + (i / (num_end - 1) if num_end > 1 else 0.5) * start_zone + ) timestamps.append(t) - + # Extract frames frame_paths = [] - + try: container = av.open(str(video_path)) video_stream = container.streams.video[0] time_base = video_stream.time_base - + for i, timestamp in enumerate(sorted(timestamps)): try: - frame_path = self.temp_dir / f"adaptive_frame_{i:03d}_t{timestamp:.2f}s.jpg" - + frame_path = ( + self.temp_dir + / "adaptive_frame_%03d_t%.2fs.jpg" + % (i, timestamp) + ) + pts = int(timestamp / float(time_base)) container.seek(pts, stream=video_stream) - + for frame in container.decode(video=0): img = frame.to_image() img.save(str(frame_path), quality=95) frame_paths.append(frame_path) break - + except Exception as e: - logger.error(f"Error at {timestamp:.2f}s: {e}") + logger.error("Error at %.2fs: %s", timestamp, e) continue - + container.close() - + except Exception as e: - logger.error(f"Error with adaptive sampling: {e}") + logger.error("Error adaptive sampling: %s", e) return [] - - logger.info(f"Extracted {len(frame_paths)} adaptive frames") + + logger.info("Extracted %d adaptive frames", len(frame_paths)) return frame_paths - + def sample_frames_at_timestamps( self, video_path: Path, timestamps: List[float], - segment_start: float = 0.0 + segment_start: float = 0.0, ) -> List[Tuple[float, Path]]: """ Sample frames at specific timestamps using PyAV. + + Args + ---- + video_path : Path + Path to video file. + timestamps : list of float + Timestamps in seconds. + segment_start : float + Segment start (for relative naming). + + Returns + ------- + list of (float, Path) + (timestamp, frame_path) pairs. """ frame_data = [] - + try: container = av.open(str(video_path)) video_stream = container.streams.video[0] time_base = video_stream.time_base - + for timestamp in timestamps: if timestamp is None: continue - + try: relative_time = timestamp - segment_start - frame_path = self.temp_dir / f"verify_t{timestamp:.2f}s_rel{relative_time:.2f}s.jpg" - + frame_path = ( + self.temp_dir + / "verify_t%.2fs_rel%.2fs.jpg" + % (timestamp, relative_time) + ) + pts = int(timestamp / float(time_base)) container.seek(pts, stream=video_stream) - + for frame in container.decode(video=0): img = frame.to_image() img.save(str(frame_path), quality=95) frame_data.append((timestamp, frame_path)) - logger.debug(f"Extracted frame at {timestamp:.2f}s") + logger.debug("Extracted frame at %.2fs", timestamp) break - + except Exception as e: - logger.error(f"Error at {timestamp:.2f}s: {e}") + logger.error("Error at %.2fs: %s", timestamp, e) continue - + container.close() - + except Exception as e: - logger.error(f"Error sampling timestamps: {e}") + logger.error("Error sampling timestamps: %s", e) return [] - - logger.info(f"Extracted {len(frame_data)} verification frames") + + logger.info("Extracted %d verification frames", len(frame_data)) return frame_data - + def cleanup(self): - """Clean up temporary frame files""" + """Clean up temporary frame files.""" if self._cleaned_up: return - + try: if self.temp_dir and self.temp_dir.exists(): shutil.rmtree(self.temp_dir) - logger.info(f"Cleaned up frame sampler temp directory: {self.temp_dir}") + logger.info("Cleaned up frame sampler temp: %s", self.temp_dir) self._cleaned_up = True except Exception as e: - logger.warning(f"Failed to cleanup frame sampler temp dir: {e}") + logger.warning("Failed to cleanup frame sampler: %s", e) diff --git a/sonic-o1/04_vqa_generation/utils/video_segmenter.py b/sonic-o1/04_vqa_generation/utils/video_segmenter.py index d7659af..be5a341 100644 --- a/sonic-o1/04_vqa_generation/utils/video_segmenter.py +++ b/sonic-o1/04_vqa_generation/utils/video_segmenter.py @@ -1,384 +1,560 @@ """ -Video segmentation utility using FFmpeg +video_segmenter.py. + +Video segmentation utility using FFmpeg. + +Author: SONIC-O1 Team """ -import subprocess -import tempfile -import shutil -import re + import logging -from pathlib import Path -from typing import List, Dict, Optional +import re +import shutil +import subprocess import time +from pathlib import Path +from typing import Dict, List, Optional, Set + logger = logging.getLogger(__name__) +def _default_temp_segment_dir(task_type: str, kind: str) -> Path: + """ + Project-local temp under sonic-o1/.tmp_video_segments/ (removed after use via cleanup_segments). + + `kind` is "video" or "audio" so parallel calls never collide. + """ + # .../04_vqa_generation/utils/video_segmenter.py -> sonic-o1 + sonic_o1 = Path(__file__).resolve().parent.parent.parent + root = sonic_o1 / ".tmp_video_segments" + sub = root / f"{task_type}_{kind}_{time.time_ns()}" + sub.mkdir(parents=True, exist_ok=True) + return sub + + class VideoSegmenter: - """Handle video segmentation using FFmpeg""" - + """Handle video segmentation using FFmpeg.""" + def __init__(self, config): """ Initialize segmenter with configuration. - + Args: config: Configuration object with video settings """ - self.summarization_segment_duration = int(config.video.summarization_segment_duration) + self.summarization_segment_duration = int( + config.video.summarization_segment_duration + ) self.mcq_segment_duration = int(config.video.mcq_segment_duration) - self.temporal_localization_segment_duration = int(config.video.temporal_localization_segment_duration) + self.temporal_localization_segment_duration = int( + config.video.temporal_localization_segment_duration + ) self.segment_overlap = int(config.video.segment_overlap) - + @staticmethod def get_actual_duration(video_path: Path) -> float: """ Get actual video duration using multiple methods. - - FIXED: Try stream duration first, then format duration. + + Tries stream duration first, then format duration. Stream duration is more reliable for videos with metadata issues. + + Returns + ------- + float + Duration in seconds. + + Raises + ------ + Exception + If duration cannot be obtained from the video. """ # Method 1: Try stream duration (more reliable) cmd = [ - "ffprobe", "-v", "error", - "-select_streams", "v:0", - "-show_entries", "stream=duration", - "-of", "default=nw=1:nk=1", + "ffprobe", + "-v", + "error", + "-select_streams", + "v:0", + "-show_entries", + "stream=duration", + "-of", + "default=nw=1:nk=1", str(video_path), ] - + try: result = subprocess.run( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, - timeout=10 + timeout=10, + check=False, ) - + if result.returncode == 0: output = result.stdout.strip() - if output and output != 'N/A': + if output and output != "N/A": try: duration = float(output) if duration > 0: - logger.debug(f"Got stream duration: {duration:.3f}s") + logger.debug("Got stream duration: %.3fs", duration) return duration except ValueError: pass except Exception as e: - logger.debug(f"Stream duration failed: {e}") - + logger.debug("Stream duration failed: %s", e) + # Method 2: Try format duration (fallback) cmd = [ - "ffprobe", "-v", "error", - "-show_entries", "format=duration", - "-of", "default=nw=1:nk=1", + "ffprobe", + "-v", + "error", + "-show_entries", + "format=duration", + "-of", + "default=nw=1:nk=1", str(video_path), ] - + try: result = subprocess.run( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, - timeout=10 + timeout=10, + check=False, ) - + if result.returncode == 0: output = result.stdout.strip() if output: try: duration = float(output) if duration > 0: - logger.debug(f"Got format duration: {duration:.3f}s") + logger.debug("Got format duration: %.3fs", duration) return duration except ValueError: pass except Exception as e: - logger.debug(f"Format duration failed: {e}") - - raise Exception(f"Could not get duration from {video_path}") - - def segment_video(self, - video_path: Path, - duration_seconds: float, - task_type: str = 'summarization', - output_dir: Optional[Path] = None) -> List[Dict]: + logger.debug("Format duration failed: %s", e) + + raise Exception("Could not get duration from %s" % video_path) + + def segment_video( + self, + video_path: Path, + duration_seconds: float, + task_type: str = "summarization", + output_dir: Optional[Path] = None, + ) -> List[Dict]: """ Segment video into chunks based on task type. + + Args + ---- + video_path : Path + Path to the video file. + duration_seconds : float + Duration in seconds (from metadata; ffprobe may override). + task_type : str + One of "summarization", "mcq", "temporal_localization". + output_dir : Path, optional + Directory for segment files; created if None. + + Returns + ------- + list of dict + Each dict has segment_path, start, end, duration, + segment_number, and optionally is_temp. """ - try: actual_duration = self.get_actual_duration(video_path) - - # FIXED: Detect severe duration mismatch (likely corrupted metadata) - duration_ratio = actual_duration / duration_seconds if duration_seconds > 0 else 999 - + + # Detect severe duration mismatch (likely corrupted metadata) + duration_ratio = ( + actual_duration / duration_seconds if duration_seconds > 0 else 999 + ) + if duration_ratio > 5 or duration_ratio < 0.2: logger.error( - f"SEVERE duration mismatch for {video_path}: " - f"metadata={duration_seconds:.1f}s, ffprobe={actual_duration:.1f}s (ratio={duration_ratio:.1f}x). " - f"Video metadata is likely corrupted. Using metadata value to be safe." + "SEVERE duration mismatch for %s: metadata=%.1fs, " + "ffprobe=%.1fs (ratio=%.1fx). Video metadata likely " + "corrupted. Using metadata value.", + video_path, + duration_seconds, + actual_duration, + duration_ratio, ) - # Don't trust ffprobe when ratio is extreme actual_duration = duration_seconds elif abs(actual_duration - duration_seconds) > 0.5: logger.warning( - f"Duration mismatch for {video_path}: " - f"metadata={duration_seconds:.3f}s, ffprobe={actual_duration:.3f}s. " - f"Using ffprobe value." + "Duration mismatch for %s: metadata=%.3fs, " + "ffprobe=%.3fs. Using ffprobe value.", + video_path, + duration_seconds, + actual_duration, ) - + duration_seconds = actual_duration except Exception as e: logger.warning( - f"Could not get actual duration with ffprobe for {video_path}, " - f"falling back to provided {duration_seconds:.3f}s. Error: {e}" + "Could not get actual duration with ffprobe for %s, " + "falling back to %.3fs. Error: %s", + video_path, + duration_seconds, + e, ) # Small epsilon to avoid sampling exactly at the end epsilon = 0.05 duration_seconds = max(0.0, duration_seconds - epsilon) - if task_type == 'summarization': + if task_type == "summarization": max_segment_duration = self.summarization_segment_duration - elif task_type == 'mcq': + elif task_type == "mcq": max_segment_duration = self.mcq_segment_duration - elif task_type == 'temporal_localization': + elif task_type == "temporal_localization": max_segment_duration = self.temporal_localization_segment_duration else: max_segment_duration = self.mcq_segment_duration - logger.warning(f"Unknown task_type '{task_type}', defaulting to MCQ segment duration") - + logger.warning( + "Unknown task_type '%s', defaulting to MCQ segment duration", + task_type, + ) + if duration_seconds <= max_segment_duration: logger.info( - f"Video duration ({duration_seconds:.3f}s) <= max segment ({max_segment_duration}s) " - f"for {task_type}, returning as single segment" + "Video duration (%.3fs) <= max segment (%s)s for %s, " + "returning as single segment", + duration_seconds, + max_segment_duration, + task_type, ) - return [{ - 'segment_path': video_path, - 'start': 0.0, - 'end': duration_seconds, - 'duration': duration_seconds, - 'segment_number': 0 - }] - + return [ + { + "segment_path": video_path, + "start": 0.0, + "end": duration_seconds, + "duration": duration_seconds, + "segment_number": 0, + } + ] + if output_dir is None: - output_dir = Path.home() / 'scratch' / 'video_segments' / f"{task_type}_{int(time.time())}" - output_dir.mkdir(parents=True, exist_ok=True) + output_dir = _default_temp_segment_dir(task_type, "video") temp_dir = output_dir else: output_dir.mkdir(parents=True, exist_ok=True) temp_dir = None - + num_segments = int(duration_seconds / max_segment_duration) + 1 logger.info( - f"Segmenting {duration_seconds:.3f}s video for {task_type} into {num_segments} chunks " - f"(max {max_segment_duration}s each with {self.segment_overlap}s overlap)" + "Segmenting %.3fs video for %s into %d chunks " + "(max %ss each with %ss overlap)", + duration_seconds, + task_type, + num_segments, + max_segment_duration, + self.segment_overlap, ) - + segments = [] - + try: for i in range(num_segments): - start_time = max(0, i * max_segment_duration - (self.segment_overlap if i > 0 else 0)) + start_time = max( + 0, + i * max_segment_duration - (self.segment_overlap if i > 0 else 0), + ) segment_duration = min( max_segment_duration + self.segment_overlap, - duration_seconds - start_time + duration_seconds - start_time, ) if segment_duration <= 0: break end_time = start_time + segment_duration - - segment_path = output_dir / f"segment_{i:03d}{video_path.suffix}" - - logger.info(f"Creating segment {i+1}/{num_segments}: {start_time:.1f}s - {end_time:.1f}s") - + + segment_path = output_dir / ("segment_%03d%s" % (i, video_path.suffix)) + + logger.info( + "Creating segment %d/%d: %.1fs - %.1fs", + i + 1, + num_segments, + start_time, + end_time, + ) + # FIXED: Use copy codec when possible (much faster) # Calculate timeout based on segment duration (2x for safety) timeout = max(300, int(segment_duration * 2)) - + cmd = [ - 'ffmpeg', '-y', - '-ss', str(start_time), - '-i', str(video_path), - '-t', str(segment_duration), - '-c', 'copy', # FIXED: Copy codec (no re-encoding) - '-avoid_negative_ts', 'make_zero', # FIXED: Better timestamp handling - str(segment_path) + "ffmpeg", + "-y", + "-ss", + str(start_time), + "-i", + str(video_path), + "-t", + str(segment_duration), + "-c", + "copy", # FIXED: Copy codec (no re-encoding) + "-avoid_negative_ts", + "make_zero", # FIXED: Better timestamp handling + str(segment_path), ] - + result = subprocess.run( - cmd, - capture_output=True, + cmd, + capture_output=True, text=True, - timeout=timeout + timeout=timeout, + check=False, ) - + if result.returncode != 0: - logger.error(f"FFmpeg error: {result.stderr}") - raise Exception(f"Failed to create segment {i}: {result.stderr}") - + logger.error("FFmpeg error: %s", result.stderr) + raise Exception( + "Failed to create segment %d: %s" % (i, result.stderr) + ) + if not segment_path.exists(): - raise Exception(f"Segment file not created: {segment_path}") - - segments.append({ - 'segment_path': segment_path, - 'start': start_time, - 'end': end_time, - 'duration': segment_duration, - 'segment_number': i, - 'is_temp': temp_dir is not None - }) - - logger.info(f"Successfully created {len(segments)} segments for {task_type}") + raise Exception("Segment file not created: %s" % segment_path) + + segments.append( + { + "segment_path": segment_path, + "start": start_time, + "end": end_time, + "duration": segment_duration, + "segment_number": i, + "is_temp": temp_dir is not None, + } + ) + + logger.info( + "Successfully created %d segments for %s", + len(segments), + task_type, + ) return segments - + except Exception as e: - logger.error(f"Error during segmentation: {e}") + logger.error("Error during segmentation: %s", e) if temp_dir and temp_dir.exists(): shutil.rmtree(temp_dir) raise - - def segment_audio(self, - audio_path: Path, - duration_seconds: float, - task_type: str = 'summarization', - output_dir: Optional[Path] = None) -> List[Dict]: + def segment_audio( + self, + audio_path: Path, + duration_seconds: float, + task_type: str = "summarization", + output_dir: Optional[Path] = None, + ) -> List[Dict]: """ Segment audio file into chunks based on task type. + + Args + ---- + audio_path : Path + Path to the audio file. + duration_seconds : float + Duration in seconds. + task_type : str + One of "summarization", "mcq", "temporal_localization". + output_dir : Path, optional + Directory for segment files; created if None. + + Returns + ------- + list of dict + Each dict has segment_path, start, end, duration, + segment_number, and optionally is_temp. """ - if task_type == 'summarization': + if task_type == "summarization": max_segment_duration = self.summarization_segment_duration - elif task_type == 'mcq': + elif task_type == "mcq": max_segment_duration = self.mcq_segment_duration - elif task_type == 'temporal_localization': + elif task_type == "temporal_localization": max_segment_duration = self.temporal_localization_segment_duration else: max_segment_duration = self.mcq_segment_duration - logger.warning(f"Unknown task_type '{task_type}', defaulting to MCQ segment duration") - - + logger.warning( + "Unknown task_type '%s', defaulting to MCQ segment duration", + task_type, + ) + if duration_seconds <= max_segment_duration: - return [{ - 'segment_path': audio_path, - 'start': 0, - 'end': duration_seconds, - 'duration': duration_seconds, - 'segment_number': 0 - }] - + return [ + { + "segment_path": audio_path, + "start": 0, + "end": duration_seconds, + "duration": duration_seconds, + "segment_number": 0, + } + ] + if output_dir is None: - output_dir = Path.home() / 'scratch' / 'audio_segments' / f"{task_type}_{int(time.time())}" - output_dir.mkdir(parents=True, exist_ok=True) + output_dir = _default_temp_segment_dir(task_type, "audio") temp_dir = output_dir else: output_dir.mkdir(parents=True, exist_ok=True) temp_dir = None num_segments = int(duration_seconds / max_segment_duration) + 1 segments = [] - + try: for i in range(num_segments): - start_time = max(0, i * max_segment_duration - (self.segment_overlap if i > 0 else 0)) + start_time = max( + 0, + i * max_segment_duration - (self.segment_overlap if i > 0 else 0), + ) segment_duration = min( max_segment_duration + self.segment_overlap, - duration_seconds - start_time + duration_seconds - start_time, ) - - segment_path = output_dir / f"segment_{i:03d}{audio_path.suffix}" - + + segment_path = output_dir / ("segment_%03d%s" % (i, audio_path.suffix)) + cmd = [ - 'ffmpeg', '-y', - '-ss', str(start_time), - '-i', str(audio_path), - '-t', str(segment_duration), - '-c', 'copy', - str(segment_path) + "ffmpeg", + "-y", + "-ss", + str(start_time), + "-i", + str(audio_path), + "-t", + str(segment_duration), + "-c", + "copy", + str(segment_path), ] - - result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) - + + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=300, + check=False, + ) + if result.returncode != 0: - logger.error(f"FFmpeg audio segment error: {result.stderr}") - raise Exception(f"Failed to create audio segment {i}") - - segments.append({ - 'segment_path': segment_path, - 'start': start_time, - 'end': start_time + segment_duration, - 'duration': segment_duration, - 'segment_number': i, - 'is_temp': temp_dir is not None - }) - - logger.info(f"Successfully created {len(segments)} audio segments for {task_type}") + logger.error("FFmpeg audio segment error: %s", result.stderr) + raise Exception("Failed to create audio segment %d" % i) + + segments.append( + { + "segment_path": segment_path, + "start": start_time, + "end": start_time + segment_duration, + "duration": segment_duration, + "segment_number": i, + "is_temp": temp_dir is not None, + } + ) + + logger.info( + "Successfully created %d audio segments for %s", + len(segments), + task_type, + ) return segments - - except Exception as e: + + except Exception: if temp_dir and temp_dir.exists(): shutil.rmtree(temp_dir) raise - - def extract_transcript_segment(self, - transcript_path: Path, - start_time: float, - end_time: float, - strip_timestamps: bool = False) -> str: + + def extract_transcript_segment( + self, + transcript_path: Path, + start_time: float, + end_time: float, + strip_timestamps: bool = False, + ) -> str: """ Extract portion of SRT transcript for a time segment. + + Args + ---- + transcript_path : Path + Path to SRT file. + start_time : float + Segment start time in seconds. + end_time : float + Segment end time in seconds. + strip_timestamps : bool + If True, return plain text; else keep SRT blocks. + + Returns + ------- + str + Extracted transcript text or empty string on error. """ try: - with open(transcript_path, 'r', encoding='utf-8') as f: + with open(transcript_path, "r", encoding="utf-8") as f: content = f.read() - - segments = content.strip().split('\n\n') + + segments = content.strip().split("\n\n") extracted = [] - + for segment in segments: - lines = segment.split('\n') + lines = segment.split("\n") if len(lines) < 3: continue - - timestamp_pattern = r'(\d{2}):(\d{2}):(\d{2}),(\d{3})\s*-->\s*(\d{2}):(\d{2}):(\d{2}),(\d{3})' + + timestamp_pattern = ( + r"(\d{2}):(\d{2}):(\d{2}),(\d{3})\s*-->\s*" + r"(\d{2}):(\d{2}):(\d{2}),(\d{3})" + ) match = re.search(timestamp_pattern, lines[1]) - + if match: h1, m1, s1, ms1, h2, m2, s2, ms2 = map(int, match.groups()) - seg_start = h1*3600 + m1*60 + s1 + ms1/1000 - seg_end = h2*3600 + m2*60 + s2 + ms2/1000 - + seg_start = h1 * 3600 + m1 * 60 + s1 + ms1 / 1000 + seg_end = h2 * 3600 + m2 * 60 + s2 + ms2 / 1000 + if seg_start < end_time and seg_end > start_time: if strip_timestamps: text_lines = lines[2:] - extracted.append(' '.join(text_lines)) + extracted.append(" ".join(text_lines)) else: extracted.append(segment) - + if strip_timestamps: - return ' '.join(extracted) - else: - return '\n\n'.join(extracted) - + return " ".join(extracted) + return "\n\n".join(extracted) + except Exception as e: - logger.warning(f"Could not extract transcript segment: {e}") + logger.warning("Could not extract transcript segment: %s", e) return "" - + def cleanup_segments(self, segments: List[Dict]): """ - Cleanup temporary segment files. + Clean up temporary segment files. + + Args + ---- + segments : list of dict + List of segment dicts with segment_path and is_temp. """ + seen: Set[Path] = set() for seg in segments: - if seg.get('is_temp', False): - try: - seg_path = seg['segment_path'] - if seg_path.exists(): - temp_dir = seg_path.parent - if temp_dir.exists() and 'segments' in temp_dir.name: - shutil.rmtree(temp_dir) - logger.info(f"Cleaned up temp directory: {temp_dir}") - break - except Exception as e: - logger.warning(f"Failed to cleanup segment: {e}") \ No newline at end of file + if not seg.get("is_temp", False): + continue + try: + seg_path = Path(seg["segment_path"]) + temp_dir = seg_path.parent + if temp_dir in seen or not temp_dir.exists(): + continue + seen.add(temp_dir) + shutil.rmtree(temp_dir) + logger.info("Cleaned up temp directory: %s", temp_dir) + except Exception as e: + logger.warning("Failed to cleanup segment: %s", e) diff --git a/sonic-o1/04_vqa_generation/config/vqa_config.yaml b/sonic-o1/04_vqa_generation/vqa_config.yaml similarity index 97% rename from sonic-o1/04_vqa_generation/config/vqa_config.yaml rename to sonic-o1/04_vqa_generation/vqa_config.yaml index faab81f..a4efb4e 100644 --- a/sonic-o1/04_vqa_generation/config/vqa_config.yaml +++ b/sonic-o1/04_vqa_generation/vqa_config.yaml @@ -8,7 +8,7 @@ gemini: retry_delay: 20 # Base delay in seconds (will use exponential backoff) file_processing_timeout: 7200 # 2 hours for large files max_output_tokens: 2048 - + # Video Processing Settings video: # Segment duration settings per task @@ -16,7 +16,7 @@ video: mcq_segment_duration: 180 # 3 minutes for MCQ (each treated independently) temporal_localization_segment_duration: 180 # 3 minutes for temporal localization segment_overlap: 30 # Overlap between segments in seconds (30 sec) - + # File size thresholds for processing strategy file_processing: inline_threshold_mb: 20 # Files >20MB use File API @@ -47,7 +47,7 @@ summarization: mcq: num_options: 5 # Always 5 options (4 + "Not enough evidence") questions_per_segment: 1 # How many MCQs per segment - + # Controlled vocabulary for evidence tags evidence_tags: - signage @@ -83,22 +83,22 @@ mcq: - crowd_cheering - sirens - police_lights - + # Task 3: Temporal Localization Settings temporal_localization: questions_per_segment: 3 # Number of temporal questions per 3-min segment - + # Temporal relation distribution (optional - for balanced generation) temporal_relations: after: 0.25 # 25% of questions use "after" relation once_finished: 0.25 # 25% use "once_finished" next: 0.20 # 20% use "next" during: 0.15 # 15% use "during" - before: 0.15 # 15% use "before" - + before: 0.15 # 15% use "before" + # Minimum confidence for temporal localization min_confidence: 0.5 - + # Abstention threshold (if confidence below this, consider abstaining) abstention_threshold: 0.3 judge_enabled: true # Enable GPT-4V validation @@ -131,7 +131,7 @@ demographics: - Arabic - Spanish - Chinese - + # Minimum confidence for demographics (0-1) min_confidence: 0.6 # Dataset Paths @@ -152,4 +152,4 @@ processing: # Logging logging: level: "INFO" # DEBUG, INFO, WARNING, ERROR - log_file: "vqa_generation.log" \ No newline at end of file + log_file: "vqa_generation.log" diff --git a/sonic-o1/05_evaluation_inference/README.md b/sonic-o1/05_evaluation_inference/README.md index e654406..70deae5 100644 --- a/sonic-o1/05_evaluation_inference/README.md +++ b/sonic-o1/05_evaluation_inference/README.md @@ -1,173 +1,666 @@ # Evaluation & Inference Pipeline -This directory contains the evaluation and inference pipeline for video question-answering models. It supports multiple open-source and commercial models with proper environment management and metrics computation. +## Overview -## Important Prerequisites +This directory handles model evaluation and inference for video question-answering tasks. It supports multiple open-source and commercial models with proper environment management, inference execution, and comprehensive metrics computation. -### 1. Working Directory -**IMPORTANT:** Always run scripts from the `sonic-o1` directory (parent directory), not from within `05_evaluation_inference`. This is required for relative paths to work correctly. +### Directory Structure -```bash -# Correct - run from sonic-o1 directory -cd /path/to/sonic-o1 -python 05_evaluation_inference/run_evaluation.py --config configs/eval_config.yaml +``` +05_evaluation_inference/ +โ”œโ”€โ”€ run_evaluation.py # Main evaluation pipeline orchestrator +โ”œโ”€โ”€ README.md # This file +โ”‚ +โ”œโ”€โ”€ inference/ # Inference execution +โ”‚ โ””โ”€โ”€ run_inference.py # Standalone inference script +โ”‚ +โ”œโ”€โ”€ models/ # Model implementations +โ”‚ โ”œโ”€โ”€ base_model.py # Base class for all models +โ”‚ โ”œโ”€โ”€ gemini.py # Google Gemini API +โ”‚ โ”œโ”€โ”€ gpt4o.py # OpenAI GPT-4o API +โ”‚ โ”œโ”€โ”€ qwen3.py # Qwen3 VL model +โ”‚ โ”œโ”€โ”€ minicpm.py # MiniCPM-V model +โ”‚ โ”œโ”€โ”€ phi4.py # Phi-4 Vision model +โ”‚ โ”œโ”€โ”€ unimoe.py # Uni-MoE model +โ”‚ โ”œโ”€โ”€ videollama.py # Video-LLaMA model +โ”‚ โ””โ”€โ”€ vita.py # VITA model +โ”‚ +โ”œโ”€โ”€ metrics/ # Metrics computation +โ”‚ โ”œโ”€โ”€ compute_metrics.py # Main metrics computation +โ”‚ โ”œโ”€โ”€ t1_metrics.py # Task 1 (Summarization) metrics +โ”‚ โ”œโ”€โ”€ t2_metrics.py # Task 2 (MCQ) metrics +โ”‚ โ”œโ”€โ”€ t3_metrics.py # Task 3 (Temporal) metrics +โ”‚ โ”œโ”€โ”€ llm_judge_gpt.py # GPT-based LLM judge +โ”‚ โ””โ”€โ”€ llm_judge_qwen.py # Qwen-based LLM judge +โ”‚ +โ”œโ”€โ”€ prompts/ # Task-specific prompts +โ”‚ โ”œโ”€โ”€ t1_prompts.py # Task 1 prompts +โ”‚ โ”œโ”€โ”€ t2_prompts.py # Task 2 prompts +โ”‚ โ””โ”€โ”€ t3_prompts.py # Task 3 prompts +โ”‚ +โ”œโ”€โ”€ utils/ # Utility functions +โ”‚ โ”œโ”€โ”€ audio_processor.py # Audio extraction and processing +โ”‚ โ”œโ”€โ”€ caption_handler.py # Caption/subtitle processing +โ”‚ โ”œโ”€โ”€ config_loader.py # Configuration management +โ”‚ โ”œโ”€โ”€ frame_sampler.py # Video frame sampling strategies +โ”‚ โ”œโ”€โ”€ mm_process_pyav.py # Multimedia processing with PyAV +โ”‚ โ””โ”€โ”€ segmenter.py # Video segmentation utilities +โ”‚ +โ”œโ”€โ”€ models_config.yaml # Model and evaluation configuration +โ”‚ +โ”œโ”€โ”€ models_requirements/ # Model-specific requirements +โ”‚ โ”œโ”€โ”€ requirements_venv_llama.txt +โ”‚ โ”œโ”€โ”€ requirements_venv_minicpm.txt +โ”‚ โ”œโ”€โ”€ requirements_venv_phi4.txt +โ”‚ โ”œโ”€โ”€ requirements_venv_unimoe.txt +โ”‚ โ”œโ”€โ”€ requirements_venv_vita.txt +โ”‚ โ””โ”€โ”€ requirements_qwen3.txt +โ”‚ +โ”œโ”€โ”€ external_repos/ # External model repositories +โ”‚ โ””โ”€โ”€ README.md # Details on included repos +โ”‚ +โ””โ”€โ”€ results/ # Output directory + โ”œโ”€โ”€ predictions/ # Model predictions + โ””โ”€โ”€ scores/ # Evaluation scores +``` + +### Pipeline Workflow + +``` +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Step 1: Model Inference โ”‚ +โ”‚ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ โ”‚ +โ”‚ Input: dataset/videos//video_*.mp4 โ”‚ +โ”‚ dataset/audios//audio_*.m4a โ”‚ +โ”‚ dataset/captions//caption_*.srt โ”‚ +โ”‚ vqa/task*/.json (ground truth) โ”‚ +โ”‚ Output: results/predictions///.json โ”‚ +โ”‚ โ”œโ”€โ”€ Model predictions for each VQA entry โ”‚ +โ”‚ โ””โ”€โ”€ Per-task, per-topic prediction files โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + โ†“ +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Step 2: Metrics Computation โ”‚ +โ”‚ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ โ”‚ +โ”‚ Input: results/predictions///.json โ”‚ +โ”‚ vqa/task*/.json (ground truth) โ”‚ +โ”‚ Output: results/scores///_scores.json โ”‚ +โ”‚ โ”œโ”€โ”€ Task-specific metrics (BLEU, ROUGE, accuracy, etc.) โ”‚ +โ”‚ โ”œโ”€โ”€ LLM judge scores (if enabled) โ”‚ +โ”‚ โ””โ”€โ”€ Aggregated evaluation results โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + โ†“ +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Step 3: Results Analysis โ”‚ +โ”‚ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ โ”‚ +โ”‚ โ€ข Review prediction files in results/predictions/ โ”‚ +โ”‚ โ€ข Analyze metric scores in results/scores/ โ”‚ +โ”‚ โ€ข Compare model performance across tasks โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ +``` + +### Supported Models + +The pipeline supports multiple video understanding models: + +**API-Based Models:** +- **Gemini** - Google Gemini multimodal API +- **GPT-4o** - OpenAI GPT-4o Vision API + +**Open-Source Models:** +- **Qwen3** - Qwen3 VL model +- **MiniCPM-V** - MiniCPM Vision model +- **Phi-4** - Phi-4 Vision model +- **Uni-MoE** - Uni-MoE multimodal model +- **Video-LLaMA** - Video-LLaMA2 model +- **VITA** - VITA video understanding model + +### Features + +- **Multiple Model Support**: Evaluate various open-source and commercial models +- **Three Task Types**: Task 1 (Summarization), Task 2 (MCQ), Task 3 (Temporal Localization) +- **Flexible Environment Management**: Model-specific virtual environments for compatibility +- **Comprehensive Metrics**: Task-specific metrics plus LLM judge evaluation +- **Resume Capability**: Skip already processed entries and resume from checkpoints +- **Batch Processing**: Process multiple models, tasks, and topics in one run +- **Error Handling**: Retry logic with fallback strategies for processing failures + +## Prerequisites + +Before running this step, you must have completed: + +1. **Data Curation** (see [01_data_curation](../01_data_curation/)) + - Downloaded videos and audio files + - Generated metadata.json files + +2. **Caption Generation** (see [02_caption_generation](../02_caption_generation/)) + - Generated captions for all videos (SRT format) + +3. **Demographics Annotation** (see [03_demographics_annotation](../03_demographics_annotation/)) + - Generated metadata_enhanced.json files with demographics + +4. **VQA Generation** (see [04_vqa_generation](../04_vqa_generation/)) + - Generated VQA ground truth files in `vqa/` directory: + - `vqa/task1_summarization/.json` + - `vqa/task2_mcq/.json` + - `vqa/task3_temporal_localization/.json` -# Incorrect - will fail due to broken relative paths -cd 05_evaluation_inference -python run_evaluation.py --config configs/eval_config.yaml +Your `dataset/` and `vqa/` directories should have this structure: ``` +dataset/ +โ”œโ”€โ”€ videos// +โ”‚ โ”œโ”€โ”€ video_001.mp4 +โ”‚ โ””โ”€โ”€ metadata_enhanced.json +โ”œโ”€โ”€ audios// +โ”‚ โ””โ”€โ”€ audio_001.m4a +โ””โ”€โ”€ captions// + โ””โ”€โ”€ caption_001.srt + +vqa/ +โ”œโ”€โ”€ task1_summarization/ +โ”‚ โ””โ”€โ”€ 01_Patient-Doctor_Consultations.json +โ”œโ”€โ”€ task2_mcq/ +โ”‚ โ””โ”€โ”€ 01_Patient-Doctor_Consultations.json +โ””โ”€โ”€ task3_temporal_localization/ + โ””โ”€โ”€ 01_Patient-Doctor_Consultations.json +``` + +5. **API Setup** (for API-based models) + - **Gemini API Key** (for Gemini model) + - Get from [Google AI Studio](https://makersuite.google.com/app/apikey) + - Set in `.env` file: `GEMINI_API_KEY=your_key_here` + - **OpenAI API Key** (for GPT-4o model) + - Get from [OpenAI Platform](https://platform.openai.com/api-keys) + - Set in `.env` file: `OPENAI_API_KEY=your_key_here` + +## Installation -### 2. Environment Activation +### 1. General Environment Setup -**General Environment:** Most models use the general project environment defined in `../pyproject.toml` (parent directory of `sonic-o1`). Install and activate this environment first: +Most models use the general project environment. Install dependencies from the project root: ```bash -# From the parent directory containing pyproject.toml +# From the project root (parent of sonic-o1) +cd /path/to/VideoQA-Agentic +source .venv/bin/activate pip install -e . # or pip install -r requirements.txt ``` -**Model-Specific Environments:** Some models require specialized dependencies and have their own virtual environments. Only use these if you're running the specific models listed below: +### 2. Model-Specific Environments -Available model-specific environments (see `models_requirements/` directory): -- `venv_llama` - For Video-LLaMA models -- `venv_minicpm` - For MiniCPM-V models -- `venv_phi4` - For Phi-4 Vision models -- `venv_unimoe` - For Uni-MoE models -- `venv_vita` - For VITA models +Some models require specialized dependencies. Only install these if you plan to use the specific models: -Example activation for model-specific environments: +**Video-LLaMA:** ```bash -# Activate environment for Video-LLaMA +cd /path/to/sonic-o1/05_evaluation_inference +python -m venv venv_llama source venv_llama/bin/activate +pip install -r models_requirements/requirements_venv_llama.txt +``` + +**MiniCPM-V:** +```bash +python -m venv venv_minicpm +source venv_minicpm/bin/activate +pip install -r models_requirements/requirements_venv_minicpm.txt +``` + +**Phi-4 Vision:** +```bash +python -m venv venv_phi4 +source venv_phi4/bin/activate +pip install -r models_requirements/requirements_venv_phi4.txt +``` -# Activate environment for Uni-MoE +**Uni-MoE:** +```bash +python -m venv venv_unimoe source venv_unimoe/bin/activate +pip install -r models_requirements/requirements_venv_unimoe.txt +``` + +**VITA:** +```bash +python -m venv venv_vita +source venv_vita/bin/activate +pip install -r models_requirements/requirements_venv_vita.txt +``` + +**Qwen3:** +```bash +# Uses general environment, but may need: +pip install -r models_requirements/requirements_qwen3.txt ``` -**Rule of thumb:** Use the general environment unless the model has a `requirements_venv_.txt` file in `models_requirements/`. - -## Directory Structure - -### Core Scripts -- [run_evaluation.py](run_evaluation.py) - Main evaluation pipeline orchestrator -- [inference/run_inference.py](inference/run_inference.py) - Standalone inference script for model predictions - -### Configuration & Setup -- **configs/** - YAML configuration files for evaluation runs -- **models_requirements/** - Python requirements files for each model's virtual environment - - `requirements_venv_llama.txt` - - `requirements_venv_minicpm.txt` - - `requirements_venv_phi4.txt` - - `requirements_venv_unimoe.txt` - - `requirements_venv_vita.txt` - -### Model Implementations -- **models/** - Model wrapper classes (only `.py` files, no backup `.txt` files) - - [base_model.py](models/base_model.py) - Base class for all models - - [gemini.py](models/gemini.py) - Google Gemini API - - [gpt4o.py](models/gpt4o.py) - OpenAI GPT-4o API - - [minicpm.py](models/minicpm.py) - MiniCPM-V model - - [phi4.py](models/phi4.py) - Phi-4 Vision model - - [qwen3.py](models/qwen3.py) - Qwen3 VL model - - [unimoe.py](models/unimoe.py) - Uni-MoE model - - [videollama.py](models/videollama.py) - Video-LLaMA model - - [vita.py](models/vita.py) - VITA model - -### Metrics & Evaluation -- **metrics/** - Metric computation scripts - - [compute_metrics.py](metrics/compute_metrics.py) - Main metrics computation - - [t1_metrics.py](metrics/t1_metrics.py) - Task 1 specific metrics - - [t2_metrics.py](metrics/t2_metrics.py) - Task 2 specific metrics - - [t3_metrics.py](metrics/t3_metrics.py) - Task 3 specific metrics - - [llm_judge_gpt.py](metrics/llm_judge_gpt.py) - GPT-based LLM judge - - [llm_judge_qwen.py](metrics/llm_judge_qwen.py) - Qwen-based LLM judge - -### Supporting Components -- **prompts/** - Task-specific prompt templates - - [t1_prompts.py](prompts/t1_prompts.py) - - [t2_prompts.py](prompts/t2_prompts.py) - - [t3_prompts.py](prompts/t3_prompts.py) - -- **utils/** - Utility functions for data processing - - [audio_processor.py](utils/audio_processor.py) - Audio extraction and processing - - [frame_sampler.py](utils/frame_sampler.py) - Video frame sampling strategies - - [caption_handler.py](utils/caption_handler.py) - Caption/subtitle processing - - [segmenter.py](utils/segmenter.py) - Video segmentation utilities - - [config_loader.py](utils/config_loader.py) - Configuration management - - [mm_process_pyav.py](utils/mm_process_pyav.py) - Multimedia processing with PyAV - -- **external_repos/** - Open-source model repositories - See [external_repos/README.md](external_repos/README.md) for details on included repositories (Uni-MoE, VideoLLaMA2, VITA) with compatibility fixes applied. - -- **results/** - Output directory for evaluation results - -## Usage Examples - -### Running Full Evaluation -```bash -# From sonic-o1 directory, with correct env activated +### 3. External Repositories + +Some models require external repositories with compatibility fixes. See [external_repos/README.md](external_repos/README.md) for setup instructions. + +## Configuration + +Edit [models_config.yaml](models_config.yaml) to customize evaluation settings. + +### Dataset Paths +```yaml +dataset_path: "dataset" +vqa_path: "vqa" + +results: + predictions_path: "05_evaluation_inference/results/predictions" + scores_path: "05_evaluation_inference/results/scores" +``` + +### Tasks and Topics +```yaml +tasks: + - task1_summarization + - task2_mcq + - task3_temporal_localization + +topics: + - "01_Patient-Doctor_Consultations" + - "02_Job_Interviews" + # ... (all 13 topics) +``` + +### Preprocessing Settings +```yaml +preprocessing: + t2_t3: + segment_max_duration: 180 # Max segment duration (seconds) + image_model_frames: 128 # Frames for image-based models + video_model_fps: 1 # FPS for video models +``` + +### Retry Logic +```yaml +retry: + max_attempts: 4 + fps_fallback: [1, 0.5, 0.25] + frame_count_fallback: [256, 128, 64, 32] + audio_chunks_fallback: [null, 64, 32, 16] + audio_chunk_duration_sec: 10.0 +``` + +### Metrics Configuration +```yaml +metrics: + llm_judge_model: "gpt-5-mini" # or "Qwen/Qwen3-8B" + # Task-specific metric settings... +``` + +### Model-Specific Configuration + +Each model has its own configuration section in `models_config.yaml`: +```yaml +models: + gemini: + class: "Gemini" + api_key: "${GEMINI_API_KEY}" + # ... model-specific settings + + videollama: + class: "VideoLLaMA2" + model_path: "/path/to/model" + # ... model-specific settings +``` + +## Usage + +**IMPORTANT**: Always run scripts from the `sonic-o1` directory (parent directory), not from within `05_evaluation_inference`. This is required for relative paths to work correctly. + +### Full Evaluation Pipeline + +Run both inference and metrics computation: + +```bash +# Navigate to sonic-o1 directory cd /path/to/sonic-o1 -source venv_llama/bin/activate # Activate appropriate environment + +# Activate appropriate environment +source .venv/bin/activate # or source venv_/bin/activate + +# Run full evaluation for one model python 05_evaluation_inference/run_evaluation.py \ - --config 05_evaluation_inference/configs/eval_config.yaml \ - --model videollama \ - --task t1 + --model gemini \ + --tasks all + +# Run for specific tasks +python 05_evaluation_inference/run_evaluation.py \ + --model gpt4o \ + --tasks t1 t2 + +# Run for specific topics +python 05_evaluation_inference/run_evaluation.py \ + --model qwen3 \ + --topics "01_Patient-Doctor_Consultations" "02_Job_Interviews" ``` -### Running Inference Only +### Inference Only + +Run inference without computing metrics: + +```bash +python 05_evaluation_inference/run_evaluation.py \ + --model gemini \ + --tasks all \ + --inference-only +``` + +### Metrics Only + +Compute metrics on existing predictions: + +```bash +python 05_evaluation_inference/run_evaluation.py \ + --model gemini \ + --metrics-only +``` + +### Multiple Models + +Evaluate multiple models in sequence: + +```bash +python 05_evaluation_inference/run_evaluation.py \ + --models gemini gpt4o qwen3 \ + --tasks all +``` + +### Standalone Inference + +Run inference directly without the orchestrator: + ```bash -# From sonic-o1 directory -cd /path/to/sonic-o1 -source venv_unimoe/bin/activate python 05_evaluation_inference/inference/run_inference.py \ - --model unimoe \ - --input_data data/test_videos.json \ - --output_dir 05_evaluation_inference/results/unimoe_inference + --model gemini \ + --tasks task1_summarization task2_mcq \ + --topics "01_Patient-Doctor_Consultations" \ + --config 05_evaluation_inference/models_config.yaml ``` -### Computing Metrics on Results +### Standalone Metrics + +Compute metrics directly: + ```bash -# From sonic-o1 directory -cd /path/to/sonic-o1 python 05_evaluation_inference/metrics/compute_metrics.py \ - --predictions 05_evaluation_inference/results/predictions.json \ - --ground_truth data/ground_truth.json \ + --predictions results/predictions/gemini/task1_summarization/01_Patient-Doctor_Consultations.json \ + --ground_truth vqa/task1_summarization/01_Patient-Doctor_Consultations.json \ --task t1 ``` -## Common Workflow +### Command-Line Arguments + +The main evaluation script supports many options: + +| Argument | Description | Example | +|----------|-------------|---------| +| `--model` | Single model name to evaluate | `--model gemini` | +| `--models` | Multiple model names | `--models gemini gpt4o` | +| `--config` | Path to config file | `--config custom_config.yaml` | +| `--tasks` | Tasks to evaluate (t1, t2, t3, or 'all') | `--tasks t1 t2` | +| `--topics` | Topics to evaluate | `--topics "01_Patient-Doctor_Consultations"` | +| `--inference-only` | Run inference only, skip metrics | `--inference-only` | +| `--metrics-only` | Compute metrics only, skip inference | `--metrics-only` | +| `--skip-existing` | Skip already processed entries (default: True) | `--skip-existing` | +| `--force-rerun` | Force re-run even if outputs exist | `--force-rerun` | +| `--retry-failed` | Retry only failed entries | `--retry-failed` | +| `--experiment-name` | Experiment name for organizing results | `--experiment-name "frames_16"` | +| `--no-llm-judge` | Skip LLM judge evaluation | `--no-llm-judge` | +| `--dataset-path` | Override dataset path | `--dataset-path custom_dataset/` | +| `--vqa-path` | Override VQA path | `--vqa-path custom_vqa/` | + +**Examples:** + +```bash +# Full evaluation with experiment name +python 05_evaluation_inference/run_evaluation.py \ + --model gemini \ + --tasks all \ + --experiment-name "modality_audio_only" + +# Retry failed entries only +python 05_evaluation_inference/run_evaluation.py \ + --model videollama \ + --tasks t2 \ + --retry-failed + +# Force rerun with custom paths +python 05_evaluation_inference/run_evaluation.py \ + --model gpt4o \ + --tasks t1 \ + --force-rerun \ + --dataset-path custom_dataset/ \ + --vqa-path custom_vqa/ +``` + +## Output + +The pipeline creates organized output files in the `results/` directory. + +### Output Location +``` +results/ +โ”œโ”€โ”€ predictions/ +โ”‚ โ””โ”€โ”€ / +โ”‚ โ”œโ”€โ”€ task1_summarization/ +โ”‚ โ”‚ โ”œโ”€โ”€ 01_Patient-Doctor_Consultations.json +โ”‚ โ”‚ โ””โ”€โ”€ ... +โ”‚ โ”œโ”€โ”€ task2_mcq/ +โ”‚ โ”‚ โ””โ”€โ”€ ... +โ”‚ โ””โ”€โ”€ task3_temporal_localization/ +โ”‚ โ””โ”€โ”€ ... +โ”‚ +โ””โ”€โ”€ scores/ + โ””โ”€โ”€ / + โ”œโ”€โ”€ task1_summarization/ + โ”‚ โ”œโ”€โ”€ 01_Patient-Doctor_Consultations_scores.json + โ”‚ โ””โ”€โ”€ ... + โ”œโ”€โ”€ task2_mcq/ + โ”‚ โ””โ”€โ”€ ... + โ””โ”€โ”€ task3_temporal_localization/ + โ””โ”€โ”€ ... +``` + +### Prediction Format + +Predictions follow the same structure as ground truth VQA files, with model predictions added: + +**Task 1 (Summarization):** +```json +{ + "video_id": "001", + "prediction": { + "summary_short": ["...", "..."], + "summary_detailed": "...", + "timeline": [...], + "glossary": [...] + }, + "ground_truth": { ... } +} +``` + +**Task 2 (MCQ):** +```json +{ + "video_id": "001", + "segment": {"start": 120.0, "end": 180.0}, + "prediction": { + "answer": 0, + "explanation": "..." + }, + "ground_truth": { ... } +} +``` + +**Task 3 (Temporal):** +```json +{ + "video_id": "001", + "segment": {"start": 45.0, "end": 90.0}, + "prediction": { + "answer": "...", + "temporal_relation": "after" + }, + "ground_truth": { ... } +} +``` + +### Metrics Format + +Scores files contain comprehensive evaluation metrics: + +```json +{ + "model": "gemini", + "task": "task1_summarization", + "topic": "01_Patient-Doctor_Consultations", + "metrics": { + "bleu": 0.45, + "rouge_l": 0.52, + "rouge_1": 0.58, + "rouge_2": 0.41, + "llm_judge_score": 0.78 + }, + "num_entries": 25, + "computed_at": "2026-01-14 12:34:56" +} +``` + +## Processing Time + +Processing time varies significantly by model and task: + +- **API Models (Gemini, GPT-4o)**: ~5-30 seconds per entry +- **Open-Source Models**: ~10-120 seconds per entry (depends on GPU) +- **Task 1 (Summarization)**: Faster (single summary per video) +- **Task 2 (MCQ)**: Medium (multiple questions per video) +- **Task 3 (Temporal)**: Slower (multiple questions with temporal reasoning) + +### Estimated Time for Full Dataset + +- **13 topics ร— 25 videos = 325 videos** +- **API models**: ~2-4 hours (all tasks) +- **Open-source models**: ~4-12 hours (all tasks, GPU-dependent) + +## Troubleshooting + +### Environment Issues + +**Problem**: `ModuleNotFoundError` or import errors + +**Solution**: Ensure you're using the correct environment: +```bash +# Check which environment is active +which python + +# Activate correct environment +source .venv/bin/activate # General environment +# or +source venv_/bin/activate # Model-specific environment +``` + +### Path Resolution Errors -1. **Setup Environment** - ```bash - cd /path/to/sonic-o1 - source venv_/bin/activate # Replace with target model - ``` +**Problem**: `FileNotFoundError` for dataset or VQA files -2. **Run Inference** - ```bash - python 05_evaluation_inference/run_evaluation.py --config --model - ``` +**Solution**: Always run from the `sonic-o1` directory: +```bash +cd /path/to/sonic-o1 +python 05_evaluation_inference/run_evaluation.py --model gemini +``` + +### API Key Not Found + +**Problem**: `ERROR: API key not set!` (for Gemini/GPT-4o) -3. **Compute Metrics** - ```bash - python 05_evaluation_inference/metrics/compute_metrics.py --predictions --ground_truth - ``` +**Solution**: +```bash +# Create .env file in project root +echo "GEMINI_API_KEY=your_key_here" >> .env +echo "OPENAI_API_KEY=your_key_here" >> .env + +# Or export environment variables +export GEMINI_API_KEY=your_key_here +export OPENAI_API_KEY=your_key_here +``` -4. **Review Results** - - Check `results/` directory for output files - - Prediction files contain model responses - - Metric files contain computed evaluation scores +### GPU Out of Memory + +**Problem**: `CUDA out of memory` (for open-source models) + +**Solution**: Adjust preprocessing settings in `models_config.yaml`: +```yaml +preprocessing: + t2_t3: + image_model_frames: 64 # Reduce from 128 + video_model_fps: 0.5 # Reduce from 1 +``` + +### Model Loading Errors + +**Problem**: Model fails to load or initialize + +**Solution**: +1. Check model path in `models_config.yaml` +2. Verify model-specific environment is activated +3. Check external repository setup (see `external_repos/README.md`) + +### Retry Logic Issues + +**Problem**: Processing fails repeatedly + +**Solution**: Check retry configuration in `models_config.yaml`: +```yaml +retry: + max_attempts: 4 # Increase if needed + frame_count_fallback: [256, 128, 64, 32] # Adjust fallback values +``` + +### LLM Judge Errors + +**Problem**: LLM judge evaluation fails + +**Solution**: +```bash +# Skip LLM judge for faster evaluation +python 05_evaluation_inference/run_evaluation.py \ + --model gemini \ + --no-llm-judge +``` + +Or check LLM judge model configuration in `models_config.yaml`. + +## Files + +- `run_evaluation.py` - Main evaluation pipeline orchestrator +- `inference/run_inference.py` - Standalone inference script +- `metrics/compute_metrics.py` - Main metrics computation +- `metrics/t1_metrics.py` - Task 1 specific metrics +- `metrics/t2_metrics.py` - Task 2 specific metrics +- `metrics/t3_metrics.py` - Task 3 specific metrics +- `models_config.yaml` - Model and evaluation configuration +- `models/` - Model implementations (base_model, gemini, gpt4o, etc.) +- `prompts/` - Task-specific prompt templates +- `utils/` - Utility modules (audio_processor, frame_sampler, etc.) +- `models_requirements/` - Model-specific Python requirements +- `external_repos/` - External model repositories with fixes ## Notes -- **Environment Management**: Most models work with the general environment defined in `../pyproject.toml`. Only use model-specific environments (in `models_requirements/`) when running models with special requirements (Uni-MoE, VITA, Video-LLaMA, MiniCPM-V, Phi-4). Mismatched environments will cause import errors or compatibility issues. +- **Environment Management**: Most models work with the general environment. Only use model-specific environments (in `models_requirements/`) when running models with special requirements (Uni-MoE, VITA, Video-LLaMA, MiniCPM-V, Phi-4). - **Path Resolution**: All scripts expect to be run from the `sonic-o1` parent directory to properly resolve relative imports and data paths. -- **Model Files**: The `models/` directory contains only final `.py` implementations. Backup `.txt` files are not included in the repository. +- **Resume Capability**: The pipeline automatically skips already processed entries. Use `--force-rerun` to reprocess everything. + +- **Experiment Names**: Use `--experiment-name` to organize results for different experimental conditions (e.g., "frames_16", "modality_audio_only"). + +- **API Rate Limits**: API-based models (Gemini, GPT-4o) have rate limits. The pipeline includes automatic retry logic, but you may need to adjust delays for large-scale evaluation. -- **External Dependencies**: Some models require external repositories with fixes applied (see `external_repos/` directory). +- **GPU Requirements**: Open-source models require GPUs. Ensure sufficient VRAM and CUDA setup before running GPU-based models. -- **API Models**: For Gemini and GPT-4o, ensure API keys are properly configured in your environment or `.env` file. +- **External Dependencies**: Some models require external repositories with compatibility fixes (see `external_repos/README.md` for setup instructions). diff --git a/sonic-o1/05_evaluation_inference/configs/__init__.py b/sonic-o1/05_evaluation_inference/configs/__init__.py deleted file mode 100644 index d1f64a3..0000000 --- a/sonic-o1/05_evaluation_inference/configs/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -configs/__init__.py -""" \ No newline at end of file diff --git a/sonic-o1/05_evaluation_inference/inference/__init__.py b/sonic-o1/05_evaluation_inference/inference/__init__.py index 5e32b99..9926803 100644 --- a/sonic-o1/05_evaluation_inference/inference/__init__.py +++ b/sonic-o1/05_evaluation_inference/inference/__init__.py @@ -1,5 +1,5 @@ """ -inference/__init__.py +inference/__init__.py. Inference pipeline for model evaluation. -""" \ No newline at end of file +""" diff --git a/sonic-o1/05_evaluation_inference/inference/run_inference.py b/sonic-o1/05_evaluation_inference/inference/run_inference.py index 0a699ec..e3a98e4 100644 --- a/sonic-o1/05_evaluation_inference/inference/run_inference.py +++ b/sonic-o1/05_evaluation_inference/inference/run_inference.py @@ -1,1068 +1,1237 @@ """ -inference/run_inference.py +inference/run_inference.py. + Main inference pipeline for model evaluation with resume capability. """ + import json import logging +import os +import re +import shutil import sys -from pathlib import Path -from typing import Dict, List, Any, Optional, Tuple -from datetime import datetime import time +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + import yaml -from tqdm import tqdm -import os -import shutil -import re from dotenv import load_dotenv +from tqdm import tqdm + load_dotenv() sys.path.append(str(Path(__file__).parent.parent)) -from prompts.t1_prompts import get_t1_prompt, get_t1_empathy_prompt -from prompts.t2_prompts import get_t2_prompt -from prompts.t3_prompts import get_t3_prompt -from utils.frame_sampler import FrameSampler -from utils.segmenter import VideoSegmenter +# Local package imports after path setup (required for script/CLI usage) +from prompts.t1_prompts import get_t1_empathy_prompt, get_t1_prompt # noqa: E402 +from prompts.t2_prompts import get_t2_prompt # noqa: E402 +from prompts.t3_prompts import get_t3_prompt # noqa: E402 +from utils.segmenter import VideoSegmenter # noqa: E402 + logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) + class InferenceRunner: - """Run inference for model evaluation with resume capability""" - - def __init__(self, config_path: str, experiment_name: Optional[str] = None): - with open(config_path, 'r') as f: + """Run inference for model evaluation with resume capability.""" + + def __init__(self, config_path: str, experiment_name: Optional[str] = None) -> None: + """ + Initialize runner from a YAML config file. + + Args: + config_path: Path to models configuration YAML. + experiment_name: Optional experiment label for organizing outputs. + """ + with open(config_path, "r") as f: self.config = yaml.safe_load(f) - - self.dataset_path = Path(self.config['dataset_path']) - self.vqa_path = Path(self.config['vqa_path']) - + + self.dataset_path = Path(self.config["dataset_path"]) + self.vqa_path = Path(self.config["vqa_path"]) + self.model = None self.model_name = None self.model_config = None - self.experiment_name = experiment_name - + self.experiment_name = experiment_name + self.frame_sampler = None self.video_segmenter = VideoSegmenter() - - self.retry_config = self.config['retry'] - self.preprocessing_config = self.config['preprocessing'] - - self.video_metadata = {} + + self.retry_config = self.config["retry"] + self.preprocessing_config = self.config["preprocessing"] + + self.video_metadata = {} self.failed_entries = [] - + def _get_temp_dir(self) -> Path: - """Get temporary directory for video segments""" + """Get temporary directory for video segments.""" unique_suffix = f"temp_segments_{os.getpid()}" - temp_base = os.environ.get('SCRATCH_DIR') or os.environ.get('TMPDIR') - + temp_base = os.environ.get("SCRATCH_DIR") or os.environ.get("TMPDIR") + if temp_base: temp_dir = Path(temp_base) / unique_suffix else: - temp_dir = Path.home() / 'scratch' / unique_suffix - + temp_dir = Path.home() / "scratch" / unique_suffix + return temp_dir - + def _load_video_metadata(self, topic_name: str) -> Dict: - """Load metadata_enhanced.json for specific topic""" - metadata_path = self.dataset_path / 'videos' / topic_name / 'metadata_enhanced.json' - + """Load metadata_enhanced.json for specific topic.""" + metadata_path = ( + self.dataset_path / "videos" / topic_name / "metadata_enhanced.json" + ) + if not metadata_path.exists(): logger.warning(f"Metadata file not found: {metadata_path}") return {} - - with open(metadata_path, 'r') as f: + + with open(metadata_path, "r") as f: metadata_list = json.load(f) - - metadata_dict = {item['video_id']: item for item in metadata_list} + + metadata_dict = {item["video_id"]: item for item in metadata_list} logger.info(f"Loaded metadata for {len(metadata_dict)} videos") return metadata_dict - + def get_video_category(self, video_id: str) -> str: - """Get duration category (short/medium/long) for video""" + """Get duration category (short/medium/long) for video.""" if video_id in self.video_metadata: - return self.video_metadata[video_id].get('duration_category', 'medium') - return 'medium' - - def load_model(self, model_name: str): - """Load specified model""" + return self.video_metadata[video_id].get("duration_category", "medium") + return "medium" + + def load_model(self, model_name: str) -> None: + """Load specified model from config and initialize it.""" model_config = None - for m in self.config['models']: - if m['name'] == model_name: + for m in self.config["models"]: + if m["name"] == model_name: model_config = m break - + if model_config is None: raise ValueError(f"Model {model_name} not found in config") - + self.model_config = model_config - model_class = model_config['class'] - - if model_class == 'Gemini': - from models.gemini import Gemini + model_class = model_config["class"] + + # Lazy imports: load only the selected model (env-dependent, avoids heavy deps) + if model_class == "Gemini": + from models.gemini import Gemini # noqa: PLC0415 + self.model = Gemini(model_name, model_config) - elif model_class == 'Qwen3Omni': - from models.qwen3 import Qwen3Omni + elif model_class == "Qwen3Omni": + from models.qwen3 import Qwen3Omni # noqa: PLC0415 + self.model = Qwen3Omni(model_name, model_config) - elif model_class == 'MiniCPM': - from models.minicpm import MiniCPM + elif model_class == "MiniCPM": + from models.minicpm import MiniCPM # noqa: PLC0415 + self.model = MiniCPM(model_name, model_config) - elif model_class == 'UniMoe': - from models.unimoe import UniMoe + elif model_class == "UniMoe": + from models.unimoe import UniMoe # noqa: PLC0415 + self.model = UniMoe(model_name, model_config) - elif model_class == 'VITA': - from models.vita import VITA + elif model_class == "VITA": + from models.vita import VITA # noqa: PLC0415 + self.model = VITA(model_name, model_config) - elif model_class == 'VideoLLaMA2': - from models.videollama import VideoLLaMA2 + elif model_class == "VideoLLaMA2": + from models.videollama import VideoLLaMA2 # noqa: PLC0415 + self.model = VideoLLaMA2(model_name, model_config) - elif model_class == 'Phi4': - from models.phi4 import Phi4 + elif model_class == "Phi4": + from models.phi4 import Phi4 # noqa: PLC0415 + self.model = Phi4(model_name, model_config) - elif model_class == 'GPT4o': - from models.gpt4o import GPT4o + elif model_class == "GPT4o": + from models.gpt4o import GPT4o # noqa: PLC0415 + self.model = GPT4o(model_name, model_config) else: raise ValueError(f"Unknown model class: {model_class}") - + self.model.load() self.model_name = model_name - - #if not model_config.get('supports_video', True): + + # --- KEPT FOR LEGACY COMPATIBILITY (NOT USED IN O1) --- + # if not model_config.get('supports_video', True): # self.frame_sampler = FrameSampler() - + logger.info(f"Loaded model: {model_name}") - + def get_video_path(self, topic_name: str, video_number: str) -> Path: - """Build path to video file""" - video_path = self.dataset_path / 'videos' / topic_name / f'video_{video_number}.mp4' + """Build path to video file.""" + video_path = ( + self.dataset_path / "videos" / topic_name / f"video_{video_number}.mp4" + ) if not video_path.exists(): raise FileNotFoundError(f"Video not found: {video_path}") return video_path - + def get_audio_path(self, topic_name: str, video_number: str) -> Optional[Path]: - """Build path to audio file""" - audio_path = self.dataset_path / 'audios' / topic_name / f'audio_{video_number}.m4a' + """Build path to audio file.""" + audio_path = ( + self.dataset_path / "audios" / topic_name / f"audio_{video_number}.m4a" + ) if audio_path.exists(): return audio_path return None - + def load_ground_truth(self, task: str, topic_name: str) -> Dict: - """Load ground truth JSON for task and topic""" - gt_path = self.vqa_path / task / f'{topic_name}.json' + """Load ground truth JSON for task and topic.""" + gt_path = self.vqa_path / task / f"{topic_name}.json" if not gt_path.exists(): raise FileNotFoundError(f"Ground truth not found: {gt_path}") - - with open(gt_path, 'r') as f: + + with open(gt_path, "r") as f: return json.load(f) - + def get_prediction_path(self, task: str, topic_name: str) -> Path: - """Get path to prediction file""" + """Get path to prediction file.""" if self.experiment_name: - output_dir = Path('results/predictions') / self.experiment_name / self.model_name / task + output_dir = ( + Path("results/predictions") + / self.experiment_name + / self.model_name + / task + ) else: - output_dir = Path('results/predictions') / self.model_name / task - - return output_dir / f'{topic_name}.json' - + output_dir = Path("results/predictions") / self.model_name / task + + return output_dir / f"{topic_name}.json" + def load_existing_predictions(self, task: str, topic_name: str) -> Optional[Dict]: - """Load existing predictions if they exist""" + """Load existing predictions if they exist.""" pred_path = self.get_prediction_path(task, topic_name) if pred_path.exists(): try: - with open(pred_path, 'r') as f: + with open(pred_path, "r") as f: return json.load(f) except Exception as e: logger.warning(f"Failed to load existing predictions: {e}") return None - + def validate_output(self, output: Dict, task: str) -> Tuple[bool, str]: - """Validate model output format""" + """Validate model output format.""" try: - if task == 'task1_summarization': - required = ['summary_detailed', 'summary_short', 'timeline', 'glossary', 'confidence'] + if task == "task1_summarization": + required = [ + "summary_detailed", + "summary_short", + "timeline", + "glossary", + "confidence", + ] for field in required: if field not in output: return False, f"Missing field: {field}" - - if not isinstance(output['summary_short'], list): + + if not isinstance(output["summary_short"], list): return False, "summary_short must be a list" - if not isinstance(output['timeline'], list): + if not isinstance(output["timeline"], list): return False, "timeline must be a list" - if not isinstance(output['glossary'], list): + if not isinstance(output["glossary"], list): return False, "glossary must be a list" - - elif task == 'task1_empathy': - required = ['summary_detailed_empathic', 'confidence'] + + elif task == "task1_empathy": + required = ["summary_detailed_empathic", "confidence"] for field in required: if field not in output: return False, f"Missing field: {field}" - if not isinstance(output['summary_detailed_empathic'], str): + if not isinstance(output["summary_detailed_empathic"], str): return False, "summary_detailed_empathic must be a string" - - elif task == 'task2_mcq': - required = ['answer_letter', 'answer_index', 'rationale', 'confidence'] + + elif task == "task2_mcq": + required = ["answer_letter", "answer_index", "rationale", "confidence"] for field in required: if field not in output: return False, f"Missing field: {field}" - - if output['answer_letter'] not in ['A', 'B', 'C', 'D', 'E']: + + if output["answer_letter"] not in ["A", "B", "C", "D", "E"]: return False, f"Invalid answer_letter: {output['answer_letter']}" - if not (0 <= output['answer_index'] <= 4): + if not (0 <= output["answer_index"] <= 4): return False, f"Invalid answer_index: {output['answer_index']}" - - elif task == 'task3_temporal_localization': - if 'questions' not in output: + + elif task == "task3_temporal_localization": + if "questions" not in output: return False, "Missing field: questions" - if not isinstance(output['questions'], list): + if not isinstance(output["questions"], list): return False, "questions must be a list" - - for q in output['questions']: - required = ['question_id', 'start_s', 'end_s', 'confidence', 'rationale_model'] + + for q in output["questions"]: + required = [ + "question_id", + "start_s", + "end_s", + "confidence", + "rationale_model", + ] for field in required: if field not in q: return False, f"Missing field in question: {field}" - + return True, "Valid" - + except Exception as e: return False, f"Validation error: {e}" - + def _clean_json_response(self, response: str) -> str: - """Clean markdown code blocks and common JSON errors from response""" + """Clean markdown code blocks and common JSON errors from response.""" response = response.strip() - - if response.startswith('```'): - lines = response.split('\n') - if lines[0].startswith('```'): + + if response.startswith("```"): + lines = response.split("\n") + if lines[0].startswith("```"): lines = lines[1:] - if lines and lines[-1].strip() == '```': + if lines and lines[-1].strip() == "```": lines = lines[:-1] - response = '\n'.join(lines).strip() - - response = response.replace('```json', '').replace('```', '') - response = re.sub(r':\s*"?(\d+\.?\d*)s"?([,\s\n}])', r': \1\2', response) - response = re.sub(r',(\s*[}\]])', r'\1', response) + response = "\n".join(lines).strip() + + response = response.replace("```json", "").replace("```", "") + response = re.sub(r':\s*"?(\d+\.?\d*)s"?([,\s\n}])', r": \1\2", response) + response = re.sub(r",(\s*[}\]])", r"\1", response) response = re.sub(r'"\s*\n\s*"', '",\n"', response) - - start = response.find('{') - end = response.rfind('}') + + start = response.find("{") + end = response.rfind("}") if start != -1 and end != -1: - response = response[start:end+1] - + response = response[start : end + 1] + return response - + def _generate_with_retry( - self, - video_path: Path, - audio_path: Optional[Path], - prompt: str, - video_category: str, - task_type: str = 't1', - topic_name: Optional[str] = None, - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - """Generate response with retry logic, JSON parsing, and validation""" - if topic_name: - self.current_topic_name = topic_name - - # Apply model capability filters - supports_video = self.model_config.get('supports_video', True) - supports_audio = self.model_config.get('supports_audio', True) - use_captions = self.model_config.get('use_captions', False) - - # For video/audio inputs to the model - actual_video_path = video_path if supports_video else None - actual_audio_path = audio_path if supports_audio else None - - # IMPORTANT: Keep original video_path for caption discovery - # even if model doesn't use video frames - - max_attempts = self.retry_config['max_attempts'] - - # Determine retry strategy from config - retry_strategy = self.model_config.get('retry_strategy', 'auto') - - # Auto-detect strategy if not explicitly set - if retry_strategy == 'auto': - is_api_model = 'api_key_env' in self.model_config - retry_strategy = 'fps' if is_api_model else 'frame_count' - - # Get fallback options based on strategy - if retry_strategy == 'frame_count_caption': - # Frame count + caption chunks (e.g., GPT-4o) - if 'retry_override' in self.model_config: - frame_options = self.model_config['retry_override'].get('frame_count_fallback', self.retry_config['frame_count_fallback']) - caption_options = self.model_config['retry_override'].get('caption_chunks_fallback', [None, None, None, None]) - else: - frame_options = self.retry_config['frame_count_fallback'] - caption_options = [None, None, None, None] - - elif retry_strategy == 'frame_count': - # Frame count + audio chunks (e.g., local models) - if 'retry_override' in self.model_config: - frame_options = self.model_config['retry_override'].get('frame_count_fallback', self.retry_config['frame_count_fallback']) - audio_options = self.model_config['retry_override'].get('audio_chunks_fallback', self.retry_config['audio_chunks_fallback']) - else: - frame_options = self.retry_config['frame_count_fallback'] - audio_options = self.retry_config.get('audio_chunks_fallback', [None, None, None, None]) - audio_chunk_duration = self.retry_config.get('audio_chunk_duration_sec', 10.0) - - elif retry_strategy == 'fps': - # FPS-based (e.g., Gemini) - fps_options = self.retry_config['fps_fallback'] - + self, + video_path: Path, + audio_path: Optional[Path], + prompt: str, + video_category: str, + task_type: str = "t1", + topic_name: Optional[str] = None, + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Generate response with retry logic, JSON parsing, and validation.""" + if topic_name: + self.current_topic_name = topic_name + + # Apply model capability filters + supports_video = self.model_config.get("supports_video", True) + supports_audio = self.model_config.get("supports_audio", True) + use_captions = self.model_config.get("use_captions", False) + + # For video/audio inputs to the model + actual_video_path = video_path if supports_video else None + actual_audio_path = audio_path if supports_audio else None + + # IMPORTANT: Keep original video_path for caption discovery + # even if model doesn't use video frames + + max_attempts = self.retry_config["max_attempts"] + + # Determine retry strategy from config + retry_strategy = self.model_config.get("retry_strategy", "auto") + + # Auto-detect strategy if not explicitly set + if retry_strategy == "auto": + is_api_model = "api_key_env" in self.model_config + retry_strategy = "fps" if is_api_model else "frame_count" + + # Get fallback options based on strategy + if retry_strategy == "frame_count_caption": + # Frame count + caption chunks (e.g., GPT-4o) + if "retry_override" in self.model_config: + frame_options = self.model_config["retry_override"].get( + "frame_count_fallback", self.retry_config["frame_count_fallback"] + ) + caption_options = self.model_config["retry_override"].get( + "caption_chunks_fallback", [None, None, None, None] + ) else: - raise ValueError(f"Unknown retry_strategy: {retry_strategy}") - - last_error = None - - for attempt in range(max_attempts): - try: - if retry_strategy == 'frame_count_caption': - # Frame count + caption chunks - max_frames = frame_options[min(attempt, len(frame_options) - 1)] - max_caption_chunks = caption_options[min(attempt, len(caption_options) - 1)] - - # Discover caption path for GPT-4o - caption_path = None - if use_captions and video_path: - caption_path = self._get_caption_path_for_video(video_path, self.current_topic_name) - - if caption_path: - logger.info(f"Discovered caption path: {caption_path}") - else: - logger.warning(f"Could not discover caption path for {video_path}") - - preprocessing = { - 'max_frames': max_frames, - 'max_caption_chunks': max_caption_chunks, - 'attempt': attempt + 1, - 'method': 'frame_caption_sampling', - 'video_category': video_category - } - - logger.info(f"Attempt {attempt + 1}/{max_attempts}: max_frames={max_frames}, max_caption_chunks={max_caption_chunks}") - - start_time = time.time() - - # Model will handle whether to use it for frames based on supports_video - response = self.model.generate( - frames=str(video_path) if video_path else None, # โ† Pass video_path, not actual_video_path - audio=str(actual_audio_path) if actual_audio_path else None, - prompt=prompt, - max_frames=max_frames, - max_caption_chunks=max_caption_chunks, - caption_path=str(caption_path) if caption_path else None, - video_category=video_category - ) - - elif retry_strategy == 'frame_count': - # Frame count + audio chunks - max_frames = frame_options[min(attempt, len(frame_options) - 1)] - max_audio_chunks = audio_options[min(attempt, len(audio_options) - 1)] - - preprocessing = { - 'max_frames': max_frames, - 'attempt': attempt + 1, - 'method': 'internal_sampling', - 'video_category': video_category - } - - start_time = time.time() - response = self.model.generate( - frames=str(actual_video_path) if actual_video_path else None, - audio=str(actual_audio_path) if actual_audio_path else None, - prompt=prompt, - max_frames=max_frames, - max_audio_chunks=max_audio_chunks, - audio_chunk_duration_sec=audio_chunk_duration, - video_category=video_category - ) - - else: # retry_strategy == 'fps' - # FPS-based - fps = fps_options[min(attempt, len(fps_options) - 1)] - - preprocessing = { - 'fps_used': fps, - 'attempt': attempt + 1, - 'method': 'api_fps_sampling', - 'video_category': video_category - } - - start_time = time.time() - response = self.model.generate( - frames=str(actual_video_path) if actual_video_path else None, - audio=str(actual_audio_path) if actual_audio_path else None, - prompt=prompt, - fps=fps, - video_category=video_category + frame_options = self.retry_config["frame_count_fallback"] + caption_options = [None, None, None, None] + + elif retry_strategy == "frame_count": + # Frame count + audio chunks (e.g., local models) + if "retry_override" in self.model_config: + frame_options = self.model_config["retry_override"].get( + "frame_count_fallback", self.retry_config["frame_count_fallback"] + ) + audio_options = self.model_config["retry_override"].get( + "audio_chunks_fallback", self.retry_config["audio_chunks_fallback"] + ) + else: + frame_options = self.retry_config["frame_count_fallback"] + audio_options = self.retry_config.get( + "audio_chunks_fallback", [None, None, None, None] + ) + audio_chunk_duration = self.retry_config.get( + "audio_chunk_duration_sec", 10.0 + ) + + elif retry_strategy == "fps": + # FPS-based (e.g., Gemini) + fps_options = self.retry_config["fps_fallback"] + + else: + raise ValueError(f"Unknown retry_strategy: {retry_strategy}") + + last_error = None + + for attempt in range(max_attempts): + try: + if retry_strategy == "frame_count_caption": + # Frame count + caption chunks + max_frames = frame_options[min(attempt, len(frame_options) - 1)] + max_caption_chunks = caption_options[ + min(attempt, len(caption_options) - 1) + ] + + # Discover caption path for GPT-4o + caption_path = None + if use_captions and video_path: + caption_path = self._get_caption_path_for_video( + video_path, self.current_topic_name ) - - preprocessing['inference_time'] = time.time() - start_time - response_clean = self._clean_json_response(response) - + + if caption_path: + logger.info(f"Discovered caption path: {caption_path}") + else: + logger.warning( + f"Could not discover caption path for {video_path}" + ) + + preprocessing = { + "max_frames": max_frames, + "max_caption_chunks": max_caption_chunks, + "attempt": attempt + 1, + "method": "frame_caption_sampling", + "video_category": video_category, + } + + logger.info( + f"Attempt {attempt + 1}/{max_attempts}: max_frames={max_frames}, max_caption_chunks={max_caption_chunks}" + ) + + start_time = time.time() + + # Model uses frames or not based on supports_video + response = self.model.generate( + frames=str(video_path) + if video_path + else None, # โ† Pass video_path, not actual_video_path + audio=str(actual_audio_path) if actual_audio_path else None, + prompt=prompt, + max_frames=max_frames, + max_caption_chunks=max_caption_chunks, + caption_path=str(caption_path) if caption_path else None, + video_category=video_category, + ) + + elif retry_strategy == "frame_count": + # Frame count + audio chunks + max_frames = frame_options[min(attempt, len(frame_options) - 1)] + max_audio_chunks = audio_options[ + min(attempt, len(audio_options) - 1) + ] + + preprocessing = { + "max_frames": max_frames, + "attempt": attempt + 1, + "method": "internal_sampling", + "video_category": video_category, + } + + start_time = time.time() + response = self.model.generate( + frames=str(actual_video_path) if actual_video_path else None, + audio=str(actual_audio_path) if actual_audio_path else None, + prompt=prompt, + max_frames=max_frames, + max_audio_chunks=max_audio_chunks, + audio_chunk_duration_sec=audio_chunk_duration, + video_category=video_category, + ) + + else: # retry_strategy == 'fps' + # FPS-based + fps = fps_options[min(attempt, len(fps_options) - 1)] + + preprocessing = { + "fps_used": fps, + "attempt": attempt + 1, + "method": "api_fps_sampling", + "video_category": video_category, + } + + start_time = time.time() + response = self.model.generate( + frames=str(actual_video_path) if actual_video_path else None, + audio=str(actual_audio_path) if actual_audio_path else None, + prompt=prompt, + fps=fps, + video_category=video_category, + ) + + preprocessing["inference_time"] = time.time() - start_time + response_clean = self._clean_json_response(response) + + try: + output = json.loads(response_clean) + except json.JSONDecodeError as e: try: - output = json.loads(response_clean) - except json.JSONDecodeError as e: - try: - from json_repair import repair_json - repaired_json = repair_json(response_clean, return_objects=False) - output = json.loads(repaired_json) - except Exception: - raise ValueError(f"JSON parsing failed: {e}") - - is_valid, validation_msg = self.validate_output(output, task_type) - if not is_valid: - raise ValueError(f"Validation failed: {validation_msg}") - - return output, preprocessing - - except Exception as e: - last_error = str(e) - logger.warning(f"Attempt {attempt + 1}/{max_attempts} failed: {last_error}") - - if attempt < max_attempts - 1: - time.sleep(2) - continue - - raise RuntimeError(f"All {max_attempts} attempts failed. Last error: {last_error}") - + # Optional dep: only import when repair is needed + from json_repair import repair_json # noqa: PLC0415 + + repaired_json = repair_json( + response_clean, return_objects=False + ) + output = json.loads(repaired_json) + except Exception: + raise ValueError(f"JSON parsing failed: {e}") from e + + is_valid, validation_msg = self.validate_output(output, task_type) + if not is_valid: + raise ValueError(f"Validation failed: {validation_msg}") + + return output, preprocessing + + except Exception as e: + last_error = str(e) + logger.warning( + f"Attempt {attempt + 1}/{max_attempts} failed: {last_error}" + ) + + if attempt < max_attempts - 1: + time.sleep(2) + continue + + raise RuntimeError( + f"All {max_attempts} attempts failed. Last error: {last_error}" + ) + def run_task1( self, topic_name: str, enable_empathy: bool = False, dry_run: bool = False, overwrite: bool = False, - retry_failed: bool = False + retry_failed: bool = False, ) -> List[Dict]: - """Run Task 1: Video Summarization""" + """Run Task 1: Video Summarization.""" logger.info(f"Running Task 1 (Summarization) for {topic_name}") self.video_metadata = self._load_video_metadata(topic_name) - ground_truth = self.load_ground_truth('task1_summarization', topic_name) - + ground_truth = self.load_ground_truth("task1_summarization", topic_name) + existing_predictions = None if not overwrite: - existing_predictions = self.load_existing_predictions('task1_summarization', topic_name) - + existing_predictions = self.load_existing_predictions( + "task1_summarization", topic_name + ) + # Initialize predictions list with None for each GT entry - num_gt_entries = len(ground_truth['entries']) + num_gt_entries = len(ground_truth["entries"]) predictions = [None] * num_gt_entries - + # Load existing predictions by matching to GT index if existing_predictions and not overwrite: - existing_entries = existing_predictions.get('entries', []) - + existing_entries = existing_predictions.get("entries", []) + for pred in existing_entries: - pred_key = (pred.get('video_id'), pred.get('video_number')) - - for gt_idx, gt_entry in enumerate(ground_truth['entries']): + pred_key = (pred.get("video_id"), pred.get("video_number")) + + for gt_idx, gt_entry in enumerate(ground_truth["entries"]): if predictions[gt_idx] is not None: continue - - gt_key = (gt_entry['video_id'], gt_entry['video_number']) - + + gt_key = (gt_entry["video_id"], gt_entry["video_number"]) + if pred_key == gt_key: - if 'error' not in pred: + if "error" not in pred: predictions[gt_idx] = pred break - + num_done = sum(1 for p in predictions if p is not None) logger.info(f"Found {num_done}/{num_gt_entries} already processed videos") - + indices_to_process = [i for i, p in enumerate(predictions) if p is None] - + if not indices_to_process: logger.info("No entries to process") return [p for p in predictions if p is not None] - + logger.info(f"Processing {len(indices_to_process)} videos") - + if dry_run: logger.info("DRY RUN - No actual inference will be performed") return [p for p in predictions if p is not None] - + for gt_idx in tqdm(indices_to_process, desc=f"Task 1 - {topic_name}"): - entry = ground_truth['entries'][gt_idx] - video_id = entry['video_id'] - video_number = entry['video_number'] - duration = entry['duration_seconds'] + entry = ground_truth["entries"][gt_idx] + video_id = entry["video_id"] + video_number = entry["video_number"] + duration = entry["duration_seconds"] video_category = self.get_video_category(video_id) - + converted_audio_path = None - + try: video_path = self.get_video_path(topic_name, video_number) audio_path = self.get_audio_path(topic_name, video_number) - - if audio_path and self._get_audio_format() == 'wav': - wav_path = audio_path.with_suffix('.wav') + + if audio_path and self._get_audio_format() == "wav": + wav_path = audio_path.with_suffix(".wav") converted_audio_path = self.video_segmenter.convert_audio_format( - audio_path, wav_path, 'wav' + audio_path, wav_path, "wav" ) audio_path = converted_audio_path - + prompt = get_t1_prompt(duration) output, preprocessing = self._generate_with_retry( - video_path, audio_path, prompt, video_category, 'task1_summarization', topic_name + video_path, + audio_path, + prompt, + video_category, + "task1_summarization", + topic_name, ) - + prediction = { - 'video_id': video_id, - 'video_number': video_number, - 'duration_seconds': duration, - 'preprocessing': preprocessing, - 'outputs': output + "video_id": video_id, + "video_number": video_number, + "duration_seconds": duration, + "preprocessing": preprocessing, + "outputs": output, } - + if enable_empathy: try: empathy_prompt = get_t1_empathy_prompt(duration) empathy_output, _ = self._generate_with_retry( - video_path, audio_path, empathy_prompt, video_category, 'task1_empathy',topic_name + video_path, + audio_path, + empathy_prompt, + video_category, + "task1_empathy", + topic_name, ) - prediction['empathy'] = empathy_output + prediction["empathy"] = empathy_output except Exception as e: - prediction['empathy_error'] = str(e) - + prediction["empathy_error"] = str(e) + predictions[gt_idx] = prediction - + except Exception as e: logger.error(f"Error processing video {video_number}: {e}") predictions[gt_idx] = { - 'video_id': video_id, - 'video_number': video_number, - 'error': str(e), - 'timestamp': datetime.now().isoformat() + "video_id": video_id, + "video_number": video_number, + "error": str(e), + "timestamp": datetime.now().isoformat(), } - self.failed_entries.append({ - 'task': 'task1_summarization', - 'topic': topic_name, - 'video_id': video_id, - 'error': str(e) - }) - + self.failed_entries.append( + { + "task": "task1_summarization", + "topic": topic_name, + "video_id": video_id, + "error": str(e), + } + ) + finally: if converted_audio_path and converted_audio_path.exists(): converted_audio_path.unlink() - - return [p for p in predictions if p is not None] + return [p for p in predictions if p is not None] def run_task2( self, topic_name: str, dry_run: bool = False, overwrite: bool = False, - retry_failed: bool = False + retry_failed: bool = False, ) -> List[Dict]: - """Run Task 2: Question Answering""" + """Run Task 2: Question Answering.""" logger.info(f"Running Task 2 (MCQ) for {topic_name}") self.video_metadata = self._load_video_metadata(topic_name) - ground_truth = self.load_ground_truth('task2_mcq', topic_name) - + ground_truth = self.load_ground_truth("task2_mcq", topic_name) + existing_predictions = None if not overwrite: - existing_predictions = self.load_existing_predictions('task2_mcq', topic_name) - + existing_predictions = self.load_existing_predictions( + "task2_mcq", topic_name + ) + # Initialize predictions list with None for each GT entry - num_gt_entries = len(ground_truth['entries']) + num_gt_entries = len(ground_truth["entries"]) predictions = [None] * num_gt_entries - + # Load existing predictions by matching to GT index if existing_predictions and not overwrite: - existing_entries = existing_predictions.get('entries', []) - + existing_entries = existing_predictions.get("entries", []) + for pred in existing_entries: - pred_seg = pred.get('segment', {}) + pred_seg = pred.get("segment", {}) pred_key = ( - pred.get('video_id'), - pred.get('video_number'), - pred_seg.get('start'), - pred_seg.get('end') + pred.get("video_id"), + pred.get("video_number"), + pred_seg.get("start"), + pred_seg.get("end"), ) - - for gt_idx, gt_entry in enumerate(ground_truth['entries']): + + for gt_idx, gt_entry in enumerate(ground_truth["entries"]): if predictions[gt_idx] is not None: continue - + gt_key = ( - gt_entry['video_id'], - gt_entry['video_number'], - gt_entry['segment']['start'], - gt_entry['segment']['end'] + gt_entry["video_id"], + gt_entry["video_number"], + gt_entry["segment"]["start"], + gt_entry["segment"]["end"], ) - + if pred_key == gt_key: - if 'error' not in pred: + if "error" not in pred: predictions[gt_idx] = pred break - + num_done = sum(1 for p in predictions if p is not None) logger.info(f"Found {num_done}/{num_gt_entries} already processed segments") - + indices_to_process = [i for i, p in enumerate(predictions) if p is None] - + if not indices_to_process: logger.info("No entries to process") return [p for p in predictions if p is not None] - + logger.info(f"Processing {len(indices_to_process)} segments") - + if dry_run: logger.info("DRY RUN - No actual inference will be performed") return [p for p in predictions if p is not None] - + segment_dir = self._get_temp_dir() segment_dir.mkdir(parents=True, exist_ok=True) - + try: for gt_idx in tqdm(indices_to_process, desc=f"Task 2 - {topic_name}"): - entry = ground_truth['entries'][gt_idx] - video_id = entry['video_id'] - video_number = entry['video_number'] - segment = entry['segment'] - question = entry['question'] - options = entry['options'] + entry = ground_truth["entries"][gt_idx] + video_id = entry["video_id"] + video_number = entry["video_number"] + segment = entry["segment"] + question = entry["question"] + options = entry["options"] video_category = self.get_video_category(video_id) - + video_segment_path = None audio_segment_path = None - + try: video_path = self.get_video_path(topic_name, video_number) audio_path = self.get_audio_path(topic_name, video_number) - - video_segment_path = segment_dir / f'seg_{video_number}_{segment["start"]}_{segment["end"]}_{gt_idx}.mp4' + + video_segment_path = ( + segment_dir + / f"seg_{video_number}_{segment['start']}_{segment['end']}_{gt_idx}.mp4" + ) self.video_segmenter.extract_video_segment( - video_path, segment['start'], segment['end'], video_segment_path + video_path, segment["start"], segment["end"], video_segment_path ) - + if audio_path: audio_format = self._get_audio_format() - audio_segment_path = segment_dir / f'seg_{video_number}_{segment["start"]}_{segment["end"]}_{gt_idx}.{audio_format}' + audio_segment_path = ( + segment_dir + / f"seg_{video_number}_{segment['start']}_{segment['end']}_{gt_idx}.{audio_format}" + ) self.video_segmenter.extract_audio_segment( - audio_path, segment['start'], segment['end'], - audio_segment_path, output_format=audio_format + audio_path, + segment["start"], + segment["end"], + audio_segment_path, + output_format=audio_format, ) - + prompt = get_t2_prompt(question, options) output, preprocessing = self._generate_with_retry( - video_segment_path, audio_segment_path, prompt, video_category, 'task2_mcq', topic_name + video_segment_path, + audio_segment_path, + prompt, + video_category, + "task2_mcq", + topic_name, ) - + predictions[gt_idx] = { - 'video_id': video_id, - 'video_number': video_number, - 'segment': segment, - 'question': question, - 'options': options, - 'preprocessing': preprocessing, - 'outputs': output + "video_id": video_id, + "video_number": video_number, + "segment": segment, + "question": question, + "options": options, + "preprocessing": preprocessing, + "outputs": output, } - + except Exception as e: logger.error(f"Error processing segment: {e}") predictions[gt_idx] = { - 'video_id': video_id, - 'video_number': video_number, - 'segment': segment, - 'error': str(e), - 'timestamp': datetime.now().isoformat() + "video_id": video_id, + "video_number": video_number, + "segment": segment, + "error": str(e), + "timestamp": datetime.now().isoformat(), } - self.failed_entries.append({ - 'task': 'task2_mcq', - 'topic': topic_name, - 'video_id': video_id, - 'segment': segment, - 'error': str(e) - }) - + self.failed_entries.append( + { + "task": "task2_mcq", + "topic": topic_name, + "video_id": video_id, + "segment": segment, + "error": str(e), + } + ) + finally: if video_segment_path and video_segment_path.exists(): video_segment_path.unlink() if audio_segment_path and audio_segment_path.exists(): audio_segment_path.unlink() - + if video_segment_path: converted_dir = video_segment_path.parent / "converted" - potential_converted = converted_dir / f"{video_segment_path.stem}_h264{video_segment_path.suffix}" + potential_converted = ( + converted_dir + / f"{video_segment_path.stem}_h264{video_segment_path.suffix}" + ) if potential_converted.exists(): potential_converted.unlink() - + return [p for p in predictions if p is not None] - + finally: shutil.rmtree(segment_dir, ignore_errors=True) - + def run_task3( self, topic_name: str, dry_run: bool = False, overwrite: bool = False, - retry_failed: bool = False + retry_failed: bool = False, ) -> List[Dict]: - """Run Task 3: Temporal Localization""" + """Run Task 3: Temporal Localization.""" logger.info(f"Running Task 3 (Temporal) for {topic_name}") self.video_metadata = self._load_video_metadata(topic_name) - ground_truth = self.load_ground_truth('task3_temporal_localization', topic_name) - + ground_truth = self.load_ground_truth("task3_temporal_localization", topic_name) + existing_predictions = None if not overwrite: - existing_predictions = self.load_existing_predictions('task3_temporal_localization', topic_name) - + existing_predictions = self.load_existing_predictions( + "task3_temporal_localization", topic_name + ) + # Initialize predictions list with None for each GT entry - num_gt_entries = len(ground_truth['entries']) + num_gt_entries = len(ground_truth["entries"]) predictions = [None] * num_gt_entries - + # Load existing predictions by matching to GT index if existing_predictions and not overwrite: - existing_entries = existing_predictions.get('entries', []) - + existing_entries = existing_predictions.get("entries", []) + # Match each existing prediction to its GT index for pred in existing_entries: - pred_seg = pred.get('segment', {}) + pred_seg = pred.get("segment", {}) pred_key = ( - pred.get('video_id'), - pred.get('video_number'), - pred_seg.get('start'), - pred_seg.get('end') + pred.get("video_id"), + pred.get("video_number"), + pred_seg.get("start"), + pred_seg.get("end"), ) - + # Find matching GT index (first unmatched one with same key) - for gt_idx, gt_entry in enumerate(ground_truth['entries']): + for gt_idx, gt_entry in enumerate(ground_truth["entries"]): if predictions[gt_idx] is not None: continue # Already filled - + gt_key = ( - gt_entry['video_id'], - gt_entry['video_number'], - gt_entry['segment']['start'], - gt_entry['segment']['end'] + gt_entry["video_id"], + gt_entry["video_number"], + gt_entry["segment"]["start"], + gt_entry["segment"]["end"], ) - + if pred_key == gt_key: # Only keep if successful - if 'error' not in pred: + if "error" not in pred: predictions[gt_idx] = pred break - + # Count how many are done num_done = sum(1 for p in predictions if p is not None) logger.info(f"Found {num_done}/{num_gt_entries} already processed segments") - + # Find indices to process indices_to_process = [i for i, p in enumerate(predictions) if p is None] - + if not indices_to_process: logger.info("No entries to process") # Convert to list (remove None placeholders - shouldn't be any) return [p for p in predictions if p is not None] - + logger.info(f"Processing {len(indices_to_process)} segments") - + if dry_run: logger.info("DRY RUN - No actual inference will be performed") return [p for p in predictions if p is not None] - + segment_dir = self._get_temp_dir() segment_dir.mkdir(parents=True, exist_ok=True) - + try: for gt_idx in tqdm(indices_to_process, desc=f"Task 3 - {topic_name}"): - entry = ground_truth['entries'][gt_idx] - video_id = entry['video_id'] - video_number = entry['video_number'] - segment = entry['segment'] - questions = entry['questions'] + entry = ground_truth["entries"][gt_idx] + video_id = entry["video_id"] + video_number = entry["video_number"] + segment = entry["segment"] + questions = entry["questions"] video_category = self.get_video_category(video_id) - + video_segment_path = None audio_segment_path = None - + try: video_path = self.get_video_path(topic_name, video_number) audio_path = self.get_audio_path(topic_name, video_number) - - video_segment_path = segment_dir / f'seg_{video_number}_{segment["start"]}_{segment["end"]}_{gt_idx}.mp4' + + video_segment_path = ( + segment_dir + / f"seg_{video_number}_{segment['start']}_{segment['end']}_{gt_idx}.mp4" + ) self.video_segmenter.extract_video_segment( - video_path, segment['start'], segment['end'], video_segment_path + video_path, segment["start"], segment["end"], video_segment_path ) - + if audio_path: audio_format = self._get_audio_format() - audio_segment_path = segment_dir / f'seg_{video_number}_{segment["start"]}_{segment["end"]}_{gt_idx}.{audio_format}' + audio_segment_path = ( + segment_dir + / f"seg_{video_number}_{segment['start']}_{segment['end']}_{gt_idx}.{audio_format}" + ) self.video_segmenter.extract_audio_segment( - audio_path, segment['start'], segment['end'], - audio_segment_path, output_format=audio_format + audio_path, + segment["start"], + segment["end"], + audio_segment_path, + output_format=audio_format, ) - - prompt = get_t3_prompt(questions, segment['start'], segment['end']) + + prompt = get_t3_prompt(questions, segment["start"], segment["end"]) output, preprocessing = self._generate_with_retry( - video_segment_path, audio_segment_path, prompt, - video_category, 'task3_temporal_localization', topic_name + video_segment_path, + audio_segment_path, + prompt, + video_category, + "task3_temporal_localization", + topic_name, ) - + predictions[gt_idx] = { - 'video_id': video_id, - 'video_number': video_number, - 'segment': segment, - 'preprocessing': preprocessing, - 'outputs': output + "video_id": video_id, + "video_number": video_number, + "segment": segment, + "preprocessing": preprocessing, + "outputs": output, } - + except Exception as e: logger.error(f"Error processing segment: {e}") predictions[gt_idx] = { - 'video_id': video_id, - 'video_number': video_number, - 'segment': segment, - 'error': str(e), - 'timestamp': datetime.now().isoformat() + "video_id": video_id, + "video_number": video_number, + "segment": segment, + "error": str(e), + "timestamp": datetime.now().isoformat(), } - self.failed_entries.append({ - 'task': 'task3_temporal_localization', - 'topic': topic_name, - 'video_id': video_id, - 'segment': segment, - 'error': str(e) - }) - + self.failed_entries.append( + { + "task": "task3_temporal_localization", + "topic": topic_name, + "video_id": video_id, + "segment": segment, + "error": str(e), + } + ) + finally: if video_segment_path and video_segment_path.exists(): video_segment_path.unlink() if audio_segment_path and audio_segment_path.exists(): audio_segment_path.unlink() - + if video_segment_path: converted_dir = video_segment_path.parent / "converted" - potential_converted = converted_dir / f"{video_segment_path.stem}_h264{video_segment_path.suffix}" + potential_converted = ( + converted_dir + / f"{video_segment_path.stem}_h264{video_segment_path.suffix}" + ) if potential_converted.exists(): potential_converted.unlink() - + # Return only non-None entries (all should be filled now) return [p for p in predictions if p is not None] - + finally: shutil.rmtree(segment_dir, ignore_errors=True) - + def save_predictions(self, task: str, topic_name: str, predictions: List[Dict]): - """Save predictions to JSON""" + """Save predictions to JSON.""" if self.experiment_name: - output_dir = Path('05_evaluation_inference/results/predictions') / self.experiment_name / self.model_name / task + output_dir = ( + Path("05_evaluation_inference/results/predictions") + / self.experiment_name + / self.model_name + / task + ) else: - output_dir = Path('05_evaluation_inference/results/predictions') / self.model_name / task - + output_dir = ( + Path("05_evaluation_inference/results/predictions") + / self.model_name + / task + ) + output_dir.mkdir(parents=True, exist_ok=True) - output_file = output_dir / f'{topic_name}.json' - - successful = len([p for p in predictions if 'error' not in p]) - failed = len([p for p in predictions if 'error' in p]) - + output_file = output_dir / f"{topic_name}.json" + + successful = len([p for p in predictions if "error" not in p]) + failed = len([p for p in predictions if "error" in p]) + output_data = { - 'model': self.model_name, - 'topic_name': topic_name, - 'task': task, - 'generated_at': datetime.now().isoformat(), - 'num_entries': len(predictions), - 'num_successful': successful, - 'num_failed': failed, - 'entries': predictions + "model": self.model_name, + "topic_name": topic_name, + "task": task, + "generated_at": datetime.now().isoformat(), + "num_entries": len(predictions), + "num_successful": successful, + "num_failed": failed, + "entries": predictions, } - - with open(output_file, 'w') as f: + + with open(output_file, "w") as f: json.dump(output_data, f, indent=2) - - logger.info(f"Saved predictions to {output_file} ({successful} successful, {failed} failed)") - - def _deduplicate_predictions(self, predictions: List[Dict], task: str) -> List[Dict]: - """Remove failed entries if successful entry exists for same key""" - + logger.info( + f"Saved predictions to {output_file} ({successful} successful, {failed} failed)" + ) + + def _deduplicate_predictions( + self, predictions: List[Dict], task: str + ) -> List[Dict]: + """Remove failed entries if successful entry exists for same key.""" + def get_key(entry: Dict) -> Tuple: - if task == 'task1_summarization': - return (entry.get('video_id'), entry.get('video_number')) - else: - seg = entry.get('segment', {}) - return ( - entry.get('video_id'), - entry.get('video_number'), - seg.get('start'), - seg.get('end') - ) - + if task == "task1_summarization": + return (entry.get("video_id"), entry.get("video_number")) + seg = entry.get("segment", {}) + return ( + entry.get("video_id"), + entry.get("video_number"), + seg.get("start"), + seg.get("end"), + ) + # First pass: collect all successful keys successful_keys = set() for p in predictions: - if 'error' not in p: + if "error" not in p: successful_keys.add(get_key(p)) - + # Second pass: keep only if successful OR no successful version exists result = [] seen_keys = set() - + for p in predictions: key = get_key(p) - + if key in seen_keys: continue # Already added this key - - if 'error' in p and key in successful_keys: + + if "error" in p and key in successful_keys: continue # Skip failed, successful exists - + result.append(p) seen_keys.add(key) - + removed = len(predictions) - len(result) if removed > 0: logger.info(f"Removed {removed} duplicate/superseded entries") - + return result - def save_failed_log(self): - """Save log of failed entries""" + def save_failed_log(self) -> None: + """Save log of failed entries to a JSON file.""" if not self.failed_entries: return - + if self.experiment_name: - failed_log_dir = Path('05_evaluation_inference/results/failed_logs') / self.experiment_name / self.model_name + failed_log_dir = ( + Path("05_evaluation_inference/results/failed_logs") + / self.experiment_name + / self.model_name + ) else: - failed_log_dir = Path('05_evaluation_inference/results/failed_logs') / self.model_name - + failed_log_dir = ( + Path("05_evaluation_inference/results/failed_logs") / self.model_name + ) + failed_log_dir.mkdir(parents=True, exist_ok=True) - - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - failed_log_file = failed_log_dir / f'failed_{timestamp}.json' - - with open(failed_log_file, 'w') as f: + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + failed_log_file = failed_log_dir / f"failed_{timestamp}.json" + + with open(failed_log_file, "w") as f: json.dump(self.failed_entries, f, indent=2) - + logger.info(f"Saved failed entries log to {failed_log_file}") - + def run( self, model_name: str, tasks: List[str], topics: List[str], - enable_empathy: bool = None, + enable_empathy: Optional[bool] = None, dry_run: bool = False, overwrite: bool = False, - retry_failed: bool = False - ): - """Run inference for specified model, tasks, and topics""" + retry_failed: bool = False, + ) -> None: + """ + Run inference for specified model, tasks, and topics. + + Args: + model_name: Name of the model (must exist in config). + tasks: List of task identifiers to run. + topics: List of topic names to process. + enable_empathy: If True/False, override config; if None, use config. + dry_run: If True, skip actual model loading and inference. + overwrite: If True, recompute even when predictions exist. + retry_failed: If True, retry only previously failed entries. + """ self.model_name = model_name if enable_empathy is None: - enable_empathy = self.config.get('empathy', {}).get('enabled', False) - + enable_empathy = self.config.get("empathy", {}).get("enabled", False) + if not dry_run: self.load_model(model_name) - + for task in tasks: for topic in topics: logger.info(f"Processing {task} - {topic}") - + try: - if task == 'task1_summarization': - predictions = self.run_task1(topic, enable_empathy, dry_run, overwrite, retry_failed) - elif task == 'task2_mcq': - predictions = self.run_task2(topic, dry_run, overwrite, retry_failed) - elif task == 'task3_temporal_localization': - predictions = self.run_task3(topic, dry_run, overwrite, retry_failed) + if task == "task1_summarization": + predictions = self.run_task1( + topic, enable_empathy, dry_run, overwrite, retry_failed + ) + elif task == "task2_mcq": + predictions = self.run_task2( + topic, dry_run, overwrite, retry_failed + ) + elif task == "task3_temporal_localization": + predictions = self.run_task3( + topic, dry_run, overwrite, retry_failed + ) else: logger.error(f"Unknown task: {task}") continue - + if not dry_run: self.save_predictions(task, topic, predictions) - + except Exception as e: logger.error(f"Error processing {task} - {topic}: {e}") - import traceback traceback.print_exc() continue - + if not dry_run: if self.frame_sampler: self.frame_sampler.cleanup() - + if self.failed_entries: self.save_failed_log() logger.warning(f"Total failed entries: {len(self.failed_entries)}") - + self.model.unload() - + logger.info("Inference complete") - + def _get_audio_format(self) -> str: - """Get required audio format from model config""" - return self.model_config.get('audio_format', 'm4a') - - def _get_caption_path_for_video(self, video_path: Path, topic_name: str) -> Optional[Path]: + """Get required audio format from model config.""" + return self.model_config.get("audio_format", "m4a") + + def _get_caption_path_for_video( + self, video_path: Path, topic_name: str + ) -> Optional[Path]: """ Get caption path for a video (full or segment). - Extracts video number from either pattern and builds caption path. - + + Extracts video number from path stem and builds caption file path. + Args: - video_path: Path to video file (full video or segment) - topic_name: Topic name for building caption path - + video_path: Path to video file (full video or segment). + topic_name: Topic name for building caption path. + Returns: - Path to caption file if it exists, None otherwise + Path to caption file if it exists, None otherwise. """ - import re - video_name = video_path.stem - + # Try pattern 1: "video_001" - match = re.search(r'video_(\d+)', video_name) - + match = re.search(r"video_(\d+)", video_name) + # Try pattern 2: "seg_001_30_60_5" if not match: - match = re.search(r'seg_(\d+)_', video_name) - + match = re.search(r"seg_(\d+)_", video_name) + if match: video_number = match.group(1) - caption_path = self.dataset_path / 'captions' / topic_name / f'caption_{video_number}.srt' - + caption_path = ( + self.dataset_path + / "captions" + / topic_name + / f"caption_{video_number}.srt" + ) + if caption_path.exists(): return caption_path - else: - logger.debug(f"Caption file not found: {caption_path}") - + logger.debug(f"Caption file not found: {caption_path}") + return None -if __name__ == '__main__': - import argparse - - parser = argparse.ArgumentParser(description='Run model inference') - parser.add_argument('--config', type=str, default='configs/models_config.yaml') - parser.add_argument('--model', type=str, required=True) - parser.add_argument('--tasks', nargs='+', default=None) - parser.add_argument('--topics', nargs='+', default=None) - parser.add_argument('--experiment-name', type=str, default=None) - parser.add_argument('--empathy', action='store_true', default=None) - parser.add_argument('--no-empathy', action='store_true') - parser.add_argument('--dry-run', action='store_true') - parser.add_argument('--overwrite', action='store_true') - parser.add_argument('--retry-failed', action='store_true') - + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run model inference") + parser.add_argument("--config", type=str, default="models_config.yaml") + parser.add_argument("--model", type=str, required=True) + parser.add_argument("--tasks", nargs="+", default=None) + parser.add_argument("--topics", nargs="+", default=None) + parser.add_argument("--experiment-name", type=str, default=None) + parser.add_argument("--empathy", action="store_true", default=None) + parser.add_argument("--no-empathy", action="store_true") + parser.add_argument("--dry-run", action="store_true") + parser.add_argument("--overwrite", action="store_true") + parser.add_argument("--retry-failed", action="store_true") + args = parser.parse_args() - + runner = InferenceRunner(args.config, experiment_name=args.experiment_name) - - tasks = args.tasks if args.tasks else runner.config['tasks'] - topics = args.topics if args.topics else runner.config['topics'] - + + tasks = args.tasks if args.tasks else runner.config["tasks"] + topics = args.topics if args.topics else runner.config["topics"] + if args.no_empathy: enable_empathy = False elif args.empathy: enable_empathy = True else: enable_empathy = None - + runner.run( args.model, tasks, @@ -1070,5 +1239,5 @@ def _get_caption_path_for_video(self, video_path: Path, topic_name: str) -> Opti enable_empathy, args.dry_run, args.overwrite, - args.retry_failed - ) \ No newline at end of file + args.retry_failed, + ) diff --git a/sonic-o1/05_evaluation_inference/metrics/__init__.py b/sonic-o1/05_evaluation_inference/metrics/__init__.py index 6f90795..652897c 100644 --- a/sonic-o1/05_evaluation_inference/metrics/__init__.py +++ b/sonic-o1/05_evaluation_inference/metrics/__init__.py @@ -1,5 +1,6 @@ -""" -metrics/__init__.py +"""metrics/__init__.py + +Metrics pipeline for model scoring. -metrics pipeline for model scoring. -""" \ No newline at end of file +Author: SONIC-O1 Team +""" diff --git a/sonic-o1/05_evaluation_inference/metrics/compute_metrics.py b/sonic-o1/05_evaluation_inference/metrics/compute_metrics.py index 59856ec..4c102b8 100644 --- a/sonic-o1/05_evaluation_inference/metrics/compute_metrics.py +++ b/sonic-o1/05_evaluation_inference/metrics/compute_metrics.py @@ -1,55 +1,72 @@ -""" -Main Metrics Computation Script -Orchestrates evaluation for all tasks and topics +"""compute_metrics.py + +Main metrics computation: orchestrates evaluation for all tasks and topics. + +Author: SONIC-O1 Team """ import argparse import json import logging -from pathlib import Path -from typing import List, Optional import sys +from pathlib import Path +from typing import Any, Dict, List, Optional + +sys.path.append(str(Path(__file__).parent.parent)) from t1_metrics import evaluate_t1_topic from t2_metrics import evaluate_t2_topic from t3_metrics import evaluate_t3_topic - -sys.path.append(str(Path(__file__).parent.parent)) - from utils.config_loader import get_config + # Setup logging logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) -def load_existing_topic_results(output_path: Path, model_name: str, task_name: str, - topics: List[str], judge_dir: str, - experiment_name: Optional[str] = None) -> dict: - """Load previously computed topic results""" + +def load_existing_topic_results( + output_path: Path, + model_name: str, + task_name: str, + topics: List[str], + judge_dir: str, + experiment_name: Optional[str] = None, +) -> Dict[str, Any]: + """Load previously computed topic results from per-topic JSON files.""" existing_results = {} - + for topic in topics: if experiment_name: - topic_file = output_path / judge_dir / experiment_name / model_name / task_name / f"{topic}.json" + topic_file = ( + output_path + / judge_dir + / experiment_name + / model_name + / task_name + / f"{topic}.json" + ) else: - topic_file = output_path / judge_dir / model_name / task_name / f"{topic}.json" - + topic_file = ( + output_path / judge_dir / model_name / task_name / f"{topic}.json" + ) + if topic_file.exists(): try: - with open(topic_file, 'r') as f: + with open(topic_file, "r") as f: data = json.load(f) existing_results[topic] = data.get("aggregated_metrics", {}) logger.info(f" Loaded existing results for {topic}") except Exception as e: logger.warning(f" Failed to load {topic_file}: {e}") - + return existing_results -def get_task_mapping(config): - """Get task mapping from config""" + +def get_task_mapping(config: Any) -> Dict[str, str]: + """Get task key to full name from config (e.g. t1 -> task1_summarization).""" tasks = config.get("tasks", []) mapping = {} for i, task in enumerate(tasks, 1): @@ -65,88 +82,179 @@ def compute_metrics_for_model( predictions_path: Path, output_path: Path, use_llm_judge: bool = True, - config = None, - experiment_name: Optional[str] = None -): + config: Optional[Any] = None, + experiment_name: Optional[str] = None, +) -> Dict[str, Any]: """ - Compute metrics for all tasks and topics for a model - + Compute metrics for all tasks and topics for a model. + Args: - model_name: Name of the model - tasks: List of task names to evaluate - topics: List of topics to evaluate - vqa_path: Path to VQA ground truth directory - predictions_path: Path to predictions directory - output_path: Path to save results - use_llm_judge: Whether to use LLM judge + model_name: Name of the model to evaluate. + tasks: List of task keys (e.g. t1, t2, t3) to evaluate. + topics: List of topic names to evaluate. + vqa_path: Path to VQA ground truth directory. + predictions_path: Path to predictions directory. + output_path: Path to save results. + use_llm_judge: Whether to use LLM judge. + config: ConfigLoader instance (from get_config). + experiment_name: Optional experiment label for paths. + + Returns: + Aggregated results dict with tasks and overall metrics. """ + if config is None: + raise ValueError("config is required for compute_metrics_for_model") + logger.info(f"Computing metrics for model: {model_name}") task_mapping = get_task_mapping(config) judge_name = config.get_llm_judge_model() logger.info(f"Using LLM judge: {judge_name} (enabled: {use_llm_judge})") - if 'gpt' in judge_name.lower(): - judge_name = 'gpt' - judge_dir = 'gpt_judge' - elif 'qwen' in judge_name.lower(): - judge_dir = 'qwen_judge' - judge_name= 'qwen' + if "gpt" in judge_name.lower(): + judge_name = "gpt" + judge_dir = "gpt_judge" + elif "qwen" in judge_name.lower(): + judge_dir = "qwen_judge" + judge_name = "qwen" else: raise ValueError(f"Unknown judge name: {judge_name}") - results = { - "model": model_name, - "experiment_name": experiment_name, - "tasks": {} - } + if experiment_name: + overall_output_path = ( + output_path + / judge_dir + / experiment_name + / model_name + / "overall_metrics.json" + ) + else: + overall_output_path = ( + output_path / judge_dir / model_name / "overall_metrics.json" + ) + + if overall_output_path.exists(): + try: + with open(overall_output_path, "r") as f: + results = json.load(f) + logger.info(f"Loaded existing results from {overall_output_path}") + logger.info(f"Existing tasks: {list(results.get('tasks', {}).keys())}") + # Ensure tasks dict exists + if "tasks" not in results: + results["tasks"] = {} + except Exception as e: + logger.warning(f"Could not load existing results: {e}, creating new") + results = { + "model": model_name, + "experiment_name": experiment_name, + "tasks": {}, + } + else: + logger.info("Creating new results file (no existing file found)") + results = {"model": model_name, "experiment_name": experiment_name, "tasks": {}} + + # Reconstruct missing tasks from per-topic JSONs + all_possible_tasks = ["t1", "t2", "t3"] + missing_tasks = [ + t for t in all_possible_tasks if t not in results.get("tasks", {}) + ] + + if missing_tasks: + logger.info(f"Attempting to reconstruct missing tasks: {missing_tasks}") + for missing_task in missing_tasks: + task_name = task_mapping.get(missing_task) + if not task_name: + continue + + logger.info(f" Reconstructing {missing_task} ({task_name})...") + + # Load all available per-topic results for this task + all_topics_for_task = config.get_topics() + existing_topic_results = load_existing_topic_results( + output_path, + model_name, + task_name, + all_topics_for_task, + judge_dir, + experiment_name, + ) + + if existing_topic_results: + # Reconstruct task with aggregated metrics + reconstructed_task = { + "task_name": task_name, + "topics": existing_topic_results, + "aggregated_across_topics": aggregate_topic_metrics( + existing_topic_results, missing_task + ), + } + results["tasks"][missing_task] = reconstructed_task + logger.info( + f" Reconstructed {missing_task} with " + f"{len(existing_topic_results)} topics" + ) + else: + logger.warning(f" No per-topic results found for {missing_task}") - for task_key in tasks: task_name = task_mapping[task_key] logger.info(f"Evaluating task: {task_name}") - - task_results = { - "task_name": task_name, - "topics": {} - } - all_topics_for_aggregation = config.get_topics() # Get all possible topics + + task_results = {"task_name": task_name, "topics": {}} + all_topics_for_aggregation = config.get_topics() existing_results = load_existing_topic_results( - output_path, model_name, task_name, - all_topics_for_aggregation, judge_dir, experiment_name + output_path, + model_name, + task_name, + all_topics_for_aggregation, + judge_dir, + experiment_name, ) task_results["topics"] = existing_results # Start with existing - - + for topic in topics: logger.info(f" Processing topic: {topic}") - + # Paths gt_path = vqa_path / task_name / f"{topic}.json" if experiment_name: - pred_path = predictions_path / experiment_name / model_name / task_name / f"{topic}.json" + pred_path = ( + predictions_path + / experiment_name + / model_name + / task_name + / f"{topic}.json" + ) else: pred_path = predictions_path / model_name / task_name / f"{topic}.json" - - + # Check if files exist if not gt_path.exists(): logger.warning(f"Ground truth not found: {gt_path}") continue - + if not pred_path.exists(): logger.warning(f"Prediction not found: {pred_path}") continue - + # Compute task-specific metrics if experiment_name: - topic_output_path = output_path /judge_dir/ experiment_name/ model_name / task_name / f"{topic}.json" + topic_output_path = ( + output_path + / judge_dir + / experiment_name + / model_name + / task_name + / f"{topic}.json" + ) else: - topic_output_path = output_path /judge_dir/ model_name / task_name / f"{topic}.json" - + topic_output_path = ( + output_path / judge_dir / model_name / task_name / f"{topic}.json" + ) + logger.info(f" Saving topic results to: {topic_output_path}") topic_output_path.parent.mkdir(parents=True, exist_ok=True) - + try: if task_key == "t1": topic_results = evaluate_t1_topic( @@ -154,7 +262,7 @@ def compute_metrics_for_model( str(pred_path), str(topic_output_path), use_llm_judge=use_llm_judge, - judge_name= judge_name + judge_name=judge_name, ) elif task_key == "t2": topic_results = evaluate_t2_topic( @@ -162,7 +270,7 @@ def compute_metrics_for_model( str(pred_path), str(topic_output_path), use_llm_judge=use_llm_judge, - judge_name= judge_name + judge_name=judge_name, ) elif task_key == "t3": topic_results = evaluate_t3_topic( @@ -170,106 +278,117 @@ def compute_metrics_for_model( str(pred_path), str(topic_output_path), use_llm_judge=use_llm_judge, - judge_name= judge_name + judge_name=judge_name, ) - + task_results["topics"][topic] = topic_results["aggregated_metrics"] - + except Exception as e: - logger.error(f"Failed to evaluate {topic} for task {task_name}: {e}", exc_info=True) + logger.error( + f"Failed to evaluate {topic} for task {task_name}: {e}", + exc_info=True, + ) continue - + # Aggregate across topics - task_results["aggregated_across_topics"] = aggregate_topic_metrics(task_results["topics"], task_key) - + task_results["aggregated_across_topics"] = aggregate_topic_metrics( + task_results["topics"], task_key + ) + results["tasks"][task_key] = task_results - - # Save overall results - if experiment_name: - overall_output_path = output_path / judge_dir/experiment_name /model_name / "overall_metrics.json" - else: - overall_output_path = output_path / judge_dir /model_name / "overall_metrics.json" - + overall_output_path.parent.mkdir(parents=True, exist_ok=True) - - with open(overall_output_path, 'w') as f: + + with open(overall_output_path, "w") as f: json.dump(results, f, indent=2) - + logger.info(f"Overall metrics saved to {overall_output_path}") - + return results -def aggregate_topic_metrics(topic_metrics: dict, task_key: str) -> dict: - """Aggregate metrics across all topics for a task""" +def aggregate_topic_metrics( + topic_metrics: Dict[str, Any], task_key: str +) -> Dict[str, Any]: + """Aggregate metrics across all topics for a task.""" if not topic_metrics: return {} - + aggregated = {} - + if task_key == "t1": # Aggregate T1 metrics for summary_type in ["detailed", "short"]: rouge_scores = [] sim_scores = [] llm_scores = [] - - for topic, metrics in topic_metrics.items(): + + for _topic, metrics in topic_metrics.items(): if summary_type in metrics: rouge_scores.append(metrics[summary_type]["rouge_l_mean"]) sim_scores.append(metrics[summary_type]["text_similarity_mean"]) if "llm_judge_score_mean" in metrics[summary_type]: llm_scores.append(metrics[summary_type]["llm_judge_score_mean"]) - + aggregated[summary_type] = { - "rouge_l_mean": float(sum(rouge_scores) / len(rouge_scores)) if rouge_scores else 0.0, - "text_similarity_mean": float(sum(sim_scores) / len(sim_scores)) if sim_scores else 0.0, + "rouge_l_mean": float(sum(rouge_scores) / len(rouge_scores)) + if rouge_scores + else 0.0, + "text_similarity_mean": float(sum(sim_scores) / len(sim_scores)) + if sim_scores + else 0.0, } - + if llm_scores: - aggregated[summary_type]["llm_judge_score_mean"] = float(sum(llm_scores) / len(llm_scores)) - + aggregated[summary_type]["llm_judge_score_mean"] = float( + sum(llm_scores) / len(llm_scores) + ) + # Aggregate CIDEr cider_detailed = [] cider_short = [] - for topic, metrics in topic_metrics.items(): + for _topic, metrics in topic_metrics.items(): if "cider" in metrics: cider_detailed.append(metrics["cider"]["cider_detailed"]) cider_short.append(metrics["cider"]["cider_short"]) - + if cider_detailed: aggregated["cider"] = { "cider_detailed_mean": float(sum(cider_detailed) / len(cider_detailed)), - "cider_short_mean": float(sum(cider_short) / len(cider_short)) + "cider_short_mean": float(sum(cider_short) / len(cider_short)), } - + elif task_key == "t2": # Aggregate T2 metrics accuracies = [] rouge_scores = [] sim_scores = [] llm_scores = [] - - for topic, metrics in topic_metrics.items(): + + for _topic, metrics in topic_metrics.items(): accuracies.append(metrics["accuracy"]) - + if "rationale" in metrics: rouge_scores.append(metrics["rationale"]["rouge_l_mean"]) sim_scores.append(metrics["rationale"]["text_similarity_mean"]) if "llm_judge_score_mean" in metrics["rationale"]: llm_scores.append(metrics["rationale"]["llm_judge_score_mean"]) - - aggregated["accuracy_mean"] = float(sum(accuracies) / len(accuracies)) if accuracies else 0.0 - + + aggregated["accuracy_mean"] = ( + float(sum(accuracies) / len(accuracies)) if accuracies else 0.0 + ) + if rouge_scores: aggregated["rationale"] = { "rouge_l_mean": float(sum(rouge_scores) / len(rouge_scores)), "text_similarity_mean": float(sum(sim_scores) / len(sim_scores)), } - + if llm_scores: - aggregated["rationale"]["llm_judge_score_mean"] = float(sum(llm_scores) / len(llm_scores)) - + aggregated["rationale"]["llm_judge_score_mean"] = float( + sum(llm_scores) / len(llm_scores) + ) + elif task_key == "t3": # Aggregate T3 metrics mean_ious = [] @@ -278,71 +397,76 @@ def aggregate_topic_metrics(topic_metrics: dict, task_key: str) -> dict: sim_scores = [] llm_scores = [] recall_metrics = {0.3: [], 0.5: [], 0.7: []} - - for topic, metrics in topic_metrics.items(): + + for _topic, metrics in topic_metrics.items(): mean_ious.append(metrics["mean_iou"]) mae_avgs.append(metrics["mae"]["average_mean"]) - + for threshold in [0.3, 0.5, 0.7]: key = f"R@{threshold}" if key in metrics: recall_metrics[threshold].append(metrics[key]["recall"]) - + if "rationale" in metrics: rouge_scores.append(metrics["rationale"]["rouge_l_mean"]) sim_scores.append(metrics["rationale"]["text_similarity_mean"]) if "llm_judge_score_mean" in metrics["rationale"]: llm_scores.append(metrics["rationale"]["llm_judge_score_mean"]) - - aggregated["mean_iou"] = float(sum(mean_ious) / len(mean_ious)) if mean_ious else 0.0 - aggregated["mae_average"] = float(sum(mae_avgs) / len(mae_avgs)) if mae_avgs else 0.0 - + + aggregated["mean_iou"] = ( + float(sum(mean_ious) / len(mean_ious)) if mean_ious else 0.0 + ) + aggregated["mae_average"] = ( + float(sum(mae_avgs) / len(mae_avgs)) if mae_avgs else 0.0 + ) + for threshold, recalls in recall_metrics.items(): if recalls: aggregated[f"R@{threshold}"] = float(sum(recalls) / len(recalls)) - + if rouge_scores: aggregated["rationale"] = { "rouge_l_mean": float(sum(rouge_scores) / len(rouge_scores)), "text_similarity_mean": float(sum(sim_scores) / len(sim_scores)), } - + if llm_scores: - aggregated["rationale"]["llm_judge_score_mean"] = float(sum(llm_scores) / len(llm_scores)) - + aggregated["rationale"]["llm_judge_score_mean"] = float( + sum(llm_scores) / len(llm_scores) + ) + return aggregated -def main(): - """Main entry point""" + +def main() -> int: + """ + Parse arguments, load config, and compute metrics for selected models. + + Returns: + Exit code: 0 on success, 1 on failure or invalid options. + """ parser = argparse.ArgumentParser(description="Compute evaluation metrics") - - parser.add_argument( - "--model", - type=str, - help="Model name to evaluate" - ) + + parser.add_argument("--model", type=str, help="Model name to evaluate") parser.add_argument( - "--models", - type=str, - nargs="+", - help="Multiple model names to evaluate" + "--models", type=str, nargs="+", help="Multiple model names to evaluate" ) parser.add_argument( "--all", action="store_true", - help="Evaluate all models in predictions directory" + help="Evaluate all models in predictions directory", ) parser.add_argument( "--config", type=str, - default="configs/models_config.yaml", - help="Path to configuration file" + default="models_config.yaml", + help="Path to configuration file", ) parser.add_argument( "--experiment-name", type=str, default=None, - help="Optional experiment name (must match inference experiment name)" + help="Optional experiment name (must match inference experiment name)", ) parser.add_argument( @@ -350,58 +474,59 @@ def main(): type=str, nargs="+", choices=["t1", "t2", "t3"], - help="Tasks to evaluate (default: from config or all tasks)" + help="Tasks to evaluate (default: from config or all tasks)", ) parser.add_argument( "--topics", type=str, nargs="+", - help="Topics to evaluate (default: from config or all topics)" + help="Topics to evaluate (default: from config or all topics)", ) + parser.add_argument("--vqa-path", type=str, help="Override VQA path from config") parser.add_argument( - "--vqa-path", - type=str, - help="Override VQA path from config" + "--predictions-path", type=str, help="Override predictions path from config" ) parser.add_argument( - "--predictions-path", - type=str, - help="Override predictions path from config" - ) - parser.add_argument( - "--output-path", - type=str, - help="Override output path from config" + "--output-path", type=str, help="Override output path from config" ) parser.add_argument( "--no-llm-judge", action="store_true", - help="Disable LLM judge evaluation (faster)" + help="Disable LLM judge evaluation (faster)", ) - + args = parser.parse_args() - + # Load config config = get_config(args.config) - + # Get all values from config with optional CLI overrides tasks = args.tasks if args.tasks else ["t1", "t2", "t3"] topics = args.topics if args.topics else config.get_topics() vqa_path = args.vqa_path if args.vqa_path else config.get_vqa_path() - predictions_path = args.predictions_path if args.predictions_path else config.get("results.predictions_path", "results/predictions") - output_path = args.output_path if args.output_path else config.get("results.scores_path", "results/scores") - + predictions_path = ( + args.predictions_path + if args.predictions_path + else config.get("results.predictions_path", "results/predictions") + ) + output_path = ( + args.output_path + if args.output_path + else config.get("results.scores_path", "results/scores") + ) + # Determine which models to evaluate models_to_evaluate = [] - + if args.all: predictions_path_obj = Path(predictions_path) if args.experiment_name: predictions_path_obj = predictions_path_obj / args.experiment_name - + if predictions_path_obj.exists(): models_to_evaluate = [ - d.name for d in predictions_path_obj.iterdir() + d.name + for d in predictions_path_obj.iterdir() if d.is_dir() and not d.name.startswith(".") ] else: @@ -414,18 +539,18 @@ def main(): else: logger.error("Must specify --model, --models, or --all") return 1 - + logger.info(f"Evaluating models: {models_to_evaluate}") logger.info(f"Tasks: {tasks}") logger.info(f"Topics: {len(topics)} topics") - if args.experiment_name: + if args.experiment_name: logger.info(f"Experiment: {args.experiment_name}") logger.info(f"VQA path: {vqa_path}") logger.info(f"Predictions path: {predictions_path}") logger.info(f"Output path: {output_path}") logger.info(f"LLM Judge: {'disabled' if args.no_llm_judge else 'enabled'}") - + # Evaluate each model for model_name in models_to_evaluate: try: @@ -438,15 +563,15 @@ def main(): output_path=Path(output_path), use_llm_judge=not args.no_llm_judge, config=config, - experiment_name=args.experiment_name + experiment_name=args.experiment_name, ) except Exception as e: logger.error(f"Failed to evaluate model {model_name}: {e}", exc_info=True) continue - + logger.info("Metrics computation complete") return 0 if __name__ == "__main__": - sys.exit(main()) \ No newline at end of file + sys.exit(main()) diff --git a/sonic-o1/05_evaluation_inference/metrics/llm_judge_gpt.py b/sonic-o1/05_evaluation_inference/metrics/llm_judge_gpt.py index 97d8d13..1c86739 100644 --- a/sonic-o1/05_evaluation_inference/metrics/llm_judge_gpt.py +++ b/sonic-o1/05_evaluation_inference/metrics/llm_judge_gpt.py @@ -1,19 +1,24 @@ +"""llm_judge_gpt.py + +LLM-as-Judge using GPT: semantic similarity, factual correctness, completeness. + +Author: SONIC-O1 Team """ -LLM-as-Judge Implementation using GPT-5-mini -Evaluates semantic similarity, factual correctness, and completeness -""" -from openai import OpenAI + import json -from typing import Dict, Any, Optional import logging -import os from pathlib import Path +from typing import Any, Dict, Optional + +from openai import OpenAI + # Load environment variables from .env file try: from dotenv import load_dotenv + # Look for .env in the evaluation_Inference directory - env_path = Path(__file__).parent.parent / '.env' + env_path = Path(__file__).parent.parent / ".env" if env_path.exists(): load_dotenv(env_path) logger = logging.getLogger(__name__) @@ -24,9 +29,10 @@ logger = logging.getLogger(__name__) + class LLMJudge: - """LLM-as-Judge evaluator using GPT-5-mini""" - + """LLM-as-Judge evaluator using GPT-5-mini.""" + SYSTEM_PROMPT = """You are an intelligent and fair evaluator AI that specializes in assessing the correctness and semantic alignment between ground truth answers and predicted responses for question-answering tasks, including those based on video content. Your role is to evaluate how well a predicted answer matches the correct (reference) answer based on the following detailed criteria: @@ -62,33 +68,33 @@ class LLMJudge: def __init__(self, api_key: Optional[str] = None, model: str = "gpt-5-mini"): """ - Initialize LLM Judge - + Initialize LLM Judge. + Args: api_key: OpenAI API key (if None, uses env variable) model: Model to use for judging (default: gpt-5-mini) """ self.client = OpenAI(api_key=api_key) if api_key else OpenAI() self.model = model - + def evaluate( - self, - question: str, - correct_answer: str, + self, + question: str, + correct_answer: str, predicted_answer: str, - task_type: str = "general" + task_type: str = "general", ) -> Dict[str, Any]: """ - Evaluate a predicted answer against ground truth - + Evaluate a predicted answer against ground truth. + Args: question: The question being answered correct_answer: Ground truth answer predicted_answer: Model's predicted answer task_type: Type of task (for logging purposes) - + Returns: - Dict with 'score' (0-10) and 'justification' (str) + Dict with "score" (0-10) and "justification" (str). """ user_prompt = f"""Please evaluate the following video-based question-answer pair: @@ -108,11 +114,10 @@ def evaluate( reasoning={"effort": "medium"}, # Balance between speed and accuracy text={"verbosity": "low"}, ) - + # Parse the JSON response response_text = result.output_text.strip() - - + # Try to extract JSON from response if "```json" in response_text: # Extract JSON from markdown code block @@ -124,64 +129,59 @@ def evaluate( json_start = response_text.find("```") + 3 json_end = response_text.find("```", json_start) response_text = response_text[json_start:json_end].strip() - - from json_repair import repair_json + # Optional dep: only import when repair is needed + from json_repair import repair_json # noqa: PLC0415 + repaired_json = repair_json(response_text, return_objects=False) response_text = json.loads(repaired_json) - evaluation = json.loads(response_text) - + # Validate structure if "score" not in evaluation or "justification" not in evaluation: raise ValueError("Missing required fields in LLM response") - + # Ensure score is integer 0-10 evaluation["score"] = max(0, min(10, int(evaluation["score"]))) - + return evaluation - + except json.JSONDecodeError as e: logger.error(f"Failed to parse LLM judge response: {e}") logger.error(f"Raw response: {response_text}") return { "score": 0, - "justification": f"Error parsing LLM response: {str(e)}" + "justification": f"Error parsing LLM response: {str(e)}", } except Exception as e: logger.error(f"LLM judge evaluation failed: {e}") - return { - "score": 0, - "justification": f"Evaluation error: {str(e)}" - } - + return {"score": 0, "justification": f"Evaluation error: {str(e)}"} + def batch_evaluate( - self, - evaluations: list[Dict[str, str]], - task_type: str = "general" + self, evaluations: list[Dict[str, str]], task_type: str = "general" ) -> list[Dict[str, Any]]: """ - Evaluate multiple question-answer pairs - + Evaluate multiple question-answer pairs. + Args: - evaluations: List of dicts with 'question', 'correct_answer', 'predicted_answer' + evaluations: List of dicts with question, correct_answer, predicted_answer task_type: Type of task - + Returns: - List of evaluation results + List of evaluation results (score and justification per item). """ results = [] for i, eval_item in enumerate(evaluations): - logger.info(f"Evaluating {i+1}/{len(evaluations)}") + logger.info(f"Evaluating {i + 1}/{len(evaluations)}") result = self.evaluate( question=eval_item["question"], correct_answer=eval_item["correct_answer"], predicted_answer=eval_item["predicted_answer"], - task_type=task_type + task_type=task_type, ) results.append(result) - + return results @@ -191,20 +191,20 @@ def evaluate_with_llm_judge( correct_answer: str, predicted_answer: str, api_key: Optional[str] = None, - model: str = "gpt-4o-mini" + model: str = "gpt-4o-mini", ) -> Dict[str, Any]: """ - Convenience function for single evaluation - + Evaluate single instance. + Args: question: The question correct_answer: Ground truth predicted_answer: Model prediction api_key: OpenAI API key model: Model to use - + Returns: - Dict with score and justification + Dict with "score" and "justification". """ judge = LLMJudge(api_key=api_key, model=model) - return judge.evaluate(question, correct_answer, predicted_answer) \ No newline at end of file + return judge.evaluate(question, correct_answer, predicted_answer) diff --git a/sonic-o1/05_evaluation_inference/metrics/llm_judge_qwen.py b/sonic-o1/05_evaluation_inference/metrics/llm_judge_qwen.py index 27ac1a9..f8746ae 100644 --- a/sonic-o1/05_evaluation_inference/metrics/llm_judge_qwen.py +++ b/sonic-o1/05_evaluation_inference/metrics/llm_judge_qwen.py @@ -1,44 +1,47 @@ +"""llm_judge_qwen.py + +LLM-as-Judge using Qwen3: text-only evaluation with multi-GPU support. + +Author: SONIC-O1 Team """ -LLM-as-Judge Implementation using Qwen3-8B -Evaluates semantic similarity, factual correctness, and completeness -Self-contained implementation for text-only evaluation with multi-GPU support -Reads configuration from models_config.yaml -""" + +import gc import json import logging -from typing import Dict, Any, Optional from pathlib import Path -import yaml +from typing import Any, Dict, Optional import torch -from transformers import AutoTokenizer, AutoModelForCausalLM +import yaml +from transformers import AutoModelForCausalLM, AutoTokenizer + logger = logging.getLogger(__name__) def load_config(config_path: str = "models_config.yaml") -> Dict[str, Any]: - """Load configuration from YAML file""" + """Load configuration from YAML file; search parent dirs if path not found.""" config_file = Path(config_path) - + # Try to find config in common locations search_paths = [ config_file, Path(__file__).parent / config_path, Path(__file__).parent.parent / config_path, ] - + for path in search_paths: if path.exists(): - with open(path, 'r') as f: + with open(path, "r") as f: return yaml.safe_load(f) - - logger.warning(f"Config file not found, using defaults") + + logger.warning("Config file not found, using defaults") return {} class LLMJudge: - """LLM-as-Judge evaluator using Qwen3-8B for text-only evaluation with multi-GPU support""" - + """LLM-as-Judge using Qwen3-8B for text-only evaluation (multi-GPU).""" + SYSTEM_PROMPT = """You are an intelligent and fair evaluator AI that specializes in assessing the correctness and semantic alignment between ground truth answers and predicted responses for question-answering tasks, including those based on video content. Your role is to evaluate how well a predicted answer matches the correct (reference) answer based on the following detailed criteria: @@ -73,16 +76,16 @@ class LLMJudge: Be fair, consistent, and concise. Follow the format exactly.""" def __init__( - self, + self, model_path: Optional[str] = None, device_map: Optional[str] = None, dtype: Optional[str] = None, max_memory: Optional[Dict[int, str]] = None, - config_path: str = "models_config.yaml" + config_path: str = "models_config.yaml", ): """ - Initialize LLM Judge with Qwen3-8B - + Initialize LLM Judge with Qwen3-8B. + Args: model_path: HuggingFace model path (if None, reads from config) device_map: Device mapping strategy (if None, reads from config) @@ -93,51 +96,53 @@ def __init__( # Load config config = load_config(config_path) metrics_config = config.get("metrics", {}) - + # Get model parameters from config or use provided/defaults - self.model_path = model_path or metrics_config.get("llm_judge_model", "Qwen/Qwen3-8B") - + self.model_path = model_path or metrics_config.get( + "llm_judge_model", "Qwen/Qwen3-8B" + ) + # Device map if device_map is None: device_map = metrics_config.get("llm_judge_device_map", "auto") self.device_map = device_map - + # Dtype if dtype is None: dtype_str = metrics_config.get("llm_judge_dtype", "bfloat16") else: dtype_str = dtype - + # Convert string to torch dtype dtype_map = { "bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32, - "auto": None + "auto": None, } self.dtype = dtype_map.get(dtype_str, torch.bfloat16) - + # Max memory if max_memory is None: max_memory = metrics_config.get("llm_judge_max_memory") self.max_memory = max_memory - + # Generation config gen_config = metrics_config.get("llm_judge_generation", {}) self.temperature = gen_config.get("temperature", 0.0) self.top_p = gen_config.get("top_p", 0.95) self.max_new_tokens = gen_config.get("max_new_tokens", 512) - + self.model = None self.tokenizer = None - + # Log configuration - logger.info(f"LLM Judge Configuration:") + logger.info("LLM Judge Configuration:") logger.info(f" Model: {self.model_path}") logger.info(f" Device map: {self.device_map}") logger.info(f" Dtype: {dtype_str}") logger.info(f" Temperature: {self.temperature}") - + # Get GPU info if torch.cuda.is_available(): num_gpus = torch.cuda.device_count() @@ -146,25 +151,24 @@ def __init__( props = torch.cuda.get_device_properties(i) memory_gb = props.total_memory / 1024**3 logger.info(f" GPU {i}: {props.name} ({memory_gb:.1f} GB)") - - def load(self): - """Load model and tokenizer with multi-GPU support""" + + def load(self) -> None: + """Load model and tokenizer with multi-GPU support.""" try: logger.info(f"Loading LLM Judge from {self.model_path}...") - + # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained( - self.model_path, - trust_remote_code=True + self.model_path, trust_remote_code=True ) - + # Set pad token if not set if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token - + # Load model with device_map for multi-GPU distribution logger.info("Loading model with multi-GPU distribution...") - + self.model = AutoModelForCausalLM.from_pretrained( self.model_path, torch_dtype=self.dtype, @@ -173,51 +177,50 @@ def load(self): max_memory=self.max_memory, offload_folder="offload", # Fallback to disk if needed ) - + self.model.eval() - + # Log device map - if hasattr(self.model, 'hf_device_map'): + if hasattr(self.model, "hf_device_map"): logger.info(f"Model device map: {self.model.hf_device_map}") - + logger.info("LLM Judge loaded successfully") - + except Exception as e: - raise RuntimeError(f"Failed to load LLM Judge model: {e}") - - def unload(self): - """Unload model to free memory""" + raise RuntimeError(f"Failed to load LLM Judge model: {e}") from e + + def unload(self) -> None: + """Unload model to free memory.""" if self.model is not None: del self.model self.model = None - + if self.tokenizer is not None: del self.tokenizer self.tokenizer = None - + if torch.cuda.is_available(): for i in range(torch.cuda.device_count()): with torch.cuda.device(i): torch.cuda.empty_cache() - - import gc + gc.collect() - + logger.info("LLM Judge unloaded from all GPUs") - + def evaluate( - self, - question: str, - correct_answer: str, + self, + question: str, + correct_answer: str, predicted_answer: str, task_type: str = "general", temperature: Optional[float] = None, top_p: Optional[float] = None, - max_new_tokens: Optional[int] = None + max_new_tokens: Optional[int] = None, ) -> Dict[str, Any]: """ - Evaluate a predicted answer against ground truth - + Evaluate a predicted answer against ground truth. + Args: question: The question being answered correct_answer: Ground truth answer @@ -226,15 +229,15 @@ def evaluate( temperature: Generation temperature (uses config default if None) top_p: Nucleus sampling parameter (uses config default if None) max_new_tokens: Maximum tokens to generate (uses config default if None) - + Returns: - Dict with 'score' (0-10) and 'justification' (str) + Dict with "score" (0-10) and "justification" (str). """ # Lazy load if needed if self.model is None or self.tokenizer is None: logger.info("Model not loaded, loading now...") self.load() - + # Use config defaults if not provided if temperature is None: temperature = self.temperature @@ -242,7 +245,7 @@ def evaluate( top_p = self.top_p if max_new_tokens is None: max_new_tokens = self.max_new_tokens - + # Construct evaluation prompt user_prompt = f"""Please evaluate the following video-based question-answer pair: @@ -256,41 +259,32 @@ def evaluate( # Build conversation using Qwen3 chat template messages = [ - { - "role": "system", - "content": self.SYSTEM_PROMPT - }, - { - "role": "user", - "content": user_prompt - } + {"role": "system", "content": self.SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt}, ] - + try: # Apply chat template text_prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, - enable_thinking=False + enable_thinking=False, ) - + # Tokenize inputs = self.tokenizer( - text_prompt, - return_tensors="pt", - padding=True, - truncation=True + text_prompt, return_tensors="pt", padding=True, truncation=True ) - + # Move inputs to first device (accelerate handles the rest) - if hasattr(self.model, 'hf_device_map'): + if hasattr(self.model, "hf_device_map"): first_device = next(iter(self.model.hf_device_map.values())) else: first_device = next(self.model.parameters()).device - + inputs = {k: v.to(first_device) for k, v in inputs.items()} - + # Generate with torch.no_grad(): outputs = self.model.generate( @@ -300,11 +294,11 @@ def evaluate( do_sample=temperature > 0, top_p=top_p if temperature > 0 else None, pad_token_id=self.tokenizer.pad_token_id, - eos_token_id=self.tokenizer.eos_token_id + eos_token_id=self.tokenizer.eos_token_id, ) - + # Decode response (only the new tokens) - generated_ids = outputs[0][inputs['input_ids'].shape[1]:].tolist() + generated_ids = outputs[0][inputs["input_ids"].shape[1] :].tolist() # Parse thinking content using token ID try: @@ -318,34 +312,30 @@ def evaluate( response_text = self.tokenizer.decode( generated_ids[index:], skip_special_tokens=True, - clean_up_tokenization_spaces=True + clean_up_tokenization_spaces=True, ).strip() # Parse JSON response - evaluation = self._parse_response(response_text) - return evaluation + return self._parse_response(response_text) except Exception as e: logger.error(f"LLM judge evaluation failed: {e}", exc_info=True) - return { - "score": 0, - "justification": f"Evaluation error: {str(e)}" - } - + return {"score": 0, "justification": f"Evaluation error: {str(e)}"} + def _parse_response(self, response_text: str) -> Dict[str, Any]: """ - Parse JSON response from model output - + Parse JSON response from model output. + Args: response_text: Raw model output - + Returns: - Dict with score and justification + Dict with "score" and "justification". """ try: - # NEW: Handle thinking blocks + # Handle thinking blocks if "" in response_text: response_text = response_text.split("")[-1].strip() - + # Try to extract JSON from response if "```json" in response_text: # Extract JSON from markdown code block @@ -357,68 +347,64 @@ def _parse_response(self, response_text: str) -> Dict[str, Any]: json_start = response_text.find("```") + 3 json_end = response_text.find("```", json_start) response_text = response_text[json_start:json_end].strip() - + # Try to find JSON object with curly braces if "{" in response_text and "}" in response_text: json_start = response_text.find("{") json_end = response_text.rfind("}") + 1 response_text = response_text[json_start:json_end] - + evaluation = json.loads(response_text) - + # Validate structure if "score" not in evaluation or "justification" not in evaluation: raise ValueError("Missing required fields in LLM response") - + # Ensure score is integer 0-10 evaluation["score"] = max(0, min(10, int(evaluation["score"]))) - + return evaluation - + except json.JSONDecodeError as e: logger.error(f"Failed to parse LLM judge response: {e}") logger.error(f"Raw response: {response_text}") return { "score": 0, - "justification": f"Error parsing LLM response: {str(e)}" + "justification": f"Error parsing LLM response: {str(e)}", } except Exception as e: logger.error(f"Error processing response: {e}") - return { - "score": 0, - "justification": f"Response processing error: {str(e)}" - } - - + return {"score": 0, "justification": f"Response processing error: {str(e)}"} + def batch_evaluate( self, evaluations: list[Dict[str, str]], task_type: str = "general", temperature: Optional[float] = None, top_p: Optional[float] = None, - max_new_tokens: Optional[int] = None + max_new_tokens: Optional[int] = None, ) -> list[Dict[str, Any]]: """ - Evaluate multiple question-answer pairs - + Evaluate multiple question-answer pairs. + Args: - evaluations: List of dicts with 'question', 'correct_answer', 'predicted_answer' + evaluations: List of dicts with question, correct_answer, predicted_answer task_type: Type of task temperature: Generation temperature (uses config default if None) top_p: Nucleus sampling parameter (uses config default if None) max_new_tokens: Max tokens per generation (uses config default if None) - + Returns: - List of evaluation results + List of evaluation results (score and justification per item). """ results = [] - + # Ensure model is loaded if self.model is None: self.load() - + for i, eval_item in enumerate(evaluations): - logger.info(f"Evaluating {i+1}/{len(evaluations)}") + logger.info(f"Evaluating {i + 1}/{len(evaluations)}") result = self.evaluate( question=eval_item["question"], correct_answer=eval_item["correct_answer"], @@ -426,30 +412,30 @@ def batch_evaluate( task_type=task_type, temperature=temperature, top_p=top_p, - max_new_tokens=max_new_tokens + max_new_tokens=max_new_tokens, ) results.append(result) - + return results - + def get_model_info(self) -> Dict[str, Any]: - """Get information about the loaded model""" + """Get information about the loaded model.""" info = { "model_path": self.model_path, "dtype": str(self.dtype), "device_map": self.device_map, } - - if self.model is not None and hasattr(self.model, 'hf_device_map'): + + if self.model is not None and hasattr(self.model, "hf_device_map"): info["loaded_device_map"] = self.model.hf_device_map - + if torch.cuda.is_available(): info["num_gpus"] = torch.cuda.device_count() info["gpu_memory_allocated"] = { i: f"{torch.cuda.memory_allocated(i) / 1024**3:.2f} GB" for i in range(torch.cuda.device_count()) } - + return info @@ -460,11 +446,11 @@ def evaluate_with_llm_judge( predicted_answer: str, model_path: Optional[str] = None, device_map: Optional[str] = None, - dtype: Optional[str] = None + dtype: Optional[str] = None, ) -> Dict[str, Any]: """ - Convenience function for single evaluation - + Evaluate single instance. + Args: question: The question correct_answer: Ground truth @@ -472,9 +458,9 @@ def evaluate_with_llm_judge( model_path: Path to Qwen model (reads from config if None) device_map: Device mapping strategy (reads from config if None) dtype: torch dtype string (reads from config if None) - + Returns: - Dict with score and justification + Dict with "score" and "justification". """ judge = LLMJudge(model_path=model_path, device_map=device_map, dtype=dtype) judge.load() diff --git a/sonic-o1/05_evaluation_inference/metrics/t1_metrics.py b/sonic-o1/05_evaluation_inference/metrics/t1_metrics.py index 1497818..7f7f952 100644 --- a/sonic-o1/05_evaluation_inference/metrics/t1_metrics.py +++ b/sonic-o1/05_evaluation_inference/metrics/t1_metrics.py @@ -1,143 +1,148 @@ -""" -T1 Metrics: Video Summarization Evaluation -Computes ROUGE-L, CIDEr, text similarity, and LLM-as-Judge scores +"""t1_metrics.py + +T1 metrics: video summarization (ROUGE-L, CIDEr, text similarity, LLM-as-Judge). + +Author: SONIC-O1 Team """ import json -import numpy as np -from pathlib import Path -from typing import Dict, List, Any, Optional import logging +from pathlib import Path +from typing import Any, Dict, List, Optional -from rouge_score import rouge_scorer +import numpy as np from pycocoevalcap.cider.cider import Cider -from sklearn.metrics.pairwise import cosine_similarity +from rouge_score import rouge_scorer from sentence_transformers import SentenceTransformer +from sklearn.metrics.pairwise import cosine_similarity logger = logging.getLogger(__name__) class T1Metrics: - """Compute metrics for T1: Video Summarization""" - + """Compute metrics for T1: Video Summarization.""" + def __init__( self, use_llm_judge: bool = True, embedding_model: str = "all-MiniLM-L6-v2", - judge_name: str = "gpt" - ): + judge_name: str = "gpt", + ) -> None: """ - Initialize T1 metrics - + Initialize T1 metrics. + Args: - use_llm_judge: Whether to use LLM judge evaluation - embedding_model: Model for computing text similarity + use_llm_judge: Whether to use LLM judge evaluation. + embedding_model: Model for computing text similarity. + judge_name: Judge backend: "gpt" or "qwen". """ - self.rouge_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True) + self.rouge_scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True) self.cider_scorer = Cider() self.embedding_model = SentenceTransformer(embedding_model) - + self.use_llm_judge = use_llm_judge logger.info(f"LLM Judge name: {judge_name}") + # Lazy import: load only selected judge backend (env-dependent) if use_llm_judge: if judge_name == "gpt": - from llm_judge_gpt import LLMJudge + from llm_judge_gpt import LLMJudge # noqa: PLC0415 + self.llm_judge = LLMJudge() elif judge_name == "qwen": - from llm_judge_qwen import LLMJudge + from llm_judge_qwen import LLMJudge # noqa: PLC0415 + self.llm_judge = LLMJudge() else: raise ValueError(f"Unknown judge name: {judge_name}") - - + def compute_rouge_l(self, reference: str, prediction: str) -> float: """ - Compute ROUGE-L F1 score - + Compute ROUGE-L F1 score. + Args: - reference: Ground truth text - prediction: Predicted text - + reference: Ground truth text. + prediction: Predicted text. + Returns: - ROUGE-L F1 score (0-1) + ROUGE-L F1 score in [0, 1]. """ scores = self.rouge_scorer.score(reference, prediction) - return scores['rougeL'].fmeasure - + return scores["rougeL"].fmeasure + def compute_cider(self, references: List[str], predictions: List[str]) -> float: """ - Compute CIDEr score for a set of summaries - + Compute CIDEr score for a set of summaries. + Args: - references: List of ground truth summaries - predictions: List of predicted summaries - + references: List of ground truth summaries. + predictions: List of predicted summaries. + Returns: - Average CIDEr score + Average CIDEr score. """ # CIDEr expects dict format: {id: [text]} gts = {i: [ref] for i, ref in enumerate(references)} res = {i: [pred] for i, pred in enumerate(predictions)} - + score, scores = self.cider_scorer.compute_score(gts, res) return float(score) - + def compute_text_similarity(self, reference: str, prediction: str) -> float: """ - Compute cosine similarity between embeddings - + Compute cosine similarity between embeddings. + Args: - reference: Ground truth text - prediction: Predicted text - + reference: Ground truth text. + prediction: Predicted text. + Returns: - Cosine similarity (0-1) + Cosine similarity in [0, 1]. """ ref_embedding = self.embedding_model.encode([reference]) pred_embedding = self.embedding_model.encode([prediction]) - + similarity = cosine_similarity(ref_embedding, pred_embedding)[0][0] return float(similarity) - + def evaluate_entry( - self, - ground_truth: Dict[str, Any], - prediction: Dict[str, Any] + self, ground_truth: Dict[str, Any], prediction: Dict[str, Any] ) -> Optional[Dict[str, Any]]: """ - Evaluate a single video entry - + Evaluate a single video entry. + Args: - ground_truth: Ground truth entry - prediction: Predicted entry - + ground_truth: Ground truth entry. + prediction: Predicted entry. + Returns: - Dict with all metrics, or None if prediction failed + Dict with all metrics, or None if prediction failed. """ # Skip failed predictions if "error" in prediction or "outputs" not in prediction: - logger.warning(f"Skipping failed entry: {prediction.get('video_id', 'unknown')}") + logger.warning( + f"Skipping failed entry: {prediction.get('video_id', 'unknown')}" + ) return None - + results = { "video_id": ground_truth["video_id"], "video_number": ground_truth["video_number"], } - + # Extract texts gt_detailed = ground_truth.get("summary_detailed", "") pred_detailed = prediction["outputs"].get("summary_detailed", "") - + gt_short = " ".join(ground_truth.get("summary_short", [])) pred_short = " ".join(prediction["outputs"].get("summary_short", [])) - + # Compute metrics for detailed summary results["detailed"] = { "rouge_l": self.compute_rouge_l(gt_detailed, pred_detailed), - "text_similarity": self.compute_text_similarity(gt_detailed, pred_detailed) + "text_similarity": self.compute_text_similarity(gt_detailed, pred_detailed), } - + # LLM Judge for detailed summary if self.use_llm_judge and gt_detailed and pred_detailed: try: @@ -145,21 +150,23 @@ def evaluate_entry( question="Provide a detailed summary of the video content.", correct_answer=gt_detailed, predicted_answer=pred_detailed, - task_type="summarization_detailed" + task_type="summarization_detailed", ) results["detailed"]["llm_judge_score"] = llm_eval["score"] - results["detailed"]["llm_judge_justification"] = llm_eval["justification"] + results["detailed"]["llm_judge_justification"] = llm_eval[ + "justification" + ] except Exception as e: logger.error(f"LLM judge failed for {results['video_id']}: {e}") results["detailed"]["llm_judge_score"] = None results["detailed"]["llm_judge_justification"] = str(e) - + # Compute metrics for short summary results["short"] = { "rouge_l": self.compute_rouge_l(gt_short, pred_short), - "text_similarity": self.compute_text_similarity(gt_short, pred_short) + "text_similarity": self.compute_text_similarity(gt_short, pred_short), } - + # LLM Judge for short summary if self.use_llm_judge and gt_short and pred_short: try: @@ -167,146 +174,162 @@ def evaluate_entry( question="Provide a brief bullet-point summary of the video.", correct_answer=gt_short, predicted_answer=pred_short, - task_type="summarization_short" + task_type="summarization_short", ) results["short"]["llm_judge_score"] = llm_eval["score"] results["short"]["llm_judge_justification"] = llm_eval["justification"] except Exception as e: - logger.error(f"LLM judge failed for short summary {results['video_id']}: {e}") + logger.error( + f"LLM judge failed for short summary {results['video_id']}: {e}" + ) results["short"]["llm_judge_score"] = None results["short"]["llm_judge_justification"] = str(e) - + return results - + def evaluate_topic( - self, - ground_truth_path: Path, - prediction_path: Path + self, ground_truth_path: Path, prediction_path: Path ) -> Dict[str, Any]: """ - Evaluate all entries for a topic - + Evaluate all entries for a topic. + Args: - ground_truth_path: Path to ground truth JSON - prediction_path: Path to prediction JSON - + ground_truth_path: Path to ground truth JSON. + prediction_path: Path to prediction JSON. + Returns: - Dict with aggregated metrics + Dict with topic_id, topic_name, num_evaluated, aggregated_metrics. """ # Load data - with open(ground_truth_path, 'r') as f: + with open(ground_truth_path, "r") as f: ground_truth = json.load(f) - - with open(prediction_path, 'r') as f: + + with open(prediction_path, "r") as f: prediction = json.load(f) - + # Match entries by video_id gt_entries = {e["video_id"]: e for e in ground_truth["entries"]} pred_entries = {e["video_id"]: e for e in prediction["entries"]} - + # Evaluate each entry entry_results = [] for video_id, gt_entry in gt_entries.items(): if video_id not in pred_entries: logger.warning(f"Missing prediction for video {video_id}") continue - + pred_entry = pred_entries[video_id] result = self.evaluate_entry(gt_entry, pred_entry) - + # Skip None results from failed entries if result is not None: entry_results.append(result) - + # Aggregate metrics aggregated = self._aggregate_results(entry_results) - + return { "topic_id": ground_truth["topic_id"], "topic_name": ground_truth["topic_name"], "num_evaluated": len(entry_results), "aggregated_metrics": aggregated, - "per_entry_results": entry_results + "per_entry_results": entry_results, } - + def _aggregate_results(self, results: List[Dict[str, Any]]) -> Dict[str, Any]: - """Aggregate metrics across all entries""" + """Aggregate metrics across all entries.""" aggregated = {} - + for summary_type in ["detailed", "short"]: - metrics = { - "rouge_l": [], - "text_similarity": [], - "llm_judge_score": [] - } - + metrics = {"rouge_l": [], "text_similarity": [], "llm_judge_score": []} + for result in results: if summary_type in result: metrics["rouge_l"].append(result[summary_type]["rouge_l"]) - metrics["text_similarity"].append(result[summary_type]["text_similarity"]) - + metrics["text_similarity"].append( + result[summary_type]["text_similarity"] + ) + if result[summary_type].get("llm_judge_score") is not None: - metrics["llm_judge_score"].append(result[summary_type]["llm_judge_score"]) - + metrics["llm_judge_score"].append( + result[summary_type]["llm_judge_score"] + ) + aggregated[summary_type] = { - "rouge_l_mean": float(np.mean(metrics["rouge_l"])) if metrics["rouge_l"] else 0.0, - "rouge_l_std": float(np.std(metrics["rouge_l"])) if metrics["rouge_l"] else 0.0, - "text_similarity_mean": float(np.mean(metrics["text_similarity"])) if metrics["text_similarity"] else 0.0, - "text_similarity_std": float(np.std(metrics["text_similarity"])) if metrics["text_similarity"] else 0.0, + "rouge_l_mean": float(np.mean(metrics["rouge_l"])) + if metrics["rouge_l"] + else 0.0, + "rouge_l_std": float(np.std(metrics["rouge_l"])) + if metrics["rouge_l"] + else 0.0, + "text_similarity_mean": float(np.mean(metrics["text_similarity"])) + if metrics["text_similarity"] + else 0.0, + "text_similarity_std": float(np.std(metrics["text_similarity"])) + if metrics["text_similarity"] + else 0.0, } - + if metrics["llm_judge_score"]: - aggregated[summary_type]["llm_judge_score_mean"] = float(np.mean(metrics["llm_judge_score"])) - aggregated[summary_type]["llm_judge_score_std"] = float(np.std(metrics["llm_judge_score"])) - + aggregated[summary_type]["llm_judge_score_mean"] = float( + np.mean(metrics["llm_judge_score"]) + ) + aggregated[summary_type]["llm_judge_score_std"] = float( + np.std(metrics["llm_judge_score"]) + ) + return aggregated - + def compute_cider_for_topic( - self, - ground_truth_path: Path, - prediction_path: Path + self, ground_truth_path: Path, prediction_path: Path ) -> Dict[str, float]: """ - Compute CIDEr scores for a topic - + Compute CIDEr scores for a topic. + Args: - ground_truth_path: Path to ground truth JSON - prediction_path: Path to prediction JSON - + ground_truth_path: Path to ground truth JSON. + prediction_path: Path to prediction JSON. + Returns: - Dict with CIDEr scores for detailed and short summaries + Dict with cider_detailed and cider_short. """ - with open(ground_truth_path, 'r') as f: + with open(ground_truth_path, "r") as f: ground_truth = json.load(f) - - with open(prediction_path, 'r') as f: + + with open(prediction_path, "r") as f: prediction = json.load(f) - + gt_entries = {e["video_id"]: e for e in ground_truth["entries"]} pred_entries = {e["video_id"]: e for e in prediction["entries"]} - + gt_detailed = [] pred_detailed = [] gt_short = [] pred_short = [] - + for video_id, gt_entry in gt_entries.items(): if video_id in pred_entries: pred_entry = pred_entries[video_id] - + # Skip failed predictions if "error" in pred_entry or "outputs" not in pred_entry: continue - + gt_detailed.append(gt_entry.get("summary_detailed", "")) pred_detailed.append(pred_entry["outputs"].get("summary_detailed", "")) - + gt_short.append(" ".join(gt_entry.get("summary_short", []))) - pred_short.append(" ".join(pred_entry["outputs"].get("summary_short", []))) - + pred_short.append( + " ".join(pred_entry["outputs"].get("summary_short", [])) + ) + return { - "cider_detailed": self.compute_cider(gt_detailed, pred_detailed) if gt_detailed else 0.0, - "cider_short": self.compute_cider(gt_short, pred_short) if gt_short else 0.0 + "cider_detailed": self.compute_cider(gt_detailed, pred_detailed) + if gt_detailed + else 0.0, + "cider_short": self.compute_cider(gt_short, pred_short) + if gt_short + else 0.0, } @@ -315,35 +338,38 @@ def evaluate_t1_topic( prediction_path: str, output_path: str, use_llm_judge: bool = True, - judge_name: str = "gpt" -): + judge_name: str = "gpt", +) -> Dict[str, Any]: """ - Convenience function to evaluate a single topic - + Evaluate T1 (summarization) for a single topic. + Args: - ground_truth_path: Path to ground truth JSON - prediction_path: Path to prediction JSON - output_path: Where to save results - use_llm_judge: Whether to use LLM judge + ground_truth_path: Path to ground truth JSON. + prediction_path: Path to prediction JSON. + output_path: Where to save results. + use_llm_judge: Whether to use LLM judge. + judge_name: Judge backend: "gpt" or "qwen". + + Returns: + Results dict with aggregated_metrics and per_entry_results. """ logger.info(f"Evaluating T1: {Path(ground_truth_path).stem}") - + metrics = T1Metrics(use_llm_judge=use_llm_judge, judge_name=judge_name) results = metrics.evaluate_topic( Path(ground_truth_path), Path(prediction_path), ) - + # Add CIDEr scores cider_scores = metrics.compute_cider_for_topic( - Path(ground_truth_path), - Path(prediction_path) + Path(ground_truth_path), Path(prediction_path) ) results["aggregated_metrics"]["cider"] = cider_scores - + # Save results - with open(output_path, 'w') as f: + with open(output_path, "w") as f: json.dump(results, f, indent=2) - + logger.info(f"Results saved to {output_path}") - return results \ No newline at end of file + return results diff --git a/sonic-o1/05_evaluation_inference/metrics/t2_metrics.py b/sonic-o1/05_evaluation_inference/metrics/t2_metrics.py index 0251055..711c8de 100644 --- a/sonic-o1/05_evaluation_inference/metrics/t2_metrics.py +++ b/sonic-o1/05_evaluation_inference/metrics/t2_metrics.py @@ -1,141 +1,157 @@ -""" -T2 Metrics: Question Answering (MCQ) Evaluation -Computes accuracy, ROUGE-L, CIDEr, text similarity, and LLM-as-Judge for rationales +"""t2_metrics.py + +T2 metrics: MCQ (accuracy, ROUGE-L, text similarity, LLM-as-Judge for rationales). + +Author: SONIC-O1 Team """ import json -import numpy as np -from pathlib import Path -from typing import Dict, List, Any, Optional import logging +from pathlib import Path +from typing import Any, Dict, List, Optional -from rouge_score import rouge_scorer +import numpy as np from pycocoevalcap.cider.cider import Cider -from sklearn.metrics.pairwise import cosine_similarity +from rouge_score import rouge_scorer from sentence_transformers import SentenceTransformer +from sklearn.metrics.pairwise import cosine_similarity + logger = logging.getLogger(__name__) -def make_key(entry): + +def make_key(entry: Dict[str, Any]) -> tuple: + """Create unique key for an entry (video_id, video_number, segment start, end).""" seg = entry["segment"] - return (entry['video_id'], entry['video_number'], float(seg['start']), float(seg['end'])) + return ( + entry["video_id"], + entry["video_number"], + float(seg["start"]), + float(seg["end"]), + ) + class T2Metrics: - """Compute metrics for T2: Question Answering (MCQ)""" - + """Compute metrics for T2: Question Answering (MCQ).""" + def __init__( self, use_llm_judge: bool = True, embedding_model: str = "all-MiniLM-L6-v2", - judge_name: str = "gpt" - ): + judge_name: str = "gpt", + ) -> None: """ - Initialize T2 metrics - + Initialize T2 metrics. + Args: - use_llm_judge: Whether to use LLM judge for rationale evaluation - embedding_model: Model for computing text similarity + use_llm_judge: Whether to use LLM judge for rationale evaluation. + embedding_model: Model for computing text similarity. + judge_name: Judge backend: "gpt" or "qwen". """ - self.rouge_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True) + self.rouge_scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True) self.cider_scorer = Cider() self.embedding_model = SentenceTransformer(embedding_model) - + self.use_llm_judge = use_llm_judge + # Lazy import: load only selected judge backend (env-dependent) if use_llm_judge: if judge_name == "gpt": - from llm_judge_gpt import LLMJudge + from llm_judge_gpt import LLMJudge # noqa: PLC0415 + self.llm_judge = LLMJudge() elif judge_name == "qwen": - from llm_judge_qwen import LLMJudge + from llm_judge_qwen import LLMJudge # noqa: PLC0415 + self.llm_judge = LLMJudge() else: raise ValueError(f"Unknown judge name: {judge_name}") - + def compute_accuracy(self, ground_truth_answer: str, predicted_answer: str) -> int: """ - Compute exact match accuracy - + Compute exact match accuracy. + Args: - ground_truth_answer: Ground truth answer (letter or index) - predicted_answer: Predicted answer - + ground_truth_answer: Ground truth answer (letter or index). + predicted_answer: Predicted answer. + Returns: - 1 if correct, 0 if incorrect + 1 if correct, 0 if incorrect. """ # Normalize answers (handle both letter and index format) gt_normalized = str(ground_truth_answer).strip().upper() pred_normalized = str(predicted_answer).strip().upper() - + return 1 if gt_normalized == pred_normalized else 0 - + def compute_rouge_l(self, reference: str, prediction: str) -> float: - """Compute ROUGE-L F1 score""" + """Compute ROUGE-L F1 score.""" scores = self.rouge_scorer.score(reference, prediction) - return scores['rougeL'].fmeasure - + return scores["rougeL"].fmeasure + def compute_text_similarity(self, reference: str, prediction: str) -> float: - """Compute cosine similarity between embeddings""" + """Compute cosine similarity between embeddings.""" ref_embedding = self.embedding_model.encode([reference]) pred_embedding = self.embedding_model.encode([prediction]) - + similarity = cosine_similarity(ref_embedding, pred_embedding)[0][0] return float(similarity) - + def evaluate_entry( - self, - ground_truth: Dict[str, Any], - prediction: Dict[str, Any] + self, ground_truth: Dict[str, Any], prediction: Dict[str, Any] ) -> Optional[Dict[str, Any]]: """ - Evaluate a single question entry - + Evaluate a single question entry. + Args: - ground_truth: Ground truth entry - prediction: Predicted entry - + ground_truth: Ground truth entry. + prediction: Predicted entry. + Returns: - Dict with all metrics, or None if prediction failed + Dict with all metrics, or None if prediction failed. """ # Skip failed predictions if "error" in prediction or "outputs" not in prediction: - logger.warning( - f"Skipping failed entry: video={prediction.get('video_id', 'unknown')}, " - f"segment={prediction.get('segment', 'unknown')}" - ) + vid = prediction.get("video_id", "unknown") + seg = prediction.get("segment", "unknown") + logger.warning(f"Skipping failed entry: video={vid}, segment={seg}") return None - + results = { "video_id": ground_truth["video_id"], "video_number": ground_truth["video_number"], "segment": ground_truth["segment"], - "question": ground_truth["question"] + "question": ground_truth["question"], } - + # Answer accuracy gt_answer_letter = ground_truth.get("answer_letter", "") gt_answer_index = ground_truth.get("answer_index", -1) - + pred_answer_letter = prediction["outputs"].get("answer_letter", "") pred_answer_index = prediction["outputs"].get("answer_index", -1) - + # Check both letter and index accuracy_letter = self.compute_accuracy(gt_answer_letter, pred_answer_letter) accuracy_index = self.compute_accuracy(gt_answer_index, pred_answer_index) - - results["answer_correct"] = max(accuracy_letter, accuracy_index) # Either format matches + + results["answer_correct"] = max( + accuracy_letter, accuracy_index + ) # Either format matches results["gt_answer"] = gt_answer_letter results["pred_answer"] = pred_answer_letter - + # Rationale evaluation gt_rationale = ground_truth.get("rationale", "") pred_rationale = prediction["outputs"].get("rationale", "") - + if gt_rationale and pred_rationale: results["rationale_metrics"] = { "rouge_l": self.compute_rouge_l(gt_rationale, pred_rationale), - "text_similarity": self.compute_text_similarity(gt_rationale, pred_rationale) + "text_similarity": self.compute_text_similarity( + gt_rationale, pred_rationale + ), } - + # LLM Judge for rationale quality if self.use_llm_judge: try: @@ -143,165 +159,164 @@ def evaluate_entry( question=ground_truth["question"], correct_answer=gt_rationale, predicted_answer=pred_rationale, - task_type="rationale" + task_type="rationale", ) results["rationale_metrics"]["llm_judge_score"] = llm_eval["score"] - results["rationale_metrics"]["llm_judge_justification"] = llm_eval["justification"] + results["rationale_metrics"]["llm_judge_justification"] = llm_eval[ + "justification" + ] except Exception as e: logger.error(f"LLM judge failed for question: {e}") results["rationale_metrics"]["llm_judge_score"] = None results["rationale_metrics"]["llm_judge_justification"] = str(e) else: results["rationale_metrics"] = None - + return results - + def evaluate_topic( - self, - ground_truth_path: Path, - prediction_path: Path + self, ground_truth_path: Path, prediction_path: Path ) -> Dict[str, Any]: """ - Evaluate all questions for a topic - + Evaluate all questions for a topic. + Args: - ground_truth_path: Path to ground truth JSON - prediction_path: Path to prediction JSON - + ground_truth_path: Path to ground truth JSON. + prediction_path: Path to prediction JSON. + Returns: - Dict with aggregated metrics + Dict with topic_id, topic_name, num_evaluated, aggregated_metrics. """ # Load data - with open(ground_truth_path, 'r') as f: + with open(ground_truth_path, "r") as f: ground_truth = json.load(f) - - with open(prediction_path, 'r') as f: + + with open(prediction_path, "r") as f: prediction = json.load(f) - + # Match entries by video_id + segment gt_entries = {make_key(e): e for e in ground_truth["entries"]} pred_entries = {make_key(e): e for e in prediction["entries"]} - + # Evaluate each entry entry_results = [] for key, gt_entry in gt_entries.items(): if key not in pred_entries: logger.warning(f"Missing prediction for question: {key}") continue - + pred_entry = pred_entries[key] result = self.evaluate_entry(gt_entry, pred_entry) - + # Skip None results from failed entries if result is not None: entry_results.append(result) - + # Aggregate metrics aggregated = self._aggregate_results(entry_results) - + # Compute CIDEr for rationales - cider_score = self._compute_cider_rationales( - ground_truth_path, - prediction_path - ) + cider_score = self._compute_cider_rationales(ground_truth_path, prediction_path) if cider_score is not None: aggregated["rationale_cider"] = cider_score - + return { "topic_id": ground_truth["topic_id"], "topic_name": ground_truth["topic_name"], "num_evaluated": len(entry_results), "aggregated_metrics": aggregated, - "per_entry_results": entry_results + "per_entry_results": entry_results, } - + def _aggregate_results(self, results: List[Dict[str, Any]]) -> Dict[str, Any]: - """Aggregate metrics across all questions""" + """Aggregate metrics across all questions.""" # Answer accuracy correct = sum(r["answer_correct"] for r in results) total = len(results) accuracy = correct / total if total > 0 else 0.0 - - aggregated = { - "accuracy": float(accuracy), - "correct": correct, - "total": total - } - + + aggregated = {"accuracy": float(accuracy), "correct": correct, "total": total} + # Rationale metrics rationale_metrics = { "rouge_l": [], "text_similarity": [], - "llm_judge_score": [] + "llm_judge_score": [], } - + for result in results: if result.get("rationale_metrics"): rm = result["rationale_metrics"] rationale_metrics["rouge_l"].append(rm["rouge_l"]) rationale_metrics["text_similarity"].append(rm["text_similarity"]) - + if rm.get("llm_judge_score") is not None: rationale_metrics["llm_judge_score"].append(rm["llm_judge_score"]) - + if rationale_metrics["rouge_l"]: aggregated["rationale"] = { "rouge_l_mean": float(np.mean(rationale_metrics["rouge_l"])), "rouge_l_std": float(np.std(rationale_metrics["rouge_l"])), - "text_similarity_mean": float(np.mean(rationale_metrics["text_similarity"])), - "text_similarity_std": float(np.std(rationale_metrics["text_similarity"])), + "text_similarity_mean": float( + np.mean(rationale_metrics["text_similarity"]) + ), + "text_similarity_std": float( + np.std(rationale_metrics["text_similarity"]) + ), } - + if rationale_metrics["llm_judge_score"]: - aggregated["rationale"]["llm_judge_score_mean"] = float(np.mean(rationale_metrics["llm_judge_score"])) - aggregated["rationale"]["llm_judge_score_std"] = float(np.std(rationale_metrics["llm_judge_score"])) - + aggregated["rationale"]["llm_judge_score_mean"] = float( + np.mean(rationale_metrics["llm_judge_score"]) + ) + aggregated["rationale"]["llm_judge_score_std"] = float( + np.std(rationale_metrics["llm_judge_score"]) + ) + return aggregated - + def _compute_cider_rationales( - self, - ground_truth_path: Path, - prediction_path: Path + self, ground_truth_path: Path, prediction_path: Path ) -> Optional[float]: - """Compute CIDEr score for rationales""" + """Compute CIDEr score for rationales.""" try: - with open(ground_truth_path, 'r') as f: + with open(ground_truth_path, "r") as f: ground_truth = json.load(f) - - with open(prediction_path, 'r') as f: + + with open(prediction_path, "r") as f: prediction = json.load(f) - + gt_entries = {make_key(e): e for e in ground_truth["entries"]} pred_entries = {make_key(e): e for e in prediction["entries"]} - + gt_rationales = [] pred_rationales = [] - + for key, gt_entry in gt_entries.items(): if key in pred_entries: pred_entry = pred_entries[key] - + # Skip failed predictions if "error" in pred_entry or "outputs" not in pred_entry: continue - + gt_rat = gt_entry.get("rationale", "") pred_rat = pred_entry["outputs"].get("rationale", "") - + if gt_rat and pred_rat: gt_rationales.append(gt_rat) pred_rationales.append(pred_rat) - + if gt_rationales and pred_rationales: gts = {i: [ref] for i, ref in enumerate(gt_rationales)} res = {i: [pred] for i, pred in enumerate(pred_rationales)} - + score, _ = self.cider_scorer.compute_score(gts, res) return float(score) - + except Exception as e: logger.error(f"Failed to compute CIDEr for rationales: {e}") - + return None @@ -311,29 +326,30 @@ def evaluate_t2_topic( output_path: str, use_llm_judge: bool = True, judge_name: str = "gpt", -): +) -> Dict[str, Any]: """ - Convenience function to evaluate a single topic - + Evaluate T2 (MCQ) for a single topic. + Args: - ground_truth_path: Path to ground truth JSON - prediction_path: Path to prediction JSON - output_path: Where to save results - use_llm_judge: Whether to use LLM judge + ground_truth_path: Path to ground truth JSON. + prediction_path: Path to prediction JSON. + output_path: Where to save results. + use_llm_judge: Whether to use LLM judge. + judge_name: Judge backend: "gpt" or "qwen". + + Returns: + Results dict with aggregated_metrics and per_entry_results. """ logger.info(f"Evaluating T2: {Path(ground_truth_path).stem}") - - metrics = T2Metrics(use_llm_judge=use_llm_judge,judge_name=judge_name) - results = metrics.evaluate_topic( - Path(ground_truth_path), - Path(prediction_path) - ) - + + metrics = T2Metrics(use_llm_judge=use_llm_judge, judge_name=judge_name) + results = metrics.evaluate_topic(Path(ground_truth_path), Path(prediction_path)) + # Save results - with open(output_path, 'w') as f: + with open(output_path, "w") as f: json.dump(results, f, indent=2) - + logger.info(f"Results saved to {output_path}") logger.info(f"Accuracy: {results['aggregated_metrics']['accuracy']:.2%}") - - return results \ No newline at end of file + + return results diff --git a/sonic-o1/05_evaluation_inference/metrics/t3_metrics.py b/sonic-o1/05_evaluation_inference/metrics/t3_metrics.py index 336dd4a..654974e 100644 --- a/sonic-o1/05_evaluation_inference/metrics/t3_metrics.py +++ b/sonic-o1/05_evaluation_inference/metrics/t3_metrics.py @@ -1,203 +1,209 @@ -""" -T3 Metrics: Temporal Localization Evaluation -Computes IoU, Mean IoU, Recall@ฮธ, MAE, and LLM-as-Judge for rationales +"""t3_metrics.py + +T3 metrics: temporal localization (IoU, Recall@ฮธ, MAE, LLM-as-Judge for rationales). + +Author: SONIC-O1 Team """ import json -import numpy as np -from pathlib import Path -from typing import Dict, List, Any, Tuple, Optional import logging +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple -from rouge_score import rouge_scorer +import numpy as np from pycocoevalcap.cider.cider import Cider -from sklearn.metrics.pairwise import cosine_similarity +from rouge_score import rouge_scorer from sentence_transformers import SentenceTransformer +from sklearn.metrics.pairwise import cosine_similarity logger = logging.getLogger(__name__) -def make_key(entry): + +def make_key(entry: Dict[str, Any]) -> tuple: + """Create unique key for an entry (video_id, video_number, segment start, end).""" seg = entry["segment"] - return (entry['video_id'], entry['video_number'], float(seg['start']), float(seg['end'])) + return ( + entry["video_id"], + entry["video_number"], + float(seg["start"]), + float(seg["end"]), + ) + class T3Metrics: - """Compute metrics for T3: Temporal Localization""" - + """Compute metrics for T3: Temporal Localization.""" + def __init__( self, use_llm_judge: bool = True, embedding_model: str = "all-MiniLM-L6-v2", - iou_thresholds: List[float] = [0.3, 0.5, 0.7], - judge_name: str = "gpt" - ): + iou_thresholds: Optional[List[float]] = None, + judge_name: str = "gpt", + ) -> None: """ - Initialize T3 metrics - + Initialize T3 metrics. + Args: - use_llm_judge: Whether to use LLM judge for rationale evaluation - embedding_model: Model for computing text similarity - iou_thresholds: IoU thresholds for Recall@ฮธ computation + use_llm_judge: Whether to use LLM judge for rationale evaluation. + embedding_model: Model for computing text similarity. + iou_thresholds: IoU thresholds for Recall@ฮธ (default [0.3, 0.5, 0.7]). + judge_name: Judge backend: "gpt" or "qwen". """ - self.rouge_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True) + if iou_thresholds is None: + iou_thresholds = [0.3, 0.5, 0.7] + self.rouge_scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True) self.cider_scorer = Cider() self.embedding_model = SentenceTransformer(embedding_model) self.iou_thresholds = iou_thresholds - + self.use_llm_judge = use_llm_judge + # Lazy import: load only selected judge backend (env-dependent) if use_llm_judge: if judge_name == "gpt": - from llm_judge_gpt import LLMJudge + from llm_judge_gpt import LLMJudge # noqa: PLC0415 + self.llm_judge = LLMJudge() elif judge_name == "qwen": - from llm_judge_qwen import LLMJudge + from llm_judge_qwen import LLMJudge # noqa: PLC0415 + self.llm_judge = LLMJudge() else: raise ValueError(f"Unknown judge name: {judge_name}") - + def compute_iou( - self, - gt_start: float, - gt_end: float, - pred_start: float, - pred_end: float + self, gt_start: float, gt_end: float, pred_start: float, pred_end: float ) -> float: """ - Compute Intersection over Union (IoU) for temporal segments - + Compute Intersection over Union (IoU) for temporal segments. + Args: - gt_start: Ground truth start time (seconds) - gt_end: Ground truth end time (seconds) - pred_start: Predicted start time (seconds) - pred_end: Predicted end time (seconds) - + gt_start: Ground truth start time (seconds). + gt_end: Ground truth end time (seconds). + pred_start: Predicted start time (seconds). + pred_end: Predicted end time (seconds). + Returns: - IoU score (0-1) + IoU score in [0, 1]. """ # Compute intersection intersection_start = max(gt_start, pred_start) intersection_end = min(gt_end, pred_end) intersection = max(0, intersection_end - intersection_start) - + # Compute union union_start = min(gt_start, pred_start) union_end = max(gt_end, pred_end) union = union_end - union_start - + # Avoid division by zero if union == 0: return 0.0 - + iou = intersection / union return float(iou) - + def compute_mae( - self, - gt_start: float, - gt_end: float, - pred_start: float, - pred_end: float + self, gt_start: float, gt_end: float, pred_start: float, pred_end: float ) -> Tuple[float, float, float]: """ - Compute Mean Absolute Error for start, end, and average - + Compute Mean Absolute Error for start, end, and average. + Args: - gt_start: Ground truth start time - gt_end: Ground truth end time - pred_start: Predicted start time - pred_end: Predicted end time - + gt_start: Ground truth start time. + gt_end: Ground truth end time. + pred_start: Predicted start time. + pred_end: Predicted end time. + Returns: - Tuple of (start_mae, end_mae, avg_mae) + Tuple of (start_mae, end_mae, avg_mae). """ start_mae = abs(gt_start - pred_start) end_mae = abs(gt_end - pred_end) avg_mae = (start_mae + end_mae) / 2.0 - + return float(start_mae), float(end_mae), float(avg_mae) - + def compute_rouge_l(self, reference: str, prediction: str) -> float: - """Compute ROUGE-L F1 score""" + """Compute ROUGE-L F1 score.""" scores = self.rouge_scorer.score(reference, prediction) - return scores['rougeL'].fmeasure - + return scores["rougeL"].fmeasure + def compute_text_similarity(self, reference: str, prediction: str) -> float: - """Compute cosine similarity between embeddings""" + """Compute cosine similarity between embeddings.""" ref_embedding = self.embedding_model.encode([reference]) pred_embedding = self.embedding_model.encode([prediction]) - + similarity = cosine_similarity(ref_embedding, pred_embedding)[0][0] return float(similarity) - + def evaluate_question( self, ground_truth: Dict[str, Any], prediction: Dict[str, Any], - segment_info: Dict[str, Any] + segment_info: Dict[str, Any], ) -> Optional[Dict[str, Any]]: """ - Evaluate a single temporal localization question - + Evaluate a single temporal localization question. + Args: - ground_truth: Ground truth question entry - prediction: Predicted question entry (the question object itself, not wrapped) - segment_info: Segment information (video_id, etc.) - + ground_truth: Ground truth question entry. + prediction: Predicted question entry (question object, not wrapped). + segment_info: Segment information (video_id, video_number, segment). + Returns: - Dict with all metrics, or None if prediction failed + Dict with all metrics, or None if prediction failed. """ results = { "question_id": ground_truth["question_id"], "question": ground_truth["question"], - **segment_info + **segment_info, } - + # Extract temporal bounds # GT: has "answer" wrapper gt_answer = ground_truth.get("answer", {}) gt_start = float(gt_answer.get("start_s", 0)) - gt_end = float(gt_answer.get("end_s", 0)) + gt_end = float(gt_answer.get("end_s", 0)) - # Prediction: NO "answer" wrapper - values are directly in the question object pred_start = float(prediction.get("start_s", 0)) - pred_end = float(prediction.get("end_s", 0)) + pred_end = float(prediction.get("end_s", 0)) - results["gt_interval"] = {"start": gt_start, "end": gt_end} results["pred_interval"] = {"start": pred_start, "end": pred_end} - + # Compute IoU iou = self.compute_iou(gt_start, gt_end, pred_start, pred_end) results["iou"] = iou - + # Compute Recall@ฮธ for each threshold results["recall_at_threshold"] = {} for threshold in self.iou_thresholds: - results["recall_at_threshold"][f"R@{threshold}"] = 1 if iou >= threshold else 0 - + results["recall_at_threshold"][f"R@{threshold}"] = ( + 1 if iou >= threshold else 0 + ) + # Compute MAE start_mae, end_mae, avg_mae = self.compute_mae( gt_start, gt_end, pred_start, pred_end ) - results["mae"] = { - "start": start_mae, - "end": end_mae, - "average": avg_mae - } - + results["mae"] = {"start": start_mae, "end": end_mae, "average": avg_mae} + # Rationale evaluation # GT: at question level gt_rationale = ground_truth.get("rationale_model", "") # Prediction: directly in question object pred_rationale = prediction.get("rationale_model", "") - + if gt_rationale and pred_rationale: results["rationale_metrics"] = { "rouge_l": self.compute_rouge_l(gt_rationale, pred_rationale), - "text_similarity": self.compute_text_similarity(gt_rationale, pred_rationale) + "text_similarity": self.compute_text_similarity( + gt_rationale, pred_rationale + ), } - + # LLM Judge for rationale if self.use_llm_judge: try: @@ -205,113 +211,113 @@ def evaluate_question( question=ground_truth["question"], correct_answer=gt_rationale, predicted_answer=pred_rationale, - task_type="temporal_rationale" + task_type="temporal_rationale", ) results["rationale_metrics"]["llm_judge_score"] = llm_eval["score"] - results["rationale_metrics"]["llm_judge_justification"] = llm_eval["justification"] + results["rationale_metrics"]["llm_judge_justification"] = llm_eval[ + "justification" + ] except Exception as e: logger.error(f"LLM judge failed: {e}") results["rationale_metrics"]["llm_judge_score"] = None results["rationale_metrics"]["llm_judge_justification"] = str(e) else: results["rationale_metrics"] = None - + return results - + def evaluate_topic( - self, - ground_truth_path: Path, - prediction_path: Path + self, ground_truth_path: Path, prediction_path: Path ) -> Dict[str, Any]: """ - Evaluate all temporal questions for a topic - + Evaluate all temporal questions for a topic. + Args: - ground_truth_path: Path to ground truth JSON - prediction_path: Path to prediction JSON - + ground_truth_path: Path to ground truth JSON. + prediction_path: Path to prediction JSON. + Returns: - Dict with aggregated metrics + Dict with topic_id, topic_name, num_evaluated, aggregated_metrics. """ # Load data - with open(ground_truth_path, 'r') as f: + with open(ground_truth_path, "r") as f: ground_truth = json.load(f) - - with open(prediction_path, 'r') as f: + + with open(prediction_path, "r") as f: prediction = json.load(f) - + # Match entries by video_id + segment - + gt_entries = {make_key(e): e for e in ground_truth["entries"]} pred_entries = {make_key(e): e for e in prediction["entries"]} - + # Evaluate each question all_question_results = [] - + for key, gt_entry in gt_entries.items(): if key not in pred_entries: logger.warning(f"Missing prediction for segment: {key}") continue - + pred_entry = pred_entries[key] - + # Skip failed segment predictions - CHECK FOR "outputs" NOT "questions" if "error" in pred_entry or "outputs" not in pred_entry: logger.warning(f"Skipping failed segment: {key}") continue - + segment_info = { "video_id": gt_entry["video_id"], "video_number": gt_entry["video_number"], - "segment": gt_entry["segment"] + "segment": gt_entry["segment"], } - + # Match questions by question_id gt_questions = {q["question_id"]: q for q in gt_entry.get("questions", [])} # GET questions from outputs - pred_questions = {q["question_id"]: q for q in pred_entry.get("outputs", {}).get("questions", [])} - + pred_questions = { + q["question_id"]: q + for q in pred_entry.get("outputs", {}).get("questions", []) + } + for qid, gt_q in gt_questions.items(): if qid not in pred_questions: logger.warning(f"Missing prediction for question {qid}") continue - + pred_q = pred_questions[qid] result = self.evaluate_question(gt_q, pred_q, segment_info) - + # Skip None results from failed questions if result is not None: all_question_results.append(result) - + # Aggregate metrics aggregated = self._aggregate_results(all_question_results) - + # Compute CIDEr for rationales - cider_score = self._compute_cider_rationales( - ground_truth_path, - prediction_path - ) + cider_score = self._compute_cider_rationales(ground_truth_path, prediction_path) if cider_score is not None: aggregated["rationale_cider"] = cider_score - + return { "topic_id": ground_truth["topic_id"], "topic_name": ground_truth["topic_name"], "num_evaluated": len(all_question_results), "aggregated_metrics": aggregated, - "per_question_results": all_question_results + "per_question_results": all_question_results, } - + def _aggregate_results(self, results: List[Dict[str, Any]]) -> Dict[str, Any]: - """Aggregate metrics across all questions""" + """Aggregate metrics across all questions.""" ious = [r["iou"] for r in results] - + aggregated = { "mean_iou": float(np.mean(ious)) if ious else 0.0, "std_iou": float(np.std(ious)) if ious else 0.0, - "median_iou": float(np.median(ious)) if ious else 0.0 + "median_iou": float(np.median(ious)) if ious else 0.0, } - + # Recall@ฮธ for threshold in self.iou_thresholds: key = f"R@{threshold}" @@ -321,104 +327,114 @@ def _aggregate_results(self, results: List[Dict[str, Any]]) -> Dict[str, Any]: aggregated[key] = { "recall": float(count / total) if total > 0 else 0.0, "count": count, - "total": total + "total": total, } - + # MAE start_maes = [r["mae"]["start"] for r in results] end_maes = [r["mae"]["end"] for r in results] avg_maes = [r["mae"]["average"] for r in results] - + aggregated["mae"] = { "start_mean": float(np.mean(start_maes)) if start_maes else 0.0, "end_mean": float(np.mean(end_maes)) if end_maes else 0.0, - "average_mean": float(np.mean(avg_maes)) if avg_maes else 0.0 + "average_mean": float(np.mean(avg_maes)) if avg_maes else 0.0, } - + # Rationale metrics rationale_metrics = { "rouge_l": [], "text_similarity": [], - "llm_judge_score": [] + "llm_judge_score": [], } - + for result in results: if result.get("rationale_metrics"): rm = result["rationale_metrics"] rationale_metrics["rouge_l"].append(rm["rouge_l"]) rationale_metrics["text_similarity"].append(rm["text_similarity"]) - + if rm.get("llm_judge_score") is not None: rationale_metrics["llm_judge_score"].append(rm["llm_judge_score"]) - + if rationale_metrics["rouge_l"]: aggregated["rationale"] = { "rouge_l_mean": float(np.mean(rationale_metrics["rouge_l"])), "rouge_l_std": float(np.std(rationale_metrics["rouge_l"])), - "text_similarity_mean": float(np.mean(rationale_metrics["text_similarity"])), - "text_similarity_std": float(np.std(rationale_metrics["text_similarity"])), + "text_similarity_mean": float( + np.mean(rationale_metrics["text_similarity"]) + ), + "text_similarity_std": float( + np.std(rationale_metrics["text_similarity"]) + ), } - + if rationale_metrics["llm_judge_score"]: - aggregated["rationale"]["llm_judge_score_mean"] = float(np.mean(rationale_metrics["llm_judge_score"])) - aggregated["rationale"]["llm_judge_score_std"] = float(np.std(rationale_metrics["llm_judge_score"])) - + aggregated["rationale"]["llm_judge_score_mean"] = float( + np.mean(rationale_metrics["llm_judge_score"]) + ) + aggregated["rationale"]["llm_judge_score_std"] = float( + np.std(rationale_metrics["llm_judge_score"]) + ) + return aggregated - + def _compute_cider_rationales( - self, - ground_truth_path: Path, - prediction_path: Path + self, ground_truth_path: Path, prediction_path: Path ) -> Optional[float]: - """Compute CIDEr score for rationales""" + """Compute CIDEr score for rationales.""" try: - with open(ground_truth_path, 'r') as f: + with open(ground_truth_path, "r") as f: ground_truth = json.load(f) - - with open(prediction_path, 'r') as f: + + with open(prediction_path, "r") as f: prediction = json.load(f) - - + gt_entries = {make_key(e): e for e in ground_truth["entries"]} pred_entries = {make_key(e): e for e in prediction["entries"]} - + gt_rationales = [] pred_rationales = [] - + for key, gt_entry in gt_entries.items(): if key in pred_entries: pred_entry = pred_entries[key] - + # Skip failed segment predictions if "error" in pred_entry or "outputs" not in pred_entry: continue - - gt_questions = {q["question_id"]: q for q in gt_entry.get("questions", [])} + + gt_questions = { + q["question_id"]: q for q in gt_entry.get("questions", []) + } # GET questions from outputs - pred_questions = {q["question_id"]: q for q in pred_entry.get("outputs", {}).get("questions", [])} - + pred_questions = { + q["question_id"]: q + for q in pred_entry.get("outputs", {}).get("questions", []) + } + for qid, gt_q in gt_questions.items(): if qid in pred_questions: pred_q = pred_questions[qid] - + # rationale_model is directly in question objects gt_rat = gt_q.get("rationale_model", "") pred_rat = pred_q.get("rationale_model", "") - + if gt_rat and pred_rat: gt_rationales.append(gt_rat) pred_rationales.append(pred_rat) - + if gt_rationales and pred_rationales: gts = {i: [ref] for i, ref in enumerate(gt_rationales)} res = {i: [pred] for i, pred in enumerate(pred_rationales)} - + score, _ = self.cider_scorer.compute_score(gts, res) return float(score) - + except Exception as e: logger.error(f"Failed to compute CIDEr for rationales: {e}") - + return None @@ -427,36 +443,39 @@ def evaluate_t3_topic( prediction_path: str, output_path: str, use_llm_judge: bool = True, - iou_thresholds: List[float] = [0.3, 0.5, 0.7], - judge_name: str = "gpt" -): + iou_thresholds: Optional[List[float]] = None, + judge_name: str = "gpt", +) -> Dict[str, Any]: """ - Convenience function to evaluate a single topic - + Evaluate T3 (temporal localization) for a single topic. + Args: - ground_truth_path: Path to ground truth JSON - prediction_path: Path to prediction JSON - output_path: Where to save results - use_llm_judge: Whether to use LLM judge - iou_thresholds: IoU thresholds for recall computation + ground_truth_path: Path to ground truth JSON. + prediction_path: Path to prediction JSON. + output_path: Where to save results. + use_llm_judge: Whether to use LLM judge. + iou_thresholds: IoU thresholds for recall (default [0.3, 0.5, 0.7]). + judge_name: Judge backend: "gpt" or "qwen". + + Returns: + Results dict with aggregated_metrics and per_question_results. """ + if iou_thresholds is None: + iou_thresholds = [0.3, 0.5, 0.7] logger.info(f"Evaluating T3: {Path(ground_truth_path).stem}") - + metrics = T3Metrics( use_llm_judge=use_llm_judge, iou_thresholds=iou_thresholds, - judge_name=judge_name - ) - results = metrics.evaluate_topic( - Path(ground_truth_path), - Path(prediction_path) + judge_name=judge_name, ) - + results = metrics.evaluate_topic(Path(ground_truth_path), Path(prediction_path)) + # Save results - with open(output_path, 'w') as f: + with open(output_path, "w") as f: json.dump(results, f, indent=2) - + logger.info(f"Results saved to {output_path}") logger.info(f"Mean IoU: {results['aggregated_metrics']['mean_iou']:.3f}") - - return results \ No newline at end of file + + return results diff --git a/sonic-o1/05_evaluation_inference/models/__init__.py b/sonic-o1/05_evaluation_inference/models/__init__.py index 9c0774a..382fdb2 100644 --- a/sonic-o1/05_evaluation_inference/models/__init__.py +++ b/sonic-o1/05_evaluation_inference/models/__init__.py @@ -1,6 +1,10 @@ +"""models/__init__.py + +Models package: base class and model implementations for evaluation. + +Author: SONIC-O1 Team """ -Models package -""" + from .base_model import BaseModel -__all__ = ['BaseModel'] \ No newline at end of file +__all__ = ["BaseModel"] diff --git a/sonic-o1/05_evaluation_inference/models/baichuan_omni.py b/sonic-o1/05_evaluation_inference/models/baichuan_omni.py new file mode 100644 index 0000000..2e9bb07 --- /dev/null +++ b/sonic-o1/05_evaluation_inference/models/baichuan_omni.py @@ -0,0 +1,633 @@ +"""baichuan_omni.py +Baichuan-Omni-1.5 wrapper following BaseModel pattern. + +Architecture (traced from source): +- LLM backbone: Qwen2.5-7B (bfloat16) +- Vision encoder: CLIP ViT-L/14 (patch=14, spatial_merge=2) +- Audio encoder: Whisper-large (16kHz mel, max 30s window โ€” same constraint as OLA) +- Video handling: processor extracts frames internally at 1fps, saves jpgs to cache dir. + We patch max_frame_num after load to honour our config value. + +Audio strategy: + Inputs may be .m4a; we convert to 16 kHz mono wav with ffmpeg, then run the + same 30s-chunk pipeline as below. Temporary wavs are removed after generate(). + Like OLA, Baichuan's audio encoder is Whisper-based (max_audio_seconds=30). + Their processor hard-truncates at 30s. We override this by pre-processing the + audio ourselves: uniformly sample up to max_chunks windows of 30s each, + concatenate into a trimmed wav, write to cache dir, and pass that path to the + processor. This gives us the same uniform-sampling coverage as OLA/VITA. + +Input format (raw string โ€” processor handles all tensor prep internally): + {system} + {"local": "/abs/path.mp4"} + {"path": "/abs/path.wav"} (.m4a converted to wav first) + {prompt} + +Text-only output: stop at audiogen_start_token_id=151700. No vocoder loaded. + +Author: SONIC-O1 Team +""" + +import json +import logging +import os +import subprocess +import sys +from pathlib import Path +from typing import Any, Dict, Literal, Optional, Union + +import numpy as np +import torch +import torchaudio +import re + +from .base_model import BaseModel + + +logger = logging.getLogger(__name__) + +# โ”€โ”€ Role/tag constants (from web_demo/constants.py) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ +SYS_START = "" +USER_START = "" +ASST_START = "" +AUDIOTEXT = "" +VIDEO_START = "" +VIDEO_END = "" +AUDIO_START = "" +AUDIO_END = "" + +# Audio encoder constants (from config.json: audio_config) +AUDIO_SR = 16000 # sampling_rate +CHUNK_SAMPLES = 30 * AUDIO_SR # max_audio_seconds=30 โ†’ 480 000 samples + + +class BaichuanOmni(BaseModel): + """Baichuan-Omni-1.5 wrapper following BaseModel pattern.""" + + def __init__(self, model_name: str, config: Dict[str, Any]) -> None: + super().__init__(model_name, config) + + self.model_path = config.get("model_path", "baichuan-inc/Baichuan-Omni-1.5") + self.baichuan_repo_path = config.get("baichuan_repo_path") + + # Cache dir โ€” processor writes extracted video frame jpgs here. + # Must be absolute path on persistent storage (not $SCRATCH). + self.cache_dir = config.get( + "cache_dir", + "/projects/aixpert/users/ahmadradw/VideoQA-Agentic/sonic-o1/.cache", + ) + + # Frame config โ€” patched into model.config.video_config.max_frame_num after load + self.default_max_frames = config.get("max_frames", 32) + self.default_min_frames = config.get("min_frames", 8) + + # Generation config + gen_config = config.get("generation_config", {}) + self.temperature = gen_config.get("temperature", 0.7) + self.top_p = gen_config.get("top_p", 0.95) + self.num_beams = gen_config.get("num_beams", 1) + self.max_new_tokens = gen_config.get("max_new_tokens", 8192) + + # Model components (populated in load()) + self.tokenizer = None + self.model = None + + # Stats + self.stats = { + "total_samples": 0, + "audio_chunks_sampled": 0, + } + def convert_av1_to_h264( + self, video_path: Path, output_dir: Optional[Path] = None + ) -> Path: + """ + Convert AV1 video to H.264 for Decord compatibility. + + Args: + video_path: Input video path + output_dir: Output directory (default: creates 'converted' subdir) + + Returns: + Path to converted video. + """ + import subprocess + + if output_dir is None: + output_dir = video_path.parent / "converted" + output_dir.mkdir(exist_ok=True) + + output_path = output_dir / f"{video_path.stem}_h264{video_path.suffix}" + + # Check if already converted + if output_path.exists(): + logger.info(f"Using cached converted video: {output_path.name}") + return output_path + + logger.info(f"Converting AV1 to H.264: {video_path.name}") + + cmd = [ + "ffmpeg", + "-i", + str(video_path), + "-c:v", + "libx264", + "-preset", + "fast", + "-crf", + "23", + "-c:a", + "copy", + "-y", + str(output_path), + ] + + try: + subprocess.run(cmd, check=True, capture_output=True, text=True) + logger.info(f"โœ“ Conversion successful: {output_path.name}") + return output_path + except subprocess.CalledProcessError as e: + logger.error(f"โœ— Conversion failed: {e.stderr}") + raise RuntimeError(f"Failed to convert AV1 video: {e.stderr}") + + def _check_video_compatibility(self, video_path: Path) -> Optional[Path]: + """ + Check if video is compatible with Decord. + If AV1, automatically convert to H.264. + + Args: + video_path: Original video path. + + Returns: + Path to compatible video (original or converted). + """ + import subprocess + + from decord import VideoReader, cpu + + try: + vr = VideoReader(str(video_path), ctx=cpu(0), num_threads=1) + frame_count = len(vr) + del vr + logger.info(f"โœ“ Video compatible: {video_path.name} ({frame_count} frames)") + return video_path + except Exception: + logger.warning(f"โš  Video incompatible with Decord: {video_path.name}") + + # Detect codec + try: + result = subprocess.run( + [ + "ffprobe", + "-v", + "error", + "-select_streams", + "v:0", + "-show_entries", + "stream=codec_name", + "-of", + "default=noprint_wrappers=1:nokey=1", + str(video_path), + ], + capture_output=True, + text=True, + check=True, + ) + codec = result.stdout.strip() + + if codec == "av1": + logger.info(" โ†’ Detected AV1 codec - converting to H.264...") + try: + return self.convert_av1_to_h264(video_path) + except Exception as conv_error: + logger.error(f" โœ— Conversion failed: {conv_error}") + return None + else: + logger.warning(f" โœ— Unsupported codec '{codec}' - skipping") + return None + except Exception as probe_error: + logger.error(f" โœ— Failed to detect codec: {probe_error}") + return None + + # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + # Loading + # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + def load(self) -> None: + """Load Baichuan-Omni-1.5 and bind processor.""" + # Add repo + model subdir to sys.path so trust_remote_code resolves + # modeling_omni.py, processor_omni.py, configuration_omni.py etc. + if self.baichuan_repo_path: + repo_abs = os.path.abspath(self.baichuan_repo_path) + model_dir = os.path.join(repo_abs, "baichuan-omni", "model") + for p in [repo_abs, model_dir]: + if p not in sys.path: + sys.path.insert(0, p) + logger.info(f"Added to sys.path: {p}") + + from transformers import AutoModelForCausalLM, AutoTokenizer + + model_path = os.path.abspath(self.model_path) + logger.info(f"Loading Baichuan-Omni-1.5 from {model_path}") + + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + ).cuda() + + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, + trust_remote_code=True, + ) + + self.model.training = False + + # bind_processor attaches OmniMMProcessor to model.processor. + # relative_path = cache dir where processor saves video frame jpgs. + os.makedirs(self.cache_dir, exist_ok=True) + self.model.bind_processor( + self.tokenizer, + training=False, + relative_path=self.cache_dir, + ) + + # Patch frame cap to honour our config (config.json default is 32). + self.model.config.video_config.max_frame_num = self.default_max_frames + logger.info(f"Patched video_config.max_frame_num = {self.default_max_frames}") + + self.model.eval() + logger.info("Baichuan-Omni-1.5 loaded successfully") + + # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + # Audio pre-processing + # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + def _m4a_to_wav(self, m4a_path: str) -> str: + """ + Decode m4a to a temporary wav (16 kHz mono) via ffmpeg. + torchaudio often cannot load m4a reliably; Baichuan expects a wav path. + Caller must delete the returned path when done. + """ + os.makedirs(self.cache_dir, exist_ok=True) + out_name = f"baichuan_m4a_conv_{os.getpid()}_{abs(hash(m4a_path)) % 10**8}.wav" + out_path = os.path.join(self.cache_dir, out_name) + cmd = [ + "ffmpeg", + "-nostdin", + "-hide_banner", + "-loglevel", + "error", + "-y", + "-i", + m4a_path, + "-ac", + "1", + "-ar", + str(AUDIO_SR), + "-f", + "wav", + out_path, + ] + try: + subprocess.run(cmd, check=True, capture_output=True, text=True) + except FileNotFoundError as e: + raise RuntimeError( + "ffmpeg is required to convert .m4a audio; install ffmpeg and retry." + ) from e + except subprocess.CalledProcessError as e: + if os.path.exists(out_path): + try: + os.remove(out_path) + except OSError: + pass + err = (e.stderr or e.stdout or "").strip() + raise RuntimeError(f"ffmpeg m4aโ†’wav failed: {err or e}") from e + logger.info(f"Converted m4a โ†’ wav: {out_path}") + return out_path + + def _maybe_convert_m4a( + self, audio_path: str + ) -> tuple[str, Optional[str]]: + """ + If ``audio_path`` is .m4a, convert to wav and return + (path_to_wav, temp_path_to_delete). Otherwise return (audio_path, None). + """ + if Path(audio_path).suffix.lower() != ".m4a": + return audio_path, None + converted = self._m4a_to_wav(audio_path) + return converted, converted + + def _audio_has_stream(self, audio_path: str) -> bool: + """Return True if ffprobe sees at least one audio stream.""" + cmd = [ + "ffprobe", + "-v", + "error", + "-select_streams", + "a:0", + "-show_entries", + "stream=codec_type", + "-of", + "default=nw=1:nk=1", + audio_path, + ] + try: + out = subprocess.run( + cmd, check=False, capture_output=True, text=True + ).stdout.strip() + return bool(out) + except Exception: + return False + + def _prepare_audio( + self, + audio_path: str, + max_chunks: Optional[int], + ) -> str: + """ + Uniformly sample up to max_chunks ร— 30s windows from the full audio, + concatenate, write a trimmed wav to cache dir, return its absolute path. + + Like OLA, we use 30s windows because Baichuan's audio encoder is + Whisper-based (max_audio_seconds=30 in config.json). Their processor + would hard-truncate at 30s โ€” we override that here. + + ``audio_path`` should be a format ``torchaudio.load`` can read (e.g. wav); + use :meth:`_maybe_convert_m4a` first for .m4a inputs. + """ + waveform, sr = torchaudio.load(audio_path) + + # Resample to 16 kHz if needed + if sr != AUDIO_SR: + waveform = torchaudio.functional.resample(waveform, sr, AUDIO_SR) + + # Downmix to mono + if waveform.shape[0] > 1: + waveform = waveform.mean(dim=0, keepdim=True) + + waveform = waveform.squeeze(0) # [samples] + total_samples = waveform.shape[0] + logger.info(f"Audio: {total_samples / AUDIO_SR:.1f}s total") + + # Build non-overlapping 30s chunks + all_chunks = [] + for start in range(0, total_samples, CHUNK_SAMPLES): + chunk = waveform[start: start + CHUNK_SAMPLES] + if chunk.shape[0] < CHUNK_SAMPLES: + chunk = torch.nn.functional.pad( + chunk, (0, CHUNK_SAMPLES - chunk.shape[0]) + ) + all_chunks.append(chunk) + + total_chunks = len(all_chunks) + logger.info(f"Audio: {total_chunks} chunk(s) of 30s") + + # Uniform sampling if over budget + if max_chunks is not None and total_chunks > max_chunks: + indices = np.linspace(0, total_chunks - 1, max_chunks, dtype=int) + sampled = [all_chunks[i] for i in indices] + logger.info(f"Uniformly sampled {max_chunks}/{total_chunks} audio chunks") + self.stats["audio_chunks_sampled"] += 1 + else: + sampled = all_chunks + logger.info(f"Using all {total_chunks} audio chunk(s)") + + # Concatenate and write to cache + trimmed = torch.cat(sampled, dim=0).unsqueeze(0) # [1, samples] + out_name = f"baichuan_audio_{os.getpid()}_{abs(hash(audio_path)) % 10**8}.wav" + out_path = os.path.join(self.cache_dir, out_name) + torchaudio.save(out_path, trimmed, AUDIO_SR) + logger.info(f"Trimmed audio โ†’ {out_path}") + return out_path + + # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + # Message builder + # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + def _build_message( + self, + video_path: Optional[str], + audio_path: Optional[str], + prompt: str, + system: str = "You are a helpful assistant.", + ) -> str: + """ + Build the raw input string the OmniMMProcessor expects. + + Format (from traced web_demo/s2s_gradio_demo_cosy_multiturn.py): + {system} + {"local": "..."} โ† optional + {"path": "..."} โ† optional + {prompt} + """ + msg = SYS_START + system + USER_START + + if video_path is not None: + msg += VIDEO_START + json.dumps({"local": video_path}) + VIDEO_END + + if audio_path is not None: + msg += AUDIO_START + json.dumps({"path": audio_path}) + AUDIO_END + + # signals the model to respond in text mode. + # Must be present even for text-only output โ€” matches training format. + msg += AUDIOTEXT + prompt + ASST_START + return msg + + # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + # Generate + # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + def generate( + self, + frames: Union[str, Path, None] = None, + audio: Optional[Union[str, Path]] = None, + prompt: str = "Describe what you see and hear.", + fps: Optional[float] = None, + video_category: Optional[Literal["short", "medium", "long"]] = None, + max_frames: Optional[int] = None, + max_audio_chunks: Optional[int] = None, + **kwargs, + ) -> str: + """Generate a text response from video and/or audio. + + Args: + frames: Path to video file (.mp4), or None. + audio: Path to audio file (.m4a / .wav), or None. + prompt: Text question/prompt. + fps: Unused โ€” processor samples at 1fps internally. + video_category: Unused; reserved for future frame budgeting. + max_frames: Per-call frame cap (patches model config). + max_audio_chunks: Max 30s audio chunks (None = no limit). + **kwargs: temperature, top_p, num_beams, max_new_tokens. + """ + temperature = kwargs.get("temperature", self.temperature) + top_p = kwargs.get("top_p", self.top_p) + num_beams = kwargs.get("num_beams", self.num_beams) + max_new_tokens = kwargs.get("max_new_tokens", self.max_new_tokens) + + # Per-call frame cap override + if max_frames is not None and max_frames != self.default_max_frames: + self.model.config.video_config.max_frame_num = max_frames + logger.info(f"Per-call frame override: max_frame_num={max_frames}") + + self.stats["total_samples"] += 1 + tmp_audio_path = None # trimmed wav from _prepare_audio + m4a_converted_path = None # intermediate wav from m4a conversion + + try: + # โ”€โ”€ 1. Video โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + has_video = frames is not None and os.path.exists(str(frames)) + video_path = os.path.abspath(str(frames)) if has_video else None + if has_video: + logger.info(f"Video: {Path(frames).name}") + + compatible = self._check_video_compatibility(Path(video_path)) + if compatible is None: + raise RuntimeError(f"Video codec incompatible with Decord: {video_path}") + + video_path = str(compatible) # use converted path going into processor + # โ”€โ”€ 2. Audio โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + has_audio = audio is not None and os.path.exists(str(audio)) + if has_audio and not self._audio_has_stream(str(audio)): + logger.warning( + f"Audio file has no decodable stream, falling back to video-only: {audio}" + ) + has_audio = False + if has_audio: + audio_for_prepare, m4a_converted_path = self._maybe_convert_m4a( + str(audio) + ) + tmp_audio_path = self._prepare_audio( + audio_for_prepare, max_audio_chunks + ) + + # โ”€โ”€ 3. Build message string and run processor โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + message = self._build_message(video_path, tmp_audio_path, prompt) + logger.info("Running processor...") + + # processor([str]) โ†’ batch mode โ†’ OmniProcessorOutput + ret = self.model.processor([message]) + + # โ”€โ”€ 4. Move tensors to GPU โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + input_ids = ret.input_ids.cuda() + attention_mask = ret.attention_mask.cuda() + + audios = ret.audios.cuda() if ret.audios is not None else None + encoder_length = ret.encoder_length.cuda() if ret.encoder_length is not None else None + bridge_length = ret.bridge_length.cuda() if ret.bridge_length is not None else None + + # Static images: not used (video frames handled via videos= field) + images = None + patch_nums = None + images_grid = None + + videos = ( + [torch.tensor(v, dtype=torch.float32).cuda() for v in ret.videos] + if ret.videos is not None else None + ) + videos_patch_nums = ret.videos_patch_nums if ret.videos_patch_nums is not None else None + videos_grid = ret.videos_grid if ret.videos_grid is not None else None + + # โ”€โ”€ 5. Generate (text only โ€” stop before audio generation) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + logger.info( + f"Generating: video={'yes' if has_video else 'no'}, " + f"audio={'yes' if has_audio else 'no'}" + ) + + with torch.inference_mode(): + output = self.model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + audios=audios, + images=images, + patch_nums=patch_nums, + images_grid=images_grid, + videos=videos, + videos_patch_nums=videos_patch_nums, + videos_grid=videos_grid, + encoder_length=encoder_length, + bridge_length=bridge_length, + tokenizer=self.tokenizer, + # Stop before TTS generation โ€” text output only + stop_strings=[""], + max_new_tokens=max_new_tokens, + do_sample=(temperature > 0), + temperature=temperature if temperature > 0 else None, + top_p=top_p if temperature > 0 else None, + num_beams=num_beams, + return_dict_in_generate=True, + use_cache=True, + ) + + # โ”€โ”€ 6. Decode new tokens only โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + input_len = input_ids.shape[1] + response = self.tokenizer.decode( + output.sequences[0, input_len:], + skip_special_tokens=True, + ).strip() + + # Clean control characters that break JSON parsing + response = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', ' ', response) + # Truncate if excessively long (Baichuan sometimes runs away) + if len(response) > 15000: + # Find last complete JSON structure + last_brace = response.rfind('}') + if last_brace > 0: + response = response[:last_brace + 1] + + logger.info(f"Generated {len(response)} characters") + return self.postprocess_output(response) + + except (torch.cuda.OutOfMemoryError, RuntimeError) as e: + error_msg = str(e) + if "out of memory" in error_msg.lower() or "size of tensor" in error_msg.lower(): + logger.error(f"OOM: {error_msg[:200]}...") + torch.cuda.empty_cache() + raise RuntimeError(f"Out of memory: {e}") + logger.error(f"Generation failed: {e}") + raise RuntimeError(f"Generation failed: {e}") + + finally: + # Remove trimmed wav from _prepare_audio and any m4aโ†’wav intermediate + for p in (tmp_audio_path, m4a_converted_path): + if p is not None and os.path.exists(p): + try: + os.remove(p) + except OSError: + pass + + # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + # Cleanup + # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + def unload(self) -> None: + """Unload model and free GPU memory.""" + logger.info("Unloading Baichuan-Omni model...") + + if self.model is not None: + del self.model + self.model = None + if self.tokenizer is not None: + del self.tokenizer + self.tokenizer = None + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + logger.info("Baichuan-Omni unloaded") + + def get_model_info(self) -> Dict[str, Any]: + info = super().get_model_info() + info.update({ + "model_path": self.model_path, + "backbone": "Qwen2.5-7B (bfloat16)", + "vision_encoder": "CLIP ViT-L/14 (1fps, max_frame_num patched from config)", + "audio_encoder": "Whisper-large (30s window, uniform sampling โ€” same as OLA)", + "native_video": True, + "native_audio": True, + "default_max_frames": self.default_max_frames, + "cache_dir": self.cache_dir, + "statistics": self.stats, + }) + return info diff --git a/sonic-o1/05_evaluation_inference/models/base_model.py b/sonic-o1/05_evaluation_inference/models/base_model.py index 72caf2b..f856b66 100644 --- a/sonic-o1/05_evaluation_inference/models/base_model.py +++ b/sonic-o1/05_evaluation_inference/models/base_model.py @@ -1,52 +1,50 @@ -""" -Base model class for evaluation framework. -All model implementations should inherit from this class. +"""base_model.py + +Abstract base class for all multimodal models in the evaluation framework. + +Author: SONIC-O1 Team """ from abc import ABC, abstractmethod -from typing import Union, List, Optional, Dict, Any, Literal +from typing import Any, Dict, List, Literal, Optional, Union + import numpy as np -from pathlib import Path class BaseModel(ABC): """ Abstract base class for all multimodal models. - + All model implementations must inherit from this class and implement the abstract methods. """ - - def __init__(self, model_name: str, config: Dict[str, Any]): + + def __init__(self, model_name: str, config: Dict[str, Any]) -> None: """ Initialize the base model. - + Args: - model_name: Name of the model - config: Configuration dictionary from models_config.yaml + model_name: Name of the model. + config: Configuration dict from models_config.yaml. """ self.model_name = model_name self.config = config self.model = None - self.supports_video = config.get('supports_video', True) - self.supports_audio = config.get('supports_audio', True) - + self.supports_video = config.get("supports_video", True) + self.supports_audio = config.get("supports_audio", True) + @abstractmethod - def load(self): + def load(self) -> None: """ Initialize and load the model. - - This method should: - - Load model weights - - Initialize processors/tokenizers - - Set up any required configurations - - Move model to appropriate device - + + Load model weights, processors/tokenizers, and move to device. + Raises: - Exception: If model loading fails + Exception: If model loading fails. """ pass - + @abstractmethod def generate( self, @@ -54,14 +52,14 @@ def generate( audio: Optional[Union[np.ndarray, str]], prompt: str, fps: Optional[float] = None, - video_category: Optional[Literal['short', 'medium', 'long']] = None, - max_frames: Optional[int] = None, + video_category: Optional[Literal["short", "medium", "long"]] = None, + max_frames: Optional[int] = None, max_audio_chunks: Optional[int] = None, - **kwargs + **kwargs, ) -> str: """ Generate response from video frames and audio. - + Args: frames: Either: - List of video frames (for image models) @@ -72,103 +70,93 @@ def generate( - Audio file path - None if audio not available prompt: Text prompt for the model - fps: Optional FPS for video processing (used by video models for memory optimization) - video_category: Optional video length category for timeout/memory optimization: + fps: Optional FPS for video (used by video models for memory). + video_category: Optional video length category for timeout/memory: - 'short': < 5 minutes - 'medium': 5-20 minutes - 'long': > 20 minutes - **kwargs: Additional model-specific parameters such as: - - temperature: Sampling temperature - - max_tokens: Maximum generation length - - top_p: Nucleus sampling parameter - + **kwargs: Additional model-specific parameters (e.g. temperature, + max_tokens, top_p). + Returns: - str: Model's text response - + Model's text response. + Raises: - Exception: If generation fails + Exception: If generation fails. """ pass - + @abstractmethod - def unload(self): + def unload(self) -> None: """ Clean up model resources. - - This method should: - - Clear model from memory - - Release GPU memory - - Close any open file handles + + Clear model from memory, release GPU memory, close file handles. """ pass - + def preprocess_frames( - self, - frames: Union[List[np.ndarray], np.ndarray], - **kwargs + self, frames: Union[List[np.ndarray], np.ndarray], **kwargs ) -> Any: """ Preprocess frames for model input. - + This is an optional method that can be overridden for custom preprocessing. - + Args: frames: Input frames - **kwargs: Additional preprocessing parameters - + **kwargs: Additional preprocessing parameters. + Returns: - Preprocessed frames in model-specific format + Preprocessed frames in model-specific format. """ return frames - - def preprocess_audio( - self, - audio: Union[np.ndarray, str], - **kwargs - ) -> Any: + + def preprocess_audio(self, audio: Union[np.ndarray, str], **kwargs) -> Any: """ Preprocess audio for model input. - + This is an optional method that can be overridden for custom preprocessing. - + Args: audio: Input audio - **kwargs: Additional preprocessing parameters - + **kwargs: Additional preprocessing parameters. + Returns: - Preprocessed audio in model-specific format + Preprocessed audio in model-specific format. """ return audio - + def postprocess_output(self, output: Any) -> str: """ Postprocess model output. - + This is an optional method that can be overridden for custom postprocessing. - + Args: - output: Raw model output - + output: Raw model output. + Returns: - str: Cleaned and formatted output text + Cleaned and formatted output text. """ if isinstance(output, str): return output.strip() return str(output).strip() - + def get_model_info(self) -> Dict[str, Any]: """ Get model information. - + Returns: - Dictionary containing model metadata + Dict with name, supports_video, supports_audio, config. """ return { - 'name': self.model_name, - 'supports_video': self.supports_video, - 'supports_audio': self.supports_audio, - 'config': self.config + "name": self.model_name, + "supports_video": self.supports_video, + "supports_audio": self.supports_audio, + "config": self.config, } - + def __repr__(self) -> str: - return f"{self.__class__.__name__}(model_name='{self.model_name}')" \ No newline at end of file + """Return string representation of the model.""" + return f"{self.__class__.__name__}(model_name='{self.model_name}')" diff --git a/sonic-o1/05_evaluation_inference/models/gemini.py b/sonic-o1/05_evaluation_inference/models/gemini.py index 746c958..cf8a761 100644 --- a/sonic-o1/05_evaluation_inference/models/gemini.py +++ b/sonic-o1/05_evaluation_inference/models/gemini.py @@ -1,156 +1,163 @@ +"""gemini.py + +Gemini 3.0 Pro implementation with native video and audio support. + +Author: SONIC-O1 Team """ -models/gemini.py -Gemini 3.0 Pro implementation adapted from working Gemini 2.5 code. -""" + +import logging import os import time -import logging -from typing import Optional, Dict, Any, Literal from pathlib import Path +from typing import Any, Dict, Literal, Optional + try: from google import genai from google.genai import types -except ImportError: +except ImportError as e: raise ImportError( "Please install google-genai: pip install google-genai" - ) + ) from e from .base_model import BaseModel + logger = logging.getLogger(__name__) class Gemini(BaseModel): - """ - Gemini 3.0 Pro Preview wrapper with native video and audio support. - """ - + """Gemini 3.0 Pro Preview wrapper with native video and audio support.""" + # Timeout configurations for different video lengths # short: < 5 minutes # medium: 5-20 minutes # long: > 20 minutes TIMEOUT_CONFIG = { - 'short': 180, # 3 minutes - videos under 5 minutes - 'medium': 600, # 10 minutes - videos 5-20 minutes - 'long': 1800, # 30 minutes - videos over 20 minutes + "short": 180, # 3 minutes - videos under 5 minutes + "medium": 600, # 10 minutes - videos 5-20 minutes + "long": 1800, # 30 minutes - videos over 20 minutes } - - def __init__(self, model_name: str, config: Dict[str, Any]): + + def __init__(self, model_name: str, config: Dict[str, Any]) -> None: super().__init__(model_name, config) - + self.supports_video = True self.supports_audio = True - - api_key_env = config.get('api_key_env', 'GEMINI_API_KEY') + + api_key_env = config.get("api_key_env", "GEMINI_API_KEY") self.api_key = os.getenv(api_key_env) - + if not self.api_key: raise ValueError( f"API key not found. Please set {api_key_env} environment variable." ) - - self.model_version = config.get('model_version', 'gemini-3-pro-preview') - - gen_config = config.get('generation_config', {}) + + self.model_version = config.get("model_version", "gemini-3-pro-preview") + + gen_config = config.get("generation_config", {}) self.generation_config = types.GenerateContentConfig( - temperature=gen_config.get('temperature', 0.7), - top_p=gen_config.get('top_p', 0.95), - top_k=gen_config.get('top_k', 40), - max_output_tokens=gen_config.get('max_output_tokens', 8192), + temperature=gen_config.get("temperature", 0.7), + top_p=gen_config.get("top_p", 0.95), + top_k=gen_config.get("top_k", 40), + max_output_tokens=gen_config.get("max_output_tokens", 8192), ) - + # Retry configuration - self.retry_attempts = config.get('retry_attempts', 3) - self.retry_delay = config.get('retry_delay', 2) - + self.retry_attempts = config.get("retry_attempts", 3) + self.retry_delay = config.get("retry_delay", 2) + # Default timeout (can be overridden per request) - self.default_timeout = config.get('file_processing_timeout', - self.TIMEOUT_CONFIG['medium']) - + self.default_timeout = config.get( + "file_processing_timeout", self.TIMEOUT_CONFIG["medium"] + ) + self.client = None - + def _get_timeout_for_category( - self, - category: Optional[Literal['short', 'medium', 'long']] = None + self, category: Optional[Literal["short", "medium", "long"]] = None ) -> int: """ Get appropriate timeout based on video category. - + Args: category: Video length category - 'short': < 5 minutes - - 'medium': 5-20 minutes + - 'medium': 5-20 minutes - 'long': > 20 minutes - - None: uses default timeout - + - None: uses default timeout. + Returns: - Timeout in seconds + Timeout in seconds. """ if category is None: return self.default_timeout - + timeout = self.TIMEOUT_CONFIG.get(category) if timeout is None: logger.warning(f"Unknown category '{category}', using default timeout") return self.default_timeout - + logger.info(f"Using {timeout}s timeout for '{category}' video") return timeout - + def _estimate_timeout_from_file(self, video_path: Path) -> int: """ Estimate timeout based on file size as a fallback. - Rule of thumb: ~1 second per MB + base timeout of 180s - + + Rule of thumb: ~1 second per MB + base timeout of 180s. + Args: - video_path: Path to video file - + video_path: Path to video file. + Returns: - Estimated timeout in seconds + Estimated timeout in seconds. """ try: file_size_mb = video_path.stat().st_size / (1024 * 1024) - + # Base timeout + file size factor base_timeout = 180 size_factor = int(file_size_mb) # 1 second per MB - + estimated = base_timeout + size_factor - + # Cap at 'long' timeout maximum - max_timeout = self.TIMEOUT_CONFIG['long'] + max_timeout = self.TIMEOUT_CONFIG["long"] timeout = min(estimated, max_timeout) - - logger.info(f"Estimated timeout from file size ({file_size_mb:.1f}MB): {timeout}s") + + logger.info( + f"Estimated timeout from file size ({file_size_mb:.1f}MB): {timeout}s" + ) return timeout - + except Exception as e: logger.warning(f"Could not estimate timeout from file: {e}") return self.default_timeout - - def load(self): + + def load(self) -> None: + """Load the Gemini model and initialize client.""" try: - os.environ['GEMINI_API_KEY'] = self.api_key + os.environ["GEMINI_API_KEY"] = self.api_key self.client = genai.Client() - logger.info(f"Loaded Gemini 3.0 Pro Preview") + logger.info("Loaded Gemini 3.0 Pro Preview") except Exception as e: - raise RuntimeError(f"Failed to load Gemini client: {e}") - + raise RuntimeError(f"Failed to load Gemini client: {e}") from e + def generate( self, frames: str, audio: Optional[str], prompt: str, - video_category: Optional[Literal['short', 'medium', 'long']] = None, + video_category: Optional[Literal["short", "medium", "long"]] = None, fps: Optional[float] = None, max_frames: Optional[int] = None, max_audio_chunks: Optional[int] = None, - **kwargs + **kwargs, ) -> str: """ Generate content from video/audio with Gemini. - + Args: frames: Path to video file audio: Optional path to audio file @@ -159,58 +166,58 @@ def generate( - 'short': < 5 minutes (180s timeout) - 'medium': 5-20 minutes (600s timeout) - 'long': > 20 minutes (1800s timeout) - **kwargs: Additional generation parameters - + **kwargs: Additional generation parameters. + Returns: - Generated text response + Generated text response. """ if self.client is None: raise RuntimeError("Client not loaded. Call load() first.") - + if not isinstance(frames, str): raise ValueError( f"Gemini requires video file path (str), got {type(frames)}" ) - + try: - media_files = [('video', Path(frames))] - + media_files = [("video", Path(frames))] + if audio is not None and isinstance(audio, str) and os.path.exists(audio): - media_files.append(('audio', Path(audio))) - + media_files.append(("audio", Path(audio))) + config = self.generation_config - if 'temperature' in kwargs: - config.temperature = kwargs['temperature'] - if 'max_output_tokens' in kwargs: - config.max_output_tokens = kwargs['max_output_tokens'] - + if "temperature" in kwargs: + config.temperature = kwargs["temperature"] + if "max_output_tokens" in kwargs: + config.max_output_tokens = kwargs["max_output_tokens"] + # Determine timeout if video_category: timeout = self._get_timeout_for_category(video_category) else: # Fallback: estimate from file size timeout = self._estimate_timeout_from_file(Path(frames)) - + response_text = self._process_with_file_api( media_files, prompt, config, timeout, fps ) - + return self.postprocess_output(response_text) - + except Exception as e: - raise RuntimeError(f"Generation failed: {e}") - + raise RuntimeError(f"Generation failed: {e}") from e + def _process_with_file_api( self, media_files: list, prompt: str, config: types.GenerateContentConfig, timeout: int, - fps: Optional[float] = None + fps: Optional[float] = None, ) -> str: """ - Process files using Gemini File API - + Process files using Gemini File API. + Args: media_files: List of (media_type, Path) tuples prompt: Text prompt @@ -218,7 +225,7 @@ def _process_with_file_api( timeout: Processing timeout in seconds """ uploaded_files = [] - + try: # Upload all files for media_type, media_path in media_files: @@ -226,53 +233,55 @@ def _process_with_file_api( uploaded_file = self.client.files.upload(file=str(media_path)) logger.info(f"Uploaded {media_type}: {uploaded_file.name}") uploaded_files.append((media_type, uploaded_file)) - + # Wait for processing with configurable timeout check_interval = 2 # Check every 2 seconds wait_time = 0 all_processed = False - + logger.info(f"Waiting for file processing (timeout: {timeout}s)...") - + while not all_processed and wait_time < timeout: all_processed = True for i, (media_type, uploaded_file) in enumerate(uploaded_files): updated_file = self.client.files.get(name=uploaded_file.name) uploaded_files[i] = (media_type, updated_file) - + if updated_file.state == "PROCESSING": all_processed = False if wait_time % 10 == 0: # Log every 10 seconds - logger.info(f"Still processing {media_type} ({wait_time}s elapsed)...") + logger.info( + f"Still processing {media_type} ({wait_time}s elapsed)..." + ) elif updated_file.state == "FAILED": - error_msg = getattr(updated_file, 'error', 'Unknown error') + error_msg = getattr(updated_file, "error", "Unknown error") raise Exception(f"File processing failed: {error_msg}") - + if not all_processed: time.sleep(check_interval) wait_time += check_interval - + if not all_processed: raise Exception(f"File processing timeout after {timeout}s") - + logger.info(f"All files processed successfully in {wait_time}s") - + # Generate content with retries for attempt in range(self.retry_attempts): try: content_parts = [] - + # Add all media files for media_type, uploaded_file in uploaded_files: - if media_type == 'video' and fps is not None: + if media_type == "video" and fps is not None: # Add video with FPS metadata content_parts.append( types.Part( file_data=types.FileData( file_uri=uploaded_file.uri, - mime_type=uploaded_file.mime_type + mime_type=uploaded_file.mime_type, ), - video_metadata=types.VideoMetadata(fps=fps) + video_metadata=types.VideoMetadata(fps=fps), ) ) else: @@ -281,56 +290,63 @@ def _process_with_file_api( types.Part( file_data=types.FileData( file_uri=uploaded_file.uri, - mime_type=uploaded_file.mime_type + mime_type=uploaded_file.mime_type, ) ) ) - + # Add prompt content_parts.append(types.Part(text=prompt)) - + response = self.client.models.generate_content( model=self.model_version, contents=types.Content(parts=content_parts), - config=config + config=config, ) - + return response.text - + except Exception as e: error_str = str(e) - if "429" in error_str or "quota" in error_str.lower() or "resource_exhausted" in error_str.lower(): - if attempt < self.retry_attempts - 1: - wait_time = (attempt + 1) * 5 - logger.warning(f"Rate limit hit, waiting {wait_time}s...") - time.sleep(wait_time) - continue + if ( + "429" in error_str + or "quota" in error_str.lower() + or "resource_exhausted" in error_str.lower() + ) and attempt < self.retry_attempts - 1: + wait_time = (attempt + 1) * 5 + logger.warning(f"Rate limit hit, waiting {wait_time}s...") + time.sleep(wait_time) + continue logger.warning(f"Generation attempt {attempt + 1} failed: {e}") if attempt < self.retry_attempts - 1: time.sleep(self.retry_delay) else: raise - + finally: # Cleanup uploaded files - for media_type, uploaded_file in uploaded_files: + for _, uploaded_file in uploaded_files: try: self.client.files.delete(name=uploaded_file.name) logger.debug(f"Deleted uploaded file: {uploaded_file.name}") except Exception as e: logger.warning(f"Failed to delete file {uploaded_file.name}: {e}") - - def unload(self): + + def unload(self) -> None: + """Unload the model and clean up resources.""" self.client = None - + def get_model_info(self) -> Dict[str, Any]: + """Get model information and configuration.""" info = super().get_model_info() - info.update({ - 'model_version': self.model_version, - 'api_based': True, - 'native_video': True, - 'native_audio': True, - 'sdk': 'google-genai', - 'timeout_config': self.TIMEOUT_CONFIG - }) - return info \ No newline at end of file + info.update( + { + "model_version": self.model_version, + "api_based": True, + "native_video": True, + "native_audio": True, + "sdk": "google-genai", + "timeout_config": self.TIMEOUT_CONFIG, + } + ) + return info diff --git a/sonic-o1/05_evaluation_inference/models/gpt4o.py b/sonic-o1/05_evaluation_inference/models/gpt4o.py index 65274e4..a4c55b4 100644 --- a/sonic-o1/05_evaluation_inference/models/gpt4o.py +++ b/sonic-o1/05_evaluation_inference/models/gpt4o.py @@ -1,104 +1,114 @@ -""" -models/gpt4o.py +"""gpt4o.py GPT-4o implementation with video frames and caption support. + +Author: SONIC-O1 Team """ -import os -import time + import base64 import logging -from typing import Optional, Dict, Any, List, Union +import os +import time from pathlib import Path +from typing import Any, Dict, List, Optional, Union + try: from openai import OpenAI -except ImportError: - raise ImportError( - "Please install openai: pip install openai" - ) +except ImportError as e: + raise ImportError("Please install openai: pip install openai") from e -from .base_model import BaseModel -from utils.frame_sampler import FrameSampler from utils.caption_handler import CaptionHandler +from utils.frame_sampler import FrameSampler + +from .base_model import BaseModel + logger = logging.getLogger(__name__) class GPT4o(BaseModel): """ - GPT-4o wrapper with support for: + GPT-4o wrapper with multimodal support. + - Video frames (extracted and encoded as images) - Captions (from SRT files) - - Multimodal (frames + captions) + - Multimodal (frames + captions). """ - - def __init__(self, model_name: str, config: Dict[str, Any]): + + def __init__(self, model_name: str, config: Dict[str, Any]) -> None: super().__init__(model_name, config) - + # Model capabilities - self.supports_video = config.get('supports_video', True) + self.supports_video = config.get("supports_video", True) self.supports_audio = False # GPT-4o doesn't support audio streams - self.use_captions = config.get('use_captions', False) - + self.use_captions = config.get("use_captions", False) + # API configuration - api_key_env = config.get('api_key_env', 'OPENAI_API_KEY') + api_key_env = config.get("api_key_env", "OPENAI_API_KEY") self.api_key = os.getenv(api_key_env) - + if not self.api_key: raise ValueError( f"API key not found. Please set {api_key_env} environment variable." ) - - self.model_version = config.get('model_version', 'gpt-4o') - + + self.model_version = config.get("model_version", "gpt-4o") + # Generation config - gen_config = config.get('generation_config', {}) - self.temperature = gen_config.get('temperature', 0.7) - self.max_tokens = gen_config.get('max_tokens', 4096) - self.top_p = gen_config.get('top_p', 1.0) - + gen_config = config.get("generation_config", {}) + self.temperature = gen_config.get("temperature", 0.7) + self.max_tokens = gen_config.get("max_tokens", 4096) + self.top_p = gen_config.get("top_p", 1.0) + # Frame configuration - self.max_frames = config.get('max_frames', 128) - self.image_detail = config.get('image_detail', 'auto') # 'auto', 'low', 'high' - + self.max_frames = config.get("max_frames", 128) + self.image_detail = config.get("image_detail", "auto") # 'auto', 'low', 'high' + # Retry configuration - retry_config = config.get('retry_override', config.get('retry', {})) - self.frame_count_fallback = retry_config.get('frame_count_fallback', [128, 64, 32, 16]) - self.caption_chunks_fallback = retry_config.get('caption_chunks_fallback', [None, 32, 16, 8]) - - self.retry_attempts = config.get('retry_attempts', 3) - self.retry_delay = config.get('retry_delay', 2) - + retry_config = config.get("retry_override", config.get("retry", {})) + self.frame_count_fallback = retry_config.get( + "frame_count_fallback", [128, 64, 32, 16] + ) + self.caption_chunks_fallback = retry_config.get( + "caption_chunks_fallback", [None, 32, 16, 8] + ) + + self.retry_attempts = config.get("retry_attempts", 3) + self.retry_delay = config.get("retry_delay", 2) + # Dataset root for caption discovery - self.dataset_root = config.get('dataset_root', None) - + self.dataset_root = config.get("dataset_root") + # Initialize handlers self.client = None self.frame_sampler = None self.caption_handler = None - - def load(self): - """Initialize OpenAI client and handlers""" + + def load(self) -> None: + """Initialize OpenAI client and handlers.""" try: self.client = OpenAI(api_key=self.api_key) - + # Initialize frame sampler if video support enabled if self.supports_video: self.frame_sampler = FrameSampler() logger.info("Frame sampler initialized") - + # Initialize caption handler if caption support enabled if self.use_captions: self.caption_handler = CaptionHandler( caption_chunks_fallback=self.caption_chunks_fallback ) logger.info("Caption handler initialized") - + logger.info(f"Loaded GPT-4o ({self.model_version})") - logger.info(f"Video support: {self.supports_video}, Caption support: {self.use_captions}") - + logger.info( + f"Video support: {self.supports_video}, Caption support: {self.use_captions}" + ) + except Exception as e: - raise RuntimeError(f"Failed to load GPT-4o: {e}") - + raise RuntimeError(f"Failed to load GPT-4o: {e}") from e + def generate( self, frames: Union[str, List[Path]], @@ -110,11 +120,11 @@ def generate( max_caption_chunks: Optional[int] = None, caption_path: Optional[str] = None, segment: Optional[Dict] = None, # NEW: {'start': 30.0, 'end': 60.0} - **kwargs + **kwargs, ) -> str: """ Generate response using GPT-4o. - + Args: frames: Video file path (str) for frame extraction audio: Ignored (GPT-4o doesn't support audio) @@ -125,229 +135,237 @@ def generate( max_caption_chunks: Maximum caption chunks (for retry logic) caption_path: Optional explicit caption file path segment: Optional segment info {'start': float, 'end': float} for caption filtering - **kwargs: Additional generation parameters - + **kwargs: Additional generation parameters. + Returns: - Generated text response + Generated text response. """ if self.client is None: raise RuntimeError("Client not loaded. Call load() first.") - + # Validate inputs based on configuration if self.supports_video and not self.use_captions: # Video-only mode if not isinstance(frames, str): - raise ValueError("GPT-4o requires video file path (str) for video-only mode") + raise ValueError( + "GPT-4o requires video file path (str) for video-only mode" + ) elif not self.supports_video and self.use_captions: # Text-only mode - captions required pass # Will auto-discover or use provided caption_path elif self.supports_video and self.use_captions: # Multimodal mode if not isinstance(frames, str): - raise ValueError("GPT-4o requires video file path (str) for multimodal mode") + raise ValueError( + "GPT-4o requires video file path (str) for multimodal mode" + ) else: raise ValueError("GPT-4o must have either video or caption support enabled") - + try: # Determine actual max frames and caption chunks actual_max_frames = max_frames or self.max_frames actual_max_caption_chunks = max_caption_chunks # None or int - + # Process inputs video_path = Path(frames) if isinstance(frames, str) else None - + # Auto-discover caption path if needed and not provided if self.use_captions and caption_path is None and video_path is not None: caption_path = self.caption_handler.auto_discover_caption_path( - video_path, - dataset_root=Path(self.dataset_root) if self.dataset_root else None + video_path, + dataset_root=Path(self.dataset_root) if self.dataset_root else None, ) - + # Extract segment info only if caption_handler is available - if segment is None and video_path is not None and self.caption_handler is not None: + if ( + segment is None + and video_path is not None + and self.caption_handler is not None + ): segment = self.caption_handler.extract_segment_info(video_path) if segment: logger.info(f"Auto-extracted segment info: {segment}") - + # Build message content content_parts = [] - + # Add frames if video mode if self.supports_video and video_path is not None: frame_paths = self._extract_frames(video_path, actual_max_frames) - + for frame_path in frame_paths: base64_image = self._encode_image_to_base64(frame_path) - content_parts.append({ - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{base64_image}", - "detail": self.image_detail + content_parts.append( + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}", + "detail": self.image_detail, + }, } - }) - + ) + logger.info(f"Added {len(frame_paths)} frames to request") - + # Add captions if caption mode if self.use_captions and caption_path is not None: # Check if we need segment-based caption extraction if segment is not None: # Task 2/3: Extract captions for specific time segment - start_time = segment.get('start', 0.0) - end_time = segment.get('end', None) - + start_time = segment.get("start", 0.0) + end_time = segment.get("end", None) + if end_time is not None: - caption_text = self.caption_handler.get_caption_text_for_segment( - Path(caption_path), - start_time=start_time, - end_time=end_time, - num_chunks=actual_max_caption_chunks + caption_text = ( + self.caption_handler.get_caption_text_for_segment( + Path(caption_path), + start_time=start_time, + end_time=end_time, + num_chunks=actual_max_caption_chunks, + ) + ) + logger.info( + f"Extracted caption for segment [{start_time:.1f}s - {end_time:.1f}s]" ) - logger.info(f"Extracted caption for segment [{start_time:.1f}s - {end_time:.1f}s]") else: # No end time provided, fall back to full caption caption_text = self.caption_handler.get_caption_text( - Path(caption_path), - num_chunks=actual_max_caption_chunks + Path(caption_path), num_chunks=actual_max_caption_chunks ) else: # Task 1: Full video caption caption_text = self.caption_handler.get_caption_text( - Path(caption_path), - num_chunks=actual_max_caption_chunks + Path(caption_path), num_chunks=actual_max_caption_chunks ) - + if caption_text: # Add captions as a separate text block - content_parts.append({ - "type": "text", - "text": f"Transcript:\n{caption_text}" - }) - logger.info(f"Added caption text ({len(caption_text)} chars, chunks={actual_max_caption_chunks})") + content_parts.append( + {"type": "text", "text": f"Transcript:\n{caption_text}"} + ) + logger.info( + f"Added caption text ({len(caption_text)} chars, chunks={actual_max_caption_chunks})" + ) else: logger.warning("No caption text extracted") - + # Add prompt - content_parts.append({ - "type": "text", - "text": prompt - }) - + content_parts.append({"type": "text", "text": prompt}) + # Generate with retries for attempt in range(self.retry_attempts): try: response = self.client.chat.completions.create( model=self.model_version, - messages=[ - { - "role": "user", - "content": content_parts - } - ], - temperature=kwargs.get('temperature', self.temperature), - max_tokens=kwargs.get('max_tokens', self.max_tokens), - top_p=kwargs.get('top_p', self.top_p), + messages=[{"role": "user", "content": content_parts}], + temperature=kwargs.get("temperature", self.temperature), + max_tokens=kwargs.get("max_tokens", self.max_tokens), + top_p=kwargs.get("top_p", self.top_p), ) - + response_text = response.choices[0].message.content logger.info(f"Generated response ({len(response_text)} chars)") - + return self.postprocess_output(response_text) - + except Exception as e: error_str = str(e) - + # Handle rate limits - if "429" in error_str or "rate_limit" in error_str.lower(): - if attempt < self.retry_attempts - 1: - wait_time = (attempt + 1) * 5 - logger.warning(f"Rate limit hit, waiting {wait_time}s...") - time.sleep(wait_time) - continue - + if ("429" in error_str or "rate_limit" in error_str.lower()) and attempt < self.retry_attempts - 1: + wait_time = (attempt + 1) * 5 + logger.warning(f"Rate limit hit, waiting {wait_time}s...") + time.sleep(wait_time) + continue + # Handle context length errors - if "context_length" in error_str.lower() or "maximum context" in error_str.lower(): + if ( + "context_length" in error_str.lower() + or "maximum context" in error_str.lower() + ): logger.error(f"Context length exceeded: {e}") - raise RuntimeError(f"Context length exceeded: {e}") - + raise RuntimeError(f"Context length exceeded: {e}") from e + logger.warning(f"Generation attempt {attempt + 1} failed: {e}") if attempt < self.retry_attempts - 1: time.sleep(self.retry_delay) else: raise - + except Exception as e: - raise RuntimeError(f"Generation failed: {e}") - + raise RuntimeError(f"Generation failed: {e}") from e + def _extract_frames(self, video_path: Path, num_frames: int) -> List[Path]: """ Extract frames from video using FrameSampler. - + Args: video_path: Path to video file - num_frames: Number of frames to extract - + num_frames: Number of frames to extract. + Returns: - List of paths to extracted frame images + List of paths to extracted frame images. """ if self.frame_sampler is None: raise RuntimeError("Frame sampler not initialized") - + logger.info(f"Extracting {num_frames} frames from {video_path.name}") - + frame_paths = self.frame_sampler.sample_frames_uniform( - video_path=video_path, - num_frames=num_frames + video_path=video_path, num_frames=num_frames ) - + if not frame_paths: raise RuntimeError(f"Failed to extract frames from {video_path}") - + return frame_paths - + def _encode_image_to_base64(self, image_path: Path) -> str: """ Encode image file to base64 string. - + Args: - image_path: Path to image file - + image_path: Path to image file. + Returns: - Base64 encoded string + Base64 encoded string. """ try: - with open(image_path, 'rb') as f: - return base64.b64encode(f.read()).decode('utf-8') + with open(image_path, "rb") as f: + return base64.b64encode(f.read()).decode("utf-8") except Exception as e: - raise RuntimeError(f"Failed to encode image {image_path}: {e}") - - def unload(self): - """Clean up resources""" + raise RuntimeError(f"Failed to encode image {image_path}: {e}") from e + + def unload(self) -> None: + """Clean up resources.""" # Cleanup frame sampler if self.frame_sampler is not None: self.frame_sampler.cleanup() self.frame_sampler = None - + # Cleanup caption handler if self.caption_handler is not None: self.caption_handler.cleanup() self.caption_handler = None - + self.client = None logger.info("GPT-4o unloaded and resources cleaned up") - + def get_model_info(self) -> Dict[str, Any]: - """Get model information""" + """Get model information.""" info = super().get_model_info() - info.update({ - 'model_version': self.model_version, - 'api_based': True, - 'supports_video': self.supports_video, - 'supports_audio': False, - 'use_captions': self.use_captions, - 'max_frames': self.max_frames, - 'image_detail': self.image_detail, - 'frame_count_fallback': self.frame_count_fallback, - 'caption_chunks_fallback': self.caption_chunks_fallback, - }) - return info \ No newline at end of file + info.update( + { + "model_version": self.model_version, + "api_based": True, + "supports_video": self.supports_video, + "supports_audio": False, + "use_captions": self.use_captions, + "max_frames": self.max_frames, + "image_detail": self.image_detail, + "frame_count_fallback": self.frame_count_fallback, + "caption_chunks_fallback": self.caption_chunks_fallback, + } + ) + return info diff --git a/sonic-o1/05_evaluation_inference/models/minicpm.py b/sonic-o1/05_evaluation_inference/models/minicpm.py index f51f367..ac4394a 100644 --- a/sonic-o1/05_evaluation_inference/models/minicpm.py +++ b/sonic-o1/05_evaluation_inference/models/minicpm.py @@ -1,20 +1,23 @@ -""" -models/minicpm.py +"""minicpm.py MiniCPM-o-2.6 implementation with omni multimodal support. + +Author: SONIC-O1 Team """ -import os + import logging -from typing import Optional, Dict, Any, Union, List, Literal -from pathlib import Path import math +import os +import subprocess +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Union + try: - import torch + import av + import librosa import numpy as np + import torch from transformers import AutoModel, AutoTokenizer - from PIL import Image - import librosa - import av except ImportError as e: raise ImportError( f"Please install required packages: {e}\n" @@ -23,57 +26,62 @@ from .base_model import BaseModel + logger = logging.getLogger(__name__) class MiniCPM(BaseModel): """ MiniCPM-o-2.6 wrapper with omni multimodal support. + Processes video frames and audio chunks in a specialized format. Automatically calculates optimal FPS based on max_frames limit. """ - - def __init__(self, model_name: str, config: Dict[str, Any]): + + def __init__(self, model_name: str, config: Dict[str, Any]) -> None: super().__init__(model_name, config) - + # Model configuration - self.model_path = config.get('model_path', 'openbmb/MiniCPM-o-2_6') - + self.model_path = config.get("model_path", "openbmb/MiniCPM-o-2_6") + # Device configuration - self.device = config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu') - self.dtype = config.get('dtype', torch.bfloat16) - self.attn_implementation = config.get('attn_implementation', 'sdpa') - + self.device = config.get( + "device", "cuda" if torch.cuda.is_available() else "cpu" + ) + self.dtype = config.get("dtype", torch.bfloat16) + self.attn_implementation = config.get("attn_implementation", "sdpa") + # Generation config - gen_config = config.get('generation_config', {}) - self.temperature = gen_config.get('temperature', 0.7) - self.top_p = gen_config.get('top_p', 0.95) - self.max_new_tokens = gen_config.get('max_new_tokens', 2048) - + gen_config = config.get("generation_config", {}) + self.temperature = gen_config.get("temperature", 0.7) + self.top_p = gen_config.get("top_p", 0.95) + self.max_new_tokens = gen_config.get("max_new_tokens", 2048) + # Audio processing config - self.audio_sr = config.get('audio_sample_rate', 16000) - self.audio_mono = config.get('audio_mono', True) - + self.audio_sr = config.get("audio_sample_rate", 16000) + self.audio_mono = config.get("audio_mono", True) + # Frame limits - self.default_min_frames = config.get('min_frames', 64) - self.default_max_frames = config.get('max_frames', 256) + self.default_min_frames = config.get("min_frames", 64) + self.default_max_frames = config.get("max_frames", 256) - logger.info(f"MiniCPM initialized with default frame limits: {self.default_min_frames}-{self.default_max_frames}") + logger.info( + f"MiniCPM frame limits: {self.default_min_frames}-{self.default_max_frames}" + ) # Model settings - self.init_vision = config.get('init_vision', True) - self.init_audio = config.get('init_audio', True) - self.init_tts = config.get('init_tts', False) - self.language = config.get('language', 'en') - + self.init_vision = config.get("init_vision", True) + self.init_audio = config.get("init_audio", True) + self.init_tts = config.get("init_tts", False) + self.language = config.get("language", "en") + self.model = None self.tokenizer = None - - - def load(self): - """Load the MiniCPM-o-2.6 model and tokenizer""" + + def load(self) -> None: + """Load the MiniCPM-o-2.6 model and tokenizer.""" try: logger.info(f"Loading MiniCPM-o-2.6 model from {self.model_path}") - + # Load model self.model = AutoModel.from_pretrained( self.model_path, @@ -83,38 +91,37 @@ def load(self): init_vision=self.init_vision, init_audio=self.init_audio, init_tts=self.init_tts, - ) self.model = self.model.eval().to(self.device) - + # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained( - self.model_path, - trust_remote_code=True + self.model_path, trust_remote_code=True ) - + logger.info(f"Successfully loaded MiniCPM-o-2.6 on {self.device}") - + except Exception as e: raise RuntimeError(f"Failed to load MiniCPM-o-2.6 model: {e}") - + def _calculate_optimal_fps(self, duration: float) -> float: """ Calculate optimal FPS to stay within max_frames limit. + Always uses 1.0 fps unless it would exceed max_frames. - + Args: - duration: Video duration in seconds - + duration: Video duration in seconds. + Returns: - Optimal FPS value (frames per second) + Optimal FPS value (frames per second). """ # Start with 1 fps (1 frame per second) target_fps = 1.0 - + # Calculate how many frames we'd get at 1 fps frames_at_1fps = int(duration * 1.0) - + # If 1 fps exceeds max_frames, reduce fps to hit max_frames exactly if frames_at_1fps > self.default_max_frames: target_fps = self.default_max_frames / duration @@ -127,7 +134,7 @@ def _calculate_optimal_fps(self, duration: float) -> float: f"Duration: {duration:.1f}s @ 1fps = {frames_at_1fps} frames " f"(within max_frames={self.default_max_frames})" ) - + # Check min_frames estimated_frames = int(duration * target_fps) if estimated_frames < self.default_min_frames: @@ -135,70 +142,83 @@ def _calculate_optimal_fps(self, duration: float) -> float: f"Estimated {estimated_frames} frames is below min_frames={self.default_min_frames}. " f"Video may be too short for meaningful analysis." ) - + logger.info(f"Using FPS: {target_fps:.4f} -> ~{estimated_frames} frames") return target_fps - + def _extract_video_audio_chunks( - self, - video_path: str, + self, + video_path: str, audio_path: str, target_num_frames: int, # max_frames from config - flatten: bool = True + flatten: bool = True, ) -> List: """ Extract video frames and audio chunks, then uniformly subsample. - + Strategy: 1. Extract ALL frames at 1 fps (proven working approach) 2. Extract ALL 1-second audio chunks 3. Uniformly subsample both to target_num_frames """ - logger.info(f"Extracting frames and audio at 1 fps, then subsampling to {target_num_frames} frames...") - - # Load audio with librosa - audio_np, sr = librosa.load(audio_path, sr=self.audio_sr, mono=self.audio_mono) + logger.info( + f"Extracting frames and audio at 1 fps, then subsampling to {target_num_frames} frames..." + ) + + try: + audio_np, sr = librosa.load(audio_path, sr=self.audio_sr, mono=self.audio_mono) + if len(audio_np) == 0: + raise ValueError("Empty audio") + except Exception: + logger.warning(f"Audio unloadable or empty ({Path(audio_path).name}), substituting silence.") + _tmp_container = av.open(video_path) + _tmp_stream = _tmp_container.streams.video[0] + _seg_duration = int(_tmp_stream.frames / float(_tmp_stream.average_rate)) + _tmp_container.close() + sr = self.audio_sr + audio_np = np.zeros(_seg_duration * sr, dtype=np.float32) audio_duration = len(audio_np) / sr - + # Load video with PyAV container = av.open(video_path) video_stream = container.streams.video[0] video_fps = float(video_stream.average_rate) total_frames = video_stream.frames video_duration = total_frames / video_fps - + logger.info(f" Video: {video_duration:.1f}s @ {video_fps:.1f}fps") logger.info(f" Audio: {audio_duration:.1f}s @ {sr}Hz") - + # Use the shorter duration duration = min(audio_duration, video_duration) num_units = math.ceil(duration) # 1 fps = ceil(duration) frames - + logger.info(f" Step 1: Extracting {num_units} units at 1 fps...") - + # Lists to collect all frames and audio chunks frames_list = [] audio_list = [] - + # Extract at 1 fps (matches working example) for i in range(num_units): # Frame at second i+1 (working example logic) target_time = min(i + 1, duration) - + # Skip if exceeds duration if target_time > duration: break - + # Calculate PTS for seeking target_pts = int( - target_time * video_stream.time_base.denominator / - video_stream.time_base.numerator + target_time + * video_stream.time_base.denominator + / video_stream.time_base.numerator ) - + try: # Seek and extract frame container.seek(target_pts, stream=video_stream) - + frame = None for packet in container.demux(video_stream): for frame_obj in packet.decode(): @@ -206,53 +226,58 @@ def _extract_video_audio_chunks( break if frame is not None: break - + if frame is not None: # Convert to PIL Image image = frame.to_image() - + # Get 1 second of audio (working example logic) - audio_chunk = audio_np[sr*i:sr*(i+1)] - - # Verify audio chunk - if len(audio_chunk) == 0: - logger.warning(f" Empty audio chunk at unit {i}, skipping") + audio_chunk = audio_np[sr * i : sr * (i + 1)] + + + if len(audio_chunk) < sr: + audio_chunk = np.pad(audio_chunk, (0, sr - len(audio_chunk))) + continue - + # Add to lists frames_list.append(image) audio_list.append(audio_chunk) - + except Exception as e: logger.warning(f" Error at unit {i}: {e}") continue - + container.close() - + total_extracted = len(frames_list) logger.info(f" Step 1 complete: Extracted {total_extracted} frame-audio pairs") - + # Step 2: Uniform subsampling if needed if total_extracted <= target_num_frames: # No subsampling needed - logger.info(f" Step 2: No subsampling needed ({total_extracted} <= {target_num_frames})") + logger.info( + f" Step 2: No subsampling needed ({total_extracted} <= {target_num_frames})" + ) selected_frames = frames_list selected_audio = audio_list else: # Uniformly subsample - logger.info(f" Step 2: Subsampling {total_extracted} -> {target_num_frames} frames...") - + logger.info( + f" Step 2: Subsampling {total_extracted} -> {target_num_frames} frames..." + ) + # Calculate indices for uniform sampling indices = np.linspace(0, total_extracted - 1, target_num_frames, dtype=int) - + selected_frames = [frames_list[i] for i in indices] selected_audio = [audio_list[i] for i in indices] - + logger.info(f" Subsampling indices: {indices[:5]}...{indices[-5:]}") - + final_count = len(selected_frames) logger.info(f" Final: {final_count} frame-audio pairs ready") - + # Build contents in MiniCPM format contents = [] for i in range(final_count): @@ -260,7 +285,7 @@ def _extract_video_audio_chunks( contents.extend(["", selected_frames[i], selected_audio[i]]) else: contents.append(["", selected_frames[i], selected_audio[i]]) - + # Validate if flatten: expected_elements = final_count * 3 @@ -268,98 +293,104 @@ def _extract_video_audio_chunks( raise RuntimeError( f"Content structure error: {len(contents)} elements != {expected_elements} expected" ) - + logger.info(f" โœ“ Built {final_count} units for model input") - + return contents - + def generate( self, frames: Union[List[np.ndarray], np.ndarray, str], audio: Optional[Union[np.ndarray, str]], prompt: str, fps: Optional[float] = None, # Ignored - kept for API compatibility - video_category: Optional[Literal['short', 'medium', 'long']] = None, - max_frames: Optional[int] = None, + video_category: Optional[Literal["short", "medium", "long"]] = None, + max_frames: Optional[int] = None, max_audio_chunks: Optional[int] = None, - **kwargs + **kwargs, ) -> str: """ Generate response from video and audio. - + Note: fps parameter is ignored. Frame extraction is always based on max_frames. - + Args: frames: Video file path (str) audio: Audio file path (str) prompt: Text prompt for generation fps: Ignored - kept for API compatibility video_category: Unused - **kwargs: Additional generation parameters - + **kwargs: Additional generation parameters. + Returns: - Generated text response + Generated text response. """ if self.model is None or self.tokenizer is None: raise RuntimeError("Model not loaded. Call load() first.") - + # Use max_frames from parameter if provided, otherwise use default - actual_max_frames = max_frames if max_frames is not None else self.default_max_frames - + actual_max_frames = ( + max_frames if max_frames is not None else self.default_max_frames + ) + # Validate inputs if not isinstance(frames, str): raise ValueError( f"MiniCPM requires video file path (str), got {type(frames)}. " f"Ensure 'supports_video: true' in config." ) - + if not isinstance(audio, str): raise ValueError( f"MiniCPM requires audio file path (str), got {type(audio)}" ) - + try: video_path = Path(frames) audio_path = Path(audio) - + if not video_path.exists(): raise FileNotFoundError(f"Video file not found: {video_path}") if not audio_path.exists(): raise FileNotFoundError(f"Audio file not found: {audio_path}") - + logger.info(f"Processing video: {video_path.name}") logger.info(f"Processing audio: {audio_path.name}") - logger.info(f"Using max_frames: {actual_max_frames}") # LOG THE ACTUAL VALUE - + logger.info( + f"Using max_frames: {actual_max_frames}" + ) # LOG THE ACTUAL VALUE + # Extract frames and audio chunks contents = self._extract_video_audio_chunks( - str(video_path), + str(video_path), str(audio_path), - target_num_frames=actual_max_frames, - flatten=True + target_num_frames=actual_max_frames, + flatten=True, ) - + # Validate num_units = len(contents) // 3 logger.info(f"Total units ready for model: {num_units}") - - if num_units > actual_max_frames: + + if num_units > actual_max_frames: raise RuntimeError( f"BUG: Extracted {num_units} frames exceeds max_frames ({actual_max_frames})" ) # Build conversation - sys_msg = self.model.get_sys_prompt(mode='omni', language=self.language) + sys_msg = self.model.get_sys_prompt(mode="omni", language=self.language) msg = {"role": "user", "content": contents + [prompt]} msgs = [sys_msg, msg] - + # Generation parameters - temperature = kwargs.get('temperature', self.temperature) - top_p = kwargs.get('top_p', self.top_p) - max_new_tokens = kwargs.get('max_new_tokens', self.max_new_tokens) - - logger.info(f"Generating (temp={temperature}, max_tokens={max_new_tokens})...") - + temperature = kwargs.get("temperature", self.temperature) + kwargs.get("top_p", self.top_p) + max_new_tokens = kwargs.get("max_new_tokens", self.max_new_tokens) + + logger.info( + f"Generating (temp={temperature}, max_tokens={max_new_tokens})..." + ) + # Generate res = self.model.chat( msgs=msgs, @@ -372,44 +403,46 @@ def generate( generate_audio=False, max_slice_nums=1, use_image_id=False, - return_dict=True + return_dict=True, ) - - response_text = res['text'] + + response_text = res["text"] logger.info(f"Generated response ({len(response_text)} chars)") - + return self.postprocess_output(response_text) - + except Exception as e: logger.error(f"Generation failed: {e}", exc_info=True) raise RuntimeError(f"Generation failed: {e}") - - def unload(self): - """Clean up model resources""" + + def unload(self) -> None: + """Clean up model resources.""" if self.model is not None: del self.model self.model = None - + if self.tokenizer is not None: del self.tokenizer self.tokenizer = None - + if torch.cuda.is_available(): torch.cuda.empty_cache() - + logger.info("Model unloaded and memory cleared") - + def get_model_info(self) -> Dict[str, Any]: - """Get model information""" + """Get model information.""" info = super().get_model_info() - info.update({ - 'model_path': self.model_path, - 'model_type': 'Omni Multimodal', - 'device': str(self.device), - 'dtype': str(self.dtype), - 'audio_sample_rate': self.audio_sr, - 'default_frame_limits': f'{self.default_min_frames}-{self.default_max_frames}', - 'fps_strategy': 'Adaptive (always respects max_frames)', - 'input_format': 'Interleaved frames and audio chunks' - }) - return info \ No newline at end of file + info.update( + { + "model_path": self.model_path, + "model_type": "Omni Multimodal", + "device": str(self.device), + "dtype": str(self.dtype), + "audio_sample_rate": self.audio_sr, + "default_frame_limits": f"{self.default_min_frames}-{self.default_max_frames}", + "fps_strategy": "Adaptive (always respects max_frames)", + "input_format": "Interleaved frames and audio chunks", + } + ) + return info diff --git a/sonic-o1/05_evaluation_inference/models/ola.py b/sonic-o1/05_evaluation_inference/models/ola.py new file mode 100644 index 0000000..d45752c --- /dev/null +++ b/sonic-o1/05_evaluation_inference/models/ola.py @@ -0,0 +1,528 @@ +"""ola.py +OLA (Omni-modal Language Assistant) implementation following BaseModel pattern. + +Author: SONIC-O1 Team +""" + +import logging +import os +import sys +from pathlib import Path +from typing import Any, Dict, Literal, Optional, Union + +import numpy as np +import torch +from decord import VideoReader, cpu +from PIL import Image +from utils.audio_processor import sample_audio_chunks + +from .base_model import BaseModel + + +logger = logging.getLogger(__name__) + + +class OLA(BaseModel): + """OLA-7b wrapper following BaseModel pattern. + + Uses our own video frame sampling and audio chunking to handle + long videos (up to 1 hour), overriding OLA's built-in 12.5-min + audio truncation and fixed 64-frame video sampling. + """ + + def __init__(self, model_name: str, config: Dict[str, Any]) -> None: + super().__init__(model_name, config) + + self.model_path = config.get("model_path", "THUdyh/Ola-7b") + self.ola_repo_path = config.get("ola_repo_path") + + # Frame config + self.default_max_frames = config.get("max_frames", 64) + self.default_min_frames = config.get("min_frames", 16) + + # Audio config + self.audio_sample_rate = config.get("audio_sample_rate", 16000) + self.audio_chunk_limit = 480000 # 30s at 16kHz โ€” OLA's window size + # self.max_speech_chunks = config.get("max_speech_chunks", 25) # OLA hard cap + + # Generation config + gen_config = config.get("generation_config", {}) + self.temperature = gen_config.get("temperature", 0.2) + self.top_p = gen_config.get("top_p", None) + self.num_beams = gen_config.get("num_beams", 1) + self.max_new_tokens = gen_config.get("max_new_tokens", 1024) + + # Model components (loaded in load()) + self.tokenizer = None + self.model = None + self.image_processor = None + self.context_len = None + + # Stats + self.stats = { + "total_samples": 0, + "audio_chunks_sampled": 0, + } + def convert_av1_to_h264( + self, video_path: Path, output_dir: Optional[Path] = None + ) -> Path: + """ + Convert AV1 video to H.264 for Decord compatibility. + + Args: + video_path: Input video path + output_dir: Output directory (default: creates 'converted' subdir) + + Returns: + Path to converted video. + """ + import subprocess + + if output_dir is None: + output_dir = video_path.parent / "converted" + output_dir.mkdir(exist_ok=True) + + output_path = output_dir / f"{video_path.stem}_h264{video_path.suffix}" + + # Check if already converted + if output_path.exists(): + logger.info(f"Using cached converted video: {output_path.name}") + return output_path + + logger.info(f"Converting AV1 to H.264: {video_path.name}") + + cmd = [ + "ffmpeg", + "-i", + str(video_path), + "-c:v", + "libx264", + "-preset", + "fast", + "-crf", + "23", + "-c:a", + "copy", + "-y", + str(output_path), + ] + + try: + subprocess.run(cmd, check=True, capture_output=True, text=True) + logger.info(f"โœ“ Conversion successful: {output_path.name}") + return output_path + except subprocess.CalledProcessError as e: + logger.error(f"โœ— Conversion failed: {e.stderr}") + raise RuntimeError(f"Failed to convert AV1 video: {e.stderr}") + + def _check_video_compatibility(self, video_path: Path) -> Optional[Path]: + """ + Check if video is compatible with Decord. + If AV1, automatically convert to H.264. + + Args: + video_path: Original video path. + + Returns: + Path to compatible video (original or converted). + """ + import subprocess + + from decord import VideoReader, cpu + + try: + vr = VideoReader(str(video_path), ctx=cpu(0), num_threads=1) + frame_count = len(vr) + del vr + logger.info(f"โœ“ Video compatible: {video_path.name} ({frame_count} frames)") + return video_path + except Exception: + logger.warning(f"โš  Video incompatible with Decord: {video_path.name}") + + # Detect codec + try: + result = subprocess.run( + [ + "ffprobe", + "-v", + "error", + "-select_streams", + "v:0", + "-show_entries", + "stream=codec_name", + "-of", + "default=noprint_wrappers=1:nokey=1", + str(video_path), + ], + capture_output=True, + text=True, + check=True, + ) + codec = result.stdout.strip() + + if codec == "av1": + logger.info(" โ†’ Detected AV1 codec - converting to H.264...") + try: + return self.convert_av1_to_h264(video_path) + except Exception as conv_error: + logger.error(f" โœ— Conversion failed: {conv_error}") + return None + else: + logger.warning(f" โœ— Unsupported codec '{codec}' - skipping") + return None + except Exception as probe_error: + logger.error(f" โœ— Failed to detect codec: {probe_error}") + return None + + def load(self) -> None: + """Load OLA model.""" + if self.ola_repo_path: + ola_path = os.path.expanduser(self.ola_repo_path) + if ola_path not in sys.path: + sys.path.insert(0, ola_path) + logger.info(f"Added OLA repo to path: {ola_path}") + + # Set OLA's required env vars before importing + os.environ.setdefault("LOWRES_RESIZE", "384x32") + os.environ.setdefault("HIGHRES_BASE", "0x32") + os.environ.setdefault("VIDEO_RESIZE", "0x64") + os.environ.setdefault("VIDEO_MAXRES", "480") + os.environ.setdefault("VIDEO_MINRES", "288") + os.environ.setdefault("MAXRES", "1536") + os.environ.setdefault("MINRES", "0") + os.environ.setdefault("FORCE_NO_DOWNSAMPLE", "1") + os.environ.setdefault("LOAD_VISION_EARLY", "1") + os.environ.setdefault("PAD2STRIDE", "1") + + try: + from ola.model.builder import load_pretrained_model + except ImportError as e: + raise ImportError( + f"Failed to import OLA modules. Set 'ola_repo_path' in config.\n" + f"Error: {e}" + ) + + logger.info(f"Loading OLA from {self.model_path}") + + self.tokenizer, self.model, self.image_processor, self.context_len = ( + load_pretrained_model(self.model_path, None) + ) + self.model = self.model.to("cuda").eval().bfloat16() + + logger.info("OLA loaded successfully") + + def _sample_video_frames(self, video_path: str, max_frames: int, min_frames: int): + video_path = Path(video_path) + compatible_path = self._check_video_compatibility(video_path) + if compatible_path is None: + raise RuntimeError(f"Video codec incompatible with Decord: {video_path}") + + vreader = VideoReader(str(compatible_path), ctx=cpu(0)) + + total_frames = len(vreader) + + num_frames = max(min_frames, min(max_frames, total_frames)) + indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) + frames = vreader.get_batch(indices.tolist()).asnumpy() + return [Image.fromarray(f) for f in frames], indices.tolist() + + def _load_audio_ola( + self, + audio_path: str, + max_chunks: Optional[int] = None, + ): + """ + Load and chunk audio for OLA, overriding their 12.5-min truncation. + + Our approach: uniformly sample up to max_chunks windows from the full + audio duration, instead of taking the first N chunks. This ensures + coverage of the full video for long recordings. + + Returns: + tuple: (mels, speech_lengths, speech_chunks, speech_wavs) โ€” all on CPU. + """ + import librosa + import whisper + + speech_wav, _ = librosa.load(audio_path, sr=self.audio_sample_rate) + if len(speech_wav.shape) > 1: + speech_wav = speech_wav[:, 0] + speech_wav = speech_wav.astype(np.float32) + + total_samples = len(speech_wav) + chunk_lim = self.audio_chunk_limit # 30s window + + # Build all 30s chunks + all_chunks = [] + for i in range(0, total_samples, chunk_lim): + chunk = speech_wav[i: i + chunk_lim] + chunk = whisper.pad_or_trim(chunk) + all_chunks.append(chunk) + + total_chunks = len(all_chunks) + logger.info(f"Audio: {total_samples / self.audio_sample_rate:.1f}s โ†’ " + f"{total_chunks} chunks of 30s") + + # None = no limit, let all chunks through + if max_chunks is not None and total_chunks > max_chunks: + indices = np.linspace(0, total_chunks - 1, max_chunks, dtype=int) + sampled_chunks = [all_chunks[i] for i in indices] + logger.info(f"Uniformly sampled {max_chunks}/{total_chunks} audio chunks") + self.stats["audio_chunks_sampled"] += 1 + else: + sampled_chunks = all_chunks + logger.info(f"Using all {total_chunks} audio chunks") + + # Build tensors + mels = [] + speech_wavs = [] + for chunk in sampled_chunks: + mel = whisper.log_mel_spectrogram(chunk, n_mels=128).permute(1, 0).unsqueeze(0) + mels.append(mel) + speech_wavs.append(torch.from_numpy(chunk).unsqueeze(0)) + + mels = torch.cat(mels, dim=0) # [N, 3000, 128] + speech_wavs = torch.cat(speech_wavs, dim=0) # [N, 480000] + + speech_lengths = torch.LongTensor([mels.shape[1]] * mels.shape[0]) + speech_chunks = torch.LongTensor([mels.shape[0]]) + + return mels, speech_lengths, speech_chunks, speech_wavs + + def _process_video_frames(self, frames: list): + """Process PIL frames through OLA's image processor.""" + from ola.mm_utils import process_anyres_video + + self.image_processor.do_resize = False + self.image_processor.do_center_crop = False + + video_processed = [] + for frame in frames: + frame_tensor = process_anyres_video(frame, self.image_processor) + video_processed.append(frame_tensor.unsqueeze(0)) + + video_processed = torch.cat(video_processed, dim=0).bfloat16().to("cuda") + return video_processed + + def generate( + self, + frames: Union[str, Path], + audio: Optional[Union[str, Path]] = None, + prompt: str = "Describe what you see and hear.", + fps: Optional[float] = None, + video_category: Optional[Literal["short", "medium", "long"]] = None, + max_frames: Optional[int] = None, + max_audio_chunks: Optional[int] = None, + **kwargs, + ) -> str: + """Generate response from video and optional audio.""" + from ola.conversation import conv_templates, SeparatorStyle + from ola.constants import ( + DEFAULT_IMAGE_TOKEN, + DEFAULT_SPEECH_TOKEN, + IMAGE_TOKEN_INDEX, + ) + from ola.datasets.preprocess import ( + tokenizer_image_token, + tokenizer_speech_image_token, + ) + from ola.mm_utils import KeywordsStoppingCriteria + + actual_max_frames = max_frames if max_frames is not None else self.default_max_frames + temperature = kwargs.get("temperature", self.temperature) + top_p = kwargs.get("top_p", self.top_p) + num_beams = kwargs.get("num_beams", self.num_beams) + max_new_tokens = kwargs.get("max_new_tokens", self.max_new_tokens) + + try: + self.stats["total_samples"] += 1 + + # โ”€โ”€ Audio โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + has_audio = audio is not None and os.path.exists(str(audio)) + speechs, speech_lengths, speech_wavs, speech_chunks = [], [], [], [] + + if has_audio: + logger.info(f"Loading audio: {audio}") + try: + mels, s_lengths, s_chunks, s_wavs = self._load_audio_ola( + str(audio), max_chunks=max_audio_chunks + ) + speechs.append(mels.bfloat16().to("cuda")) + speech_lengths.append(s_lengths.to("cuda")) + speech_chunks.append(s_chunks.to("cuda")) + speech_wavs.append(s_wavs.to("cuda")) + logger.info(f"Audio loaded: {s_chunks[0].item()} chunks") + except Exception as e: + logger.error(f"Audio processing failed: {e}") + logger.warning("Falling back to dummy audio") + has_audio = False + + if not has_audio: + # Dummy audio โ€” OLA handles this gracefully + speechs = [torch.zeros(1, 3000, 128).bfloat16().to("cuda")] + speech_lengths = [torch.LongTensor([3000]).to("cuda")] + speech_wavs = [torch.zeros(1, 480000).to("cuda")] + speech_chunks = [torch.LongTensor([1]).to("cuda")] + + # โ”€โ”€ Video โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + has_video = frames is not None and os.path.exists(str(frames)) + + if has_video: + logger.info(f"Loading video: {frames}") + pil_frames, frame_idx = self._sample_video_frames( + str(frames), + max_frames=actual_max_frames, + min_frames=self.default_min_frames, + ) + video_tensor = self._process_video_frames(pil_frames) + video_data = ( + (video_tensor, video_tensor), # (images, images_highres) + (384, 384), + "video", + ) + logger.info(f"Video frames sampled: {len(pil_frames)}") + + # Build prompt + qs = DEFAULT_SPEECH_TOKEN + DEFAULT_IMAGE_TOKEN + "\n" + prompt + + else: + # No video โ€” audio only + qs = DEFAULT_SPEECH_TOKEN + "\n" + prompt + + # โ”€โ”€ Tokenize โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + conv_mode = "qwen_1_5" + conv = conv_templates[conv_mode].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + full_prompt = conv.get_prompt() + + if has_video: + input_ids = ( + tokenizer_speech_image_token( + full_prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" + ) + .unsqueeze(0) + .to("cuda") + ) + else: + from ola.datasets.preprocess import tokenizer_speech_token + from ola.constants import SPEECH_TOKEN_INDEX + input_ids = ( + tokenizer_speech_token( + full_prompt, self.tokenizer, SPEECH_TOKEN_INDEX, return_tensors="pt" + ) + .unsqueeze(0) + .to("cuda") + ) + + pad_token_ids = 151643 + attention_masks = input_ids.ne(pad_token_ids).long().to("cuda") + + stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 + stopping_criteria = KeywordsStoppingCriteria( + [stop_str], self.tokenizer, input_ids + ) + + # โ”€โ”€ Generate โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + logger.info("Generating response...") + with torch.inference_mode(): + if has_video: + output_ids = self.model.generate( + inputs=input_ids, + images=video_data[0][0], + images_highres=video_data[0][1], + modalities=video_data[2], + speech=speechs, + speech_lengths=speech_lengths, + speech_chunks=speech_chunks, + speech_wav=speech_wavs, + attention_mask=attention_masks, + use_cache=True, + stopping_criteria=[stopping_criteria], + do_sample=temperature > 0, + temperature=temperature, + top_p=top_p, + num_beams=num_beams, + max_new_tokens=max_new_tokens, + ) + else: + dummy_images = [ + torch.zeros(1, 3, 224, 224).bfloat16().to("cuda") + ] + output_ids = self.model.generate( + inputs=input_ids, + images=dummy_images, + images_highres=dummy_images, + image_sizes=[(224, 224)], + modalities=["text"], + speech=speechs, + speech_lengths=speech_lengths, + speech_chunks=speech_chunks, + speech_wav=speech_wavs, + attention_mask=attention_masks, + use_cache=True, + stopping_criteria=[stopping_criteria], + do_sample=temperature > 0, + temperature=temperature, + top_p=top_p, + num_beams=num_beams, + max_new_tokens=max_new_tokens, + ) + + outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] + outputs = outputs.strip() + if outputs.endswith(stop_str): + outputs = outputs[: -len(stop_str)] + outputs = outputs.strip() + + logger.info(f"Generated {len(outputs)} characters") + return self.postprocess_output(outputs) + + except (torch.cuda.OutOfMemoryError, RuntimeError) as e: + error_msg = str(e) + if "out of memory" in error_msg.lower() or "size of tensor" in error_msg.lower(): + logger.error(f"OOM error: {error_msg[:200]}...") + torch.cuda.empty_cache() + raise RuntimeError(f"Out of memory: {e}") + logger.error(f"Generation failed: {e}") + raise RuntimeError(f"Generation failed: {e}") + + def unload(self) -> None: + """Unload model and free memory.""" + logger.info("Unloading OLA model...") + + if self.model is not None: + del self.model + self.model = None + + if self.tokenizer is not None: + del self.tokenizer + self.tokenizer = None + + if self.image_processor is not None: + del self.image_processor + self.image_processor = None + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + logger.info("OLA model unloaded") + + def get_model_info(self) -> Dict[str, Any]: + """Get model information.""" + info = super().get_model_info() + info.update( + { + "model_path": self.model_path, + "backend": "HuggingFace Transformers", + "native_video": True, + "native_audio": True, + "default_max_frames": self.default_max_frames, + "default_min_frames": self.default_min_frames, + "max_speech_chunks": self.max_speech_chunks, + "context_length": self.context_len, + "statistics": self.stats, + } + ) + return info diff --git a/sonic-o1/05_evaluation_inference/models/omnivinci.py b/sonic-o1/05_evaluation_inference/models/omnivinci.py new file mode 100644 index 0000000..7db6e6d --- /dev/null +++ b/sonic-o1/05_evaluation_inference/models/omnivinci.py @@ -0,0 +1,419 @@ +"""omnivinci.py +OmniVinci (nvidia/omnivinci) implementation following BaseModel pattern. + +Accepts separate video (.mp4) and audio (.m4a) paths, merges them into a +single mp4 via ffmpeg (stream-copy video, re-encode audio to aac), then +passes the merged file to OmniVinci's processor. + +Audio is handled entirely internally by OmniVinci โ€” no custom chunking needed. +Frame count is controlled via model.config / processor.config before each call +so that retry fallbacks (frame_count_fallback) take effect correctly. + +Author: SONIC-O1 Team +""" + +import logging +import subprocess +import tempfile +from pathlib import Path +from typing import Any, Dict, Literal, Optional, Union + +import torch + +from .base_model import BaseModel + + +logger = logging.getLogger(__name__) + + +class OmniVinci(BaseModel): + """ + OmniVinci wrapper following BaseModel pattern. + + - Video + audio are merged into a single mp4 before inference. + - Frame count is set on model/processor config before each call. + - Audio length is capped at max_3600 (1 hour) at load time โ€” no chunking. + - No repo path injection needed (pure HF trust_remote_code model). + """ + + def __init__(self, model_name: str, config: Dict[str, Any]) -> None: + super().__init__(model_name, config) + + self.model_path = config.get("model_path", "nvidia/omnivinci") + + # modeling_vila.py does eval(torch_dtype) so must be a torch.* object + self.device_map = config.get("device_map", "auto") + dtype_cfg = config.get("torch_dtype", "float16") + self.torch_dtype = getattr(torch, dtype_cfg) if isinstance(dtype_cfg, str) else dtype_cfg + + # Frame limits + self.default_max_frames = config.get("max_frames", 128) + self.default_min_frames = config.get("min_frames", 8) + + # Generation config + gen_config = config.get("generation_config", {}) + self.temperature = gen_config.get("temperature", 0.7) + self.top_p = gen_config.get("top_p", 0.95) + self.max_new_tokens = gen_config.get("max_new_tokens", 8192) + + self.model = None + self.processor = None + self.generation_config = None + self._temp_dir = None + + # ------------------------------------------------------------------ # + # Load / Unload # + # ------------------------------------------------------------------ # + + def load(self) -> None: + """Load OmniVinci model and processor.""" + from transformers import AutoModel, AutoProcessor + + logger.info(f"Loading OmniVinci from {self.model_path}") + + self.model = AutoModel.from_pretrained( + self.model_path, + trust_remote_code=True, + torch_dtype=self.torch_dtype, + device_map=self.device_map, + ) + self.model = self.model.to("cuda") + self.model.eval() + + self.processor = AutoProcessor.from_pretrained( + self.model_path, + trust_remote_code=True, + ) + + self.generation_config = self.model.default_generation_config + self.generation_config.update( + max_new_tokens=self.max_new_tokens, + max_length=99999999, + ) + + # Audio: trust OmniVinci's internal pipeline, cap at 1 hour + self.model.config.load_audio_in_video = True + self.processor.config.load_audio_in_video = True + self.model.config.audio_chunk_length = "max_3600" + self.processor.config.audio_chunk_length = "max_3600" + + self._temp_dir = tempfile.mkdtemp(prefix="omnivinci_merged_") + logger.info(f"OmniVinci loaded successfully (device_map={self.device_map})") + + def unload(self) -> None: + """Unload model and free memory.""" + logger.info("Unloading OmniVinci model...") + + if self.model is not None: + del self.model + self.model = None + if self.processor is not None: + del self.processor + self.processor = None + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + if self._temp_dir and Path(self._temp_dir).exists(): + import shutil + shutil.rmtree(self._temp_dir, ignore_errors=True) + logger.info(f"Cleaned up temp dir: {self._temp_dir}") + + logger.info("OmniVinci unloaded") + + # ------------------------------------------------------------------ # + # Video helpers # + # ------------------------------------------------------------------ # + + def convert_av1_to_h264(self, video_path: Path, output_dir: Optional[Path] = None) -> Path: + """Convert AV1 video to H.264 for decoder compatibility.""" + if output_dir is None: + output_dir = Path(self._temp_dir) / "converted" + output_dir.mkdir(parents=True, exist_ok=True) + + output_path = output_dir / f"{video_path.stem}_h264{video_path.suffix}" + if output_path.exists(): + logger.info(f"Using cached converted video: {output_path.name}") + return output_path + + logger.info(f"Converting AV1 to H.264: {video_path.name}") + cmd = [ + "ffmpeg", "-i", str(video_path), + "-c:v", "libx264", "-preset", "fast", "-crf", "23", + "-c:a", "copy", "-y", str(output_path), + ] + try: + subprocess.run(cmd, check=True, capture_output=True, text=True) + logger.info(f"โœ“ Conversion successful: {output_path.name}") + return output_path + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Failed to convert AV1 video: {e.stderr[-200:]}") + + def _ensure_video_compatibility(self, video_path: Path) -> Path: + """Convert AV1 โ†’ H.264 if needed; return compatible path.""" + try: + result = subprocess.run( + [ + "ffprobe", "-v", "error", "-select_streams", "v:0", + "-show_entries", "stream=codec_name", + "-of", "default=noprint_wrappers=1:nokey=1", + str(video_path), + ], + check=True, capture_output=True, text=True, + ) + codec = result.stdout.strip().lower() + except subprocess.CalledProcessError as e: + logger.warning(f"Could not detect codec for {video_path.name}: {e}") + return video_path + + if codec == "av1": + logger.info("Detected AV1 codec โ€” converting to H.264 before inference") + return self.convert_av1_to_h264(video_path) + return video_path + + def _has_usable_audio_stream(self, audio_path: Path) -> bool: + """ + True if ffprobe finds at least one audio stream with positive duration. + + Segments can be saved as .m4a with no audio track (silent source, + failed extract, or empty mux) โ€” ffmpeg then errors on -map 1:a:0. + Some broken extracts report a container but Duration 00:00:00.00; + we require format duration > 0 as well. + """ + try: + if not audio_path.is_file() or audio_path.stat().st_size == 0: + return False + except OSError: + return False + r = subprocess.run( + [ + "ffprobe", + "-v", "error", + "-select_streams", "a", + "-show_entries", "stream=index", + "-of", "csv=p=0", + str(audio_path), + ], + capture_output=True, + text=True, + ) + if r.returncode != 0 or not r.stdout.strip(): + return False + dur = subprocess.run( + [ + "ffprobe", + "-v", "error", + "-show_entries", "format=duration", + "-of", "default=noprint_wrappers=1:nokey=1", + str(audio_path), + ], + capture_output=True, + text=True, + ) + if dur.returncode != 0: + return False + try: + if float((dur.stdout.strip() or "0").split()[0]) <= 0: + return False + except (ValueError, IndexError): + return False + return True + + def _merge_video_audio(self, video_path: Path, audio_path: Path) -> Path: + """ + Merge video + audio into a single mp4 (cached across retries). + + On any ffmpeg failure (no audio stream, codec issue, corrupt segment, etc.), + logs a warning and returns ``video_path`` so inference runs video-only. + """ + output_path = Path(self._temp_dir) / f"merged_{video_path.stem}.mp4" + if output_path.exists(): + logger.info(f"Using cached merged file: {output_path.name}") + return output_path + + logger.info(f"Merging {video_path.name} + {audio_path.name} โ†’ {output_path.name}") + cmd = [ + "ffmpeg", + "-i", str(video_path), "-i", str(audio_path), + "-c:v", "copy", "-c:a", "aac", + "-map", "0:v:0", "-map", "1:a:0", + "-y", str(output_path), + ] + try: + subprocess.run(cmd, check=True, capture_output=True, text=True) + logger.info(f"โœ“ Merge successful: {output_path.name}") + return output_path + except subprocess.CalledProcessError as e: + err = (e.stderr or "") + (e.stdout or "") + if output_path.exists(): + try: + output_path.unlink() + except OSError: + pass + tail = (err[-400:] if err else str(e))[:400] + logger.warning( + f"ffmpeg merge failed ({audio_path.name}); using video without merged audio. " + f"Last output: {tail!r}" + ) + return video_path + + # ------------------------------------------------------------------ # + # Inference helpers # + # ------------------------------------------------------------------ # + + def _set_frame_count(self, num_frames: int) -> None: + """Must be called before processor([text]) so retry fallback takes effect.""" + self.model.config.num_video_frames = num_frames + self.processor.config.num_video_frames = num_frames + logger.info(f"num_video_frames set to: {num_frames}") + + def _inference_device(self) -> torch.device: + return next(self.model.parameters()).device + + @staticmethod + def _move_nested_to_device(obj: Any, device: torch.device) -> Any: + """Recursively move tensors in nested dicts/lists/tuples to device.""" + if isinstance(obj, torch.Tensor): + return obj.to(device) + if isinstance(obj, dict): + return {k: OmniVinci._move_nested_to_device(v, device) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + moved = [OmniVinci._move_nested_to_device(x, device) for x in obj] + return type(obj)(moved) + return obj + + @staticmethod + def _patch_media(media: Optional[Any]) -> Optional[Any]: + """ + modeling_vila.__embed_media_tokens unconditionally accesses + media['speech'] regardless of unified_audio_encoder flag. + When audio comes from video the processor only fills 'sound', + leaving 'speech' absent โ†’ KeyError: 'speech'. + Fix: ensure every expected key exists with an empty fallback. + """ + if not isinstance(media, dict): + return media + + patched = dict(media) + + # Keys VILA's embed loop iterates over โ€” ensure they exist + for key in ("speech", "sound", "image", "video", "vision", "frames", "audio"): + if key not in patched: + patched[key] = [] + + logger.info(f"media keys after patch: {list(patched.keys())}") + return patched + + # ------------------------------------------------------------------ # + # generate # + # ------------------------------------------------------------------ # + + def generate( + self, + frames: Union[str, Path], + audio: Optional[Union[str, Path]] = None, + prompt: str = "", + fps: Optional[float] = None, + video_category: Optional[Literal["short", "medium", "long"]] = None, + max_frames: Optional[int] = None, + max_audio_chunks: Optional[int] = None, # ignored โ€” audio handled internally + **kwargs, + ) -> str: + if self.model is None or self.processor is None: + raise RuntimeError("Model not loaded. Call load() first.") + + actual_max_frames = max_frames if max_frames is not None else self.default_max_frames + + video_path = Path(frames) + if not video_path.exists(): + raise FileNotFoundError(f"Video not found: {video_path}") + + try: + compatible_video = self._ensure_video_compatibility(video_path) + + if audio is not None: + audio_path = Path(audio) + if not audio_path.exists(): + logger.warning(f"Audio not found: {audio_path} โ€” proceeding video-only") + input_video = compatible_video + elif not self._has_usable_audio_stream(audio_path): + logger.warning( + f"No audio stream in {audio_path.name} (empty/silent extract) โ€” " + "proceeding video-only" + ) + input_video = compatible_video + else: + input_video = self._merge_video_audio(compatible_video, audio_path) + else: + logger.warning("No audio provided โ€” proceeding video-only") + input_video = compatible_video + + # Must be set BEFORE processor call + self._set_frame_count(actual_max_frames) + logger.info(f"Running OmniVinci on: {input_video.name} (frames={actual_max_frames})") + + conversation = [{ + "role": "user", + "content": [ + {"type": "video", "video": str(input_video)}, + {"type": "text", "text": prompt}, + ], + }] + + text = self.processor.apply_chat_template( + conversation, tokenize=False, add_generation_prompt=True, + ) + inputs = self.processor([text]) + + # Processor returns input_ids on CPU; model weights are on cuda. + # media / media_config are handled internally by model.generate. + inputs.input_ids = inputs.input_ids.to("cuda") + + max_new_tokens = kwargs.get("max_new_tokens", self.max_new_tokens) + self.generation_config.update(max_new_tokens=max_new_tokens) + + logger.info(f"Generating (max_tokens={max_new_tokens})...") + + with torch.inference_mode(): + output_ids = self.model.generate( + input_ids=inputs.input_ids, + media=getattr(inputs, "media", None), + media_config=getattr(inputs, "media_config", None), + generation_config=self.generation_config, + ) + + response = self.processor.tokenizer.batch_decode( + output_ids, skip_special_tokens=True, + )[0].strip() + + logger.info(f"Generated response ({len(response)} chars)") + return self.postprocess_output(response) + + except (torch.cuda.OutOfMemoryError, RuntimeError) as e: + error_msg = str(e) + if "out of memory" in error_msg.lower(): + logger.error(f"OOM: {error_msg[:200]}") + torch.cuda.empty_cache() + raise RuntimeError(f"Out of memory: {e}") + logger.error(f"Generation failed: {e}") + raise RuntimeError(f"Generation failed: {e}") + + # ------------------------------------------------------------------ # + # Info # + # ------------------------------------------------------------------ # + + def get_model_info(self) -> Dict[str, Any]: + info = super().get_model_info() + info.update( + { + "model_path": self.model_path, + "model_type": "Omni Multimodal (VILAForCausalLM)", + "device_map": self.device_map, + "torch_dtype": str(self.torch_dtype), + "default_max_frames": self.default_max_frames, + "audio_handling": "Internal (max_3600) โ€” no custom chunking", + "input_format": "Merged mp4 (ffmpeg stream-copy video + aac audio)", + } + ) + return info \ No newline at end of file diff --git a/sonic-o1/05_evaluation_inference/models/phi4.py b/sonic-o1/05_evaluation_inference/models/phi4.py index d147312..d882a58 100644 --- a/sonic-o1/05_evaluation_inference/models/phi4.py +++ b/sonic-o1/05_evaluation_inference/models/phi4.py @@ -1,87 +1,90 @@ -""" -models/phi4.py - -Phi-4 Multimodal implementation following BaseModel pattern. -Self-contained with all Phi-4-specific logic. +"""phi4.py +Phi-4 Multimodal implementation (microsoft/Phi-4-multimodal-instruct). -Based on microsoft/Phi-4-multimodal-instruct -https://huggingface.co/microsoft/Phi-4-multimodal-instruct +Author: SONIC-O1 Team """ -import os import logging -import torch -import numpy as np +import os from pathlib import Path -from typing import Optional, Dict, Any, Union, Literal, List, Tuple +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +import numpy as np +import torch from PIL import Image from .base_model import BaseModel + logger = logging.getLogger(__name__) class Phi4(BaseModel): - """Phi-4 Multimodal wrapper following BaseModel pattern""" - - def __init__(self, model_name: str, config: Dict[str, Any]): + """Phi-4 Multimodal wrapper following BaseModel pattern.""" + + def __init__(self, model_name: str, config: Dict[str, Any]) -> None: super().__init__(model_name, config) - - self.model_path = config.get('model_path', 'microsoft/Phi-4-multimodal-instruct') - + + self.model_path = config.get( + "model_path", "microsoft/Phi-4-multimodal-instruct" + ) + # Device config - self.device_map = config.get('device_map', 'auto') - self.torch_dtype = config.get('torch_dtype', 'bfloat16') - self.attn_implementation = config.get('attn_implementation', 'eager') - self.trust_remote_code = config.get('trust_remote_code', True) - + self.device_map = config.get("device_map", "auto") + self.torch_dtype = config.get("torch_dtype", "bfloat16") + self.attn_implementation = config.get("attn_implementation", "eager") + self.trust_remote_code = config.get("trust_remote_code", True) + # Frame config - self.default_max_frames = config.get('max_frames', 256) - self.default_min_frames = config.get('min_frames', 64) - + self.default_max_frames = config.get("max_frames", 256) + self.default_min_frames = config.get("min_frames", 64) + # Audio config - self.audio_sample_rate = config.get('audio_sample_rate', 16000) - self.audio_mono = config.get('audio_mono', True) - + self.audio_sample_rate = config.get("audio_sample_rate", 16000) + self.audio_mono = config.get("audio_mono", True) + # Generation config - gen_config = config.get('generation_config', {}) - self.temperature = gen_config.get('temperature', 0.7) - self.top_p = gen_config.get('top_p', 0.95) - self.max_new_tokens = gen_config.get('max_new_tokens', 8192) - + gen_config = config.get("generation_config", {}) + self.temperature = gen_config.get("temperature", 0.7) + self.top_p = gen_config.get("top_p", 0.95) + self.max_new_tokens = gen_config.get("max_new_tokens", 8192) + # Model components (loaded in load()) self.processor = None self.model = None self.generation_config = None - + # Stats self.stats = { - 'total_samples': 0, - 'audio_chunks_sampled': 0, - 'avg_frames_per_sample': 0, - 'total_frames_processed': 0, + "total_samples": 0, + "audio_chunks_sampled": 0, + "avg_frames_per_sample": 0, + "total_frames_processed": 0, } - - def load(self): - """Load Phi-4 model and processor""" + + def load(self) -> None: + """Load Phi-4 model and processor.""" try: - from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig + from transformers import ( + AutoModelForCausalLM, + AutoProcessor, + GenerationConfig, + ) except ImportError as e: raise ImportError( f"Failed to import transformers. Please install: pip install transformers\n" f"Error: {e}" ) - + logger.info(f"Loading Phi-4 from {self.model_path}") logger.info(f"Using device_map={self.device_map}, dtype={self.torch_dtype}") logger.info(f"Found {torch.cuda.device_count()} GPUs") - + # Load processor self.processor = AutoProcessor.from_pretrained( - self.model_path, - trust_remote_code=self.trust_remote_code + self.model_path, trust_remote_code=self.trust_remote_code ) - + # Determine torch dtype if self.torch_dtype == "auto": torch_dtype = "auto" @@ -94,7 +97,7 @@ def load(self): else: torch_dtype = "auto" logger.warning(f"Unknown dtype {self.torch_dtype}, using 'auto'") - + # Load model self.model = AutoModelForCausalLM.from_pretrained( self.model_path, @@ -103,35 +106,33 @@ def load(self): trust_remote_code=self.trust_remote_code, _attn_implementation=self.attn_implementation, ) - + # Load generation config self.generation_config = GenerationConfig.from_pretrained(self.model_path) - + # Log device distribution if hasattr(self.model, "hf_device_map"): logger.info("=== Device Map ===") device_counts = {} - for module, device in self.model.hf_device_map.items(): + for _module, device in self.model.hf_device_map.items(): device_counts[device] = device_counts.get(device, 0) + 1 for device, count in sorted(device_counts.items(), key=lambda x: str(x[0])): logger.info(f" {device}: {count} modules") logger.info("==================") - + self.model.eval() logger.info("Phi-4 loaded successfully") - + def _extract_video_frames( - self, - video_path: Union[str, Path], - max_frames: int + self, video_path: Union[str, Path], max_frames: int ) -> List[Image.Image]: """ Extract frames from video using OpenCV with uniform sampling. - + Args: video_path: Path to video file max_frames: Maximum number of frames to extract - + Returns: List of PIL Images """ @@ -142,23 +143,23 @@ def _extract_video_frames( "OpenCV required for frame extraction. " "Install with: pip install opencv-python" ) - + video_path = str(video_path) - + # Open video cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise ValueError(f"Could not open video: {video_path}") - + # Get video info total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) duration = total_frames / fps if fps > 0 else 0 - + logger.info( f"Video info: {total_frames} frames, {fps:.2f} fps, {duration:.2f}s duration" ) - + # Calculate frame indices to extract (uniform sampling) if total_frames <= max_frames: frame_indices = list(range(total_frames)) @@ -166,13 +167,15 @@ def _extract_video_frames( frame_indices = np.linspace( 0, total_frames - 1, max_frames, dtype=int ).tolist() - + # Ensure minimum frames if len(frame_indices) < self.default_min_frames: # For very short videos, repeat frames if necessary - repeat_count = (self.default_min_frames + len(frame_indices) - 1) // len(frame_indices) - frame_indices = (frame_indices * repeat_count)[:self.default_min_frames] - + repeat_count = (self.default_min_frames + len(frame_indices) - 1) // len( + frame_indices + ) + frame_indices = (frame_indices * repeat_count)[: self.default_min_frames] + # Extract frames frames = [] for idx in frame_indices: @@ -186,26 +189,24 @@ def _extract_video_frames( frames.append(pil_image) else: logger.warning(f"Failed to read frame at index {idx}") - + cap.release() - + if not frames: raise ValueError(f"Could not extract any frames from video: {video_path}") - + logger.info(f"Extracted {len(frames)} frames from video") return frames - - def _load_audio( - self, - audio_path: Union[str, Path] - ) -> Tuple[np.ndarray, int]: + + def _load_audio(self, audio_path: Union[str, Path]) -> Tuple[np.ndarray, int]: """ Load audio from file with automatic format detection. + Supports m4a, mp3, wav, flac, etc. - + Args: audio_path: Path to audio file - + Returns: Tuple of (audio_array, sample_rate) """ @@ -213,79 +214,76 @@ def _load_audio( import librosa except ImportError: raise ImportError( - "librosa required for audio loading. " - "Install with: pip install librosa" + "librosa required for audio loading. Install with: pip install librosa" ) - + audio_path = str(audio_path) - + logger.info(f"Loading audio: {audio_path}") - + # Load audio with librosa (handles all formats via ffmpeg) audio, sr = librosa.load( - audio_path, - sr=self.audio_sample_rate, - mono=self.audio_mono + audio_path, sr=self.audio_sample_rate, mono=self.audio_mono ) - + duration = len(audio) / sr logger.info(f"Loaded audio: {duration:.2f}s @ {sr}Hz, mono={self.audio_mono}") - + return audio, sr - + def _chunk_audio( self, audio_array: np.ndarray, sample_rate: int, max_chunks: Optional[int], - chunk_duration_sec: float = 10.0 + chunk_duration_sec: float = 10.0, ) -> Tuple[np.ndarray, int]: """ Chunk audio to maximum duration if max_chunks is specified. - + Args: audio_array: Audio waveform sample_rate: Sample rate max_chunks: Maximum number of chunks (None = no chunking) chunk_duration_sec: Duration of each chunk in seconds - + Returns: Tuple of (chunked_audio, sample_rate) """ if max_chunks is None: return audio_array, sample_rate - + # Calculate max samples max_samples = int(max_chunks * chunk_duration_sec * sample_rate) - + if len(audio_array) > max_samples: original_duration = len(audio_array) / sample_rate chunked_duration = max_samples / sample_rate - + logger.info( f"Chunking audio: {original_duration:.2f}s -> {chunked_duration:.2f}s " f"({max_chunks} chunks ร— {chunk_duration_sec}s)" ) - - self.stats['audio_chunks_sampled'] += 1 + + self.stats["audio_chunks_sampled"] += 1 return audio_array[:max_samples], sample_rate - + return audio_array, sample_rate - + def generate( self, frames: Union[List[np.ndarray], np.ndarray, str], audio: Optional[Union[np.ndarray, str]], prompt: str, fps: Optional[float] = None, - video_category: Optional[Literal['short', 'medium', 'long']] = None, + video_category: Optional[Literal["short", "medium", "long"]] = None, max_frames: Optional[int] = None, max_audio_chunks: Optional[int] = None, - **kwargs + **kwargs, ) -> str: """ Generate response from video frames and audio. - + Args: frames: Video file path (str) - we extract frames externally audio: Audio file path (str) or None @@ -295,95 +293,101 @@ def generate( max_frames: Maximum frames to extract (set by external retry) max_audio_chunks: Maximum audio chunks (set by external retry) **kwargs: Additional generation parameters - + Returns: Generated text response """ if self.model is None or self.processor is None: raise RuntimeError("Model not loaded. Call load() first.") - + if not isinstance(frames, str): raise ValueError( f"Phi-4 requires video file path (str), got {type(frames)}" ) - + video_path = Path(frames) if not video_path.exists(): raise FileNotFoundError(f"Video file not found: {video_path}") - + # Use external max_frames or default - actual_max_frames = max_frames if max_frames is not None else self.default_max_frames - + actual_max_frames = ( + max_frames if max_frames is not None else self.default_max_frames + ) + try: logger.info( f"Processing: max_frames={actual_max_frames}, " f"max_audio_chunks={max_audio_chunks}" ) - + # 1. Extract video frames video_frames = self._extract_video_frames(video_path, actual_max_frames) num_frames = len(video_frames) - + # Track stats - self.stats['total_samples'] += 1 - self.stats['total_frames_processed'] += num_frames - self.stats['avg_frames_per_sample'] = ( - self.stats['total_frames_processed'] / self.stats['total_samples'] + self.stats["total_samples"] += 1 + self.stats["total_frames_processed"] += num_frames + self.stats["avg_frames_per_sample"] = ( + self.stats["total_frames_processed"] / self.stats["total_samples"] ) - + # 2. Load and chunk audio if provided - has_audio = audio is not None and isinstance(audio, str) and os.path.exists(audio) - + has_audio = ( + audio is not None and isinstance(audio, str) and os.path.exists(audio) + ) + if has_audio: audio_array, sr = self._load_audio(audio) - + # Chunk audio if needed audio_array, sr = self._chunk_audio( audio_array, sr, max_audio_chunks, - kwargs.get('audio_chunk_duration_sec', 10.0) + kwargs.get("audio_chunk_duration_sec", 10.0), ) - + audio_input = [(audio_array, sr)] else: audio_input = None logger.info("No audio provided") - + # 3. Build Phi-4 prompt with special tokens - user_prompt = '<|user|>' - assistant_prompt = '<|assistant|>' - prompt_suffix = '<|end|>' - + user_prompt = "<|user|>" + assistant_prompt = "<|assistant|>" + prompt_suffix = "<|end|>" + # Build image placeholders: <|image_1|><|image_2|>...<|image_n|> - image_placeholders = ''.join([f'<|image_{i+1}|>' for i in range(num_frames)]) - + image_placeholders = "".join( + [f"<|image_{i + 1}|>" for i in range(num_frames)] + ) + # Build full prompt based on modality if has_audio: # Video + Audio format full_prompt = ( - f'{user_prompt}{image_placeholders}<|audio_1|>' - f'{prompt}{prompt_suffix}{assistant_prompt}' + f"{user_prompt}{image_placeholders}<|audio_1|>" + f"{prompt}{prompt_suffix}{assistant_prompt}" ) else: # Video only format full_prompt = ( - f'{user_prompt}{image_placeholders}' - f'{prompt}{prompt_suffix}{assistant_prompt}' + f"{user_prompt}{image_placeholders}" + f"{prompt}{prompt_suffix}{assistant_prompt}" ) - + logger.info(f"Prompt: {full_prompt[:200]}...") - + # 4. Process inputs processor_inputs = { - 'text': full_prompt, - 'images': video_frames, - 'return_tensors': 'pt' + "text": full_prompt, + "images": video_frames, + "return_tensors": "pt", } - + if has_audio: - processor_inputs['audios'] = audio_input - + processor_inputs["audios"] = audio_input + inputs = self.processor(**processor_inputs) device = next(self.model.parameters()).device clean = {} @@ -396,12 +400,12 @@ def generate( clean[k] = v inputs = clean # Get generation parameters - temperature = kwargs.get('temperature', self.temperature) - top_p = kwargs.get('top_p', self.top_p) - max_new_tokens = kwargs.get('max_new_tokens', self.max_new_tokens) - + temperature = kwargs.get("temperature", self.temperature) + top_p = kwargs.get("top_p", self.top_p) + max_new_tokens = kwargs.get("max_new_tokens", self.max_new_tokens) + logger.info("Generating response...") - + # 5. Generate with torch.no_grad(): output_ids = self.model.generate( @@ -409,111 +413,124 @@ def generate( max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, - use_cache=False, + use_cache=False, generation_config=self.generation_config, ) - + # 6. Decode (skip input tokens) - input_length = inputs['input_ids'].shape[1] + input_length = inputs["input_ids"].shape[1] output_ids = output_ids[:, input_length:] - + response = self.processor.batch_decode( - output_ids, - skip_special_tokens=True, - clean_up_tokenization_spaces=False + output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] - + logger.info(f"Generated response ({len(response)} chars)") - + return self.postprocess_output(response) - + except (torch.cuda.OutOfMemoryError, RuntimeError) as e: error_msg = str(e) - + # Check for OOM - if "out of memory" in error_msg.lower() or "CUDA out of memory" in error_msg: + if ( + "out of memory" in error_msg.lower() + or "CUDA out of memory" in error_msg + ): logger.error(f"OOM error: {error_msg[:200]}...") self._clear_memory() if torch.cuda.is_available(): torch.cuda.empty_cache() raise RuntimeError(f"Out of memory: {e}") - + # Check for CUDA errors - elif "CUDA" in error_msg or "device" in error_msg: + if "CUDA" in error_msg or "device" in error_msg: logger.error(f"CUDA error: {error_msg[:200]}...") raise RuntimeError(f"CUDA error: {e}") - + # Check for context length errors - elif any(keyword in error_msg.lower() for keyword in [ - 'context', 'token', 'length', 'limit', 'maximum', 'exceed' - ]): + if any( + keyword in error_msg.lower() + for keyword in [ + "context", + "token", + "length", + "limit", + "maximum", + "exceed", + ] + ): logger.error(f"Context length error: {error_msg[:200]}...") raise RuntimeError(f"Context length exceeded: {e}") - - else: - logger.error(f"Generation failed: {e}", exc_info=True) - raise RuntimeError(f"Generation failed: {e}") - + + logger.error(f"Generation failed: {e}", exc_info=True) + raise RuntimeError(f"Generation failed: {e}") + except Exception as e: logger.error(f"Unexpected error: {e}", exc_info=True) raise RuntimeError(f"Generation failed: {e}") - - def unload(self): - """Unload model and free memory""" + + def unload(self) -> None: + """Unload model and free memory.""" logger.info("Unloading Phi-4 model...") - + if self.model is not None: del self.model self.model = None - + if self.processor is not None: del self.processor self.processor = None - + if self.generation_config is not None: del self.generation_config self.generation_config = None - + # Clear CUDA cache if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() - + logger.info("Phi-4 model unloaded") - + def get_model_info(self) -> Dict[str, Any]: - """Get model information""" + """Get model information.""" info = super().get_model_info() - info.update({ - 'model_path': self.model_path, - 'backend': 'HuggingFace Transformers', - 'native_video': False, # We extract frames externally - 'native_audio': True, - 'device_map': self.device_map, - 'torch_dtype': self.torch_dtype, - 'attn_implementation': self.attn_implementation, - 'default_max_frames': self.default_max_frames, - 'default_min_frames': self.default_min_frames, - 'audio_sample_rate': self.audio_sample_rate, - 'audio_mono': self.audio_mono, - 'statistics': self.stats, - }) + info.update( + { + "model_path": self.model_path, + "backend": "HuggingFace Transformers", + "native_video": False, # We extract frames externally + "native_audio": True, + "device_map": self.device_map, + "torch_dtype": self.torch_dtype, + "attn_implementation": self.attn_implementation, + "default_max_frames": self.default_max_frames, + "default_min_frames": self.default_min_frames, + "audio_sample_rate": self.audio_sample_rate, + "audio_mono": self.audio_mono, + "statistics": self.stats, + } + ) return info + def _clear_memory(self): - """Aggressively clear GPU memory""" + """Aggressively clear GPU memory.""" # Clear gradients if self.model is not None: self.model.zero_grad(set_to_none=True) - + # Clear all cached tensors import gc + gc.collect() - + if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() # Clear IPC cache torch.cuda.ipc_collect() + def get_statistics(self) -> Dict[str, Any]: - """Get processing statistics""" - return self.stats \ No newline at end of file + """Get processing statistics.""" + return self.stats diff --git a/sonic-o1/05_evaluation_inference/models/qwen3.py b/sonic-o1/05_evaluation_inference/models/qwen3.py index f7c95e9..16abb87 100644 --- a/sonic-o1/05_evaluation_inference/models/qwen3.py +++ b/sonic-o1/05_evaluation_inference/models/qwen3.py @@ -1,22 +1,26 @@ -""" -models/qwen3_omni.py +"""qwen3.py Qwen3-Omni implementation with vLLM for efficient inference. + +Author: SONIC-O1 Team """ -import os -import logging -import time -import shutil + import gc +import logging import multiprocessing -from typing import Optional, Dict, Any, Union +import os +import shutil +import time from pathlib import Path +from typing import Any, Dict, Optional + import torch import torch.distributed as dist + try: - from vllm import LLM, SamplingParams from transformers import Qwen3OmniMoeProcessor from utils import process_mm_info + from vllm import LLM, SamplingParams except ImportError as e: raise ImportError( f"Please install required packages: {e}\n" @@ -27,106 +31,113 @@ logger = logging.getLogger(__name__) + class Qwen3Omni(BaseModel): """ Qwen3-Omni wrapper with vLLM for efficient multi-GPU inference. + Supports both Instruct and Thinking variants with audio chunking. """ - + AUDIO_TOKENS_PER_SEC = 25 VIDEO_TOKENS_PER_FRAME = 250 - - def __init__(self, model_name: str, config: Dict[str, Any]): + + def __init__(self, model_name: str, config: Dict[str, Any]) -> None: super().__init__(model_name, config) - - self.model_path = config.get('model_path', 'Qwen/Qwen3-Omni-30B-A3B-Instruct') - self.use_thinking = config.get('use_thinking', False) - - self.gpu_memory_utilization = config.get('gpu_memory_utilization', 0.85) - self.tensor_parallel_size = config.get('tensor_parallel_size', torch.cuda.device_count()) - self.max_num_seqs = config.get('max_num_seqs', 1) - self.max_model_len = config.get('max_model_len', 65536) - - gen_config = config.get('generation_config', {}) - self.temperature = gen_config.get('temperature', 0.0) - self.top_p = gen_config.get('top_p', 0.95) - self.top_k = gen_config.get('top_k', 20) - self.max_tokens = gen_config.get('max_new_tokens', 8192) - - self.default_max_frames = config.get('max_frames', 256) - self.default_min_frames = config.get('min_frames', 64) - + + self.model_path = config.get("model_path", "Qwen/Qwen3-Omni-30B-A3B-Instruct") + self.use_thinking = config.get("use_thinking", False) + + self.gpu_memory_utilization = config.get("gpu_memory_utilization", 0.85) + self.tensor_parallel_size = config.get( + "tensor_parallel_size", torch.cuda.device_count() + ) + self.max_num_seqs = config.get("max_num_seqs", 1) + self.max_model_len = config.get("max_model_len", 65536) + + gen_config = config.get("generation_config", {}) + self.temperature = gen_config.get("temperature", 0.0) + self.top_p = gen_config.get("top_p", 0.95) + self.top_k = gen_config.get("top_k", 20) + self.max_tokens = gen_config.get("max_new_tokens", 8192) + + self.default_max_frames = config.get("max_frames", 256) + self.default_min_frames = config.get("min_frames", 64) + # Audio config self.audio_feature_rate = 25.0 # Qwen3 uses 25 tokens/sec for audio - - self.limit_mm_per_prompt = config.get('limit_mm_per_prompt', { - 'image': 1, - 'video': 1, - 'audio': 1 - }) - + + self.limit_mm_per_prompt = config.get( + "limit_mm_per_prompt", {"image": 1, "video": 1, "audio": 1} + ) + self.llm = None self.processor = None - + self.stats = { - 'total_samples': 0, - 'audio_chunks_sampled': 0, + "total_samples": 0, + "audio_chunks_sampled": 0, } - + def _clear_vllm_cache(self): - vllm_cache = Path(os.environ.get('VLLM_CACHE_ROOT', Path.home() / '.cache/vllm')) - mm_cache = vllm_cache / 'multimodal_cache' + vllm_cache = Path( + os.environ.get("VLLM_CACHE_ROOT", Path.home() / ".cache/vllm") + ) + mm_cache = vllm_cache / "multimodal_cache" if mm_cache.exists(): try: shutil.rmtree(mm_cache, ignore_errors=True) logger.debug(f"Cleared multimodal cache: {mm_cache}") except Exception as e: logger.warning(f"Failed to clear cache: {e}") - + def _is_engine_alive(self) -> bool: if self.llm is None: return False - + try: - test_output = self.llm.generate([{ - 'prompt': 'test', - 'multi_modal_data': {} - }], SamplingParams(max_tokens=1)) + self.llm.generate( + [{"prompt": "test", "multi_modal_data": {}}], + SamplingParams(max_tokens=1), + ) return True except Exception: return False - + def _reload_engine(self): logger.warning("Engine crashed, attempting reload") - + try: self.unload() except Exception as e: logger.warning(f"Error during unload: {e}") - + self._clear_vllm_cache() time.sleep(15) - + try: self.load() logger.info("Engine reloaded successfully") except Exception as e: logger.error(f"Failed to reload engine: {e}") raise RuntimeError(f"Could not recover from engine crash: {e}") - - def load(self): + + def load(self) -> None: + """Load the Qwen3 model with vLLM.""" try: self._clear_vllm_cache() - - os.environ['VLLM_USE_V1'] = '0' - os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' + + os.environ["VLLM_USE_V1"] = "0" + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" if self.max_model_len > 65536: - os.environ['VLLM_ALLOW_LONG_MAX_MODEL_LEN'] = '1' - + os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1" + logger.info(f"Loading Qwen3-Omni model from {self.model_path} with vLLM") - logger.info(f"Using {self.tensor_parallel_size} GPUs for tensor parallelism") + logger.info( + f"Using {self.tensor_parallel_size} GPUs for tensor parallelism" + ) logger.info(f"Context length: {self.max_model_len} tokens") - + self.llm = LLM( model=self.model_path, trust_remote_code=True, @@ -141,14 +152,16 @@ def load(self): enable_prefix_caching=False, mm_processor_kwargs={"cache_gb": 0}, ) - + self.processor = Qwen3OmniMoeProcessor.from_pretrained(self.model_path) - - logger.info(f"Successfully loaded Qwen3-Omni with vLLM ({'Thinking' if self.use_thinking else 'Instruct'} mode)") - + + logger.info( + f"Successfully loaded Qwen3-Omni with vLLM ({'Thinking' if self.use_thinking else 'Instruct'} mode)" + ) + except Exception as e: raise RuntimeError(f"Failed to load Qwen3-Omni model with vLLM: {e}") - + def generate( self, frames: Optional[str], @@ -158,12 +171,13 @@ def generate( video_category: Optional[str] = None, max_frames: Optional[int] = None, max_audio_chunks: Optional[int] = None, - **kwargs + **kwargs, ) -> str: """ Generate response from video and/or audio. + Supports modality ablation: video-only, audio-only, or both. - + Args: frames: Video file path (str) or None for audio-only mode audio: Audio file path (str) or None for video-only mode @@ -172,26 +186,28 @@ def generate( video_category: Ignored (kept for API compatibility) max_frames: Maximum frames to use (set by external retry) max_audio_chunks: Maximum audio chunks (set by external retry) - **kwargs: Additional generation parameters - + **kwargs: Additional generation parameters. + Returns: - Generated text response + Generated text response. """ if self.llm is None or self.processor is None: try: - logger.warning("Model found unloaded in generate(), attempting lazy load...") + logger.warning( + "Model found unloaded in generate(), attempting lazy load..." + ) self.load() - except Exception as e: + except Exception: raise RuntimeError("Model not loaded. Call load() first.") - + # Validate: at least one modality must be provided if frames is None and audio is None: raise ValueError("At least one of 'frames' or 'audio' must be provided") - + # Determine active modalities has_video = frames is not None has_audio = audio is not None - + # Log modality mode if has_video and has_audio: modality_mode = "video+audio" @@ -200,24 +216,30 @@ def generate( else: modality_mode = "audio-only" logger.info(f"Modality mode: {modality_mode}") - + # Validate video if provided if has_video: if not isinstance(frames, str): - raise ValueError(f"Qwen3-Omni requires video file path (str), got {type(frames)}") + raise ValueError( + f"Qwen3-Omni requires video file path (str), got {type(frames)}" + ) video_path = Path(frames) if not video_path.exists(): raise FileNotFoundError(f"Video file not found: {video_path}") - + # Use external max_frames or default - actual_max_frames = max_frames if max_frames is not None else self.default_max_frames - + actual_max_frames = ( + max_frames if max_frames is not None else self.default_max_frames + ) + try: - logger.info(f"Processing: frames={actual_max_frames if has_video else 'N/A'}, max_audio_chunks={max_audio_chunks}") - + logger.info( + f"Processing: frames={actual_max_frames if has_video else 'N/A'}, max_audio_chunks={max_audio_chunks}" + ) + # Build content content = [] - + # Add video if provided if has_video: video_content = { @@ -226,17 +248,18 @@ def generate( "max_frames": actual_max_frames, "min_frames": self.default_min_frames, } - + if fps is not None: video_content["fps"] = fps - + content.append(video_content) - + # Add audio if provided - check if it has actual audio data if has_audio and isinstance(audio, str) and os.path.exists(audio): # Quick check if audio file has actual audio stream try: import av + test_container = av.open(audio) if len(test_container.streams.audio) > 0: content.append({"type": "audio", "audio": str(audio)}) @@ -248,124 +271,140 @@ def generate( logger.warning(f"Could not verify audio file {audio}: {e}") # Still try to add it content.append({"type": "audio", "audio": str(audio)}) - + content.append({"type": "text", "text": prompt}) - + conversation = [{"role": "user", "content": content}] - + # Apply chat template text = self.processor.apply_chat_template( conversation, tokenize=False, add_generation_prompt=True ) - + # Process multimodal with chunking (not truncation) audios, images, videos = process_mm_info( conversation, use_audio_in_video=False, max_audio_duration=None, # Don't truncate max_audio_chunks=max_audio_chunks, # Use chunking instead - audio_chunk_duration_sec=kwargs.get('audio_chunk_duration_sec', 10.0) + audio_chunk_duration_sec=kwargs.get("audio_chunk_duration_sec", 10.0), ) - + # Filter out empty audio arrays if audios is not None: audios = [a for a in audios if len(a) > 0] if len(audios) == 0: audios = None - + # Remove audio pad token from text if no audio (safety check) - if audios is None and '<|audio_pad|>' in text: - text = text.replace('<|audio_pad|>', '').strip() + if audios is None and "<|audio_pad|>" in text: + text = text.replace("<|audio_pad|>", "").strip() logger.info("Removed <|audio_pad|> from prompt (no audio available)") - + # Track stats if max_audio_chunks is not None and audios is not None: - self.stats['audio_chunks_sampled'] += 1 - + self.stats["audio_chunks_sampled"] += 1 + # Track stats - self.stats['total_samples'] += 1 - + self.stats["total_samples"] += 1 + # Build inputs - inputs = {'prompt': text, 'multi_modal_data': {}} - + inputs = {"prompt": text, "multi_modal_data": {}} + if audios is not None: - inputs['multi_modal_data']['audio'] = audios + inputs["multi_modal_data"]["audio"] = audios if images is not None: - inputs['multi_modal_data']['image'] = images + inputs["multi_modal_data"]["image"] = images if videos is not None: - inputs['multi_modal_data']['video'] = videos - + inputs["multi_modal_data"]["video"] = videos + # Sampling params - temperature = kwargs.get('temperature', self.temperature) - top_p = kwargs.get('top_p', self.top_p) - top_k = kwargs.get('top_k', self.top_k) - max_tokens = kwargs.get('max_new_tokens', self.max_tokens) - + temperature = kwargs.get("temperature", self.temperature) + top_p = kwargs.get("top_p", self.top_p) + top_k = kwargs.get("top_k", self.top_k) + max_tokens = kwargs.get("max_new_tokens", self.max_tokens) + sampling_params = SamplingParams( temperature=temperature, top_p=top_p, top_k=top_k, max_tokens=max_tokens, ) - + logger.info("Generating response...") - + # Generate outputs = self.llm.generate([inputs], sampling_params=sampling_params) response_text = outputs[0].outputs[0].text - + logger.info(f"Generated response ({len(response_text)} chars)") - + return self.postprocess_output(response_text) - + except Exception as e: error_msg = str(e) - + # Detect error types - is_cache_error = "Expected a cached item" in error_msg or "mm_hash" in error_msg or "AssertionError" in error_msg - is_engine_dead = "EngineDeadError" in error_msg or "EngineCore" in error_msg or "process_input_sockets" in error_msg - is_context_error = any(keyword in error_msg.lower() for keyword in [ - 'context', 'token', 'length', 'limit', 'maximum', 'exceed', 'longer than' - ]) + is_cache_error = ( + "Expected a cached item" in error_msg + or "mm_hash" in error_msg + or "AssertionError" in error_msg + ) + is_engine_dead = ( + "EngineDeadError" in error_msg + or "EngineCore" in error_msg + or "process_input_sockets" in error_msg + ) + is_context_error = any( + keyword in error_msg.lower() + for keyword in [ + "context", + "token", + "length", + "limit", + "maximum", + "exceed", + "longer than", + ] + ) is_oom = "out of memory" in error_msg.lower() or "OOM" in error_msg - + # Handle specific errors if is_engine_dead or is_cache_error: logger.error(f"Engine/Cache error: {e}") self._reload_engine() raise RuntimeError(f"Engine/cache error (engine reloaded): {e}") - - elif is_context_error: + + if is_context_error: logger.error(f"Context length error: {e}") - self._reload_engine() + self._reload_engine() raise RuntimeError(f"Context length exceeded: {e}") - - elif is_oom: + + if is_oom: logger.error(f"OOM error: {e}") self.unload() if torch.cuda.is_available(): torch.cuda.empty_cache() - self._clear_vllm_cache() + self._clear_vllm_cache() self.load() raise RuntimeError(f"Out of memory (engine reloaded): {e}") - - else: - logger.error(f"Generation failed: {e}", exc_info=True) - raise RuntimeError(f"Generation failed: {e}") - - def unload(self): - """Aggressively cleanup vLLM to prevent zombie processes""" + + logger.error(f"Generation failed: {e}", exc_info=True) + raise RuntimeError(f"Generation failed: {e}") + + def unload(self) -> None: + """Aggressively cleanup vLLM to prevent zombie processes.""" if self.llm is not None: try: del self.llm except Exception as e: logger.warning(f"Error deleting llm object: {e}") self.llm = None - + if self.processor is not None: del self.processor self.processor = None - + gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -381,11 +420,13 @@ def unload(self): logger.info("Distributed process group destroyed") except Exception as e: logger.warning(f"Failed to destroy process group: {e}") - + try: active_children = multiprocessing.active_children() if active_children: - logger.info(f"Found {len(active_children)} active child processes. Terminating...") + logger.info( + f"Found {len(active_children)} active child processes. Terminating..." + ) for child in active_children: try: child.terminate() @@ -396,26 +437,30 @@ def unload(self): logger.warning(f"Failed to kill child {child.pid}: {e}") except Exception as e: logger.warning(f"Error during manual process cleanup: {e}") - + logger.info("Model unloaded, memory cleared, and child processes terminated") - + def get_model_info(self) -> Dict[str, Any]: + """Get model information and configuration.""" info = super().get_model_info() - info.update({ - 'model_path': self.model_path, - 'model_type': 'Thinking' if self.use_thinking else 'Instruct', - 'backend': 'vLLM', - 'native_video': True, - 'native_audio': True, - 'tensor_parallel_size': self.tensor_parallel_size, - 'gpu_memory_utilization': self.gpu_memory_utilization, - 'default_max_frames': self.default_max_frames, - 'default_min_frames': self.default_min_frames, - 'max_model_len': self.max_model_len, - 'audio_feature_rate': self.audio_feature_rate, - 'statistics': self.stats, - }) + info.update( + { + "model_path": self.model_path, + "model_type": "Thinking" if self.use_thinking else "Instruct", + "backend": "vLLM", + "native_video": True, + "native_audio": True, + "tensor_parallel_size": self.tensor_parallel_size, + "gpu_memory_utilization": self.gpu_memory_utilization, + "default_max_frames": self.default_max_frames, + "default_min_frames": self.default_min_frames, + "max_model_len": self.max_model_len, + "audio_feature_rate": self.audio_feature_rate, + "statistics": self.stats, + } + ) return info - + def get_statistics(self) -> Dict[str, Any]: - return self.stats \ No newline at end of file + """Get model statistics.""" + return self.stats diff --git a/sonic-o1/05_evaluation_inference/models/unimoe.py b/sonic-o1/05_evaluation_inference/models/unimoe.py index 1c55081..6e8b047 100644 --- a/sonic-o1/05_evaluation_inference/models/unimoe.py +++ b/sonic-o1/05_evaluation_inference/models/unimoe.py @@ -1,221 +1,245 @@ -""" -models/unimoe.py +"""unimoe.py Uni-MoE-2.0-Omni implementation with video and audio support. -Supports both single-GPU and multi-GPU inference. + +Author: SONIC-O1 Team """ + +import logging import os import sys -import logging -from typing import Optional, Dict, Any, Union from pathlib import Path +from typing import Any, Dict, Optional + try: - import torch import deepspeed + import torch import torch.distributed as dist except ImportError as e: raise ImportError( - f"Please install required packages: {e}\n" - "pip install torch deepspeed" + f"Please install required packages: {e}\npip install torch deepspeed" ) from .base_model import BaseModel + logger = logging.getLogger(__name__) class UniMoe(BaseModel): """ Uni-MoE-2.0-Omni wrapper with native video and audio support. + Supports both single-GPU and multi-GPU inference modes. """ - - def __init__(self, model_name: str, config: Dict[str, Any]): + + def __init__(self, model_name: str, config: Dict[str, Any]) -> None: super().__init__(model_name, config) - + # Model configuration - self.model_path = config.get('model_path', 'HIT-TMG/Uni-MoE-2.0-Omni') - + self.model_path = config.get("model_path", "HIT-TMG/Uni-MoE-2.0-Omni") + # Handle Uni-MoE package path - self.unimoe_package_path = config.get('unimoe_package_path', None) + self.unimoe_package_path = config.get("unimoe_package_path") if self.unimoe_package_path: unimoe_path = str(Path(self.unimoe_package_path).resolve()) if unimoe_path not in sys.path: sys.path.insert(0, unimoe_path) logger.info(f"Added Uni-MoE package path: {unimoe_path}") - + # Import Uni-MoE components after path is set try: - from uni_moe.model.processing_qwen2_vl import Qwen2VLProcessor - from uni_moe.model.modeling_out import GrinQwen2VLOutForConditionalGeneration - from uni_moe.qwen_vl_utils import process_mm_info # Import inference utils to patch DeepSpeed MoE for single-machine inference from uni_moe.model import deepspeed_moe_inference_utils - + from uni_moe.model.modeling_out import ( + GrinQwen2VLOutForConditionalGeneration, + ) + from uni_moe.model.processing_qwen2_vl import Qwen2VLProcessor + from uni_moe.qwen_vl_utils import process_mm_info + self.Qwen2VLProcessor = Qwen2VLProcessor - self.GrinQwen2VLOutForConditionalGeneration = GrinQwen2VLOutForConditionalGeneration + self.GrinQwen2VLOutForConditionalGeneration = ( + GrinQwen2VLOutForConditionalGeneration + ) self.process_mm_info = process_mm_info except ImportError as e: raise ImportError( f"Failed to import Uni-MoE components: {e}\n" "Please ensure 'unimoe_package_path' is set in config or Uni-MoE is in PYTHONPATH" ) - + # Device configuration # Options: 'cuda:0' (single GPU), 'auto' (multi-GPU), or specific device - self.device = config.get('device', 'cuda:0') - self.multi_gpu = self.device == 'auto' - - self.dtype = config.get('dtype', 'bfloat16') - + self.device = config.get("device", "cuda:0") + self.multi_gpu = self.device == "auto" + + self.dtype = config.get("dtype", "bfloat16") + # Parse dtype - if self.dtype == 'bfloat16': + if self.dtype == "bfloat16": self.torch_dtype = torch.bfloat16 - elif self.dtype == 'float16': + elif self.dtype == "float16": self.torch_dtype = torch.float16 - elif self.dtype == 'float32': + elif self.dtype == "float32": self.torch_dtype = torch.float32 else: self.torch_dtype = torch.bfloat16 - + # Generation config - gen_config = config.get('generation_config', {}) - self.temperature = gen_config.get('temperature', 0.7) - self.top_p = gen_config.get('top_p', 0.95) - self.max_new_tokens = gen_config.get('max_new_tokens', 2048) - + gen_config = config.get("generation_config", {}) + self.temperature = gen_config.get("temperature", 0.7) + self.top_p = gen_config.get("top_p", 0.95) + self.max_new_tokens = gen_config.get("max_new_tokens", 2048) + # Video processing config - self.default_max_frames = config.get('max_frames', 480) - self.default_min_frames = config.get('min_frames', 64) - + self.default_max_frames = config.get("max_frames", 480) + self.default_min_frames = config.get("min_frames", 64) + # DeepSpeed initialization flag (only for single-GPU mode) self._deepspeed_initialized = False - + self.model = None self.processor = None - + def _init_deepspeed_single_gpu(self): - """Initialize DeepSpeed for single-GPU mode (not needed for multi-GPU)""" + """Initialize DeepSpeed for single-GPU mode (not needed for multi-GPU).""" if self.multi_gpu: logger.info("Multi-GPU mode - skipping DeepSpeed initialization") return - + if self._deepspeed_initialized or dist.is_initialized(): logger.info("DeepSpeed already initialized, skipping...") return - + try: logger.info("Initializing DeepSpeed for single-GPU mode...") - + # Set environment variables for single-GPU distributed setup os.environ["RANK"] = "0" os.environ["WORLD_SIZE"] = "1" os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "29500" os.environ["LOCAL_RANK"] = "0" - + # Initialize DeepSpeed distributed backend deepspeed.init_distributed(dist_backend="nccl") - + self._deepspeed_initialized = True logger.info("DeepSpeed initialized successfully") - + except Exception as e: logger.warning(f"DeepSpeed initialization warning: {e}") # Continue anyway - might still work - - def load(self): - """Load the Uni-MoE model and processor""" + + def load(self) -> None: + """Load the Uni-MoE model and processor.""" try: logger.info(f"Loading Uni-MoE model from {self.model_path}") - + if self.multi_gpu: num_gpus = torch.cuda.device_count() - logger.info(f"Multi-GPU mode enabled - using {num_gpus} GPUs with device_map='auto'") + logger.info( + f"Multi-GPU mode enabled - using {num_gpus} GPUs with device_map='auto'" + ) else: logger.info(f"Single-GPU mode - using device: {self.device}") # Initialize DeepSpeed for single-GPU self._init_deepspeed_single_gpu() - + # Load processor self.processor = self.Qwen2VLProcessor.from_pretrained(self.model_path) - + # Load model with appropriate device mapping if self.multi_gpu: # Multi-GPU: use device_map="auto" for automatic layer distribution - self.model = self.GrinQwen2VLOutForConditionalGeneration.from_pretrained( - self.model_path, - torch_dtype=self.torch_dtype, - device_map="auto", # Automatically split across GPUs - low_cpu_mem_usage=True, + self.model = ( + self.GrinQwen2VLOutForConditionalGeneration.from_pretrained( + self.model_path, + torch_dtype=self.torch_dtype, + device_map="auto", # Automatically split across GPUs + low_cpu_mem_usage=True, + ) ) - + # Print device distribution - if hasattr(self.model, 'hf_device_map'): + if hasattr(self.model, "hf_device_map"): logger.info("=== Device Map ===") device_distribution = {} - for name, device in self.model.hf_device_map.items(): - device_distribution[device] = device_distribution.get(device, 0) + 1 + for _name, device in self.model.hf_device_map.items(): + device_distribution[device] = ( + device_distribution.get(device, 0) + 1 + ) for device, count in sorted(device_distribution.items()): logger.info(f" {device}: {count} modules") logger.info("==================") else: # Single-GPU: load normally and move to specified device - self.model = self.GrinQwen2VLOutForConditionalGeneration.from_pretrained( - self.model_path, - torch_dtype=self.torch_dtype, - low_cpu_mem_usage=True, + self.model = ( + self.GrinQwen2VLOutForConditionalGeneration.from_pretrained( + self.model_path, + torch_dtype=self.torch_dtype, + low_cpu_mem_usage=True, + ) ) self.model.to(self.device) - + # Set processor data args from model config self.processor.data_args = self.model.config - + if self.multi_gpu: - logger.info(f"Successfully loaded Uni-MoE across {torch.cuda.device_count()} GPUs") + logger.info( + f"Successfully loaded Uni-MoE across {torch.cuda.device_count()} GPUs" + ) else: logger.info(f"Successfully loaded Uni-MoE on {self.device}") - + except Exception as e: raise RuntimeError(f"Failed to load Uni-MoE model: {e}") - - def convert_av1_to_h264(self, video_path: Path, output_dir: Optional[Path] = None) -> Path: + + def convert_av1_to_h264( + self, video_path: Path, output_dir: Optional[Path] = None + ) -> Path: """ Convert AV1 video to H.264 for Decord compatibility. - + Args: video_path: Input video path output_dir: Output directory (default: creates 'converted' subdir) - + Returns: Path to converted video """ import subprocess - + if output_dir is None: output_dir = video_path.parent / "converted" output_dir.mkdir(exist_ok=True) - + output_path = output_dir / f"{video_path.stem}_h264{video_path.suffix}" - + # Check if already converted if output_path.exists(): logger.info(f"Using cached converted video: {output_path.name}") return output_path - + logger.info(f"Converting AV1 to H.264: {video_path.name}") - + cmd = [ - 'ffmpeg', - '-i', str(video_path), - '-c:v', 'libx264', - '-preset', 'fast', - '-crf', '23', - '-c:a', 'copy', - '-y', - str(output_path) + "ffmpeg", + "-i", + str(video_path), + "-c:v", + "libx264", + "-preset", + "fast", + "-crf", + "23", + "-c:a", + "copy", + "-y", + str(output_path), ] - + try: subprocess.run(cmd, check=True, capture_output=True, text=True) logger.info(f"โœ“ Conversion successful: {output_path.name}") @@ -223,46 +247,57 @@ def convert_av1_to_h264(self, video_path: Path, output_dir: Optional[Path] = Non except subprocess.CalledProcessError as e: logger.error(f"โœ— Conversion failed: {e.stderr}") raise RuntimeError(f"Failed to convert AV1 video: {e.stderr}") - + def _check_video_compatibility(self, video_path: Path) -> Optional[Path]: """ Check if video is compatible with Decord. + If AV1, automatically convert to H.264. - + Args: video_path: Original video path - + Returns: Path to compatible video (original or converted) """ - from decord import VideoReader, cpu import subprocess - + + from decord import VideoReader, cpu + try: vr = VideoReader(str(video_path), ctx=cpu(0), num_threads=1) frame_count = len(vr) del vr logger.info(f"โœ“ Video compatible: {video_path.name} ({frame_count} frames)") return video_path - except Exception as e: + except Exception: logger.warning(f"โš  Video incompatible with Decord: {video_path.name}") - + # Detect codec try: result = subprocess.run( - ['ffprobe', '-v', 'error', '-select_streams', 'v:0', - '-show_entries', 'stream=codec_name', - '-of', 'default=noprint_wrappers=1:nokey=1', - str(video_path)], - capture_output=True, text=True, check=True + [ + "ffprobe", + "-v", + "error", + "-select_streams", + "v:0", + "-show_entries", + "stream=codec_name", + "-of", + "default=noprint_wrappers=1:nokey=1", + str(video_path), + ], + capture_output=True, + text=True, + check=True, ) codec = result.stdout.strip() - - if codec == 'av1': - logger.info(f" โ†’ Detected AV1 codec - converting to H.264...") + + if codec == "av1": + logger.info(" โ†’ Detected AV1 codec - converting to H.264...") try: - converted_path = self.convert_av1_to_h264(video_path) - return converted_path + return self.convert_av1_to_h264(video_path) except Exception as conv_error: logger.error(f" โœ— Conversion failed: {conv_error}") return None @@ -272,21 +307,21 @@ def _check_video_compatibility(self, video_path: Path) -> Optional[Path]: except Exception as probe_error: logger.error(f" โœ— Failed to detect codec: {probe_error}") return None - + def generate( self, frames: Optional[str] = None, # Video file path (now optional) - audio: Optional[str] = None, # Audio file path (now optional) + audio: Optional[str] = None, # Audio file path (now optional) prompt: str = "", fps: Optional[float] = None, video_category: Optional[str] = None, max_frames: Optional[int] = None, max_audio_chunks: Optional[int] = None, - **kwargs + **kwargs, ) -> str: """ Generate response from video and/or audio. - + Args: frames: Path to video file (optional) audio: Path to audio file (optional) @@ -296,36 +331,35 @@ def generate( max_frames: Maximum frames for video processing max_audio_chunks: Unused for Uni-MoE **kwargs: Additional generation parameters - + Returns: Generated text response """ if self.model is None or self.processor is None: raise RuntimeError("Model not loaded. Call load() first.") - + # Determine modality mode if frames is None and audio is None: raise ValueError("At least one of 'frames' or 'audio' must be provided") - + has_video = frames is not None has_audio = audio is not None - + if has_video and has_audio: modality_mode = "video+audio" - modal_type = 'video' elif has_video: modality_mode = "video-only" - modal_type = 'video' else: modality_mode = "audio-only" - modal_type = 'audio' - + logger.info(f"Modality mode: {modality_mode}") - + # Use max_frames from parameter if provided, otherwise use default - actual_max_frames = max_frames if max_frames is not None else self.default_max_frames + actual_max_frames = ( + max_frames if max_frames is not None else self.default_max_frames + ) actual_min_frames = self.default_min_frames - + try: # Handle video path if present video_path = None @@ -340,11 +374,11 @@ def generate( video_path = self._check_video_compatibility(video_path) if video_path is None: raise RuntimeError( - f"Video codec incompatible with Decord (likely AV1). " - f"Skipping this video." + "Video codec incompatible with Decord (likely AV1). " + "Skipping this video." ) logger.info(f"Processing video: {video_path.name}") - + # Handle audio path if present audio_path = None if has_audio: @@ -356,7 +390,7 @@ def generate( if not audio_path.exists(): raise FileNotFoundError(f"Audio file not found: {audio_path}") logger.info(f"Processing audio: {audio_path.name}") - + # Build text prompt with appropriate modality tokens text_prompt_parts = [] if has_video: @@ -365,45 +399,52 @@ def generate( text_prompt_parts.append("