Create an LLM agent from a graph containing a cycle using the LangChain ecosystem LangGraph

llm

LangGraph is an ecosystem of LangChain, and a Runnable is created by adding nodes to a StateGraph or MessageGraph whose state is List[BaseMessage], connecting them with add_edge() or add_conditional_edges(), and calling compile(). LangChain Expression Language (LCEL) can also connect Runnables to create a DAG, but LangGraph can also express cycles.

from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage
from langgraph.graph import END, StateGraph
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode
from typing import TypedDict, Annotated
import json

def add_messages(left: list, right: list):
    return left + right

class AgentState(TypedDict):
    messages: Annotated[list, add_messages]

# export OPENAI_API_KEY=your-api-key
model = ChatOpenAI(model="gpt-4-turbo", temperature=0)
# graph = MessageGraph()
graph = StateGraph(AgentState)

@tool
def multiply(left: int, right: int):
    """Multiplies two numbers together."""
    return left * right

model_with_tools = model.bind_tools([multiply])

def call_model(state: AgentState) -> AgentState:
    messages = state['messages']
    return {'messages': [model_with_tools.invoke(messages)]}

NODE_AGENT = 'agent'
NODE_MULTIPLY = 'multiply'

graph.add_node(NODE_AGENT, call_model)
graph.add_node(NODE_MULTIPLY, ToolNode([multiply]))

graph.set_entry_point(NODE_AGENT)

def router(state: AgentState) -> str:
    messages = state['messages']
    last_message = messages[-1]
    if hasattr(last_message, 'tool_calls') and len(last_message.tool_calls) > 0:
        return last_message.tool_calls[0]['name']
    else:
        return END

graph.add_conditional_edges(NODE_AGENT, router)
graph.add_edge(NODE_MULTIPLY, NODE_AGENT) # cycle

runnable = graph.compile()

print("> ", end='')
question = input()

initial_state: AgentState = {'messages': [HumanMessage(content=question)]}
result = runnable.invoke(initial_state)
for message in result['messages']:
    print(f"### {type(message).__name__}")
    if len(message.content) > 0:
        print(f"content: {message.content}")
    if hasattr(message, 'tool_calls'):
        print(f"tool_calls: {json.dumps(message.tool_calls)}")

When executed, you can see that expected arguments are passed to the ToolNode and the results are returned to the original node. ChatOpenAI selects tools and parameters with Function calling feature. At this time, the docstring of tools passed to bind_tools() is used as description.

> A baseball team has 9 players. Each player needs a bat, a glove, and a helmet. There are 6 teams in a league now. How many items are there in a league?
### HumanMessage
content: A baseball team has 9 players. Each player needs a bat, a glove, and a helmet. There are 6 teams in a league now. How many items are there in a league?
### AIMessage
tool_calls: [{"name": "multiply", "args": {"left": 9, "right": 3}, "id": "call_iOcQoW9cCuiMFVt1js4BL5mt"}, {"name": "multiply", "args": {"left": 6, "right": 27}, "id": "call_DNLbzD25ODxwqdott3VPQbJx"}]
### ToolMessage
content: 27
### ToolMessage
content: 162
### AIMessage
content: Each player needs 3 items (a bat, a glove, and a helmet). For a team of 9 players, this amounts to \(9 \times 3 = 27\) items per team. With 6 teams in the league, the total number of items is \(6 \times 27 = 162\).

Therefore, there are 162 items in total in the league.
tool_calls: []