Skip to content

Commit 8761607

Browse files
committed
feat: migrate agent to LangGraph and bump version to 2.1.0
- Migrate agent implementation from LangChain's legacy `AgentExecutor` to LangGraph's `create_react_agent`. - Update dependencies in `pyproject.toml` to include `langgraph` and newer versions of `langchain`. - Refactor `create_commit_agent` to handle system prompt injection within the graph-based architecture. - Set a default value for the `deep` flag in the CLI. - Update unit tests to support the new agent execution flow and response structure.
1 parent 6ff78fd commit 8761607

6 files changed

Lines changed: 73 additions & 33 deletions

File tree

commitai/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
# This __version__ string is read by hatchling during the build process
55
# Make sure to update it for new releases.
6-
__version__ = "2.0.0"
6+
__version__ = "2.1.0"
77

88
# The importlib.metadata approach is generally for reading the version
99
# of an *already installed* package at runtime. We don't need it here

commitai/agent.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import subprocess
44
from typing import Any, Dict, Type
55

6-
from langchain.agents import AgentExecutor, create_tool_calling_agent
6+
# from langchain.agents import AgentExecutor, create_tool_calling_agent # Removed
77
from langchain_core.language_models import BaseChatModel
8-
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
98
from langchain_core.runnables import Runnable
109
from langchain_core.tools import BaseTool
10+
from langgraph.prebuilt import create_react_agent
1111
from pydantic import BaseModel, Field
1212

