|
1 | 1 | from .future import DSSFuture |
| 2 | +from .local_model import DSSLocalModel |
2 | 3 | import json |
3 | 4 | import warnings |
4 | 5 | import logging |
@@ -267,7 +268,35 @@ def set_definition(self, definition): |
267 | 268 | return self.client._perform_json( |
268 | 269 | "PUT", "/admin/connections/%s" % self.name, |
269 | 270 | body = definition) |
270 | | - |
| 271 | + |
| 272 | + |
| 273 | + def list_local_models(self): |
| 274 | + """ |
| 275 | + :returns: List local models defined in this connection. |
| 276 | + :rtype: list[:class:`~dataikuapi.dss.local_model.DSSLocalModel`] |
| 277 | + :raises Exception: If this connection is not of type HuggingFaceLocal. |
| 278 | + """ |
| 279 | + definition = self.get_definition() |
| 280 | + if definition.get("type") != "HuggingFaceLocal": |
| 281 | + raise Exception("Connection %s is not a HuggingFaceLocal connection" % self.name) |
| 282 | + params = definition.get("params") or {} |
| 283 | + models = params.get("models") or [] |
| 284 | + return [DSSLocalModel(self.client, self.name, model.get("id")) for model in models] |
| 285 | + |
| 286 | + def get_local_model(self, model_id): |
| 287 | + """ |
| 288 | + Get a handle on this local model. |
| 289 | +
|
| 290 | + :param str model_id: Identifier of the model. |
| 291 | + :rtype: :class:`~dataikuapi.dss.local_model.DSSLocalModel` |
| 292 | + """ |
| 293 | + if not model_id.strip(): |
| 294 | + raise ValueError("model_id must be a non-empty string") |
| 295 | + if not self.name.strip(): |
| 296 | + raise ValueError("connection_name must be a non-empty string") |
| 297 | + |
| 298 | + return DSSLocalModel(self.client, self.name, model_id) |
| 299 | + |
271 | 300 | ######################################################## |
272 | 301 | # Security |
273 | 302 | ######################################################## |
@@ -2173,6 +2202,7 @@ def add_container_runtime_addition(self, container_runtime_addition): |
2173 | 2202 | * PYTHON36_SUPPORT |
2174 | 2203 | * PYTHON37_SUPPORT |
2175 | 2204 | * PYTHON38_SUPPORT |
| 2205 | + * HUGGING_FACE_LOCAL_GPU |
2176 | 2206 |
|
2177 | 2207 | :param dict container_runtime_addition: a dict with the container runtime addition definition |
2178 | 2208 | """ |
@@ -3361,4 +3391,3 @@ def get_counter(self, id): |
3361 | 3391 | :rtype: dict |
3362 | 3392 | """ |
3363 | 3393 | return next((counter for counter in self.counters if counter["id"] == id), None) |
3364 | | - |
|
0 commit comments