Langchain の SQL Database Toolkits の AgentExecutor で自然言語からSQLを生成しデータベースにアクセスする
llmLangchain の SQL Database Toolkits は SQL を生成しデータベースにアクセスする機能を提供する。
create_sql_agent() は SQLDatabaseToolkit.get_tools() の
- query_sql_database_tool: SQLを実行する
- info_sql_database_tool: テーブルのスキーマを取得する
- list_sql_database_tool: テーブル名の一覧を取得する
- query_sql_checker_tool: SQLが誤っている場合は修正する
を tools として追加しプロンプトに埋めるなどした LLM の AgentExecutor を返す。
LangChain のエコシステム LangGraph で cycle を含むグラフから LLM のエージェントを作る - sambaiz-net
ドキュメントと同じく Chinook database で実行してみる。
from langchain_community.agent_toolkits import create_sql_agent
from langchain_openai import ChatOpenAI
from langchain_community.utilities.sql_database import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///Chinook_Sqlite.sqlite")
llm = ChatOpenAI(model="gpt-4-turbo", temperature=0)
agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)
result = agent_executor.invoke(
"List the total sales per country. Which country's customers spent the most?"
)
print(result)
verbose=True で実行しているので過程も出力される。テーブル名の一覧から Customer と Invoice テーブルのスキーマを取得し JOIN するクエリを生成して データを取得していることが確認できる。メタデータもなしに必要なテーブルだけがピンポイントで取得できているが、これはある程度一般的な命名とテーブル構造に基づいているものだと考えている。
$ curl -LO https://github.com/lerocha/chinook-database/releases/download/v1.4.5/Chinook_Sqlite.sqlite
$ poetry run python main.py
> Entering new SQL Agent Executor chain...
Invoking: `sql_db_list_tables` with `{}`
Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
Invoking: `sql_db_schema` with `{'table_names': 'Customer, Invoice'}`
CREATE TABLE "Customer" (
"CustomerId" INTEGER NOT NULL,
"FirstName" NVARCHAR(40) NOT NULL,
"LastName" NVARCHAR(20) NOT NULL,
"Company" NVARCHAR(80),
"Address" NVARCHAR(70),
"City" NVARCHAR(40),
"State" NVARCHAR(40),
"Country" NVARCHAR(40),
"PostalCode" NVARCHAR(10),
"Phone" NVARCHAR(24),
"Fax" NVARCHAR(24),
"Email" NVARCHAR(60) NOT NULL,
"SupportRepId" INTEGER,
PRIMARY KEY ("CustomerId"),
FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)
/*
3 rows from Customer table:
CustomerId FirstName LastName Company Address City State Country PostalCode Phone Fax Email SupportRepId
1 Luís Gonçalves Embraer - Empresa Brasileira de Aeronáutica S.A. Av. Brigadeiro Faria Lima, 2170 São José dos Campos SP Brazil 12227-000 +55 (12) 3923-5555 +55 (12) 3923-5566 [email protected] 3
2 Leonie Köhler None Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 +49 0711 2842222 None [email protected] 5
3 François Tremblay None 1498 rue Bélanger Montréal QC Canada H2G 1A7 +1 (514) 721-4711 None [email protected] 3
*/
CREATE TABLE "Invoice" (
"InvoiceId" INTEGER NOT NULL,
"CustomerId" INTEGER NOT NULL,
"InvoiceDate" DATETIME NOT NULL,
"BillingAddress" NVARCHAR(70),
"BillingCity" NVARCHAR(40),
"BillingState" NVARCHAR(40),
"BillingCountry" NVARCHAR(40),
"BillingPostalCode" NVARCHAR(10),
"Total" NUMERIC(10, 2) NOT NULL,
PRIMARY KEY ("InvoiceId"),
FOREIGN KEY("CustomerId") REFERENCES "Customer" ("CustomerId")
)
/*
3 rows from Invoice table:
InvoiceId CustomerId InvoiceDate BillingAddress BillingCity BillingState BillingCountry BillingPostalCode Total
1 2 2009-01-01 00:00:00 Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 1.98
2 4 2009-01-02 00:00:00 Ullevålsveien 14 Oslo None Norway 0171 3.96
3 8 2009-01-03 00:00:00 Grétrystraat 63 Brussels None Belgium 1000 5.94
*/
Invoking: `sql_db_query_checker` with `{'query': 'SELECT c.Country, SUM(i.Total) AS TotalSales FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY TotalSales DESC'}`
\```sql
SELECT c.Country, SUM(i.Total) AS TotalSales FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY TotalSales DESC
\```
Invoking: `sql_db_query` with `{'query': 'SELECT c.Country, SUM(i.Total) AS TotalSales FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.Country ORDER BY TotalSales DESC'}`
[('USA', 523.06), ('Canada', 303.96), ('France', 195.1), ('Brazil', 190.1), ('Germany', 156.48), ('United Kingdom', 112.86), ('Czech Republic', 90.24), ('Portugal', 77.24), ('India', 75.26), ('Chile', 46.62), ('Ireland', 45.62), ('Hungary', 45.62), ('Austria', 42.62), ('Finland', 41.62), ('Netherlands', 40.62), ('Norway', 39.62), ('Sweden', 38.62), ('Spain', 37.62), ('Poland', 37.62), ('Italy', 37.62), ('Denmark', 37.62), ('Belgium', 37.62), ('Australia', 37.62), ('Argentina', 37.62)]The total sales per country are as follows:
1. USA: $523.06
2. Canada: $303.96
3. France: $195.10
4. Brazil: $190.10
5. Germany: $156.48
6. United Kingdom: $112.86
7. Czech Republic: $90.24
8. Portugal: $77.24
9. India: $75.26
10. Chile: $46.62
The country whose customers spent the most is the USA.
> Finished chain.
{'input': "List the total sales per country. Which country's customers spent the most?", 'output': 'The total sales per country are as follows:\n\n1. USA: $523.06\n2. Canada: $303.96\n3. France: $195.10\n4. Brazil: $190.10\n5. Germany: $156.48\n6. United Kingdom: $112.86\n7. Czech Republic: $90.24\n8. Portugal: $77.24\n9. India: $75.26\n10. Chile: $46.62\n\nThe country whose customers spent the most is the USA.'}
もう少し曖昧で複雑な質問をしても”最も価値ある顧客というのは最もお金を払っている人だ”のような解釈をしてそれらしい回答を返せた。
# Who is the most valuable customer in this company? Why? What did they buy the most?
Helena Holý's most frequently purchased item is the track "New Year's Day," which she bought once.
In summary:
- The most valuable customer is Helena Holý, who has spent a total of $49.62.
- The product she bought the most is the track titled "New Year's Day."
SQLDatabase は内部で SQLAlchemy を使っているので sqlalchemy.dialects を提供する PyAthena のようなライブラリをインストールすることで様々なデータソースに接続できる。
s3_staging_dir = "s3://<bucket>"
database = "<database>"
conn_str = f"awsathena+rest://@athena.ap-northeast-1.amazonaws.com:443/{database}?s3_staging_dir={s3_staging_dir}"
db = SQLDatabase.from_uri(conn_str)
print('execute', db.run("SELECT * FROM <database>.<table> LIMIT 10"))
print('get_usable_table_names', db.get_usable_table_names())
print('get_table_info', db.get_table_info(table_names=['test']))
'''
CREATE EXTERNAL TABLE test (
...
)
PARTITIONED BY (
...
)
...
/*
3 rows from test table:
...
'''