综合实战!构建一个完整的智能TextToSql系统,用户用自然语言提问,AI自动生成SQL查询并返回结果。项目涵盖意图识别、实体抽取、知识检索、SQL生成与验证等完整流程,并提供FastAPI接口和Web界面。
🎯 项目目标 构建一个企业级的TextToSql系统,包含:
意图识别:理解用户想查询什么 实体拆解:提取关键信息(时间、指标、维度等) 知识检索:查询字段映射和示例SQL SQL生成:根据意图生成准确的SQL SQL验证:检查SQL的正确性和安全性 结果呈现:格式化查询结果 FastAPI接口:RESTful API Web界面:用户友好的交互界面 📁 项目结构 完整代码地址:https://gitee.com/uyynot/texttosqlproject.git
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 TextToSqlProject/ ├── README.md # 项目文档 ├── app/ │ ├── __init__.py │ ├── main.py # FastAPI应用入口 │ ├── config.py # 配置文件 │ ├── models.py # 数据模型 │ └── api/ │ └── routes.py # API路由 ├── core/ │ ├── __init__.py │ ├── llm.py # LLM封装 │ ├── graph.py # LangGraph工作流 │ └── nodes/ │ ├── __init__.py │ ├── intent.py # 意图识别节点 │ ├── entity.py # 实体拆解节点 │ ├── knowledge.py # 知识检索节点 │ ├── sql_gen.py # SQL生成节点 │ ├── sql_validate.py # SQL验证节点 │ └── result.py # 结果呈现节点 ├── knowledge/ │ ├── schema.json # 数据库Schema │ ├── examples.json # SQL示例 │ └── entities.json # 实体映射 ├── database/ │ ├── __init__.py │ └── sample_data.db # 示例数据库 ├── static/ │ └── index.html # Web界面 ├── tests/ │ └── test_graph.py # 测试用例 ├── requirements.txt # 依赖包 └── run.py # 启动脚本
🚀 快速开始 1. 安装依赖 1 2 cd demopip install -r requirements.txt
2. 配置环境 创建 .env 文件:
1 2 DASHSCOPE_API_KEY=your_api_key_here DATABASE_PATH=./database/sample_data.db
3. 初始化数据库 1 python -m database.init_db
4. 启动服务 访问:http://localhost:8000
💡 核心流程 整体流程图 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 用户输入 ↓ 意图识别 ──→ 实体拆解 ↓ ↓ 知识检索 ←──────┘ ↓ SQL生成 ↓ SQL验证 ──→ 验证失败 ──→ 重新生成(最多3次) ↓ 验证通过 ↓ 执行查询 ↓ 结果呈现
状态定义 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 from typing import TypedDict, List , Dict , Optional class TextToSqlState (TypedDict ): user_query: str intent: Optional [str ] intent_confidence: Optional [float ] entities: Optional [Dict ] time_range: Optional [Dict ] metrics: Optional [List [str ]] dimensions: Optional [List [str ]] filters: Optional [Dict ] schema_info: Optional [Dict ] similar_examples: Optional [List ] field_mapping: Optional [Dict ] generated_sql: Optional [str ] sql_explanation: Optional [str ] is_valid: bool validation_errors: Optional [List ] retry_count: int query_result: Optional [List [Dict ]] result_count: Optional [int ] final_response: Optional [str ] visualization: Optional [Dict ]
📝 节点实现 节点1:意图识别 core/nodes/intent.py:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 """ 意图识别节点 识别用户查询的意图类型 """ from typing import Dict from ..llm import get_llmINTENT_TYPES = { "aggregation" : "聚合查询(如:统计、求和、平均)" , "comparison" : "对比查询(如:同比、环比、对比)" , "ranking" : "排名查询(如:TOP N、排行榜)" , "filtering" : "筛选查询(如:查找、筛选)" , "trend" : "趋势查询(如:增长、变化趋势)" } def intent_recognition_node (state: Dict ) -> Dict : """ 识别用户查询意图 Args: state: 包含user_query的状态 Returns: 更新后的状态,包含intent和intent_confidence """ user_query = state["user_query" ] llm = get_llm() prompt = f""" 你是一个SQL查询意图识别专家。 用户查询:{user_query} 请识别查询意图,从以下类型中选择: {chr (10 ).join(f'- {k} : {v} ' for k, v in INTENT_TYPES.items())} 以JSON格式返回: {{ "intent": "意图类型", "confidence": 0.95, "reasoning": "判断理由" }} """ response = llm.invoke(prompt) import json try : result = json.loads(response) state["intent" ] = result["intent" ] state["intent_confidence" ] = result["confidence" ] except : state["intent" ] = "unknown" state["intent_confidence" ] = 0.0 print (f"[意图识别] {state['intent' ]} (置信度: {state['intent_confidence' ]} )" ) return state
节点2:实体拆解 core/nodes/entity.py:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 """ 实体拆解节点 从用户查询中提取关键实体 """ from typing import Dict from ..llm import get_llmdef entity_extraction_node (state: Dict ) -> Dict : """ 提取查询中的实体 提取内容: - 时间范围:今天、本月、2023年等 - 指标:销售额、用户数、增长率等 - 维度:地区、部门、产品类别等 - 过滤条件:年龄>25、金额>1000等 """ user_query = state["user_query" ] llm = get_llm() prompt = f""" 从查询中提取关键实体。 查询:{user_query} 提取以下信息,以JSON格式返回: {{ "time_range": {{ "type": "relative/absolute", "value": "2023-01-01 TO 2023-12-31" }}, "metrics": ["销售额", "订单数"], "dimensions": ["地区", "产品类别"], "filters": {{ "age": ">25", "amount": ">1000" }} }} 如果某项不存在,使用null。 """ response = llm.invoke(prompt) import json try : entities = json.loads(response) state["entities" ] = entities state["time_range" ] = entities.get("time_range" ) state["metrics" ] = entities.get("metrics" , []) state["dimensions" ] = entities.get("dimensions" , []) state["filters" ] = entities.get("filters" , {}) except : state["entities" ] = {} print (f"[实体拆解] 指标: {state.get('metrics' )} , 维度: {state.get('dimensions' )} " ) return state
节点3:知识检索 core/nodes/knowledge.py:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 """ 知识检索节点 检索数据库Schema和相似SQL示例 """ from typing import Dict import jsondef knowledge_retrieval_node (state: Dict ) -> Dict : """ 检索知识库 检索内容: 1. 表结构信息 2. 字段映射(中文名→英文字段名) 3. 相似SQL示例 """ with open ("knowledge/schema.json" , "r" , encoding="utf-8" ) as f: schema = json.load(f) with open ("knowledge/entities.json" , "r" , encoding="utf-8" ) as f: field_mapping = json.load(f) with open ("knowledge/examples.json" , "r" , encoding="utf-8" ) as f: examples = json.load(f) intent = state.get("intent" , "" ) similar_examples = [ ex for ex in examples if ex.get("intent" ) == intent ][:3 ] state["schema_info" ] = schema state["field_mapping" ] = field_mapping state["similar_examples" ] = similar_examples print (f"[知识检索] 找到{len (similar_examples)} 个相似示例" ) return state
节点4:SQL生成 core/nodes/sql_gen.py:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 """ SQL生成节点 根据意图、实体和知识库生成SQL """ from typing import Dict from ..llm import get_llmdef sql_generation_node (state: Dict ) -> Dict : """ 生成SQL查询语句 """ llm = get_llm() prompt = f""" 你是SQL生成专家。根据以下信息生成SQL查询。 用户查询:{state['user_query' ]} 意图:{state.get('intent' )} 实体: - 指标:{state.get('metrics' )} - 维度:{state.get('dimensions' )} - 时间范围:{state.get('time_range' )} - 过滤条件:{state.get('filters' )} 数据库Schema: {state.get('schema_info' )} 字段映射: {state.get('field_mapping' )} 相似SQL示例: {state.get('similar_examples' )} 要求: 1. 生成标准的SQL查询 2. 使用正确的表名和字段名 3. 添加适当的注释 4. 只返回SQL,不要有其他说明 SQL: """ response = llm.invoke(prompt) sql = response.strip() if sql.startswith("```sql" ): sql = sql[6 :] if sql.endswith("```" ): sql = sql[:-3 ] sql = sql.strip() state["generated_sql" ] = sql print (f"[SQL生成]\n{sql} " ) return state
节点5:SQL验证 core/nodes/sql_validate.py:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 """ SQL验证节点 验证SQL的正确性和安全性 """ from typing import Dict import sqlparsedef sql_validation_node (state: Dict ) -> Dict : """ 验证SQL 检查: 1. 语法正确性 2. 安全性(防止SQL注入) 3. 表名和字段名是否存在 """ sql = state.get("generated_sql" , "" ) errors = [] try : parsed = sqlparse.parse(sql) if not parsed: errors.append("SQL语法错误:无法解析" ) except Exception as e: errors.append(f"SQL语法错误:{e} " ) dangerous_keywords = ["DROP" , "DELETE" , "UPDATE" , "INSERT" , "ALTER" , "TRUNCATE" ] sql_upper = sql.upper() for keyword in dangerous_keywords: if keyword in sql_upper: errors.append(f"安全检查失败:包含危险关键字 {keyword} " ) schema = state.get("schema_info" , {}) tables = schema.get("tables" , []) if errors: state["is_valid" ] = False state["validation_errors" ] = errors print (f"[SQL验证] ❌ 验证失败:{errors} " ) else : state["is_valid" ] = True state["validation_errors" ] = [] print (f"[SQL验证] ✅ 验证通过" ) return state
节点6:结果呈现 core/nodes/result.py:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 """ 结果呈现节点 执行SQL并格式化结果 """ from typing import Dict import sqlite3def result_presentation_node (state: Dict ) -> Dict : """ 执行SQL并呈现结果 """ sql = state.get("generated_sql" , "" ) try : conn = sqlite3.connect(state.get("database_path" , "database/sample_data.db" )) cursor = conn.cursor() cursor.execute(sql) rows = cursor.fetchall() columns = [desc[0 ] for desc in cursor.description] results = [ dict (zip (columns, row)) for row in rows ] state["query_result" ] = results state["result_count" ] = len (results) state["final_response" ] = format_response(state) print (f"[结果呈现] 查询到{len (results)} 条记录" ) conn.close() except Exception as e: state["final_response" ] = f"查询执行失败:{e} " print (f"[结果呈现] ❌ 错误:{e} " ) return state def format_response (state: Dict ) -> str : """格式化回复""" results = state.get("query_result" , []) if not results: return "未查询到相关数据。" response = f"查询到{len (results)} 条记录:\n\n" for i, row in enumerate (results[:5 ], 1 ): response += f"{i} . " response += ", " .join(f"{k} : {v} " for k, v in row.items()) response += "\n" if len (results) > 5 : response += f"\n... 还有{len (results) - 5 } 条记录" return response
🔧 LangGraph工作流 core/graph.py:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 """ TextToSql LangGraph工作流 """ from typing import Dict from langgraph.graph import StateGraph, ENDfrom .nodes.intent import intent_recognition_nodefrom .nodes.entity import entity_extraction_nodefrom .nodes.knowledge import knowledge_retrieval_nodefrom .nodes.sql_gen import sql_generation_nodefrom .nodes.sql_validate import sql_validation_nodefrom .nodes.result import result_presentation_nodedef should_retry_sql (state: Dict ) -> str : """判断是否需要重新生成SQL""" if state.get("is_valid" , False ): return "execute" elif state.get("retry_count" , 0 ) < 3 : return "regenerate" else : return "failed" def create_text_to_sql_graph (): """创建TextToSql工作流""" graph = StateGraph(dict ) graph.add_node("intent" , intent_recognition_node) graph.add_node("entity" , entity_extraction_node) graph.add_node("knowledge" , knowledge_retrieval_node) graph.add_node("sql_gen" , sql_generation_node) graph.add_node("validate" , sql_validation_node) graph.add_node("result" , result_presentation_node) graph.set_entry_point("intent" ) graph.add_edge("intent" , "entity" ) graph.add_edge("entity" , "knowledge" ) graph.add_edge("knowledge" , "sql_gen" ) graph.add_edge("sql_gen" , "validate" ) graph.add_conditional_edges( "validate" , should_retry_sql, { "execute" : "result" , "regenerate" : "sql_gen" , "failed" : END } ) graph.add_edge("result" , END) return graph.compile ()
🌐 FastAPI接口 app/main.py:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 """ FastAPI应用 """ from fastapi import FastAPI, HTTPExceptionfrom fastapi.staticfiles import StaticFilesfrom fastapi.responses import FileResponsefrom pydantic import BaseModelfrom core.graph import create_text_to_sql_graphapp = FastAPI(title="智能TextToSql系统" ) app.mount("/static" , StaticFiles(directory="static" ), name="static" ) graph = create_text_to_sql_graph() class QueryRequest (BaseModel ): query: str class QueryResponse (BaseModel ): sql: str result: list explanation: str @app.get("/" ) async def read_root (): """返回Web界面""" return FileResponse("static/index.html" ) @app.post("/api/query" , response_model=QueryResponse ) async def process_query (request: QueryRequest ): """ 处理文本查询,返回SQL和结果 """ try : result = graph.invoke({ "user_query" : request.query, "retry_count" : 0 , "is_valid" : False }) return QueryResponse( sql=result.get("generated_sql" , "" ), result=result.get("query_result" , []), explanation=result.get("final_response" , "" ) ) except Exception as e: raise HTTPException(status_code=500 , detail=str (e)) @app.get("/health" ) async def health_check (): """健康检查""" return {"status" : "ok" }
🎨 Web界面 static/index.html(简化版):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 <!DOCTYPE html > <html > <head > <title > 智能TextToSql系统</title > <style > body { font-family : Arial, sans-serif; max-width : 1200px ; margin : 0 auto; padding : 20px ; } .container { display : flex; flex-direction : column; gap : 20px ; } .input-box { width : 100% ; padding : 10px ; font-size : 16px ; } .btn { padding : 10px 20px ; font-size : 16px ; background : #007bff ; color : white; border : none; cursor : pointer; } .result-box { background : #f5f5f5 ; padding : 15px ; border-radius : 5px ; } pre { background : #272822 ; color : #f8f8f2 ; padding : 15px ; border-radius : 5px ; overflow-x : auto; } </style > </head > <body > <h1 > 🤖 智能TextToSql系统</h1 > <div class ="container" > <div > <input type ="text" class ="input-box" id ="queryInput" placeholder ="输入查询,例如:查询本月销售额最高的前10个产品" > <button class ="btn" onclick ="submitQuery()" > 查询</button > </div > <div id ="sqlBox" class ="result-box" style ="display:none" > <h3 > 生成的SQL:</h3 > <pre id ="sqlContent" > </pre > </div > <div id ="resultBox" class ="result-box" style ="display:none" > <h3 > 查询结果:</h3 > <div id ="resultContent" > </div > </div > </div > <script > async function submitQuery ( ) { const query = document .getElementById ('queryInput' ).value ; const response = await fetch ('/api/query' , { method : 'POST' , headers : {'Content-Type' : 'application/json' }, body : JSON .stringify ({query : query}) }); const data = await response.json (); document .getElementById ('sqlContent' ).textContent = data.sql ; document .getElementById ('sqlBox' ).style .display = 'block' ; document .getElementById ('resultContent' ).innerHTML = data.explanation ; document .getElementById ('resultBox' ).style .display = 'block' ; } </script > </body > </html >
🧪 测试用例 测试各种查询场景:
1 2 3 4 5 6 7 test_queries = [ "查询本月销售额TOP10的产品" , "对比去年同期和今年的用户增长率" , "统计各地区的订单数量" , "查找年龄大于25岁的活跃用户" , "分析最近7天的销售趋势" ]
📚 总结 这个实战项目综合运用了前面所有课程的知识:
✅ Lesson 01-02 :LangGraph基础和节点定义 ✅ Lesson 03 :复杂状态管理 ✅ Lesson 04 :条件路由和循环重试 ✅ Lesson 05 :工具集成(数据库查询) ✅ Lesson 06 :知识检索(Schema和示例) ✅ Lesson 07 :错误处理和重试机制
恭喜你完成了整个课程!🎉
📖 扩展阅读 LangGraph官方文档 TextToSql最佳实践 FastAPI文档