LangChain のエコシステム LangGraph で cycle を含むグラフから LLM のエージェントを作る

llm

LangGraph は LangChain のエコシステムで、次のように StateGraph あるいは List[BaseMessage] をステートとする MessageGraph にノードを追加し add_edge() や add_conditional_edges() で繋げて compile() することで Runnable を作る。LangChain Expression Language (LCEL) でも Runnable を繋げて DAG を作ることはできるが LangGraph だと cycle も表現できる。

from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage
from langgraph.graph import END, StateGraph
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode
from typing import TypedDict, Annotated
import json

def add_messages(left: list, right: list):
    return left + right

class AgentState(TypedDict):
    messages: Annotated[list, add_messages]

# export OPENAI_API_KEY=your-api-key
model = ChatOpenAI(model="gpt-4-turbo", temperature=0)
# graph = MessageGraph()
graph = StateGraph(AgentState)

@tool
def multiply(left: int, right: int):
    """Multiplies two numbers together."""
    return left * right

model_with_tools = model.bind_tools([multiply])

def call_model(state: AgentState) -> AgentState:
    messages = state['messages']
    return {'messages': [model_with_tools.invoke(messages)]}

NODE_AGENT = 'agent'
NODE_MULTIPLY = 'multiply'

graph.add_node(NODE_AGENT, call_model)
graph.add_node(NODE_MULTIPLY, ToolNode([multiply]))

graph.set_entry_point(NODE_AGENT)

def router(state: AgentState) -> str:
    messages = state['messages']
    last_message = messages[-1]
    if hasattr(last_message, 'tool_calls') and len(last_message.tool_calls) > 0:
        return last_message.tool_calls[0]['name']
    else:
        return END

graph.add_conditional_edges(NODE_AGENT, router)
graph.add_edge(NODE_MULTIPLY, NODE_AGENT) # cycle

runnable = graph.compile()

print("> ", end='')
question = input()

initial_state: AgentState = {'messages': [HumanMessage(content=question)]}
result = runnable.invoke(initial_state)
for message in result['messages']:
    print(f"### {type(message).__name__}")
    if len(message.content) > 0:
        print(f"content: {message.content}")
    if hasattr(message, 'tool_calls'):
        print(f"tool_calls: {json.dumps(message.tool_calls)}")

実行すると ToolNode に意図した引数が渡り、その結果が元のノードに返っていることが確認できる。 呼び出される Tool とパラメータは ChatOpenAI の場合 Function calling 機能によって選択される。この際 bind_tools() した Tool の docstring が description として渡される

> A baseball team has 9 players. Each player needs a bat, a glove, and a helmet. There are 6 teams in a league now. How many items are there in a league?
### HumanMessage
content: A baseball team has 9 players. Each player needs a bat, a glove, and a helmet. There are 6 teams in a league now. How many items are there in a league?
### AIMessage
tool_calls: [{"name": "multiply", "args": {"left": 9, "right": 3}, "id": "call_iOcQoW9cCuiMFVt1js4BL5mt"}, {"name": "multiply", "args": {"left": 6, "right": 27}, "id": "call_DNLbzD25ODxwqdott3VPQbJx"}]
### ToolMessage
content: 27
### ToolMessage
content: 162
### AIMessage
content: Each player needs 3 items (a bat, a glove, and a helmet). For a team of 9 players, this amounts to \(9 \times 3 = 27\) items per team. With 6 teams in the league, the total number of items is \(6 \times 27 = 162\).

Therefore, there are 162 items in total in the league.
tool_calls: []