Skip to content

Commit 18810d8

Browse files
dgozmanSkn0tt
andauthored
feat: add FormData class for form and multipart requests (#3060)
Co-authored-by: Simon Knott <info@simonknott.de>
1 parent b846cee commit 18810d8

11 files changed

Lines changed: 401 additions & 96 deletions

File tree

playwright/_impl/_fetch.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import base64
1616
import json
17+
import mimetypes
1718
import pathlib
1819
import typing
1920
from pathlib import Path
@@ -32,6 +33,7 @@
3233
)
3334
from playwright._impl._connection import ChannelOwner, from_channel
3435
from playwright._impl._errors import is_target_closed_error
36+
from playwright._impl._form_data import FormData
3537
from playwright._impl._helper import (
3638
Error,
3739
NameValue,
@@ -51,9 +53,9 @@
5153
from playwright._impl._playwright import Playwright
5254

5355

54-
FormType = Dict[str, Union[bool, float, str]]
56+
FormType = Union[Dict[str, Union[bool, float, str]], FormData]
5557
DataType = Union[Any, bytes, str]
56-
MultipartType = Dict[str, Union[bytes, bool, float, str, FilePayload]]
58+
MultipartType = Union[Dict[str, Union[bytes, bool, float, str, FilePayload]], FormData]
5759
ParamsType = Union[Dict[str, Union[bool, float, str]], str]
5860

5961

@@ -217,7 +219,7 @@ async def patch(
217219
headers: Headers = None,
218220
data: DataType = None,
219221
form: FormType = None,
220-
multipart: Dict[str, Union[bytes, bool, float, str, FilePayload]] = None,
222+
multipart: MultipartType = None,
221223
timeout: float = None,
222224
failOnStatusCode: bool = None,
223225
ignoreHTTPSErrors: bool = None,
@@ -246,7 +248,7 @@ async def put(
246248
headers: Headers = None,
247249
data: DataType = None,
248250
form: FormType = None,
249-
multipart: Dict[str, Union[bytes, bool, float, str, FilePayload]] = None,
251+
multipart: MultipartType = None,
250252
timeout: float = None,
251253
failOnStatusCode: bool = None,
252254
ignoreHTTPSErrors: bool = None,
@@ -275,7 +277,7 @@ async def post(
275277
headers: Headers = None,
276278
data: DataType = None,
277279
form: FormType = None,
278-
multipart: Dict[str, Union[bytes, bool, float, str, FilePayload]] = None,
280+
multipart: MultipartType = None,
279281
timeout: float = None,
280282
failOnStatusCode: bool = None,
281283
ignoreHTTPSErrors: bool = None,
@@ -305,7 +307,7 @@ async def fetch(
305307
headers: Headers = None,
306308
data: DataType = None,
307309
form: FormType = None,
308-
multipart: Dict[str, Union[bytes, bool, float, str, FilePayload]] = None,
310+
multipart: MultipartType = None,
309311
timeout: float = None,
310312
failOnStatusCode: bool = None,
311313
ignoreHTTPSErrors: bool = None,
@@ -346,7 +348,7 @@ async def _inner_fetch(
346348
data: DataType = None,
347349
params: ParamsType = None,
348350
form: FormType = None,
349-
multipart: Dict[str, Union[bytes, bool, float, str, FilePayload]] = None,
351+
multipart: MultipartType = None,
350352
timeout: float = None,
351353
failOnStatusCode: bool = None,
352354
ignoreHTTPSErrors: bool = None,
@@ -386,21 +388,36 @@ async def _inner_fetch(
386388
else:
387389
raise Error(f"Unsupported 'data' type: {type(data)}")
388390
elif form:
389-
form_data = object_to_array(form)
391+
if isinstance(form, FormData):
392+
form_data = []
393+
for fd_name, fd_value in form._fields:
394+
if isinstance(fd_value, (pathlib.Path, dict)):
395+
raise Error(
396+
f"Form field {fd_name!r} must be a string, number or boolean. Use 'multipart' for file uploads."
397+
)
398+
form_data.append(NameValue(name=fd_name, value=str(fd_value)))
399+
else:
400+
form_data = object_to_array(form)
390401
elif multipart:
391402
multipart_data = []
392-
# Convert file-like values to ServerFilePayload structs.
393-
for name, value in multipart.items():
394-
if is_file_payload(value):
395-
payload = cast(FilePayload, value)
396-
assert isinstance(
397-
payload["buffer"], bytes
398-
), f"Unexpected buffer type of 'data.{name}'"
403+
if isinstance(multipart, FormData):
404+
for fd_name, fd_value in multipart._fields:
399405
multipart_data.append(
400-
FormField(name=name, file=file_payload_to_json(payload))
406+
await _form_data_field_to_form_field(fd_name, fd_value)
401407
)
402-
elif isinstance(value, str):
403-
multipart_data.append(FormField(name=name, value=value))
408+
else:
409+
# Convert file-like values to ServerFilePayload structs.
410+
for name, value in multipart.items():
411+
if is_file_payload(value):
412+
payload = cast(FilePayload, value)
413+
assert isinstance(
414+
payload["buffer"], bytes
415+
), f"Unexpected buffer type of 'data.{name}'"
416+
multipart_data.append(
417+
FormField(name=name, file=file_payload_to_json(payload))
418+
)
419+
elif isinstance(value, str):
420+
multipart_data.append(FormField(name=name, value=value))
404421
if (
405422
post_data_buffer is None
406423
and json_data is None
@@ -455,6 +472,28 @@ def file_payload_to_json(payload: FilePayload) -> ServerFilePayload:
455472
)
456473

457474

475+
async def _form_data_field_to_form_field(name: str, value: Any) -> FormField:
476+
if isinstance(value, pathlib.Path):
477+
mime_type, _ = mimetypes.guess_type(str(value))
478+
return FormField(
479+
name=name,
480+
file=ServerFilePayload(
481+
name=value.name,
482+
mimeType=mime_type or "application/octet-stream",
483+
buffer=base64.b64encode(await async_readfile(str(value))).decode(),
484+
),
485+
)
486+
if is_file_payload(value):
487+
payload = cast(FilePayload, value)
488+
assert isinstance(
489+
payload["buffer"], bytes
490+
), f"Unexpected buffer type of form field {name!r}"
491+
return FormField(name=name, file=file_payload_to_json(payload))
492+
if isinstance(value, (str, int, float, bool)):
493+
return FormField(name=name, value=str(value))
494+
raise Error(f"Unsupported form field {name!r} value type: {type(value).__name__}")
495+
496+
458497
class APIResponse:
459498
def __init__(self, context: APIRequestContext, initializer: Dict) -> None:
460499
self._loop = context._loop

playwright/_impl/_form_data.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright (c) Microsoft Corporation.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pathlib
16+
from typing import List, Tuple, Union
17+
18+
from playwright._impl._api_structures import FilePayload
19+
20+
FormDataValue = Union[bool, float, str, pathlib.Path, FilePayload]
21+
22+
23+
class FormData:
24+
def __init__(self) -> None:
25+
self._fields: List[Tuple[str, FormDataValue]] = []
26+
27+
def set(self, name: str, value: FormDataValue) -> "FormData":
28+
self._fields = [(n, v) for (n, v) in self._fields if n != name]
29+
self._fields.append((name, value))
30+
return self
31+
32+
def append(self, name: str, value: FormDataValue) -> "FormData":
33+
self._fields.append((name, value))
34+
return self

playwright/async_api/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import playwright._impl._api_structures
2424
import playwright._impl._errors
25+
import playwright._impl._form_data
2526
import playwright.async_api._generated
2627
from playwright._impl._assertions import (
2728
APIResponseAssertions as APIResponseAssertionsImpl,
@@ -69,6 +70,7 @@
6970

7071
Cookie = playwright._impl._api_structures.Cookie
7172
FilePayload = playwright._impl._api_structures.FilePayload
73+
FormData = playwright._impl._form_data.FormData
7274
FloatRect = playwright._impl._api_structures.FloatRect
7375
Geolocation = playwright._impl._api_structures.Geolocation
7476
HttpCredentials = playwright._impl._api_structures.HttpCredentials
@@ -171,6 +173,7 @@ def __call__(
171173
"FileChooser",
172174
"FilePayload",
173175
"FloatRect",
176+
"FormData",
174177
"Frame",
175178
"FrameLocator",
176179
"Geolocation",

0 commit comments

Comments
 (0)