diff --git a/docs/job_definition_parameters.md b/docs/job_definition_parameters.md index b0b5697f2..5e6063c7d 100644 --- a/docs/job_definition_parameters.md +++ b/docs/job_definition_parameters.md @@ -160,6 +160,46 @@ datapoints=["image1.jpg", "image2.jpg"], contexts=["A cat sitting on a red couch", "A blue car in the rain"] ``` +**Length limit:** A context may be at most 400 characters; the backend rejects longer ones. If a context exceeds the limit, a warning is logged at creation time. See `auto_shorten` below to have over-long contexts shortened automatically. + +--- + +### `auto_shorten` + +| Property | Value | +|----------|-------| +| **Type** | `bool` | +| **Required** | No | +| **Default** | `False` | + +When `True`, any context longer than the 400-character limit is automatically shortened — tuned to the `instruction` so only the part relevant to the question is kept — before upload. When `False` (the default), an over-long context is left unchanged and a warning is logged explaining the backend would reject it. + +```python +order = rapi.order.create_classification_order( + name="Outfit check", + instruction="Does the main character wear the right clothing?", + answer_options=["Yes", "No"], + datapoints=["scene.jpg"], + contexts=[""], + auto_shorten=True, +) +``` + +You can also shorten contexts directly via the client, without creating an order: + +```python +short = rapi.context.shorten_context( + context="", + question="Does the main character wear the right clothing?", +) + +# Or a batch of (context, question) pairs in one call: +shortened = rapi.context.shorten_contexts([ + (context_a, question_a), + (context_b, question_b), +]) +``` + --- ### `media_contexts` diff --git a/src/rapidata/__init__.py b/src/rapidata/__init__.py index 9fb862df2..1177eecac 100644 --- a/src/rapidata/__init__.py +++ b/src/rapidata/__init__.py @@ -61,6 +61,7 @@ DeviceFilter, DeviceType, Datapoint, + ContextManager, FailedUploadException, FailedUpload, rapidata_config, diff --git a/src/rapidata/rapidata_client/__init__.py b/src/rapidata/rapidata_client/__init__.py index e1d30ad7d..07b312c02 100644 --- a/src/rapidata/rapidata_client/__init__.py +++ b/src/rapidata/rapidata_client/__init__.py @@ -20,11 +20,13 @@ EffortSelection, ) from .datapoints import Datapoint +from .context import ContextManager from .datapoints.metadata import ( PrivateTextMetadata, PublicTextMetadata, SelectWordsMetadata, ) + # --- GENERATED SETTINGS IMPORTS START --- from .settings import ( RapidataSettings, @@ -48,6 +50,7 @@ CompareEquirectangularSetting, ClassifyEquirectangularSetting, ) + # --- GENERATED SETTINGS IMPORTS END --- from .filter import ( CountryFilter, diff --git a/src/rapidata/rapidata_client/context/__init__.py b/src/rapidata/rapidata_client/context/__init__.py new file mode 100644 index 000000000..bbc092a98 --- /dev/null +++ b/src/rapidata/rapidata_client/context/__init__.py @@ -0,0 +1,4 @@ +from .context_manager import ContextManager +from ._context_length import MAX_CONTEXT_LENGTH + +__all__ = ["ContextManager", "MAX_CONTEXT_LENGTH"] diff --git a/src/rapidata/rapidata_client/context/_context_length.py b/src/rapidata/rapidata_client/context/_context_length.py new file mode 100644 index 000000000..319b591e2 --- /dev/null +++ b/src/rapidata/rapidata_client/context/_context_length.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from rapidata.rapidata_client.config import logger + +if TYPE_CHECKING: + from rapidata.rapidata_client.datapoints._datapoint import Datapoint + from rapidata.rapidata_client.context.context_manager import ContextManager + +# Mirrors the backend's datapoint/group context validation +# (datasets-service CreateDatapointCommandValidator: `RuleFor(x => x.Context).MaximumLength(400)`). +# Keep in sync if the backend limit changes. +MAX_CONTEXT_LENGTH = 400 + + +def enforce_context_length( + datapoints: list[Datapoint], + question: str | None, + auto_shorten: bool, + context_manager: ContextManager, +) -> None: + """Check datapoint contexts against the backend's maximum length, in place. + + For every datapoint whose context exceeds :data:`MAX_CONTEXT_LENGTH`: + + - if ``auto_shorten`` is True and a ``question`` is available, the context + is shortened for that question (one batched request) and substituted; + - otherwise a warning is logged explaining the backend would reject it. + """ + over_limit = [ + (index, datapoint) + for index, datapoint in enumerate(datapoints) + if datapoint.context is not None and len(datapoint.context) > MAX_CONTEXT_LENGTH + ] + if not over_limit: + return + + if auto_shorten and not question: + # auto_shorten needs the question to tune the context; without it we + # can't shorten, so fall back to warning rather than silently proceed. + logger.warning( + "auto_shorten=True but no question/instruction was available to shorten " + "the context against; leaving %d over-long context(s) unchanged.", + len(over_limit), + ) + + if auto_shorten and question: + pairs = [ + (datapoint.context, question) + for _, datapoint in over_limit + if datapoint.context is not None + ] + shortened = context_manager.shorten_contexts(pairs) + for (index, datapoint), new_context in zip(over_limit, shortened): + if not new_context: + logger.warning( + "Datapoint %d: shorten-context returned an empty result; " + "keeping the original context.", + index, + ) + continue + assert datapoint.context is not None + logger.info( + "Datapoint %d: shortened context from %d to %d characters.", + index, + len(datapoint.context), + len(new_context), + ) + datapoint.context = new_context + return + + for index, datapoint in over_limit: + assert datapoint.context is not None + logger.warning( + "Datapoint %d has a context of %d characters, which exceeds the maximum " + "of %d and would be rejected by the backend. Shorten it, or pass " + "auto_shorten=True to shorten it automatically.", + index, + len(datapoint.context), + MAX_CONTEXT_LENGTH, + ) diff --git a/src/rapidata/rapidata_client/context/context_manager.py b/src/rapidata/rapidata_client/context/context_manager.py new file mode 100644 index 000000000..b4117a789 --- /dev/null +++ b/src/rapidata/rapidata_client/context/context_manager.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from typing import Sequence, TYPE_CHECKING + +from rapidata.rapidata_client.config import logger, tracer + +if TYPE_CHECKING: + from rapidata.service.openapi_service import OpenAPIService + + +class ContextManager: + """Shortens a datapoint's context for the specific question an annotator answers. + + A long, general context (e.g. a full scene description) is often far more + detail than a single question needs. This manager tunes a context down to + what is relevant for the question, which keeps it within the length the + backend accepts and focuses the annotator. Results are cached server-side. + """ + + def __init__(self, openapi_service: OpenAPIService): + self._openapi_service = openapi_service + logger.debug("ContextManager initialized") + + def shorten_context(self, context: str, question: str) -> str: + """Shorten a single context for the given question. + + Args: + context: The (potentially long) context to shorten. + question: The question the context will be shown alongside. The + context is tuned to what this question needs. + + Returns: + The shortened context. + """ + return self.shorten_contexts([(context, question)])[0] + + def shorten_contexts(self, pairs: Sequence[tuple[str, str]]) -> list[str]: + """Shorten a batch of ``(context, question)`` pairs in one request. + + Args: + pairs: The ``(context, question)`` pairs to shorten. + + Returns: + The shortened contexts, in the same order as ``pairs``. + """ + with tracer.start_as_current_span("ContextManager.shorten_contexts"): + return self._openapi_service.context.shorten_contexts(pairs) diff --git a/src/rapidata/rapidata_client/job/rapidata_job_manager.py b/src/rapidata/rapidata_client/job/rapidata_job_manager.py index 76dbf0b2e..84352d714 100644 --- a/src/rapidata/rapidata_client/job/rapidata_job_manager.py +++ b/src/rapidata/rapidata_client/job/rapidata_job_manager.py @@ -18,6 +18,8 @@ from rapidata.rapidata_client.datapoints._datapoints_validator import ( DatapointsValidator, ) +from rapidata.rapidata_client.context.context_manager import ContextManager +from rapidata.rapidata_client.context._context_length import enforce_context_length if TYPE_CHECKING: from rapidata.rapidata_client.job.rapidata_job import RapidataJob @@ -31,6 +33,7 @@ class RapidataJobManager: def __init__(self, openapi_service: OpenAPIService): self._openapi_service = openapi_service + self.__context_manager = ContextManager(openapi_service) self.__priority: int | None = None logger.debug("JobManager initialized") @@ -44,10 +47,18 @@ def _create_general_job_definition( confidence_threshold: float | None = None, quorum_threshold: int | None = None, settings: Sequence[RapidataSetting] | None = None, + auto_shorten: bool = False, ) -> RapidataJobDefinition: if settings is None: settings = [] + enforce_context_length( + datapoints=datapoints, + question=workflow._get_instruction(), + auto_shorten=auto_shorten, + context_manager=self.__context_manager, + ) + if confidence_threshold is not None and quorum_threshold is not None: raise ValueError( "Cannot set both confidence_threshold and quorum_threshold. Choose one stopping strategy." @@ -163,6 +174,7 @@ def create_classification_job_definition( quorum_threshold: int | None = None, settings: Sequence[RapidataSetting] | None = None, private_metadata: list[dict[str, str]] | None = None, + auto_shorten: bool = False, ) -> RapidataJobDefinition: """Create a classification job definition. @@ -192,6 +204,9 @@ def create_classification_job_definition( private_metadata (list[dict[str, str]], optional): Key-value string pairs for each datapoint. Defaults to None. If provided has to be the same length as datapoints.\n This will NOT be shown to the labelers but will be included in the result purely for your own reference. + auto_shorten (bool, optional): Defaults to False. If True, any context longer than the backend's + maximum length is automatically shortened for the instruction before upload. If False, an + over-long context is left unchanged and a warning is logged that the backend would reject it. """ with tracer.start_as_current_span("JobManager.create_classification_job"): if not isinstance(datapoints, list) or not all( @@ -218,6 +233,7 @@ def create_classification_job_definition( confidence_threshold=confidence_threshold, quorum_threshold=quorum_threshold, settings=settings, + auto_shorten=auto_shorten, ) def create_compare_job_definition( @@ -234,6 +250,7 @@ def create_compare_job_definition( quorum_threshold: int | None = None, settings: Sequence[RapidataSetting] | None = None, private_metadata: list[dict[str, str]] | None = None, + auto_shorten: bool = False, ) -> RapidataJobDefinition: """Create a compare job definition. @@ -271,6 +288,9 @@ def create_compare_job_definition( private_metadata (list[dict[str, str]], optional): Key-value string pairs for each datapoint. Defaults to None.\n If provided has to be the same length as datapoints.\n This will NOT be shown to the labelers but will be included in the result purely for your own reference. + auto_shorten (bool, optional): Defaults to False. If True, any context longer than the backend's + maximum length is automatically shortened for the instruction before upload. If False, an + over-long context is left unchanged and a warning is logged that the backend would reject it. """ with tracer.start_as_current_span("JobManager.create_compare_job"): if any(not isinstance(datapoint, list) for datapoint in datapoints): @@ -304,6 +324,7 @@ def create_compare_job_definition( confidence_threshold=confidence_threshold, quorum_threshold=quorum_threshold, settings=settings, + auto_shorten=auto_shorten, ) def _create_ranking_job_definition( @@ -318,6 +339,7 @@ def _create_ranking_job_definition( contexts: list[str] | None = None, media_contexts: list[list[str]] | None = None, settings: Sequence[RapidataSetting] | None = None, + auto_shorten: bool = False, ) -> RapidataJobDefinition: """ Create a ranking job definition. @@ -394,6 +416,7 @@ def _create_ranking_job_definition( datapoints=datapoints_instances, responses_per_datapoint=responses_per_comparison, settings=settings, + auto_shorten=auto_shorten, ) def _create_free_text_job_definition( @@ -407,6 +430,7 @@ def _create_free_text_job_definition( media_contexts: list[list[str]] | None = None, settings: Sequence[RapidataSetting] | None = None, private_metadata: list[dict[str, str]] | None = None, + auto_shorten: bool = False, ) -> RapidataJobDefinition: """Create a free text job definition. @@ -431,6 +455,9 @@ def _create_free_text_job_definition( private_metadata (list[dict[str, str]], optional): Key-value string pairs for each datapoint. Defaults to None.\n If provided has to be the same length as datapoints.\n This will NOT be shown to the labelers but will be included in the result purely for your own reference. + auto_shorten (bool, optional): Defaults to False. If True, any context longer than the backend's + maximum length is automatically shortened for the instruction before upload. If False, an + over-long context is left unchanged and a warning is logged that the backend would reject it. """ with tracer.start_as_current_span("JobManager.create_free_text_job"): from rapidata.rapidata_client.workflow import FreeTextWorkflow @@ -448,6 +475,7 @@ def _create_free_text_job_definition( datapoints=datapoints_instances, responses_per_datapoint=responses_per_datapoint, settings=settings, + auto_shorten=auto_shorten, ) def _create_select_words_job_definition( @@ -459,6 +487,7 @@ def _create_select_words_job_definition( responses_per_datapoint: int = 10, settings: Sequence[RapidataSetting] | None = None, private_metadata: list[dict[str, str]] | None = None, + auto_shorten: bool = False, ) -> RapidataJobDefinition: """Create a select words job definition. @@ -477,6 +506,9 @@ def _create_select_words_job_definition( private_metadata (list[dict[str, str]], optional): Key-value string pairs for each datapoint. Defaults to None.\n If provided has to be the same length as datapoints.\n This will NOT be shown to the labelers but will be included in the result purely for your own reference. + auto_shorten (bool, optional): Defaults to False. If True, any context longer than the backend's + maximum length is automatically shortened for the instruction before upload. If False, an + over-long context is left unchanged and a warning is logged that the backend would reject it. """ with tracer.start_as_current_span("JobManager.create_select_words_job"): from rapidata.rapidata_client.workflow import SelectWordsWorkflow @@ -494,6 +526,7 @@ def _create_select_words_job_definition( datapoints=datapoints_instances, responses_per_datapoint=responses_per_datapoint, settings=settings, + auto_shorten=auto_shorten, ) def create_locate_job_definition( @@ -506,6 +539,7 @@ def create_locate_job_definition( media_contexts: list[list[str]] | None = None, settings: Sequence[RapidataSetting] | None = None, private_metadata: list[dict[str, str]] | None = None, + auto_shorten: bool = False, ) -> RapidataJobDefinition: """Create a locate job definition. @@ -527,6 +561,9 @@ def create_locate_job_definition( private_metadata (list[dict[str, str]], optional): Key-value string pairs for each datapoint. Defaults to None.\n If provided has to be the same length as datapoints.\n This will NOT be shown to the labelers but will be included in the result purely for your own reference. + auto_shorten (bool, optional): Defaults to False. If True, any context longer than the backend's + maximum length is automatically shortened for the instruction before upload. If False, an + over-long context is left unchanged and a warning is logged that the backend would reject it. """ with tracer.start_as_current_span("JobManager.create_locate_job"): from rapidata.rapidata_client.workflow import LocateWorkflow @@ -543,6 +580,7 @@ def create_locate_job_definition( datapoints=datapoints_instances, responses_per_datapoint=responses_per_datapoint, settings=settings, + auto_shorten=auto_shorten, ) def _create_draw_job_definition( @@ -555,6 +593,7 @@ def _create_draw_job_definition( media_contexts: list[list[str]] | None = None, settings: Sequence[RapidataSetting] | None = None, private_metadata: list[dict[str, str]] | None = None, + auto_shorten: bool = False, ) -> RapidataJobDefinition: """Create a draw job definition. @@ -576,6 +615,9 @@ def _create_draw_job_definition( private_metadata (list[dict[str, str]], optional): Key-value string pairs for each datapoint. Defaults to None.\n If provided has to be the same length as datapoints.\n This will NOT be shown to the labelers but will be included in the result purely for your own reference. + auto_shorten (bool, optional): Defaults to False. If True, any context longer than the backend's + maximum length is automatically shortened for the instruction before upload. If False, an + over-long context is left unchanged and a warning is logged that the backend would reject it. """ with tracer.start_as_current_span("JobManager.create_draw_job"): from rapidata.rapidata_client.workflow import DrawWorkflow @@ -592,6 +634,7 @@ def _create_draw_job_definition( datapoints=datapoints_instances, responses_per_datapoint=responses_per_datapoint, settings=settings, + auto_shorten=auto_shorten, ) def _create_timestamp_job_definition( @@ -604,6 +647,7 @@ def _create_timestamp_job_definition( media_contexts: list[list[str]] | None = None, settings: Sequence[RapidataSetting] | None = None, private_metadata: list[dict[str, str]] | None = None, + auto_shorten: bool = False, ) -> RapidataJobDefinition: """Create a timestamp job definition. @@ -628,6 +672,9 @@ def _create_timestamp_job_definition( private_metadata (list[dict[str, str]], optional): Key-value string pairs for each datapoint. Defaults to None.\n If provided has to be the same length as datapoints.\n This will NOT be shown to the labelers but will be included in the result purely for your own reference. + auto_shorten (bool, optional): Defaults to False. If True, any context longer than the backend's + maximum length is automatically shortened for the instruction before upload. If False, an + over-long context is left unchanged and a warning is logged that the backend would reject it. """ with tracer.start_as_current_span("JobManager.create_timestamp_job"): from rapidata.rapidata_client.workflow import TimestampWorkflow @@ -644,6 +691,7 @@ def _create_timestamp_job_definition( datapoints=datapoints_instances, responses_per_datapoint=responses_per_datapoint, settings=settings, + auto_shorten=auto_shorten, ) def get_job_definition_by_id(self, job_definition_id: str) -> RapidataJobDefinition: diff --git a/src/rapidata/rapidata_client/order/rapidata_order_manager.py b/src/rapidata/rapidata_client/order/rapidata_order_manager.py index d5eb97652..e9884b790 100644 --- a/src/rapidata/rapidata_client/order/rapidata_order_manager.py +++ b/src/rapidata/rapidata_client/order/rapidata_order_manager.py @@ -15,6 +15,8 @@ from rapidata.rapidata_client.filter.rapidata_filters import RapidataFilters from rapidata.rapidata_client.settings import RapidataSetting, RapidataSettings from rapidata.rapidata_client.selection.rapidata_selections import RapidataSelections +from rapidata.rapidata_client.context.context_manager import ContextManager +from rapidata.rapidata_client.context._context_length import enforce_context_length from rapidata.service.openapi_service import OpenAPIService if TYPE_CHECKING: @@ -49,6 +51,7 @@ class RapidataOrderManager: def __init__(self, openapi_service: OpenAPIService): self.__openapi_service = openapi_service + self.__context_manager = ContextManager(openapi_service) self.filters = RapidataFilters self.settings = RapidataSettings self.selections = RapidataSelections @@ -78,8 +81,15 @@ def _create_general_order( filters: Sequence[RapidataFilter] | None = None, settings: Sequence[RapidataSetting] | None = None, selections: Sequence[RapidataSelection] | None = None, + auto_shorten: bool = False, ) -> RapidataOrder: self._warn_deprecated() + enforce_context_length( + datapoints=datapoints, + question=workflow._get_instruction(), + auto_shorten=auto_shorten, + context_manager=self.__context_manager, + ) if filters is None: filters = [] if settings is None: @@ -194,6 +204,7 @@ def create_classification_order( settings: Sequence[RapidataSetting] | None = None, selections: Sequence[RapidataSelection] | None = None, private_metadata: list[dict[str, str]] | None = None, + auto_shorten: bool = False, ) -> RapidataOrder: """Create a classification order. @@ -227,6 +238,9 @@ def create_classification_order( private_metadata (list[dict[str, str]], optional): Key-value string pairs for each datapoint. Defaults to None. If provided has to be the same length as datapoints.\n This will NOT be shown to the labelers but will be included in the result purely for your own reference. + auto_shorten (bool, optional): Defaults to False. If True, any context longer than the backend's + maximum length is automatically shortened for the instruction before upload. If False, an + over-long context is left unchanged and a warning is logged that the backend would reject it. Example: ```python @@ -275,6 +289,7 @@ def create_classification_order( filters=filters, selections=selections, settings=settings, + auto_shorten=auto_shorten, ) def create_compare_order( @@ -294,6 +309,7 @@ def create_compare_order( settings: Sequence[RapidataSetting] | None = None, selections: Sequence[RapidataSelection] | None = None, private_metadata: list[dict[str, str]] | None = None, + auto_shorten: bool = False, ) -> RapidataOrder: """Create a compare order. @@ -341,6 +357,9 @@ def create_compare_order( private_metadata (list[dict[str, str]], optional): Key-value string pairs for each datapoint. Defaults to None.\n If provided has to be the same length as datapoints.\n This will NOT be shown to the labelers but will be included in the result purely for your own reference. + auto_shorten (bool, optional): Defaults to False. If True, any context longer than the backend's + maximum length is automatically shortened for the instruction before upload. If False, an + over-long context is left unchanged and a warning is logged that the backend would reject it. Example: ```python @@ -397,6 +416,7 @@ def create_compare_order( filters=filters, selections=selections, settings=settings, + auto_shorten=auto_shorten, ) def create_ranking_order( @@ -414,6 +434,7 @@ def create_ranking_order( filters: Sequence[RapidataFilter] | None = None, settings: Sequence[RapidataSetting] | None = None, selections: Sequence[RapidataSelection] | None = None, + auto_shorten: bool = False, ) -> RapidataOrder: """ Create a ranking order. @@ -442,6 +463,9 @@ def create_ranking_order( filters (Sequence[RapidataFilter], optional): The list of filters for the ranking. Defaults to []. Decides who the tasks should be shown to. settings (Sequence[RapidataSetting], optional): The list of settings for the ranking. Defaults to []. Decides how the tasks should be shown. selections (Sequence[RapidataSelection], optional): The list of selections for the ranking. Defaults to []. Decides in what order the tasks should be shown. + auto_shorten (bool, optional): Defaults to False. If True, any context longer than the backend's + maximum length is automatically shortened for the instruction before upload. If False, an + over-long context is left unchanged and a warning is logged that the backend would reject it. Example: ```python @@ -514,6 +538,7 @@ def create_ranking_order( filters=filters, selections=selections, settings=settings, + auto_shorten=auto_shorten, ) def create_free_text_order( @@ -529,6 +554,7 @@ def create_free_text_order( settings: Sequence[RapidataSetting] | None = None, selections: Sequence[RapidataSelection] | None = None, private_metadata: list[dict[str, str]] | None = None, + auto_shorten: bool = False, ) -> RapidataOrder: """Create a free text order. @@ -555,6 +581,9 @@ def create_free_text_order( private_metadata (list[dict[str, str]], optional): Key-value string pairs for each datapoint. Defaults to None.\n If provided has to be the same length as datapoints.\n This will NOT be shown to the labelers but will be included in the result purely for your own reference. + auto_shorten (bool, optional): Defaults to False. If True, any context longer than the backend's + maximum length is automatically shortened for the instruction before upload. If False, an + over-long context is left unchanged and a warning is logged that the backend would reject it. Example: ```python @@ -592,6 +621,7 @@ def create_free_text_order( filters=filters, selections=selections, settings=settings, + auto_shorten=auto_shorten, ) def create_select_words_order( @@ -607,6 +637,7 @@ def create_select_words_order( settings: Sequence[RapidataSetting] | None = None, selections: Sequence[RapidataSelection] | None = None, private_metadata: list[dict[str, str]] | None = None, + auto_shorten: bool = False, ) -> RapidataOrder: """Create a select words order. @@ -632,6 +663,9 @@ def create_select_words_order( private_metadata (list[dict[str, str]], optional): Key-value string pairs for each datapoint. Defaults to None.\n If provided has to be the same length as datapoints.\n This will NOT be shown to the labelers but will be included in the result purely for your own reference. + auto_shorten (bool, optional): Defaults to False. If True, any context longer than the backend's + maximum length is automatically shortened for the instruction before upload. If False, an + over-long context is left unchanged and a warning is logged that the backend would reject it. Example: ```python @@ -671,6 +705,7 @@ def create_select_words_order( filters=filters, selections=selections, settings=settings, + auto_shorten=auto_shorten, ) def create_locate_order( @@ -686,6 +721,7 @@ def create_locate_order( settings: Sequence[RapidataSetting] | None = None, selections: Sequence[RapidataSelection] | None = None, private_metadata: list[dict[str, str]] | None = None, + auto_shorten: bool = False, ) -> RapidataOrder: """Create a locate order. @@ -711,6 +747,9 @@ def create_locate_order( private_metadata (list[dict[str, str]], optional): Key-value string pairs for each datapoint. Defaults to None.\n If provided has to be the same length as datapoints.\n This will NOT be shown to the labelers but will be included in the result purely for your own reference. + auto_shorten (bool, optional): Defaults to False. If True, any context longer than the backend's + maximum length is automatically shortened for the instruction before upload. If False, an + over-long context is left unchanged and a warning is logged that the backend would reject it. Example: ```python @@ -745,6 +784,7 @@ def create_locate_order( filters=filters, selections=selections, settings=settings, + auto_shorten=auto_shorten, ) def create_draw_order( @@ -760,6 +800,7 @@ def create_draw_order( settings: Sequence[RapidataSetting] | None = None, selections: Sequence[RapidataSelection] | None = None, private_metadata: list[dict[str, str]] | None = None, + auto_shorten: bool = False, ) -> RapidataOrder: """Create a draw order. @@ -785,6 +826,9 @@ def create_draw_order( private_metadata (list[dict[str, str]], optional): Key-value string pairs for each datapoint. Defaults to None.\n If provided has to be the same length as datapoints.\n This will NOT be shown to the labelers but will be included in the result purely for your own reference. + auto_shorten (bool, optional): Defaults to False. If True, any context longer than the backend's + maximum length is automatically shortened for the instruction before upload. If False, an + over-long context is left unchanged and a warning is logged that the backend would reject it. Example: ```python @@ -819,6 +863,7 @@ def create_draw_order( filters=filters, selections=selections, settings=settings, + auto_shorten=auto_shorten, ) def create_timestamp_order( @@ -834,6 +879,7 @@ def create_timestamp_order( settings: Sequence[RapidataSetting] | None = None, selections: Sequence[RapidataSelection] | None = None, private_metadata: list[dict[str, str]] | None = None, + auto_shorten: bool = False, ) -> RapidataOrder: """Create a timestamp order. @@ -862,6 +908,9 @@ def create_timestamp_order( private_metadata (list[dict[str, str]], optional): Key-value string pairs for each datapoint. Defaults to None.\n If provided has to be the same length as datapoints.\n This will NOT be shown to the labelers but will be included in the result purely for your own reference. + auto_shorten (bool, optional): Defaults to False. If True, any context longer than the backend's + maximum length is automatically shortened for the instruction before upload. If False, an + over-long context is left unchanged and a warning is logged that the backend would reject it. """ with tracer.start_as_current_span( @@ -884,6 +933,7 @@ def create_timestamp_order( filters=filters, selections=selections, settings=settings, + auto_shorten=auto_shorten, ) def get_order_by_id(self, order_id: str) -> RapidataOrder: diff --git a/src/rapidata/rapidata_client/rapidata_client.py b/src/rapidata/rapidata_client/rapidata_client.py index dec0b1440..e41937b18 100644 --- a/src/rapidata/rapidata_client/rapidata_client.py +++ b/src/rapidata/rapidata_client/rapidata_client.py @@ -25,6 +25,7 @@ ) from rapidata.rapidata_client.demographic.demographic_manager import DemographicManager +from rapidata.rapidata_client.context.context_manager import ContextManager from rapidata.rapidata_client.config import ( logger, @@ -107,6 +108,8 @@ def __init__( audience (RapidataAudienceManager): The RapidataAudienceManager instance. job (JobManager): The JobManager instance. mri (RapidataBenchmarkManager): The RapidataBenchmarkManager instance. + context (ContextManager): The ContextManager instance for shortening + datapoint contexts against a question. """ tracer.set_session_id( uuid.UUID(int=random.Random().getrandbits(128), version=4).hex @@ -172,6 +175,9 @@ def __init__( openapi_service=self._openapi_service ) + logger.debug("Initializing ContextManager") + self.context = ContextManager(openapi_service=self._openapi_service) + self._check_beta_features() # can't be in the trace for some reason def reset_credentials(self): diff --git a/src/rapidata/service/openapi_service.py b/src/rapidata/service/openapi_service.py index d62c04b0e..8c803b00d 100644 --- a/src/rapidata/service/openapi_service.py +++ b/src/rapidata/service/openapi_service.py @@ -23,6 +23,7 @@ from rapidata.service.services.leaderboard_service import LeaderboardService from rapidata.service.services.rapid_service import RapidService from rapidata.service.services.translation_service import TranslationService + from rapidata.service.services.context_service import ContextService class OpenAPIService: @@ -77,6 +78,7 @@ def __init__( self._leaderboard: LeaderboardService | None = None self._rapid: RapidService | None = None self._translation: TranslationService | None = None + self._context: ContextService | None = None if token: logger.debug("Using token for authentication") @@ -220,6 +222,13 @@ def translation(self) -> TranslationService: self._translation = TranslationService(self.api_client) return self._translation + @property + def context(self) -> ContextService: + if self._context is None: + from rapidata.service.services.context_service import ContextService + self._context = ContextService(self.api_client) + return self._context + def _get_rapidata_package_version(self): """ Returns the version of the currently installed rapidata package. diff --git a/src/rapidata/service/services/context_service.py b/src/rapidata/service/services/context_service.py new file mode 100644 index 000000000..d439b79b1 --- /dev/null +++ b/src/rapidata/service/services/context_service.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +from typing import Any, TYPE_CHECKING, Sequence, cast + +if TYPE_CHECKING: + from rapidata.rapidata_client.api.rapidata_api_client import RapidataApiClient + + +# TODO: Replace this hand-written wrapper with the generated OpenAPI client once +# `POST /datasets/shorten-context` ships in datasets-service and the contract is +# published. At that point this service should call the generated +# ``DatasetsApi`` method (and the request/response shapes below become typed +# models) instead of crafting the request by hand. The path is kept as a single +# constant so the regeneration is a one-line removal. +_SHORTEN_CONTEXT_PATH = "/datasets/shorten-context" + + +class ContextService: + """Thin client for the datasets context-shortening endpoint. + + The endpoint takes a batch of ``(context, question)`` pairs and returns a + shortened context per item, tuned to the question the annotator answers. + Results are cached server-side, so re-sending the same pair is cheap. + """ + + def __init__(self, api_client: RapidataApiClient) -> None: + self._api_client = api_client + + def shorten_contexts(self, items: Sequence[tuple[str, str]]) -> list[str]: + """Shorten each ``(context, question)`` pair. + + Returns the shortened contexts in the same order as ``items``. + """ + if not items: + return [] + + url = f"{self._api_client.configuration.host}{_SHORTEN_CONTEXT_PATH}" + body = { + "items": [ + {"context": context, "question": question} + for context, question in items + ] + } + + response_data = self._api_client.call_api( + "POST", + url, + header_params={ + "Content-Type": "application/json", + "Accept": "application/json", + }, + body=body, + ) + response_data.read() + + # Reuse the generated deserializer so error responses are converted to + # RapidataError and tracing headers are honoured, exactly like a + # generated endpoint call. "object" yields a plain dict. + result = cast( + "dict[str, Any]", + self._api_client.response_deserialize( + response_data=response_data, + response_types_map={ + "200": "object", + "400": "object", + "401": None, + "403": None, + }, + ).data + or {}, + ) + + returned = result.get("items", []) + if len(returned) != len(items): + raise ValueError( + "shorten-context returned " + f"{len(returned)} item(s) for {len(items)} request item(s)." + ) + + return [item["shortenedContext"] for item in returned]