-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathserialization.py
More file actions
260 lines (210 loc) · 7.03 KB
/
serialization.py
File metadata and controls
260 lines (210 loc) · 7.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
"""
Object Serialization
====================
Contains serializers for storage of objects on the Simvue server
"""
import contextlib
import typing
import pickle
import pandas
import json
import numpy
from io import BytesIO
if typing.TYPE_CHECKING:
from pandas import DataFrame
from plotly.graph_objects import Figure
from torch import Tensor
from typing_extensions import Buffer
from .types import DeserializedContent
from .utilities import check_extra
def _is_torch_tensor(data: typing.Any) -> bool:
"""
Check if value is a PyTorch tensor or state dict
"""
module_name = data.__class__.__module__
class_name = data.__class__.__name__
if module_name == "collections" and class_name == "OrderedDict":
valid = True
for item in data:
module_name = data[item].__class__.__module__
class_name = data[item].__class__.__name__
if module_name != "torch" or class_name != "Tensor":
valid = False
if valid:
return True
elif module_name == "torch" and class_name == "Tensor":
return True
return False
def serialize_object(data: typing.Any, allow_pickle: bool) -> tuple[str, str] | None:
"""Determine which serializer to use for the given object
Parameters
----------
data : typing.Any
object to serialize
allow_pickle : bool
whether pickling is allowed
Returns
-------
Callable[[typing.Any], tuple[str, str]]
the serializer to use
"""
module_name = data.__class__.__module__
class_name = data.__class__.__name__
if module_name == "plotly.graph_objs._figure" and class_name == "Figure":
return _serialize_plotly_figure(data)
elif module_name == "matplotlib.figure" and class_name == "Figure":
return _serialize_matplotlib_figure(data)
elif module_name == "numpy" and class_name == "ndarray":
return _serialize_numpy_array(data)
elif module_name == "pandas.core.frame" and class_name == "DataFrame":
return _serialize_dataframe(data)
elif _is_torch_tensor(data):
return _serialize_torch_tensor(data)
elif module_name == "builtins" and class_name == "module" and not allow_pickle:
with contextlib.suppress(ImportError):
import matplotlib.pyplot
if data == matplotlib.pyplot:
return _serialize_matplotlib(data)
elif serialized := _serialize_json(data):
return serialized
return _serialize_pickle(data) if allow_pickle else None
@check_extra("plot")
def _serialize_plotly_figure(data: typing.Any) -> tuple[str, str]:
try:
import plotly
except ImportError:
return None
mimetype = "application/vnd.plotly.v1+json"
data = plotly.io.to_json(data, engine="json")
mfile = BytesIO()
mfile.write(data.encode())
mfile.seek(0)
data = mfile.read()
return data, mimetype
@check_extra("plot")
def _serialize_matplotlib(data: typing.Any) -> tuple[str, str] | None:
try:
import plotly
except ImportError:
return None
mimetype = "application/vnd.plotly.v1+json"
data = plotly.io.to_json(plotly.tools.mpl_to_plotly(data.gcf()), engine="json")
mfile = BytesIO()
mfile.write(data.encode())
mfile.seek(0)
data = mfile.read()
return data, mimetype
@check_extra("plot")
def _serialize_matplotlib_figure(data: typing.Any) -> tuple[str, str] | None:
try:
import plotly
except ImportError:
return None
mimetype = "application/vnd.plotly.v1+json"
data = plotly.io.to_json(plotly.tools.mpl_to_plotly(data), engine="json")
mfile = BytesIO()
mfile.write(data.encode())
mfile.seek(0)
data = mfile.read()
return data, mimetype
def _serialize_numpy_array(data: typing.Any) -> tuple[str, str] | None:
mimetype = "application/vnd.simvue.numpy.v1"
mfile = BytesIO()
numpy.save(mfile, data, allow_pickle=False)
mfile.seek(0)
data = mfile.read()
return data, mimetype
def _serialize_dataframe(data: typing.Any) -> tuple[str, str] | None:
mimetype = "application/vnd.simvue.df.v1"
mfile = BytesIO()
data.to_csv(mfile)
mfile.seek(0)
data = mfile.read()
return data, mimetype
@check_extra("torch")
def _serialize_torch_tensor(data: typing.Any) -> tuple[str, str] | None:
try:
import torch
except ImportError:
torch = None
return None
mimetype = "application/vnd.simvue.torch.v1"
mfile = BytesIO()
torch.save(data, mfile)
mfile.seek(0)
data = mfile.read()
return data, mimetype
def _serialize_json(data: typing.Any) -> tuple[str, str] | None:
mimetype = "application/json"
try:
mfile = BytesIO()
mfile.write(json.dumps(data).encode())
mfile.seek(0)
data = mfile.read()
except (TypeError, json.JSONDecodeError):
return None
return data, mimetype
def _serialize_pickle(data: typing.Any) -> tuple[str, str] | None:
mimetype = "application/octet-stream"
data = pickle.dumps(data)
return data, mimetype
def deserialize_data(
data: "Buffer", mimetype: str, allow_pickle: bool
) -> typing.Optional["DeserializedContent"]:
"""
Determine which deserializer to use
"""
if mimetype == "application/vnd.plotly.v1+json":
return _deserialize_plotly_figure(data)
elif mimetype == "application/vnd.simvue.numpy.v1":
return _deserialize_numpy_array(data)
elif mimetype == "application/vnd.simvue.df.v1":
return _deserialize_dataframe(data)
elif mimetype == "application/vnd.simvue.torch.v1":
return _deserialize_torch_tensor(data)
elif mimetype == "application/json":
return _deserialize_json(data)
elif mimetype == "application/octet-stream" and allow_pickle:
return _deserialize_pickle(data)
return None
@check_extra("plot")
def _deserialize_plotly_figure(data: "Buffer") -> typing.Optional["Figure"]:
try:
import plotly
except ImportError:
return None
data = plotly.io.from_json(data)
return data
@check_extra("plot")
def _deserialize_matplotlib_figure(data: "Buffer") -> typing.Optional["Figure"]:
try:
import plotly
except ImportError:
return None
data = plotly.io.from_json(data)
return data
def _deserialize_numpy_array(data: "Buffer") -> typing.Any | None:
mfile = BytesIO(data)
mfile.seek(0)
data = numpy.load(mfile, allow_pickle=False)
return data
def _deserialize_dataframe(data: "Buffer") -> typing.Optional["DataFrame"]:
mfile = BytesIO(data)
mfile.seek(0)
return pandas.read_csv(mfile, index_col=0)
@check_extra("torch")
def _deserialize_torch_tensor(data: "Buffer") -> typing.Optional["Tensor"]:
try:
import torch
except ImportError:
torch = None
return None
mfile = BytesIO(data)
mfile.seek(0)
return torch.load(mfile)
def _deserialize_pickle(data) -> typing.Any | None:
data = pickle.loads(data)
return data
def _deserialize_json(data) -> typing.Any | None:
data = json.loads(data)
return data