diff --git a/backend/src/vmarker/api/routes/video.py b/backend/src/vmarker/api/routes/video.py index 4862e33..e02d209 100644 --- a/backend/src/vmarker/api/routes/video.py +++ b/backend/src/vmarker/api/routes/video.py @@ -32,6 +32,7 @@ MAX_DURATION = 300 # 5 分钟 PARALLEL_THRESHOLD_SECONDS = 180 # 超过 3 分钟自动使用并行合成 ALLOWED_EXTENSIONS = {".mp4", ".mov", ".webm", ".mkv", ".avi"} +UPLOAD_CHUNK_SIZE = 1024 * 1024 # 1MB # ============================================================================= @@ -133,16 +134,26 @@ async def upload_video( if ext not in ALLOWED_EXTENSIONS: raise HTTPException(400, f"不支持的文件格式: {ext},支持: {', '.join(ALLOWED_EXTENSIONS)}") - # 读取文件内容 - content = await file.read() - - # 验证文件大小 - if len(content) > MAX_FILE_SIZE: - raise HTTPException(400, f"文件大小超出限制 ({MAX_FILE_SIZE // 1024 // 1024}MB)") - - # 创建会话并保存文件 session = TempSession() - video_path = session.save_upload(f"source{ext}", content) + video_path = session.get_path(f"source{ext}") + + # 流式写入上传内容,避免一次性读入整个文件 + total_size = 0 + try: + with video_path.open("wb") as output: + while chunk := await file.read(UPLOAD_CHUNK_SIZE): + total_size += len(chunk) + if total_size > MAX_FILE_SIZE: + raise HTTPException(400, f"文件大小超出限制 ({MAX_FILE_SIZE // 1024 // 1024}MB)") + output.write(chunk) + except HTTPException: + session.cleanup() + raise + except Exception as e: + session.cleanup() + raise HTTPException(500, f"文件保存失败: {e}") from e + finally: + await file.close() # 探测视频信息 try: diff --git a/backend/tests/test_video_api.py b/backend/tests/test_video_api.py new file mode 100644 index 0000000..a2dab05 --- /dev/null +++ b/backend/tests/test_video_api.py @@ -0,0 +1,87 @@ +""" +[INPUT]: 依赖 pytest, FastAPI TestClient, vmarker.api.main +[OUTPUT]: video API 路由测试 +[POS]: tests/ 的视频上传接口测试 +[PROTOCOL]: 变更时更新此头部,然后检查 CLAUDE.md +""" + +from dataclasses import dataclass + +import pytest +from fastapi import status +from fastapi.testclient import TestClient + +from vmarker.api.main import app +from vmarker.api.routes import video as video_route +from vmarker import temp_manager + + +@dataclass +class FakeVideoInfo: + duration: float = 12.5 + width: int = 1920 + height: int = 1080 + fps: float = 30.0 + codec: str = "h264" + file_size: int = 0 + + +@pytest.fixture +def client(): + with TestClient(app) as test_client: + yield test_client + + +def test_upload_video_streams_to_disk(client, tmp_path, monkeypatch): + """上传接口应分块写入并保留完整文件""" + monkeypatch.setattr(temp_manager, "BASE_DIR", tmp_path) + + content = b"a" * (video_route.UPLOAD_CHUNK_SIZE + 17) + + def fake_validate_video(video_path, max_duration, max_size_mb): + assert video_path.read_bytes() == content + return FakeVideoInfo(file_size=len(content)) + + monkeypatch.setattr(video_route.video_probe, "validate_video", fake_validate_video) + + response = client.post( + "/api/v1/video/upload", + files={"file": ("sample.mp4", content, "video/mp4")}, + ) + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["width"] == 1920 + assert data["height"] == 1080 + assert data["fps"] == 30.0 + assert data["file_size_mb"] == pytest.approx(len(content) / 1024 / 1024) + + session_dir = tmp_path / data["session_id"] + assert (session_dir / "source.mp4").read_bytes() == content + + +def test_upload_video_rejects_oversized_file_and_cleans_session(client, tmp_path, monkeypatch): + """超限上传应立即失败并清理临时目录""" + monkeypatch.setattr(temp_manager, "BASE_DIR", tmp_path) + monkeypatch.setattr(video_route, "MAX_FILE_SIZE", video_route.UPLOAD_CHUNK_SIZE) + + validate_called = False + + def fake_validate_video(*args, **kwargs): + nonlocal validate_called + validate_called = True + raise AssertionError("validate_video should not be called for oversized uploads") + + monkeypatch.setattr(video_route.video_probe, "validate_video", fake_validate_video) + + content = b"b" * (video_route.UPLOAD_CHUNK_SIZE + 1) + + response = client.post( + "/api/v1/video/upload", + files={"file": ("oversized.mp4", content, "video/mp4")}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "文件大小超出限制" in response.text + assert validate_called is False + assert list(tmp_path.iterdir()) == []