Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions backend/src/vmarker/api/routes/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
MAX_DURATION = 300 # 5 分钟
PARALLEL_THRESHOLD_SECONDS = 180 # 超过 3 分钟自动使用并行合成
ALLOWED_EXTENSIONS = {".mp4", ".mov", ".webm", ".mkv", ".avi"}
UPLOAD_CHUNK_SIZE = 1024 * 1024 # 1MB


# =============================================================================
Expand Down Expand Up @@ -133,16 +134,26 @@ async def upload_video(
if ext not in ALLOWED_EXTENSIONS:
raise HTTPException(400, f"不支持的文件格式: {ext},支持: {', '.join(ALLOWED_EXTENSIONS)}")

# 读取文件内容
content = await file.read()

# 验证文件大小
if len(content) > MAX_FILE_SIZE:
raise HTTPException(400, f"文件大小超出限制 ({MAX_FILE_SIZE // 1024 // 1024}MB)")

# 创建会话并保存文件
session = TempSession()
video_path = session.save_upload(f"source{ext}", content)
video_path = session.get_path(f"source{ext}")

# 流式写入上传内容,避免一次性读入整个文件
total_size = 0
try:
with video_path.open("wb") as output:
while chunk := await file.read(UPLOAD_CHUNK_SIZE):
total_size += len(chunk)
if total_size > MAX_FILE_SIZE:
raise HTTPException(400, f"文件大小超出限制 ({MAX_FILE_SIZE // 1024 // 1024}MB)")
output.write(chunk)
except HTTPException:
session.cleanup()
raise
except Exception as e:
session.cleanup()
raise HTTPException(500, f"文件保存失败: {e}") from e
finally:
await file.close()

# 探测视频信息
try:
Expand Down
87 changes: 87 additions & 0 deletions backend/tests/test_video_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""
[INPUT]: 依赖 pytest, FastAPI TestClient, vmarker.api.main
[OUTPUT]: video API 路由测试
[POS]: tests/ 的视频上传接口测试
[PROTOCOL]: 变更时更新此头部,然后检查 CLAUDE.md
"""

from dataclasses import dataclass

import pytest
from fastapi import status
from fastapi.testclient import TestClient

from vmarker.api.main import app
from vmarker.api.routes import video as video_route
from vmarker import temp_manager


@dataclass
class FakeVideoInfo:
duration: float = 12.5
width: int = 1920
height: int = 1080
fps: float = 30.0
codec: str = "h264"
file_size: int = 0


@pytest.fixture
def client():
with TestClient(app) as test_client:
yield test_client


def test_upload_video_streams_to_disk(client, tmp_path, monkeypatch):
"""上传接口应分块写入并保留完整文件"""
monkeypatch.setattr(temp_manager, "BASE_DIR", tmp_path)

content = b"a" * (video_route.UPLOAD_CHUNK_SIZE + 17)

def fake_validate_video(video_path, max_duration, max_size_mb):
assert video_path.read_bytes() == content
return FakeVideoInfo(file_size=len(content))

monkeypatch.setattr(video_route.video_probe, "validate_video", fake_validate_video)

response = client.post(
"/api/v1/video/upload",
files={"file": ("sample.mp4", content, "video/mp4")},
)

assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["width"] == 1920
assert data["height"] == 1080
assert data["fps"] == 30.0
assert data["file_size_mb"] == pytest.approx(len(content) / 1024 / 1024)

session_dir = tmp_path / data["session_id"]
assert (session_dir / "source.mp4").read_bytes() == content


def test_upload_video_rejects_oversized_file_and_cleans_session(client, tmp_path, monkeypatch):
"""超限上传应立即失败并清理临时目录"""
monkeypatch.setattr(temp_manager, "BASE_DIR", tmp_path)
monkeypatch.setattr(video_route, "MAX_FILE_SIZE", video_route.UPLOAD_CHUNK_SIZE)

validate_called = False

def fake_validate_video(*args, **kwargs):
nonlocal validate_called
validate_called = True
raise AssertionError("validate_video should not be called for oversized uploads")

monkeypatch.setattr(video_route.video_probe, "validate_video", fake_validate_video)

content = b"b" * (video_route.UPLOAD_CHUNK_SIZE + 1)

response = client.post(
"/api/v1/video/upload",
files={"file": ("oversized.mp4", content, "video/mp4")},
)

assert response.status_code == status.HTTP_400_BAD_REQUEST
assert "文件大小超出限制" in response.text
assert validate_called is False
assert list(tmp_path.iterdir()) == []