08_实战项目 - 智能TextToSql应用

综合实战!构建一个完整的智能TextToSql系统,用户用自然语言提问,AI自动生成SQL查询并返回结果。项目涵盖意图识别、实体抽取、知识检索、SQL生成与验证等完整流程,并提供FastAPI接口和Web界面。


🎯 项目目标

构建一个企业级的TextToSql系统,包含:

  1. 意图识别:理解用户想查询什么
  2. 实体拆解:提取关键信息(时间、指标、维度等)
  3. 知识检索:查询字段映射和示例SQL
  4. SQL生成:根据意图生成准确的SQL
  5. SQL验证:检查SQL的正确性和安全性
  6. 结果呈现:格式化查询结果
  7. FastAPI接口:RESTful API
  8. Web界面:用户友好的交互界面

📁 项目结构

完整代码地址:https://gitee.com/uyynot/texttosqlproject.git

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
TextToSqlProject/
├── README.md # 项目文档
├── app/
│ ├── __init__.py
│ ├── main.py # FastAPI应用入口
│ ├── config.py # 配置文件
│ ├── models.py # 数据模型
│ └── api/
│ └── routes.py # API路由
├── core/
│ ├── __init__.py
│ ├── llm.py # LLM封装
│ ├── graph.py # LangGraph工作流
│ └── nodes/
│ ├── __init__.py
│ ├── intent.py # 意图识别节点
│ ├── entity.py # 实体拆解节点
│ ├── knowledge.py # 知识检索节点
│ ├── sql_gen.py # SQL生成节点
│ ├── sql_validate.py # SQL验证节点
│ └── result.py # 结果呈现节点
├── knowledge/
│ ├── schema.json # 数据库Schema
│ ├── examples.json # SQL示例
│ └── entities.json # 实体映射
├── database/
│ ├── __init__.py
│ └── sample_data.db # 示例数据库
├── static/
│ └── index.html # Web界面
├── tests/
│ └── test_graph.py # 测试用例
├── requirements.txt # 依赖包
└── run.py # 启动脚本

🚀 快速开始

1. 安装依赖

1
2
cd demo
pip install -r requirements.txt

2. 配置环境

创建 .env 文件:

1
2
DASHSCOPE_API_KEY=your_api_key_here
DATABASE_PATH=./database/sample_data.db

3. 初始化数据库

1
python -m database.init_db

4. 启动服务

1
python run.py

访问:http://localhost:8000


💡 核心流程

整体流程图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
用户输入

意图识别 ──→ 实体拆解
↓ ↓
知识检索 ←──────┘

SQL生成

SQL验证 ──→ 验证失败 ──→ 重新生成(最多3次)

验证通过

执行查询

结果呈现

状态定义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from typing import TypedDict, List, Dict, Optional

class TextToSqlState(TypedDict):
# 输入
user_query: str # 用户输入

# 意图识别
intent: Optional[str] # 查询意图
intent_confidence: Optional[float] # 意图置信度

# 实体拆解
entities: Optional[Dict] # 提取的实体
time_range: Optional[Dict] # 时间范围
metrics: Optional[List[str]] # 指标
dimensions: Optional[List[str]] # 维度
filters: Optional[Dict] # 过滤条件

# 知识检索
schema_info: Optional[Dict] # 表结构信息
similar_examples: Optional[List] # 相似SQL示例
field_mapping: Optional[Dict] # 字段映射

# SQL生成
generated_sql: Optional[str] # 生成的SQL
sql_explanation: Optional[str] # SQL说明

# SQL验证
is_valid: bool # SQL是否有效
validation_errors: Optional[List] # 验证错误
retry_count: int # 重试次数

# 执行结果
query_result: Optional[List[Dict]] # 查询结果
result_count: Optional[int] # 结果数量

# 最终输出
final_response: Optional[str] # 最终回复
visualization: Optional[Dict] # 可视化配置

📝 节点实现

节点1:意图识别

core/nodes/intent.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""
意图识别节点
识别用户查询的意图类型
"""
from typing import Dict
from ..llm import get_llm


INTENT_TYPES = {
"aggregation": "聚合查询(如:统计、求和、平均)",
"comparison": "对比查询(如:同比、环比、对比)",
"ranking": "排名查询(如:TOP N、排行榜)",
"filtering": "筛选查询(如:查找、筛选)",
"trend": "趋势查询(如:增长、变化趋势)"
}


def intent_recognition_node(state: Dict) -> Dict:
"""
识别用户查询意图

