Skip to content

Commit e66b34c

Browse files
Normalize extract schema mappings safely
Co-authored-by: Shri Sukhani <shrisukhani@users.noreply.github.com>
1 parent 5312dc5 commit e66b34c

2 files changed

Lines changed: 162 additions & 2 deletions

File tree

hyperbrowser/tools/__init__.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,33 @@ def _format_tool_param_key_for_error(key: str) -> str:
5252
return f"{normalized_key[:available_length]}{_TRUNCATED_KEY_DISPLAY_SUFFIX}"
5353

5454

55+
def _normalize_extract_schema_mapping(schema_value: MappingABC[object, Any]) -> Dict[str, Any]:
56+
try:
57+
schema_keys = list(schema_value.keys())
58+
except HyperbrowserError:
59+
raise
60+
except Exception as exc:
61+
raise HyperbrowserError(
62+
"Failed to read extract tool `schema` object keys",
63+
original_error=exc,
64+
) from exc
65+
normalized_schema: Dict[str, Any] = {}
66+
for key in schema_keys:
67+
if not isinstance(key, str):
68+
raise HyperbrowserError("Extract tool `schema` object keys must be strings")
69+
try:
70+
normalized_schema[key] = schema_value[key]
71+
except HyperbrowserError:
72+
raise
73+
except Exception as exc:
74+
key_display = _format_tool_param_key_for_error(key)
75+
raise HyperbrowserError(
76+
f"Failed to read extract tool `schema` value for key '{key_display}'",
77+
original_error=exc,
78+
) from exc
79+
return normalized_schema
80+
81+
5582
def _prepare_extract_tool_params(params: Mapping[str, Any]) -> Dict[str, Any]:
5683
normalized_params = _to_param_dict(params)
5784
schema_value = normalized_params.get("schema")
@@ -69,11 +96,13 @@ def _prepare_extract_tool_params(params: Mapping[str, Any]) -> Dict[str, Any]:
6996
"Invalid JSON string provided for `schema` in extract tool params",
7097
original_error=exc,
7198
) from exc
72-
if parsed_schema is not None and not isinstance(parsed_schema, MappingABC):
99+
if not isinstance(parsed_schema, MappingABC):
73100
raise HyperbrowserError(
74101
"Extract tool `schema` must decode to a JSON object"
75102
)
76-
normalized_params["schema"] = parsed_schema
103+
normalized_params["schema"] = _normalize_extract_schema_mapping(parsed_schema)
104+
elif isinstance(schema_value, MappingABC):
105+
normalized_params["schema"] = _normalize_extract_schema_mapping(schema_value)
77106
return normalized_params
78107

79108

tests/test_tools_extract.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import asyncio
2+
from collections.abc import Mapping
3+
from types import MappingProxyType
24

35
import pytest
46

@@ -171,6 +173,135 @@ async def run():
171173
asyncio.run(run())
172174

173175

176+
def test_extract_tool_runnable_rejects_null_schema_json():
177+
client = _SyncClient()
178+
params = {
179+
"urls": ["https://example.com"],
180+
"schema": "null",
181+
}
182+
183+
with pytest.raises(
184+
HyperbrowserError, match="Extract tool `schema` must decode to a JSON object"
185+
):
186+
WebsiteExtractTool.runnable(client, params)
187+
188+
189+
def test_extract_tool_runnable_normalizes_mapping_schema_values():
190+
client = _SyncClient()
191+
schema_mapping = MappingProxyType({"type": "object", "properties": {}})
192+
193+
WebsiteExtractTool.runnable(
194+
client,
195+
{
196+
"urls": ["https://example.com"],
197+
"schema": schema_mapping,
198+
},
199+
)
200+
201+
assert isinstance(client.extract.last_params, StartExtractJobParams)
202+
assert isinstance(client.extract.last_params.schema_, dict)
203+
assert client.extract.last_params.schema_ == {"type": "object", "properties": {}}
204+
205+
206+
def test_extract_tool_runnable_rejects_non_string_schema_keys():
207+
client = _SyncClient()
208+
209+
with pytest.raises(
210+
HyperbrowserError, match="Extract tool `schema` object keys must be strings"
211+
):
212+
WebsiteExtractTool.runnable(
213+
client,
214+
{
215+
"urls": ["https://example.com"],
216+
"schema": {1: "invalid-key"}, # type: ignore[dict-item]
217+
},
218+
)
219+
220+
221+
def test_extract_tool_runnable_wraps_schema_key_read_failures():
222+
class _BrokenSchemaMapping(Mapping[object, object]):
223+
def __iter__(self):
224+
raise RuntimeError("cannot iterate schema keys")
225+
226+
def __len__(self) -> int:
227+
return 1
228+
229+
def __getitem__(self, key: object) -> object:
230+
return key
231+
232+
client = _SyncClient()
233+
234+
with pytest.raises(
235+
HyperbrowserError, match="Failed to read extract tool `schema` object keys"
236+
) as exc_info:
237+
WebsiteExtractTool.runnable(
238+
client,
239+
{
240+
"urls": ["https://example.com"],
241+
"schema": _BrokenSchemaMapping(),
242+
},
243+
)
244+
245+
assert exc_info.value.original_error is not None
246+
247+
248+
def test_extract_tool_runnable_wraps_schema_value_read_failures():
249+
class _BrokenSchemaMapping(Mapping[str, object]):
250+
def __iter__(self):
251+
yield "type"
252+
253+
def __len__(self) -> int:
254+
return 1
255+
256+
def __getitem__(self, key: str) -> object:
257+
_ = key
258+
raise RuntimeError("cannot read schema value")
259+
260+
client = _SyncClient()
261+
262+
with pytest.raises(
263+
HyperbrowserError,
264+
match="Failed to read extract tool `schema` value for key 'type'",
265+
) as exc_info:
266+
WebsiteExtractTool.runnable(
267+
client,
268+
{
269+
"urls": ["https://example.com"],
270+
"schema": _BrokenSchemaMapping(),
271+
},
272+
)
273+
274+
assert exc_info.value.original_error is not None
275+
276+
277+
def test_extract_tool_runnable_preserves_hyperbrowser_schema_value_read_failures():
278+
class _BrokenSchemaMapping(Mapping[str, object]):
279+
def __iter__(self):
280+
yield "type"
281+
282+
def __len__(self) -> int:
283+
return 1
284+
285+
def __getitem__(self, key: str) -> object:
286+
_ = key
287+
raise HyperbrowserError("custom schema value failure")
288+
289+
client = _SyncClient()
290+
291+
with pytest.raises(
292+
HyperbrowserError, match="custom schema value failure"
293+
) as exc_info:
294+
WebsiteExtractTool.runnable(
295+
client,
296+
{
297+
"urls": ["https://example.com"],
298+
"schema": _BrokenSchemaMapping(),
299+
},
300+
)
301+
302+
assert exc_info.value.original_error is None
303+
304+
174305
def test_extract_tool_runnable_serializes_empty_object_data():
175306
client = _SyncClient(response_data={})
176307

0 commit comments

Comments
 (0)