1313
# --- TOOLS ---
@@ -157,6 +157,9 @@ def process(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
157157
# --- AGENT ---
158158

159159

160+
# --- AGENT ---
161+
162+
160163
def create_commit_agent(llm: BaseChatModel) -> Runnable:
161164
# 1. Init Tools
162165
tools = [ReadOnlyShellTool(), FileSearchTool(), FileReadTool()]
@@ -186,18 +189,22 @@ def create_commit_agent(llm: BaseChatModel) -> Runnable:
186189
3. If clarification is needed, explore files.
187190
4. Final Answer MUST be ONLY the commit message.
188191
"""
189-
prompt = ChatPromptTemplate.from_messages(
190-
[
191-
("system", system_prompt),
192-
MessagesPlaceholder("chat_history", optional=True),
193-
("human", "Generate the commit message."),
194-
MessagesPlaceholder("agent_scratchpad"),
195-
]
196-
)
192+
# Note: create_react_agent handles the prompt internally or via state_modifier.
193+
# We can pass a system string or a function. Since our prompt depends on dynamic
194+
# variables (diff, explanation, etc.), we need to inject them. LangGraph's
195+
# prebuilt agent usually takes a static system message. However, we can use the
196+
# 'messages' state. But to keep it simple and compatible with existing 'invoke'
197+
# interface: We will format the system prompt in the wrapper and pass it as the
198+
# first message.
199+
200+
# Actually, create_react_agent supports 'state_modifier'.
201+
# If we pass a formatted string, it works as system prompt.
202+
203+
# 4. Construct Graph
204+
# We don't construct the graph with ALL variables pre-bound if they change per run.
205+
# Instead, we'll format the prompt in the pipeline and pass it to the agent.
197206

198-
# 4. Construct Agent
199-
agent = create_tool_calling_agent(llm, tools, prompt)
200-
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=False)
207+
agent_graph = create_react_agent(llm, tools)
201208

202209
# 5. Pipeline with Middleware
203210
def run_pipeline(inputs: Dict[str, Any]) -> str:
@@ -210,11 +217,34 @@ def run_pipeline(inputs: Dict[str, Any]) -> str:
210217
state.setdefault("explanation", "None")
211218
state.setdefault("summary", "None")
212219
state.setdefault("todo_str", "None")
213-
state.setdefault("chat_history", [])
220+
221+
# Format System Prompt
222+
formatted_system_prompt = system_prompt.format(
223+
explanation=state["explanation"],
224+
todo_str=state["todo_str"],
225+
summary=state["summary"],
226+
diff=state.get("diff", ""),
227+
)
214228

215229
# Run Agent
216-
result = agent_executor.invoke(state)
217-
return str(result["output"])
230+
# LangGraph inputs: {"messages": [{"role": "user", "content": ...}]}
231+
# We inject the system prompt as a SystemMessage or just update the state.
232+
# create_react_agent primarily looks at 'messages'.
233+
234+
from langchain_core.messages import HumanMessage, SystemMessage
235+
236+
messages = [
237+
SystemMessage(content=formatted_system_prompt),
238+
HumanMessage(content="Generate the commit message."),
239+
]
240+
241+
# Invoke graph
242+
# result is a dict with 'messages'
243+
result = agent_graph.invoke({"messages": messages})
244+
245+
# Extract last message content
246+
last_message = result["messages"][-1]
247+
return str(last_message.content)
218248

219249
# Wrap in RunnableLambda to expose 'invoke'
220250
from langchain_core.runnables import RunnableLambda

commitai/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def generate_message(
197197
template: Optional[str],
198198
add: bool,
199199
model: str,
200-
deep: bool,
200+
deep: bool = False,
201201
) -> None:
202202
explanation = " ".join(description)
203203

pyproject.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "hatchling.build"
77
name = "commitai"
88
# Make sure to update version in commitai/__init__.py as well
99

10-
version = "2.0.0"
10+
version = "2.1.0"
1111

1212
description = "Commitai helps you generate git commit messages using AI"
1313
readme = "README.md"
@@ -30,10 +30,11 @@ classifiers = [
3030
]
3131
dependencies = [
3232
"click>=8.1",
33-
"langchain>=0.2",
33+
"langchain>=1.0",
34+
"langchain-community>=0.4.0",
3435
"langchain-core>=1.0",
35-
"langchain-community>=0.2",
3636
"langchain-google-genai>=1.0",
37+
"langgraph>=1.0.0",
3738
"pydantic>=2.0",
3839
]
3940

tests/test_agent_chains.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,25 @@ def test_todo_scanner_middleware(mock_llm):
6262

6363
def test_create_commit_agent(mock_llm):
6464
# This tests the factory function
65-
agent_executor = create_commit_agent(mock_llm)
66-
assert agent_executor is not None
67-
# We can try to invoke it if we mock enough stuff
68-
# But just creating it covers the definition lines.
65+
# Mocking create_react_agent to avoid actual graph compilation
66+
with patch("commitai.agent.create_react_agent") as mock_create_graph:
67+
mock_graph = MagicMock()
68+
mock_create_graph.return_value = mock_graph
69+
agent_executor = create_commit_agent(mock_llm)
70+
assert agent_executor is not None
6971

7072

7173
def test_agent_run(mock_llm):
7274
# E2E-ish test of the agent logic with mocks
73-
with patch("commitai.agent.AgentExecutor") as MockExecutor:
74-
mock_executor_instance = MockExecutor.return_value
75-
mock_executor_instance.invoke.return_value = {"output": "Final Commit Message"}
75+
with patch("commitai.agent.create_react_agent") as mock_create_graph:
76+
mock_graph = MagicMock()
77+
mock_create_graph.return_value = mock_graph
78+
79+
# Determine strict return structure for LangGraph invoke
80+
# It yields a dict with "messages" list
81+
last_message = MagicMock()
82+
last_message.content = "Final Commit Message"
83+
mock_graph.invoke.return_value = {"messages": [last_message]}
7684

7785
agent_runnable = create_commit_agent(mock_llm)
7886
result = agent_runnable.invoke({"diff": "diff", "explanation": "expl"})

tests/test_cli.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def mock_generate_deps(tmp_path):
4141
"commitai.cli.get_current_branch_name", return_value="main"
4242
) as mock_branch,
4343
patch("commitai.cli.create_commit") as mock_commit,
44+
# Update mock target for agent creation
4445
patch("commitai.cli.create_commit_agent") as mock_create_agent,
4546
patch("click.edit") as mock_edit,
4647
patch("click.clear"),
@@ -57,10 +58,10 @@ def mock_generate_deps(tmp_path):
5758

5859
mock_google_instance = mock_google_class_in_cli.return_value
5960

60-
# Agent Mock
61-
mock_agent_instance = MagicMock()
62-
mock_agent_instance.invoke.return_value = "Generated commit message"
63-
mock_create_agent.return_value = mock_agent_instance
61+
# Agent Mock (RunnableLambda now)
62+
mock_agent_runnable = MagicMock()
63+
mock_agent_runnable.invoke.return_value = "Generated commit message"
64+
mock_create_agent.return_value = mock_agent_runnable
6465

6566
if mock_google_class_in_cli is not None:
6667
mock_google_instance.spec = ActualChatGoogleGenerativeAI
@@ -93,7 +94,7 @@ def getenv_side_effect(key, default=None):
9394
"path_exists": mock_path_exists,
9495
"commit_msg_path": fake_commit_msg_path,
9596
"create_agent": mock_create_agent,
96-
"agent_instance": mock_agent_instance,
97+
"agent_instance": mock_agent_runnable, # Still useful alias for tests
9798
"confirm": mock_confirm,
9899
}
99100

0 commit comments

Comments
 (0)