Args:
state: 包含user_query的状态

Returns:
更新后的状态,包含intent和intent_confidence
"""
user_query = state["user_query"]
llm = get_llm()

prompt = f"""
你是一个SQL查询意图识别专家。

用户查询:{user_query}

请识别查询意图,从以下类型中选择:
{chr(10).join(f'- {k}: {v}' for k, v in INTENT_TYPES.items())}

以JSON格式返回:
{{
"intent": "意图类型",
"confidence": 0.95,
"reasoning": "判断理由"
}}
"""

response = llm.invoke(prompt)

# 解析响应(简化版,实际应该用结构化输出)
import json
try:
result = json.loads(response)
state["intent"] = result["intent"]
state["intent_confidence"] = result["confidence"]
except:
state["intent"] = "unknown"
state["intent_confidence"] = 0.0

print(f"[意图识别] {state['intent']} (置信度: {state['intent_confidence']})")

return state

节点2:实体拆解

core/nodes/entity.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""
实体拆解节点
从用户查询中提取关键实体
"""
from typing import Dict
from ..llm import get_llm


def entity_extraction_node(state: Dict) -> Dict:
"""
提取查询中的实体

提取内容:
- 时间范围:今天、本月、2023年等
- 指标:销售额、用户数、增长率等
- 维度:地区、部门、产品类别等
- 过滤条件:年龄>25、金额>1000等
"""
user_query = state["user_query"]
llm = get_llm()

prompt = f"""
从查询中提取关键实体。

查询:{user_query}

提取以下信息,以JSON格式返回:
{{
"time_range": {{
"type": "relative/absolute",
"value": "2023-01-01 TO 2023-12-31"
}},
"metrics": ["销售额", "订单数"],
"dimensions": ["地区", "产品类别"],
"filters": {{
"age": ">25",
"amount": ">1000"
}}
}}

如果某项不存在,使用null。
"""

response = llm.invoke(prompt)

import json
try:
entities = json.loads(response)
state["entities"] = entities
state["time_range"] = entities.get("time_range")
state["metrics"] = entities.get("metrics", [])
state["dimensions"] = entities.get("dimensions", [])
state["filters"] = entities.get("filters", {})
except:
state["entities"] = {}

print(f"[实体拆解] 指标: {state.get('metrics')}, 维度: {state.get('dimensions')}")

return state

节点3:知识检索

core/nodes/knowledge.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
"""
知识检索节点
检索数据库Schema和相似SQL示例
"""
from typing import Dict
import json


def knowledge_retrieval_node(state: Dict) -> Dict:
"""
检索知识库

检索内容:
1. 表结构信息
2. 字段映射(中文名→英文字段名)
3. 相似SQL示例
"""
# 加载Schema(实际应该从向量数据库检索)
with open("knowledge/schema.json", "r", encoding="utf-8") as f:
schema = json.load(f)

# 加载字段映射
with open("knowledge/entities.json", "r", encoding="utf-8") as f:
field_mapping = json.load(f)

# 加载SQL示例
with open("knowledge/examples.json", "r", encoding="utf-8") as f:
examples = json.load(f)

# 根据意图筛选相关示例
intent = state.get("intent", "")
similar_examples = [
ex for ex in examples
if ex.get("intent") == intent
][:3] # 取前3个

state["schema_info"] = schema
state["field_mapping"] = field_mapping
state["similar_examples"] = similar_examples

print(f"[知识检索] 找到{len(similar_examples)}个相似示例")

return state

节点4:SQL生成

core/nodes/sql_gen.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""
SQL生成节点
根据意图、实体和知识库生成SQL
"""
from typing import Dict
from ..llm import get_llm


def sql_generation_node(state: Dict) -> Dict:
"""
生成SQL查询语句
"""
llm = get_llm()

# 构造提示词
prompt = f"""
你是SQL生成专家。根据以下信息生成SQL查询。

用户查询:{state['user_query']}

意图:{state.get('intent')}
实体:
- 指标:{state.get('metrics')}
- 维度:{state.get('dimensions')}
- 时间范围:{state.get('time_range')}
- 过滤条件:{state.get('filters')}

数据库Schema:
{state.get('schema_info')}

字段映射:
{state.get('field_mapping')}

相似SQL示例:
{state.get('similar_examples')}

