diff --git a/MIGRATION.md b/MIGRATION.md index 872def6593..ef8463a5c2 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -836,7 +836,7 @@ doc2 = Document(content="Berlin is the capital of Germany.", meta={"lang": "en", assert doc1.id == doc2.id ``` -It is possible to migrate an existing index without rerunning your indexing pipeline, for example to avoid recalculating embeddings. To do that, read stored documents, regenerate their IDs using Haystack 3.0, write the updated documents, and delete the documents stored under their old IDs. +It is possible to migrate an existing index without rerunning your indexing pipeline, for example to avoid recalculating embeddings. To do that, read stored documents, regenerate their IDs using Haystack 3.0, write the updated documents, and delete the documents stored under their old IDs. ```python from dataclasses import replace @@ -870,3 +870,26 @@ store.write_documents(new_documents, policy=DuplicatePolicy.OVERWRITE) new_ids = {doc.id for doc in new_documents} store.delete_documents([doc.id for doc in old_documents if doc.id not in new_ids]) ``` + +### Components now resolve API keys at warm-up + +**What changed:** Components that use external services now create their resources (such as API clients) during `warm_up()` instead of in `__init__`. As a consequence, a missing API key (for example, an unset environment variable behind a `Secret.from_env_var` default) is now reported at warm-up or first run rather than at construction. This affects OpenAI and Azure OpenAI components. + +**Why:** Creating resources in `warm_up()` / `warm_up_async()` and releasing them in `close()` / `close_async()` gives components and pipelines a single, predictable resource lifecycle. + +**How to migrate:** If you relied on construction failing for a missing API key, expect the same error at `warm_up()` (or the first `run`) instead. + +Before (v2.x), with `OPENAI_API_KEY` unset: +```python +from haystack.components.embedders import OpenAITextEmbedder + +embedder = OpenAITextEmbedder() # raised here +``` + +After (v3.0), with `OPENAI_API_KEY` unset: +```python +from haystack.components.embedders import OpenAITextEmbedder + +embedder = OpenAITextEmbedder() # no error at construction +embedder.warm_up() # raised here +``` diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index 8811ec93d9..96aa26ad34 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -487,7 +487,7 @@ def __init__( self.tool_concurrency_limit = tool_concurrency_limit self.tool_streaming_callback_passthrough = tool_streaming_callback_passthrough self._confirmation_strategies = confirmation_strategies or {} - self._is_warmed_up = False + self._tools_warmed_up = False # --- State schema --- # shallow copy is sufficient: we only add a top-level "messages" key, never mutate nested values @@ -574,16 +574,38 @@ def _register_prompt_variables(self) -> None: else: component.set_input_type(self, name=var_name, type=Any, default=None) - def warm_up(self) -> None: - """ - Warm up the Agent. - """ - if not self._is_warmed_up: - if hasattr(self.chat_generator, "warm_up"): - self.chat_generator.warm_up() + def _warm_up_tools(self) -> None: + """Warm up the configured tools once.""" + if not self._tools_warmed_up: if self.tools: warm_up_tools(self.tools) - self._is_warmed_up = True + self._tools_warmed_up = True + + def warm_up(self) -> None: + """Warm up the tools and the underlying chat generator.""" + self._warm_up_tools() + if hasattr(self.chat_generator, "warm_up"): + self.chat_generator.warm_up() + + async def warm_up_async(self) -> None: + """Warm up the tools and the underlying chat generator on the serving event loop.""" + self._warm_up_tools() + if hasattr(self.chat_generator, "warm_up_async"): + await self.chat_generator.warm_up_async() + elif hasattr(self.chat_generator, "warm_up"): + self.chat_generator.warm_up() + + def close(self) -> None: + """Release the underlying chat generator's resources.""" + if hasattr(self.chat_generator, "close"): + self.chat_generator.close() + + async def close_async(self) -> None: + """Release the underlying chat generator's async resources.""" + if hasattr(self.chat_generator, "close_async"): + await self.chat_generator.close_async() + elif hasattr(self.chat_generator, "close"): + self.chat_generator.close() def to_dict(self) -> dict[str, Any]: """ @@ -828,8 +850,7 @@ def run( - Any additional keys defined in the `state_schema`. """ agent_inputs = {"messages": messages, "streaming_callback": streaming_callback, **kwargs} - if not self._is_warmed_up: - self.warm_up() + self.warm_up() exe_context = self._initialize_fresh_execution( messages=messages, @@ -903,8 +924,7 @@ async def run_async( - Any additional keys defined in the `state_schema`. """ agent_inputs = {"messages": messages, "streaming_callback": streaming_callback, **kwargs} - if not self._is_warmed_up: - self.warm_up() + await self.warm_up_async() exe_context = self._initialize_fresh_execution( messages=messages, diff --git a/haystack/components/audio/whisper_remote.py b/haystack/components/audio/whisper_remote.py index af9e255e2e..f22279dfa9 100644 --- a/haystack/components/audio/whisper_remote.py +++ b/haystack/components/audio/whisper_remote.py @@ -98,18 +98,50 @@ def __init__( ) whisper_params["response_format"] = "json" self.whisper_params = whisper_params - self.client = OpenAI( - api_key=api_key.resolve_value(), - organization=organization, - base_url=api_base_url, - http_client=init_http_client(self.http_client_kwargs, async_client=False), - ) - self.async_client = AsyncOpenAI( - api_key=api_key.resolve_value(), - organization=organization, - base_url=api_base_url, - http_client=init_http_client(self.http_client_kwargs, async_client=True), - ) + + self.client: OpenAI | None = None + self.async_client: AsyncOpenAI | None = None + + def _client_kwargs(self) -> dict[str, Any]: + return { + "api_key": self.api_key.resolve_value(), + "organization": self.organization, + "base_url": self.api_base_url, + } + + def warm_up(self) -> None: + """ + Initializes the synchronous OpenAI client. + """ + if self.client is None: + self.client = OpenAI( + http_client=init_http_client(self.http_client_kwargs, async_client=False), **self._client_kwargs() + ) + + async def warm_up_async(self) -> None: # noqa: RUF029 + """ + Initializes the asynchronous OpenAI client on the serving event loop. + """ + if self.async_client is None: + self.async_client = AsyncOpenAI( + http_client=init_http_client(self.http_client_kwargs, async_client=True), **self._client_kwargs() + ) + + def close(self) -> None: + """ + Releases the synchronous OpenAI client. + """ + if self.client is not None: + self.client.close() + self.client = None + + async def close_async(self) -> None: + """ + Releases the asynchronous OpenAI client. + """ + if self.async_client is not None: + await self.async_client.close() + self.async_client = None def to_dict(self) -> dict[str, Any]: """ @@ -152,6 +184,7 @@ def run(self, sources: list[str | Path | ByteStream]) -> dict[str, Any]: - `documents`: A list of documents, one document for each file. The content of each document is the transcribed text. """ + self.warm_up() documents = [] for source in sources: @@ -163,6 +196,7 @@ def run(self, sources: list[str | Path | ByteStream]) -> dict[str, Any]: file = io.BytesIO(source.data) file.name = str(source.meta["file_path"]) if "file_path" in source.meta else "__fallback__.wav" + assert self.client is not None # mypy: client is built by warm_up above content = self.client.audio.transcriptions.create(file=file, model=self.model, **self.whisper_params) doc = Document(content=content.text, meta=source.meta) documents.append(doc) @@ -184,6 +218,7 @@ async def run_async(self, sources: list[str | Path | ByteStream]) -> dict[str, A - `documents`: A list of documents, one document for each file. The content of each document is the transcribed text. """ + await self.warm_up_async() documents = [] for source in sources: @@ -195,6 +230,7 @@ async def run_async(self, sources: list[str | Path | ByteStream]) -> dict[str, A file = io.BytesIO(source.data) file.name = str(source.meta["file_path"]) if "file_path" in source.meta else "__fallback__.wav" + assert self.async_client is not None # mypy: async_client is built by warm_up_async above content = await self.async_client.audio.transcriptions.create( file=file, model=self.model, **self.whisper_params ) diff --git a/haystack/components/embedders/azure_document_embedder.py b/haystack/components/embedders/azure_document_embedder.py index 490a36710b..83153e22b0 100644 --- a/haystack/components/embedders/azure_document_embedder.py +++ b/haystack/components/embedders/azure_document_embedder.py @@ -137,32 +137,67 @@ def __init__( # noqa: PLR0913 (too-many-arguments) self.progress_bar = progress_bar self.meta_fields_to_embed = meta_fields_to_embed or [] self.embedding_separator = embedding_separator - self.timeout = timeout if timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0")) - self.max_retries = max_retries if max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5")) + self.timeout = timeout + self.max_retries = max_retries self.default_headers = default_headers or {} self.azure_ad_token_provider = azure_ad_token_provider self.http_client_kwargs = http_client_kwargs self.raise_on_failure = raise_on_failure - client_args: dict[str, Any] = { - "api_version": api_version, - "azure_endpoint": azure_endpoint, - "azure_deployment": azure_deployment, - "azure_ad_token_provider": azure_ad_token_provider, - "api_key": api_key.resolve_value() if api_key is not None else None, - "azure_ad_token": azure_ad_token.resolve_value() if azure_ad_token is not None else None, - "organization": organization, - "timeout": self.timeout, - "max_retries": self.max_retries, + self.client: AzureOpenAI | None = None + self.async_client: AsyncAzureOpenAI | None = None + + def _client_kwargs(self) -> dict[str, Any]: + timeout = self.timeout if self.timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0")) + max_retries = ( + self.max_retries if self.max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5")) + ) + return { + "api_version": self.api_version, + "azure_endpoint": self.azure_endpoint, + "azure_deployment": self.azure_deployment, + "azure_ad_token_provider": self.azure_ad_token_provider, + "api_key": self.api_key.resolve_value() if self.api_key is not None else None, + "azure_ad_token": self.azure_ad_token.resolve_value() if self.azure_ad_token is not None else None, + "organization": self.organization, + "timeout": timeout, + "max_retries": max_retries, "default_headers": self.default_headers, } - self.client = AzureOpenAI( - http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_args - ) - self.async_client = AsyncAzureOpenAI( - http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_args - ) + def warm_up(self) -> None: + """ + Initializes the synchronous AzureOpenAI client. + """ + if self.client is None: + self.client = AzureOpenAI( + http_client=init_http_client(self.http_client_kwargs, async_client=False), **self._client_kwargs() + ) + + async def warm_up_async(self) -> None: # noqa: RUF029 + """ + Initializes the asynchronous AzureOpenAI client on the serving event loop. + """ + if self.async_client is None: + self.async_client = AsyncAzureOpenAI( + http_client=init_http_client(self.http_client_kwargs, async_client=True), **self._client_kwargs() + ) + + def close(self) -> None: + """ + Releases the synchronous AzureOpenAI client. + """ + if self.client is not None: + self.client.close() + self.client = None + + async def close_async(self) -> None: + """ + Releases the asynchronous AzureOpenAI client. + """ + if self.async_client is not None: + await self.async_client.close() + self.async_client = None def to_dict(self) -> dict[str, Any]: """ diff --git a/haystack/components/embedders/azure_text_embedder.py b/haystack/components/embedders/azure_text_embedder.py index a17efcb33c..9310b8f9ec 100644 --- a/haystack/components/embedders/azure_text_embedder.py +++ b/haystack/components/embedders/azure_text_embedder.py @@ -117,33 +117,68 @@ def __init__( # noqa: PLR0913 self.model = azure_deployment self.dimensions = dimensions self.organization = organization - self.timeout = timeout if timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0")) - self.max_retries = max_retries if max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5")) + self.timeout = timeout + self.max_retries = max_retries self.prefix = prefix self.suffix = suffix self.default_headers = default_headers or {} self.azure_ad_token_provider = azure_ad_token_provider self.http_client_kwargs = http_client_kwargs - client_kwargs: dict[str, Any] = { - "api_version": api_version, - "azure_endpoint": azure_endpoint, - "azure_deployment": azure_deployment, - "azure_ad_token_provider": azure_ad_token_provider, - "api_key": api_key.resolve_value() if api_key is not None else None, - "azure_ad_token": azure_ad_token.resolve_value() if azure_ad_token is not None else None, - "organization": organization, - "timeout": self.timeout, - "max_retries": self.max_retries, + self.client: AzureOpenAI | None = None + self.async_client: AsyncAzureOpenAI | None = None + + def _client_kwargs(self) -> dict[str, Any]: + timeout = self.timeout if self.timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0")) + max_retries = ( + self.max_retries if self.max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5")) + ) + return { + "api_version": self.api_version, + "azure_endpoint": self.azure_endpoint, + "azure_deployment": self.azure_deployment, + "azure_ad_token_provider": self.azure_ad_token_provider, + "api_key": self.api_key.resolve_value() if self.api_key is not None else None, + "azure_ad_token": self.azure_ad_token.resolve_value() if self.azure_ad_token is not None else None, + "organization": self.organization, + "timeout": timeout, + "max_retries": max_retries, "default_headers": self.default_headers, } - self.client = AzureOpenAI( - http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_kwargs - ) - self.async_client = AsyncAzureOpenAI( - http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_kwargs - ) + def warm_up(self) -> None: + """ + Initializes the synchronous Azure OpenAI client. + """ + if self.client is None: + self.client = AzureOpenAI( + http_client=init_http_client(self.http_client_kwargs, async_client=False), **self._client_kwargs() + ) + + async def warm_up_async(self) -> None: # noqa: RUF029 + """ + Initializes the asynchronous Azure OpenAI client on the serving event loop. + """ + if self.async_client is None: + self.async_client = AsyncAzureOpenAI( + http_client=init_http_client(self.http_client_kwargs, async_client=True), **self._client_kwargs() + ) + + def close(self) -> None: + """ + Releases the synchronous Azure OpenAI client. + """ + if self.client is not None: + self.client.close() + self.client = None + + async def close_async(self) -> None: + """ + Releases the asynchronous Azure OpenAI client. + """ + if self.async_client is not None: + await self.async_client.close() + self.async_client = None def to_dict(self) -> dict[str, Any]: """ diff --git a/haystack/components/embedders/openai_document_embedder.py b/haystack/components/embedders/openai_document_embedder.py index 1d3626c3c2..9005a7760d 100644 --- a/haystack/components/embedders/openai_document_embedder.py +++ b/haystack/components/embedders/openai_document_embedder.py @@ -122,23 +122,55 @@ def __init__( # noqa: PLR0913 (too-many-arguments) self.http_client_kwargs = http_client_kwargs self.raise_on_failure = raise_on_failure - if timeout is None: - timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0")) - if max_retries is None: - max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5")) - - client_kwargs: dict[str, Any] = { - "api_key": api_key.resolve_value(), - "organization": organization, - "base_url": api_base_url, + self.client: OpenAI | None = None + self.async_client: AsyncOpenAI | None = None + + def _client_kwargs(self) -> dict[str, Any]: + timeout = self.timeout if self.timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0")) + max_retries = ( + self.max_retries if self.max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5")) + ) + return { + "api_key": self.api_key.resolve_value(), + "organization": self.organization, + "base_url": self.api_base_url, "timeout": timeout, "max_retries": max_retries, } - self.client = OpenAI(http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_kwargs) - self.async_client = AsyncOpenAI( - http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_kwargs - ) + def warm_up(self) -> None: + """ + Initializes the synchronous OpenAI client. + """ + if self.client is None: + self.client = OpenAI( + http_client=init_http_client(self.http_client_kwargs, async_client=False), **self._client_kwargs() + ) + + async def warm_up_async(self) -> None: # noqa: RUF029 + """ + Initializes the asynchronous OpenAI client on the serving event loop. + """ + if self.async_client is None: + self.async_client = AsyncOpenAI( + http_client=init_http_client(self.http_client_kwargs, async_client=True), **self._client_kwargs() + ) + + def close(self) -> None: + """ + Releases the synchronous OpenAI client. + """ + if self.client is not None: + self.client.close() + self.client = None + + async def close_async(self) -> None: + """ + Releases the asynchronous OpenAI client. + """ + if self.async_client is not None: + await self.async_client.close() + self.async_client = None def _get_telemetry_data(self) -> dict[str, Any]: """ @@ -218,6 +250,8 @@ def _embed_batch( args["dimensions"] = self.dimensions try: + # this method is invoked after warm_up, so client is not None + assert self.client is not None response = self.client.embeddings.create(**args) except APIError as exc: ids = ", ".join(b[0] for b in batch) @@ -261,6 +295,8 @@ async def _embed_batch_async( args["dimensions"] = self.dimensions try: + # this method is invoked after warm_up_async, so async_client is not None + assert self.async_client is not None response = await self.async_client.embeddings.create(**args) except APIError as exc: ids = ", ".join(b[0] for b in batch) @@ -302,6 +338,8 @@ def run(self, documents: list[Document]) -> dict[str, Any]: "In case you want to embed a string, please use the OpenAITextEmbedder." ) + self.warm_up() + texts_to_embed = self._prepare_texts_to_embed(documents=documents) doc_ids_to_embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size) @@ -334,6 +372,8 @@ async def run_async(self, documents: list[Document]) -> dict[str, Any]: "In case you want to embed a string, please use the OpenAITextEmbedder." ) + await self.warm_up_async() + texts_to_embed = self._prepare_texts_to_embed(documents=documents) doc_ids_to_embeddings, meta = await self._embed_batch_async( diff --git a/haystack/components/embedders/openai_text_embedder.py b/haystack/components/embedders/openai_text_embedder.py index 24672bb050..7a2269012c 100644 --- a/haystack/components/embedders/openai_text_embedder.py +++ b/haystack/components/embedders/openai_text_embedder.py @@ -97,23 +97,55 @@ def __init__( self.max_retries = max_retries self.http_client_kwargs = http_client_kwargs - if timeout is None: - timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0")) - if max_retries is None: - max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5")) - - client_kwargs: dict[str, Any] = { - "api_key": api_key.resolve_value(), - "organization": organization, - "base_url": api_base_url, + self.client: OpenAI | None = None + self.async_client: AsyncOpenAI | None = None + + def _client_kwargs(self) -> dict[str, Any]: + timeout = self.timeout if self.timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0")) + max_retries = ( + self.max_retries if self.max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5")) + ) + return { + "api_key": self.api_key.resolve_value(), + "organization": self.organization, + "base_url": self.api_base_url, "timeout": timeout, "max_retries": max_retries, } - self.client = OpenAI(http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_kwargs) - self.async_client = AsyncOpenAI( - http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_kwargs - ) + def warm_up(self) -> None: + """ + Initializes the synchronous OpenAI client. + """ + if self.client is None: + self.client = OpenAI( + http_client=init_http_client(self.http_client_kwargs, async_client=False), **self._client_kwargs() + ) + + async def warm_up_async(self) -> None: # noqa: RUF029 + """ + Initializes the asynchronous OpenAI client on the serving event loop. + """ + if self.async_client is None: + self.async_client = AsyncOpenAI( + http_client=init_http_client(self.http_client_kwargs, async_client=True), **self._client_kwargs() + ) + + def close(self) -> None: + """ + Releases the synchronous OpenAI client. + """ + if self.client is not None: + self.client.close() + self.client = None + + async def close_async(self) -> None: + """ + Releases the asynchronous OpenAI client. + """ + if self.async_client is not None: + await self.async_client.close() + self.async_client = None def _get_telemetry_data(self) -> dict[str, Any]: """ @@ -184,7 +216,9 @@ def run(self, text: str) -> dict[str, Any]: - `embedding`: The embedding of the input text. - `meta`: Information about the usage of the model. """ + self.warm_up() create_kwargs = self._prepare_input(text=text) + assert self.client is not None # mypy: client is built by warm_up above response = self.client.embeddings.create(**create_kwargs) return self._prepare_output(result=response) @@ -204,6 +238,8 @@ async def run_async(self, text: str) -> dict[str, Any]: - `embedding`: The embedding of the input text. - `meta`: Information about the usage of the model. """ + await self.warm_up_async() create_kwargs = self._prepare_input(text=text) + assert self.async_client is not None # mypy: async_client is built by warm_up_async above response = await self.async_client.embeddings.create(**create_kwargs) return self._prepare_output(result=response) diff --git a/haystack/components/evaluators/llm_evaluator.py b/haystack/components/evaluators/llm_evaluator.py index 3a18e8058d..42e5c7b255 100644 --- a/haystack/components/evaluators/llm_evaluator.py +++ b/haystack/components/evaluators/llm_evaluator.py @@ -111,16 +111,37 @@ def __init__( generation_kwargs = {"response_format": {"type": "json_object"}, "seed": 42} self._chat_generator = OpenAIChatGenerator(generation_kwargs=generation_kwargs) - self._is_warmed_up = False - def warm_up(self) -> None: """ - Warm up the component by warming up the underlying chat generator. + Warm up the underlying chat generator. + """ + if hasattr(self._chat_generator, "warm_up"): + self._chat_generator.warm_up() + + async def warm_up_async(self) -> None: + """ + Warm up the underlying chat generator on the serving event loop. + """ + if hasattr(self._chat_generator, "warm_up_async"): + await self._chat_generator.warm_up_async() + elif hasattr(self._chat_generator, "warm_up"): + self._chat_generator.warm_up() + + def close(self) -> None: + """ + Release the underlying chat generator's resources. + """ + if hasattr(self._chat_generator, "close"): + self._chat_generator.close() + + async def close_async(self) -> None: + """ + Release the underlying chat generator's async resources. """ - if not self._is_warmed_up: - if hasattr(self._chat_generator, "warm_up"): - self._chat_generator.warm_up() - self._is_warmed_up = True + if hasattr(self._chat_generator, "close_async"): + await self._chat_generator.close_async() + elif hasattr(self._chat_generator, "close"): + self._chat_generator.close() @staticmethod def validate_init_parameters( @@ -195,8 +216,7 @@ def run(self, **inputs: Any) -> dict[str, Any]: Only in the case that `raise_on_failure` is set to True and the received inputs are not lists or have different lengths, or if the output is not a valid JSON or doesn't contain the expected keys. """ - if not self._is_warmed_up: - self.warm_up() + self.warm_up() self.validate_input_parameters(dict(self.inputs), inputs) @@ -263,8 +283,7 @@ async def run_async(self, **inputs: Any) -> dict[str, Any]: different lengths, or if the output is not a valid JSON or doesn't contain the expected keys. """ - if not self._is_warmed_up: - self.warm_up() + await self.warm_up_async() self.validate_input_parameters(dict(self.inputs), inputs) diff --git a/haystack/components/extractors/image/llm_document_content_extractor.py b/haystack/components/extractors/image/llm_document_content_extractor.py index f5f78c82e5..d1e09433c5 100644 --- a/haystack/components/extractors/image/llm_document_content_extractor.py +++ b/haystack/components/extractors/image/llm_document_content_extractor.py @@ -169,16 +169,38 @@ def __init__( self._document_to_image_content = DocumentToImageContent( file_path_meta_field=file_path_meta_field, root_path=root_path, detail=detail, size=size ) - self._is_warmed_up = False def warm_up(self) -> None: """ - Warm up the ChatGenerator if it has a warm_up method. + Warm up the underlying chat generator. """ - if not self._is_warmed_up: - if hasattr(self._chat_generator, "warm_up"): - self._chat_generator.warm_up() - self._is_warmed_up = True + if hasattr(self._chat_generator, "warm_up"): + self._chat_generator.warm_up() + + async def warm_up_async(self) -> None: + """ + Warm up the underlying chat generator on the serving event loop. + """ + if hasattr(self._chat_generator, "warm_up_async"): + await self._chat_generator.warm_up_async() + elif hasattr(self._chat_generator, "warm_up"): + self._chat_generator.warm_up() + + def close(self) -> None: + """ + Release the underlying chat generator's resources. + """ + if hasattr(self._chat_generator, "close"): + self._chat_generator.close() + + async def close_async(self) -> None: + """ + Release the underlying chat generator's async resources. + """ + if hasattr(self._chat_generator, "close_async"): + await self._chat_generator.close_async() + elif hasattr(self._chat_generator, "close"): + self._chat_generator.close() def to_dict(self) -> dict[str, Any]: """ @@ -340,8 +362,7 @@ def run(self, documents: list[Document]) -> dict[str, list[Document]]: if not documents: return {"documents": [], "failed_documents": []} - if not self._is_warmed_up: - self.warm_up() + self.warm_up() image_contents = self._document_to_image_content.run(documents=documents)["image_contents"] @@ -376,8 +397,7 @@ async def run_async(self, documents: list[Document]) -> dict[str, list[Document] if not documents: return {"documents": [], "failed_documents": []} - if not self._is_warmed_up: - self.warm_up() + await self.warm_up_async() image_contents = self._document_to_image_content.run(documents=documents)["image_contents"] diff --git a/haystack/components/extractors/llm_metadata_extractor.py b/haystack/components/extractors/llm_metadata_extractor.py index 64fd5c5ae8..abb04e7d8a 100644 --- a/haystack/components/extractors/llm_metadata_extractor.py +++ b/haystack/components/extractors/llm_metadata_extractor.py @@ -197,16 +197,42 @@ def __init__( self.expanded_range = expand_page_range(page_range) if page_range else None self.max_workers = max_workers self._chat_generator = chat_generator - self._is_warmed_up = False def warm_up(self) -> None: """ - Warm up the LLM provider component. + Warm up the underlying chat generator and splitter. """ - if not self._is_warmed_up: - if hasattr(self._chat_generator, "warm_up"): - self._chat_generator.warm_up() - self._is_warmed_up = True + for inner in (self._chat_generator, self.splitter): + if hasattr(inner, "warm_up"): + inner.warm_up() + + async def warm_up_async(self) -> None: + """ + Warm up the underlying chat generator and splitter on the serving event loop. + """ + for inner in (self._chat_generator, self.splitter): + if hasattr(inner, "warm_up_async"): + await inner.warm_up_async() + elif hasattr(inner, "warm_up"): + inner.warm_up() + + def close(self) -> None: + """ + Release the underlying chat generator's and splitter's resources. + """ + for inner in (self._chat_generator, self.splitter): + if hasattr(inner, "close"): + inner.close() + + async def close_async(self) -> None: + """ + Release the underlying chat generator's and splitter's async resources. + """ + for inner in (self._chat_generator, self.splitter): + if hasattr(inner, "close_async"): + await inner.close_async() + elif hasattr(inner, "close"): + inner.close() def to_dict(self) -> dict[str, Any]: """ @@ -376,8 +402,7 @@ def run(self, documents: list[Document], page_range: list[str | int] | None = No logger.warning("No documents provided. Skipping metadata extraction.") return {"documents": [], "failed_documents": []} - if not self._is_warmed_up: - self.warm_up() + self.warm_up() expanded_range = self.expanded_range if page_range: @@ -426,8 +451,7 @@ async def run_async(self, documents: list[Document], page_range: list[str | int] logger.warning("No documents provided. Skipping metadata extraction.") return {"documents": [], "failed_documents": []} - if not self._is_warmed_up: - self.warm_up() + await self.warm_up_async() expanded_range = self.expanded_range if page_range: diff --git a/haystack/components/fetchers/link_content.py b/haystack/components/fetchers/link_content.py index 4bd983e799..2a2959dd1d 100644 --- a/haystack/components/fetchers/link_content.py +++ b/haystack/components/fetchers/link_content.py @@ -149,26 +149,9 @@ def __init__( self.client_kwargs.setdefault("timeout", timeout) self.client_kwargs.setdefault("follow_redirects", True) - # Create httpx clients - client_kwargs = {**self.client_kwargs} - - # Optional HTTP/2 support - if http2: - try: - h2_import.check() - client_kwargs["http2"] = True - except ImportError: - logger.warning( - "HTTP/2 support requested but 'h2' package is not installed. " - "Falling back to HTTP/1.1. Install with `pip install httpx[http2]` to enable HTTP/2 support." - ) - self.http2 = False # Update the setting to match actual capability - - # Initialize synchronous client - self._client = httpx.Client(**client_kwargs) - - # Initialize asynchronous client - self._async_client = httpx.AsyncClient(**client_kwargs) + # httpx clients are built lazily in warm_up / warm_up_async (resource lifecycle) + self._client: httpx.Client | None = None + self._async_client: httpx.AsyncClient | None = None # register default content handlers that extract data from the response self.handlers: dict[str, Callable[[httpx.Response], ByteStream]] = defaultdict(lambda: _text_content_handler) @@ -189,40 +172,76 @@ def __init__( after=self._switch_user_agent, ) def get_response(url: str) -> httpx.Response: + assert self._client is not None # mypy: client is built by warm_up before run response = self._client.get(url, headers=self._get_headers()) response.raise_for_status() return response self._get_response: Callable = get_response + def _build_client_kwargs(self) -> dict[str, Any]: + """ + Build the keyword arguments used to construct the httpx clients. + + Resolves optional HTTP/2 support, downgrading to HTTP/1.1 if the 'h2' package is not installed. + """ + client_kwargs = {**self.client_kwargs} + + # Optional HTTP/2 support + if self.http2: + try: + h2_import.check() + client_kwargs["http2"] = True + except ImportError: + logger.warning( + "HTTP/2 support requested but 'h2' package is not installed. " + "Falling back to HTTP/1.1. Install with `pip install httpx[http2]` to enable HTTP/2 support." + ) + self.http2 = False # Update the setting to match actual capability + + return client_kwargs + + def warm_up(self) -> None: + """ + Initializes the synchronous httpx client. + """ + if self._client is None: + self._client = httpx.Client(**self._build_client_kwargs()) + + async def warm_up_async(self) -> None: # noqa: RUF029 + """ + Initializes the asynchronous httpx client on the serving event loop. + """ + if self._async_client is None: + self._async_client = httpx.AsyncClient(**self._build_client_kwargs()) + + def close(self) -> None: + """ + Releases the synchronous httpx client. + """ + if self._client is not None: + self._client.close() + self._client = None + + async def close_async(self) -> None: + """ + Releases the asynchronous httpx client. + """ + if self._async_client is not None: + await self._async_client.aclose() + self._async_client = None + def _get_headers(self) -> dict[str, str]: """ Build headers with precedence client defaults -> component defaults -> user-provided -> rotating UA """ - base = dict(self._client.headers) + base = dict(self._client.headers) if self._client is not None else {} return _merge_headers( base, REQUEST_HEADERS, self.request_headers, {"User-Agent": self.user_agents[self.current_user_agent_idx]} ) - def __del__(self) -> None: - """ - Clean up resources when the component is deleted. - - Closes both the synchronous and asynchronous HTTP clients to prevent - resource leaks. - """ - try: - # Close the synchronous client if it exists - if hasattr(self, "_client"): - self._client.close() - - # There is no way to close the async client without await - except Exception: - # Suppress any exceptions during cleanup - pass - @component.output_types(streams=list[ByteStream]) def run(self, urls: list[str]) -> dict[str, Any]: """ @@ -241,6 +260,8 @@ def run(self, urls: list[str]) -> dict[str, Any]: In all other scenarios, any retrieval errors are logged, and a list of successfully retrieved `ByteStream` objects is returned. """ + self.warm_up() + streams: list[ByteStream] = [] if not urls: return {"streams": streams} @@ -273,10 +294,13 @@ async def run_async(self, urls: list[str]) -> dict[str, Any]: :param urls: A list of URLs to fetch content from. :returns: `ByteStream` objects representing the extracted content. """ + await self.warm_up_async() + streams: list[ByteStream] = [] if not urls: return {"streams": streams} + assert self._async_client is not None # mypy: async_client is built by warm_up_async above # Create tasks for all URLs using _fetch_async directly tasks = [self._fetch_async(url, self._async_client) for url in urls] diff --git a/haystack/components/generators/chat/azure.py b/haystack/components/generators/chat/azure.py index 89450910ee..a77e93f942 100644 --- a/haystack/components/generators/chat/azure.py +++ b/haystack/components/generators/chat/azure.py @@ -201,14 +201,13 @@ def __init__( # None init parameters. This way we accommodate the use case where env var AZURE_OPENAI_ENDPOINT is set instead # of passing it as a parameter. azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT") - # `azure_endpoint` and `api_version` accept either a plain string or a `Secret`. We keep the original value - # on the instance for serialization and resolve it to a string only when building the client. + # `azure_endpoint` accepts either a plain string or a `Secret`. We keep the original value on the instance for + # serialization and resolve it to a string only to validate that an endpoint was provided. resolved_azure_endpoint = ( azure_endpoint.resolve_value() if isinstance(azure_endpoint, Secret) else azure_endpoint ) if not resolved_azure_endpoint: raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.") - resolved_api_version = api_version.resolve_value() if isinstance(api_version, Secret) else api_version if api_key is None and azure_ad_token is None: raise ValueError("Please provide an API key or an Azure Active Directory token.") @@ -224,8 +223,8 @@ def __init__( self.azure_deployment = azure_deployment self.organization = organization self.model = azure_deployment or "gpt-4.1-mini" - self.timeout = timeout if timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0")) - self.max_retries = max_retries if max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5")) + self.timeout = timeout + self.max_retries = max_retries self.default_headers = default_headers or {} self.azure_ad_token_provider = azure_ad_token_provider self.http_client_kwargs = http_client_kwargs @@ -233,37 +232,74 @@ def __init__( self.tools = tools self.tools_strict = tools_strict - client_args: dict[str, Any] = { + self.client: AzureOpenAI | None = None + self.async_client: AsyncAzureOpenAI | None = None + self._tools_warmed_up = False + + def _client_kwargs(self) -> dict[str, Any]: + timeout = self.timeout if self.timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0")) + max_retries = ( + self.max_retries if self.max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5")) + ) + resolved_azure_endpoint = ( + self.azure_endpoint.resolve_value() if isinstance(self.azure_endpoint, Secret) else self.azure_endpoint + ) + resolved_api_version = ( + self.api_version.resolve_value() if isinstance(self.api_version, Secret) else self.api_version + ) + return { "api_version": resolved_api_version, "azure_endpoint": resolved_azure_endpoint, - "azure_deployment": azure_deployment, - "api_key": api_key.resolve_value() if api_key is not None else None, - "azure_ad_token": azure_ad_token.resolve_value() if azure_ad_token is not None else None, - "organization": organization, - "timeout": self.timeout, - "max_retries": self.max_retries, + "azure_deployment": self.azure_deployment, + "api_key": self.api_key.resolve_value() if self.api_key is not None else None, + "azure_ad_token": self.azure_ad_token.resolve_value() if self.azure_ad_token is not None else None, + "organization": self.organization, + "timeout": timeout, + "max_retries": max_retries, "default_headers": self.default_headers, - "azure_ad_token_provider": azure_ad_token_provider, + "azure_ad_token_provider": self.azure_ad_token_provider, } - self.client = AzureOpenAI( - http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_args - ) - self.async_client = AsyncAzureOpenAI( - http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_args - ) - self._is_warmed_up = False + def _warm_up_tools(self) -> None: + if not self._tools_warmed_up: + warm_up_tools(self.tools) + self._tools_warmed_up = True def warm_up(self) -> None: """ - Warm up the Azure OpenAI chat generator. + Warm up the tools and initialize the synchronous Azure OpenAI client. + """ + self._warm_up_tools() + if self.client is None: + self.client = AzureOpenAI( + http_client=init_http_client(self.http_client_kwargs, async_client=False), **self._client_kwargs() + ) - This will warm up the tools registered in the chat generator. - This method is idempotent and will only warm up the tools once. + async def warm_up_async(self) -> None: # noqa: RUF029 """ - if not self._is_warmed_up: - warm_up_tools(self.tools) - self._is_warmed_up = True + Warm up the tools and initialize the asynchronous Azure OpenAI client on the serving event loop. + """ + self._warm_up_tools() + if self.async_client is None: + self.async_client = AsyncAzureOpenAI( + http_client=init_http_client(self.http_client_kwargs, async_client=True), **self._client_kwargs() + ) + + def close(self) -> None: + """ + Releases the synchronous Azure OpenAI client. + """ + if self.client is not None: + self.client.close() + self.client = None + + async def close_async(self) -> None: + """ + Releases the asynchronous Azure OpenAI client. + """ + if self.async_client is not None: + await self.async_client.close() + self.async_client = None def to_dict(self) -> dict[str, Any]: """ diff --git a/haystack/components/generators/chat/fallback.py b/haystack/components/generators/chat/fallback.py index 29a8b81844..c2083dd9a7 100644 --- a/haystack/components/generators/chat/fallback.py +++ b/haystack/components/generators/chat/fallback.py @@ -59,7 +59,6 @@ def __init__(self, chat_generators: list[ChatGenerator]) -> None: raise ValueError(msg) self.chat_generators = list(chat_generators) - self._is_warmed_up = False def to_dict(self) -> dict[str, Any]: """Serialize the component, including nested chat generators when they support serialization.""" @@ -84,19 +83,32 @@ def from_dict(cls, data: dict[str, Any]) -> FallbackChatGenerator: return default_from_dict(cls, data) def warm_up(self) -> None: - """ - Warm up all underlying chat generators. - - This method calls warm_up() on each underlying generator that supports it. - """ - if self._is_warmed_up: - return + """Warm up all underlying chat generators.""" + for gen in self.chat_generators: + if hasattr(gen, "warm_up"): + gen.warm_up() + async def warm_up_async(self) -> None: + """Warm up all underlying chat generators on the serving event loop.""" for gen in self.chat_generators: - if hasattr(gen, "warm_up") and callable(gen.warm_up): + if hasattr(gen, "warm_up_async"): + await gen.warm_up_async() + elif hasattr(gen, "warm_up"): gen.warm_up() - self._is_warmed_up = True + def close(self) -> None: + """Release the underlying chat generators' resources.""" + for gen in self.chat_generators: + if hasattr(gen, "close"): + gen.close() + + async def close_async(self) -> None: + """Release the underlying chat generators' async resources.""" + for gen in self.chat_generators: + if hasattr(gen, "close_async"): + await gen.close_async() + elif hasattr(gen, "close"): + gen.close() def _run_single_sync( self, @@ -147,8 +159,7 @@ def run( total_attempts, failed_chat_generators, plus any metadata from the successful generator. :raises RuntimeError: If all chat generators fail. """ - if not self._is_warmed_up: - self.warm_up() + self.warm_up() messages = _normalize_messages(messages) @@ -205,8 +216,7 @@ async def run_async( total_attempts, failed_chat_generators, plus any metadata from the successful generator. :raises RuntimeError: If all chat generators fail. """ - if not self._is_warmed_up: - self.warm_up() + await self.warm_up_async() messages = _normalize_messages(messages) diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index e2bb5be2c0..3547f672d0 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -209,35 +209,63 @@ def __init__( # Check for duplicate tool names _check_duplicate_tool_names(flatten_tools_or_toolsets(self.tools)) - if timeout is None: - timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0")) - if max_retries is None: - max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5")) - - client_kwargs: dict[str, Any] = { - "api_key": api_key.resolve_value(), - "organization": organization, - "base_url": api_base_url, + self.client: OpenAI | None = None + self.async_client: AsyncOpenAI | None = None + self._tools_warmed_up = False + + def _client_kwargs(self) -> dict[str, Any]: + timeout = self.timeout if self.timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0")) + max_retries = ( + self.max_retries if self.max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5")) + ) + return { + "api_key": self.api_key.resolve_value(), + "organization": self.organization, + "base_url": self.api_base_url, "timeout": timeout, "max_retries": max_retries, } - self.client = OpenAI(http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_kwargs) - self.async_client = AsyncOpenAI( - http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_kwargs - ) - self._is_warmed_up = False + def _warm_up_tools(self) -> None: + if not self._tools_warmed_up: + warm_up_tools(self.tools) + self._tools_warmed_up = True def warm_up(self) -> None: """ - Warm up the OpenAI chat generator. + Warm up the tools and initialize the synchronous OpenAI client. + """ + self._warm_up_tools() + if self.client is None: + self.client = OpenAI( + http_client=init_http_client(self.http_client_kwargs, async_client=False), **self._client_kwargs() + ) - This will warm up the tools registered in the chat generator. - This method is idempotent and will only warm up the tools once. + async def warm_up_async(self) -> None: # noqa: RUF029 """ - if not self._is_warmed_up: - warm_up_tools(self.tools) - self._is_warmed_up = True + Warm up the tools and initialize the asynchronous OpenAI client on the serving event loop. + """ + self._warm_up_tools() + if self.async_client is None: + self.async_client = AsyncOpenAI( + http_client=init_http_client(self.http_client_kwargs, async_client=True), **self._client_kwargs() + ) + + def close(self) -> None: + """ + Releases the synchronous OpenAI client. + """ + if self.client is not None: + self.client.close() + self.client = None + + async def close_async(self) -> None: + """ + Releases the asynchronous OpenAI client. + """ + if self.async_client is not None: + await self.async_client.close() + self.async_client = None def _get_telemetry_data(self) -> dict[str, Any]: """ @@ -335,8 +363,7 @@ def run( A dictionary with the following key: - `replies`: A list containing the generated responses as ChatMessage instances. """ - if not self._is_warmed_up: - self.warm_up() + self.warm_up() messages = _normalize_messages(messages) @@ -356,6 +383,7 @@ def run( tools_strict=tools_strict, ) openai_endpoint = api_args.pop("openai_endpoint") + assert self.client is not None # mypy: client is built by warm_up above openai_endpoint_method = getattr(self.client.chat.completions, openai_endpoint) chat_completion = openai_endpoint_method(**api_args) @@ -416,8 +444,7 @@ async def run_async( A dictionary with the following key: - `replies`: A list containing the generated responses as ChatMessage instances. """ - if not self._is_warmed_up: - self.warm_up() + await self.warm_up_async() messages = _normalize_messages(messages) @@ -439,6 +466,7 @@ async def run_async( ) openai_endpoint = api_args.pop("openai_endpoint") + assert self.async_client is not None # mypy: async_client is built by warm_up_async above openai_endpoint_method = getattr(self.async_client.chat.completions, openai_endpoint) chat_completion = await openai_endpoint_method(**api_args) diff --git a/haystack/components/generators/chat/openai_responses.py b/haystack/components/generators/chat/openai_responses.py index cb523d6940..492e3c4608 100644 --- a/haystack/components/generators/chat/openai_responses.py +++ b/haystack/components/generators/chat/openai_responses.py @@ -195,40 +195,68 @@ def __init__( self.tools_strict = tools_strict self.http_client_kwargs = http_client_kwargs - if timeout is None: - timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0")) - if max_retries is None: - max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5")) - - resolved_api_key = api_key.resolve_value() if isinstance(api_key, Secret) else api_key - client_kwargs: dict[str, Any] = { + self.client: OpenAI | None = None + self.async_client: AsyncOpenAI | None = None + self._tools_warmed_up = False + + def _client_kwargs(self) -> dict[str, Any]: + timeout = self.timeout if self.timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0")) + max_retries = ( + self.max_retries if self.max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5")) + ) + resolved_api_key = self.api_key.resolve_value() if isinstance(self.api_key, Secret) else self.api_key + return { "api_key": resolved_api_key, - "organization": organization, - "base_url": api_base_url, + "organization": self.organization, + "base_url": self.api_base_url, "timeout": timeout, "max_retries": max_retries, } - self.client = OpenAI(http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_kwargs) - self.async_client = AsyncOpenAI( - http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_kwargs - ) - self._is_warmed_up = False - - def warm_up(self) -> None: - """ - Warm up the OpenAI responses chat generator. - - This will warm up the tools registered in the chat generator. - This method is idempotent and will only warm up the tools once. - """ - if not self._is_warmed_up: + def _warm_up_tools(self) -> None: + if not self._tools_warmed_up: is_openai_tool = isinstance(self.tools, list) and isinstance(self.tools[0], dict) # We only warm up Haystack tools, not OpenAI/MCP tools # The type ignore is needed because mypy cannot infer the type correctly if not is_openai_tool: warm_up_tools(self.tools) # type: ignore[arg-type] - self._is_warmed_up = True + self._tools_warmed_up = True + + def warm_up(self) -> None: + """ + Warm up the tools and initialize the synchronous OpenAI client. + """ + self._warm_up_tools() + if self.client is None: + self.client = OpenAI( + http_client=init_http_client(self.http_client_kwargs, async_client=False), **self._client_kwargs() + ) + + async def warm_up_async(self) -> None: # noqa: RUF029 + """ + Warm up the tools and initialize the asynchronous OpenAI client on the serving event loop. + """ + self._warm_up_tools() + if self.async_client is None: + self.async_client = AsyncOpenAI( + http_client=init_http_client(self.http_client_kwargs, async_client=True), **self._client_kwargs() + ) + + def close(self) -> None: + """ + Releases the synchronous OpenAI client. + """ + if self.client is not None: + self.client.close() + self.client = None + + async def close_async(self) -> None: + """ + Releases the asynchronous OpenAI client. + """ + if self.async_client is not None: + await self.async_client.close() + self.async_client = None def _get_telemetry_data(self) -> dict[str, Any]: """ @@ -349,8 +377,7 @@ def run( A dictionary with the following key: - `replies`: A list containing the generated responses as ChatMessage instances. """ - if not self._is_warmed_up: - self.warm_up() + self.warm_up() messages = _normalize_messages(messages) @@ -370,6 +397,7 @@ def run( tools_strict=tools_strict, ) openai_endpoint = api_args.pop("openai_endpoint") + assert self.client is not None # mypy: client is built by warm_up above openai_endpoint_method = getattr(self.client.responses, openai_endpoint) responses = openai_endpoint_method(**api_args) @@ -423,8 +451,7 @@ async def run_async( A dictionary with the following key: - `replies`: A list containing the generated responses as ChatMessage instances. """ - if not self._is_warmed_up: - self.warm_up() + await self.warm_up_async() messages = _normalize_messages(messages) @@ -446,6 +473,7 @@ async def run_async( ) openai_endpoint = api_args.pop("openai_endpoint") + assert self.async_client is not None # mypy: async_client is built by warm_up_async above openai_endpoint_method = getattr(self.async_client.responses, openai_endpoint) responses = await openai_endpoint_method(**api_args) diff --git a/haystack/components/generators/openai_image_generator.py b/haystack/components/generators/openai_image_generator.py index f65676e0b9..6605f99152 100644 --- a/haystack/components/generators/openai_image_generator.py +++ b/haystack/components/generators/openai_image_generator.py @@ -80,36 +80,60 @@ def __init__( self.api_base_url = api_base_url self.organization = organization - self.timeout = timeout if timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0")) - self.max_retries = max_retries if max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5")) + self.timeout = timeout + self.max_retries = max_retries self.http_client_kwargs = http_client_kwargs self.client: OpenAI | None = None self.async_client: AsyncOpenAI | None = None + def _client_kwargs(self) -> dict[str, Any]: + timeout = self.timeout if self.timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0")) + max_retries = ( + self.max_retries if self.max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5")) + ) + return { + "api_key": self.api_key.resolve_value(), + "organization": self.organization, + "base_url": self.api_base_url, + "timeout": timeout, + "max_retries": max_retries, + } + def warm_up(self) -> None: """ - Warm up the OpenAI client. + Initializes the synchronous OpenAI client. """ if self.client is None: self.client = OpenAI( - api_key=self.api_key.resolve_value(), - organization=self.organization, - base_url=self.api_base_url, - timeout=self.timeout, - max_retries=self.max_retries, - http_client=init_http_client(self.http_client_kwargs, async_client=False), + http_client=init_http_client(self.http_client_kwargs, async_client=False), **self._client_kwargs() ) + + async def warm_up_async(self) -> None: # noqa: RUF029 + """ + Initializes the asynchronous OpenAI client on the serving event loop. + """ if self.async_client is None: self.async_client = AsyncOpenAI( - api_key=self.api_key.resolve_value(), - organization=self.organization, - base_url=self.api_base_url, - timeout=self.timeout, - max_retries=self.max_retries, - http_client=init_http_client(self.http_client_kwargs, async_client=True), + http_client=init_http_client(self.http_client_kwargs, async_client=True), **self._client_kwargs() ) + def close(self) -> None: + """ + Releases the synchronous OpenAI client. + """ + if self.client is not None: + self.client.close() + self.client = None + + async def close_async(self) -> None: + """ + Releases the asynchronous OpenAI client. + """ + if self.async_client is not None: + await self.async_client.close() + self.async_client = None + @component.output_types(images=list[str], revised_prompt=str) def run( self, @@ -131,8 +155,7 @@ def run( The revised prompt is the prompt that was used to generate the image, if there was any revision to the prompt made by OpenAI. """ - if self.client is None: - self.warm_up() + self.warm_up() # at this point the client is initialized, but mypy doesn't know that assert self.client is not None @@ -173,8 +196,7 @@ async def run_async( The revised prompt is the prompt that was used to generate the image, if there was any revision to the prompt made by OpenAI. """ - if self.async_client is None: - self.warm_up() + await self.warm_up_async() # at this point the client is initialized, but mypy doesn't know that assert self.async_client is not None diff --git a/haystack/components/preprocessors/embedding_based_document_splitter.py b/haystack/components/preprocessors/embedding_based_document_splitter.py index 31d93a619c..4d34298785 100644 --- a/haystack/components/preprocessors/embedding_based_document_splitter.py +++ b/haystack/components/preprocessors/embedding_based_document_splitter.py @@ -118,21 +118,55 @@ def __init__( self.use_split_rules = use_split_rules self.extend_abbreviations = extend_abbreviations self.sentence_splitter: SentenceSplitter | None = None - self._is_warmed_up = False def warm_up(self) -> None: """ - Warm up the component by initializing the sentence splitter. + Warm up the component by initializing the sentence splitter and the document embedder. """ - self.sentence_splitter = SentenceSplitter( - language=self.language, - use_split_rules=self.use_split_rules, - extend_abbreviations=self.extend_abbreviations, - keep_white_spaces=True, - ) + if self.sentence_splitter is None: + self.sentence_splitter = SentenceSplitter( + language=self.language, + use_split_rules=self.use_split_rules, + extend_abbreviations=self.extend_abbreviations, + keep_white_spaces=True, + ) if hasattr(self.document_embedder, "warm_up"): self.document_embedder.warm_up() - self._is_warmed_up = True + + async def warm_up_async(self) -> None: + """ + Warm up the component on the serving event loop. + + Initializes the sentence splitter and warms up the document embedder using its async warm-up path when + available, falling back to the synchronous one otherwise. + """ + if self.sentence_splitter is None: + self.sentence_splitter = SentenceSplitter( + language=self.language, + use_split_rules=self.use_split_rules, + extend_abbreviations=self.extend_abbreviations, + keep_white_spaces=True, + ) + if hasattr(self.document_embedder, "warm_up_async"): + await self.document_embedder.warm_up_async() + elif hasattr(self.document_embedder, "warm_up"): + self.document_embedder.warm_up() + + def close(self) -> None: + """ + Release the document embedder's resources. + """ + if hasattr(self.document_embedder, "close"): + self.document_embedder.close() + + async def close_async(self) -> None: + """ + Release the document embedder's async resources. + """ + if hasattr(self.document_embedder, "close_async"): + await self.document_embedder.close_async() + elif hasattr(self.document_embedder, "close"): + self.document_embedder.close() @component.output_types(documents=list[Document]) def run(self, documents: list[Document]) -> dict[str, list[Document]]: @@ -151,8 +185,7 @@ def run(self, documents: list[Document]) -> dict[str, list[Document]]: :raises TypeError: If the input is not a list of Documents. :raises ValueError: If the document content is None or empty. """ - if not self._is_warmed_up: - self.warm_up() + self.warm_up() if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): raise TypeError("EmbeddingBasedDocumentSplitter expects a List of Documents as input.") @@ -192,8 +225,7 @@ async def run_async(self, documents: list[Document]) -> dict[str, list[Document] :raises TypeError: If the input is not a list of Documents. :raises ValueError: If the document content is None or empty. """ - if not self._is_warmed_up: - self.warm_up() + await self.warm_up_async() if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): raise TypeError("EmbeddingBasedDocumentSplitter expects a List of Documents as input.") diff --git a/haystack/components/preprocessors/recursive_splitter.py b/haystack/components/preprocessors/recursive_splitter.py index ca7da62db1..da22017e74 100644 --- a/haystack/components/preprocessors/recursive_splitter.py +++ b/haystack/components/preprocessors/recursive_splitter.py @@ -99,6 +99,8 @@ def warm_up(self) -> None: """ Warm up the sentence tokenizer and tiktoken tokenizer if needed. """ + if self._is_warmed_up: + return if "sentence" in self.separators: self.nltk_tokenizer = self._get_custom_sentence_tokenizer(self.sentence_splitter_params) if self.split_units == "token": diff --git a/haystack/components/query/query_expander.py b/haystack/components/query/query_expander.py index ab6a420da0..c8285c074f 100644 --- a/haystack/components/query/query_expander.py +++ b/haystack/components/query/query_expander.py @@ -134,7 +134,6 @@ def __init__( else: self.chat_generator = chat_generator - self._is_warmed_up = False self.prompt_template = prompt_template or DEFAULT_PROMPT_TEMPLATE # Check if required variables are present in the template @@ -196,8 +195,7 @@ def run(self, query: str, n_expansions: int | None = None) -> dict[str, list[str :raises ValueError: If n_expansions is not positive (less than or equal to 0). """ - if not self._is_warmed_up: - self.warm_up() + self.warm_up() response = {"queries": [query] if self.include_original_query else []} @@ -264,8 +262,7 @@ async def run_async(self, query: str, n_expansions: int | None = None) -> dict[s :raises ValueError: If n_expansions is not positive (less than or equal to 0). """ - if not self._is_warmed_up: - self.warm_up() + await self.warm_up_async() response = {"queries": [query] if self.include_original_query else []} @@ -316,12 +313,35 @@ async def run_async(self, query: str, n_expansions: int | None = None) -> dict[s def warm_up(self) -> None: """ - Warm up the LLM provider component. + Warm up the underlying chat generator. """ - if not self._is_warmed_up: - if hasattr(self.chat_generator, "warm_up"): - self.chat_generator.warm_up() - self._is_warmed_up = True + if hasattr(self.chat_generator, "warm_up"): + self.chat_generator.warm_up() + + async def warm_up_async(self) -> None: + """ + Warm up the underlying chat generator on the serving event loop. + """ + if hasattr(self.chat_generator, "warm_up_async"): + await self.chat_generator.warm_up_async() + elif hasattr(self.chat_generator, "warm_up"): + self.chat_generator.warm_up() + + def close(self) -> None: + """ + Release the underlying chat generator's resources. + """ + if hasattr(self.chat_generator, "close"): + self.chat_generator.close() + + async def close_async(self) -> None: + """ + Release the underlying chat generator's async resources. + """ + if hasattr(self.chat_generator, "close_async"): + await self.chat_generator.close_async() + elif hasattr(self.chat_generator, "close"): + self.chat_generator.close() @staticmethod def _parse_expanded_queries(generator_response: str) -> list[str]: diff --git a/haystack/components/rankers/llm_ranker.py b/haystack/components/rankers/llm_ranker.py index 03ca7ae981..8f1c85a8fb 100644 --- a/haystack/components/rankers/llm_ranker.py +++ b/haystack/components/rankers/llm_ranker.py @@ -171,16 +171,30 @@ def __init__( self._chat_generator = _default_openai_chat_generator() else: self._chat_generator = chat_generator - self._is_warmed_up = False def warm_up(self) -> None: - """ - Warm up the underlying chat generator. - """ - if not self._is_warmed_up: - if hasattr(self._chat_generator, "warm_up"): - self._chat_generator.warm_up() - self._is_warmed_up = True + """Warm up the underlying chat generator.""" + if hasattr(self._chat_generator, "warm_up"): + self._chat_generator.warm_up() + + async def warm_up_async(self) -> None: + """Warm up the underlying chat generator on the serving event loop.""" + if hasattr(self._chat_generator, "warm_up_async"): + await self._chat_generator.warm_up_async() + elif hasattr(self._chat_generator, "warm_up"): + self._chat_generator.warm_up() + + def close(self) -> None: + """Release the underlying chat generator's resources.""" + if hasattr(self._chat_generator, "close"): + self._chat_generator.close() + + async def close_async(self) -> None: + """Release the underlying chat generator's async resources.""" + if hasattr(self._chat_generator, "close_async"): + await self._chat_generator.close_async() + elif hasattr(self._chat_generator, "close"): + self._chat_generator.close() def to_dict(self) -> dict[str, Any]: """ @@ -242,8 +256,7 @@ def run(self, query: str, documents: list[Document], top_k: int | None = None) - logger.warning("Empty query provided to LLMRanker. Returning documents without reranking.") return {"documents": fallback_documents} - if not self._is_warmed_up: - self.warm_up() + self.warm_up() prompt = self._prompt_builder.run(query=query.strip(), documents=deduplicated_documents) @@ -307,8 +320,7 @@ async def run_async( logger.warning("Empty query provided to LLMRanker. Returning documents without reranking.") return {"documents": fallback_documents} - if not self._is_warmed_up: - self.warm_up() + await self.warm_up_async() prompt = self._prompt_builder.run(query=query.strip(), documents=deduplicated_documents) diff --git a/haystack/components/retrievers/multi_query_embedding_retriever.py b/haystack/components/retrievers/multi_query_embedding_retriever.py index 4075e943d7..7c84157324 100644 --- a/haystack/components/retrievers/multi_query_embedding_retriever.py +++ b/haystack/components/retrievers/multi_query_embedding_retriever.py @@ -85,18 +85,42 @@ def __init__(self, *, retriever: EmbeddingRetriever, query_embedder: TextEmbedde self.retriever = retriever self.query_embedder = query_embedder self.max_workers = max_workers - self._is_warmed_up = False def warm_up(self) -> None: """ - Warm up the query embedder and the retriever if any has a warm_up method. + Warm up the query embedder and the retriever. """ - if not self._is_warmed_up: - if hasattr(self.query_embedder, "warm_up") and callable(self.query_embedder.warm_up): - self.query_embedder.warm_up() - if hasattr(self.retriever, "warm_up") and callable(self.retriever.warm_up): - self.retriever.warm_up() - self._is_warmed_up = True + for inner in (self.query_embedder, self.retriever): + if hasattr(inner, "warm_up"): + inner.warm_up() + + async def warm_up_async(self) -> None: + """ + Warm up the query embedder and the retriever on the serving event loop. + """ + for inner in (self.query_embedder, self.retriever): + if hasattr(inner, "warm_up_async"): + await inner.warm_up_async() + elif hasattr(inner, "warm_up"): + inner.warm_up() + + def close(self) -> None: + """ + Release the query embedder's and the retriever's resources. + """ + for inner in (self.query_embedder, self.retriever): + if hasattr(inner, "close"): + inner.close() + + async def close_async(self) -> None: + """ + Release the query embedder's and the retriever's async resources. + """ + for inner in (self.query_embedder, self.retriever): + if hasattr(inner, "close_async"): + await inner.close_async() + elif hasattr(inner, "close"): + inner.close() @component.output_types(documents=list[Document]) def run(self, queries: list[str], retriever_kwargs: dict[str, Any] | None = None) -> dict[str, list[Document]]: @@ -112,8 +136,7 @@ def run(self, queries: list[str], retriever_kwargs: dict[str, Any] | None = None docs: list[Document] = [] retriever_kwargs = retriever_kwargs or {} - if not self._is_warmed_up: - self.warm_up() + self.warm_up() with ThreadPoolExecutor(max_workers=self.max_workers) as executor: queries_results = executor.map(lambda query: self._run_on_thread(query, retriever_kwargs), queries) @@ -145,8 +168,7 @@ async def run_async( """ retriever_kwargs = retriever_kwargs or {} - if not self._is_warmed_up: - self.warm_up() + await self.warm_up_async() results = await asyncio.gather(*[self._run_one_async(q, retriever_kwargs) for q in queries]) docs: list[Document] = [doc for result in results if result for doc in result] diff --git a/haystack/components/retrievers/multi_query_text_retriever.py b/haystack/components/retrievers/multi_query_text_retriever.py index a8221208e3..a3326057c9 100644 --- a/haystack/components/retrievers/multi_query_text_retriever.py +++ b/haystack/components/retrievers/multi_query_text_retriever.py @@ -67,16 +67,38 @@ def __init__(self, *, retriever: TextRetriever, max_workers: int = 3) -> None: """ self.retriever = retriever self.max_workers = max_workers - self._is_warmed_up = False def warm_up(self) -> None: """ - Warm up the retriever if it has a warm_up method. + Warm up the retriever. """ - if not self._is_warmed_up: - if hasattr(self.retriever, "warm_up") and callable(self.retriever.warm_up): - self.retriever.warm_up() - self._is_warmed_up = True + if hasattr(self.retriever, "warm_up"): + self.retriever.warm_up() + + async def warm_up_async(self) -> None: + """ + Warm up the retriever on the serving event loop. + """ + if hasattr(self.retriever, "warm_up_async"): + await self.retriever.warm_up_async() + elif hasattr(self.retriever, "warm_up"): + self.retriever.warm_up() + + def close(self) -> None: + """ + Release the retriever's resources. + """ + if hasattr(self.retriever, "close"): + self.retriever.close() + + async def close_async(self) -> None: + """ + Release the retriever's async resources. + """ + if hasattr(self.retriever, "close_async"): + await self.retriever.close_async() + elif hasattr(self.retriever, "close"): + self.retriever.close() @component.output_types(documents=list[Document]) def run(self, queries: list[str], retriever_kwargs: dict[str, Any] | None = None) -> dict[str, list[Document]]: @@ -92,8 +114,7 @@ def run(self, queries: list[str], retriever_kwargs: dict[str, Any] | None = None docs: list[Document] = [] retriever_kwargs = retriever_kwargs or {} - if not self._is_warmed_up: - self.warm_up() + self.warm_up() with ThreadPoolExecutor(max_workers=self.max_workers) as executor: queries_results = executor.map(lambda query: self._run_on_thread(query, retriever_kwargs), queries) @@ -125,8 +146,7 @@ async def run_async( """ retriever_kwargs = retriever_kwargs or {} - if not self._is_warmed_up: - self.warm_up() + await self.warm_up_async() results = await asyncio.gather(*[self._run_one_async(q, retriever_kwargs) for q in queries]) docs: list[Document] = [doc for result in results if result for doc in result] diff --git a/haystack/components/retrievers/multi_retriever.py b/haystack/components/retrievers/multi_retriever.py index c918974676..47beefe2fc 100644 --- a/haystack/components/retrievers/multi_retriever.py +++ b/haystack/components/retrievers/multi_retriever.py @@ -108,7 +108,6 @@ def __init__( self.top_k = top_k self.max_workers = max_workers self.join_mode = join_mode - self._is_warmed_up = False def _merge_results(self, document_lists: list[list[Document]]) -> list[Document]: """ @@ -146,14 +145,39 @@ def _resolve_retrievers(self, active_retrievers: list[str] | None) -> dict[str, def warm_up(self) -> None: """ - Warm up the retrievers if any has a warm_up method. + Warm up the retrievers. """ - if self._is_warmed_up: - return for retriever in self.retrievers.values(): - if hasattr(retriever, "warm_up") and callable(retriever.warm_up): + if hasattr(retriever, "warm_up"): retriever.warm_up() - self._is_warmed_up = True + + async def warm_up_async(self) -> None: + """ + Warm up the retrievers on the serving event loop. + """ + for retriever in self.retrievers.values(): + if hasattr(retriever, "warm_up_async"): + await retriever.warm_up_async() + elif hasattr(retriever, "warm_up"): + retriever.warm_up() + + def close(self) -> None: + """ + Release the retrievers' resources. + """ + for retriever in self.retrievers.values(): + if hasattr(retriever, "close"): + retriever.close() + + async def close_async(self) -> None: + """ + Release the retrievers' async resources. + """ + for retriever in self.retrievers.values(): + if hasattr(retriever, "close_async"): + await retriever.close_async() + elif hasattr(retriever, "close"): + retriever.close() @component.output_types(documents=list[Document]) def run( @@ -183,8 +207,7 @@ def run( :raises ValueError: If any name in `active_retrievers` does not match a retriever name. """ - if not self._is_warmed_up: - self.warm_up() + self.warm_up() resolved_top_k = top_k if top_k is not None else self.top_k resolved_filters = filters if filters is not None else self.filters @@ -237,8 +260,7 @@ async def run_async( :raises ValueError: If any name in `active_retrievers` does not match a retriever name. """ - if not self._is_warmed_up: - self.warm_up() + await self.warm_up_async() resolved_top_k = top_k if top_k is not None else self.top_k resolved_filters = filters if filters is not None else self.filters diff --git a/haystack/components/retrievers/text_embedding_retriever.py b/haystack/components/retrievers/text_embedding_retriever.py index cf23248cd6..63a952e23a 100644 --- a/haystack/components/retrievers/text_embedding_retriever.py +++ b/haystack/components/retrievers/text_embedding_retriever.py @@ -67,18 +67,42 @@ def __init__(self, *, retriever: EmbeddingRetriever, text_embedder: TextEmbedder """ self.retriever = retriever self.text_embedder = text_embedder - self._is_warmed_up = False def warm_up(self) -> None: """ - Warm up the text embedder and the retriever if any has a warm_up method. + Warm up the text embedder and the retriever. """ - if not self._is_warmed_up: - if hasattr(self.text_embedder, "warm_up") and callable(self.text_embedder.warm_up): - self.text_embedder.warm_up() - if hasattr(self.retriever, "warm_up") and callable(self.retriever.warm_up): - self.retriever.warm_up() - self._is_warmed_up = True + for inner in (self.text_embedder, self.retriever): + if hasattr(inner, "warm_up"): + inner.warm_up() + + async def warm_up_async(self) -> None: + """ + Warm up the text embedder and the retriever on the serving event loop. + """ + for inner in (self.text_embedder, self.retriever): + if hasattr(inner, "warm_up_async"): + await inner.warm_up_async() + elif hasattr(inner, "warm_up"): + inner.warm_up() + + def close(self) -> None: + """ + Release the text embedder's and the retriever's resources. + """ + for inner in (self.text_embedder, self.retriever): + if hasattr(inner, "close"): + inner.close() + + async def close_async(self) -> None: + """ + Release the text embedder's and the retriever's async resources. + """ + for inner in (self.text_embedder, self.retriever): + if hasattr(inner, "close_async"): + await inner.close_async() + elif hasattr(inner, "close"): + inner.close() @component.output_types(documents=list[Document]) def run( @@ -94,8 +118,7 @@ def run( A dictionary containing: - `documents`: List of retrieved documents sorted by relevance score. """ - if not self._is_warmed_up: - self.warm_up() + self.warm_up() embedding_result = self.text_embedder.run(text=query) result = self.retriever.run(query_embedding=embedding_result["embedding"], filters=filters, top_k=top_k) @@ -122,8 +145,7 @@ async def run_async( A dictionary containing: - `documents`: List of retrieved documents sorted by relevance score. """ - if not self._is_warmed_up: - self.warm_up() + await self.warm_up_async() embedding_result = await _execute_component_async(self.text_embedder, text=query) result = await _execute_component_async( diff --git a/haystack/components/routers/llm_messages_router.py b/haystack/components/routers/llm_messages_router.py index ebf8d56cd9..6094b6562c 100644 --- a/haystack/components/routers/llm_messages_router.py +++ b/haystack/components/routers/llm_messages_router.py @@ -84,20 +84,34 @@ def __init__( self._output_patterns = output_patterns self._compiled_patterns = [re.compile(pattern) for pattern in output_patterns] - self._is_warmed_up = False component.set_output_types( self, **{"chat_generator_text": str, **dict.fromkeys(output_names + ["unmatched"], list[ChatMessage])} ) def warm_up(self) -> None: - """ - Warm up the underlying LLM. - """ - if not self._is_warmed_up: - if hasattr(self._chat_generator, "warm_up"): - self._chat_generator.warm_up() - self._is_warmed_up = True + """Warm up the underlying chat generator.""" + if hasattr(self._chat_generator, "warm_up"): + self._chat_generator.warm_up() + + async def warm_up_async(self) -> None: + """Warm up the underlying chat generator on the serving event loop.""" + if hasattr(self._chat_generator, "warm_up_async"): + await self._chat_generator.warm_up_async() + elif hasattr(self._chat_generator, "warm_up"): + self._chat_generator.warm_up() + + def close(self) -> None: + """Release the underlying chat generator's resources.""" + if hasattr(self._chat_generator, "close"): + self._chat_generator.close() + + async def close_async(self) -> None: + """Release the underlying chat generator's async resources.""" + if hasattr(self._chat_generator, "close_async"): + await self._chat_generator.close_async() + elif hasattr(self._chat_generator, "close"): + self._chat_generator.close() def run(self, messages: list[ChatMessage]) -> dict[str, str | list[ChatMessage]]: """ @@ -121,8 +135,7 @@ def run(self, messages: list[ChatMessage]) -> dict[str, str | list[ChatMessage]] ) raise ValueError(msg) - if not self._is_warmed_up: - self.warm_up() + self.warm_up() messages_for_inference = [] if self._system_prompt: @@ -168,8 +181,7 @@ async def run_async(self, messages: list[ChatMessage]) -> dict[str, str | list[C ) raise ValueError(msg) - if not self._is_warmed_up: - self.warm_up() + await self.warm_up_async() messages_for_inference = [] if self._system_prompt: diff --git a/haystack/core/pipeline/base.py b/haystack/core/pipeline/base.py index 663392410e..f5cb777ec1 100644 --- a/haystack/core/pipeline/base.py +++ b/haystack/core/pipeline/base.py @@ -916,15 +916,60 @@ def walk(self) -> Iterator[tuple[str, Component]]: def warm_up(self) -> None: """ - Make sure all nodes are warm. + Make sure all components are warm. - It's the node's responsibility to make sure this method can be called at every `Pipeline.run()` + It's the component's responsibility to make sure this method can be called at every `Pipeline.run()` without re-initializing everything. """ - for node in self.graph.nodes: - if hasattr(self.graph.nodes[node]["instance"], "warm_up"): - logger.info("Warming up component {node}...", node=node) - self.graph.nodes[node]["instance"].warm_up() + for component_name in self.graph.nodes: + if hasattr(self.graph.nodes[component_name]["instance"], "warm_up"): + logger.info("Warming up component {component_name}...", component_name=component_name) + self.graph.nodes[component_name]["instance"].warm_up() + + async def warm_up_async(self) -> None: + """ + Make sure all components are warm, using the async warm-up path where available. + + Each component is warmed up with `warm_up_async` if it has one, otherwise with its sync `warm_up`. + Both run on the event loop, never offloaded to a worker thread. + This ensures that if an async client is created during `warm-up` (residual scenario), it binds to the loop that + `run_async` will use. + """ + for component_name in self.graph.nodes: + instance = self.graph.nodes[component_name]["instance"] + if hasattr(instance, "warm_up_async"): + logger.info("Warming up component {component_name}...", component_name=component_name) + await instance.warm_up_async() + elif hasattr(instance, "warm_up"): + logger.info("Warming up component {component_name}...", component_name=component_name) + instance.warm_up() + + def close(self) -> None: + """ + Release resources held by the pipeline's components by calling each component's `close` method. + + Only the synchronous side of each component is released here; use `close_async` to release async clients. + """ + for component_name in self.graph.nodes: + instance = self.graph.nodes[component_name]["instance"] + if hasattr(instance, "close"): + logger.info("Closing component {component_name}...", component_name=component_name) + instance.close() + + async def close_async(self) -> None: + """ + Release resources held by the pipeline's components, using the async close path where available. + + Each component is closed with `close_async` if it has one, otherwise with its sync `close`. + """ + for component_name in self.graph.nodes: + instance = self.graph.nodes[component_name]["instance"] + if hasattr(instance, "close_async"): + logger.info("Closing component {component_name}...", component_name=component_name) + await instance.close_async() + elif hasattr(instance, "close"): + logger.info("Closing component {component_name}...", component_name=component_name) + instance.close() @staticmethod def _create_component_span( diff --git a/haystack/core/pipeline/pipeline.py b/haystack/core/pipeline/pipeline.py index 25e29f1711..f0bd0f36ce 100644 --- a/haystack/core/pipeline/pipeline.py +++ b/haystack/core/pipeline/pipeline.py @@ -861,8 +861,8 @@ async def process_results(): pipeline_running(self) # telemetry - # warm up the pipeline by running each component's warm_up method - self.warm_up() + # warm up the pipeline by running each component's warm_up_async (or warm_up) method + await self.warm_up_async() if include_outputs_from is None: include_outputs_from = set() diff --git a/haystack/core/super_component/super_component.py b/haystack/core/super_component/super_component.py index 287965e49a..8cea698135 100644 --- a/haystack/core/super_component/super_component.py +++ b/haystack/core/super_component/super_component.py @@ -105,6 +105,26 @@ def warm_up(self) -> None: self.pipeline.warm_up() self._warmed_up = True + async def warm_up_async(self) -> None: + """ + Warms up the SuperComponent by warming up the wrapped pipeline on the serving event loop. + """ + if not self._warmed_up: + await self.pipeline.warm_up_async() + self._warmed_up = True + + def close(self) -> None: + """ + Releases the synchronous resources held by the wrapped pipeline's components. + """ + self.pipeline.close() + + async def close_async(self) -> None: + """ + Releases the async resources held by the wrapped pipeline's components. + """ + await self.pipeline.close_async() + def run(self, **kwargs: Any) -> dict[str, Any]: """ Runs the wrapped pipeline with the provided inputs. diff --git a/releasenotes/notes/component-resource-lifecycle-4a47aada1ac026fd.yaml b/releasenotes/notes/component-resource-lifecycle-4a47aada1ac026fd.yaml new file mode 100644 index 0000000000..57817855f3 --- /dev/null +++ b/releasenotes/notes/component-resource-lifecycle-4a47aada1ac026fd.yaml @@ -0,0 +1,16 @@ +--- +features: + - | + Components and pipelines now have a consistent resource lifecycle for acquiring and releasing resources such as + HTTP clients and connections. In addition to the existing ``warm_up``, components can now implement ``warm_up_async`` + to acquire asynchronous resources, and ``close``/``close_async`` to release them. + ``Pipeline`` now implements the same lifecycle methods. In addition to the existing ``warm_up``, ``warm_up_async()`` + runs async warm-up on the serving event loop (falling back to the synchronous ``warm_up``), while + ``close()``/``close_async()`` release synchronous and asynchronous resources respectively. + Closing is always explicit; pipelines are never closed automatically after a run. +upgrade: + - | + Components in Haystack that use external resources now consistently create those resources during warm-up rather + than in ``__init__``. As a result, some components that previously raised errors during initialization (for example, + because of a missing API key) now raise them during ``warm_up`` instead. + Similar changes will be rolled out to components in Haystack Core Integrations in the future. diff --git a/test/components/agents/test_agent.py b/test/components/agents/test_agent.py index 26bd610f26..59fb27d569 100644 --- a/test/components/agents/test_agent.py +++ b/test/components/agents/test_agent.py @@ -954,15 +954,6 @@ def run( run_tool_mock.assert_called_once() assert run_tool_mock.call_args.kwargs["tools"] == [weather_tool] - def test_run_not_warmed_up(self, weather_tool): - """Warmup is run automatically on first run""" - chat_generator = MockChatGeneratorWithoutRunAsync() - chat_generator.warm_up = MagicMock() - agent = Agent(chat_generator=chat_generator, tools=[weather_tool], system_prompt="This is a system prompt.") - agent.run([ChatMessage.from_user("What is the weather in Berlin?")]) - assert agent._is_warmed_up is True - assert chat_generator.warm_up.call_count == 1 - def test_run_no_messages(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "fake-key") chat_generator = OpenAIChatGenerator() @@ -1884,7 +1875,6 @@ def document_store_with_docs(self): @pytest.fixture def make_rag_pipeline(self, document_store_with_docs: InMemoryDocumentStore, make_agent): - def _factory(user_prompt: str | None = None): agent = make_agent( user_prompt=user_prompt @@ -1990,7 +1980,6 @@ class TestAgentWaitsForBlockedPredecessor: """ def test_agent_waits_for_messages_when_predecessor_is_blocked(self, weather_tool): - @component class HistoryParser: @component.output_types(messages=list[ChatMessage]) @@ -2295,6 +2284,70 @@ async def test_run_async_warms_up_per_run_toolset(self): assert per_run_tool.was_warmed_up +class TestComponentLifecycle: + def test_warm_up_delegates_to_chat_generator(self, weather_tool): + chat_generator = MockChatGenerator() + chat_generator.warm_up = MagicMock() + agent = Agent(chat_generator=chat_generator, tools=[weather_tool], system_prompt="This is a system prompt.") + + agent.warm_up() + chat_generator.warm_up.assert_called_once() + + chat_generator.warm_up.reset_mock() + agent.run([ChatMessage.from_user("What is the weather in Berlin?")]) + assert agent._tools_warmed_up is True + chat_generator.warm_up.assert_called_once() + + @pytest.mark.asyncio + async def test_warm_up_async_delegates_to_chat_generator(self): + chat_generator = MockChatGenerator() + chat_generator.warm_up_async = AsyncMock() + chat_generator.warm_up = MagicMock() + agent = Agent(chat_generator=chat_generator, tools=[]) + await agent.warm_up_async() + chat_generator.warm_up_async.assert_awaited_once() + chat_generator.warm_up.assert_not_called() + + @pytest.mark.asyncio + async def test_warm_up_async_falls_back_to_sync_warm_up(self): + chat_generator = MockChatGeneratorWithoutRunAsync() + chat_generator.warm_up = MagicMock() + agent = Agent(chat_generator=chat_generator, tools=[]) + await agent.warm_up_async() + chat_generator.warm_up.assert_called_once() + + def test_close_delegates_to_chat_generator(self): + chat_generator = MockChatGenerator() + chat_generator.close = MagicMock() + agent = Agent(chat_generator=chat_generator, tools=[]) + agent.close() + chat_generator.close.assert_called_once() + + @pytest.mark.asyncio + async def test_close_async_delegates_to_chat_generator(self): + chat_generator = MockChatGenerator() + chat_generator.close_async = AsyncMock() + agent = Agent(chat_generator=chat_generator, tools=[]) + await agent.close_async() + chat_generator.close_async.assert_awaited_once() + + @pytest.mark.asyncio + async def test_close_async_falls_back_to_sync_close(self): + chat_generator = MockChatGenerator() + chat_generator.close = MagicMock() + agent = Agent(chat_generator=chat_generator, tools=[]) + await agent.close_async() + chat_generator.close.assert_called_once() + + @pytest.mark.asyncio + async def test_lifecycle_is_safe_when_chat_generator_lacks_methods(self): + agent = Agent(chat_generator=MockChatGeneratorWithoutRunAsync(), tools=[]) + agent.warm_up() + await agent.warm_up_async() + agent.close() + await agent.close_async() + + class TestAgentNotTriggeredByInjectedInput: """ Regression test for https://github.com/deepset-ai/haystack/issues/11109. diff --git a/test/components/audio/test_whisper_remote.py b/test/components/audio/test_whisper_remote.py index 8c433ab12f..4535d97c29 100644 --- a/test/components/audio/test_whisper_remote.py +++ b/test/components/audio/test_whisper_remote.py @@ -3,38 +3,25 @@ # SPDX-License-Identifier: Apache-2.0 import os -from unittest.mock import AsyncMock, Mock +from unittest.mock import AsyncMock, MagicMock, Mock import pytest -from openai import AsyncOpenAI +import haystack.components.audio.whisper_remote as whisper_remote_module from haystack.components.audio.whisper_remote import RemoteWhisperTranscriber from haystack.dataclasses import ByteStream from haystack.utils import Secret class TestRemoteWhisperTranscriber: - def test_init_no_key(self, monkeypatch): - monkeypatch.delenv("OPENAI_API_KEY", raising=False) - with pytest.raises(ValueError, match="None of the .* environment variables are set"): - RemoteWhisperTranscriber() - - def test_init_key_env_var(self, monkeypatch): - monkeypatch.setenv("OPENAI_API_KEY", "test_api_key") - t = RemoteWhisperTranscriber() - assert t.client.api_key == "test_api_key" - - def test_init_key_module_env_and_global_var(self, monkeypatch): - monkeypatch.setenv("OPENAI_API_KEY", "test_api_key_2") - t = RemoteWhisperTranscriber() - assert t.client.api_key == "test_api_key_2" - def test_init_default(self): transcriber = RemoteWhisperTranscriber(api_key=Secret.from_token("test_api_key")) - assert transcriber.client.api_key == "test_api_key" + assert transcriber.api_key == Secret.from_token("test_api_key") assert transcriber.model == "whisper-1" assert transcriber.organization is None assert transcriber.whisper_params == {"response_format": "json"} + assert transcriber.client is None + assert transcriber.async_client is None def test_init_custom_parameters(self): transcriber = RemoteWhisperTranscriber( @@ -58,6 +45,8 @@ def test_init_custom_parameters(self): "response_format": "json", "temperature": "0.5", } + assert transcriber.client is None + assert transcriber.async_client is None def test_to_dict_default_parameters(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test_api_key") @@ -167,8 +156,10 @@ def test_from_dict_with_default_parameters_no_env_var(self, monkeypatch): }, } - with pytest.raises(ValueError, match="None of the .* environment variables are set"): - RemoteWhisperTranscriber.from_dict(data) + transcriber = RemoteWhisperTranscriber.from_dict(data) + assert transcriber.model == "whisper-1" + assert transcriber.client is None + assert transcriber.async_client is None @pytest.mark.skipif( not os.environ.get("OPENAI_API_KEY", None), @@ -198,12 +189,6 @@ def test_whisper_remote_transcriber(self, test_files_path): class TestRemoteWhisperTranscriberAsync: - def test_init_async_client(self, monkeypatch): - monkeypatch.setenv("OPENAI_API_KEY", "test_api_key") - transcriber = RemoteWhisperTranscriber() - assert isinstance(transcriber.async_client, AsyncOpenAI) - assert transcriber.async_client.api_key == "test_api_key" - @pytest.mark.asyncio async def test_run_async_with_path(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test_api_key") @@ -268,3 +253,73 @@ async def test_whisper_remote_transcriber_async(self, test_files_path): assert str(test_files_path / "audio" / "the context for this answer is here.wav") == docs[1].meta["file_path"] assert docs[2].content.strip().lower() == "answer." + + +@pytest.fixture +def mock_openai_clients(monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "fake") + sync_cls = MagicMock(name="OpenAI") + async_cls = MagicMock(name="AsyncOpenAI") + async_cls.return_value.close = AsyncMock() + monkeypatch.setattr(whisper_remote_module, "OpenAI", sync_cls) + monkeypatch.setattr(whisper_remote_module, "AsyncOpenAI", async_cls) + return sync_cls, async_cls + + +class TestComponentLifecycle: + def test_warm_up_resolves_key(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test_api_key") + transcriber = RemoteWhisperTranscriber() + transcriber.warm_up() + assert transcriber.client.api_key == "test_api_key" + + def test_key_resolved_at_warm_up_not_init(self, monkeypatch): + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + transcriber = RemoteWhisperTranscriber() + with pytest.raises(ValueError, match="None of the .* environment variables are set"): + transcriber.warm_up() + + def test_sync_lifecycle(self, mock_openai_clients): + sync_cls, _ = mock_openai_clients + transcriber = RemoteWhisperTranscriber() + assert transcriber.client is None + assert transcriber.async_client is None + + transcriber.warm_up() + assert transcriber.client is sync_cls.return_value + assert transcriber.async_client is None + + transcriber.close() + sync_cls.return_value.close.assert_called_once() + assert transcriber.client is None + + async def test_async_lifecycle(self, mock_openai_clients): + _, async_cls = mock_openai_clients + transcriber = RemoteWhisperTranscriber() + + await transcriber.warm_up_async() + assert transcriber.async_client is async_cls.return_value + assert transcriber.client is None + + await transcriber.close_async() + async_cls.return_value.close.assert_awaited_once() + assert transcriber.async_client is None + + async def test_close_is_safe_without_warm_up(self, mock_openai_clients): + transcriber = RemoteWhisperTranscriber() + transcriber.close() + await transcriber.close_async() + assert transcriber.client is None + assert transcriber.async_client is None + + async def test_close_and_close_async_are_independent(self, mock_openai_clients): + transcriber = RemoteWhisperTranscriber() + transcriber.warm_up() + await transcriber.warm_up_async() + + transcriber.close() + assert transcriber.client is None + assert transcriber.async_client is not None + + await transcriber.close_async() + assert transcriber.async_client is None diff --git a/test/components/embedders/test_azure_document_embedder.py b/test/components/embedders/test_azure_document_embedder.py index 938cbb136a..e38866fbd8 100644 --- a/test/components/embedders/test_azure_document_embedder.py +++ b/test/components/embedders/test_azure_document_embedder.py @@ -3,11 +3,12 @@ # SPDX-License-Identifier: Apache-2.0 import os -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest -from openai import APIError +from openai import APIError, OpenAIError +import haystack.components.embedders.azure_document_embedder as azure_document_embedder_module from haystack import Document from haystack.components.embedders import AzureOpenAIDocumentEmbedder from haystack.utils.auth import Secret @@ -28,9 +29,13 @@ def test_init_default(self, monkeypatch): assert embedder.progress_bar is True assert embedder.meta_fields_to_embed == [] assert embedder.embedding_separator == "\n" + assert embedder.timeout is None + assert embedder.max_retries is None assert embedder.default_headers == {} assert embedder.azure_ad_token_provider is None assert embedder.http_client_kwargs is None + assert embedder.client is None + assert embedder.async_client is None def test_init_with_0_max_retries(self, monkeypatch): """Tests that the max_retries init param is set correctly if equal 0""" @@ -51,6 +56,8 @@ def test_init_with_0_max_retries(self, monkeypatch): assert embedder.default_headers == {} assert embedder.azure_ad_token_provider is None assert embedder.max_retries == 0 + assert embedder.client is None + assert embedder.async_client is None def test_to_dict(self, monkeypatch): monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") @@ -72,8 +79,8 @@ def test_to_dict(self, monkeypatch): "progress_bar": True, "meta_fields_to_embed": [], "embedding_separator": "\n", - "max_retries": 5, - "timeout": 30.0, + "max_retries": None, + "timeout": None, "default_headers": {}, "azure_ad_token_provider": None, "http_client_kwargs": None, @@ -208,6 +215,7 @@ def test_embed_batch_handles_exceptions_gracefully(self, caplog): azure_deployment="text-embedding-ada-002", embedding_separator=" | ", ) + embedder.warm_up() fake_texts_to_embed = {"1": "text1", "2": "text2"} @@ -228,6 +236,7 @@ def test_embed_batch_raises_exception_on_failure(self): azure_deployment="text-embedding-ada-002", raise_on_failure=True, ) + embedder.warm_up() fake_texts_to_embed = {"1": "text1", "2": "text2"} with patch.object( embedder.client.embeddings, @@ -276,3 +285,94 @@ def test_run(self): assert metadata["usage"]["prompt_tokens"] == 15 assert metadata["usage"]["total_tokens"] == 15 assert "ada" in metadata["model"] + + +@pytest.fixture +def mock_azure_clients(monkeypatch): + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake") + sync_cls = MagicMock(name="AzureOpenAI") + async_cls = MagicMock(name="AsyncAzureOpenAI") + async_cls.return_value.close = AsyncMock() + monkeypatch.setattr(azure_document_embedder_module, "AzureOpenAI", sync_cls) + monkeypatch.setattr(azure_document_embedder_module, "AsyncAzureOpenAI", async_cls) + return sync_cls, async_cls + + +class TestComponentLifecycle: + def test_warm_up_uses_default_timeout_and_max_retries(self, monkeypatch): + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") + embedder = AzureOpenAIDocumentEmbedder(azure_endpoint="https://example-resource.azure.openai.com/") + embedder.warm_up() + assert embedder.client.max_retries == 5 + assert embedder.client.timeout == 30.0 + + def test_warm_up_uses_timeout_and_max_retries_from_parameters(self, monkeypatch): + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") + embedder = AzureOpenAIDocumentEmbedder( + azure_endpoint="https://example-resource.azure.openai.com/", timeout=40.0, max_retries=1 + ) + embedder.warm_up() + assert embedder.client.max_retries == 1 + assert embedder.client.timeout == 40.0 + + def test_warm_up_uses_timeout_and_max_retries_from_env_vars(self, monkeypatch): + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") + monkeypatch.setenv("OPENAI_TIMEOUT", "100") + monkeypatch.setenv("OPENAI_MAX_RETRIES", "10") + embedder = AzureOpenAIDocumentEmbedder(azure_endpoint="https://example-resource.azure.openai.com/") + embedder.warm_up() + assert embedder.client.max_retries == 10 + assert embedder.client.timeout == 100.0 + + def test_key_resolved_at_warm_up_not_init(self, monkeypatch): + monkeypatch.delenv("AZURE_OPENAI_API_KEY", raising=False) + monkeypatch.delenv("AZURE_OPENAI_AD_TOKEN", raising=False) + embedder = AzureOpenAIDocumentEmbedder(azure_endpoint="https://example-resource.azure.openai.com/") + assert embedder.client is None + with pytest.raises(OpenAIError): + embedder.warm_up() + + def test_sync_lifecycle(self, mock_azure_clients): + sync_cls, _ = mock_azure_clients + embedder = AzureOpenAIDocumentEmbedder(azure_endpoint="https://example-resource.azure.openai.com/") + assert embedder.client is None + assert embedder.async_client is None + + embedder.warm_up() + assert embedder.client is sync_cls.return_value + assert embedder.async_client is None + + embedder.close() + sync_cls.return_value.close.assert_called_once() + assert embedder.client is None + + async def test_async_lifecycle(self, mock_azure_clients): + _, async_cls = mock_azure_clients + embedder = AzureOpenAIDocumentEmbedder(azure_endpoint="https://example-resource.azure.openai.com/") + + await embedder.warm_up_async() + assert embedder.async_client is async_cls.return_value + assert embedder.client is None + + await embedder.close_async() + async_cls.return_value.close.assert_awaited_once() + assert embedder.async_client is None + + async def test_close_is_safe_without_warm_up(self, mock_azure_clients): + embedder = AzureOpenAIDocumentEmbedder(azure_endpoint="https://example-resource.azure.openai.com/") + embedder.close() + await embedder.close_async() + assert embedder.client is None + assert embedder.async_client is None + + async def test_close_and_close_async_are_independent(self, mock_azure_clients): + embedder = AzureOpenAIDocumentEmbedder(azure_endpoint="https://example-resource.azure.openai.com/") + embedder.warm_up() + await embedder.warm_up_async() + + embedder.close() + assert embedder.client is None + assert embedder.async_client is not None + + await embedder.close_async() + assert embedder.async_client is None diff --git a/test/components/embedders/test_azure_text_embedder.py b/test/components/embedders/test_azure_text_embedder.py index 302193ff41..a1a1f2ebc9 100644 --- a/test/components/embedders/test_azure_text_embedder.py +++ b/test/components/embedders/test_azure_text_embedder.py @@ -3,9 +3,12 @@ # SPDX-License-Identifier: Apache-2.0 import os +from unittest.mock import AsyncMock, MagicMock import pytest +from openai import OpenAIError +import haystack.components.embedders.azure_text_embedder as azure_text_embedder_module from haystack.components.embedders import AzureOpenAITextEmbedder from haystack.utils.azure import default_azure_ad_token_provider @@ -15,23 +18,27 @@ def test_init_default(self, monkeypatch): monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") embedder = AzureOpenAITextEmbedder(azure_endpoint="https://example-resource.azure.openai.com/") - assert embedder.client.api_key == "fake-api-key" + assert embedder.api_key.resolve_value() == "fake-api-key" assert embedder.azure_deployment == "text-embedding-ada-002" assert embedder.model == "text-embedding-ada-002" assert embedder.dimensions is None assert embedder.organization is None assert embedder.prefix == "" assert embedder.suffix == "" + assert embedder.timeout is None + assert embedder.max_retries is None assert embedder.default_headers == {} assert embedder.azure_ad_token_provider is None assert embedder.http_client_kwargs is None + assert embedder.client is None + assert embedder.async_client is None def test_init_with_zero_max_retries(self, monkeypatch): """Tests that the max_retries init param is set correctly if equal 0""" monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") embedder = AzureOpenAITextEmbedder(azure_endpoint="https://example-resource.azure.openai.com/", max_retries=0) - assert embedder.client.api_key == "fake-api-key" + assert embedder.api_key.resolve_value() == "fake-api-key" assert embedder.azure_deployment == "text-embedding-ada-002" assert embedder.model == "text-embedding-ada-002" assert embedder.dimensions is None @@ -41,6 +48,8 @@ def test_init_with_zero_max_retries(self, monkeypatch): assert embedder.default_headers == {} assert embedder.azure_ad_token_provider is None assert embedder.max_retries == 0 + assert embedder.client is None + assert embedder.async_client is None def test_to_dict_default(self, monkeypatch): monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") @@ -56,8 +65,8 @@ def test_to_dict_default(self, monkeypatch): "organization": None, "azure_endpoint": "https://example-resource.azure.openai.com/", "api_version": "2023-05-15", - "max_retries": 5, - "timeout": 30.0, + "max_retries": None, + "timeout": None, "prefix": "", "suffix": "", "default_headers": {}, @@ -189,3 +198,94 @@ def test_run(self): assert all(isinstance(x, float) for x in result["embedding"]) assert result["meta"]["usage"] == {"prompt_tokens": 6, "total_tokens": 6} assert "ada" in result["meta"]["model"] + + +@pytest.fixture +def mock_azure_clients(monkeypatch): + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake") + sync_cls = MagicMock(name="AzureOpenAI") + async_cls = MagicMock(name="AsyncAzureOpenAI") + async_cls.return_value.close = AsyncMock() + monkeypatch.setattr(azure_text_embedder_module, "AzureOpenAI", sync_cls) + monkeypatch.setattr(azure_text_embedder_module, "AsyncAzureOpenAI", async_cls) + return sync_cls, async_cls + + +class TestComponentLifecycle: + def test_warm_up_uses_default_timeout_and_max_retries(self, monkeypatch): + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") + embedder = AzureOpenAITextEmbedder(azure_endpoint="https://example-resource.azure.openai.com/") + embedder.warm_up() + assert embedder.client.max_retries == 5 + assert embedder.client.timeout == 30.0 + + def test_warm_up_uses_timeout_and_max_retries_from_parameters(self, monkeypatch): + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") + embedder = AzureOpenAITextEmbedder( + azure_endpoint="https://example-resource.azure.openai.com/", timeout=40.0, max_retries=1 + ) + embedder.warm_up() + assert embedder.client.max_retries == 1 + assert embedder.client.timeout == 40.0 + + def test_warm_up_uses_timeout_and_max_retries_from_env_vars(self, monkeypatch): + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") + monkeypatch.setenv("OPENAI_TIMEOUT", "100") + monkeypatch.setenv("OPENAI_MAX_RETRIES", "10") + embedder = AzureOpenAITextEmbedder(azure_endpoint="https://example-resource.azure.openai.com/") + embedder.warm_up() + assert embedder.client.max_retries == 10 + assert embedder.client.timeout == 100.0 + + def test_key_resolved_at_warm_up_not_init(self, monkeypatch): + monkeypatch.delenv("AZURE_OPENAI_API_KEY", raising=False) + monkeypatch.delenv("AZURE_OPENAI_AD_TOKEN", raising=False) + embedder = AzureOpenAITextEmbedder(azure_endpoint="https://example-resource.azure.openai.com/") + assert embedder.client is None + with pytest.raises(OpenAIError): + embedder.warm_up() + + def test_sync_lifecycle(self, mock_azure_clients): + sync_cls, _ = mock_azure_clients + embedder = AzureOpenAITextEmbedder(azure_endpoint="https://example-resource.azure.openai.com/") + assert embedder.client is None + assert embedder.async_client is None + + embedder.warm_up() + assert embedder.client is sync_cls.return_value + assert embedder.async_client is None + + embedder.close() + sync_cls.return_value.close.assert_called_once() + assert embedder.client is None + + async def test_async_lifecycle(self, mock_azure_clients): + _, async_cls = mock_azure_clients + embedder = AzureOpenAITextEmbedder(azure_endpoint="https://example-resource.azure.openai.com/") + + await embedder.warm_up_async() + assert embedder.async_client is async_cls.return_value + assert embedder.client is None + + await embedder.close_async() + async_cls.return_value.close.assert_awaited_once() + assert embedder.async_client is None + + async def test_close_is_safe_without_warm_up(self, mock_azure_clients): + embedder = AzureOpenAITextEmbedder(azure_endpoint="https://example-resource.azure.openai.com/") + embedder.close() + await embedder.close_async() + assert embedder.client is None + assert embedder.async_client is None + + async def test_close_and_close_async_are_independent(self, mock_azure_clients): + embedder = AzureOpenAITextEmbedder(azure_endpoint="https://example-resource.azure.openai.com/") + embedder.warm_up() + await embedder.warm_up_async() + + embedder.close() + assert embedder.client is None + assert embedder.async_client is not None + + await embedder.close_async() + assert embedder.async_client is None diff --git a/test/components/embedders/test_openai_document_embedder.py b/test/components/embedders/test_openai_document_embedder.py index 4a97e7c457..ad107e57ec 100644 --- a/test/components/embedders/test_openai_document_embedder.py +++ b/test/components/embedders/test_openai_document_embedder.py @@ -4,11 +4,12 @@ import contextlib import os -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest from openai import APIError +import haystack.components.embedders.openai_document_embedder as openai_document_embedder_module from haystack import Document from haystack.components.embedders.openai_document_embedder import OpenAIDocumentEmbedder from haystack.utils.auth import Secret @@ -27,8 +28,10 @@ def test_init_default(self, monkeypatch): assert embedder.progress_bar is True assert embedder.meta_fields_to_embed == [] assert embedder.embedding_separator == "\n" - assert embedder.client.max_retries == 5 - assert embedder.client.timeout == 30.0 + assert embedder.timeout is None + assert embedder.max_retries is None + assert embedder.client is None + assert embedder.async_client is None def test_init_with_parameters(self, monkeypatch): monkeypatch.setenv("OPENAI_TIMEOUT", "100") @@ -55,8 +58,10 @@ def test_init_with_parameters(self, monkeypatch): assert embedder.progress_bar is False assert embedder.meta_fields_to_embed == ["test_field"] assert embedder.embedding_separator == " | " - assert embedder.client.max_retries == 1 - assert embedder.client.timeout == 40.0 + assert embedder.timeout == 40.0 + assert embedder.max_retries == 1 + assert embedder.client is None + assert embedder.async_client is None def test_init_with_parameters_and_env_vars(self, monkeypatch): monkeypatch.setenv("OPENAI_TIMEOUT", "100") @@ -81,13 +86,10 @@ def test_init_with_parameters_and_env_vars(self, monkeypatch): assert embedder.progress_bar is False assert embedder.meta_fields_to_embed == ["test_field"] assert embedder.embedding_separator == " | " - assert embedder.client.max_retries == 10 - assert embedder.client.timeout == 100.0 - - def test_init_fail_wo_api_key(self, monkeypatch): - monkeypatch.delenv("OPENAI_API_KEY", raising=False) - with pytest.raises(ValueError, match="None of the .* environment variables are set"): - OpenAIDocumentEmbedder() + assert embedder.timeout is None + assert embedder.max_retries is None + assert embedder.client is None + assert embedder.async_client is None def test_to_dict(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") @@ -214,6 +216,7 @@ def test_run_on_empty_list(self): def test_embed_batch_handles_exceptions_gracefully(self, caplog): embedder = OpenAIDocumentEmbedder(api_key=Secret.from_token("fake_api_key")) + embedder.warm_up() fake_texts_to_embed = {"1": "text1", "2": "text2"} with patch.object( embedder.client.embeddings, @@ -227,6 +230,7 @@ def test_embed_batch_handles_exceptions_gracefully(self, caplog): def test_run_handles_exceptions_gracefully(self, caplog): embedder = OpenAIDocumentEmbedder(api_key=Secret.from_token("fake_api_key"), batch_size=1) + embedder.warm_up() docs = [ Document(content="I love cheese", meta={"topic": "Cuisine"}), Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), @@ -255,6 +259,7 @@ def test_run_handles_exceptions_gracefully(self, caplog): def test_embed_batch_raises_exception_on_failure(self): embedder = OpenAIDocumentEmbedder(api_key=Secret.from_token("fake_api_key"), raise_on_failure=True) + embedder.warm_up() fake_texts_to_embed = {"1": "text1", "2": "text2"} with patch.object( embedder.client.embeddings, @@ -277,6 +282,7 @@ def test_run(self): embedder = OpenAIDocumentEmbedder(model=model, meta_fields_to_embed=["topic"], embedding_separator=" | ") result = embedder.run(documents=docs) + assert embedder.client is not None documents_with_embeddings = result["documents"] assert isinstance(documents_with_embeddings, list) @@ -308,6 +314,7 @@ async def test_run_async(self): ] result = await embedder.run_async(documents=docs) + assert embedder.async_client is not None documents_with_embeddings = result["documents"] assert isinstance(documents_with_embeddings, list) @@ -328,4 +335,89 @@ async def test_run_async(self): # Close async client; suppress RuntimeError if the event loop is already closed with contextlib.suppress(RuntimeError): - await embedder.async_client.close() + await embedder.close_async() + + +@pytest.fixture +def mock_openai_clients(monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "fake") + sync_cls = MagicMock(name="OpenAI") + async_cls = MagicMock(name="AsyncOpenAI") + async_cls.return_value.close = AsyncMock() + monkeypatch.setattr(openai_document_embedder_module, "OpenAI", sync_cls) + monkeypatch.setattr(openai_document_embedder_module, "AsyncOpenAI", async_cls) + return sync_cls, async_cls + + +class TestComponentLifecycle: + def test_warm_up_uses_default_timeout_and_max_retries(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") + embedder = OpenAIDocumentEmbedder() + embedder.warm_up() + assert embedder.client.max_retries == 5 + assert embedder.client.timeout == 30.0 + + def test_warm_up_uses_timeout_and_max_retries_from_parameters(self): + embedder = OpenAIDocumentEmbedder(api_key=Secret.from_token("fake-api-key"), timeout=40.0, max_retries=1) + embedder.warm_up() + assert embedder.client.max_retries == 1 + assert embedder.client.timeout == 40.0 + + def test_warm_up_uses_timeout_and_max_retries_from_env_vars(self, monkeypatch): + monkeypatch.setenv("OPENAI_TIMEOUT", "100") + monkeypatch.setenv("OPENAI_MAX_RETRIES", "10") + embedder = OpenAIDocumentEmbedder(api_key=Secret.from_token("fake-api-key")) + embedder.warm_up() + assert embedder.client.max_retries == 10 + assert embedder.client.timeout == 100.0 + + def test_key_resolved_at_warm_up_not_init(self, monkeypatch): + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + embedder = OpenAIDocumentEmbedder() + with pytest.raises(ValueError, match="None of the .* environment variables are set"): + embedder.warm_up() + + def test_sync_lifecycle(self, mock_openai_clients): + sync_cls, _ = mock_openai_clients + embedder = OpenAIDocumentEmbedder() + assert embedder.client is None + assert embedder.async_client is None + + embedder.warm_up() + assert embedder.client is sync_cls.return_value + assert embedder.async_client is None + + embedder.close() + sync_cls.return_value.close.assert_called_once() + assert embedder.client is None + + async def test_async_lifecycle(self, mock_openai_clients): + _, async_cls = mock_openai_clients + embedder = OpenAIDocumentEmbedder() + + await embedder.warm_up_async() + assert embedder.async_client is async_cls.return_value + assert embedder.client is None + + await embedder.close_async() + async_cls.return_value.close.assert_awaited_once() + assert embedder.async_client is None + + async def test_close_is_safe_without_warm_up(self, mock_openai_clients): + embedder = OpenAIDocumentEmbedder() + embedder.close() + await embedder.close_async() + assert embedder.client is None + assert embedder.async_client is None + + async def test_close_and_close_async_are_independent(self, mock_openai_clients): + embedder = OpenAIDocumentEmbedder() + embedder.warm_up() + await embedder.warm_up_async() + + embedder.close() + assert embedder.client is None + assert embedder.async_client is not None + + await embedder.close_async() + assert embedder.async_client is None diff --git a/test/components/embedders/test_openai_text_embedder.py b/test/components/embedders/test_openai_text_embedder.py index 4b386fc6a7..8de3b06416 100644 --- a/test/components/embedders/test_openai_text_embedder.py +++ b/test/components/embedders/test_openai_text_embedder.py @@ -4,10 +4,12 @@ import contextlib import os +from unittest.mock import AsyncMock, MagicMock import pytest from openai.types import CreateEmbeddingResponse, Embedding +import haystack.components.embedders.openai_text_embedder as openai_text_embedder_module from haystack.components.embedders.openai_text_embedder import OpenAITextEmbedder from haystack.utils.auth import Secret @@ -17,14 +19,16 @@ def test_init_default(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") embedder = OpenAITextEmbedder() - assert embedder.client.api_key == "fake-api-key" + assert embedder.api_key.resolve_value() == "fake-api-key" assert embedder.model == "text-embedding-ada-002" assert embedder.api_base_url is None assert embedder.organization is None assert embedder.prefix == "" assert embedder.suffix == "" - assert embedder.client.timeout == 30 - assert embedder.client.max_retries == 5 + assert embedder.timeout is None + assert embedder.max_retries is None + assert embedder.client is None + assert embedder.async_client is None def test_init_with_parameters(self, monkeypatch): monkeypatch.setenv("OPENAI_TIMEOUT", "100") @@ -39,14 +43,16 @@ def test_init_with_parameters(self, monkeypatch): timeout=40.0, max_retries=1, ) - assert embedder.client.api_key == "fake-api-key" + assert embedder.api_key.resolve_value() == "fake-api-key" assert embedder.model == "model" assert embedder.api_base_url == "https://my-custom-base-url.com" assert embedder.organization == "fake-organization" assert embedder.prefix == "prefix" assert embedder.suffix == "suffix" - assert embedder.client.timeout == 40.0 - assert embedder.client.max_retries == 1 + assert embedder.timeout == 40.0 + assert embedder.max_retries == 1 + assert embedder.client is None + assert embedder.async_client is None def test_init_with_parameters_and_env_vars(self, monkeypatch): monkeypatch.setenv("OPENAI_TIMEOUT", "100") @@ -59,19 +65,16 @@ def test_init_with_parameters_and_env_vars(self, monkeypatch): prefix="prefix", suffix="suffix", ) - assert embedder.client.api_key == "fake-api-key" + assert embedder.api_key.resolve_value() == "fake-api-key" assert embedder.model == "model" assert embedder.api_base_url == "https://my-custom-base-url.com" assert embedder.organization == "fake-organization" assert embedder.prefix == "prefix" assert embedder.suffix == "suffix" - assert embedder.client.timeout == 100.0 - assert embedder.client.max_retries == 10 - - def test_init_fail_wo_api_key(self, monkeypatch): - monkeypatch.delenv("OPENAI_API_KEY", raising=False) - with pytest.raises(ValueError, match="None of the .* environment variables are set"): - OpenAITextEmbedder() + assert embedder.timeout is None + assert embedder.max_retries is None + assert embedder.client is None + assert embedder.async_client is None def test_to_dict(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") @@ -138,7 +141,7 @@ def test_from_dict(self, monkeypatch): }, } component = OpenAITextEmbedder.from_dict(data) - assert component.client.api_key == "fake-api-key" + assert component.api_key.resolve_value() == "fake-api-key" assert component.model == "text-embedding-ada-002" assert component.api_base_url == "https://my-custom-base-url.com" assert component.organization == "fake-organization" @@ -219,4 +222,89 @@ async def test_run_async(self): # Close async client; suppress RuntimeError if the event loop is already closed with contextlib.suppress(RuntimeError): - await embedder.async_client.close() + await embedder.close_async() + + +@pytest.fixture +def mock_openai_clients(monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "fake") + sync_cls = MagicMock(name="OpenAI") + async_cls = MagicMock(name="AsyncOpenAI") + async_cls.return_value.close = AsyncMock() + monkeypatch.setattr(openai_text_embedder_module, "OpenAI", sync_cls) + monkeypatch.setattr(openai_text_embedder_module, "AsyncOpenAI", async_cls) + return sync_cls, async_cls + + +class TestComponentLifecycle: + def test_warm_up_uses_default_timeout_and_max_retries(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") + embedder = OpenAITextEmbedder() + embedder.warm_up() + assert embedder.client.max_retries == 5 + assert embedder.client.timeout == 30.0 + + def test_warm_up_uses_timeout_and_max_retries_from_parameters(self): + embedder = OpenAITextEmbedder(api_key=Secret.from_token("fake-api-key"), timeout=40.0, max_retries=1) + embedder.warm_up() + assert embedder.client.max_retries == 1 + assert embedder.client.timeout == 40.0 + + def test_warm_up_uses_timeout_and_max_retries_from_env_vars(self, monkeypatch): + monkeypatch.setenv("OPENAI_TIMEOUT", "100") + monkeypatch.setenv("OPENAI_MAX_RETRIES", "10") + embedder = OpenAITextEmbedder(api_key=Secret.from_token("fake-api-key")) + embedder.warm_up() + assert embedder.client.max_retries == 10 + assert embedder.client.timeout == 100.0 + + def test_key_resolved_at_warm_up_not_init(self, monkeypatch): + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + embedder = OpenAITextEmbedder() + with pytest.raises(ValueError, match="None of the .* environment variables are set"): + embedder.warm_up() + + def test_sync_lifecycle(self, mock_openai_clients): + sync_cls, _ = mock_openai_clients + embedder = OpenAITextEmbedder() + assert embedder.client is None + assert embedder.async_client is None + + embedder.warm_up() + assert embedder.client is sync_cls.return_value + assert embedder.async_client is None + + embedder.close() + sync_cls.return_value.close.assert_called_once() + assert embedder.client is None + + async def test_async_lifecycle(self, mock_openai_clients): + _, async_cls = mock_openai_clients + embedder = OpenAITextEmbedder() + + await embedder.warm_up_async() + assert embedder.async_client is async_cls.return_value + assert embedder.client is None + + await embedder.close_async() + async_cls.return_value.close.assert_awaited_once() + assert embedder.async_client is None + + async def test_close_is_safe_without_warm_up(self, mock_openai_clients): + embedder = OpenAITextEmbedder() + embedder.close() + await embedder.close_async() + assert embedder.client is None + assert embedder.async_client is None + + async def test_close_and_close_async_are_independent(self, mock_openai_clients): + embedder = OpenAITextEmbedder() + embedder.warm_up() + await embedder.warm_up_async() + + embedder.close() + assert embedder.client is None + assert embedder.async_client is not None + + await embedder.close_async() + assert embedder.async_client is None diff --git a/test/components/evaluators/test_context_relevance_evaluator.py b/test/components/evaluators/test_context_relevance_evaluator.py index 1e5f89aff6..e70d4717ac 100644 --- a/test/components/evaluators/test_context_relevance_evaluator.py +++ b/test/components/evaluators/test_context_relevance_evaluator.py @@ -52,13 +52,14 @@ def test_init_default(self, monkeypatch): ] assert isinstance(component._chat_generator, OpenAIChatGenerator) - assert component._chat_generator.client.api_key == "test-api-key" + assert component._chat_generator.api_key.resolve_value() == "test-api-key" assert component._chat_generator.generation_kwargs == {"response_format": {"type": "json_object"}, "seed": 42} - def test_init_fail_wo_openai_api_key(self, monkeypatch): + def test_key_resolved_at_warm_up_not_init(self, monkeypatch): monkeypatch.delenv("OPENAI_API_KEY", raising=False) + component = ContextRelevanceEvaluator() with pytest.raises(ValueError, match="None of the .* environment variables are set"): - ContextRelevanceEvaluator() + component.warm_up() def test_init_with_parameters(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") @@ -75,7 +76,7 @@ def test_init_with_parameters(self, monkeypatch): ] assert isinstance(component._chat_generator, OpenAIChatGenerator) - assert component._chat_generator.client.api_key == "test-api-key" + assert component._chat_generator.api_key.resolve_value() == "test-api-key" assert component._chat_generator.generation_kwargs == {"response_format": {"type": "json_object"}, "seed": 42} def test_init_with_chat_generator(self, monkeypatch): @@ -123,7 +124,7 @@ def test_from_dict(self, monkeypatch): component = ContextRelevanceEvaluator.from_dict(data) assert isinstance(component._chat_generator, OpenAIChatGenerator) - assert component._chat_generator.client.api_key == "test-api-key" + assert component._chat_generator.api_key.resolve_value() == "test-api-key" assert component._chat_generator.generation_kwargs == {"response_format": {"type": "json_object"}, "seed": 42} assert component.examples == [{"inputs": {"questions": "What is football?"}, "outputs": {"score": 0}}] diff --git a/test/components/evaluators/test_faithfulness_evaluator.py b/test/components/evaluators/test_faithfulness_evaluator.py index 43aadfe23d..76feecf319 100644 --- a/test/components/evaluators/test_faithfulness_evaluator.py +++ b/test/components/evaluators/test_faithfulness_evaluator.py @@ -67,13 +67,14 @@ def test_init_default(self, monkeypatch): ] assert isinstance(component._chat_generator, OpenAIChatGenerator) - assert component._chat_generator.client.api_key == "test-api-key" + assert component._chat_generator.api_key.resolve_value() == "test-api-key" assert component._chat_generator.generation_kwargs == {"response_format": {"type": "json_object"}, "seed": 42} - def test_init_fail_wo_openai_api_key(self, monkeypatch): + def test_key_resolved_at_warm_up_not_init(self, monkeypatch): monkeypatch.delenv("OPENAI_API_KEY", raising=False) + component = FaithfulnessEvaluator() with pytest.raises(ValueError, match="None of the .* environment variables are set"): - FaithfulnessEvaluator() + component.warm_up() def test_init_with_parameters(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") @@ -96,7 +97,7 @@ def test_init_with_parameters(self, monkeypatch): ] assert isinstance(component._chat_generator, OpenAIChatGenerator) - assert component._chat_generator.client.api_key == "test-api-key" + assert component._chat_generator.api_key.resolve_value() == "test-api-key" assert component._chat_generator.generation_kwargs == {"response_format": {"type": "json_object"}, "seed": 42} def test_init_with_chat_generator(self, monkeypatch): @@ -150,7 +151,7 @@ def test_from_dict(self, monkeypatch): } component = FaithfulnessEvaluator.from_dict(data) assert isinstance(component._chat_generator, OpenAIChatGenerator) - assert component._chat_generator.client.api_key == "test-api-key" + assert component._chat_generator.api_key.resolve_value() == "test-api-key" assert component._chat_generator.generation_kwargs == {"response_format": {"type": "json_object"}, "seed": 42} assert component.examples == [ {"inputs": {"predicted_answers": "Football is the most popular sport."}, "outputs": {"score": 0}} diff --git a/test/components/evaluators/test_llm_evaluator.py b/test/components/evaluators/test_llm_evaluator.py index 1af418cc37..71724ae99c 100644 --- a/test/components/evaluators/test_llm_evaluator.py +++ b/test/components/evaluators/test_llm_evaluator.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +from unittest.mock import AsyncMock, Mock import pytest @@ -30,20 +31,20 @@ def test_init_default(self, monkeypatch): ] assert isinstance(component._chat_generator, OpenAIChatGenerator) - assert component._chat_generator.client.api_key == "test-api-key" assert component._chat_generator.generation_kwargs == {"response_format": {"type": "json_object"}, "seed": 42} - def test_init_fail_wo_openai_api_key(self, monkeypatch): + def test_key_resolved_at_warm_up_not_init(self, monkeypatch): monkeypatch.delenv("OPENAI_API_KEY", raising=False) + component = LLMEvaluator( + instructions="test-instruction", + inputs=[("predicted_answers", list[str])], + outputs=["score"], + examples=[ + {"inputs": {"predicted_answers": "Football is the most popular sport."}, "outputs": {"score": 0}} + ], + ) with pytest.raises(ValueError, match="None of the .* environment variables are set"): - LLMEvaluator( - instructions="test-instruction", - inputs=[("predicted_answers", list[str])], - outputs=["score"], - examples=[ - {"inputs": {"predicted_answers": "Football is the most popular sport."}, "outputs": {"score": 0}} - ], - ) + component.warm_up() def test_init_with_chat_generator(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") @@ -268,7 +269,6 @@ def test_from_dict(self, monkeypatch): component = LLMEvaluator.from_dict(data) assert isinstance(component._chat_generator, OpenAIChatGenerator) - assert component._chat_generator.client.api_key == "test-api-key" assert component._chat_generator.generation_kwargs == {"response_format": {"type": "json_object"}, "seed": 42} assert component.instructions == "test-instruction" assert component.inputs == [("predicted_answers", list[str])] @@ -574,3 +574,63 @@ async def chat_generator_run_async(self, *args, **kwargs): with pytest.raises(ValueError): await component.run_async(questions=["question"], predicted_answers=["answer"]) + + +class TestComponentLifecycle: + @staticmethod + def _make_evaluator(chat_generator): + return LLMEvaluator( + instructions="test-instruction", + inputs=[("predicted_answers", list[str])], + outputs=["score"], + examples=[ + {"inputs": {"predicted_answers": "Football is the most popular sport."}, "outputs": {"score": 0}} + ], + chat_generator=chat_generator, + ) + + def test_warm_up_delegates_to_chat_generator(self): + chat_generator = Mock(spec=["run", "warm_up"]) + evaluator = self._make_evaluator(chat_generator) + evaluator.warm_up() + chat_generator.warm_up.assert_called_once() + + async def test_warm_up_async_delegates_to_chat_generator(self): + chat_generator = Mock(spec=["run", "warm_up_async"]) + chat_generator.warm_up_async = AsyncMock() + evaluator = self._make_evaluator(chat_generator) + await evaluator.warm_up_async() + chat_generator.warm_up_async.assert_awaited_once() + + async def test_warm_up_async_falls_back_to_sync_warm_up(self): + chat_generator = Mock(spec=["run", "warm_up"]) + evaluator = self._make_evaluator(chat_generator) + await evaluator.warm_up_async() + chat_generator.warm_up.assert_called_once() + + def test_close_delegates_to_chat_generator(self): + chat_generator = Mock(spec=["run", "close"]) + evaluator = self._make_evaluator(chat_generator) + evaluator.close() + chat_generator.close.assert_called_once() + + async def test_close_async_delegates_to_chat_generator(self): + chat_generator = Mock(spec=["run", "close_async"]) + chat_generator.close_async = AsyncMock() + evaluator = self._make_evaluator(chat_generator) + await evaluator.close_async() + chat_generator.close_async.assert_awaited_once() + + async def test_close_async_falls_back_to_sync_close(self): + chat_generator = Mock(spec=["run", "close"]) + evaluator = self._make_evaluator(chat_generator) + await evaluator.close_async() + chat_generator.close.assert_called_once() + + async def test_lifecycle_is_safe_when_chat_generator_lacks_methods(self): + chat_generator = Mock(spec=["run"]) + evaluator = self._make_evaluator(chat_generator) + evaluator.warm_up() + await evaluator.warm_up_async() + evaluator.close() + await evaluator.close_async() diff --git a/test/components/extractors/image/test_llm_document_content_extractor.py b/test/components/extractors/image/test_llm_document_content_extractor.py index 383f855814..ecad87e28a 100644 --- a/test/components/extractors/image/test_llm_document_content_extractor.py +++ b/test/components/extractors/image/test_llm_document_content_extractor.py @@ -130,20 +130,6 @@ def test_from_dict_openai(self, monkeypatch): assert extractor.max_workers == 4 assert component_to_dict(extractor._chat_generator, "name") == component_to_dict(chat_generator, "name") - def test_warm_up_with_chat_generator(self, monkeypatch): - mock_chat_generator = Mock() - mock_chat_generator.warm_up = Mock() - extractor = LLMDocumentContentExtractor(chat_generator=mock_chat_generator) - mock_chat_generator.warm_up.assert_not_called() - extractor.warm_up() - mock_chat_generator.warm_up.assert_called_once() - - def test_warm_up_without_warm_up_method(self, monkeypatch): - mock_chat_generator = Mock() - extractor = LLMDocumentContentExtractor(chat_generator=mock_chat_generator) - extractor.warm_up() - assert extractor._is_warmed_up is True - def test_run_no_documents(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") chat_generator = OpenAIChatGenerator() @@ -658,3 +644,51 @@ async def test_live_run_async(self): assert len(result["failed_documents"]) == 0 assert len(result["documents"]) == 1 assert len(result["documents"][0].content) > 0 + + +class TestComponentLifecycle: + def test_warm_up_delegates_to_chat_generator(self): + mock_chat_generator = Mock(spec=["run", "warm_up"]) + extractor = LLMDocumentContentExtractor(chat_generator=mock_chat_generator) + extractor.warm_up() + mock_chat_generator.warm_up.assert_called_once() + + async def test_warm_up_async_delegates_to_chat_generator(self): + mock_chat_generator = Mock(spec=["run", "warm_up_async"]) + mock_chat_generator.warm_up_async = AsyncMock() + extractor = LLMDocumentContentExtractor(chat_generator=mock_chat_generator) + await extractor.warm_up_async() + mock_chat_generator.warm_up_async.assert_awaited_once() + + async def test_warm_up_async_falls_back_to_sync_warm_up(self): + mock_chat_generator = Mock(spec=["run", "warm_up"]) + extractor = LLMDocumentContentExtractor(chat_generator=mock_chat_generator) + await extractor.warm_up_async() + mock_chat_generator.warm_up.assert_called_once() + + def test_close_delegates_to_chat_generator(self): + mock_chat_generator = Mock(spec=["run", "close"]) + extractor = LLMDocumentContentExtractor(chat_generator=mock_chat_generator) + extractor.close() + mock_chat_generator.close.assert_called_once() + + async def test_close_async_delegates_to_chat_generator(self): + mock_chat_generator = Mock(spec=["run", "close_async"]) + mock_chat_generator.close_async = AsyncMock() + extractor = LLMDocumentContentExtractor(chat_generator=mock_chat_generator) + await extractor.close_async() + mock_chat_generator.close_async.assert_awaited_once() + + async def test_close_async_falls_back_to_sync_close(self): + mock_chat_generator = Mock(spec=["run", "close"]) + extractor = LLMDocumentContentExtractor(chat_generator=mock_chat_generator) + await extractor.close_async() + mock_chat_generator.close.assert_called_once() + + async def test_lifecycle_is_safe_when_chat_generator_lacks_methods(self): + mock_chat_generator = Mock(spec=["run"]) + extractor = LLMDocumentContentExtractor(chat_generator=mock_chat_generator) + extractor.warm_up() + await extractor.warm_up_async() + extractor.close() + await extractor.close_async() diff --git a/test/components/extractors/test_llm_metadata_extractor.py b/test/components/extractors/test_llm_metadata_extractor.py index 1e2c8e3916..457f03b537 100644 --- a/test/components/extractors/test_llm_metadata_extractor.py +++ b/test/components/extractors/test_llm_metadata_extractor.py @@ -4,7 +4,7 @@ import asyncio import os -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock import pytest @@ -141,13 +141,6 @@ def test_from_dict_openai(self, monkeypatch): assert extractor.prompt == "some prompt that was used with the LLM {{document.content}}" assert extractor._chat_generator.to_dict() == chat_generator.to_dict() - def test_warm_up_with_chat_generator(self, monkeypatch): - mock_chat_generator = Mock() - extractor = LLMMetadataExtractor(prompt="prompt {{document.content}}", chat_generator=mock_chat_generator) - mock_chat_generator.warm_up.assert_not_called() - extractor.warm_up() - mock_chat_generator.warm_up.assert_called_once() - def test_extract_metadata(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") extractor = LLMMetadataExtractor(prompt="prompt {{document.content}}", chat_generator=OpenAIChatGenerator()) @@ -513,3 +506,66 @@ async def test_live_run_async(self, in_memory_doc_store: InMemoryDocumentStore, assert len(doc_store_docs) == 2 assert "entities" in doc_store_docs[0].meta assert "entities" in doc_store_docs[1].meta + + +class TestComponentLifecycle: + def test_warm_up_delegates_to_inner_components(self): + mock_chat_generator = Mock(spec=["run", "warm_up"]) + extractor = LLMMetadataExtractor(prompt="prompt {{document.content}}", chat_generator=mock_chat_generator) + extractor.splitter = Mock(spec=["run", "warm_up"]) + extractor.warm_up() + mock_chat_generator.warm_up.assert_called_once() + extractor.splitter.warm_up.assert_called_once() + + async def test_warm_up_async_delegates_to_inner_components(self): + mock_chat_generator = Mock(spec=["run", "warm_up", "warm_up_async"]) + mock_chat_generator.warm_up_async = AsyncMock() + extractor = LLMMetadataExtractor(prompt="prompt {{document.content}}", chat_generator=mock_chat_generator) + extractor.splitter = Mock(spec=["run", "warm_up_async"]) + extractor.splitter.warm_up_async = AsyncMock() + await extractor.warm_up_async() + mock_chat_generator.warm_up_async.assert_awaited_once() + extractor.splitter.warm_up_async.assert_awaited_once() + + async def test_warm_up_async_falls_back_to_sync_warm_up(self): + mock_chat_generator = Mock(spec=["run", "warm_up"]) + extractor = LLMMetadataExtractor(prompt="prompt {{document.content}}", chat_generator=mock_chat_generator) + extractor.splitter = Mock(spec=["run", "warm_up"]) + await extractor.warm_up_async() + mock_chat_generator.warm_up.assert_called_once() + extractor.splitter.warm_up.assert_called_once() + + def test_close_delegates_to_inner_components(self): + mock_chat_generator = Mock(spec=["run", "close"]) + extractor = LLMMetadataExtractor(prompt="prompt {{document.content}}", chat_generator=mock_chat_generator) + extractor.splitter = Mock(spec=["run", "close"]) + extractor.close() + mock_chat_generator.close.assert_called_once() + extractor.splitter.close.assert_called_once() + + async def test_close_async_delegates_to_inner_components(self): + mock_chat_generator = Mock(spec=["run", "close_async"]) + mock_chat_generator.close_async = AsyncMock() + extractor = LLMMetadataExtractor(prompt="prompt {{document.content}}", chat_generator=mock_chat_generator) + extractor.splitter = Mock(spec=["run", "close_async"]) + extractor.splitter.close_async = AsyncMock() + await extractor.close_async() + mock_chat_generator.close_async.assert_awaited_once() + extractor.splitter.close_async.assert_awaited_once() + + async def test_close_async_falls_back_to_sync_close(self): + mock_chat_generator = Mock(spec=["run", "close"]) + extractor = LLMMetadataExtractor(prompt="prompt {{document.content}}", chat_generator=mock_chat_generator) + extractor.splitter = Mock(spec=["run", "close"]) + await extractor.close_async() + mock_chat_generator.close.assert_called_once() + extractor.splitter.close.assert_called_once() + + async def test_lifecycle_is_safe_when_inner_lacks_methods(self): + mock_chat_generator = Mock(spec=["run"]) + extractor = LLMMetadataExtractor(prompt="prompt {{document.content}}", chat_generator=mock_chat_generator) + extractor.splitter = Mock(spec=["run"]) + extractor.warm_up() + await extractor.warm_up_async() + extractor.close() + await extractor.close_async() diff --git a/test/components/fetchers/test_link_content_fetcher.py b/test/components/fetchers/test_link_content_fetcher.py index 1462f295b0..62782a3515 100644 --- a/test/components/fetchers/test_link_content_fetcher.py +++ b/test/components/fetchers/test_link_content_fetcher.py @@ -57,8 +57,8 @@ def test_init(self): "video/*": _binary_content_handler, } assert hasattr(fetcher, "_get_response") - assert hasattr(fetcher, "_client") - assert isinstance(fetcher._client, httpx.Client) + assert fetcher._client is None + assert fetcher._async_client is None def test_init_with_params(self): """Test initialization with custom parameters""" @@ -183,6 +183,106 @@ def test_request_headers_merging_and_ua_override(self): assert sent_headers["User-Agent"] == "ua-sync-1" # rotating UA wins +class TestComponentLifecycle: + def test_clients_are_none_after_init(self): + fetcher = LinkContentFetcher() + assert fetcher._client is None + assert fetcher._async_client is None + + def test_sync_lifecycle(self): + with patch("haystack.components.fetchers.link_content.httpx.Client") as ClientMock: + client_instance = ClientMock.return_value + fetcher = LinkContentFetcher() + + fetcher.warm_up() + assert fetcher._client is client_instance + assert fetcher._async_client is None + ClientMock.assert_called_once() + + fetcher.close() + client_instance.close.assert_called_once() + assert fetcher._client is None + + def test_warm_up_is_idempotent(self): + with patch("haystack.components.fetchers.link_content.httpx.Client") as ClientMock: + fetcher = LinkContentFetcher() + fetcher.warm_up() + fetcher.warm_up() + ClientMock.assert_called_once() + + @pytest.mark.asyncio + async def test_async_lifecycle(self): + with patch("haystack.components.fetchers.link_content.httpx.AsyncClient") as AsyncClientMock: + async_client_instance = AsyncClientMock.return_value + async_client_instance.aclose = AsyncMock() + fetcher = LinkContentFetcher() + + await fetcher.warm_up_async() + assert fetcher._async_client is async_client_instance + assert fetcher._client is None + AsyncClientMock.assert_called_once() + + await fetcher.close_async() + async_client_instance.aclose.assert_awaited_once() + assert fetcher._async_client is None + + @pytest.mark.asyncio + async def test_warm_up_async_is_idempotent(self): + with patch("haystack.components.fetchers.link_content.httpx.AsyncClient") as AsyncClientMock: + fetcher = LinkContentFetcher() + await fetcher.warm_up_async() + await fetcher.warm_up_async() + AsyncClientMock.assert_called_once() + + @pytest.mark.asyncio + async def test_close_is_safe_without_warm_up(self): + fetcher = LinkContentFetcher() + fetcher.close() + await fetcher.close_async() + assert fetcher._client is None + assert fetcher._async_client is None + + @pytest.mark.asyncio + async def test_close_and_close_async_are_independent(self): + with ( + patch("haystack.components.fetchers.link_content.httpx.Client") as ClientMock, + patch("haystack.components.fetchers.link_content.httpx.AsyncClient") as AsyncClientMock, + ): + client_instance = ClientMock.return_value + async_client_instance = AsyncClientMock.return_value + async_client_instance.aclose = AsyncMock() + + fetcher = LinkContentFetcher() + fetcher.warm_up() + await fetcher.warm_up_async() + + fetcher.close() + assert fetcher._client is None + assert fetcher._async_client is async_client_instance + async_client_instance.aclose.assert_not_awaited() + + await fetcher.close_async() + assert fetcher._async_client is None + client_instance.close.assert_called_once() + + def test_run_self_heals(self): + with patch("haystack.components.fetchers.link_content.httpx.Client.get") as mock_get: + mock_response = Mock(status_code=200, text="ok", headers={"Content-Type": "text/plain"}) + mock_get.return_value = mock_response + fetcher = LinkContentFetcher() + fetcher.run(urls=["https://www.example.com"]) + assert fetcher._client is not None + + @pytest.mark.asyncio + async def test_run_async_self_heals(self): + with patch("haystack.components.fetchers.link_content.httpx.AsyncClient.get") as mock_get: + mock_response = Mock(status_code=200, text="ok", headers={"Content-Type": "text/plain"}) + mock_get.return_value = mock_response + fetcher = LinkContentFetcher() + await fetcher.run_async(urls=["https://www.example.com"]) + assert fetcher._async_client is not None + + @pytest.mark.flaky(reruns=3, reruns_delay=5) @pytest.mark.integration class TestLinkContentFetcherIntegration: diff --git a/test/components/generators/chat/test_azure.py b/test/components/generators/chat/test_azure.py index 167775cfd7..1c020f1818 100644 --- a/test/components/generators/chat/test_azure.py +++ b/test/components/generators/chat/test_azure.py @@ -2,15 +2,16 @@ # # SPDX-License-Identifier: Apache-2.0 -import contextlib import json import os from typing import Any +from unittest.mock import AsyncMock, MagicMock import pytest from openai import OpenAIError from pydantic import BaseModel +import haystack.components.generators.chat.azure as azure_chat_module from haystack import Pipeline, component from haystack.components.generators.chat import AzureOpenAIChatGenerator from haystack.components.generators.utils import print_streaming_chunk @@ -86,16 +87,19 @@ def test_supported_models(self) -> None: def test_init_default(self, monkeypatch): monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint") - assert component.client.api_key == "test-api-key" + assert component.api_key == Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False) assert component.azure_deployment == "gpt-4.1-mini" assert component.streaming_callback is None assert not component.generation_kwargs + assert component.client is None + assert component.async_client is None - def test_init_fail_wo_api_key(self, monkeypatch): + def test_init_does_not_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("AZURE_OPENAI_API_KEY", raising=False) monkeypatch.delenv("AZURE_OPENAI_AD_TOKEN", raising=False) - with pytest.raises(OpenAIError): - AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint") + component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint") + assert component.client is None + assert component.async_client is None def test_init_with_parameters(self, tools): component = AzureOpenAIChatGenerator( @@ -107,14 +111,16 @@ def test_init_with_parameters(self, tools): tools_strict=True, azure_ad_token_provider=default_azure_ad_token_provider, ) - assert component.client.api_key == "test-api-key" + assert component.api_key == Secret.from_token("test-api-key") assert component.azure_deployment == "gpt-4.1-mini" assert component.streaming_callback is print_streaming_chunk assert component.generation_kwargs == {"max_completion_tokens": 10, "some_test_param": "test-params"} assert component.tools == tools assert component.tools_strict assert component.azure_ad_token_provider is not None - assert component.max_retries == 5 + assert component.max_retries is None + assert component.client is None + assert component.async_client is None def test_init_with_0_max_retries(self, tools): """Tests that the max_retries init param is set correctly if equal 0""" @@ -128,7 +134,7 @@ def test_init_with_0_max_retries(self, tools): azure_ad_token_provider=default_azure_ad_token_provider, max_retries=0, ) - assert component.client.api_key == "test-api-key" + assert component.api_key == Secret.from_token("test-api-key") assert component.azure_deployment == "gpt-4.1-mini" assert component.streaming_callback is print_streaming_chunk assert component.generation_kwargs == {"max_completion_tokens": 10, "some_test_param": "test-params"} @@ -136,6 +142,8 @@ def test_init_with_0_max_retries(self, tools): assert component.tools_strict assert component.azure_ad_token_provider is not None assert component.max_retries == 0 + assert component.client is None + assert component.async_client is None def test_init_with_secret_azure_endpoint_and_api_version(self, monkeypatch): """`azure_endpoint` and `api_version` accept a Secret that is resolved from an environment variable.""" @@ -149,11 +157,6 @@ def test_init_with_secret_azure_endpoint_and_api_version(self, monkeypatch): # The Secret objects are kept on the instance so they can be serialized assert component.azure_endpoint == Secret.from_env_var("AZURE_OPENAI_ENDPOINT") assert component.api_version == Secret.from_env_var("AZURE_OPENAI_API_VERSION") - # The clients receive the resolved string values - assert str(component.client._azure_endpoint) == "https://test-resource.azure.openai.com/" - assert component.client._api_version == "2024-08-01-preview" - assert str(component.async_client._azure_endpoint) == "https://test-resource.azure.openai.com/" - assert component.async_client._api_version == "2024-08-01-preview" def test_init_fail_with_unset_secret_azure_endpoint(self, monkeypatch): """A Secret azure_endpoint that resolves to nothing raises the same error as a missing endpoint.""" @@ -195,6 +198,7 @@ def test_secret_azure_endpoint_and_api_version_roundtrip(self, monkeypatch): deserialized = AzureOpenAIChatGenerator.from_dict(component.to_dict()) assert deserialized.azure_endpoint == Secret.from_env_var("AZURE_OPENAI_ENDPOINT") assert deserialized.api_version == Secret.from_env_var("AZURE_OPENAI_API_VERSION") + deserialized.warm_up() assert str(deserialized.client._azure_endpoint) == "https://test-resource.azure.openai.com/" assert deserialized.client._api_version == "2024-08-01-preview" @@ -214,8 +218,8 @@ def test_from_dict_with_secret_azure_endpoint_and_api_version(self, monkeypatch) "organization": None, "streaming_callback": None, "generation_kwargs": {}, - "timeout": 30.0, - "max_retries": 5, + "timeout": None, + "max_retries": None, "default_headers": {}, "tools": None, "tools_strict": False, @@ -228,6 +232,7 @@ def test_from_dict_with_secret_azure_endpoint_and_api_version(self, monkeypatch) assert generator.azure_endpoint == Secret.from_env_var("AZURE_OPENAI_ENDPOINT") assert generator.api_version == Secret.from_env_var("AZURE_OPENAI_API_VERSION") # And they are resolved to the string values the client expects + generator.warm_up() assert str(generator.client._azure_endpoint) == "https://test-resource.azure.openai.com/" assert generator.client._api_version == "2024-08-01-preview" @@ -246,8 +251,8 @@ def test_to_dict_default(self, monkeypatch): "organization": None, "streaming_callback": None, "generation_kwargs": {}, - "timeout": 30.0, - "max_retries": 5, + "timeout": None, + "max_retries": None, "default_headers": {}, "tools": None, "tools_strict": False, @@ -389,8 +394,8 @@ def test_pipeline_serialization_deserialization(self, tmp_path, monkeypatch): "api_version": "2024-12-01-preview", "streaming_callback": None, "generation_kwargs": {}, - "timeout": 30.0, - "max_retries": 5, + "timeout": None, + "max_retries": None, "api_key": {"type": "env_var", "env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False}, "azure_ad_token": {"type": "env_var", "env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False}, "default_headers": {}, @@ -531,113 +536,9 @@ def test_to_dict_with_toolset(self, tools, monkeypatch): } assert data["init_parameters"]["tools"] == expected_tools_data - def test_warm_up_with_tools(self, monkeypatch): - """Test that warm_up() calls warm_up on tools and is idempotent.""" - monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") - - # Create a mock tool that tracks if warm_up() was called - class MockTool(Tool): - warm_up_call_count = 0 # Class variable to track calls - - def __init__(self): - super().__init__( - name="mock_tool", - description="A mock tool for testing", - parameters={"x": {"type": "string"}}, - function=lambda x: x, - ) - - def warm_up(self): - MockTool.warm_up_call_count += 1 - - # Reset the class variable before test - MockTool.warm_up_call_count = 0 - mock_tool = MockTool() - - # Create AzureOpenAIChatGenerator with the mock tool - component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint", tools=[mock_tool]) - - # Verify initial state - warm_up not called yet - assert MockTool.warm_up_call_count == 0 - assert not component._is_warmed_up - - # Call warm_up() on the generator - component.warm_up() - - # Assert that the tool's warm_up() was called - assert MockTool.warm_up_call_count == 1 - assert component._is_warmed_up - - # Call warm_up() again and verify it's idempotent (only warms up once) - component.warm_up() - - # The tool's warm_up should still only have been called once - assert MockTool.warm_up_call_count == 1 - assert component._is_warmed_up - - def test_warm_up_with_no_tools(self, monkeypatch): - """Test that warm_up() works when no tools are provided.""" - monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") - - component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint") - - # Verify initial state - assert not component._is_warmed_up - assert component.tools is None - - # Call warm_up() - should not raise an error - component.warm_up() - - # Verify the component is warmed up - assert component._is_warmed_up - - # Call warm_up() again - should be idempotent - component.warm_up() - assert component._is_warmed_up - - def test_warm_up_with_multiple_tools(self, monkeypatch): - """Test that warm_up() works with multiple tools.""" - monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") - - # Track warm_up calls - warm_up_calls = [] - - class MockTool(Tool): - def __init__(self, tool_name): - super().__init__( - name=tool_name, - description=f"Mock tool {tool_name}", - parameters={"type": "object", "properties": {"x": {"type": "string"}}, "required": ["x"]}, - function=lambda x: f"{tool_name} result: {x}", - ) - - def warm_up(self): - warm_up_calls.append(self.name) - - mock_tool1 = MockTool("tool1") - mock_tool2 = MockTool("tool2") - - # Use a LIST of tools, not a Toolset - component = AzureOpenAIChatGenerator( - azure_endpoint="some-non-existing-endpoint", tools=[mock_tool1, mock_tool2] - ) - - # Call warm_up() - component.warm_up() - - # Assert that both tools' warm_up() were called - assert "tool1" in warm_up_calls - assert "tool2" in warm_up_calls - assert component._is_warmed_up - - # Test idempotency - warm_up should not call tools again - initial_count = len(warm_up_calls) - component.warm_up() - assert len(warm_up_calls) == initial_count - class TestAzureOpenAIChatGeneratorAsync: - def test_init_should_also_create_async_client_with_same_args(self, tools): + async def test_warm_up_async_builds_async_client(self, tools): component = AzureOpenAIChatGenerator( api_key=Secret.from_token("test-api-key"), azure_endpoint="some-non-existing-endpoint", @@ -646,7 +547,10 @@ def test_init_should_also_create_async_client_with_same_args(self, tools): tools=tools, tools_strict=True, ) + assert component.async_client is None + await component.warm_up_async() assert component.async_client.api_key == "test-api-key" + assert component.client is None assert component.azure_deployment == "gpt-4.1-mini" assert component.streaming_callback is print_streaming_chunk assert component.generation_kwargs == {"max_completion_tokens": 10, "some_test_param": "test-params"} @@ -672,9 +576,7 @@ async def test_live_run_async(self): assert "Paris" in message.text assert "gpt-4.1-mini" in message.meta["model"] assert message.meta["finish_reason"] == "stop" - # Close async client; suppress RuntimeError if the event loop is already closed - with contextlib.suppress(RuntimeError): - await component.async_client.close() + await component.close_async() @pytest.mark.integration @pytest.mark.skipif( @@ -702,8 +604,133 @@ async def test_live_run_with_tools_async(self, tools): assert tool_call.arguments == {"city": "Paris"} assert message.meta["finish_reason"] == "tool_calls" - # Close async client; suppress RuntimeError if the event loop is already closed - with contextlib.suppress(RuntimeError): - await component.async_client.close() + await component.close_async() # additional tests intentionally omitted as they are covered by test_openai.py + + +@pytest.fixture +def mock_azure_clients(monkeypatch): + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake") + sync_cls = MagicMock(name="AzureOpenAI") + async_cls = MagicMock(name="AsyncAzureOpenAI") + async_cls.return_value.close = AsyncMock() + monkeypatch.setattr(azure_chat_module, "AzureOpenAI", sync_cls) + monkeypatch.setattr(azure_chat_module, "AsyncAzureOpenAI", async_cls) + return sync_cls, async_cls + + +class TestComponentLifecycle: + def test_warm_up_uses_default_timeout_and_max_retries(self, monkeypatch): + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") + generator = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint") + generator.warm_up() + assert generator.client.max_retries == 5 + assert generator.client.timeout == 30.0 + + def test_warm_up_uses_timeout_and_max_retries_from_parameters(self): + generator = AzureOpenAIChatGenerator( + api_key=Secret.from_token("fake-api-key"), + azure_endpoint="some-non-existing-endpoint", + timeout=40.0, + max_retries=1, + ) + generator.warm_up() + assert generator.client.max_retries == 1 + assert generator.client.timeout == 40.0 + + def test_warm_up_uses_timeout_and_max_retries_from_env_vars(self, monkeypatch): + monkeypatch.setenv("OPENAI_TIMEOUT", "100") + monkeypatch.setenv("OPENAI_MAX_RETRIES", "10") + generator = AzureOpenAIChatGenerator( + api_key=Secret.from_token("fake-api-key"), azure_endpoint="some-non-existing-endpoint" + ) + generator.warm_up() + assert generator.client.max_retries == 10 + assert generator.client.timeout == 100.0 + + def test_key_resolved_at_warm_up_not_init(self, monkeypatch): + monkeypatch.delenv("AZURE_OPENAI_API_KEY", raising=False) + monkeypatch.delenv("AZURE_OPENAI_AD_TOKEN", raising=False) + generator = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint") + with pytest.raises(OpenAIError): + generator.warm_up() + + def test_warm_up_warms_tools_once(self, monkeypatch): + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") + warm_up_calls = [] + + class MockTool(Tool): + def __init__(self, tool_name): + super().__init__( + name=tool_name, + description=f"Mock tool {tool_name}", + parameters={"type": "object", "properties": {"x": {"type": "string"}}, "required": ["x"]}, + function=lambda x: x, + ) + + def warm_up(self): + warm_up_calls.append(self.name) + + generator = AzureOpenAIChatGenerator( + azure_endpoint="some-non-existing-endpoint", tools=[MockTool("tool1"), MockTool("tool2")] + ) + assert not generator._tools_warmed_up + + generator.warm_up() + assert sorted(warm_up_calls) == ["tool1", "tool2"] + assert generator._tools_warmed_up + + generator.warm_up() + assert sorted(warm_up_calls) == ["tool1", "tool2"] + + def test_warm_up_with_no_tools_does_not_raise(self, monkeypatch): + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key") + generator = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint") + generator.warm_up() + assert generator._tools_warmed_up + + def test_sync_lifecycle(self, mock_azure_clients): + sync_cls, _ = mock_azure_clients + generator = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint") + assert generator.client is None + assert generator.async_client is None + + generator.warm_up() + assert generator.client is sync_cls.return_value + assert generator.async_client is None + + generator.close() + sync_cls.return_value.close.assert_called_once() + assert generator.client is None + + async def test_async_lifecycle(self, mock_azure_clients): + _, async_cls = mock_azure_clients + generator = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint") + + await generator.warm_up_async() + assert generator.async_client is async_cls.return_value + assert generator.client is None + + await generator.close_async() + async_cls.return_value.close.assert_awaited_once() + assert generator.async_client is None + + async def test_close_is_safe_without_warm_up(self, mock_azure_clients): + generator = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint") + generator.close() + await generator.close_async() + assert generator.client is None + assert generator.async_client is None + + async def test_close_and_close_async_are_independent(self, mock_azure_clients): + generator = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint") + generator.warm_up() + await generator.warm_up_async() + + generator.close() + assert generator.client is None + assert generator.async_client is not None + + await generator.close_async() + assert generator.async_client is None diff --git a/test/components/generators/chat/test_azure_responses.py b/test/components/generators/chat/test_azure_responses.py index 636a99356c..db3dcf8214 100644 --- a/test/components/generators/chat/test_azure_responses.py +++ b/test/components/generators/chat/test_azure_responses.py @@ -7,7 +7,6 @@ from typing import Any import pytest -from openai import OpenAIError from pydantic import BaseModel from haystack import Pipeline, component @@ -85,17 +84,12 @@ def test_supported_models(self) -> None: def test_init_default(self, monkeypatch): monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") component = AzureOpenAIResponsesChatGenerator(azure_endpoint="some-non-existing-endpoint") - assert component.client.api_key == "test-api-key" + assert component.client is None + assert component.async_client is None assert component._azure_deployment == "gpt-5-mini" assert component.streaming_callback is None assert not component.generation_kwargs - def test_init_fail_wo_api_key(self, monkeypatch): - monkeypatch.delenv("AZURE_OPENAI_API_KEY", raising=False) - monkeypatch.delenv("OPENAI_API_KEY", raising=False) - with pytest.raises(OpenAIError): - AzureOpenAIResponsesChatGenerator(azure_endpoint="some-non-existing-endpoint") - def test_init_fail_wo_azure_endpoint(self, monkeypatch): monkeypatch.delenv("AZURE_OPENAI_ENDPOINT", raising=False) with pytest.raises(ValueError): @@ -110,7 +104,8 @@ def test_init_with_parameters(self, tools): tools=tools, tools_strict=True, ) - assert component.client.api_key == "test-api-key" + assert component.client is None + assert component.async_client is None assert component._azure_deployment == "gpt-5-mini" assert component.streaming_callback is print_streaming_chunk assert component.generation_kwargs == {"max_completion_tokens": 10, "some_test_param": "test-params"} @@ -385,96 +380,72 @@ def test_pipeline_serialization_deserialization(self, tmp_path, monkeypatch): assert p.to_dict() == q.to_dict() -class TestWarmUp: - def test_warm_up_with_tools(self, monkeypatch): - """Test that warm_up() calls warm_up on tools and is idempotent.""" +class TestComponentLifecycle: + def test_warm_up_warms_tools_once(self, monkeypatch): monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") + warm_up_calls = [] - # Create a mock tool that tracks if warm_up() was called class MockTool(Tool): - warm_up_call_count = 0 # Class variable to track calls - - def __init__(self): + def __init__(self, tool_name): super().__init__( - name="mock_tool", - description="A mock tool for testing", - parameters={"x": {"type": "string"}}, + name=tool_name, + description=f"Mock tool {tool_name}", + parameters={"type": "object", "properties": {"x": {"type": "string"}}, "required": ["x"]}, function=lambda x: x, ) def warm_up(self): - MockTool.warm_up_call_count += 1 - - # Reset the class variable before test - MockTool.warm_up_call_count = 0 - mock_tool = MockTool() - - # Create AzureOpenAIChatGenerator with the mock tool - component = AzureOpenAIResponsesChatGenerator(azure_endpoint="some-non-existing-endpoint", tools=[mock_tool]) + warm_up_calls.append(self.name) - # Verify initial state - warm_up not called yet - assert MockTool.warm_up_call_count == 0 - assert not component._is_warmed_up + component = AzureOpenAIResponsesChatGenerator( + azure_endpoint="some-non-existing-endpoint", tools=[MockTool("tool1"), MockTool("tool2")] + ) + assert not component._tools_warmed_up - # Call warm_up() on the generator component.warm_up() + assert sorted(warm_up_calls) == ["tool1", "tool2"] + assert component._tools_warmed_up - # Assert that the tool's warm_up() was called - assert MockTool.warm_up_call_count == 1 - assert component._is_warmed_up - - # Call warm_up() again and verify it's idempotent (only warms up once) component.warm_up() + assert sorted(warm_up_calls) == ["tool1", "tool2"] - # The tool's warm_up should still only have been called once - assert MockTool.warm_up_call_count == 1 - assert component._is_warmed_up - - def test_warm_up_with_no_tools(self, monkeypatch): - """Test that warm_up() works when no tools are provided.""" + def test_warm_up_with_no_tools_does_not_raise(self, monkeypatch): monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") component = AzureOpenAIResponsesChatGenerator(azure_endpoint="some-non-existing-endpoint") - - # Verify initial state - assert not component._is_warmed_up - assert component.tools is None - - # Verify the component is warmed up component.warm_up() - assert component._is_warmed_up + assert component._tools_warmed_up - def test_warm_up_with_multiple_tools(self, monkeypatch): - """Test that warm_up() works with multiple tools.""" + def test_sync_lifecycle(self, monkeypatch): monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") - warm_up_calls = [] + component = AzureOpenAIResponsesChatGenerator(azure_endpoint="some-non-existing-endpoint") + assert component.client is None + assert component.async_client is None - class MockTool(Tool): - def __init__(self, tool_name): - super().__init__( - name=tool_name, - description=f"Mock tool {tool_name}", - parameters={"type": "object", "properties": {"x": {"type": "string"}}, "required": ["x"]}, - function=lambda x: f"{tool_name} result: {x}", - ) + component.warm_up() + assert component.client is not None + assert component.async_client is None - def warm_up(self): - warm_up_calls.append(self.name) + component.close() + assert component.client is None - # Use a LIST of tools, not a Toolset - component = AzureOpenAIResponsesChatGenerator( - azure_endpoint="some-non-existing-endpoint", tools=[MockTool("tool1"), MockTool("tool2")] - ) + async def test_async_lifecycle(self, monkeypatch): + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") + component = AzureOpenAIResponsesChatGenerator(azure_endpoint="some-non-existing-endpoint") - # Assert that both tools' warm_up() were called - component.warm_up() - assert "tool1" in warm_up_calls - assert "tool2" in warm_up_calls - assert component._is_warmed_up + await component.warm_up_async() + assert component.async_client is not None + assert component.client is None - # Test idempotency - warm_up should not call tools again - initial_count = len(warm_up_calls) - component.warm_up() - assert len(warm_up_calls) == initial_count + await component.close_async() + assert component.async_client is None + + async def test_close_is_safe_without_warm_up(self, monkeypatch): + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") + component = AzureOpenAIResponsesChatGenerator(azure_endpoint="some-non-existing-endpoint") + component.close() + await component.close_async() + assert component.client is None + assert component.async_client is None @pytest.mark.integration @@ -565,7 +536,7 @@ def test_live_run_with_text_format_json_schema(self): class TestAzureOpenAIResponsesChatGeneratorAsync: - def test_init_should_also_create_async_client_with_same_args(self, tools): + async def test_warm_up_async_creates_async_client_with_expected_args(self, tools): component = AzureOpenAIResponsesChatGenerator( api_key=Secret.from_token("test-api-key"), azure_endpoint="some-non-existing-endpoint", @@ -574,6 +545,10 @@ def test_init_should_also_create_async_client_with_same_args(self, tools): tools=tools, tools_strict=True, ) + assert component.async_client is None + + await component.warm_up_async() + assert component.async_client.api_key == "test-api-key" assert component._azure_deployment == "gpt-5-mini" assert component.streaming_callback is print_streaming_chunk diff --git a/test/components/generators/chat/test_fallback.py b/test/components/generators/chat/test_fallback.py index 7b14cbd66d..ceaa023843 100644 --- a/test/components/generators/chat/test_fallback.py +++ b/test/components/generators/chat/test_fallback.py @@ -5,6 +5,7 @@ import asyncio import time from typing import Any +from unittest.mock import AsyncMock, Mock from urllib.error import HTTPError as URLLibHTTPError import pytest @@ -375,69 +376,55 @@ async def test_failover_trigger_401_authentication_async(): assert result["meta"]["failed_chat_generators"] == ["_DummyHTTPErrorGen"] -@component -class _DummyGenWithWarmUp: - """Dummy generator that tracks warm_up calls.""" - - def __init__(self, text: str = "ok"): - self.text = text - self.warm_up_called = False - - def warm_up(self) -> None: - self.warm_up_called = True - - def run( - self, - messages: list[ChatMessage], - generation_kwargs: dict[str, Any] | None = None, - tools: ToolsType | None = None, - streaming_callback: StreamingCallbackT | None = None, - ) -> dict[str, Any]: - return {"replies": [ChatMessage.from_assistant(self.text)], "meta": {}} - - -def test_warm_up_delegates_to_generators(): - """Test that warm_up() is called on each underlying generator.""" - gen1 = _DummyGenWithWarmUp(text="A") - gen2 = _DummyGenWithWarmUp(text="B") - gen3 = _DummyGenWithWarmUp(text="C") - - fallback = FallbackChatGenerator(chat_generators=[gen1, gen2, gen3]) - fallback.warm_up() - - assert gen1.warm_up_called - assert gen2.warm_up_called - assert gen3.warm_up_called - - -def test_warm_up_with_no_warm_up_method(): - """Test that warm_up() handles generators without warm_up() gracefully.""" - gen1 = _DummySuccessGen(text="A") - gen2 = _DummySuccessGen(text="B") - - fallback = FallbackChatGenerator(chat_generators=[gen1, gen2]) - # Should not raise any error - fallback.warm_up() - - # Verify generators still work - result = fallback.run([ChatMessage.from_user("test")]) - assert result["replies"][0].text == "A" - - -def test_warm_up_mixed_generators(): - """Test warm_up() with a mix of generators with and without warm_up().""" - gen1 = _DummyGenWithWarmUp(text="A") - gen2 = _DummySuccessGen(text="B") - gen3 = _DummyGenWithWarmUp(text="C") - gen4 = _DummyFailGen() - - fallback = FallbackChatGenerator(chat_generators=[gen1, gen2, gen3, gen4]) - fallback.warm_up() - - # Only generators with warm_up() should have been called - assert gen1.warm_up_called - assert gen3.warm_up_called - - # Verify the fallback still works correctly - result = fallback.run([ChatMessage.from_user("test")]) - assert result["replies"][0].text == "A" +class TestComponentLifecycle: + def test_warm_up_delegates_to_every_generator(self): + gens = [Mock(spec=["run", "warm_up"]) for _ in range(3)] + fallback = FallbackChatGenerator(chat_generators=gens) + fallback.warm_up() + for gen in gens: + gen.warm_up.assert_called_once() + + async def test_warm_up_async_delegates_to_every_generator(self): + gens = [Mock(spec=["run", "warm_up_async"]) for _ in range(3)] + for gen in gens: + gen.warm_up_async = AsyncMock() + fallback = FallbackChatGenerator(chat_generators=gens) + await fallback.warm_up_async() + for gen in gens: + gen.warm_up_async.assert_awaited_once() + + async def test_warm_up_async_falls_back_to_sync_warm_up(self): + gens = [Mock(spec=["run", "warm_up"]) for _ in range(3)] + fallback = FallbackChatGenerator(chat_generators=gens) + await fallback.warm_up_async() + for gen in gens: + gen.warm_up.assert_called_once() + + def test_close_delegates_to_every_generator(self): + gens = [Mock(spec=["run", "close"]) for _ in range(3)] + fallback = FallbackChatGenerator(chat_generators=gens) + fallback.close() + for gen in gens: + gen.close.assert_called_once() + + async def test_close_async_delegates_to_every_generator(self): + gens = [Mock(spec=["run", "close_async"]) for _ in range(3)] + for gen in gens: + gen.close_async = AsyncMock() + fallback = FallbackChatGenerator(chat_generators=gens) + await fallback.close_async() + for gen in gens: + gen.close_async.assert_awaited_once() + + async def test_close_async_falls_back_to_sync_close(self): + gens = [Mock(spec=["run", "close"]) for _ in range(3)] + fallback = FallbackChatGenerator(chat_generators=gens) + await fallback.close_async() + for gen in gens: + gen.close.assert_called_once() + + def test_lifecycle_is_safe_when_generators_lack_methods(self): + gens = [Mock(spec=["run"]) for _ in range(3)] + fallback = FallbackChatGenerator(chat_generators=gens) + fallback.warm_up() + fallback.close() diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index af9ac47cbd..af36840b72 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -7,7 +7,7 @@ import os from datetime import datetime from typing import Any -from unittest.mock import ANY, MagicMock, patch +from unittest.mock import ANY, AsyncMock, MagicMock, patch import pytest from openai import OpenAIError @@ -29,6 +29,7 @@ from openai.types.completion_usage import CompletionTokensDetails, CompletionUsage, PromptTokensDetails from pydantic import BaseModel +import haystack.components.generators.chat.openai as openai_chat_module from haystack import component from haystack.components.generators.chat.openai import ( OpenAIChatGenerator, @@ -197,20 +198,17 @@ def test_supported_models(self): def test_init_default(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") component = OpenAIChatGenerator() - assert component.client.api_key == "test-api-key" + assert component.api_key.resolve_value() == "test-api-key" assert component.model == "gpt-5-mini" assert component.streaming_callback is None assert not component.generation_kwargs - assert component.client.timeout == 30 - assert component.client.max_retries == 5 + assert component.timeout is None + assert component.max_retries is None assert component.tools is None assert not component.tools_strict assert component.http_client_kwargs is None - - def test_init_fail_wo_api_key(self, monkeypatch): - monkeypatch.delenv("OPENAI_API_KEY", raising=False) - with pytest.raises(ValueError): - OpenAIChatGenerator() + assert component.client is None + assert component.async_client is None def test_init_fail_with_duplicate_tool_names(self, monkeypatch, tools): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") @@ -235,15 +233,17 @@ def test_init_with_parameters(self, monkeypatch): tools_strict=True, http_client_kwargs={"proxy": "http://example.com:8080", "verify": False}, ) - assert component.client.api_key == "test-api-key" + assert component.api_key.resolve_value() == "test-api-key" assert component.model == "gpt-5-mini" assert component.streaming_callback is print_streaming_chunk assert component.generation_kwargs == {"max_completion_tokens": 10, "some_test_param": "test-params"} - assert component.client.timeout == 40.0 - assert component.client.max_retries == 1 + assert component.timeout == 40.0 + assert component.max_retries == 1 assert component.tools == [tool] assert component.tools_strict assert component.http_client_kwargs == {"proxy": "http://example.com:8080", "verify": False} + assert component.client is None + assert component.async_client is None def test_init_with_parameters_and_env_vars(self, monkeypatch): monkeypatch.setenv("OPENAI_TIMEOUT", "100") @@ -254,12 +254,14 @@ def test_init_with_parameters_and_env_vars(self, monkeypatch): api_base_url="test-base-url", generation_kwargs={"max_completion_tokens": 10, "some_test_param": "test-params"}, ) - assert component.client.api_key == "test-api-key" + assert component.api_key.resolve_value() == "test-api-key" assert component.model == "gpt-5-mini" assert component.streaming_callback is print_streaming_chunk assert component.generation_kwargs == {"max_completion_tokens": 10, "some_test_param": "test-params"} - assert component.client.timeout == 100.0 - assert component.client.max_retries == 10 + assert component.timeout is None + assert component.max_retries is None + assert component.client is None + assert component.async_client is None def test_to_dict_default(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") @@ -420,11 +422,11 @@ def test_from_dict(self, monkeypatch): Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) ] assert component.tools_strict - assert component.client.timeout == 100.0 - assert component.client.max_retries == 10 + assert component.timeout == 100.0 + assert component.max_retries == 10 assert component.http_client_kwargs == {"proxy": "http://example.com:8080", "verify": False} - def test_from_dict_fail_wo_env_var(self, monkeypatch): + def test_from_dict_wo_env_var_does_not_fail(self, monkeypatch): monkeypatch.delenv("OPENAI_API_KEY", raising=False) data = { "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", @@ -438,8 +440,9 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): "tools": None, }, } - with pytest.raises(ValueError): - OpenAIChatGenerator.from_dict(data) + component = OpenAIChatGenerator.from_dict(data) + assert component.client is None + assert component.async_client is None def test_run(self, chat_messages, openai_mock_chat_completion): component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) @@ -574,6 +577,7 @@ def streaming_callback(chunk: StreamingChunk) -> None: wrapped_openai_stream.__iter__.return_value = iter([chunk]) component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) + component.warm_up() with patch.object( component.client.chat.completions, "create", return_value=wrapped_openai_stream @@ -1218,112 +1222,6 @@ def test_serde_with_list_of_toolsets(self, monkeypatch, tools): assert len(deserialized.tools) == 2 assert all(isinstance(ts, Toolset) for ts in deserialized.tools) - def test_warm_up_with_tools(self, monkeypatch): - """Test that warm_up() calls warm_up on tools and is idempotent.""" - monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") - - # Create a mock tool that tracks if warm_up() was called - class MockTool(Tool): - warm_up_call_count = 0 # Class variable to track calls - - def __init__(self): - super().__init__( - name="mock_tool", - description="A mock tool for testing", - parameters={"x": {"type": "string"}}, - function=lambda x: x, - ) - - def warm_up(self): - MockTool.warm_up_call_count += 1 - - # Reset the class variable before test - MockTool.warm_up_call_count = 0 - mock_tool = MockTool() - - # Create OpenAIChatGenerator with the mock tool - component = OpenAIChatGenerator(tools=[mock_tool]) - - # Verify initial state - warm_up not called yet - assert MockTool.warm_up_call_count == 0 - assert not component._is_warmed_up - - # Call warm_up() on the generator - component.warm_up() - - # Assert that the tool's warm_up() was called - assert MockTool.warm_up_call_count == 1 - assert component._is_warmed_up - - # Call warm_up() again and verify it's idempotent (only warms up once) - component.warm_up() - - # The tool's warm_up should still only have been called once - assert MockTool.warm_up_call_count == 1 - assert component._is_warmed_up - - def test_warm_up_with_no_tools(self, monkeypatch): - """Test that warm_up() works when no tools are provided.""" - monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") - - component = OpenAIChatGenerator() - - # Verify initial state - assert not component._is_warmed_up - assert component.tools is None - - # Call warm_up() - should not raise an error - component.warm_up() - - # Verify the component is warmed up - assert component._is_warmed_up - - # Call warm_up() again - should be idempotent - component.warm_up() - assert component._is_warmed_up - - def test_warm_up_with_multiple_tools(self, monkeypatch): - """Test that warm_up() works with multiple tools.""" - monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") - - from haystack.tools import Tool - - # Track warm_up calls - warm_up_calls = [] - - class MockTool(Tool): - def __init__(self, tool_name): - super().__init__( - name=tool_name, - description=f"Mock tool {tool_name}", - parameters={"type": "object", "properties": {"x": {"type": "string"}}, "required": ["x"]}, - function=lambda x: f"{tool_name} result: {x}", - ) - - def warm_up(self): - warm_up_calls.append(self.name) - - mock_tool1 = MockTool("tool1") - mock_tool2 = MockTool("tool2") - - # Use a LIST of tools, not a Toolset - component = OpenAIChatGenerator(tools=[mock_tool1, mock_tool2]) - - # Call warm_up() - component.warm_up() - - # Assert that both tools' warm_up() were called - assert "tool1" in warm_up_calls - assert "tool2" in warm_up_calls - assert component._is_warmed_up - - # Track count - call_count = len(warm_up_calls) - - # Verify idempotency - component.warm_up() - assert len(warm_up_calls) == call_count - @pytest.fixture def chat_completion_chunks(): @@ -1761,6 +1659,123 @@ def streaming_chunks(): ] +@pytest.fixture +def mock_openai_clients(monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "fake") + sync_cls = MagicMock(name="OpenAI") + async_cls = MagicMock(name="AsyncOpenAI") + async_cls.return_value.close = AsyncMock() + monkeypatch.setattr(openai_chat_module, "OpenAI", sync_cls) + monkeypatch.setattr(openai_chat_module, "AsyncOpenAI", async_cls) + return sync_cls, async_cls + + +class TestComponentLifecycle: + def test_warm_up_uses_default_timeout_and_max_retries(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") + generator = OpenAIChatGenerator() + generator.warm_up() + assert generator.client.max_retries == 5 + assert generator.client.timeout == 30.0 + + def test_warm_up_uses_timeout_and_max_retries_from_parameters(self): + generator = OpenAIChatGenerator(api_key=Secret.from_token("fake-api-key"), timeout=40.0, max_retries=1) + generator.warm_up() + assert generator.client.max_retries == 1 + assert generator.client.timeout == 40.0 + + def test_warm_up_uses_timeout_and_max_retries_from_env_vars(self, monkeypatch): + monkeypatch.setenv("OPENAI_TIMEOUT", "100") + monkeypatch.setenv("OPENAI_MAX_RETRIES", "10") + generator = OpenAIChatGenerator(api_key=Secret.from_token("fake-api-key")) + generator.warm_up() + assert generator.client.max_retries == 10 + assert generator.client.timeout == 100.0 + + def test_key_resolved_at_warm_up_not_init(self, monkeypatch): + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + generator = OpenAIChatGenerator() + with pytest.raises(ValueError, match="None of the .* environment variables are set"): + generator.warm_up() + + def test_warm_up_warms_tools_once(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") + warm_up_calls = [] + + class MockTool(Tool): + def __init__(self, tool_name): + super().__init__( + name=tool_name, + description=f"Mock tool {tool_name}", + parameters={"type": "object", "properties": {"x": {"type": "string"}}, "required": ["x"]}, + function=lambda x: x, + ) + + def warm_up(self): + warm_up_calls.append(self.name) + + generator = OpenAIChatGenerator(tools=[MockTool("tool1"), MockTool("tool2")]) + assert not generator._tools_warmed_up + + generator.warm_up() + assert sorted(warm_up_calls) == ["tool1", "tool2"] + assert generator._tools_warmed_up + + generator.warm_up() + assert sorted(warm_up_calls) == ["tool1", "tool2"] + + def test_warm_up_with_no_tools_does_not_raise(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") + generator = OpenAIChatGenerator() + generator.warm_up() + assert generator._tools_warmed_up + + def test_sync_lifecycle(self, mock_openai_clients): + sync_cls, _ = mock_openai_clients + generator = OpenAIChatGenerator() + assert generator.client is None + assert generator.async_client is None + + generator.warm_up() + assert generator.client is sync_cls.return_value + assert generator.async_client is None + + generator.close() + sync_cls.return_value.close.assert_called_once() + assert generator.client is None + + async def test_async_lifecycle(self, mock_openai_clients): + _, async_cls = mock_openai_clients + generator = OpenAIChatGenerator() + + await generator.warm_up_async() + assert generator.async_client is async_cls.return_value + assert generator.client is None + + await generator.close_async() + async_cls.return_value.close.assert_awaited_once() + assert generator.async_client is None + + async def test_close_is_safe_without_warm_up(self, mock_openai_clients): + generator = OpenAIChatGenerator() + generator.close() + await generator.close_async() + assert generator.client is None + assert generator.async_client is None + + async def test_close_and_close_async_are_independent(self, mock_openai_clients): + generator = OpenAIChatGenerator() + generator.warm_up() + await generator.warm_up_async() + + generator.close() + assert generator.client is None + assert generator.async_client is not None + + await generator.close_async() + assert generator.async_client is None + + class TestChatCompletionChunkConversion: def test_convert_chat_completion_chunk_to_streaming_chunk(self, chat_completion_chunks, streaming_chunks): previous_chunks = [] diff --git a/test/components/generators/chat/test_openai_async.py b/test/components/generators/chat/test_openai_async.py index 7b7e39f959..ad0d0bd9c2 100644 --- a/test/components/generators/chat/test_openai_async.py +++ b/test/components/generators/chat/test_openai_async.py @@ -88,7 +88,7 @@ def tools(): class TestOpenAIChatGeneratorAsync: - def test_init_should_also_create_async_client_with_same_args(self, monkeypatch): + async def test_warm_up_async_should_create_async_client_with_same_args(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") component = OpenAIChatGenerator( api_key=Secret.from_token("test-api-key"), @@ -97,6 +97,7 @@ def test_init_should_also_create_async_client_with_same_args(self, monkeypatch): timeout=30, max_retries=5, ) + await component.warm_up_async() assert isinstance(component.async_client, AsyncOpenAI) assert component.async_client.api_key == "test-api-key" @@ -477,6 +478,7 @@ async def streaming_callback(chunk: StreamingChunk) -> None: wrapped_openai_async_stream.__aiter__.return_value = iter([chunk]) component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) + await component.warm_up_async() # Patch the async client's create method with patch.object( diff --git a/test/components/generators/chat/test_openai_responses.py b/test/components/generators/chat/test_openai_responses.py index 08800a1fbe..724c3b4131 100644 --- a/test/components/generators/chat/test_openai_responses.py +++ b/test/components/generators/chat/test_openai_responses.py @@ -5,12 +5,13 @@ import json import os from typing import Any -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest from openai import AsyncOpenAI, OpenAIError from pydantic import BaseModel +import haystack.components.generators.chat.openai_responses as openai_responses_module from haystack import component from haystack.components.agents import Agent from haystack.components.generators.chat.openai_responses import OpenAIResponsesChatGenerator @@ -112,21 +113,18 @@ def test_supported_models(self): def test_init_default(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") component = OpenAIResponsesChatGenerator() - assert component.client.api_key == "test-api-key" + assert component.client is None + assert component.async_client is None + assert component.api_key == Secret.from_env_var("OPENAI_API_KEY") assert component.model == "gpt-5-mini" assert component.streaming_callback is None assert not component.generation_kwargs - assert component.client.timeout == 30 - assert component.client.max_retries == 5 + assert component.timeout is None + assert component.max_retries is None assert component.tools is None assert not component.tools_strict assert component.http_client_kwargs is None - def test_init_fail_wo_api_key(self, monkeypatch): - monkeypatch.delenv("OPENAI_API_KEY", raising=False) - with pytest.raises(ValueError): - OpenAIResponsesChatGenerator() - def test_init_with_parameters(self, monkeypatch): tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=lambda x: x) @@ -144,12 +142,14 @@ def test_init_with_parameters(self, monkeypatch): tools_strict=True, http_client_kwargs={"proxy": "http://example.com:8080", "verify": False}, ) - assert component.client.api_key == "test-api-key" + assert component.client is None + assert component.async_client is None + assert component.api_key == Secret.from_token("test-api-key") assert component.model == "gpt-4o-mini" assert component.streaming_callback is print_streaming_chunk assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} - assert component.client.timeout == 40.0 - assert component.client.max_retries == 1 + assert component.timeout == 40.0 + assert component.max_retries == 1 assert component.tools == [tool] assert component.tools_strict assert component.http_client_kwargs == {"proxy": "http://example.com:8080", "verify": False} @@ -164,12 +164,14 @@ def test_init_with_parameters_and_env_vars(self, monkeypatch): api_base_url="test-base-url", generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, ) - assert component.client.api_key == "test-api-key" + assert component.client is None + assert component.async_client is None + assert component.api_key == Secret.from_token("test-api-key") assert component.model == "gpt-4o-mini" assert component.streaming_callback is print_streaming_chunk assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} - assert component.client.timeout == 100.0 - assert component.client.max_retries == 10 + assert component.timeout is None + assert component.max_retries is None def test_init_with_toolset(self, tools, monkeypatch): """Test that the OpenAIChatGenerator can be initialized with a Toolset.""" @@ -310,11 +312,11 @@ def test_from_dict(self, monkeypatch): Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) ] assert component.tools_strict - assert component.client.timeout == 100.0 - assert component.client.max_retries == 10 + assert component.timeout == 100.0 + assert component.max_retries == 10 assert component.http_client_kwargs == {"proxy": "http://example.com:8080", "verify": False} - def test_from_dict_fail_wo_env_var(self, monkeypatch): + def test_from_dict_wo_env_var_does_not_fail(self, monkeypatch): monkeypatch.delenv("OPENAI_API_KEY", raising=False) data = { "type": "haystack.components.generators.chat.openai_responses.OpenAIResponsesChatGenerator", @@ -328,8 +330,10 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): "tools": None, }, } - with pytest.raises(ValueError): - OpenAIResponsesChatGenerator.from_dict(data) + component = OpenAIResponsesChatGenerator.from_dict(data) + assert component.client is None + assert component.async_client is None + assert component.api_key == Secret.from_env_var("OPENAI_API_KEY") def test_from_dict_with_toolset(self, tools, monkeypatch): """Test that the OpenAIChatGenerator can be deserialized from a dictionary with a Toolset.""" @@ -345,68 +349,80 @@ def test_from_dict_with_toolset(self, tools, monkeypatch): assert all(isinstance(tool, Tool) for tool in deserialized_component.tools) -class TestWarmUp: - def test_warm_up_with_tools(self, monkeypatch): - """Test that warm_up() calls warm_up on tools and is idempotent.""" - monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") +@pytest.fixture +def mock_openai_clients(monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "fake") + sync_cls = MagicMock(name="OpenAI") + async_cls = MagicMock(name="AsyncOpenAI") + async_cls.return_value.close = AsyncMock() + monkeypatch.setattr(openai_responses_module, "OpenAI", sync_cls) + monkeypatch.setattr(openai_responses_module, "AsyncOpenAI", async_cls) + return sync_cls, async_cls + + +class TestComponentLifecycle: + def test_warm_up_uses_default_timeout_and_max_retries(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") + generator = OpenAIResponsesChatGenerator() + generator.warm_up() + assert generator.client.max_retries == 5 + assert generator.client.timeout == 30.0 + + def test_warm_up_uses_timeout_and_max_retries_from_parameters(self): + generator = OpenAIResponsesChatGenerator(api_key=Secret.from_token("fake-api-key"), timeout=40.0, max_retries=1) + generator.warm_up() + assert generator.client.max_retries == 1 + assert generator.client.timeout == 40.0 + + def test_warm_up_uses_timeout_and_max_retries_from_env_vars(self, monkeypatch): + monkeypatch.setenv("OPENAI_TIMEOUT", "100") + monkeypatch.setenv("OPENAI_MAX_RETRIES", "10") + generator = OpenAIResponsesChatGenerator(api_key=Secret.from_token("fake-api-key")) + generator.warm_up() + assert generator.client.max_retries == 10 + assert generator.client.timeout == 100.0 - # Create a mock tool that tracks if warm_up() was called - class MockTool(Tool): - warm_up_call_count = 0 # Class variable to track calls + def test_key_resolved_at_warm_up_not_init(self, monkeypatch): + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + generator = OpenAIResponsesChatGenerator() + with pytest.raises(ValueError, match="None of the .* environment variables are set"): + generator.warm_up() + + def test_warm_up_warms_tools_once(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") + warm_up_calls = [] - def __init__(self): + class MockTool(Tool): + def __init__(self, tool_name): super().__init__( - name="mock_tool", - description="A mock tool for testing", - parameters={"x": {"type": "string"}}, + name=tool_name, + description=f"Mock tool {tool_name}", + parameters={"type": "object", "properties": {"x": {"type": "string"}}, "required": ["x"]}, function=lambda x: x, ) def warm_up(self): - MockTool.warm_up_call_count += 1 - - # Reset the class variable before test - MockTool.warm_up_call_count = 0 - mock_tool = MockTool() - - # Create OpenAIChatGenerator with the mock tool - component = OpenAIResponsesChatGenerator(tools=[mock_tool]) - - # Verify initial state - warm_up not called yet - assert MockTool.warm_up_call_count == 0 - assert not component._is_warmed_up - - # Call warm_up() on the generator - component.warm_up() - - # Assert that the tool's warm_up() was called - assert MockTool.warm_up_call_count == 1 - assert component._is_warmed_up - - component.warm_up() + warm_up_calls.append(self.name) - # The tool's warm_up should still only have been called once - assert MockTool.warm_up_call_count == 1 - assert component._is_warmed_up + generator = OpenAIResponsesChatGenerator(tools=[MockTool("tool1"), MockTool("tool2")]) + assert not generator._tools_warmed_up - def test_warm_up_with_no_tools(self, monkeypatch): - """Test that warm_up() works when no tools are provided.""" - monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") + generator.warm_up() + assert sorted(warm_up_calls) == ["tool1", "tool2"] + assert generator._tools_warmed_up - component = OpenAIResponsesChatGenerator() - - # Verify initial state - assert not component._is_warmed_up - assert component.tools is None + generator.warm_up() + assert sorted(warm_up_calls) == ["tool1", "tool2"] - # Verify the component is warmed up - component.warm_up() - assert component._is_warmed_up - - def test_warm_up_with_openai_tools(self, monkeypatch): - monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") + def test_warm_up_with_no_tools_does_not_raise(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") + generator = OpenAIResponsesChatGenerator() + generator.warm_up() + assert generator._tools_warmed_up - component = OpenAIResponsesChatGenerator( + def test_warm_up_with_openai_tools_does_not_raise(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") + generator = OpenAIResponsesChatGenerator( tools=[ {"type": "web_search_preview"}, { @@ -418,49 +434,53 @@ def test_warm_up_with_openai_tools(self, monkeypatch): }, ] ) + generator.warm_up() + assert generator._tools_warmed_up - # Make sure the component can still be warmed up even when using openai tools - assert not component._is_warmed_up - component.warm_up() - assert component._is_warmed_up + def test_sync_lifecycle(self, mock_openai_clients): + sync_cls, _ = mock_openai_clients + generator = OpenAIResponsesChatGenerator() + assert generator.client is None + assert generator.async_client is None - def test_warm_up_with_multiple_tools(self, monkeypatch): - """Test that warm_up() works with multiple tools.""" - monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") + generator.warm_up() + assert generator.client is sync_cls.return_value + assert generator.async_client is None - from haystack.tools import Tool + generator.close() + sync_cls.return_value.close.assert_called_once() + assert generator.client is None - # Track warm_up calls - warm_up_calls = [] + async def test_async_lifecycle(self, mock_openai_clients): + _, async_cls = mock_openai_clients + generator = OpenAIResponsesChatGenerator() - class MockTool(Tool): - def __init__(self, tool_name): - super().__init__( - name=tool_name, - description=f"Mock tool {tool_name}", - parameters={"type": "object", "properties": {"x": {"type": "string"}}, "required": ["x"]}, - function=lambda x: f"{tool_name} result: {x}", - ) + await generator.warm_up_async() + assert generator.async_client is async_cls.return_value + assert generator.client is None - def warm_up(self): - warm_up_calls.append(self.name) + await generator.close_async() + async_cls.return_value.close.assert_awaited_once() + assert generator.async_client is None - mock_tool1 = MockTool("tool1") - mock_tool2 = MockTool("tool2") + async def test_close_is_safe_without_warm_up(self, mock_openai_clients): + generator = OpenAIResponsesChatGenerator() + generator.close() + await generator.close_async() + assert generator.client is None + assert generator.async_client is None - # Use a LIST of tools, not a Toolset - component = OpenAIResponsesChatGenerator(tools=[mock_tool1, mock_tool2]) + async def test_close_and_close_async_are_independent(self, mock_openai_clients): + generator = OpenAIResponsesChatGenerator() + generator.warm_up() + await generator.warm_up_async() - # Assert that both tools' warm_up() were called - component.warm_up() - assert "tool1" in warm_up_calls - assert "tool2" in warm_up_calls - assert component._is_warmed_up + generator.close() + assert generator.client is None + assert generator.async_client is not None - # Verify idempotency - call_count = len(warm_up_calls) - component.warm_up() - assert len(warm_up_calls) == call_count + await generator.close_async() + assert generator.async_client is None class TestRun: @@ -948,7 +968,7 @@ def retrieve_image(): class TestOpenAIResponsesChatGeneratorAsync: - def test_init_should_also_create_async_client_with_same_args(self, monkeypatch): + async def test_warm_up_async_creates_async_client_with_expected_args(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") component = OpenAIResponsesChatGenerator( api_key=Secret.from_token("test-api-key"), @@ -957,6 +977,9 @@ def test_init_should_also_create_async_client_with_same_args(self, monkeypatch): timeout=30, max_retries=5, ) + assert component.async_client is None + + await component.warm_up_async() assert isinstance(component.async_client, AsyncOpenAI) assert component.async_client.api_key == "test-api-key" diff --git a/test/components/generators/test_openai_image_generator.py b/test/components/generators/test_openai_image_generator.py index ce5ea6e09b..65abe223da 100644 --- a/test/components/generators/test_openai_image_generator.py +++ b/test/components/generators/test_openai_image_generator.py @@ -4,13 +4,14 @@ import base64 import os -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest from openai import AsyncOpenAI from openai.types import ImagesResponse from openai.types.image import Image +import haystack.components.generators.openai_image_generator as openai_image_generator_module from haystack.components.generators.openai_image_generator import OpenAIImageGenerator from haystack.utils import Secret @@ -34,9 +35,11 @@ def test_init_default(self, monkeypatch): assert component.api_key == Secret.from_env_var("OPENAI_API_KEY") assert component.api_base_url is None assert component.organization is None - assert pytest.approx(component.timeout) == 30.0 - assert component.max_retries == 5 + assert component.timeout is None + assert component.max_retries is None assert component.http_client_kwargs is None + assert component.client is None + assert component.async_client is None def test_init_with_params(self, monkeypatch): component = OpenAIImageGenerator( @@ -57,11 +60,10 @@ def test_init_with_params(self, monkeypatch): assert component.organization == "test-org" assert pytest.approx(component.timeout) == 60.0 assert component.max_retries == 10 + assert component.client is None + assert component.async_client is None def test_init_max_retries_0(self, monkeypatch): - """ - Test that the max_retries parameter is taken into account even if it is 0. - """ component = OpenAIImageGenerator(max_retries=0) assert component.max_retries == 0 @@ -74,14 +76,6 @@ def test_init_non_default_response_format_warns(self, caplog): OpenAIImageGenerator(response_format="url") # type: ignore[arg-type] assert "response_format is ignored" in caplog.text - def test_warm_up(self, monkeypatch): - monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") - component = OpenAIImageGenerator() - component.warm_up() - assert component.client.api_key == "test-api-key" - assert component.client.timeout == 30 - assert component.client.max_retries == 5 - def test_to_dict(self): generator = OpenAIImageGenerator() data = generator.to_dict() @@ -156,13 +150,14 @@ def test_from_dict_default_params(self): assert generator.api_key.to_dict() == {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True} assert generator.api_base_url is None assert generator.organization is None - assert pytest.approx(generator.timeout) == 30.0 - assert generator.max_retries == 5 + assert generator.timeout is None + assert generator.max_retries is None assert generator.http_client_kwargs is None def test_run(self, mock_image_response): generator = OpenAIImageGenerator(api_key=Secret.from_token("test-api-key")) response = generator.run("Show me a picture of a black cat.") + assert generator.client is not None assert isinstance(response, dict) assert "images" in response and "revised_prompt" in response assert response["images"] == ["test-b64-json"] @@ -193,17 +188,17 @@ def test_async_client_none_before_warm_up(self, monkeypatch): component = OpenAIImageGenerator() assert component.async_client is None - def test_async_client_after_warm_up(self, monkeypatch): + @pytest.mark.asyncio + async def test_async_client_after_warm_up_async(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") component = OpenAIImageGenerator() - component.warm_up() + await component.warm_up_async() assert isinstance(component.async_client, AsyncOpenAI) assert component.async_client.api_key == "test-api-key" @pytest.mark.asyncio async def test_run_async(self): generator = OpenAIImageGenerator(api_key=Secret.from_token("test-api-key")) - generator.warm_up() image_response = ImagesResponse( created=1630000000, data=[Image(b64_json="test-b64-json", revised_prompt="test-prompt")] @@ -254,3 +249,88 @@ async def test_live_run_async(self): decoded = base64.b64decode(image_str, validate=True) assert decoded.startswith(b"\x89PNG\r\n\x1a\n") + + +@pytest.fixture +def mock_openai_clients(monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "fake") + sync_cls = MagicMock(name="OpenAI") + async_cls = MagicMock(name="AsyncOpenAI") + async_cls.return_value.close = AsyncMock() + monkeypatch.setattr(openai_image_generator_module, "OpenAI", sync_cls) + monkeypatch.setattr(openai_image_generator_module, "AsyncOpenAI", async_cls) + return sync_cls, async_cls + + +class TestComponentLifecycle: + def test_warm_up_uses_default_timeout_and_max_retries(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") + generator = OpenAIImageGenerator() + generator.warm_up() + assert generator.client.max_retries == 5 + assert generator.client.timeout == 30.0 + + def test_warm_up_uses_timeout_and_max_retries_from_parameters(self): + generator = OpenAIImageGenerator(api_key=Secret.from_token("fake-api-key"), timeout=40.0, max_retries=1) + generator.warm_up() + assert generator.client.max_retries == 1 + assert generator.client.timeout == 40.0 + + def test_warm_up_uses_timeout_and_max_retries_from_env_vars(self, monkeypatch): + monkeypatch.setenv("OPENAI_TIMEOUT", "100") + monkeypatch.setenv("OPENAI_MAX_RETRIES", "10") + generator = OpenAIImageGenerator(api_key=Secret.from_token("fake-api-key")) + generator.warm_up() + assert generator.client.max_retries == 10 + assert generator.client.timeout == 100.0 + + def test_key_resolved_at_warm_up_not_init(self, monkeypatch): + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + generator = OpenAIImageGenerator() + with pytest.raises(ValueError, match="None of the .* environment variables are set"): + generator.warm_up() + + def test_sync_lifecycle(self, mock_openai_clients): + sync_cls, _ = mock_openai_clients + generator = OpenAIImageGenerator() + assert generator.client is None + assert generator.async_client is None + + generator.warm_up() + assert generator.client is sync_cls.return_value + assert generator.async_client is None + + generator.close() + sync_cls.return_value.close.assert_called_once() + assert generator.client is None + + async def test_async_lifecycle(self, mock_openai_clients): + _, async_cls = mock_openai_clients + generator = OpenAIImageGenerator() + + await generator.warm_up_async() + assert generator.async_client is async_cls.return_value + assert generator.client is None + + await generator.close_async() + async_cls.return_value.close.assert_awaited_once() + assert generator.async_client is None + + async def test_close_is_safe_without_warm_up(self, mock_openai_clients): + generator = OpenAIImageGenerator() + generator.close() + await generator.close_async() + assert generator.client is None + assert generator.async_client is None + + async def test_close_and_close_async_are_independent(self, mock_openai_clients): + generator = OpenAIImageGenerator() + generator.warm_up() + await generator.warm_up_async() + + generator.close() + assert generator.client is None + assert generator.async_client is not None + + await generator.close_async() + assert generator.async_client is None diff --git a/test/components/preprocessors/test_embedding_based_document_splitter.py b/test/components/preprocessors/test_embedding_based_document_splitter.py index 846618e163..96ea195f4f 100644 --- a/test/components/preprocessors/test_embedding_based_document_splitter.py +++ b/test/components/preprocessors/test_embedding_based_document_splitter.py @@ -53,45 +53,10 @@ def test_init_invalid_max_length(self): with pytest.raises(ValueError, match="max_length must be greater than min_length"): EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder, min_length=100, max_length=50) - def test_warm_up(self): - mock_embedder = Mock() - splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder) - - with patch( - "haystack.components.preprocessors.embedding_based_document_splitter.SentenceSplitter" - ) as mock_splitter_class: - mock_splitter = Mock() - mock_splitter_class.return_value = mock_splitter - - splitter.warm_up() - - assert splitter.sentence_splitter == mock_splitter - mock_splitter_class.assert_called_once() - - def test_run_not_warmed_up(self): - mock_embedder = Mock() - splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder) - - with patch.object(splitter, "warm_up", wraps=splitter.warm_up) as mock_warm_up: - splitter.run(documents=[]) - assert splitter._is_warmed_up - mock_warm_up.assert_called_once() - - @pytest.mark.asyncio - async def test_run_not_warmed_up_async(self) -> None: - mock_embedder = AsyncMock() - splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder) - - with patch.object(splitter, "warm_up", wraps=splitter.warm_up) as mock_warm_up: - await splitter.run_async(documents=[]) - assert splitter._is_warmed_up - mock_warm_up.assert_called_once() - def test_run_invalid_input(self): mock_embedder = Mock() splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder) splitter.sentence_splitter = Mock() - splitter._is_warmed_up = True with pytest.raises(TypeError, match="expects a List of Documents"): splitter.run(documents="not a list") @@ -101,7 +66,6 @@ async def test_run_invalid_input_async(self) -> None: mock_embedder = AsyncMock() splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder) splitter.sentence_splitter = AsyncMock() - splitter._is_warmed_up = True with pytest.raises(TypeError, match="expects a List of Documents"): await splitter.run_async(documents="not a list") @@ -110,7 +74,6 @@ def test_run_document_with_none_content(self): mock_embedder = Mock() splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder) splitter.sentence_splitter = Mock() - splitter._is_warmed_up = True with pytest.raises(ValueError, match="content for document ID"): splitter.run(documents=[Document(content=None)]) @@ -120,7 +83,6 @@ async def test_run_document_with_none_content_async(self) -> None: mock_embedder = AsyncMock() splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder) splitter.sentence_splitter = AsyncMock() - splitter._is_warmed_up = True with pytest.raises(ValueError, match="content for document ID"): await splitter.run_async(documents=[Document(content=None)]) @@ -129,7 +91,6 @@ def test_run_empty_document(self): mock_embedder = Mock() splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder) splitter.sentence_splitter = Mock() - splitter._is_warmed_up = True result = splitter.run(documents=[Document(content="")]) assert result["documents"] == [] @@ -139,7 +100,6 @@ async def test_run_empty_document_async(self) -> None: mock_embedder = AsyncMock() splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder) splitter.sentence_splitter = AsyncMock() - splitter._is_warmed_up = True result = await splitter.run_async(documents=[Document(content="")]) assert result["documents"] == [] @@ -777,3 +737,97 @@ async def test_split_large_splits_actually_splits_async(self) -> None: assert split_doc.meta["page_number"] == 3 if i in [9, 10]: assert split_doc.meta["page_number"] == 4 + + +class TestComponentLifecycle: + def test_warm_up_builds_splitter_and_delegates_to_embedder(self): + mock_embedder = Mock() + splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder) + + with patch( + "haystack.components.preprocessors.embedding_based_document_splitter.SentenceSplitter" + ) as mock_splitter_class: + splitter.warm_up() + + assert splitter.sentence_splitter is mock_splitter_class.return_value + mock_splitter_class.assert_called_once() + mock_embedder.warm_up.assert_called_once() + + def test_warm_up_builds_splitter_once(self): + mock_embedder = Mock() + splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder) + + with patch( + "haystack.components.preprocessors.embedding_based_document_splitter.SentenceSplitter" + ) as mock_splitter_class: + splitter.warm_up() + first_splitter = splitter.sentence_splitter + splitter.warm_up() + + mock_splitter_class.assert_called_once() + assert splitter.sentence_splitter is first_splitter + + @pytest.mark.asyncio + async def test_warm_up_async_delegates_to_embedder_async(self) -> None: + mock_embedder = Mock() + mock_embedder.warm_up_async = AsyncMock() + splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder) + + with patch("haystack.components.preprocessors.embedding_based_document_splitter.SentenceSplitter"): + await splitter.warm_up_async() + + mock_embedder.warm_up_async.assert_awaited_once() + + @pytest.mark.asyncio + async def test_warm_up_async_falls_back_to_sync_warm_up(self) -> None: + mock_embedder = Mock(spec=["warm_up"]) + splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder) + + with patch("haystack.components.preprocessors.embedding_based_document_splitter.SentenceSplitter"): + await splitter.warm_up_async() + + mock_embedder.warm_up.assert_called_once() + + def test_close_delegates_to_embedder(self): + mock_embedder = Mock() + splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder) + + splitter.close() + + mock_embedder.close.assert_called_once() + + @pytest.mark.asyncio + async def test_close_async_delegates_to_embedder(self) -> None: + mock_embedder = Mock() + mock_embedder.close_async = AsyncMock() + splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder) + + await splitter.close_async() + + mock_embedder.close_async.assert_awaited_once() + + @pytest.mark.asyncio + async def test_close_async_falls_back_to_sync_close(self) -> None: + mock_embedder = Mock(spec=["close"]) + splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder) + + await splitter.close_async() + + mock_embedder.close.assert_called_once() + + def test_lifecycle_is_safe_when_embedder_lacks_methods(self): + mock_embedder = Mock(spec=[]) + splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder) + + with patch("haystack.components.preprocessors.embedding_based_document_splitter.SentenceSplitter"): + splitter.warm_up() + splitter.close() + + @pytest.mark.asyncio + async def test_lifecycle_is_safe_when_embedder_lacks_methods_async(self) -> None: + mock_embedder = Mock(spec=[]) + splitter = EmbeddingBasedDocumentSplitter(document_embedder=mock_embedder) + + with patch("haystack.components.preprocessors.embedding_based_document_splitter.SentenceSplitter"): + await splitter.warm_up_async() + await splitter.close_async() diff --git a/test/components/preprocessors/test_recursive_splitter.py b/test/components/preprocessors/test_recursive_splitter.py index d91ac883bd..bf95695d25 100644 --- a/test/components/preprocessors/test_recursive_splitter.py +++ b/test/components/preprocessors/test_recursive_splitter.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import re +from unittest.mock import Mock import pytest from pytest import LogCaptureFixture @@ -981,3 +982,39 @@ def test_recursive_splitter_generates_unique_ids_and_correct_meta(): for idx, chunk in enumerate(chunks): assert chunk.meta["parent_id"] == source_doc.id assert chunk.meta["split_id"] == idx + + +def test_warm_up_is_idempotent_sentence(monkeypatch): + splitter = RecursiveDocumentSplitter(separators=["sentence", " "]) + + calls = [] + original = RecursiveDocumentSplitter._get_custom_sentence_tokenizer + + def spy(params): + calls.append(params) + return original(params) + + monkeypatch.setattr(RecursiveDocumentSplitter, "_get_custom_sentence_tokenizer", staticmethod(spy)) + + splitter.warm_up() + first_tokenizer = splitter.nltk_tokenizer + splitter.warm_up() + + assert len(calls) == 1 + assert splitter.nltk_tokenizer is first_tokenizer + + +def test_warm_up_is_idempotent_token(monkeypatch): + import haystack.components.preprocessors.recursive_splitter as mod + + sentinel = object() + get_encoding = Mock(return_value=sentinel) + monkeypatch.setattr(mod.tiktoken, "get_encoding", get_encoding) + + splitter = RecursiveDocumentSplitter(split_unit="token", split_length=10) + + splitter.warm_up() + splitter.warm_up() + + assert get_encoding.call_count == 1 + assert splitter.tiktoken_tokenizer is sentinel diff --git a/test/components/query/test_query_expander.py b/test/components/query/test_query_expander.py index 7a7d8363ae..3bd7ded786 100644 --- a/test/components/query/test_query_expander.py +++ b/test/components/query/test_query_expander.py @@ -18,13 +18,6 @@ def mock_chat_generator(): return Mock(spec=OpenAIChatGenerator) -@pytest.fixture -def mock_chat_generator_with_warm_up(): - mock_generator = Mock(spec=OpenAIChatGenerator) - mock_generator.warm_up = lambda: None - return mock_generator - - class TestQueryExpander: def test_init_default_generator(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-key-12345") @@ -42,20 +35,13 @@ def test_init_custom_generator(self, mock_chat_generator): assert expander.n_expansions == 3 assert expander.chat_generator is mock_chat_generator - def test_run_warm_up(self, mock_chat_generator_with_warm_up): - expander = QueryExpander(chat_generator=mock_chat_generator_with_warm_up) - mock_chat_generator_with_warm_up.run.return_value = {"queries": ["test query"]} + def test_run_warms_up_chat_generator(self, mock_chat_generator): + expander = QueryExpander(chat_generator=mock_chat_generator) + mock_chat_generator.run.return_value = {"replies": [ChatMessage.from_assistant("1. test query")]} - expander.warm_up() expander.run("test query") - assert expander._is_warmed_up is True - assert expander.run("test query") == {"queries": ["test query"]} - - def test_warm_up(self, mock_chat_generator): - expander = QueryExpander(chat_generator=mock_chat_generator) - expander.warm_up() - assert expander._is_warmed_up is True + mock_chat_generator.warm_up.assert_called() def test_init_negative_expansions_raises_error(self): with pytest.raises(ValueError, match="n_expansions must be positive"): @@ -492,3 +478,45 @@ def test_different_domains(self, chat_generator): # Should be different from original assert query not in result["queries"] + + +class TestComponentLifecycle: + def test_warm_up_delegates_to_chat_generator(self, mock_chat_generator): + expander = QueryExpander(chat_generator=mock_chat_generator) + expander.warm_up() + mock_chat_generator.warm_up.assert_called_once() + + async def test_warm_up_async_delegates_to_chat_generator(self, mock_chat_generator): + mock_chat_generator.warm_up_async = AsyncMock() + expander = QueryExpander(chat_generator=mock_chat_generator) + await expander.warm_up_async() + mock_chat_generator.warm_up_async.assert_awaited_once() + + async def test_warm_up_async_falls_back_to_sync_warm_up(self): + chat_generator = Mock(spec=["run", "warm_up"]) + expander = QueryExpander(chat_generator=chat_generator) + await expander.warm_up_async() + chat_generator.warm_up.assert_called_once() + + def test_close_delegates_to_chat_generator(self, mock_chat_generator): + expander = QueryExpander(chat_generator=mock_chat_generator) + expander.close() + mock_chat_generator.close.assert_called_once() + + async def test_close_async_delegates_to_chat_generator(self, mock_chat_generator): + mock_chat_generator.close_async = AsyncMock() + expander = QueryExpander(chat_generator=mock_chat_generator) + await expander.close_async() + mock_chat_generator.close_async.assert_awaited_once() + + async def test_close_async_falls_back_to_sync_close(self): + chat_generator = Mock(spec=["run", "close"]) + expander = QueryExpander(chat_generator=chat_generator) + await expander.close_async() + chat_generator.close.assert_called_once() + + def test_lifecycle_is_safe_when_chat_generator_lacks_methods(self): + chat_generator = Mock(spec=["run"]) + expander = QueryExpander(chat_generator=chat_generator) + expander.warm_up() + expander.close() diff --git a/test/components/rankers/test_llm_ranker.py b/test/components/rankers/test_llm_ranker.py index ab97e9fbc4..077999133f 100644 --- a/test/components/rankers/test_llm_ranker.py +++ b/test/components/rankers/test_llm_ranker.py @@ -437,3 +437,45 @@ async def test_live_run_async_ranks_berlin_first_for_germany_query(self): assert result["documents"] assert result["documents"][0].id == "doc-berlin" assert len(result["documents"]) <= 2 + + +class TestComponentLifecycle: + def test_warm_up_delegates_to_chat_generator(self, mock_chat_generator): + ranker = LLMRanker(chat_generator=mock_chat_generator) + ranker.warm_up() + mock_chat_generator.warm_up.assert_called_once() + + async def test_warm_up_async_delegates_to_chat_generator(self, mock_chat_generator): + mock_chat_generator.warm_up_async = AsyncMock() + ranker = LLMRanker(chat_generator=mock_chat_generator) + await ranker.warm_up_async() + mock_chat_generator.warm_up_async.assert_awaited_once() + + async def test_warm_up_async_falls_back_to_sync_warm_up(self): + chat_generator = Mock(spec=["run", "warm_up"]) + ranker = LLMRanker(chat_generator=chat_generator) + await ranker.warm_up_async() + chat_generator.warm_up.assert_called_once() + + def test_close_delegates_to_chat_generator(self, mock_chat_generator): + ranker = LLMRanker(chat_generator=mock_chat_generator) + ranker.close() + mock_chat_generator.close.assert_called_once() + + async def test_close_async_delegates_to_chat_generator(self, mock_chat_generator): + mock_chat_generator.close_async = AsyncMock() + ranker = LLMRanker(chat_generator=mock_chat_generator) + await ranker.close_async() + mock_chat_generator.close_async.assert_awaited_once() + + async def test_close_async_falls_back_to_sync_close(self): + chat_generator = Mock(spec=["run", "close"]) + ranker = LLMRanker(chat_generator=chat_generator) + await ranker.close_async() + chat_generator.close.assert_called_once() + + def test_lifecycle_is_safe_when_chat_generator_lacks_methods(self): + chat_generator = Mock(spec=["run"]) + ranker = LLMRanker(chat_generator=chat_generator) + ranker.warm_up() + ranker.close() diff --git a/test/components/retrievers/test_multi_query_embedding_retriever.py b/test/components/retrievers/test_multi_query_embedding_retriever.py index 6e2eb0966c..161728e690 100644 --- a/test/components/retrievers/test_multi_query_embedding_retriever.py +++ b/test/components/retrievers/test_multi_query_embedding_retriever.py @@ -4,7 +4,7 @@ import os from typing import Any -from unittest.mock import ANY +from unittest.mock import ANY, AsyncMock, Mock import numpy as np import pytest @@ -280,3 +280,66 @@ def test_pipeline_integration(self, document_store_with_embeddings): # assert there are not duplicates ids = [doc.id for doc in results["multiquery_retriever"]["documents"]] assert len(ids) == len(set(ids)) + + +class TestComponentLifecycle: + def test_warm_up_delegates_to_inner_components(self): + query_embedder = Mock(spec=["run", "warm_up"]) + retriever = Mock(spec=["run", "warm_up"]) + component = MultiQueryEmbeddingRetriever(retriever=retriever, query_embedder=query_embedder) + component.warm_up() + query_embedder.warm_up.assert_called_once() + retriever.warm_up.assert_called_once() + + async def test_warm_up_async_delegates_to_inner_components(self): + query_embedder = Mock(spec=["run", "warm_up_async"]) + query_embedder.warm_up_async = AsyncMock() + retriever = Mock(spec=["run", "warm_up_async"]) + retriever.warm_up_async = AsyncMock() + component = MultiQueryEmbeddingRetriever(retriever=retriever, query_embedder=query_embedder) + await component.warm_up_async() + query_embedder.warm_up_async.assert_awaited_once() + retriever.warm_up_async.assert_awaited_once() + + async def test_warm_up_async_falls_back_to_sync_warm_up(self): + query_embedder = Mock(spec=["run", "warm_up"]) + retriever = Mock(spec=["run", "warm_up"]) + component = MultiQueryEmbeddingRetriever(retriever=retriever, query_embedder=query_embedder) + await component.warm_up_async() + query_embedder.warm_up.assert_called_once() + retriever.warm_up.assert_called_once() + + def test_close_delegates_to_inner_components(self): + query_embedder = Mock(spec=["run", "close"]) + retriever = Mock(spec=["run", "close"]) + component = MultiQueryEmbeddingRetriever(retriever=retriever, query_embedder=query_embedder) + component.close() + query_embedder.close.assert_called_once() + retriever.close.assert_called_once() + + async def test_close_async_delegates_to_inner_components(self): + query_embedder = Mock(spec=["run", "close_async"]) + query_embedder.close_async = AsyncMock() + retriever = Mock(spec=["run", "close_async"]) + retriever.close_async = AsyncMock() + component = MultiQueryEmbeddingRetriever(retriever=retriever, query_embedder=query_embedder) + await component.close_async() + query_embedder.close_async.assert_awaited_once() + retriever.close_async.assert_awaited_once() + + async def test_close_async_falls_back_to_sync_close(self): + query_embedder = Mock(spec=["run", "close"]) + retriever = Mock(spec=["run", "close"]) + component = MultiQueryEmbeddingRetriever(retriever=retriever, query_embedder=query_embedder) + await component.close_async() + query_embedder.close.assert_called_once() + retriever.close.assert_called_once() + + async def test_lifecycle_is_safe_when_inner_components_lack_methods(self): + query_embedder = Mock(spec=["run"]) + retriever = Mock(spec=["run"]) + component = MultiQueryEmbeddingRetriever(retriever=retriever, query_embedder=query_embedder) + component.warm_up() + await component.warm_up_async() + component.close() + await component.close_async() diff --git a/test/components/retrievers/test_multi_query_text_retriever.py b/test/components/retrievers/test_multi_query_text_retriever.py index 2b22ebf02a..eeb8e7a832 100644 --- a/test/components/retrievers/test_multi_query_text_retriever.py +++ b/test/components/retrievers/test_multi_query_text_retriever.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import os -from unittest.mock import ANY +from unittest.mock import ANY, AsyncMock, Mock import pytest @@ -192,3 +192,51 @@ def test_pipeline_integration(self, document_store_with_docs): # assert there are not duplicates contents = [doc.content for doc in results["multiquery_retriever"]["documents"]] assert len(contents) == len(set(contents)) + + +class TestComponentLifecycle: + def test_warm_up_delegates_to_retriever(self): + retriever = Mock(spec=["run", "warm_up"]) + component = MultiQueryTextRetriever(retriever=retriever) + component.warm_up() + retriever.warm_up.assert_called_once() + + async def test_warm_up_async_delegates_to_retriever(self): + retriever = Mock(spec=["run", "warm_up_async"]) + retriever.warm_up_async = AsyncMock() + component = MultiQueryTextRetriever(retriever=retriever) + await component.warm_up_async() + retriever.warm_up_async.assert_awaited_once() + + async def test_warm_up_async_falls_back_to_sync_warm_up(self): + retriever = Mock(spec=["run", "warm_up"]) + component = MultiQueryTextRetriever(retriever=retriever) + await component.warm_up_async() + retriever.warm_up.assert_called_once() + + def test_close_delegates_to_retriever(self): + retriever = Mock(spec=["run", "close"]) + component = MultiQueryTextRetriever(retriever=retriever) + component.close() + retriever.close.assert_called_once() + + async def test_close_async_delegates_to_retriever(self): + retriever = Mock(spec=["run", "close_async"]) + retriever.close_async = AsyncMock() + component = MultiQueryTextRetriever(retriever=retriever) + await component.close_async() + retriever.close_async.assert_awaited_once() + + async def test_close_async_falls_back_to_sync_close(self): + retriever = Mock(spec=["run", "close"]) + component = MultiQueryTextRetriever(retriever=retriever) + await component.close_async() + retriever.close.assert_called_once() + + async def test_lifecycle_is_safe_when_retriever_lacks_methods(self): + retriever = Mock(spec=["run"]) + component = MultiQueryTextRetriever(retriever=retriever) + component.warm_up() + await component.warm_up_async() + component.close() + await component.close_async() diff --git a/test/components/retrievers/test_multi_retriever.py b/test/components/retrievers/test_multi_retriever.py index 068619150e..8a0d1f071c 100644 --- a/test/components/retrievers/test_multi_retriever.py +++ b/test/components/retrievers/test_multi_retriever.py @@ -4,7 +4,7 @@ import os from typing import Any -from unittest.mock import ANY +from unittest.mock import ANY, AsyncMock, Mock import pytest @@ -491,3 +491,66 @@ def test_emits_experimental_warning_on_init(self): @pytest.mark.filterwarnings("always::haystack.utils.experimental.ExperimentalWarning") def test_experimental_attribute_is_set(self): assert getattr(MultiRetriever, "__experimental__", False) is True + + +class TestComponentLifecycle: + def test_warm_up_delegates_to_all_retrievers(self): + a = Mock(spec=["run", "warm_up"]) + b = Mock(spec=["run", "warm_up"]) + retriever = MultiRetriever(retrievers={"a": a, "b": b}) + retriever.warm_up() + a.warm_up.assert_called_once() + b.warm_up.assert_called_once() + + async def test_warm_up_async_delegates_to_all_retrievers(self): + a = Mock(spec=["run", "warm_up_async"]) + a.warm_up_async = AsyncMock() + b = Mock(spec=["run", "warm_up_async"]) + b.warm_up_async = AsyncMock() + retriever = MultiRetriever(retrievers={"a": a, "b": b}) + await retriever.warm_up_async() + a.warm_up_async.assert_awaited_once() + b.warm_up_async.assert_awaited_once() + + async def test_warm_up_async_falls_back_to_sync_warm_up(self): + a = Mock(spec=["run", "warm_up"]) + b = Mock(spec=["run", "warm_up"]) + retriever = MultiRetriever(retrievers={"a": a, "b": b}) + await retriever.warm_up_async() + a.warm_up.assert_called_once() + b.warm_up.assert_called_once() + + def test_close_delegates_to_all_retrievers(self): + a = Mock(spec=["run", "close"]) + b = Mock(spec=["run", "close"]) + retriever = MultiRetriever(retrievers={"a": a, "b": b}) + retriever.close() + a.close.assert_called_once() + b.close.assert_called_once() + + async def test_close_async_delegates_to_all_retrievers(self): + a = Mock(spec=["run", "close_async"]) + a.close_async = AsyncMock() + b = Mock(spec=["run", "close_async"]) + b.close_async = AsyncMock() + retriever = MultiRetriever(retrievers={"a": a, "b": b}) + await retriever.close_async() + a.close_async.assert_awaited_once() + b.close_async.assert_awaited_once() + + async def test_close_async_falls_back_to_sync_close(self): + a = Mock(spec=["run", "close"]) + b = Mock(spec=["run", "close"]) + retriever = MultiRetriever(retrievers={"a": a, "b": b}) + await retriever.close_async() + a.close.assert_called_once() + b.close.assert_called_once() + + async def test_lifecycle_is_safe_when_retrievers_lack_methods(self): + a = Mock(spec=["run"]) + b = Mock(spec=["run"]) + retriever = MultiRetriever(retrievers={"a": a, "b": b}) + retriever.warm_up() + await retriever.warm_up_async() + retriever.close() + await retriever.close_async() diff --git a/test/components/retrievers/test_text_embedding_retriever.py b/test/components/retrievers/test_text_embedding_retriever.py index a015b3f26a..3880be0266 100644 --- a/test/components/retrievers/test_text_embedding_retriever.py +++ b/test/components/retrievers/test_text_embedding_retriever.py @@ -4,7 +4,7 @@ import os from typing import Any -from unittest.mock import ANY +from unittest.mock import ANY, AsyncMock, Mock import numpy as np import pytest @@ -189,3 +189,66 @@ def test_run_with_top_k(self, document_store_with_embeddings): result = retriever.run(query="energy", top_k=2) assert "documents" in result assert len(result["documents"]) <= 2 + + +class TestComponentLifecycle: + def test_warm_up_delegates_to_inner_components(self): + text_embedder = Mock(spec=["run", "warm_up"]) + retriever = Mock(spec=["run", "warm_up"]) + component = TextEmbeddingRetriever(retriever=retriever, text_embedder=text_embedder) + component.warm_up() + text_embedder.warm_up.assert_called_once() + retriever.warm_up.assert_called_once() + + async def test_warm_up_async_delegates_to_inner_components(self): + text_embedder = Mock(spec=["run", "warm_up_async"]) + text_embedder.warm_up_async = AsyncMock() + retriever = Mock(spec=["run", "warm_up_async"]) + retriever.warm_up_async = AsyncMock() + component = TextEmbeddingRetriever(retriever=retriever, text_embedder=text_embedder) + await component.warm_up_async() + text_embedder.warm_up_async.assert_awaited_once() + retriever.warm_up_async.assert_awaited_once() + + async def test_warm_up_async_falls_back_to_sync_warm_up(self): + text_embedder = Mock(spec=["run", "warm_up"]) + retriever = Mock(spec=["run", "warm_up"]) + component = TextEmbeddingRetriever(retriever=retriever, text_embedder=text_embedder) + await component.warm_up_async() + text_embedder.warm_up.assert_called_once() + retriever.warm_up.assert_called_once() + + def test_close_delegates_to_inner_components(self): + text_embedder = Mock(spec=["run", "close"]) + retriever = Mock(spec=["run", "close"]) + component = TextEmbeddingRetriever(retriever=retriever, text_embedder=text_embedder) + component.close() + text_embedder.close.assert_called_once() + retriever.close.assert_called_once() + + async def test_close_async_delegates_to_inner_components(self): + text_embedder = Mock(spec=["run", "close_async"]) + text_embedder.close_async = AsyncMock() + retriever = Mock(spec=["run", "close_async"]) + retriever.close_async = AsyncMock() + component = TextEmbeddingRetriever(retriever=retriever, text_embedder=text_embedder) + await component.close_async() + text_embedder.close_async.assert_awaited_once() + retriever.close_async.assert_awaited_once() + + async def test_close_async_falls_back_to_sync_close(self): + text_embedder = Mock(spec=["run", "close"]) + retriever = Mock(spec=["run", "close"]) + component = TextEmbeddingRetriever(retriever=retriever, text_embedder=text_embedder) + await component.close_async() + text_embedder.close.assert_called_once() + retriever.close.assert_called_once() + + async def test_lifecycle_is_safe_when_inner_components_lack_methods(self): + text_embedder = Mock(spec=["run"]) + retriever = Mock(spec=["run"]) + component = TextEmbeddingRetriever(retriever=retriever, text_embedder=text_embedder) + component.warm_up() + await component.warm_up_async() + component.close() + await component.close_async() diff --git a/test/components/routers/test_llm_messages_router.py b/test/components/routers/test_llm_messages_router.py index 86ba5e6e3f..e9669a9714 100644 --- a/test/components/routers/test_llm_messages_router.py +++ b/test/components/routers/test_llm_messages_router.py @@ -47,7 +47,6 @@ def test_init(self): assert router._output_names == ["safe", "unsafe"] assert router._output_patterns == ["safe", "unsafe"] assert router._compiled_patterns == [re.compile(pattern) for pattern in ["safe", "unsafe"]] - assert router._is_warmed_up is False def test_init_errors(self): chat_generator = MockChatGenerator() @@ -63,23 +62,6 @@ def test_init_errors(self): chat_generator=chat_generator, output_names=["name1", "name2"], output_patterns=["pattern1"] ) - def test_warm_up_with_unwarmable_chat_generator(self): - chat_generator = MockChatGenerator() - router = LLMMessagesRouter( - chat_generator=chat_generator, output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"] - ) - router.warm_up() - assert router._is_warmed_up is True - - def test_warm_up_with_warmable_chat_generator(self): - chat_generator = Mock() - router = LLMMessagesRouter( - chat_generator=chat_generator, output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"] - ) - router.warm_up() - assert router._is_warmed_up is True - assert router._chat_generator.warm_up.call_count == 1 - def test_run_input_errors(self): router = LLMMessagesRouter( chat_generator=MockChatGenerator(), output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"] @@ -99,8 +81,6 @@ def test_run_no_warm_up_with_unwarmable_chat_generator(self): router.run([ChatMessage.from_user("Hello")]) def test_run_no_warm_up_with_warmable_chat_generator(self): - """Warm up is run automatically if not done before.""" - def mock_run(messages): return {"replies": [ChatMessage.from_assistant("safe")]} @@ -111,7 +91,6 @@ def mock_run(messages): ) router.run([ChatMessage.from_user("Hello")]) assert chat_generator.warm_up.call_count == 1 - assert router._is_warmed_up is True def test_run(self): router = LLMMessagesRouter( @@ -315,3 +294,54 @@ async def test_live_run_async(self): assert result["chat_generator_text"].lower() == "safe" assert "unsafe" not in result assert "unmatched" not in result + + +class TestComponentLifecycle: + def _make_router(self, chat_generator): + return LLMMessagesRouter( + chat_generator=chat_generator, output_names=["safe", "unsafe"], output_patterns=["safe", "unsafe"] + ) + + def test_warm_up_delegates_to_chat_generator(self): + chat_generator = Mock() + router = self._make_router(chat_generator) + router.warm_up() + chat_generator.warm_up.assert_called_once() + + async def test_warm_up_async_delegates_to_chat_generator(self): + chat_generator = Mock() + chat_generator.warm_up_async = AsyncMock() + router = self._make_router(chat_generator) + await router.warm_up_async() + chat_generator.warm_up_async.assert_awaited_once() + + async def test_warm_up_async_falls_back_to_sync_warm_up(self): + chat_generator = Mock(spec=["run", "warm_up"]) + router = self._make_router(chat_generator) + await router.warm_up_async() + chat_generator.warm_up.assert_called_once() + + def test_close_delegates_to_chat_generator(self): + chat_generator = Mock() + router = self._make_router(chat_generator) + router.close() + chat_generator.close.assert_called_once() + + async def test_close_async_delegates_to_chat_generator(self): + chat_generator = Mock() + chat_generator.close_async = AsyncMock() + router = self._make_router(chat_generator) + await router.close_async() + chat_generator.close_async.assert_awaited_once() + + async def test_close_async_falls_back_to_sync_close(self): + chat_generator = Mock(spec=["run", "close"]) + router = self._make_router(chat_generator) + await router.close_async() + chat_generator.close.assert_called_once() + + def test_lifecycle_is_safe_when_chat_generator_lacks_methods(self): + chat_generator = Mock(spec=["run"]) + router = self._make_router(chat_generator) + router.warm_up() + router.close() diff --git a/test/core/pipeline/test_pipeline_lifecycle.py b/test/core/pipeline/test_pipeline_lifecycle.py new file mode 100644 index 0000000000..e4a677acf2 --- /dev/null +++ b/test/core/pipeline/test_pipeline_lifecycle.py @@ -0,0 +1,209 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import asyncio + +import pytest + +from haystack import Pipeline, component + + +@component +class LifecycleRecorder: + """Records every lifecycle method called, so tests can assert which ones the pipeline picks.""" + + def __init__(self): + self.events = [] + + def warm_up(self): + self.events.append("warm_up") + + async def warm_up_async(self): + self.events.append("warm_up_async") + + @component.output_types(value=int) + def run(self): + self.events.append("run") + return {"value": 1} + + @component.output_types(value=int) + async def run_async(self): + self.events.append("run_async") + return {"value": 1} + + def close(self): + self.events.append("close") + + async def close_async(self): + self.events.append("close_async") + + +@component +class SyncOnlyRecorder: + """Implements only the synchronous warm_up and close, to exercise the async fallbacks.""" + + def __init__(self): + self.events = [] + + def warm_up(self): + self.events.append("warm_up") + + @component.output_types(value=int) + def run(self): + return {"value": 1} + + @component.output_types(value=int) + async def run_async(self): + return {"value": 1} + + def close(self): + self.events.append("close") + + +@component +class BareComponent: + """Implements no lifecycle method at all.""" + + @component.output_types(value=int) + def run(self): + return {"value": 1} + + @component.output_types(value=int) + async def run_async(self): + return {"value": 1} + + +class LoopBoundAsyncClient: + """Mimics a real async client (aiohttp): binds to the loop it is created on and refuses any other.""" + + def __init__(self): + self._loop = asyncio.get_running_loop() + + async def use(self): + if asyncio.get_running_loop() is not self._loop: + raise RuntimeError("async client used on a different event loop than the one it was created on") + + +@component +class AsyncClientComponent: + """Creates a loop-bound async client in warm_up_async and uses it in run_async.""" + + def __init__(self): + self.client: LoopBoundAsyncClient | None = None + + async def warm_up_async(self): + if self.client is None: + self.client = LoopBoundAsyncClient() + + @component.output_types(value=int) + async def run_async(self): + assert self.client is not None + await self.client.use() + return {"value": 1} + + @component.output_types(value=int) + def run(self): + return {"value": 1} + + +async def test_run_async_uses_warm_up_async(): + """When a component implements warm_up_async, run_async uses it and does not also call its sync warm_up.""" + rec = LifecycleRecorder() + pipe = Pipeline() + pipe.add_component("rec", rec) + await pipe.run_async({"rec": {}}) + assert "warm_up_async" in rec.events + assert "warm_up" not in rec.events + + +async def test_warm_up_async_falls_back_to_sync_warm_up(): + """A component with only the sync warm_up is still warmed by run_async through that method.""" + rec = SyncOnlyRecorder() + pipe = Pipeline() + pipe.add_component("rec", rec) + await pipe.run_async({"rec": {}}) + assert rec.events == ["warm_up"] + + +def test_sync_run_uses_sync_warm_up(): + """The sync run path warms components via the sync warm_up, never warm_up_async.""" + rec = LifecycleRecorder() + pipe = Pipeline() + pipe.add_component("rec", rec) + pipe.run({"rec": {}}) + assert "warm_up" in rec.events + assert "warm_up_async" not in rec.events + + +def test_pipeline_close_calls_sync_close_only(): + """Pipeline.close() calls each component's sync close, never close_async.""" + rec = LifecycleRecorder() + pipe = Pipeline() + pipe.add_component("rec", rec) + pipe.close() + assert "close" in rec.events + assert "close_async" not in rec.events + + +async def test_pipeline_close_async_calls_async_close_only(): + """When a component implements close_async, Pipeline.close_async() uses it and does not also call its sync close.""" + rec = LifecycleRecorder() + pipe = Pipeline() + pipe.add_component("rec", rec) + await pipe.close_async() + assert "close_async" in rec.events + assert "close" not in rec.events + + +async def test_close_async_falls_back_to_sync_close(): + """A component with only the sync close is still released by close_async through that method.""" + rec = SyncOnlyRecorder() + pipe = Pipeline() + pipe.add_component("rec", rec) + await pipe.close_async() + assert rec.events == ["close"] + + +async def test_run_does_not_auto_close(): + """Running a pipeline (sync or async) never closes components; closing is always explicit.""" + rec = LifecycleRecorder() + pipe = Pipeline() + pipe.add_component("rec", rec) + pipe.run({"rec": {}}) + await pipe.run_async({"rec": {}}) + assert "close" not in rec.events + assert "close_async" not in rec.events + + +async def test_lifecycle_methods_are_optional(): + """A component without lifecycle methods works: every call is hasattr-guarded and skipped.""" + pipe = Pipeline() + pipe.add_component("bare", BareComponent()) + await pipe.warm_up_async() + pipe.close() + await pipe.close_async() + await pipe.run_async({"bare": {}}) + + +def test_loop_bound_client_rejects_other_loop(): + """The fake client raises when used from a different loop. + + This ensures the affinity test below enforces loop binding. + """ + + async def _make_loop_bound_client(): + return LoopBoundAsyncClient() + + client = asyncio.run(_make_loop_bound_client()) + with pytest.raises(RuntimeError): + asyncio.run(client.use()) + + +async def test_async_client_bound_to_run_loop(): + """warm_up_async creates the async client on the loop run_async uses, so it stays usable there.""" + pipe = Pipeline() + pipe.add_component("client_component", AsyncClientComponent()) + await pipe.warm_up_async() + # Would raise if warm_up_async had bound the client to a different loop than run_async + await pipe.run_async({"client_component": {}}) diff --git a/test/core/super_component/test_super_component.py b/test/core/super_component/test_super_component.py index 7edde1bf9d..7f7e3a2fd2 100644 --- a/test/core/super_component/test_super_component.py +++ b/test/core/super_component/test_super_component.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Any -from unittest.mock import patch +from unittest.mock import AsyncMock, patch import pytest @@ -467,3 +467,33 @@ async def run_async(self): result = await deserialized_super_component.run_async() assert result == {"output": "Hello world"} + + +class TestSuperComponentLifecycle: + def test_warm_up_delegates_to_pipeline(self, sample_super_component): + with patch.object(sample_super_component.pipeline, "warm_up") as mock_warm_up: + sample_super_component.warm_up() + mock_warm_up.assert_called_once() + + def test_warm_up_is_idempotent(self, sample_super_component): + with patch.object(sample_super_component.pipeline, "warm_up") as mock_warm_up: + sample_super_component.warm_up() + sample_super_component.warm_up() + mock_warm_up.assert_called_once() + + @pytest.mark.asyncio + async def test_warm_up_async_delegates_to_pipeline(self, sample_super_component): + with patch.object(sample_super_component.pipeline, "warm_up_async", new=AsyncMock()) as mock_warm_up_async: + await sample_super_component.warm_up_async() + mock_warm_up_async.assert_awaited_once() + + def test_close_delegates_to_pipeline(self, sample_super_component): + with patch.object(sample_super_component.pipeline, "close") as mock_close: + sample_super_component.close() + mock_close.assert_called_once() + + @pytest.mark.asyncio + async def test_close_async_delegates_to_pipeline(self, sample_super_component): + with patch.object(sample_super_component.pipeline, "close_async", new=AsyncMock()) as mock_close_async: + await sample_super_component.close_async() + mock_close_async.assert_awaited_once()