diff --git a/ABB-Manual-Assistant/README.md b/ABB-Manual-Assistant/README.md new file mode 100644 index 0000000..916d7e2 --- /dev/null +++ b/ABB-Manual-Assistant/README.md @@ -0,0 +1,3 @@ +# Linamar-Vector-Bootcamp + +## Applying agentic ai to robot troubleshooting. diff --git a/ABB-Manual-Assistant/agent_utils/__init__.py b/ABB-Manual-Assistant/agent_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ABB-Manual-Assistant/agent_utils/memory_store.py b/ABB-Manual-Assistant/agent_utils/memory_store.py new file mode 100644 index 0000000..a3f4c8c --- /dev/null +++ b/ABB-Manual-Assistant/agent_utils/memory_store.py @@ -0,0 +1,67 @@ +import datetime +import os +import sqlite3 +import uuid + + +class MemoryStore: + def __init__(self, db_path="data/memory.db"): + os.makedirs(os.path.dirname(db_path), exist_ok=True) + self.conn = sqlite3.connect(db_path, check_same_thread=False) + self._init_db() + + def _init_db(self): + cur = self.conn.cursor() + cur.execute(""" + CREATE TABLE IF NOT EXISTS conversations ( + id TEXT PRIMARY KEY, + title TEXT, + created_at TEXT + ) + """) + cur.execute(""" + CREATE TABLE IF NOT EXISTS messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + conversation_id TEXT, + role TEXT, + content TEXT, + timestamp TEXT + ) + """) + self.conn.commit() + + def create_conversation(self, title=None): + cid = str(uuid.uuid4()) + if not title: + title = f"Chat {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}" + self.conn.execute( + "INSERT INTO conversations VALUES (?, ?, ?)", (cid, title, datetime.datetime.now().isoformat()) + ) + self.conn.commit() + return cid + + def list_conversations(self): + return self.conn.execute("SELECT id, title, created_at FROM conversations ORDER BY created_at DESC").fetchall() + + def get_history(self, conversation_id, limit=100): + rows = self.conn.execute( + "SELECT role, content FROM messages WHERE conversation_id=? ORDER BY id ASC LIMIT ?", + (conversation_id, limit), + ).fetchall() + return [{"role": r[0], "content": r[1]} for r in rows] + + def log_message(self, conversation_id, role, content): + self.conn.execute( + "INSERT INTO messages (conversation_id, role, content, timestamp) VALUES (?, ?, ?, ?)", + (conversation_id, role, content, datetime.datetime.now().isoformat()), + ) + self.conn.commit() + + def rename_conversation(self, conversation_id, new_title): + self.conn.execute("UPDATE conversations SET title=? WHERE id=?", (new_title, conversation_id)) + self.conn.commit() + + def delete_conversation(self, conversation_id): + self.conn.execute("DELETE FROM messages WHERE conversation_id=?", (conversation_id,)) + self.conn.execute("DELETE FROM conversations WHERE id=?", (conversation_id,)) + self.conn.commit() diff --git a/ABB-Manual-Assistant/conversation_manager.py b/ABB-Manual-Assistant/conversation_manager.py new file mode 100644 index 0000000..e1da969 --- /dev/null +++ b/ABB-Manual-Assistant/conversation_manager.py @@ -0,0 +1,73 @@ +# agents/conversation_manager.py +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from agent_utils.memory_store import MemoryStore + + +class ConversationManagerAgent: + """ + Minimal 'agent' wrapper around MemoryStore so other agents/tools + can call a stable interface. Keeps things simple: + - make sure a conversation exists + - append user/assistant turns + - read history + - rename/delete/list conversations + - optional: keep assistant partials in RAM; persist only on finalize + """ + + def __init__(self, store: Optional[MemoryStore] = None): + self.store = store or MemoryStore() + # simple in-RAM buffer for current assistant partials per conversation + self._partials: Dict[str, str] = {} + + # ---- conversation mgmt ---- + def ensure_conversation(self, conversation_id: Optional[str]) -> str: + rows = self.store.list_conversations() + if not rows: + return self.store.create_conversation() + if conversation_id: + return conversation_id + # default to newest (list_conversations should return DESC by created_at) + return rows[0][0] + + def list_conversations(self) -> List[Dict[str, Any]]: + rows = self.store.list_conversations() + return [{"id": r[0], "title": r[1], "created_at": r[2]} for r in rows] + + def rename(self, conversation_id: str, title: str) -> Dict[str, Any]: + self.store.rename_conversation(conversation_id, title) + return {"ok": True, "conversation_id": conversation_id, "title": title} + + def delete(self, conversation_id: str) -> Dict[str, Any]: + self.store.delete_conversation(conversation_id) + self._partials.pop(conversation_id, None) + return {"ok": True} + + def create(self, title: Optional[str] = None) -> Dict[str, Any]: + cid = self.store.create_conversation(title) + return {"ok": True, "conversation_id": cid} + + # ---- messages ---- + def save_user(self, conversation_id: str, content: str) -> Dict[str, Any]: + cid = self.ensure_conversation(conversation_id) + self.store.log_message(cid, "user", content) + return {"ok": True, "conversation_id": cid} + + def set_assistant_partial(self, conversation_id: str, partial_text: str) -> Dict[str, Any]: + # Keep partials in RAM for simplicity/quickness + cid = self.ensure_conversation(conversation_id) + self._partials[cid] = partial_text + return {"ok": True, "conversation_id": cid} + + def finalize_assistant(self, conversation_id: str) -> Dict[str, Any]: + cid = self.ensure_conversation(conversation_id) + final_text = self._partials.pop(cid, "") + # Only persist once at the end of the stream + self.store.log_message(cid, "assistant", final_text) + return {"ok": True, "conversation_id": cid, "content_len": len(final_text)} + + def get_history_messages(self, conversation_id: str, limit: int = 1000) -> List[Dict[str, Any]]: + cid = self.ensure_conversation(conversation_id) + return self.store.get_history(cid, limit=limit) diff --git a/ABB-Manual-Assistant/gitignore-2025.txt b/ABB-Manual-Assistant/gitignore-2025.txt new file mode 100644 index 0000000..cad5723 --- /dev/null +++ b/ABB-Manual-Assistant/gitignore-2025.txt @@ -0,0 +1,212 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[codz] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +scrape_with_LLM/ +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py.cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock +#poetry.toml + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. +# https://pdm-project.org/en/latest/usage/project/#working-with-version-control +#pdm.lock +#pdm.toml +.pdm-python +.pdm-build/ + +# pixi +# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. +#pixi.lock +# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one +# in the .venv directory. It is recommended not to include this directory in version control. +.pixi + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.envrc +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Max Added +.venv_temp +scrape_with_LLM/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Abstra +# Abstra is an AI-powered process automation framework. +# Ignore directories containing user credentials, local state, and settings. +# Learn more at https://abstra.io/docs +.abstra/ + +# Visual Studio Code +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore +# and can be added to the global gitignore or merged into this file. However, if you prefer, +# you could uncomment the following to ignore the entire vscode folder +# .vscode/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Cursor +# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to +# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data +# refer to https://docs.cursor.com/context/ignore-files +.cursorignore +.cursorindexingignore + +# Marimo +marimo/_static/ +marimo/_lsp/ +__marimo__/ diff --git a/ABB-Manual-Assistant/orchestrator_agent.py b/ABB-Manual-Assistant/orchestrator_agent.py new file mode 100644 index 0000000..c0062b1 --- /dev/null +++ b/ABB-Manual-Assistant/orchestrator_agent.py @@ -0,0 +1,62 @@ +import os + +import agents +from dotenv import load_dotenv +from openai import AsyncOpenAI +from openai.types.responses import ResponseTextDeltaEvent +from search_agent import SearchAgent +from workorder_agent import WorkorderAgent + + +load_dotenv() + + +class Orchestrator: + def __init__(self): + self.client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_BASE_URL")) + + search_agent_instance = SearchAgent() + self.search_agent_tool = search_agent_instance.search_agent.as_tool( + tool_name="search_knowledge_base", + tool_description="Agent searches ABB robot manuals for repair, troubleshooting, maintenance instructions, etc.", + ) + + workorder_agent_instance = WorkorderAgent() + self.workorder_agent_tool = workorder_agent_instance.workorder_agent.as_tool( + tool_name="workorder_agent", + tool_description="Given a conversation between the user and the orchestrator agent, the workorder agent will create a workorder.", + ) + + self.main_agent = agents.Agent( + name="Orchestrator Agent", + instructions=""" + You are a helpful assistant and organizer. + If the search agent doesn't find anything, use your own knowledge. + Always present the search agent's findings at the bottom of your output inside a collapsible section. + + If the user asks you to create a workorder, then call the workorder_agent. + """, + model=agents.OpenAIChatCompletionsModel(model="gemini-2.5-pro", openai_client=self.client), + model_settings=agents.ModelSettings(tool_choice="required", temperature=0.5), + tools=[self.search_agent_tool, self.workorder_agent_tool], + ) + + async def run(self, prompt: str, history) -> str: + context = "" + + # reconstruct conversation history into a string + if history: + for i in range(0, len(history), 2): + if i + 1 < len(history): + user_msg = history[i]["content"] + bot_msg = history[i + 1]["content"] + context += f"User: {user_msg}\nAssistant: {bot_msg}\n" + + # combine history with the new prompt + full_prompt = f"{context} User: {prompt}" + + result_stream = agents.Runner.run_streamed(self.main_agent, input=full_prompt) + async for event in result_stream.stream_events(): + if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent): + print(event.data.delta, end="", flush=True) + yield event.data.delta diff --git a/ABB-Manual-Assistant/req_temp.txt b/ABB-Manual-Assistant/req_temp.txt new file mode 100644 index 0000000..fcd9eca --- /dev/null +++ b/ABB-Manual-Assistant/req_temp.txt @@ -0,0 +1,49 @@ +annotated-types==0.7.0 +anyio==4.10.0 +attrs==25.3.0 +Authlib==1.6.1 +certifi==2025.8.3 +cffi==1.17.1 +charset-normalizer==3.4.2 +click==8.2.1 +colorama==0.4.6 +cryptography==45.0.5 +deprecation==2.1.0 +distro==1.9.0 +griffe==1.9.0 +grpcio==1.74.0 +grpcio-health-checking==1.74.0 +h11==0.16.0 +httpcore==1.0.9 +httpx==0.28.1 +httpx-sse==0.4.1 +idna==3.10 +jiter==0.10.0 +jsonschema==4.25.0 +jsonschema-specifications==2025.4.1 +mcp==1.12.3 +openai==1.99.0 +openai-agents==0.2.4 +packaging==25.0 +protobuf==6.31.1 +pycparser==2.22 +pydantic==2.11.7 +pydantic-settings==2.10.1 +pydantic_core==2.33.2 +python-dotenv==1.1.1 +python-multipart==0.0.20 +referencing==0.36.2 +requests==2.32.4 +rpds-py==0.26.0 +sniffio==1.3.1 +sse-starlette==3.0.2 +starlette==0.47.2 +tqdm==4.67.1 +types-requests==2.32.4.20250611 +typing-inspection==0.4.1 +typing_extensions==4.14.1 +urllib3==2.5.0 +uvicorn==0.35.0 +validators==0.35.0 +weaviate==0.1.2 +weaviate-client==4.16.5 diff --git a/ABB-Manual-Assistant/run_eval.py b/ABB-Manual-Assistant/run_eval.py new file mode 100644 index 0000000..051e1ec --- /dev/null +++ b/ABB-Manual-Assistant/run_eval.py @@ -0,0 +1,151 @@ +""" +This script runs the ABB Manual assistent agent on a Langfuse dataset and evaluates it;s responses using an LLM as a judge. Results are uploaded to langfuse for traceability +Include the following when you run this script: +--langfuse_dataset_name: Name of the dataset in Langfuse +--run_name: Label for this evaluation run +--limit: (Optional) Number of items to evaluate +Example +python run_eval.py --langfuse_dataset_name LLM_Judge_Errors --run_name ABB_Eval_Run_01 --limit 10 +""" + +import argparse +import asyncio +import os + +from agents import Agent, OpenAIChatCompletionsModel, Runner +from dotenv import load_dotenv +from langfuse import get_client +from langfuse._client.datasets import DatasetItemClient +from openai import AsyncOpenAI +from orchestrator_agent import Orchestrator +from pydantic import BaseModel +from rich.progress import track +from utils import setup_langfuse_tracer +from utils.langfuse.shared_client import flush_langfuse, langfuse_client + + +# --- Load environment and Langfuse --- +load_dotenv() +langfuse = get_client() +setup_langfuse_tracer() + +# Load your OpenAI API key from .env or environment +openai_api_key = os.getenv("OPENAI_API_KEY") + +# Create the async client +async_openai_client = AsyncOpenAI(api_key=openai_api_key) + + +# --- Evaluation Prompt Templates --- +EVALUATOR_INSTRUCTIONS = "Evaluate whether the 'Proposed Answer' to the given 'Question' matches the 'Ground Truth'." +EVALUATOR_TEMPLATE = """\ +# Question +{question} + +# Ground Truth +{ground_truth} + +# Proposed Answer +{proposed_response} +""" + + +# --- Data Models --- +class LangFuseTracedResponse(BaseModel): + answer: str | None + trace_id: str | None + + +class EvaluatorQuery(BaseModel): + question: str + ground_truth: str + proposed_response: str + + def get_query(self) -> str: + return EVALUATOR_TEMPLATE.format(**self.model_dump()) + + +class EvaluatorResponse(BaseModel): + explanation: str + is_answer_correct: bool + + +# --- Agent Execution --- +async def run_agent_with_trace(orchestrator: Orchestrator, query: str) -> LangFuseTracedResponse: + try: + result = await orchestrator.run(query) + answer = getattr(result, "final_output", str(result)) + except Exception: + answer = None + + return LangFuseTracedResponse(answer=answer, trace_id=langfuse_client.get_current_trace_id()) + + +# --- Evaluation Agent --- +async def run_evaluator_agent(evaluator_query: EvaluatorQuery) -> EvaluatorResponse: + evaluator_agent = Agent( + name="ABB Evaluator", + instructions=EVALUATOR_INSTRUCTIONS, + output_type=EvaluatorResponse, + model=OpenAIChatCompletionsModel(model="gemini-2.5-flash", openai_client=async_openai_client), + ) + result = await Runner.run(evaluator_agent, input=evaluator_query.get_query()) + return result.final_output_as(EvaluatorResponse) + + +# --- Main Evaluation Loop --- +async def run_and_evaluate(run_name: str, orchestrator: Orchestrator, item: DatasetItemClient): + expected_output = item.expected_output + assert expected_output is not None + + with item.run(run_name=run_name) as span: + span.update(input=item.input["text"]) + traced_response = await run_agent_with_trace(orchestrator, item.input["text"]) + span.update(output=traced_response.answer) + + print(f"Running query: {item.input['text']}") + print(f"Agent response: {traced_response.answer}") + + if traced_response.answer is None: + return traced_response, None + + evaluator_response = await run_evaluator_agent( + EvaluatorQuery( + question=item.input["text"], ground_truth=expected_output["text"], proposed_response=traced_response.answer + ) + ) + + return traced_response, evaluator_response + + +# --- CLI Entrypoint --- +parser = argparse.ArgumentParser() +parser.add_argument("--langfuse_dataset_name", required=True) +parser.add_argument("--run_name", required=True) +parser.add_argument("--limit", type=int) + +if __name__ == "__main__": + args = parser.parse_args() + + items = langfuse.get_dataset(args.langfuse_dataset_name).items + if args.limit: + items = items[: args.limit] + + orchestrator = Orchestrator() + coros = [run_and_evaluate(args.run_name, orchestrator, item) for item in items] + + async def main(): + return await asyncio.gather(*coros) + + results = asyncio.run(main()) + + for traced_response, eval_output in track(results, total=len(results), description="Uploading scores"): + if eval_output: + langfuse_client.create_score( + name="is_answer_correct", + value=eval_output.is_answer_correct, + comment=eval_output.explanation, + trace_id=traced_response.trace_id, + ) + + flush_langfuse() diff --git a/ABB-Manual-Assistant/search_agent.py b/ABB-Manual-Assistant/search_agent.py new file mode 100644 index 0000000..f43a162 --- /dev/null +++ b/ABB-Manual-Assistant/search_agent.py @@ -0,0 +1,68 @@ +import os + +import agents +from dotenv import load_dotenv +from openai import AsyncOpenAI +from search_tool import Weaviate + + +load_dotenv() + + +class SearchAgent: + def __init__(self): + self.client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_BASE_URL")) + + self.knowledge_tool = agents.function_tool( + self.search_knowledgebase, + name_override="knowledge_search", + description_override="Searches the ABB robot manual vector database for the most relevant technical sections.", + ) + + self.search_agent = agents.Agent( + name="Search Agent", + instructions=""" + You are a Search Agent specialized in retrieving exact, relevant information from ABB robot manuals stored in a vector database. + + Your ONLY purpose is to: + 1. Use the provided "knowledge_search" tool to query the database. + 2. Return the most relevant technical excerpts that directly answer the query. + 3. Preserve the original wording — do not paraphrase, summarize, or add explanations. + + Rules: + - Do NOT generate answers from your own knowledge. + - Do NOT guess or fill in missing details. + - Do NOT include unrelated or generic text. + + Output Format: + [ + { + "source": "", + "url": "< a clickable url to go to the source of the information, include the page number >" + "excerpt": "", + "confidence": + }, + ... + ] + + Guidelines: + - Always select the top results that are most technically accurate and useful for a technician repairing ABB robots. + - Avoid redundancy — do not return overlapping excerpts. + - If no relevant information is found, return an empty list []. + """, + model=agents.OpenAIChatCompletionsModel( + model="gemini-2.5-flash-lite-preview-06-17", openai_client=self.client + ), + model_settings=agents.ModelSettings(tool_choice="required", temperature=0), + tools=[self.knowledge_tool], + ) + + @staticmethod + async def search_knowledgebase(query: str): + weaviate = Weaviate() + + return await weaviate.get_knowledge(query) + + async def run(self, prompt: str) -> str: + response = await agents.Runner.run(self.search_agent, input=prompt) + return response diff --git a/ABB-Manual-Assistant/search_tool.py b/ABB-Manual-Assistant/search_tool.py new file mode 100644 index 0000000..92be19e --- /dev/null +++ b/ABB-Manual-Assistant/search_tool.py @@ -0,0 +1,69 @@ +import json +import os + +import openai +import weaviate +from dotenv import load_dotenv +from weaviate.classes.init import Auth + + +load_dotenv() + + +class Weaviate: + def __init__(self, data_name=os.getenv("COLLECTION_NAME")): + self.client = None + self.data_name = data_name + + # Setup OpenAI client via Cloudflare + self.openai_client = openai.OpenAI( + api_key=os.getenv("EMBEDDING_API_KEY"), base_url=os.getenv("EMBEDDING_BASE_URL") + ) + + async def create_client(self) -> weaviate.WeaviateClient: + cluster_url = os.getenv("WEAVIATE_HTTP_HOST") + api_key = os.getenv("WEAVIATE_API_KEY") + + client = weaviate.connect_to_weaviate_cloud(cluster_url=cluster_url, auth_credentials=Auth.api_key(api_key)) + + return client + + async def ensure_connected(self): + if self.client is None: + self.client = await self.create_client() + + async def get_knowledge(self, query: str) -> str: + await self.ensure_connected() + + try: + # Generate embedding + embedding = self.openai_client.embeddings.create(model=os.getenv("EMBEDDING_MODEL_NAME"), input=query) + + # Perform hybrid search + collection = self.client.collections.get(self.data_name) + response = collection.query.hybrid( + query=query, vector=embedding.data[0].embedding, return_metadata=["score"] + ) + + if not response.objects: + return "No results found." + + # Format results + formatted_results = [ + { + "Document Name": obj.properties.get("document_Name", ""), + "URL": obj.properties.get("uRL", ""), + "Page Number": obj.properties.get("page_number", ""), + "Full Text": obj.properties.get("full_text", ""), + } + for obj in response.objects + ] + + return json.dumps(formatted_results, indent=2, sort_keys=True) + + except Exception as e: + return f"Search error: {e}" + + async def close(self): + if self.client: + self.client.close() diff --git a/ABB-Manual-Assistant/test_scripts/test_agent_search.py b/ABB-Manual-Assistant/test_scripts/test_agent_search.py new file mode 100644 index 0000000..89b358e --- /dev/null +++ b/ABB-Manual-Assistant/test_scripts/test_agent_search.py @@ -0,0 +1,82 @@ +# generated test file - agent using search tool + +import os +import sys + + +# Add parent directory to sys.path +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +import asyncio +import os + +import agents +from dotenv import load_dotenv +from openai import AsyncOpenAI +from search_tool import Weaviate # your tool + + +load_dotenv() + +# Initialize OpenAI client +client = AsyncOpenAI( + api_key=os.getenv("OPENAI_API_KEY"), + base_url=os.getenv("OPENAI_BASE_URL"), +) + +# Initialize Weaviate search class +weaviate_search = Weaviate() + + +async def search_knowledgebase(query: str) -> str: + print(f"[TOOL] Called with query: {query}") + try: + result = await weaviate_search.get_knowledge(query) + print(f"[TOOL] Result length: {len(result) if result else 'None'}") + print(f"[TOOL] Full text: {result}") + return result + except Exception as e: + print(f"[TOOL] Exception: {e}") + return f"Error during search: {e}" + + +knowledge_tool = agents.function_tool(search_knowledgebase) + +agent = agents.Agent( + name="Debug Knowledge Agent", + instructions="You are a helpful assistant. Use the tool to get knowledge.", + tools=[knowledge_tool], + model=agents.OpenAIChatCompletionsModel( + model="gemini-2.5-pro", + openai_client=client, + ), + model_settings=agents.ModelSettings(tool_choice="required"), +) + + +async def main(): + test_queries = [ + # "10077, FTP server down", + # "What is FTP?", + # "Explain HTTP protocol", + # "What is the largest ABB robot", + # "What are the specs for the IRB 140?", + # "What is \"EN ISO 12100 -1\"", + # "What colour is the sky", + "Spot application weld error reported" + ] + + for q in test_queries: + print("\n===============================") + print(f"Running agent with prompt: {q}") + try: + response = await agents.Runner.run(agent, input=q) + print("[AGENT FINAL OUTPUT]:\n", response.final_output) + except Exception as e: + print("[AGENT] Exception during run:", e) + + await weaviate_search.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/ABB-Manual-Assistant/test_scripts/test_orch.py b/ABB-Manual-Assistant/test_scripts/test_orch.py new file mode 100644 index 0000000..cbdac54 --- /dev/null +++ b/ABB-Manual-Assistant/test_scripts/test_orch.py @@ -0,0 +1,35 @@ +# generated test file - orchestrator agent using search tool + +import os +import sys + + +# Add parent directory to sys.path +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +import asyncio +import logging + +#### +import warnings + +from orchestrator_agent import Orchestrator + + +# Suppress all warnings (UserWarning, DeprecationWarning, etc.) +warnings.filterwarnings("ignore") +# Suppress log messages from all libraries +logging.basicConfig(level=logging.CRITICAL) +for name in logging.root.manager.loggerDict: + logging.getLogger(name).setLevel(logging.CRITICAL) +#### + + +async def main(): + orchestrator = Orchestrator() + response = await orchestrator.run("How to integrate IRC5?") + print(response) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/ABB-Manual-Assistant/test_scripts/test_search.py b/ABB-Manual-Assistant/test_scripts/test_search.py new file mode 100644 index 0000000..d11c38b --- /dev/null +++ b/ABB-Manual-Assistant/test_scripts/test_search.py @@ -0,0 +1,27 @@ +# generated test file - only search tool + +import os +import sys + + +# Add parent directory to sys.path +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +import asyncio + +from search_tool import Weaviate + + +async def test(): + try: + weaviate_search = Weaviate() + test_query = "Spot application weld error reported" + result = await weaviate_search.get_knowledge(test_query) + print("[TEST RESULT]") + print(result if result else "No result returned.") + except Exception as e: + print(f"[TEST ERROR] {e}") + + +if __name__ == "__main__": + asyncio.run(test()) diff --git a/ABB-Manual-Assistant/test_searchagent.py b/ABB-Manual-Assistant/test_searchagent.py new file mode 100644 index 0000000..75458cc --- /dev/null +++ b/ABB-Manual-Assistant/test_searchagent.py @@ -0,0 +1,18 @@ +import asyncio + +from search_agent import SearchAgent # Replace with actual import path + + +async def test_error_code_query(): + agent = SearchAgent() + query = "What does error code 10039 mean in ABB robot manuals?" + + print("Running test query...") + result = await agent.run(query) + + print("\n=== Test Result ===") + print(result) + + +if __name__ == "__main__": + asyncio.run(test_error_code_query()) diff --git a/ABB-Manual-Assistant/test_searchtool.py b/ABB-Manual-Assistant/test_searchtool.py new file mode 100644 index 0000000..66ec2ec --- /dev/null +++ b/ABB-Manual-Assistant/test_searchtool.py @@ -0,0 +1,13 @@ +import asyncio + +from search_tool import Weaviate # Ensure this file exists in the same directory + + +async def main(): + search_tool = Weaviate() + result = await search_tool.get_knowledge("error code 10039") + print(result) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/ABB-Manual-Assistant/ui.py b/ABB-Manual-Assistant/ui.py new file mode 100644 index 0000000..7e07208 --- /dev/null +++ b/ABB-Manual-Assistant/ui.py @@ -0,0 +1,274 @@ +from __future__ import annotations + +# Gradio powers the UI +import gradio as gr + +# Your local persistence layer (SQLite wrapper you already have) +from agent_utils.memory_store import MemoryStore + +# Our tiny “agent” that owns conversation state and history (lives in project root) +from conversation_manager import ConversationManagerAgent + +# Orchestrator remains the single “brain” that decides how to answer and streams tokens/chunks +from orchestrator_agent import Orchestrator + +# Your tracing setup (left intact so nothing breaks) +from utils import setup_langfuse_tracer +from utils.langfuse.shared_client import langfuse_client # noqa: F401 (imported for tracer wiring) + + +# ----------------------------------------------------------------------------- +# Storage + manager agent +# ----------------------------------------------------------------------------- +store = MemoryStore() +conv_agent = ConversationManagerAgent(store) # The app talks to the conversation manager, not raw DB + + +# ----------------------------------------------------------------------------- +# Small helpers for the dropdown and history mapping +# ----------------------------------------------------------------------------- +def _label_for(id_: str, title: str) -> str: + """Pretty labels for the chat selector dropdown: 'Title · abc123'.""" + return f"{title} · {id_[:6]}" + + +def _choices_and_maps(): + """ + Build dropdown choices and mapping between labels and conversation IDs. + Ensures at least one conversation exists. + """ + rows = store.list_conversations() + if not rows: + store.create_conversation() + rows = store.list_conversations() + choices = [_label_for(r[0], r[1]) for r in rows] + id_by_label = {_label_for(r[0], r[1]): r[0] for r in rows} + label_by_id = {r[0]: _label_for(r[0], r[1]) for r in rows} + return choices, id_by_label, label_by_id, choices[0] # default = first + + +def _history_messages(cid: str): + """ + Read messages from the ConversationManagerAgent and convert to + Chatbot(type='messages') format: [{role, content}, ...] + """ + rows = conv_agent.get_history_messages(cid, limit=1000) + return [{"role": r["role"], "content": r["content"]} for r in rows] + + +# ----------------------------------------------------------------------------- +# The Gradio App +# ----------------------------------------------------------------------------- +class GradioApp: + def __init__(self): + # One Orchestrator instance for the whole app + self.orchestrator = Orchestrator() + + async def run_search_stream(self, message: str, chat_history): + """ + Delegate to Orchestrator and stream chunks back. + This is intentionally simple: the Orchestrator decides how to answer. + """ + async for chunk in self.orchestrator.run(message, chat_history): + # Each 'chunk' is a piece of the assistant's text (token/phrase/etc) + yield chunk + + def launch(self): + """ + Build and launch the Gradio UI. + """ + with gr.Blocks(title="ABB Knowledgebase Search") as demo: + # Top title + gr.Markdown("### ABB Knowledgebase Search — Multi-Chat (per-chat history)") + + # Global UI state: we store the currently selected conversation_id here + state = gr.State({}) + + with gr.Row(): + # ----------------------- Sidebar: Conversation management ----------------------- + with gr.Column(scale=1, min_width=320): + gr.Markdown("**Chats**") + + # Dropdown for selecting which conversation is active + chat_dropdown = gr.Dropdown(choices=[], value=None, label="Select chat", interactive=True) + + # Buttons/inputs for CRUD on conversations + new_title = gr.Textbox(label="New chat title", placeholder="Optional (auto if blank)") + btn_new = gr.Button("New chat") + + rename_to = gr.Textbox(label="Rename to") + btn_rename = gr.Button("Rename") + + btn_delete = gr.Button("Delete chat") + btn_refresh = gr.Button("Refresh list") + + # ----------------------- Main chat panel ----------------------- + with gr.Column(scale=3): + # Chatbot uses OpenAI-style message dicts when type='messages' + chatbot = gr.Chatbot(label="Chat", height=520, type="messages") + + # Input row: one-line textbox + Send button + with gr.Row(): + msg = gr.Textbox( + placeholder="Ask me something about ABB errors...", scale=9, label="Your Question", lines=1 + ) + send = gr.Button("Send", variant="primary", scale=1) + + # Some convenience example prompts + examples = [ + "What is Error code 10039 and possible solution?", + "What is a reference error?", + "We are getting a Motor phase short circuit. Where should we look?", + "How to set up an IRC5, and what is it?", + ] + gr.Examples(examples=examples, inputs=msg, label="Try one of these example questions:") + + # ----------------------- Callbacks ----------------------- + + # Initialize the UI on page load: + # - Populate dropdown + # - Set default conversation_id in state + # - Load that conversation's message history + def _init(): + choices, id_by_label, label_by_id, default_label = _choices_and_maps() + cid = id_by_label[default_label] + return ( + gr.update(choices=choices, value=default_label), # dropdown options + selection + {"conversation_id": cid}, # state + _history_messages(cid), # initial chat messages + ) + + demo.load(_init, inputs=None, outputs=[chat_dropdown, state, chatbot]) + + # When the user switches the dropdown, update the active conversation and show its history + def _select_chat(label, current_state): + choices, id_by_label, label_by_id, default_label = _choices_and_maps() + if not label or label not in id_by_label: + label = default_label + cid = id_by_label[label] + current_state = current_state or {} + current_state["conversation_id"] = cid + return current_state, _history_messages(cid) + + chat_dropdown.change(_select_chat, inputs=[chat_dropdown, state], outputs=[state, chatbot]) + + # The main send handler — this is an async generator that yields intermediate UI updates, + # so you see the assistant's reply grow inside the *same* chat bubble (streaming). + async def _send(user_text, current_state, current_msgs): + # Guard: ignore empty messages but keep UI outputs consistent + if not user_text or not user_text.strip(): + yield current_msgs, "", (current_state or {}) + return + + # Resolve the active conversation ID (default to first if none) + choices, id_by_label, label_by_id, default_label = _choices_and_maps() + cid = (current_state or {}).get("conversation_id") or id_by_label[default_label] + current_state = {"conversation_id": cid} + + # Persist the user's message immediately + conv_agent.save_user(cid, user_text) + + # Start from canonical history for this conversation + messages = _history_messages(cid) + + # Append user's new message + messages.append({"role": "user", "content": user_text}) + + chat_history = messages + + # Append an *empty* assistant bubble that we'll fill as chunks arrive + # Seed it with a quick typing indicator so the user sees a response is coming + messages.append({"role": "assistant", "content": "…"}) + # Push user + typing indicator to UI immediately + yield messages, "", current_state + + # Now stream the assistant reply into that last message in-place + partial = "" + try: + async for chunk in self.run_search_stream(user_text, chat_history): + partial += chunk # grow the assistant's partial reply + messages[-1]["content"] = partial # update the *same* assistant bubble + conv_agent.set_assistant_partial(cid, partial) # keep partial in RAM (not DB) + yield messages, "", current_state # push incremental update to the UI + finally: + # When stream ends (or errors), persist the final assistant message once + conv_agent.finalize_assistant(cid) + + # Reload canonical history from storage (ensures what you see is exactly what we saved) + messages = _history_messages(cid) + yield messages, "", current_state + + # Wire the Send button and Enter key to the same streaming handler + send.click(_send, inputs=[msg, state, chatbot], outputs=[chatbot, msg, state]) + msg.submit(_send, inputs=[msg, state, chatbot], outputs=[chatbot, msg, state]) + + # Create a new conversation; select it and show an empty history + def _new_chat(title): + res = conv_agent.create(title or None) + cid = res["conversation_id"] + choices, id_by_label, label_by_id, _ = _choices_and_maps() + label = label_by_id[cid] + return ( + gr.update(choices=choices, value=label), # select new chat in dropdown + {"conversation_id": cid}, # set state + _history_messages(cid), # show empty (or fresh) history + ) + + btn_new.click(_new_chat, inputs=new_title, outputs=[chat_dropdown, state, chatbot]) + + # Rename the current conversation; update the dropdown label + def _rename_chat(new_title, current_state): + choices, id_by_label, label_by_id, default_label = _choices_and_maps() + cid = (current_state or {}).get("conversation_id") or id_by_label[default_label] + if new_title and new_title.strip(): + conv_agent.rename(cid, new_title.strip()) + choices, id_by_label, label_by_id, _ = _choices_and_maps() + label = label_by_id[cid] + return gr.update(choices=choices, value=label) + + btn_rename.click(_rename_chat, inputs=rename_to, outputs=chat_dropdown) + + # Delete the current conversation; ensure one remains and switch to it + def _delete_chat(current_state): + choices, id_by_label, label_by_id, default_label = _choices_and_maps() + cid = (current_state or {}).get("conversation_id") or id_by_label[default_label] + conv_agent.delete(cid) + + # Recompute choices; pick default; load its history + choices, id_by_label, label_by_id, default_label = _choices_and_maps() + new_cid = id_by_label[default_label] + return ( + gr.update(choices=choices, value=default_label), + {"conversation_id": new_cid}, + _history_messages(new_cid), + ) + + btn_delete.click(_delete_chat, inputs=state, outputs=[chat_dropdown, state, chatbot]) + + # Just refresh the dropdown list, keeping the same selection if still valid + def _refresh_list(current_state): + choices, id_by_label, label_by_id, default_label = _choices_and_maps() + cid = (current_state or {}).get("conversation_id") + if cid and cid in label_by_id: + label = label_by_id[cid] + else: + label = default_label + return gr.update(choices=choices, value=label) + + btn_refresh.click(_refresh_list, inputs=state, outputs=chat_dropdown) + + # Start the Gradio server + demo.launch(server_name="0.0.0.0") + + +# ----------------------------------------------------------------------------- +# Entry point (kept as-is for your tracing + app start) +# ----------------------------------------------------------------------------- +def main(): + setup_langfuse_tracer() + app = GradioApp() + app.launch() + + +if __name__ == "__main__": + main() diff --git a/ABB-Manual-Assistant/upload_test_data.py b/ABB-Manual-Assistant/upload_test_data.py new file mode 100644 index 0000000..badb529 --- /dev/null +++ b/ABB-Manual-Assistant/upload_test_data.py @@ -0,0 +1,74 @@ +import pandas as pd +from dotenv import load_dotenv +from langfuse import get_client +from rich.progress import track + + +# Load environment variables from .env file +load_dotenv() + +# Initialize Langfuse client +langfuse = get_client() + +# Define the dataset name +dataset_name = "LLM_Judge_Errors" + +# Define the question-answer pairs +qa_pairs = [ + ( + "What is Error code 10039 and possible solution?", + "During startup, the system has found that data in the Serial Measurement Board (SMB) memory is not OK. All data must be OK before automatic operation is possible. Manually jogging the robot is possible. There are differences between the data stored on the SMB and the data stored in the controller. This may be due to replacement of SMB, controller or both. Possible solution is to update the Serial Measurement Board data.", + ), + ("How to fix SMB memory is not OK", "Update the Serial Measurement Board data."), + ( + "How to recover if axis computer has lost communication.", + "1) Check cable between the axis computer and the Safety System is intact and correctly connected.\n2) Check power supply connected to the Safety System.\n3) Make sure no extreme levels of electromagnetic interference are emitted close to the robot cabling.", + ), + ( + "What does error code 40038 mean?", + "It is a LOCAL illegal in routine variable declaration. Only program data declarations may have the LOCAL attribute. Remove the LOCAL attribute.", + ), + ( + "What is a reference error.", + "System should ask to specify what reference error number they are getting to better answer the question. There are many reference error", + ), + ( + "Why am I getting a programmed forced reduced error.", + "Programmed tip force too high for tool arg. Requested motor torque (Nm)= arg. Force was reduced to max motor torque.", + ), + ( + "SMB Data is missing. What should I do?", + "If proper data exists in cabinet - transfer the data to SMB-memory. If still problem - check communication cable to SMB-board. Replace SMB-board.", + ), + ( + "We are getting a Motor phase short circuit. Where should we look?", + "You have a short circuit in cables or connectors between the phases or to Ground or a Short circuit in motor between the phases or to ground. Check/replace cables and connectors. Check/replace motor.", + ), + ( + "Why am I getting a singularity problem", + "Depending on exact error number the problem is either in joint 4 or joint 6.", + ), + ( + "Why am I getting a joint not synchronized error and how to fix it.", + "The speed of joint arg before power down/failure was too high. Make a new update of the revolution counter.", + ), +] + +# Convert to DataFrame +df = pd.DataFrame(qa_pairs, columns=["question", "expected_answer"]) + +# Create the dataset in Langfuse +langfuse.create_dataset( + name=dataset_name, + description="Robot error troubleshooting Q&A dataset", + metadata={"type": "benchmark", "source": "manual_upload"}, +) + +# Upload each item +for idx, row in track(df.iterrows(), total=len(df), description="Uploading to Langfuse"): + langfuse.create_dataset_item( + dataset_name=dataset_name, + input={"text": row["question"]}, + expected_output={"text": row["expected_answer"]}, + id=f"llmjudge-{idx:03}", + ) diff --git a/ABB-Manual-Assistant/utils/__init__.py b/ABB-Manual-Assistant/utils/__init__.py new file mode 100644 index 0000000..d9a32f6 --- /dev/null +++ b/ABB-Manual-Assistant/utils/__init__.py @@ -0,0 +1,3 @@ +"""Shared toolings for reference implementations.""" + +from .langfuse.oai_sdk_setup import setup_langfuse_tracer diff --git a/ABB-Manual-Assistant/utils/async_utils.py b/ABB-Manual-Assistant/utils/async_utils.py new file mode 100644 index 0000000..707f571 --- /dev/null +++ b/ABB-Manual-Assistant/utils/async_utils.py @@ -0,0 +1,54 @@ +"""Utils for async workflows.""" + +import asyncio +import types +from typing import Any, Awaitable, Callable, Coroutine, Sequence, TypeVar + +from rich.progress import Progress + + +T = TypeVar("T") + + +async def indexed(index: int, coro: Coroutine[None, None, T]) -> tuple[int, T]: + """Return (index, await coro).""" + return index, (await coro) + + +async def rate_limited(_fn: Callable[[], Awaitable[T]], semaphore: asyncio.Semaphore) -> T: + """Run _fn with semaphore rate limit.""" + async with semaphore: + return await _fn() + + +async def gather_with_progress( + coros: "list[types.CoroutineType[Any, Any, T]]", + description: str = "Running tasks", +) -> Sequence[T]: + """ + Run a list of coroutines concurrently, display a rich.Progress bar as each finishes. + + Returns the results in the same order as the input list. + + :param coros: List of coroutines to run. + :return: List of results, ordered to match the input coroutines. + """ + # Wrap each coroutine in a Task and remember its original index + tasks = [asyncio.create_task(indexed(index=index, coro=coro)) for index, coro in enumerate(coros)] + + # Pre‐allocate a results list; we'll fill in each slot as its Task completes + results: list[T | None] = [None] * len(tasks) + + # Create and start a Progress bar with a total equal to the number of tasks + with Progress() as progress: + progress_task = progress.add_task(description, total=len(tasks)) + + # as_completed yields each Task as soon as it finishes + for finished in asyncio.as_completed(tasks): + index, result = await finished + results[index] = result + progress.update(progress_task, advance=1) + + # At this point, every slot in `results` is guaranteed to be non‐None + # so we can safely cast it back to List[T] + return results # type: ignore diff --git a/ABB-Manual-Assistant/utils/data/__init__.py b/ABB-Manual-Assistant/utils/data/__init__.py new file mode 100644 index 0000000..99eee29 --- /dev/null +++ b/ABB-Manual-Assistant/utils/data/__init__.py @@ -0,0 +1,4 @@ +from .load_dataset import get_dataset, get_dataset_url_hash + + +__all__ = ["get_dataset", "get_dataset_url_hash"] diff --git a/ABB-Manual-Assistant/utils/data/batching.py b/ABB-Manual-Assistant/utils/data/batching.py new file mode 100644 index 0000000..41f91a3 --- /dev/null +++ b/ABB-Manual-Assistant/utils/data/batching.py @@ -0,0 +1,38 @@ +"""Utils for creating batches of data for performance.""" + +from typing import TypeVar + + +V = TypeVar("V") + + +def create_batches( + items: list[V], + batch_size: int, + limit: int | None = None, + keep_trailing: bool = True, +) -> list[list[V]]: + """Transform the list of items into batches. + + Params: + limit: number of items to include in total + keep_trailing: if False, the last few items that + does not fit in a full batch will not be returned. + + Return: + List of batches. + """ + batches: list[list[V]] = [[]] + for _index, _item in enumerate(items): + if (limit is not None) and (_index >= limit): + break + + batches[-1].append(_item) + if len(batches[-1]) == batch_size: + batches.append([]) + + # Discard trailing batch if empty or required + if (len(batches[-1]) == 0) or ((not keep_trailing) and (len(batches[-1]) < batch_size)): + batches.pop(-1) + + return batches diff --git a/ABB-Manual-Assistant/utils/data/load_dataset.py b/ABB-Manual-Assistant/utils/data/load_dataset.py new file mode 100644 index 0000000..a40fb75 --- /dev/null +++ b/ABB-Manual-Assistant/utils/data/load_dataset.py @@ -0,0 +1,83 @@ +"""Logic for loading datasets.""" + +import hashlib +import os.path +import re + +import datasets +import pandas as pd +import pydantic + + +PATTERN = re.compile( + r"(?P[^:]+)://" + r"(?P[^:@]+)?" + r"(@(?P[a-f\d]+))?" + r"(\[(?P\w+)\])?" + r"(:(?P\w+))?$" +) + + +class _SourceInfo(pydantic.BaseModel): + provider: str + repo: str + version: str | None = None + subset: str | None = None + split: str = "train" + + @staticmethod + def _from_url(dataset_url: str) -> "_SourceInfo": + """Parse URL.""" + url_match = PATTERN.match(dataset_url) + dataset_info = _SourceInfo(**url_match.groupdict()) if url_match else None + if dataset_info is None: + raise ValueError("Invalid URL pattern. Should be {provider}://{path}[@{commit}]:{split}") + + return dataset_info + + +def get_dataset(dataset_url: str, limit: int | None = None) -> pd.DataFrame: + """Load dataset from the given URL. + + Params + ------ + dataset_url: in the following format: + {provider}://{path}[@{commit}][[subset]]:{split} + limit: optional; max number of items to include. + + Returns + ------- + Huggingface dataset instance. + """ + dataset_info = _SourceInfo._from_url(dataset_url) + if dataset_info.provider == "hf": + return _load_hf(dataset_info, limit=limit).to_pandas() # type: ignore + + raise ValueError(f"Dataset provider not supported: {dataset_info.provider}. Available options: hf") + + +def get_dataset_url_hash(dataset_url: str) -> str: + """Hash dataset url for attribution.""" + return hashlib.sha256(dataset_url.encode()).hexdigest()[:6] + + +def _load_hf(dataset_info: _SourceInfo, limit: int | None = None) -> datasets.Dataset: + """Load HF dataset.""" + # Prefer load_from_disk locally. + # If not possible, load from Hub or local snapshot using load_dataset. + if (dataset_info.version is None) and (os.path.exists(dataset_info.repo)): + try: + dataset_or_dict = datasets.load_from_disk(dataset_info.repo) + if dataset_info.split is not None: + return dataset_or_dict[dataset_info.split] # type: ignore + + return dataset_or_dict # type: ignore + + except FileNotFoundError: + pass # type: ignore + + split_name = dataset_info.split + if limit is not None: + split_name += f"[0:{limit}]" + + return datasets.load_dataset(dataset_info.repo, name=dataset_info.subset, split=split_name) # type: ignore diff --git a/ABB-Manual-Assistant/utils/env_vars.py b/ABB-Manual-Assistant/utils/env_vars.py new file mode 100644 index 0000000..6867bad --- /dev/null +++ b/ABB-Manual-Assistant/utils/env_vars.py @@ -0,0 +1,56 @@ +"""Interface for storing and accessing config env vars.""" + +from os import environ + +import pydantic + + +class Configs(pydantic.BaseModel): + """Type-friendly collection of env var configs.""" + + # Embeddings + embedding_base_url: str + embedding_api_key: str + + # Weaviate + weaviate_http_host: str + weaviate_grpc_host: str + weaviate_api_key: str + weaviate_http_port: int = 443 + weaviate_grpc_port: int = 443 + weaviate_http_secure: bool = True + weaviate_grpc_secure: bool = True + + # Langfuse + langfuse_public_key: str + langfuse_secret_key: str + langfuse_host: str = "https://us.cloud.langfuse.com" + + def _check_langfuse(self): + """Ensure that Langfuse pk and sk are in the right place.""" + if not self.langfuse_public_key.startswith("pk-lf-"): + raise ValueError("LANGFUSE_PUBLIC_KEY should start with pk-lf-") + + if not self.langfuse_secret_key.startswith("sk-lf-"): + raise ValueError("LANGFUSE_SECRET_KEY should start with sk-lf-") + + @staticmethod + def from_env_var() -> "Configs": + """Initialize from env vars.""" + # Add only config line items defined in Configs. + data: dict[str, str] = {} + for k, v in environ.items(): + _key = k.lower() + data[_key] = v + + try: + config = Configs(**data) + config._check_langfuse() + return config + + except pydantic.ValidationError as e: + raise ValueError( + "Some ENV VARs are missing. See above for details. " + "Try to load your .env file as follows: \n" + "```\nuv run --env-file .env -m ...\n```" + ) from e diff --git a/ABB-Manual-Assistant/utils/gradio/__init__.py b/ABB-Manual-Assistant/utils/gradio/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ABB-Manual-Assistant/utils/gradio/messages.py b/ABB-Manual-Assistant/utils/gradio/messages.py new file mode 100644 index 0000000..2fb2cb6 --- /dev/null +++ b/ABB-Manual-Assistant/utils/gradio/messages.py @@ -0,0 +1,146 @@ +"""Tools for integrating with the Gradio chatbot UI.""" + +from typing import TYPE_CHECKING + +from agents import StreamEvent, stream_events +from agents.items import MessageOutputItem, RunItem, ToolCallItem, ToolCallOutputItem +from gradio.components.chatbot import ChatMessage +from openai.types.responses import ( + ResponseCompletedEvent, + ResponseFunctionToolCall, + ResponseOutputMessage, + ResponseOutputText, +) + +from ..pretty_printing import pretty_print + + +if TYPE_CHECKING: + from openai.types.chat import ChatCompletionMessageParam + + +def gradio_messages_to_oai_chat( + messages: list[ChatMessage | dict], +) -> list["ChatCompletionMessageParam"]: + """Translate Gradio chat message history to OpenAI format.""" + output: list["ChatCompletionMessageParam"] = [] + for message in messages: + if isinstance(message, dict): + output.append(message) # type: ignore[arg-type] + continue + + message_content = message.content + assert isinstance(message_content, str), message_content + output.append({"role": message.role, "content": message_content}) # type: ignore[arg-type,misc] + + return output + + +def _oai_response_output_item_to_gradio(item: RunItem) -> list[ChatMessage] | None: + """Map OAI SDK new RunItem (response.new_items) to gr messages. + + Returns None if message is of unknown/unsupported type. + """ + print(type(item)) + pretty_print(item) + + if isinstance(item, ToolCallItem): + raw_item = item.raw_item + + if isinstance(raw_item, ResponseFunctionToolCall): + return [ + ChatMessage( + role="assistant", + content=f"```\n{raw_item.arguments}\n```\n`{raw_item.call_id}`", + metadata={ + "title": f"Used tool `{raw_item.name}`", + }, + ) + ] + + if isinstance(item, ToolCallOutputItem): + function_output = item.raw_item["output"] + call_id = item.raw_item.get("call_id", None) + + if isinstance(function_output, str): + return [ + ChatMessage( + role="assistant", + content=f"> {function_output}\n\n`{call_id}`", + metadata={ + "title": "Tool response", + }, + ) + ] + + if isinstance(item, MessageOutputItem): + message_content = item.raw_item + + output_texts: list[str] = [] + for response_text in message_content.content: + if isinstance(response_text, ResponseOutputText): + output_texts.append(response_text.text) + + return [ChatMessage(role="assistant", content=_text) for _text in output_texts] + + return None + + +def oai_agent_items_to_gradio_messages( + new_items: list[RunItem], +) -> list[ChatMessage]: + """Parse agent sdk "new items" into a list of gr messages. + + Adds extra data for tool use to make the gradio display informative. + """ + output: list[ChatMessage] = [] + for item in new_items: + maybe_messages = _oai_response_output_item_to_gradio(item) + if maybe_messages is not None: + output.extend(maybe_messages) + + return output + + +def oai_agent_stream_to_gradio_messages( + stream_event: StreamEvent, +) -> list[ChatMessage]: + """Parse agent sdk "stream event" into a list of gr messages. + + Adds extra data for tool use to make the gradio display informative. + """ + output: list[ChatMessage] = [] + if isinstance(stream_event, stream_events.RawResponsesStreamEvent): + data = stream_event.data + if isinstance(data, ResponseCompletedEvent): + for message in data.response.output: + if isinstance(message, ResponseOutputMessage): + for _item in message.content: + if isinstance(_item, ResponseOutputText): + output.append(ChatMessage(role="assistant", content=_item.text)) + + elif isinstance(message, ResponseFunctionToolCall): + output.append( + ChatMessage( + role="assistant", + content=f"```\n{message.arguments}\n```", + metadata={ + "title": f"Used tool `{message.name}`", + }, + ) + ) + elif isinstance(stream_event, stream_events.RunItemStreamEvent): + name = stream_event.name + item = stream_event.item + if name == "tool_output" and isinstance(item, ToolCallOutputItem): + output.append( + ChatMessage( + role="assistant", + content=f"```\n{item.output}\n```", + metadata={ + "title": "*Tool call output*", + }, + ) + ) + + return output diff --git a/ABB-Manual-Assistant/utils/langfuse/oai_sdk_setup.py b/ABB-Manual-Assistant/utils/langfuse/oai_sdk_setup.py new file mode 100644 index 0000000..8432cc3 --- /dev/null +++ b/ABB-Manual-Assistant/utils/langfuse/oai_sdk_setup.py @@ -0,0 +1,42 @@ +"""Utils for redirecting OpenAI Agent SDK traces to LangFuse via OpenTelemetry. + +Full documentation: +langfuse.com/docs/integrations/openaiagentssdk/openai-agents +""" + +import logfire +import nest_asyncio +from opentelemetry import trace +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor + +from .otlp_env_setup import set_up_langfuse_otlp_env_vars + + +def configure_oai_agents_sdk(service_name: str) -> None: + """Register Langfuse as tracing provider for OAI Agents SDK.""" + nest_asyncio.apply() + logfire.configure(service_name=service_name, send_to_logfire=False, scrubbing=False) + logfire.instrument_openai_agents() + + +def setup_langfuse_tracer(service_name: str = "agents_sdk") -> "trace.Tracer": + """Register Langfuse as the default tracing provider and return tracer. + + Returns + ------- + tracer: OpenTelemetry Tracer + """ + set_up_langfuse_otlp_env_vars() + configure_oai_agents_sdk(service_name) + + # Create a TracerProvider for OpenTelemetry + trace_provider = TracerProvider() + + # Add a SimpleSpanProcessor with the OTLPSpanExporter to send traces + trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter())) + + # Set the global default tracer provider + trace.set_tracer_provider(trace_provider) + return trace.get_tracer(__name__) diff --git a/ABB-Manual-Assistant/utils/langfuse/otlp_env_setup.py b/ABB-Manual-Assistant/utils/langfuse/otlp_env_setup.py new file mode 100644 index 0000000..6adb35d --- /dev/null +++ b/ABB-Manual-Assistant/utils/langfuse/otlp_env_setup.py @@ -0,0 +1,27 @@ +"""Set up environment variables for LangFuse integration.""" + +import base64 +import logging +import os + +from ..env_vars import Configs + + +def set_up_langfuse_otlp_env_vars(): + """Set up environment variables for Langfuse OpenTelemetry integration. + + OTLP = OpenTelemetry Protocol. + + This function updates environment variables. + + Also refer to: + langfuse.com/docs/integrations/openaiagentssdk/openai-agents + """ + configs = Configs.from_env_var() + + langfuse_auth = base64.b64encode(f"{configs.langfuse_public_key}:{configs.langfuse_secret_key}".encode()).decode() + + os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = configs.langfuse_host + "/api/public/otel" + os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = f"Authorization=Basic {langfuse_auth}" + + logging.info(f"Langfuse host: {configs.langfuse_host}") diff --git a/ABB-Manual-Assistant/utils/langfuse/shared_client.py b/ABB-Manual-Assistant/utils/langfuse/shared_client.py new file mode 100644 index 0000000..0366328 --- /dev/null +++ b/ABB-Manual-Assistant/utils/langfuse/shared_client.py @@ -0,0 +1,30 @@ +"""Shared instance of langfuse client.""" + +from os import getenv + +from langfuse import Langfuse +from rich.progress import Progress, SpinnerColumn, TextColumn + +from ..env_vars import Configs + + +__all__ = ["langfuse_client"] + + +config = Configs.from_env_var() +assert getenv("LANGFUSE_PUBLIC_KEY") is not None +langfuse_client = Langfuse(public_key=config.langfuse_public_key, secret_key=config.langfuse_secret_key) + + +def flush_langfuse(client: "Langfuse | None" = None): + """Flush shared LangFuse Client. Rich Progress included.""" + if client is None: + client = langfuse_client + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + transient=True, + ) as progress: + progress.add_task("Finalizing Langfuse annotations...", total=None) + langfuse_client.flush() diff --git a/ABB-Manual-Assistant/utils/langfuse/trace_id.py b/ABB-Manual-Assistant/utils/langfuse/trace_id.py new file mode 100644 index 0000000..e1b93fa --- /dev/null +++ b/ABB-Manual-Assistant/utils/langfuse/trace_id.py @@ -0,0 +1,11 @@ +""" +Obtain trace_id, required for linking trace to dataset row. + +Full documentation: +langfuse.com/docs/integrations/openaiagentssdk/example-evaluating-openai-agents +running-the-agent-on-the-dataset +""" + + +def get_langfuse_trace_id(): + """Obtain "formatted" trace_id for LangFuse.""" diff --git a/ABB-Manual-Assistant/utils/logging.py b/ABB-Manual-Assistant/utils/logging.py new file mode 100644 index 0000000..492d9e7 --- /dev/null +++ b/ABB-Manual-Assistant/utils/logging.py @@ -0,0 +1,35 @@ +"""Set up logging, warning, etc.""" + +import logging +import warnings + + +class IgnoreOpenAI401Filter(logging.Filter): + """ + A logging filter that excludes specific OpenAI client error messages. + + Filters out: 'ERROR:openai.agents:[non-fatal] Tracing client error 401' + """ + + def filter(self, record: logging.LogRecord) -> bool: + """Define filter logic.""" + msg = record.getMessage() + return not ( + record.levelname == "ERROR" + and record.name == "openai.agents" + and "[non-fatal] Tracing client error 401" in msg + ) + + +def set_up_logging(): + """Set up Logging and Warning levels.""" + root_logger = logging.getLogger() + filter_ = IgnoreOpenAI401Filter() + + if not root_logger.handlers: + logging.basicConfig() + + for handler in root_logger.handlers: + handler.addFilter(filter_) + + warnings.filterwarnings("ignore", category=ResourceWarning) diff --git a/ABB-Manual-Assistant/utils/pretty_printing.py b/ABB-Manual-Assistant/utils/pretty_printing.py new file mode 100644 index 0000000..76d1554 --- /dev/null +++ b/ABB-Manual-Assistant/utils/pretty_printing.py @@ -0,0 +1,34 @@ +"""Pretty-Print Utils.""" + +import json +from typing import Any + +import pydantic + + +def _serializer(item: Any) -> dict[str, Any] | str: + """Serialize using heuristics.""" + if isinstance(item, pydantic.BaseModel): + return item.model_dump() + return str(item) + + +def pretty_print(data: Any) -> str: + """Extract and JSON-dump only the 'properties' field from Weaviate result objects.""" + try: + if isinstance(data, list): + properties_list = [] + for obj in data: + if hasattr(obj, "properties"): + properties_list.append(obj.properties) + else: + properties_list.append(_serializer(obj)) + else: + properties_list = [getattr(data, "properties", _serializer(data))] + + output = json.dumps(properties_list, indent=2) + print(output) + return output + + except Exception as e: + return f"Error during pretty print: {e}" diff --git a/ABB-Manual-Assistant/utils/tools/README.md b/ABB-Manual-Assistant/utils/tools/README.md new file mode 100644 index 0000000..55666d2 --- /dev/null +++ b/ABB-Manual-Assistant/utils/tools/README.md @@ -0,0 +1,8 @@ +# Tools for Agents + +This module contains various tools for LLM agents. + +```bash +# Tool for getting a list of recent news headlines from enwiki +uv run -m src.utils.tools.news_events +``` diff --git a/ABB-Manual-Assistant/utils/tools/__init__.py b/ABB-Manual-Assistant/utils/tools/__init__.py new file mode 100644 index 0000000..b1514d1 --- /dev/null +++ b/ABB-Manual-Assistant/utils/tools/__init__.py @@ -0,0 +1,2 @@ +from .kb_weaviate import AsyncWeaviateKnowledgeBase, get_weaviate_async_client +from .news_events import get_news_events diff --git a/ABB-Manual-Assistant/utils/tools/code_interpreter.py b/ABB-Manual-Assistant/utils/tools/code_interpreter.py new file mode 100644 index 0000000..957d9c4 --- /dev/null +++ b/ABB-Manual-Assistant/utils/tools/code_interpreter.py @@ -0,0 +1,121 @@ +"""Code interpreter tool.""" + +from pathlib import Path +from typing import Sequence + +from e2b_code_interpreter import AsyncSandbox +from pydantic import BaseModel + +from ..async_utils import gather_with_progress + + +class _CodeInterpreterOutputError(BaseModel): + """Error from code interpreter.""" + + name: str + value: str + traceback: str + + +class CodeInterpreterOutput(BaseModel): + """Output from code interpreter.""" + + stdout: list[str] + stderr: list[str] + error: _CodeInterpreterOutputError | None = None + + def __init__(self, stdout: list[str], stderr: list[str], **kwargs): + """Split lines in stdout and stderr.""" + stdout_processed = [] + for _line in stdout: + stdout_processed.extend(_line.splitlines()) + + stderr_processed = [] + for _line in stderr: + stderr_processed.extend(_line.splitlines()) + + super().__init__(stdout=stdout_processed, stderr=stderr_processed, **kwargs) + + +async def _upload_file(sandbox: "AsyncSandbox", local_path: "str | Path") -> str: + """Upload file to sandbox. + + Returns + ------- + str, denoting the remote path. + """ + path = Path(local_path) + remote_path = f"{path.name}" + with open(local_path, "rb") as file: + await sandbox.files.write(remote_path, file) + + return remote_path + + +async def _upload_files(sandbox: "AsyncSandbox", paths: Sequence[Path | str]) -> list[str]: + """Upload files to the sandbox. + + Parameters + ---------- + paths: Sequence[pathlib.Path | str] + Files to upload to the sandbox. + + Returns + ------- + list[str] + List of remote paths, one per file. + """ + if not paths: + return [] + + file_upload_coros = [_upload_file(sandbox, _path) for _path in paths] + remote_paths = await gather_with_progress(file_upload_coros, description=f"Uploading {len(paths)} to sandbox") + return list(remote_paths) + + +class CodeInterpreter: + """Code Interpreter tool for the agent.""" + + def __init__( + self, + local_files: "Sequence[Path | str]| None" = None, + timeout_seconds: int = 30, + ): + """Configure your Code Interpreter session. + + Note that the sandbox is not persistent, and each run_code will + execute in a fresh sandbox! (e.g., variables need to be re-declared each time.) + + Parameters + ---------- + local_files : list[pathlib.Path | str] | None + Optionally, specify a list of local files (as paths) + to upload to sandbox working directory. + timeout_seconds : int + Limit executions to this duration. + """ + self.timeout_seconds = timeout_seconds + self.local_files = local_files if local_files else [] + + async def run_code(self, code: str) -> str: + """Run the given Python code in a sandbox environment. + + Parameters + ---------- + code : str + Python logic to execute. + """ + sbx = await AsyncSandbox.create(timeout=self.timeout_seconds) + await _upload_files(sbx, self.local_files) + + try: + result = await sbx.run_code(code, on_error=lambda error: print(error.traceback)) + response = CodeInterpreterOutput.model_validate_json(result.logs.to_json()) + + error = result.error + if error is not None: + response.error = _CodeInterpreterOutputError.model_validate_json(error.to_json()) + + return response.model_dump_json() + finally: + await sbx.kill() diff --git a/ABB-Manual-Assistant/utils/tools/kb_weaviate.py b/ABB-Manual-Assistant/utils/tools/kb_weaviate.py new file mode 100644 index 0000000..e69861d --- /dev/null +++ b/ABB-Manual-Assistant/utils/tools/kb_weaviate.py @@ -0,0 +1,206 @@ +"""Implements knowledge retrieval tool for Weaviate.""" + +import asyncio +import logging +import os + +import backoff +import openai +import pydantic +import weaviate +from weaviate import WeaviateAsyncClient +from weaviate.config import AdditionalConfig + +from ..async_utils import rate_limited + + +class _Source(pydantic.BaseModel): + """Type hints for the "_source" field in ES Search Results.""" + + title: str + section: str | None = None + + +class _Highlight(pydantic.BaseModel): + """Type hints for the "highlight" field in ES Search Results.""" + + text: list[str] + + +class _SearchResult(pydantic.BaseModel): + """Type hints for knowledge base search result.""" + + source: _Source = pydantic.Field(alias="_source") + highlight: _Highlight + + def __repr__(self) -> str: + return self.model_dump_json(indent=2) + + +SearchResults = list[_SearchResult] + + +class AsyncWeaviateKnowledgeBase: + """Configurable search tools for Weaviate knowledge base.""" + + def __init__( + self, + async_client: WeaviateAsyncClient, + collection_name: str, + num_results: int = 5, + snippet_length: int = 1000, + max_concurrency: int = 3, + embedding_model_name: str = "@cf/baai/bge-m3", + embedding_api_key: str | None = None, + embedding_base_url: str | None = None, + ) -> None: + self.async_client = async_client + self.collection_name = collection_name + self.num_results = num_results + self.snippet_length = snippet_length + self.logger = logging.getLogger(__name__) + self.semaphore = asyncio.Semaphore(max_concurrency) + + self.embedding_model_name = embedding_model_name + self.embedding_api_key = embedding_api_key + self.embedding_base_url = embedding_base_url + + self._embed_client = openai.OpenAI( + api_key=self.embedding_api_key or os.getenv("EMBEDDING_API_KEY"), + base_url=self.embedding_base_url or os.getenv("EMBEDDING_BASE_URL"), + max_retries=5, + ) + + @backoff.on_exception(backoff.expo, exception=asyncio.CancelledError) # type: ignore + async def search_knowledgebase(self, keyword: str) -> SearchResults: + """Search knowledge base. + + Parameters + ---------- + keyword : str + The search keyword to query the knowledge base. + + Returns + ------- + SearchResults + A list of search results. Each result contains source and highlight. + If no results are found, returns an empty list. + + Raises + ------ + Exception + If Weaviate is not ready to accept requests (HTTP 503). + + """ + async with self.async_client: + if not await self.async_client.is_ready(): + raise Exception("Weaviate is not ready to accept requests (HTTP 503).") + + collection = self.async_client.collections.get(self.collection_name) + vector = self._vectorize(keyword) + response = await rate_limited( + lambda: collection.query.hybrid(keyword, vector=vector, limit=self.num_results), + semaphore=self.semaphore, + ) + + self.logger.info(f"Query: {keyword}; Returned matches: {len(response.objects)}") + + hits = [] + for obj in response.objects: + hit = { + "_source": { + "title": obj.properties.get("title", ""), + "section": obj.properties.get("section", None), + }, + "highlight": {"text": [obj.properties.get("text", "")[: self.snippet_length]]}, + } + hits.append(hit) + + return [_SearchResult.model_validate(_hit) for _hit in hits] + + def _vectorize(self, text: str) -> list[float]: + """Vectorize text using the embedding client. + + Parameters + ---------- + text : str + The text to be vectorized. + + Returns + ------- + list[float] + A list of floats representing the vectorized text. + """ + response = self._embed_client.embeddings.create(input=text, model=self.embedding_model_name) + return response.data[0].embedding + + +def get_weaviate_async_client( + http_host: str | None = None, + http_port: int | None = None, + http_secure: bool = False, + grpc_host: str | None = None, + grpc_port: int | None = None, + grpc_secure: bool = False, + api_key: str | None = None, + headers: dict[str, str] | None = None, + additional_config: AdditionalConfig | None = None, + skip_init_checks: bool = False, +) -> WeaviateAsyncClient: + """Get an async Weaviate client. + + If no parameters are provided, the function will attempt to connect to a local + Weaviate instance using environment variables. + + Parameters + ---------- + http_host : str, optional, default=None + The HTTP host for the Weaviate instance. If not provided, defaults to the + `WEAVIATE_HTTP_HOST` environment variable or "localhost" if the environment + variable is not set. + http_port : int, optional, default=None + The HTTP port for the Weaviate instance. If not provided, defaults to the + `WEAVIATE_HTTP_PORT` environment variable or 8080 if the environment variable + is not set. + http_secure : bool, optional, default=False + Whether to use HTTPS for the HTTP connection. Defaults to the + `WEAVIATE_HTTP_SECURE` environment variable or `False` if the environment + variable is not set. + grpc_host : str, optional, default=None + The gRPC host for the Weaviate instance. If not provided, defaults to the + `WEAVIATE_GRPC_HOST` environment variable or "localhost" if the environment + variable is not set. + grpc_port : int, optional, default=None + The gRPC port for the Weaviate instance. If not provided, defaults to the + `WEAVIATE_GRPC_PORT` environment variable or 50051 if the environment variable + is not set. + grpc_secure : bool, optional, default=False + Whether to use secure gRPC. Defaults to the `WEAVIATE_GRPC_SECURE` environment + variable or `False` if the environment variable is not set. + api_key : str, optional, default=None + The API key for authentication with Weaviate. If not provided, defaults to the + `WEAVIATE_API_KEY` environment variable. + headers : dict[str, str], optional, default=None + Additional headers to include in the request. + additional_config : AdditionalConfig, optional, default=None + Additional configuration for the Weaviate client. + skip_init_checks : bool, optional, default=False + Whether to skip initialization checks. + + Returns + ------- + WeaviateAsyncClient + An asynchronous Weaviate client configured with the provided parameters. + """ + return weaviate.use_async_with_custom( + http_host=http_host or os.getenv("WEAVIATE_HTTP_HOST", "localhost"), + http_port=http_port or int(os.getenv("WEAVIATE_HTTP_PORT", "8080")), + http_secure=http_secure or os.getenv("WEAVIATE_HTTP_SECURE", "false").lower() == "true", + grpc_host=grpc_host or os.getenv("WEAVIATE_GRPC_HOST", "localhost"), + grpc_port=grpc_port or int(os.getenv("WEAVIATE_GRPC_PORT", "50051")), + grpc_secure=grpc_secure or os.getenv("WEAVIATE_GRPC_SECURE", "false").lower() == "true", + auth_credentials=api_key or os.getenv("WEAVIATE_API_KEY"), + headers=headers, + additional_config=additional_config, + skip_init_checks=skip_init_checks, + ) diff --git a/ABB-Manual-Assistant/utils/tools/news_events.py b/ABB-Manual-Assistant/utils/tools/news_events.py new file mode 100644 index 0000000..fe363bc --- /dev/null +++ b/ABB-Manual-Assistant/utils/tools/news_events.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 +"""Fetch and parse Wikipedia Current Events into structured data using Pydantic.""" + +from __future__ import annotations + +import argparse +import asyncio +import random +from collections import defaultdict +from datetime import date, timedelta +from typing import Any + +import httpx +from bs4 import BeautifulSoup +from pydantic import BaseModel, RootModel +from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn + + +class NewsEvent(BaseModel): + """Represents a single current event item.""" + + date: date + category: str + description: str + + +class CurrentEvents(RootModel): + """Mapping of event category to a list of Event items.""" + + root: dict[str, list[NewsEvent]] + + +async def _fetch_current_events_html() -> str: + """ + Retrieve the HTML for the Wikipedia Current Events page for a given date. + + Returns + ------- + Raw HTML string of the parsed page. + """ + # pick a random month between January and May + # (the knowledge base is not updated after May 30, 2025) + # and a random day in that month + random.seed(42) + random_date = date(2025, 1, 1) + timedelta(days=random.randint(0, (date(2025, 5, 20) - date(2025, 1, 1)).days)) + # convert to Year_Month_day format (example: 2025_May_6) + date_str = random_date.strftime("%Y_%B_%d") + + api_url = "https://en.wikipedia.org/w/api.php" + params = { + "action": "parse", + "page": f"Portal:Current_events/{date_str}", + "prop": "text", + "format": "json", + } + client = httpx.AsyncClient() + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + TimeElapsedColumn(), + ) as progress: + progress.add_task("GET wikipedia/Portal:Current_events...") + resp = await client.get(api_url, params=params) + + resp.raise_for_status() + data = resp.json() + return data["parse"]["text"]["*"] + + +def _parse_current_events(html: str) -> dict[str, list[NewsEvent]]: + """ + Parse the HTML of the Wikipedia Current Events portal and extract a list of events. + + Args: + html: The HTML content of the portal or date subpage. + + Returns + ------- + A dict mapping category -> list of Events + """ + soup = BeautifulSoup(html, "lxml") + events_by_category: dict[str, list[NewsEvent]] = defaultdict(list) + # Find each date block + date_divs = soup.find_all("div", class_="current-events-main vevent") + + for date_div in date_divs: + date_div: Any + # Extract ISO date + date_span = date_div.find("span", class_="bday") + date_str = date_span.get_text(strip=True) if date_span else "" + + # Find the content section + content_div = date_div.find("div", class_="current-events-content") + if not content_div: + continue + + # Iterate through each category heading and its events + for p_tag in content_div.find_all("p"): + b_tag = p_tag.find("b") + if not b_tag: + continue + category = b_tag.get_text(strip=True) + + # The next sibling
    contains the list of events for this category + ul = p_tag.find_next_sibling(lambda tag: tag.name == "ul") + if not ul: + continue + + # Iterate top-level list items as individual events + for li in ul.find_all("li", recursive=False): + # Join all text fragments for a clean description + description = " ".join(li.stripped_strings) + events_by_category[category].append( + NewsEvent( + date=date.fromisoformat(date_str), + category=category, + description=description, + ) + ) + + return events_by_category + + +async def get_news_events() -> CurrentEvents: + """Return a list of current news events from the English Wikipedia. + + Returns + ------- + dict mapping category of news events to list of news headlines. + """ + html = await _fetch_current_events_html() + events_dict = _parse_current_events(html) + + return CurrentEvents.model_validate(events_dict) + + +async def main() -> None: + """Fetch, parse, and output events as JSON.""" + parser = argparse.ArgumentParser(description="Fetch and parse Wikipedia Current Events into structured JSON.") + parser.add_argument("--output", "-o", help="Output JSON file path (default: stdout)") + args = parser.parse_args() + + news_events = await get_news_events() + output = news_events.model_dump_json(indent=2) + + if args.output: + with open(args.output, "w", encoding="utf-8") as f: + f.write(output) + else: + print(output) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/ABB-Manual-Assistant/utils/trees.py b/ABB-Manual-Assistant/utils/trees.py new file mode 100644 index 0000000..e5dade0 --- /dev/null +++ b/ABB-Manual-Assistant/utils/trees.py @@ -0,0 +1,24 @@ +"""Utils for handling nested dict.""" + +from typing import Any, Callable, TypeVar + + +Tree = TypeVar("Tree", bound=dict) + + +def tree_filter( + data: Tree, + criteria_fn: Callable[[Any], bool] = lambda x: x is not None, +) -> Tree: + """Keep only leaves for which criteria is True. + + Filters out None leaves if criteria is not specified. + """ + output: Tree = {} # type: ignore[reportAssignType] + for k, v in data.items(): + if isinstance(v, dict): + output[k] = tree_filter(v, criteria_fn=criteria_fn) + elif criteria_fn(v): + output[k] = v + + return output diff --git a/ABB-Manual-Assistant/workorder_agent.py b/ABB-Manual-Assistant/workorder_agent.py new file mode 100644 index 0000000..5700c6a --- /dev/null +++ b/ABB-Manual-Assistant/workorder_agent.py @@ -0,0 +1,35 @@ +import os + +import agents +from dotenv import load_dotenv +from openai import AsyncOpenAI + + +load_dotenv() + + +class WorkorderAgent: + def __init__(self): + self.client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_BASE_URL")) + + self.workorder_agent = agents.Agent( + name="Workorder Agent", + instructions=""" + Given a conversation with an ABB robot manual assistant agent you should create a workorder that states: + + 1) Workorder Title: A descriptive title for the workorder. + 2) Error/Issue: the error or issue that occured. + 3) Work completed: the work/action that the user has taken or will take to resolve the error. + + Only use information from the user's conversation with the ABB robot assistant agent to complete the workorder. + Your response should include only the created workorder and nothing else. Provide as much detail as possible. + """, + model=agents.OpenAIChatCompletionsModel( + model="gemini-2.5-flash-lite-preview-06-17", openai_client=self.client + ), + model_settings=agents.ModelSettings(temperature=0.5), + ) + + async def run(self, prompt: str) -> str: + response = await agents.Runner.run(self.workorder_agent, input=prompt) + return response