要求:
1. 生成标准的SQL查询
2. 使用正确的表名和字段名
3. 添加适当的注释
4. 只返回SQL,不要有其他说明

SQL:
"""

response = llm.invoke(prompt)

# 提取SQL(去除可能的markdown标记)
sql = response.strip()
if sql.startswith("```sql"):
sql = sql[6:]
if sql.endswith("```"):
sql = sql[:-3]
sql = sql.strip()

state["generated_sql"] = sql
print(f"[SQL生成]\n{sql}")

return state

节点5:SQL验证

core/nodes/sql_validate.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""
SQL验证节点
验证SQL的正确性和安全性
"""
from typing import Dict
import sqlparse


def sql_validation_node(state: Dict) -> Dict:
"""
验证SQL

检查:
1. 语法正确性
2. 安全性(防止SQL注入)
3. 表名和字段名是否存在
"""
sql = state.get("generated_sql", "")
errors = []

# 1. 语法检查
try:
parsed = sqlparse.parse(sql)
if not parsed:
errors.append("SQL语法错误:无法解析")
except Exception as e:
errors.append(f"SQL语法错误:{e}")

# 2. 安全检查
dangerous_keywords = ["DROP", "DELETE", "UPDATE", "INSERT", "ALTER", "TRUNCATE"]
sql_upper = sql.upper()
for keyword in dangerous_keywords:
if keyword in sql_upper:
errors.append(f"安全检查失败:包含危险关键字 {keyword}")

# 3. 表名检查(简化版)
schema = state.get("schema_info", {})
tables = schema.get("tables", [])
# 实际应该做更详细的检查

# 设置验证结果
if errors:
state["is_valid"] = False
state["validation_errors"] = errors
print(f"[SQL验证] ❌ 验证失败:{errors}")
else:
state["is_valid"] = True
state["validation_errors"] = []
print(f"[SQL验证] ✅ 验证通过")

return state

节点6:结果呈现

core/nodes/result.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""
结果呈现节点
执行SQL并格式化结果
"""
from typing import Dict
import sqlite3


def result_presentation_node(state: Dict) -> Dict:
"""
执行SQL并呈现结果
"""
sql = state.get("generated_sql", "")

try:
# 连接数据库
conn = sqlite3.connect(state.get("database_path", "database/sample_data.db"))
cursor = conn.cursor()

# 执行查询
cursor.execute(sql)
rows = cursor.fetchall()
columns = [desc[0] for desc in cursor.description]

# 转换为字典列表
results = [
dict(zip(columns, row))
for row in rows
]

state["query_result"] = results
state["result_count"] = len(results)

# 生成友好的回复
state["final_response"] = format_response(state)

print(f"[结果呈现] 查询到{len(results)}条记录")

conn.close()

except Exception as e:
state["final_response"] = f"查询执行失败:{e}"
print(f"[结果呈现] ❌ 错误:{e}")

return state


def format_response(state: Dict) -> str:
"""格式化回复"""
results = state.get("query_result", [])

if not results:
return "未查询到相关数据。"

# 构造表格
response = f"查询到{len(results)}条记录:\n\n"

# 取前5条展示
for i, row in enumerate(results[:5], 1):
response += f"{i}. "
response += ", ".join(f"{k}: {v}" for k, v in row.items())
response += "\n"

if len(results) > 5:
response += f"\n... 还有{len(results) - 5}条记录"

return response

🔧 LangGraph工作流

core/graph.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""
TextToSql LangGraph工作流
"""
from typing import Dict
from langgraph.graph import StateGraph, END
from .nodes.intent import intent_recognition_node
from .nodes.entity import entity_extraction_node
from .nodes.knowledge import knowledge_retrieval_node
from .nodes.sql_gen import sql_generation_node
from .nodes.sql_validate import sql_validation_node
from .nodes.result import result_presentation_node


def should_retry_sql(state: Dict) -> str:
"""判断是否需要重新生成SQL"""
if state.get("is_valid", False):
return "execute" # 验证通过,执行查询
elif state.get("retry_count", 0) < 3:
return "regenerate" # 重试
else:
return "failed" # 超过重试次数


def create_text_to_sql_graph():
"""创建TextToSql工作流"""
graph = StateGraph(dict)

# 添加节点
graph.add_node("intent", intent_recognition_node)
graph.add_node("entity", entity_extraction_node)
graph.add_node("knowledge", knowledge_retrieval_node)
graph.add_node("sql_gen", sql_generation_node)
graph.add_node("validate", sql_validation_node)
graph.add_node("result", result_presentation_node)

# 设置流程
graph.set_entry_point("intent")

# 顺序流程
graph.add_edge("intent", "entity")
graph.add_edge("entity", "knowledge")
graph.add_edge("knowledge", "sql_gen")
graph.add_edge("sql_gen", "validate")

# 验证后的条件路由
graph.add_conditional_edges(
"validate",
should_retry_sql,
{
"execute": "result", # 执行查询
"regenerate": "sql_gen", # 重新生成
"failed": END # 失败结束
}
)

graph.add_edge("result", END)

return graph.compile()

🌐 FastAPI接口

app/main.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""
FastAPI应用
"""
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
from core.graph import create_text_to_sql_graph


app = FastAPI(title="智能TextToSql系统")

# 挂载静态文件
app.mount("/static", StaticFiles(directory="static"), name="static")

# 创建工作流
graph = create_text_to_sql_graph()


class QueryRequest(BaseModel):
query: str


class QueryResponse(BaseModel):
sql: str
result: list
explanation: str


@app.get("/")
async def read_root():
"""返回Web界面"""
return FileResponse("static/index.html")


@app.post("/api/query", response_model=QueryResponse)
async def process_query(request: QueryRequest):
"""
处理文本查询,返回SQL和结果
"""
try:
# 执行工作流
result = graph.invoke({
"user_query": request.query,
"retry_count": 0,
"is_valid": False
})

return QueryResponse(
sql=result.get("generated_sql", ""),
result=result.get("query_result", []),
explanation=result.get("final_response", "")
)

except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


@app.get("/health")
async def health_check():
"""健康检查"""
return {"status": "ok"}

🎨 Web界面

static/index.html(简化版):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
<!DOCTYPE html>
<html>
<head>
<title>智能TextToSql系统</title>
<style>
body {
font-family: Arial, sans-serif;
max-width: 1200px;
margin: 0 auto;
padding: 20px;
}
.container {
display: flex;
flex-direction: column;
gap: 20px;
}
.input-box {
width: 100%;
padding: 10px;
font-size: 16px;
}
.btn {
padding: 10px 20px;
font-size: 16px;
background: #007bff;
color: white;
border: none;
cursor: pointer;
}
.result-box {
background: #f5f5f5;
padding: 15px;
border-radius: 5px;
}
pre {
background: #272822;
color: #f8f8f2;
padding: 15px;
border-radius: 5px;
overflow-x: auto;
}
</style>
</head>
<body>
<h1>🤖 智能TextToSql系统</h1>
<div class="container">
<div>
<input type="text" class="input-box" id="queryInput"
placeholder="输入查询,例如:查询本月销售额最高的前10个产品">
<button class="btn" onclick="submitQuery()">查询</button>
</div>

<div id="sqlBox" class="result-box" style="display:none">
<h3>生成的SQL:</h3>
<pre id="sqlContent"></pre>
</div>

<div id="resultBox" class="result-box" style="display:none">
<h3>查询结果:</h3>
<div id="resultContent"></div>
</div>
</div>

<script>
async function submitQuery() {
const query = document.getElementById('queryInput').value;

const response = await fetch('/api/query', {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({query: query})
});

const data = await response.json();

// 显示SQL
document.getElementById('sqlContent').textContent = data.sql;
document.getElementById('sqlBox').style.display = 'block';

// 显示结果
document.getElementById('resultContent').innerHTML = data.explanation;
document.getElementById('resultBox').style.display = 'block';
}
</script>
</body>
</html>

🧪 测试用例

测试各种查询场景:

1
2
3
4
5
6
7
test_queries = [
"查询本月销售额TOP10的产品",
"对比去年同期和今年的用户增长率",
"统计各地区的订单数量",
"查找年龄大于25岁的活跃用户",
"分析最近7天的销售趋势"
]

📚 总结

这个实战项目综合运用了前面所有课程的知识:

Lesson 01-02:LangGraph基础和节点定义
Lesson 03:复杂状态管理
Lesson 04:条件路由和循环重试
Lesson 05:工具集成(数据库查询)
Lesson 06:知识检索(Schema和示例)
Lesson 07:错误处理和重试机制

恭喜你完成了整个课程!🎉


📖 扩展阅读

  • LangGraph官方文档
  • TextToSql最佳实践
  • FastAPI文档