# vanna-ai 实现 text2SQL

# 简介

Vanna AI 是一款基于 MIT 许可的开源 Python RAG(检索增强生成)框架,专为通过自然语言生成精准的 SQL 查询而设计。简单来说,它能让用户用日常说话的方式向数据库提问,并自动得到答案。

Vanna 的核心工作流程分为两步:

  1. 训练阶段:基于用户的数据库结构(DDL)、业务文档和历史 SQL 查询,训练一个 RAG 模型。
  2. 推理阶段:用户输入自然语言问题,Vanna 自动生成对应的 SQL 查询,并可配置为直接在数据库上执行。

Vanna 并非 “即插即用” 的工具 —— 它需要先通过训练来认识你的表结构、理解字段含义和业务口径,才能准确生成 SQL。

# 核心特点

特点说明
自然语言转 SQL用户无需掌握 SQL 语法,用日常语言即可完成复杂查询
支持任意 SQL 数据库兼容 MySQL、PostgreSQL、Snowflake、SQLite 等主流数据库, 也支持国产化的数据,例如达梦
高度可扩展可自由替换底层的大语言模型(LLM)和向量数据库组件
多种交互界面支持 Jupyter Notebook、Streamlit、Flask、Slackbot 等多种前端
RAG 架构优势训练数据可动态更新,推理成本较传统微调方案降低 60% 以上

Vanna 2.0 版本已从简单的 SQL 生成库演进为生产就绪的、具备用户感知能力的智能数据分析代理框架,内置了 SQL 执行、数据可视化和 RAG 记忆等工具。

# 开始使用

# 核心代码

"""
Vanna 官方标准配置 + 硅基流动 + 多数据库 + 百度地图MCP工具(正式生产版)
✅ 无测试表/数据,直接连接生产库
✅ 自动训练全库表 + 字段
✅ 文档RAG增强:支持单文件存放所有表说明(带生效验证)
✅ 表注释兜底:有注释用注释,无注释用文档
✅ 完整日志追踪:可查看RAG是否生效、注入了哪些内容
✅ 修复Flask服务初始化问题
✅ 适配旧版configparser
✅ 新增:百度地图MCP服务工具(地址解析/POI/路线规划/IP定位)
"""
from vanna import Agent
from vanna.core.registry import ToolRegistry
from vanna.core.user import UserResolver, User, RequestContext
from vanna.tools import RunSqlTool, VisualizeDataTool
from vanna.tools.agent_memory import SaveQuestionToolArgsTool, SearchSavedCorrectToolUsesTool, SaveTextMemoryTool
from vanna.servers.flask import VannaFlaskServer
from vanna.integrations.openai import OpenAILlmService
from vanna.integrations.sqlite import SqliteRunner
from vanna.integrations.local.agent_memory import DemoAgentMemory
from vanna.capabilities.sql_runner import SqlRunner, RunSqlToolArgs
from vanna.core.tool import ToolContext
from vanna.core.enhancer import LlmContextEnhancer, DefaultLlmContextEnhancer
from vanna.core.llm import LlmMessage
from vanna.core.agent import AgentConfig
import os
import logging
import configparser
import pandas as pd
import pymysql
import glob
import re
import json
from flask import Flask, jsonify
from sentence_transformers import SentenceTransformer
import numpy as np
import httpx
from typing import Type, Optional
from pydantic import BaseModel, Field
from vanna.core.tool import Tool, ToolResult
from vanna.components import UiComponent, NotificationComponent, ComponentType
import jwt  # JWT 认证所需
from jwt import InvalidTokenError  # JWT 异常处理
try:
    import dmPython
except ImportError:
    dmPython = None
    logging.warning("⚠️ 达梦数据库驱动未安装,DM数据库功能将不可用")
# ==================== 日志配置 ====================
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler('vanna_rag_info.log', encoding='utf-8')
    ]
)
logger = logging.getLogger(__name__)
# 全局变量
LAST_ENHANCED_PROMPT = ""
LAST_MATCHED_TABLES = []
LAST_RAG_CONTENT = ""
_agent_config: AgentConfig = AgentConfig(
	max_tool_iterations=20
)
# ==================== 兼容工具 ====================
def get_config_value(config, section, option, default=None):
    try:
        return config.get(section, option)
    except (configparser.NoSectionError, configparser.NoOptionError):
        return default
# ==================== 1. 文档 RAG 服务 ====================
class LocalDocSearchService:
    def __init__(self, doc_dir: str = "./docs", api_key: str = ""):
        self.doc_dir = doc_dir
        self.api_key = api_key
        self.table_docs = []
        self.embeddings = []
        self.client = httpx.AsyncClient(timeout=30.0)
        self.embedding_model = "BAAI/bge-m3"
        self.embedding_url = "https://api.siliconflow.cn/v1/embeddings"
        self._load_and_split_docs()
        self._build_embeddings_sync()
    def _load_and_split_docs(self):
        if not os.path.exists(self.doc_dir):
            os.makedirs(self.doc_dir)
            logger.warning(f"⚠️ 文档目录 {self.doc_dir} 不存在,已自动创建")
            return
        doc_patterns = [
            os.path.join(self.doc_dir, "*.md"),
            os.path.join(self.doc_dir, "*.txt")
        ]
        for pattern in doc_patterns:
            for file_path in glob.glob(pattern):
                try:
                    with open(file_path, 'r', encoding='utf-8') as f:
                        content = f.read().strip()
                    if not content:
                        logger.warning(f"⚠️ 文档 {file_path} 内容为空")
                        continue
                    table_sections = self._split_table_sections(content, file_path)
                    self.table_docs.extend(table_sections)
                    logger.info(f"✅ 加载 {file_path},识别 {len(table_sections)} 张表")
                except Exception as e:
                    logger.error(f"❌ 加载文档失败:{str(e)}", exc_info=True)
        if self.table_docs:
            table_names = [d['table_name'] for d in self.table_docs]
            logger.info(f"🔍 向量库表清单:{', '.join(table_names)}")
        else:
            logger.warning("⚠️ 未加载任何表说明")
    def _split_table_sections(self, content, file_path):
        sections = []
        pattern = r'(?:## |### |^)([\u4e00-\u9fa5a-zA-Z0-9_]+表)[::\s-]+'
        matches = list(re.finditer(pattern, content, re.MULTILINE))
        if not matches:
            sections.append({
                "table_name": "通用说明",
                "content": content,
                "source_file": file_path
            })
            return sections
        for i, match in enumerate(matches):
            table_name = match.group(1).replace("表", "").strip()
            start_idx = match.end()
            end_idx = matches[i+1].start() if i+1 < len(matches) else len(content)
            table_content = content[start_idx:end_idx].strip()
            if table_content:
                sections.append({
                    "table_name": table_name,
                    "content": f"表名:{table_name},说明:{table_content}",
                    "source_file": file_path
                })
        return sections
    def _get_embedding_sync(self, text: str):
        if not self.api_key:
            return np.zeros(1024)
        try:
            resp = httpx.post(
                url=self.embedding_url,
                headers={
                    "Authorization": f"Bearer {self.api_key}",
                    "Content-Type": "application/json"
                },
                json={
                    "model": self.embedding_model,
                    "input": text
                },
                timeout=20
            )
            resp.raise_for_status()
            data = resp.json()
            return np.array(data["data"][0]["embedding"])
        except Exception as e:
            logger.error(f"❌ 硅基流动向量生成失败:{e}")
            return np.zeros(1024)
    def _build_embeddings_sync(self):
        if not self.table_docs or not self.api_key:
            self.embeddings = []
            return
        logger.info("🚀 正在通过硅基流动生成表结构向量...")
        embs = []
        for doc in self.table_docs:
            vec = self._get_embedding_sync(doc["content"])
            embs.append(vec)
        self.embeddings = embs
        logger.info(f"✅ 完成 {len(embs)} 张表的向量构建(BGE-M3)")
    def _cos_sim(self, a, b):
        dot = np.dot(a, b)
        norm = np.linalg.norm(a) * np.linalg.norm(b)
        return dot / norm if norm != 0 else 0
    async def search(self, query: str, limit: int = 3):
        global LAST_MATCHED_TABLES, LAST_RAG_CONTENT
        LAST_MATCHED_TABLES = []
        LAST_RAG_CONTENT = ""
        logger.debug(f"🔍 硅基向量检索:{query}")
        if not self.table_docs or not query or not self.api_key:
            return []
        q_vec = self._get_embedding_sync(query)
        scored = []
        for i, doc in enumerate(self.table_docs):
            sim = self._cos_sim(q_vec, self.embeddings[i])
            scored.append((sim, doc))
        scored.sort(reverse=True, key=lambda x: x[0])
        threshold = 0.2
        relevant = [item for item in scored if item[0] >= threshold]
        if not relevant:
            logger.warning("⚠️ 未匹配到相关表")
            return []
        relevant_docs = []
        for sim, doc in relevant[:limit]:
            logger.info(f"✅ 匹配 {doc['table_name']} 相似度:{sim:.2f}")
            relevant_docs.append({
                "title": f"{doc['table_name']}表说明",
                "content": f"### 表名:{doc['table_name']}\n{doc['content']}"
            })
            LAST_MATCHED_TABLES.append(doc['table_name'])
        LAST_RAG_CONTENT = "\n---\n".join([
            f"{d['title']}\n{d['content'][:300]}" for d in relevant_docs
        ])
        return relevant_docs
# ==================== 2. 文档 RAG 增强器 ====================
class DocumentationEnhancer(LlmContextEnhancer):
    def __init__(self, doc_search_service):
        self.doc_search_service = doc_search_service
    async def enhance_system_prompt(
        self,
        system_prompt: str,
        user_message: str,
        user: User
    ) -> str:
        logger.debug(f"\n========== RAG增强开始 ==========")
        logger.debug(f"原始系统提示词长度:{len(system_prompt)}")
        logger.debug(f"用户问题:{user_message}")
        relevant_docs = await self.doc_search_service.search(user_message, limit=3)
        if not relevant_docs:
            logger.debug("❌ 无相关文档,跳过RAG增强")
            logger.debug(f"========== RAG增强结束 ==========\n")
            return system_prompt
        docs_section = "\n\n## Relevant Table Documentation (RAG)\n\n"
        for idx, doc in enumerate(relevant_docs, 1):
            docs_section += f"### {idx}. {doc['title']}\n"
            content_snippet = doc["content"][:1000] + "..." if len(doc["content"]) > 1000 else doc["content"]
            docs_section += f"{content_snippet}\n\n"
        enhanced_prompt = system_prompt + docs_section
        logger.info(f"\n========== 注入的RAG内容 ==========\n{docs_section}\n")
        logger.debug(f"增强后系统提示词长度:{len(enhanced_prompt)}")
        logger.debug(f"========== RAG增强结束 ==========\n")
        return enhanced_prompt
    async def enhance_user_messages(
        self,
        messages,
        user: User
    ):
        return messages
# ==================== 3. 表注释增强器 ====================
class SchemaCommentEnhancer(LlmContextEnhancer):
    def __init__(self, sql_runner, db_type, db_name=None):
        self.sql_runner = sql_runner
        self.db_type = db_type
        self.db_name = db_name
        self.schema_cache = None
    async def _load_schema_with_comments(self):
        if self.db_type != "mysql" or not self.db_name:
            self.schema_cache = None
            logger.warning("⚠️ 非MySQL数据库,跳过表注释加载")
            return
        table_sql = f"""
            SELECT TABLE_NAME, TABLE_COMMENT 
            FROM information_schema.TABLES 
            WHERE TABLE_SCHEMA = '{self.db_name}' AND TABLE_TYPE = 'BASE TABLE'
        """
        df_tables = await self.sql_runner.run_sql(RunSqlToolArgs(sql=table_sql), None)
        column_sql = f"""
            SELECT TABLE_NAME, COLUMN_NAME, COLUMN_COMMENT, DATA_TYPE, IS_NULLABLE
            FROM information_schema.COLUMNS 
            WHERE TABLE_SCHEMA = '{self.db_name}'
        """
        df_columns = await self.sql_runner.run_sql(RunSqlToolArgs(sql=column_sql), None)
        self.schema_cache = {"tables": [], "columns": {}}
        for _, row in df_tables.iterrows():
            table_name = row["TABLE_NAME"]
            table_comment = row["TABLE_COMMENT"] if row["TABLE_COMMENT"] else ""
            if table_comment and table_comment != "NULL":
                self.schema_cache["tables"].append({
                    "name": table_name,
                    "comment": table_comment
                })
                self.schema_cache["columns"][table_name] = []
                logger.debug(f"📌 加载表注释:{table_name} - {table_comment}")
        for _, row in df_columns.iterrows():
            table_name = row["TABLE_NAME"]
            col_name = row["COLUMN_NAME"]
            col_comment = row["COLUMN_COMMENT"] if row["COLUMN_COMMENT"] else ""
            if table_name in self.schema_cache["columns"] and col_comment and col_comment != "NULL":
                self.schema_cache["columns"][table_name].append({
                    "name": col_name,
                    "comment": col_comment,
                    "type": row["DATA_TYPE"],
                    "nullable": row["IS_NULLABLE"]
                })
                logger.debug(f"   字段注释:{col_name} - {col_comment}")
        logger.info(f"✅ 共加载 {len(self.schema_cache['tables'])} 张有注释的表")
    def _find_relevant_tables(self, user_message):
        if not self.schema_cache or not self.schema_cache["tables"]:
            return []
        message_lower = user_message.lower()
        relevant_tables = []
        for table in self.schema_cache["tables"]:
            if table["name"].lower() in message_lower:
                relevant_tables.append(table)
                logger.info(f"✅ 匹配到有注释的表:{table['name']}")
        if not relevant_tables:
            logger.warning("⚠️ 未匹配到有注释的表")
        return relevant_tables
    async def enhance_system_prompt(
        self,
        system_prompt: str,
        user_message: str,
        user: User
    ) -> str:
        if not self.schema_cache:
            await self._load_schema_with_comments()
        relevant_tables = self._find_relevant_tables(user_message)
        if not relevant_tables:
            return system_prompt
        schema_section = "\n\n## Relevant Database Schema (With Comments)\n\n"
        for table in relevant_tables:
            schema_section += f"### 表:{table['name']}\n"
            schema_section += f"- 表注释:{table['comment']}\n"
            schema_section += "- 字段信息:\n"
            for col in self.schema_cache["columns"][table["name"]]:
                schema_section += f"  - `{col['name']}` ({col['type']}):{col['comment']}(是否为空:{col['nullable']})\n"
            schema_section += "\n"
        logger.info(f"\n========== 注入的表注释内容 ==========\n{schema_section}\n")
        global LAST_ENHANCED_PROMPT
        LAST_ENHANCED_PROMPT = system_prompt + schema_section
        return system_prompt + schema_section
    async def enhance_user_messages(
        self,
        messages,
        user: User
    ):
        return messages
# ==================== 4. 组合增强器 ====================
class CombinedEnhancer(LlmContextEnhancer):
    def __init__(self, enhancers):
        self.enhancers = enhancers
        self.BUSINESS_RULES = """
======== 【数据库业务规则 - 必须严格遵守】 ========
🚨 严格禁止的操作:
1. 禁止执行任何 DELETE 语句
2. 禁止执行任何 UPDATE 语句
3. 禁止执行任何 INSERT 语句
4. 禁止执行任何 DROP、ALTER、CREATE 等DDL语句
5. 禁止执行任何 TRUNCATE 语句
⚠️ 违规后果:如果尝试执行上述操作,必须立即停止并提示用户这是禁止的操作。
✅ 允许的操作:
- SELECT 查询
- SHOW 语句
- DESCRIBE 语句
- EXPLAIN 语句
返回结果不返回相关的sql语句
======== 【规则结束】 ========
"""
    async def enhance_system_prompt(
        self,
        system_prompt: str,
        user_message: str,
        user: User
    ) -> str:
        global LAST_ENHANCED_PROMPT
        enhanced_prompt = system_prompt
        
        logger.debug(f"\n========== 组合增强开始 ==========")
        logger.debug(f"初始提示词:{enhanced_prompt[:200]}...")
        for i, enhancer in enumerate(self.enhancers):
            enhancer_name = enhancer.__class__.__name__
            logger.debug(f"\n🔧 执行增强器:{enhancer_name}")
            enhanced_prompt = await enhancer.enhance_system_prompt(
                enhanced_prompt, user_message, user
            )
            logger.debug(f"增强后长度:{len(enhanced_prompt)}")
        enhanced_prompt = enhanced_prompt + "\n\n" + self.BUSINESS_RULES
        LAST_ENHANCED_PROMPT = enhanced_prompt
        
        logger.info(f"\n========== 最终注入LLM的完整上下文(前2000字符) ==========\n{enhanced_prompt[:2000]}\n")
        logger.debug(f"========== 组合增强结束 ==========\n")
        return enhanced_prompt
    async def enhance_user_messages(
        self,
        messages,
        user: User
    ):
        enhanced_messages = messages
        for enhancer in self.enhancers:
            enhanced_messages = await enhancer.enhance_user_messages(
                enhanced_messages, user
            )
        return enhanced_messages
# ==================== 最终版:高德地图 API 工具(稳定无报错) ====================
# ==================== 最终完美修复版:高德地图 API 工具 ====================
class MapToolArgs(BaseModel):
    query_type: str = Field(
        description="""地图操作类型,必须选其一:
        - geocode: 地址转换经纬度
        - regeocode: 经纬度转换地址
        - search: 周边搜索(餐厅、银行等)
        """
    )
    address: Optional[str] = Field(description="地址", default=None)
    location: Optional[str] = Field(description="经纬度 格式:lat,lng", default=None)
    keywords: Optional[str] = Field(description="搜索关键词", default=None)
    city: Optional[str] = Field(description="城市", default="北京")
class GaodeMapTool(Tool[MapToolArgs]):
    def __init__(self):
        self.key = "8880f90917273111d7ac1841f3d"
        self.client = httpx.AsyncClient(timeout=10)
    @property
    def name(self) -> str:
        return "gaode_map_service"
    @property
    def description(self) -> str:
        return """高德地图服务,支持:
        1. 地址解析经纬度(geocode)
        2. 经纬度解析地址(regeocode)
        3. 关键词POI搜索(search)
        """
    def get_args_schema(self) -> Type[MapToolArgs]:
        return MapToolArgs
    async def execute(self, context: ToolContext, args: MapToolArgs) -> ToolResult:
        try:
            result_data = None
            text_result = ""
            # ==================== 1. 地址 → 经纬度 ====================
            if args.query_type == "geocode" and args.address:
                r = await self.client.get(
                    "https://restapi.amap.com/v3/geocode/geo",
                    params={
                        "key": self.key,
                        "address": args.address,
                        "city": args.city,
                        "output": "json"
                    }
                )
                data = r.json()
                if data.get("status") == "1" and data.get("geocodes"):
                    geo = data["geocodes"][0]
                    text_result = f"地址:{geo['formatted_address']}\n经纬度:{geo['location']}"
                    result_data = {"geo": geo}
            # ==================== 2. 经纬度 → 地址 ====================
            elif args.query_type == "regeocode" and args.location:
                r = await self.client.get(
                    "https://restapi.amap.com/v3/geocode/regeo",
                    params={
                        "key": self.key,
                        "location": args.location,
                        "output": "json"
                    }
                )
                data = r.json()
                if data.get("status") == "1" and data.get("regeocode"):
                    addr = data["regeocode"]["formatted_address"]
                    text_result = f"经纬度:{args.location}\n解析地址:{addr}"
                    result_data = {"regeo": data["regeocode"]}
            # ==================== 3. 关键词搜索(周边) ====================
            elif args.query_type == "search" and args.keywords:
                r = await self.client.get(
                    "https://restapi.amap.com/v3/place/text",
                    params={
                        "key": self.key,
                        "keywords": args.keywords,
                        "city": args.city,
                        "output": "json"
                    }
                )
                data = r.json()
                pois = []
                if data.get("status") == "1" and data.get("pois"):
                    pois = data["pois"][:5]
                
                names = [f"{p.get('name','未知')}({p.get('address','无地址')})" for p in pois]
                text_result = f"搜索:{args.keywords}({args.city})\n结果:\n" + "\n".join(names)
                result_data = {"pois": pois, "count": len(pois)}  # 包装成字典,修复验证错误
            # 无结果
            if not text_result:
                text_result = "地图服务调用成功,但未查询到相关信息"
            if result_data is None:
                result_data = {}
            logger.info(f"✅ 高德地图调用成功:{args.query_type}")
            return ToolResult(
                success=True,
                result_for_llm=text_result,
                ui_component=UiComponent(
                    rich_component=NotificationComponent(
                        type=ComponentType.NOTIFICATION,
                        level="info",
                        message=text_result
                    )
                ),
                metadata=result_data  # 现在一定是字典
            )
        except Exception as e:
            logger.error(f"❌ 高德地图调用失败:{str(e)}", exc_info=True)
            return ToolResult(
                success=False,
                result_for_llm=f"地图服务调用失败:{str(e)}",
                error=str(e),
                metadata={}
            )
# JWT 认证配置(生产环境务必通过环境变量设置)
JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", "your-strong-secret-key-here")  # 必须更换为强密钥
JWT_ALGORITHM = "HS256"
# ==================== 自定义认证异常(核心改进) ====================
class AuthenticationError(Exception):
    """自定义认证异常,用于返回标准HTTP状态码"""
    def __init__(self, message: str, status_code: int = 401):
        self.message = message
        self.status_code = status_code
        super().__init__(self.message)
# ==================== JWT 用户认证解析器(核心修改) ====================
class JwtUserResolver(UserResolver):
    """JWT认证解析器 - 强制认证,失败抛出异常"""
    def __init__(self, secret_key: str, algorithm: str = "HS256"):
        self.secret_key = secret_key
        self.algorithm = algorithm
    
    async def resolve_user(self, request_context: RequestContext) -> User:
        # 1. 获取 Authorization 请求头
        auth_header = request_context.get_header('Authorization')
        
        # 2. 检查 Authorization 头是否存在且格式正确
        if not auth_header or not auth_header.startswith('Bearer '):
            raise AuthenticationError("拒绝访问:未提供有效的Authorization头", 401)
        
        # 3. 提取 JWT 令牌
        token = auth_header.split(' ')[1]
        logger.info(f"收到JWT令牌(部分):{token[:20]}...")
        
        # 4. 验证并解析 JWT 令牌
        try:
            payload = jwt.decode(
                token, 
                self.secret_key, 
                algorithms=[self.algorithm],
                options={"verify_exp": True}  # 验证令牌过期时间
            )
            
            # 5. 验证 payload 中必须包含的字段
            if 'sub' not in payload:
                raise AuthenticationError("拒绝访问:JWT令牌缺少sub字段", 403)
                
            if 'admin' != payload['sub']:
                raise AuthenticationError("拒绝访问:用户没有权限", 401)
            
            # 6. 返回认证通过的用户信息
            user = User(
                id=payload['sub'],
                username=payload.get('username', payload['sub']),
                email=payload.get('email', ''),
                group_memberships=payload.get('groups', ['user']),  # 默认 user 组
                permissions=payload.get('permissions', [])
            )
            logger.info(f"用户认证成功:{user.username} (ID: {user.id})")
            return user
        
        except InvalidTokenError as e:
            # JWT 验证失败(过期、签名错误、填充错误等)
            raise AuthenticationError(f"JWT令牌验证失败:{str(e)}", 401)
            
# ================= 不限制权限 ===============================
class SimpleUserResolver(UserResolver):
    async def resolve_user(self, request_context):
        return User(id="prod_user", email="prod@example.com", group_memberships=['user'])
# ==================== 全局异常处理注册函数(兼容所有 Vanna 版本) ====================
def register_auth_error_handler(server: VannaFlaskServer):
    """
    兼容不同Vanna版本的异常处理器注册
    自动适配 app/_app/get_app() 三种情况
    """
    # 尝试多种方式获取 Flask app 实例
    app = None
    if hasattr(server, 'app'):
        app = server.app
    elif hasattr(server, '_app'):
        app = server._app
    elif hasattr(server, 'get_app') and callable(server.get_app):
        app = server.get_app()
    else:
        # 兜底:创建新的 Flask app(确保不会报错)
        app = Flask(__name__)
    
    # 注册认证异常处理器
    @app.errorhandler(AuthenticationError)
    def handle_auth_error(e):
        """捕获认证异常,返回标准JSON错误响应和HTTP状态码"""
        response = {
            "success": False,
            "error": {
                "message": e.message,
                "status_code": e.status_code
            }
        }
        logger.error(f"认证失败:{e.message}")
        return jsonify(response), e.status_code
    
    # 注册通用异常处理器(可选)
    @app.errorhandler(Exception)
    def handle_generic_error(e):
        """捕获其他未处理的异常"""
        response = {
            "success": False,
            "error": {
                "message": f"服务器内部错误:{str(e)}",
                "status_code": 500
            }
        }
        logger.error(f"通用异常:{str(e)}", exc_info=True)
        return jsonify(response), 500
    
    logger.info("✅ 认证异常处理器注册成功")
    return app
# ==================== 调试接口 ====================
app = Flask(__name__)
@app.route('/debug/rag', methods=['GET'])
def debug_rag():
    return jsonify({
        "status": "success",
        "data": {
            "matched_tables": LAST_MATCHED_TABLES,
            "rag_content": LAST_RAG_CONTENT,
            "has_rag_content": len(LAST_RAG_CONTENT) > 0,
            "tips": "如果matched_tables为空,说明未匹配到任何表说明;如果rag_content为空,说明RAG未生效"
        }
    })
@app.route('/debug/prompt', methods=['GET'])
def debug_prompt():
    return jsonify({
        "status": "success",
        "data": {
            "prompt_length": len(LAST_ENHANCED_PROMPT),
            "prompt_preview": LAST_ENHANCED_PROMPT[:2000] + "..." if len(LAST_ENHANCED_PROMPT) > 2000 else LAST_ENHANCED_PROMPT,
            "has_enhanced_content": len(LAST_ENHANCED_PROMPT) > 1000,
            "tips": "如果prompt_length很小,说明上下文增强未生效"
        }
    })
# ==================== 基础配置 & 执行器 ====================
def load_config(config_path="vanna_config.ini"):
    config = configparser.ConfigParser()
    if not os.path.exists(config_path):
        config['DEFAULT'] = {
            'DATABASE_TYPE': 'mysql',
            'SQLITE_PATH': './sales.db',
            'MYSQL_HOST': 'localhost',
            'MYSQL_PORT': '3306',
            'MYSQL_USER': 'root',
            'MYSQL_PASSWORD': '123456',
            'MYSQL_DATABASE': 'your_db',
            'DM_HOST': 'localhost',
            'DM_PORT': '5236',
            'DM_USER': 'SYSDBA',
            'DM_PASSWORD': 'SYSDBA',
            'DM_SCHEMA': 'SYSDBA',
            'LLM_MODEL': 'Qwen/Qwen3.5-122B-A10B',
            'LLM_BASE_URL': 'https://api.siliconflow.cn/v1',
            'FLASK_PORT': '5000',
            'DOC_DIR': './docs'
        }
        with open(config_path, 'w', encoding='utf-8') as f:
            config.write(f)
    else:
        config.read(config_path, encoding='utf-8')
    return config
class MySQLRunner(SqlRunner):
    def __init__(self, host, port, user, password, database):
        self.connection_params = {
            'host': host, 'port': int(port), 'user': user,
            'password': password, 'database': database, 'charset': 'utf8mb4'
        }
    async def run_sql(self, args, context):
        sql = self._qualify(args.sql, self.connection_params['database'])
        logger.info(f"📊 MySql: {sql[:150]}")
        conn = pymysql.connect(**self.connection_params)
        try:
            cursor = conn.cursor()
            cursor.execute(args.sql)
            if cursor.description:
                df = pd.DataFrame(cursor.fetchall(), columns=[d[0] for d in cursor.description])
            else:
                conn.commit()
                df = pd.DataFrame()
            return df
        finally:
            conn.close()
    def _qualify(self, s, database):
        s = re.sub(r"\bFROM\s+(\w+)", f"FROM {database}.\\1", s, flags=re.I)
        s = re.sub(r"\bJOIN\s+(\w+)", f"JOIN {database}.\\1", s, flags=re.I)
        return s
class DMRunner(SqlRunner):
    def __init__(self, host, port, user, password, schema):
        if not dmPython:
            raise Exception("请安装 dmPython")
        self.host = host
        self.port = port
        self.user = user
        self.password = password
        self.schema = schema
    async def run_sql(self, args, context):
        sql = self._qualify(args.sql, self.schema)
        logger.info(f"📊 DM: {sql[:150]}")
        conn = dmPython.connect(user=self.user, password=self.password, server=self.host, port=self.port, schema=self.schema)
        try:
            cursor = conn.cursor()
            sql = args.sql.replace('`', '"')
            cursor.execute(sql)
            if cursor.description:
                df = pd.DataFrame(cursor.fetchall(), columns=[d[0] for d in cursor.description])
            else:
                conn.commit()
                df = pd.DataFrame()
            return df
        finally:
            conn.close()
    def _qualify(self, s, schema):
        if re.search(r"(ALL_TAB_COMMENTS|ALL_COL_COMMENTS|USER_TABLES|DBA_)", s, re.I):
            return s
        s = re.sub(r"\bFROM\s+(\w+)", f"FROM {schema}.\\1", s, flags=re.I)
        s = re.sub(r"\bJOIN\s+(\w+)", f"JOIN {schema}.\\1", s, flags=re.I)
        return s
async def train_all_tables(agent, sql_runner, db_type, db_name=None):
    logger.info("🚀 开始训练全库所有表...")
    if db_type == "mysql":
        sql = f"SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_SCHEMA = '{db_name}'"
    elif db_type == "dm":
        sql = "SELECT TABLE_NAME FROM USER_TABLES"
    else:
        logger.warning("不支持自动训练,跳过")
        return
    df = await sql_runner.run_sql(RunSqlToolArgs(sql=sql), None)
    tables = df.iloc[:, 0].tolist()
    logger.info(f"✅ 共找到 {len(tables)} 张表: {tables}")
# ==================== 主程序 ====================
def main():
    print("\n" + "="*60)
    print(" Vanna 正式生产版 + 百度地图MCP工具")
    print("="*60)
    print("🔧 调试接口:")
    print("   - http://localhost:5000/debug/rag    查看RAG匹配情况")
    print("   - http://localhost:5000/debug/prompt 查看完整上下文")
    print("🗺️  地图功能:地址解析/POI检索/路线规划/IP定位")
    print("📝 日志文件:vanna_rag_info.log")
    print("="*60 + "\n")
    config = load_config()
    db_type = get_config_value(config, 'DEFAULT', 'DATABASE_TYPE', 'mysql').lower()
    db_name = get_config_value(config, 'DEFAULT', 'MYSQL_DATABASE') if db_type == 'mysql' else None
    doc_dir = get_config_value(config, 'DEFAULT', 'DOC_DIR', './docs')
    is_authentication = get_config_value(config, 'DEFAULT', 'IS_AUTHENTICATION', 'false')
    api_key = os.getenv("SILICONFLOW_API_KEY", "sk-ntznt1111lonavfuhpvixycysjoesnis")
    llm_model = get_config_value(config, 'DEFAULT', 'LLM_MODEL', 'Qwen/Qwen3.5-122B-A10B')
    llm_base_url = get_config_value(config, 'DEFAULT', 'LLM_BASE_URL', 'https://api.siliconflow.cn/v1')
    llm = OpenAILlmService(
        model=llm_model,
        api_key=api_key,
        base_url=llm_base_url
    )
    sql_runner = None
    if db_type == "sqlite":
        sqlite_path = get_config_value(config, 'DEFAULT', 'SQLITE_PATH', './sales.db')
        sql_runner = SqliteRunner(sqlite_path)
    elif db_type == "mysql":
        sql_runner = MySQLRunner(
            host=get_config_value(config, 'DEFAULT', 'MYSQL_HOST', 'localhost'),
            port=get_config_value(config, 'DEFAULT', 'MYSQL_PORT', 3306),
            user=get_config_value(config, 'DEFAULT', 'MYSQL_USER', 'root'),
            password=get_config_value(config, 'DEFAULT', 'MYSQL_PASSWORD', '123456'),
            database=db_name
        )
    elif db_type == "dm":
        sql_runner = DMRunner(
            host=get_config_value(config, 'DEFAULT', 'DM_HOST', 'localhost'),
            port=get_config_value(config, 'DEFAULT', 'DM_PORT', 5236),
            user=get_config_value(config, 'DEFAULT', 'DM_USER', 'SYSDBA'),
            password=get_config_value(config, 'DEFAULT', 'DM_PASSWORD', 'SYSDBA'),
            schema=get_config_value(config, 'DEFAULT', 'DM_SCHEMA', 'SYSDBA')
        )
        
    #是否需要 jwt 认证
    user_resolver = None
    if is_authentication == "true":
        logger.info("配置JWT用户认证...")
        user_resolver = JwtUserResolver(
            secret_key=JWT_SECRET_KEY,
            algorithm=JWT_ALGORITHM
        )
    else:
        logger.info("配置普通用户证...")
        user_resolver = SimpleUserResolver()
    # 注册工具(已包含地图 MCP)
    tools = ToolRegistry()
    tools.register_local_tool(RunSqlTool(sql_runner=sql_runner), access_groups=['admin', 'user'])
    tools.register_local_tool(VisualizeDataTool(), access_groups=['admin', 'user'])
    tools.register_local_tool(SaveQuestionToolArgsTool(), access_groups=['admin'])
    tools.register_local_tool(SearchSavedCorrectToolUsesTool(), access_groups=['admin', 'user'])
    tools.register_local_tool(SaveTextMemoryTool(), access_groups=['admin', 'user'])
    tools.register_local_tool(GaodeMapTool(), access_groups=['admin', 'user'])
    doc_search_service = LocalDocSearchService(doc_dir=doc_dir, api_key=api_key)
    doc_enhancer = DocumentationEnhancer(doc_search_service)
    schema_enhancer = SchemaCommentEnhancer(sql_runner, db_type, db_name)
    combined_enhancer = CombinedEnhancer([
        schema_enhancer,
        doc_enhancer,
        DefaultLlmContextEnhancer(DemoAgentMemory())
    ])
    agent = Agent(
        llm_service=llm,
        tool_registry=tools,
        user_resolver=user_resolver,
        agent_memory=DemoAgentMemory(),
        llm_context_enhancer=combined_enhancer,
        config=_agent_config
    )
    import asyncio
    asyncio.run(train_all_tables(agent, sql_runner, db_type, db_name))
    port = int(get_config_value(config, 'DEFAULT', 'FLASK_PORT', 5000))
    
    try:
        server = VannaFlaskServer(agent, app=app)
        server.run(host="0.0.0.0", port=port, debug=False)
    except:
        logger.warning("⚠️ Vanna不支持自定义Flask app,将启动独立调试服务(端口5001)")
        import threading
        def start_vanna_server():
            server = VannaFlaskServer(agent)
            server.run(host="0.0.0.0", port=port, debug=False)
            register_auth_error_handler(server)
        def start_debug_server():
            app.run(host="0.0.0.0", port=5001, debug=False, use_reloader=False)
        
        threading.Thread(target=start_vanna_server, daemon=True).start()
        threading.Thread(target=start_debug_server, daemon=True).start()
        
        print(f"\n🎉 Vanna主服务启动:http://localhost:{port}")
        print(f"🔧 调试服务启动:http://localhost:5001/debug/rag")
        print("="*60)
        import time
        while True:
            time.sleep(1)
            
# ==================== 生成测试 JWT 令牌的辅助函数(可选) ====================
def generate_test_jwt():
    """生成测试用的有效JWT令牌(运行前执行)"""
    import time
    payload = {
        "sub": "admin",
        "username": "admin",
        "email": "admin@example.com",
        "groups": ["admin"],
        "exp": int(time.time()) + 3600  # 1 小时过期
    }
    token = jwt.encode(payload, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
    print("🔑 测试用有效JWT令牌:")
    print(token)
    return token
if __name__ == "__main__":
    # 可选:生成测试令牌(注释掉则直接启动服务)
    generate_test_jwt()
    main()

注意:请将 api-key 更换成自己的

可以通过配置是否启用 jwt认证 校验

可以自己注册工具,详细看:

tools.register_local_tool(GaodeMapTool(), access_groups=['admin', 'user'])

配置文件

[DEFAULT]
DATABASE_TYPE = dm
SQLITE_PATH = ./sales.db
MYSQL_HOST = host.docker.internal
MYSQL_PORT = 3306
MYSQL_USER = root
MYSQL_PASSWORD = 123456
MYSQL_DATABASE = your_database
DM_HOST = 192.168.0.102
DM_PORT = 5236
DM_USER = ZCJY_HPQ
DM_PASSWORD = Gzbright38259201@
DM_SCHEMA = ZCJY_HPQ
LLM_MODEL = Qwen/Qwen3.5-122B-A10B
LLM_BASE_URL = https://api.siliconflow.cn/v1
FLASK_PORT = 5000
DOC_DIR = ./docs
IS_AUTHENTICATION = true

# 对应配置文件

docs 目录文件

# 业务规则文件

# 核心业务规则手册
   说明:本手册统一维护数据库查询相关业务规则,AI将严格遵守;后续新增规则,按对应模块格式追加即可,无需修改代码。
   ## 一、地区编码层级规则
   ### 1.1 编码规则
   dist 表存储地区编码,编码长度对应层级,具体规则如下:
   - 4位编码:区级(示例:0102)
   - 6位编码:街道/镇级(示例:010207 → 对应九佛街)
   - 9位编码:村级(示例:010207001)
   - 12位编码:账套/单位级(存储于zt表,属于村级下级)
   ### 1.2 查询规范(必须严格遵守)
      1. 查询某地区下的数据时,必须先从 dist 表查询对应地区名称的 distNo;
      2. 获取 distNo 后,使用 LIKE 前缀匹配查询该地区及下级所有数据(示例:distNo LIKE '010207%');
      3. 禁止直接使用「distName = 'xxx'」的条件查询 ht_contract 合同表。
   ## 二、表关联关系规则
   ### 2.1 核心表关联
   - ht_contract(合同表) ←→ dist(地区表):通过 distNo 字段关联
   - ht_contract(合同表) ←→ zt(账套/单位表):通过 distNo 字段关联
   - zt(账套/单位表) ←→ dist(地区表):通过 distNo 字段关联
   ### 2.2 关联查询要求
   查询某地区下的合同、账套数据时,必须先查询 dist 表获取对应 distNo,再通过 distNo 模糊匹配关联其他表,禁止单表直接查询。
   
   ## 三、字段规则
   ht_land.zcdl:土地资产类型,枚举值及查询要求如下:
   - 资源性资产类型
   - 物业资产类型
   - 其他资产类型
   - 实物资产类型
     查询某类资产(如物业资产)时,需添加条件:zcdl = '物业资产类型'
   ## 四、统计口径规则
格式:统计场景 + 统计规则 + 示例SQL
### 示例1:合同数量统计口径(核心场景)
统计场景:查询某地区(如九佛街)生效合同总数量
统计规则:
- 仅统计 ht_contract 表中 status = 1(生效)的合同;
- 需关联 dist 表获取地区编码,用 distNo LIKE 前缀匹配;
- 排除 status = 0(草稿)、2(终止)、9(删除)的合同;
- 统计结果需标注统计地区、统计时间范围。
示例SQL:
-- 第一步:获取九佛街地区编码
SELECT distno FROM zcjy_hpq.dist WHERE distname = '九佛街';
-- 第二步:按口径统计生效合同数量
SELECT COUNT(*) as valid_contract_count, '九佛街' as统计地区, CURDATE() as统计日期
FROM zcjy_hpq.ht_contract
WHERE distNo LIKE '查询到的distno%' AND status = 1;
### 示例2:土地资产面积统计口径
统计场景:查询某地区物业资产类型的土地总面积
统计规则:
- 仅统计 ht_land 表中 zcdl = '物业资产类型' 的土地数据;
- 关联 dist 表获取地区编码,仅统计该地区及下级土地;
- 面积字段取 area 字段,单位为“平方米”,保留2位小数;
- 排除 area 为空、0或负数的无效数据。
示例SQL:
-- 第一步:获取九佛街地区编码
SELECT distno FROM zcjy_hpq.dist WHERE distname = '九佛街';
-- 第二步:按口径统计物业资产土地总面积
SELECT ROUND(SUM(area), 2) as total_property_land_area, '平方米' as单位, '九佛街' as统计地区
FROM zcjy_hpq.ht_land
JOIN zcjy_hpq.dist ON ht_land.distNo LIKE CONCAT(dist.distno, '%')
WHERE dist.distname = '九佛街' AND ht_land.zcdl = '物业资产类型' AND ht_land.area > 0;
### 示例3:多表关联统计口径(合同+土地关联)
统计场景:查询某地区物业资产对应的生效合同总金额
统计规则:
- 关联 ht_contract(合同表)、ht_land(土地表)、dist(地区表);
- 合同需满足 status = 1(生效),土地需满足 zcdl = '物业资产类型';
- 金额取 ht_contract.Amount_tot(总标的)字段,汇总后保留2位小数;
- 按地区编码模糊匹配,确保统计范围包含该地区所有下级。
示例SQL:
SELECT ROUND(SUM(c.Amount_tot), 2) as total_contract_amount, '元' as单位, d.distname as统计地区
FROM zcjy_hpq.ht_contract c
JOIN zcjy_hpq.ht_land l ON c.distNo = l.distNo
JOIN zcjy_hpq.dist d ON c.distNo LIKE CONCAT(d.distno, '%')
WHERE d.distname = '九佛街' AND c.status = 1 AND l.zcdl = '物业资产类型';
## 五、预留新增规则模块(后续添加用)
### 4.1 新增表关联关系
 格式:【表A名称】(表A用途) ←→ 【表B名称】(表B用途):通过 【关联字段1】、【关联字段2】 关联
 示例:prj_project(项目表) ←→ ht_contract(合同表):通过 prj_no 字段关联
### 4.2 新增字段规则
 格式:【表名】.【字段名】:字段含义、枚举值(如有)、查询要求
 示例:ht_contract.status:合同状态,0=草稿、1=生效、2=终止、9=删除;查询有效合同时需加条件 status = 1
### 4.3 新增统计口径规则
 格式:统计场景 + 统计规则 + 示例SQL
### 4.4 新增权限/数据过滤规则
 格式:规则说明 + 强制要求 + 禁止操作

# 表说明文件

## dist表
表说明:地区数据表
字段说明:
- id:int,主键
- distno:varchar(50),地区编号
- distname:varchar(100),地区名称
- parentDistNo:varchar(50),所属地区编号
- parentDistName:varchar(100),所属地区名称
- distInfo:varchar(100),地区信息
## zt表
表说明:账套/单位信息表,存储各单位(经联社、经济社等)的基本信息、财务信息及系统配置
字段说明:
- ztid:varchar,账套编号
- ztName:varchar,账套全称
- ztJc:varchar,账套简称
- distNo:varchar,地区编号
- distName:varchar,地区名称
- ztTypeId:int,账套类型(1是经济社,2是经联社,3是公司)
- ztTypeName:varchar,账套类型(如:经联社、经济社、公司)
- sqdwdbr:varchar,单位代表人
- sqdwdz:varchar,单位地址
- tel:varchar,联系电话
- zmsbh:varchar,证明书编号
- zbr:varchar,制表人
- dwfzr:varchar,单位负责人
- kjxm:varchar,会计姓名
- cwfzr:varchar,财务负责人
- cnxm:varchar,出纳姓名
- sp_head:varchar,业务审批号前缀
- cno_head:varchar,合同编号前缀
- jdxz:varchar,监督小组
- BS_ztNo:varchar,BS财务账套号
- BS_dbName:varchar,BS财务数据库
- cj_amount1:decimal,财监报警额度
- cj_amount2:decimal,使用现金超额限度
- cj_amount3:decimal,日超额提现额度
- zdkxNo:varchar,第三方资产对接编号
- sqdwdbrID:varchar,法人身份证号码
## ht_land表
表说明:资产主表,存储所有类型的资产信息(资源、物业、实物等)
字段说明:
- landName:varchar,资产名称
- landNo:varchar,资产编号
- distNo:varchar,地区编号
- distName:varchar,地区名称
- ztName:varchar,单位名称
- ztJc:varchar,单位简称
- address:varchar,地址(土名)
- area:decimal,面积(数量)
- unit:varchar,面积单位
- zcsl:decimal,出租宗数
- jykflx:varchar,经营开发类型
- description:varchar,摘要(四至描述)
- zyxx:varchar,所有权属性
- yt:varchar,资产用途
- detailUse:varchar,详细用途
- cqxz:varchar,产权性质
- zyfxm:varchar,争议相对方姓名
- fszch:varchar,附属资源号
- zdxlh:varchar,宗地系列号
- fcsyr:varchar,房产所有权人
- fcsyqzh:varchar,房产所有权证号
- dyqr:varchar,抵押权人
- dytimes:decimal,抵押期限(年)
- zclb:varchar,资产性质
- longitude:varchar,经度
- latitude:varchar,纬度
- writeDate:date,录入日期
- writer:varchar,录入人
- landinfo:varchar,资产信息
- f1:varchar,资产图片
- enddate:date,终止日期
- enduser:varchar,终止人
- sfqq:bit,是否确权
- tdsuoyouzh:varchar,土地所有权证号
- tdsuoyour:varchar,土地所有权人
- dh:varchar,地号
- th:varchar,图号
- jyq_tdzmj:decimal,土地总面积(公顷)
- jyq_nyd:decimal,农用地(公顷)
- jyq_gd:decimal,耕地(公顷)
- jyq_yd:decimal,园地(公顷)
- jyq_ld:decimal,林地(公顷)
- jyq_mcd:decimal,牧草地(公顷)
- jyq_qt:decimal,其它(公顷)
- jyq_jsyd:decimal,建设用地(公顷)
- jyq_wlyd:decimal,未利用地(公顷)
- zcdlx:varchar,资产类别
- zcdl:varchar,资产大类
- landType:varchar,资产类型
- distNo_2:varchar,地区编号
- distName_2:varchar,地区名称
- ztName_2:varchar,单位名称
- distNo_3:varchar,地区编号
- distName_3:varchar,地区名称
- ztName_3:varchar,单位名称
- distNo_4:varchar,地区编号
- distName_4:varchar,地区名称
- ztName_4:varchar,单位名称
- distNo_5:varchar,地区编号
- distName_5:varchar,地区名称
- ztName_5:varchar,单位名称
## ht_contract表
表说明:合同主表,存储所有类型的合同信息(承包、租赁、工程、采购、借款等)
字段说明:
- cno:varchar,合同号
- htdl:varchar,合同大类(如:承包合同、租赁合同、工程合同、借款合同)
- contractType:varchar,合同类型
- prj_no:varchar,立项号
- prj_jiaoyi_no:varchar,交易号
- distNo:varchar,地区编号
- distName:varchar,地区名称
- ztName:varchar,单位名称
- ztJc:varchar,单位简称
- owner:varchar,合同甲方
- sqdwdbr:varchar,单位代表人
- sqdwdbrTel:varchar,代表人联系电话
- customer:varchar,合同乙方
- lxr:varchar,乙方联系人
- tel:varchar,乙方联系方式
- address:varchar,乙方地址
- area:decimal,面积/数量
- unit:varchar,面积/数量单位
- price:decimal,单价
- priceUnit:varchar,价格单位
- yearamount:decimal,首年金额(元)
- Amount_tot:decimal,总标的(元)
- years:decimal,年限(年)
- years_y:decimal,年限(年)
- years_m:decimal,年限(月)
- years_d:decimal,年限(天)
- bdate:date,开始日期
- edate:date,结束日期
- setdate:date,签订日期
- enddate:date,终止日期
- amountStartDate:date,计租/计提开始日期
- zjsqzq:varchar,租金/收款/付款周期
- bbznx:decimal,不变租年限
- bzzq:decimal,每几年变租一次
- bzfs:varchar,升租方式(比例/固定值)
- bzbl:decimal,升租比例(%)/升租金额
- cashPledge:decimal,合同履约金(押金)
- description:text,合同描述
- writer:varchar,录入人
- writeDate:date,录入日期
- f1:varchar,合同附件

# docker 部署

# 大体目录结构

# Dockerfile

# Dockerfile
FROM registry.cn-guangzhou.aliyuncs.com/tzting/python:3.11.14-slim
# 设置工作目录
WORKDIR /app
# 设置环境变量
ENV PYTHONUNBUFFERED=1 \
    TZ=Asia/Shanghai \
    PIP_NO_CACHE_DIR=1
# 安装系统依赖
RUN apt-get update && apt-get install -y --no-install-recommends \
    gcc \
    g++ \
    make \
    libffi-dev \
    libssl-dev \
    curl \
    # ODBC 核心依赖
    unixodbc \
    unixodbc-dev \
    freetds-dev \
    freetds-bin \
    tdsodbc \
    # MySQL 支持
    default-libmysqlclient-dev \
    # PostgreSQL 支持
    libpq-dev \
    && rm -rf /var/lib/apt/lists/* \
    && ln -snf /usr/share/zoneinfo/$TZ /etc/localtime \
    && echo $TZ > /etc/timezone
# 配置 ODBC
RUN echo '[FreeTDS]' > /etc/odbcinst.ini && \
    echo 'Description = FreeTDS Driver' >> /etc/odbcinst.ini && \
    echo 'Driver = /usr/lib/x86_64-linux-gnu/odbc/libtdsodbc.so' >> /etc/odbcinst.ini
# 复制依赖文件
COPY requirements.txt .
# 安装 Python 依赖(使用国内镜像加速)
RUN pip install --no-cache-dir -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
# 复制项目文件
COPY . .
# 创建必要目录
RUN mkdir -p /app/docs /app/logs
# 暴露端口
EXPOSE 5000 5001
# 健康检查
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
    CMD curl -f http://localhost:5000/health || exit 1
# 启动命令
CMD ["python", "vanna_app.py"]

# requirements.txt

# requirements.txt
# 基础依赖
annotated-types==0.7.0
anyio==4.12.1
certifi==2026.2.25
charset-normalizer==3.4.4
click==8.3.1
idna==3.11
packaging==25.0
pip==26.0.1
setuptools==80.10.2
six==1.17.0
typing_extensions==4.15.0
typing-inspection==0.4.2
wheel==0.46.3
# Web框架
Flask==3.1.3
flask-cors==6.0.2
Werkzeug==3.1.6
itsdangerous==2.2.0
Jinja2==3.1.6
MarkupSafe==3.0.3
blinker==1.9.0
# FastAPI相关(如果需要)
fastapi==0.135.1
starlette==0.52.1
pydantic==2.12.5
pydantic_core==2.41.5
# Vanna核心
vanna==2.0.0
faiss-cpu==1.13.2
docstring_parser==0.17.0
tabulate==0.10.0
tqdm==4.67.3
PyYAML==6.0.3
python-dotenv==1.2.2
sqlparse==0.5.5
# LLM集成
openai==2.26.0
anthropic==0.84.0
httpx==0.28.1
httpcore==1.0.9
distro==1.9.0
h11==0.16.0
jiter==0.13.0
sniffio==1.3.1
# 数据库驱动
PyMySQL==1.1.2
psycopg2-binary==2.9.11
pymssql==2.3.13
pyodbc==5.3.0
dmPython==2.5.30
SQLAlchemy==2.0.48
# 数据处理
pandas==2.3.3
numpy==2.4.2
matplotlib==3.10.8
seaborn==0.13.2
plotly==6.6.0
pillow==12.1.1
pyarrow==23.0.1
db-dtypes==1.5.0
greenlet==3.3.2
# 数据可视化
contourpy==1.3.3
cycler==0.12.1
fonttools==4.61.1
kiwisolver==1.4.9
pyparsing==3.3.2
# 工具库
python-dateutil==2.9.0.post0
pytz==2026.1.post1
tzdata==2025.3
urllib3==2.6.3
requests==2.32.5
cryptography==46.0.5
cffi==2.0.0
pycparser==3.0
annotated-doc==0.0.4
narwhals==2.17.0
pyjwt==2.12.1
pyodbc
# 语义搜索和向量化
sentence-transformers==3.4.1

# 查看日志脚本

#!/bin/bash
# docker-logs.sh - 查看日志
CONTAINER_NAME="vanna-prod"
echo "📋 查看容器日志 (Ctrl+C 退出)"
docker logs -f --tail 100 $CONTAINER_NAME

# 启动脚本

#!/bin/bash
# docker-run.sh - 纯 Docker 命令部署脚本
set -e
# 颜色输出
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
echo -e "${GREEN}=========================================${NC}"
echo -e "${GREEN}Vanna Docker 部署脚本(纯Docker命令)${NC}"
echo -e "${GREEN}=========================================${NC}"
# 配置参数
IMAGE_NAME="vanna-app"
CONTAINER_NAME="vanna-prod"
NETWORK_NAME="vanna-network"
PORT_MAIN=5000
PORT_DEBUG=5001
# 检查 API Key
if [ -z "$SILICONFLOW_API_KEY" ]; then
    echo -e "${RED}❌ 错误: 未设置 SILICONFLOW_API_KEY 环境变量${NC}"
    echo -e "${YELLOW}请执行: export SILICONFLOW_API_KEY='sk-your-api-key'${NC}"
    exit 1
fi
# 创建必要目录
echo -e "${YELLOW}📁 创建必要目录...${NC}"
mkdir -p docs logs
echo -e "${GREEN}✅ 目录创建完成${NC}"
# 创建默认配置文件(如果不存在)
if [ ! -f vanna_config.ini ]; then
    echo -e "${YELLOW}📝 创建默认配置文件 vanna_config.ini...${NC}"
    cat > vanna_config.ini << 'EOF'
[DEFAULT]
DATABASE_TYPE = mysql
SQLITE_PATH = ./sales.db
MYSQL_HOST = host.docker.internal
MYSQL_PORT = 3306
MYSQL_USER = root
MYSQL_PASSWORD = 123456
MYSQL_DATABASE = your_database
DM_HOST = localhost
DM_PORT = 5236
DM_USER = SYSDBA
DM_PASSWORD = SYSDBA
DM_SCHEMA = SYSDBA
LLM_MODEL = Qwen/Qwen3.5-122B-A10B
LLM_BASE_URL = https://api.siliconflow.cn/v1
FLASK_PORT = 5000
DOC_DIR = ./docs
EOF
    echo -e "${GREEN}✅ 配置文件创建完成${NC}"
fi
# 停止并删除已存在的容器
if [ "$(docker ps -aq -f name=$CONTAINER_NAME)" ]; then
    echo -e "${YELLOW}🛑 停止并删除已存在的容器...${NC}"
    docker stop $CONTAINER_NAME 2>/dev/null || true
    docker rm $CONTAINER_NAME 2>/dev/null || true
    echo -e "${GREEN}✅ 旧容器已清理${NC}"
fi
# 创建 Docker 网络(如果不存在)
if ! docker network inspect $NETWORK_NAME >/dev/null 2>&1; then
    echo -e "${YELLOW}🌐 创建Docker网络...${NC}"
    docker network create $NETWORK_NAME
    echo -e "${GREEN}✅ 网络创建完成${NC}"
fi
# 构建镜像
echo -e "${YELLOW}🐳 构建Docker镜像...${NC}"
docker build -t $IMAGE_NAME .
echo -e "${GREEN}✅ 镜像构建完成${NC}"
# 启动容器
echo -e "${YELLOW}🚀 启动容器...${NC}"
docker run -d \
    --name $CONTAINER_NAME \
    --network $NETWORK_NAME \
    -p $PORT_MAIN:5000 \
    -p $PORT_DEBUG:5001 \
    -e SILICONFLOW_API_KEY="$SILICONFLOW_API_KEY" \
    -e TZ=Asia/Shanghai \
    -v $(pwd)/vanna_config.ini:/app/vanna_config.ini:ro \
    -v $(pwd)/docs:/app/docs \
    -v $(pwd)/logs:/app/logs \
    --restart unless-stopped \
    $IMAGE_NAME
# 等待容器启动
echo -e "${YELLOW}⏳ 等待容器启动...${NC}"
sleep 5
# 检查容器状态
if [ "$(docker ps -q -f name=$CONTAINER_NAME)" ]; then
    echo -e "${GREEN}=========================================${NC}"
    echo -e "${GREEN}🎉 部署成功!${NC}"
    echo -e "${GREEN}=========================================${NC}"
    echo -e "容器名称: ${YELLOW}$CONTAINER_NAME${NC}"
    echo -e "Vanna主服务: ${YELLOW}http://localhost:$PORT_MAIN${NC}"
    echo -e "调试接口:"
    echo -e "  - RAG状态: ${YELLOW}http://localhost:$PORT_MAIN/debug/rag${NC}"
    echo -e "  - 完整上下文: ${YELLOW}http://localhost:$PORT_MAIN/debug/prompt${NC}"
    echo -e "  - 健康检查: ${YELLOW}http://localhost:$PORT_MAIN/health${NC}"
    echo -e ""
    echo -e "查看日志: ${YELLOW}docker logs -f $CONTAINER_NAME${NC}"
    echo -e "停止服务: ${YELLOW}docker stop $CONTAINER_NAME${NC}"
    echo -e "启动服务: ${YELLOW}docker start $CONTAINER_NAME${NC}"
    echo -e "进入容器: ${YELLOW}docker exec -it $CONTAINER_NAME /bin/bash${NC}"
    echo -e "${GREEN}=========================================${NC}"
    
    # 显示最近的日志
    echo -e "\n${YELLOW}📋 最近日志:${NC}"
    docker logs --tail 20 $CONTAINER_NAME
else
    echo -e "${RED}❌ 容器启动失败!${NC}"
    echo -e "${YELLOW}查看错误日志:${NC}"
    docker logs $CONTAINER_NAME
    exit 1
fi

# 关闭脚本

#!/bin/bash
# docker-stop.sh - 停止并清理容器
CONTAINER_NAME="vanna-prod"
IMAGE_NAME="vanna-app"
echo "🛑 停止容器..."
docker stop $CONTAINER_NAME 2>/dev/null || true
echo "🗑️  删除容器..."
docker rm $CONTAINER_NAME 2>/dev/null || true
echo "❓ 是否删除镜像? (y/n)"
read -r answer
if [ "$answer" = "y" ]; then
    echo "🗑️  删除镜像..."
    docker rmi $IMAGE_NAME 2>/dev/null || true
fi
echo "✅ 清理完成"

# 重启脚本

#!/bin/bash
# docker-restart.sh - 重启容器
CONTAINER_NAME="vanna-prod"
echo "🔄 重启容器..."
docker restart $CONTAINER_NAME
echo "📋 查看日志..."
docker logs --tail 30 -f $CONTAINER_NAME

# 运行效果图

图中的数据均为测试数据

数据库查询结果:

# 接入 dify

dsl 文件:

app:
  description: ''
  icon: 🤖
  icon_background: '#FFEAD5'
  mode: advanced-chat
  name: text2sql
  use_icon_as_answer_icon: false
kind: app
version: 0.1.5
workflow:
  conversation_variables: []
  environment_variables: []
  features:
    file_upload:
      allowed_file_extensions:
      - .JPG
      - .JPEG
      - .PNG
      - .GIF
      - .WEBP
      - .SVG
      allowed_file_types:
      - image
      allowed_file_upload_methods:
      - local_file
      - remote_url
      enabled: false
      fileUploadConfig:
        audio_file_size_limit: 50
        batch_count_limit: 5
        file_size_limit: 15
        image_file_size_limit: 10
        video_file_size_limit: 100
        workflow_file_upload_limit: 10
      image:
        enabled: false
        number_limits: 3
        transfer_methods:
        - local_file
        - remote_url
      number_limits: 3
    opening_statement: ''
    retriever_resource:
      enabled: true
    sensitive_word_avoidance:
      enabled: false
    speech_to_text:
      enabled: false
    suggested_questions: []
    suggested_questions_after_answer:
      enabled: false
    text_to_speech:
      enabled: false
      language: ''
      voice: ''
  graph:
    edges:
    - data:
        isInIteration: false
        isInLoop: false
        sourceType: http-request
        targetType: code
      id: 1743835170032-source-1743836222004-target
      selected: false
      source: '1743835170032'
      sourceHandle: source
      target: '1743836222004'
      targetHandle: target
      type: custom
      zIndex: 0
    - data:
        isInIteration: false
        sourceType: start
        targetType: http-request
      id: 1743835151179-source-1743835170032-target
      source: '1743835151179'
      sourceHandle: source
      target: '1743835170032'
      targetHandle: target
      type: custom
      zIndex: 0
    - data:
        isInIteration: false
        sourceType: code
        targetType: llm
      id: 1743836222004-source-1743835483751-target
      source: '1743836222004'
      sourceHandle: source
      target: '1743835483751'
      targetHandle: target
      type: custom
      zIndex: 0
    - data:
        isInIteration: false
        sourceType: llm
        targetType: code
      id: 1743835483751-source-1777541698296-target
      source: '1743835483751'
      sourceHandle: source
      target: '1777541698296'
      targetHandle: target
      type: custom
      zIndex: 0
    - data:
        isInIteration: false
        sourceType: code
        targetType: if-else
      id: 1777541698296-source-1777544490969-target
      selected: false
      source: '1777541698296'
      sourceHandle: source
      target: '1777544490969'
      targetHandle: target
      type: custom
      zIndex: 0
    - data:
        isInIteration: false
        sourceType: if-else
        targetType: answer
      id: 1777544490969-true-answer-target
      source: '1777544490969'
      sourceHandle: 'true'
      target: answer
      targetHandle: target
      type: custom
      zIndex: 0
    - data:
        isInIteration: false
        sourceType: if-else
        targetType: answer
      id: 1777544490969-false-1777544518961-target
      source: '1777544490969'
      sourceHandle: 'false'
      target: '1777544518961'
      targetHandle: target
      type: custom
      zIndex: 0
    nodes:
    - data:
        desc: ''
        selected: false
        title: 开始
        type: start
        variables: []
      height: 54
      id: '1743835151179'
      position:
        x: 78.63206469609747
        y: 282
      positionAbsolute:
        x: 78.63206469609747
        y: 282
      selected: false
      sourcePosition: right
      targetPosition: left
      type: custom
      width: 244
    - data:
        authorization:
          config: null
          type: no-auth
        body:
          data:
          - id: key-value-5
            key: ''
            type: text
            value: '{
              "message": "{{#sys.query#}}"
              }'
          type: json
        desc: ''
        headers: Authorization:Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJhZG1pbiIsInVzZXJuYW1lIjoiYWRtaW4iLCJlbWFpbCI6ImFkbWluQGV4YW1wbGUuY29tIiwiZ3JvdXBzIjpbImFkbWluIl0sImV4cCI6MTc4MjU1NTY2MH0.9JCly6bMs7kvFyzdjgEpukytxMtjxWz3_niS3YArOpg1
        method: post
        params: ''
        retry_config:
          max_retries: 3
          retry_enabled: false
          retry_interval: 100
        selected: false
        timeout:
          max_connect_timeout: 0
          max_read_timeout: 0
          max_write_timeout: 0
        title: HTTP 请求
        type: http-request
        url: http://192.168.0.102:5000/api/vanna/v2/chat_sse
        variables: []
      height: 110
      id: '1743835170032'
      position:
        x: 680
        y: 282
      positionAbsolute:
        x: 680
        y: 282
      selected: true
      sourcePosition: right
      targetPosition: left
      type: custom
      width: 244
    - data:
        answer: '{{#1743835483751.text#}}
          📊 **图表展示**
          ```echarts
          {{#1777541698296.echarts_json#}}
          ```
          '
        desc: ''
        selected: false
        title: 直接回复
        type: answer
        variables: []
      height: 121
      id: answer
      position:
        x: 2655.940055790228
        y: 116.47982822779872
      positionAbsolute:
        x: 2655.940055790228
        y: 116.47982822779872
      selected: false
      sourcePosition: right
      targetPosition: left
      type: custom
      width: 244
    - data:
        context:
          enabled: true
          variable_selector:
          - '1743836222004'
          - answer
        desc: ''
        model:
          completion_params: {}
          mode: chat
          name: deepseek-chat
          provider: deepseek
        prompt_template:
        - id: 5db31ac3-92e2-4e9f-9338-d2dfc7304611
          role: system
          text: '帮我把以下内容转换为markdown表格输出,切记不要嵌套在```markdown```里面。
            {{#context#}}'
        - id: 223ab65b-dc04-4346-aa50-7f423d71fdf5
          role: user
          text: 需要对数据做分析,在返回分析结果
        selected: false
        title: LLM
        type: llm
        variables: []
        vision:
          enabled: false
      height: 98
      id: '1743835483751'
      position:
        x: 1441.2607027512033
        y: 282
      positionAbsolute:
        x: 1441.2607027512033
        y: 282
      selected: false
      sourcePosition: right
      targetPosition: left
      type: custom
      width: 244
    - data:
        code: "def main(response_data: str) -> dict:\n    \"\"\"\n    Dify 代码节点处理函数\n\
          \    \"\"\"\n    import json\n    \n    result = {\n        \"answer\":\
          \ \"\",\n        \"chart_data\": [],\n        \"chart_config\": {}\n   \
          \ }\n    \n    # 解析 SSE 响应 \n    lines = response_data.strip ().split ('\\\
          n')\n    \n    for line in lines:\n        line = line.strip()\n       \
          \ if not line.startswith('data:'):\n            continue\n            \n\
          \        json_str = line[5:].strip()\n        if json_str == '[DONE]':\n\
          \            break\n            \n        try:\n            event = json.loads(json_str)\n\
          \            rich_data = event.get('rich', {})\n            event_type =\
          \ rich_data.get('type', '')\n            \n            # 提取最终回答 \n      \
          \      if event_type == 'text':\n                result[\"answer\"] = event.get('simple',\
          \ {}).get('text', '')\n            \n            # 提取数据表格(用于图表)\n      \
          \      elif event_type == 'dataframe':\n                df_data = rich_data.get('data',\
          \ {})\n                if df_data.get('data'):\n                    # 转换为 \
          \ ECharts 所需格式\n                    columns = df_data.get('columns', [])\n\
          \                    rows = df_data.get('data', [])\n                  \
          \  \n                    # 自动检测图表类型 \n                    chart_type = 'bar'\
          \  # 默认柱状图 \n                    if len (columns) == 2 and all (isinstance (r.get (columns [1]),\
          \ (int, float)) for r in rows):\n                        # 两列数据,适合图表展示 \n\
          \                        chart_data = {\n                            'xAxis':\
          \ [str(r.get(columns[0])) for r in rows],\n                            'series':\
          \ [{\n                                'name': columns[1],\n            \
          \                    'data': [r.get(columns[1]) for r in rows]\n       \
          \                     }]\n                        }\n                  \
          \  else:\n                        # 多列数据,使用表格展示 \n                      \
          \  chart_data = {\n                            'type': 'table',\n      \
          \                      'columns': columns,\n                           \
          \ 'rows': rows\n                        }\n                    \n      \
          \              result[\"chart_data\"].append(chart_data)\n             \
          \       \n        except Exception as e:\n            print(f\"解析错误: {e}\"\
          )\n    \n    # 生成图表配置 \n    if result [\"chart_data\"]:\n        result [\"\
          chart_config\"] = {\n            \"type\": \"bar\",\n            \"title\"\
          : \"查询结果可视化\",\n            \"xAxis\": result[\"chart_data\"][0].get('xAxis',\
          \ []),\n            \"series\": result[\"chart_data\"][0].get('series',\
          \ [])\n        }\n    \n    return result"
        code_language: python3
        desc: ''
        outputs:
          answer:
            children: null
            type: string
          chart_config:
            children: null
            type: object
          chart_data:
            children: null
            type: array[object]
        selected: false
        title: 代码执行
        type: code
        variables:
        - value_selector:
          - '1743835170032'
          - body
          variable: response_data
      height: 54
      id: '1743836222004'
      position:
        x: 1051.1326358029296
        y: 282
      positionAbsolute:
        x: 1051.1326358029296
        y: 282
      selected: false
      sourcePosition: right
      targetPosition: left
      type: custom
      width: 244
    - data:
        code: "def main(chart_data: list, chart_config: dict) -> dict:\n    import\
          \ json\n    \n    # 查找图表数据(包含 xAxis 和 series 的项)\n    chart_item = None\n\
          \    tables = []\n    \n    for item in chart_data:\n        if \"xAxis\"\
          \ in item and \"series\" in item:\n            chart_item = item\n     \
          \   elif item.get(\"type\") == \"table\":\n            tables.append(item)\n\
          \    \n    # 情况 1:有图表数据,渲染图表 \n    if chart_item:\n        chart_type = chart_config.get (\"\
          type\", \"bar\")\n        title = chart_config.get(\"title\", \"数据统计图表\"\
          )\n        \n        option = {\n            \"title\": {\n            \
          \    \"text\": title,\n                \"left\": \"center\",\n         \
          \       \"top\": 0,\n                \"textStyle\": {\"fontSize\": 16}\n\
          \            },\n            \"tooltip\": {\"trigger\": \"axis\", \"axisPointer\"\
          : {\"type\": \"shadow\"}},\n            \"grid\": {\n                \"\
          left\": \"3%\",\n                \"right\": \"4%\",\n                \"\
          bottom\": \"3%\",\n                \"top\": \"15%\",\n                \"\
          containLabel\": True\n            },\n            \"xAxis\": {\n       \
          \         \"type\": \"category\",\n                \"data\": chart_item[\"\
          xAxis\"],\n                \"axisLabel\": {\"rotate\": 15, \"interval\"\
          : 0}\n            },\n            \"yAxis\": {\"type\": \"value\"},\n  \
          \          \"series\": [{\n                \"name\": series[\"name\"],\n\
          \                \"type\": chart_type,\n                \"data\": series[\"\
          data\"],\n                \"label\": {\"show\": True, \"position\": \"top\"\
          }\n            } for series in chart_item[\"series\"]]\n        }\n    \
          \    \n        return {\n            \"echarts_json\": json.dumps(option,\
          \ ensure_ascii=False),\n            \"show_type\": \"chart\",  # 标记显示图表 \n\
          \            \"table_html\": \"\",      # 表格为空 \n            \"total_assets\"\
          : None,\n            \"districts\": []\n        }\n    \n    # 情况 2:只有表格数据,生成 \
          \ HTML 表格\n    elif tables:\n        # 构建 HTML 表格 \n        html_parts =\
          \ ['<div style=\"overflow-x: auto;\"><table border=\"1\" style=\"border-collapse:\
          \ collapse; width: 100%;\">']\n        \n        for table in tables:\n\
          \            columns = table.get(\"columns\", [])\n            rows = table.get(\"\
          rows\", [])\n            \n            if columns and rows:\n          \
          \      # 表头 \n                html_parts.append ('<thead><tr>')\n        \
          \        for col in columns:\n                    html_parts.append(f'<th\
          \ style=\"padding: 8px; background-color: #f2f2f2;\">{col}</th>')\n    \
          \            html_parts.append('</tr></thead>')\n                \n    \
          \            # 表体 \n                html_parts.append ('<tbody>')\n      \
          \          for row in rows:\n                    html_parts.append('<tr>')\n\
          \                    for col in columns:\n                        html_parts.append(f'<td\
          \ style=\"padding: 8px; border: 1px solid #ddd;\">{row.get(col, \"\")}</td>')\n\
          \                    html_parts.append('</tr>')\n                html_parts.append('</tbody>')\n\
          \        \n        html_parts.append('</table></div>')\n        table_html\
          \ = ''.join(html_parts)\n        \n        # 提取统计数据 \n        total_assets\
          \ = None\n        districts = []\n        for table in tables:\n       \
          \     if \"total_contracts\" in table.get(\"columns\", []):\n          \
          \      if table.get(\"rows\"):\n                    total_assets = table[\"\
          rows\"][0].get(\"total_contracts\")\n            if \"distNo\" in table.get(\"\
          columns\", []):\n                for row in table.get(\"rows\", []):\n \
          \                   districts.append(row.get(\"distNo\"))\n        \n  \
          \      # 返回一个占位图表(显示表格提示)\n        option = {\n            \"title\": {\n\
          \                \"text\": \"\U0001F4CB 数据以表格形式展示\",\n                \"\
          left\": \"center\",\n                \"top\": \"center\",\n            \
          \    \"textStyle\": {\"color\": \"#666\", \"fontSize\": 14}\n          \
          \  },\n            \"xAxis\": {\"show\": False},\n            \"yAxis\"\
          : {\"show\": False},\n            \"series\": [{\"type\": \"bar\", \"data\"\
          : []}]\n        }\n        \n        return {\n            \"echarts_json\"\
          : json.dumps(option, ensure_ascii=False),\n            \"show_type\": \"\
          table\",      # 标记显示表格 \n            \"table_html\": table_html,  # HTML\
          \ 表格内容\n            \"total_assets\": total_assets,\n            \"districts\"\
          : districts\n        }\n    \n    # 情况 3:没有任何数据 \n    else:\n        option\
          \ = {\n            \"title\": {\"text\": \"暂无数据\", \"left\": \"center\"\
          , \"top\": \"center\"},\n            \"xAxis\": {\"show\": False},\n   \
          \         \"yAxis\": {\"show\": False},\n            \"series\": [{\"type\"\
          : \"bar\", \"data\": []}]\n        }\n        \n        return {\n     \
          \       \"echarts_json\": json.dumps(option, ensure_ascii=False),\n    \
          \        \"show_type\": \"empty\",\n            \"table_html\": \"<p>暂无数据</p>\"\
          ,\n            \"total_assets\": None,\n            \"districts\": []\n\
          \        }"
        code_language: python3
        desc: ''
        outputs:
          districts:
            children: null
            type: array[string]
          echarts_json:
            children: null
            type: string
          show_type:
            children: null
            type: string
          table_html:
            children: null
            type: string
          total_assets:
            children: null
            type: number
        selected: false
        title: echat
        type: code
        variables:
        - value_selector:
          - '1743836222004'
          - chart_data
          variable: chart_data
        - value_selector:
          - '1743836222004'
          - chart_config
          variable: chart_config
      height: 54
      id: '1777541698296'
      position:
        x: 1775.5133184902174
        y: 282
      positionAbsolute:
        x: 1775.5133184902174
        y: 282
      selected: false
      sourcePosition: right
      targetPosition: left
      type: custom
      width: 244
    - data:
        cases:
        - case_id: 'true'
          conditions:
          - comparison_operator: contains
            id: 9e9cafc5-525e-4405-a7b3-dfa452b7bb65
            value: chart
            varType: string
            variable_selector:
            - '1777541698296'
            - show_type
          id: 'true'
          logical_operator: and
        desc: ''
        selected: false
        title: 条件分支
        type: if-else
      height: 126
      id: '1777544490969'
      position:
        x: 2191.7878193219294
        y: 282
      positionAbsolute:
        x: 2191.7878193219294
        y: 282
      selected: false
      sourcePosition: right
      targetPosition: left
      type: custom
      width: 244
    - data:
        answer: '{{#1743835483751.text#}}
          '
        desc: ''
        selected: false
        title: 直接回复 2
        type: answer
        variables: []
      height: 103
      id: '1777544518961'
      position:
        x: 2655.940055790228
        y: 546
      positionAbsolute:
        x: 2655.940055790228
        y: 546
      selected: false
      sourcePosition: right
      targetPosition: left
      type: custom
      width: 244
    viewport:
      x: 211.27707997379594
      y: 163.48625731711581
      zoom: 0.8705505632961248

# 效果图

# 遇到的问题与解决方法

# 遇到 Tool Execution Limit 如何处理

Tool Execution Limit Reached The agent stopped after executing 10 tools (the configured maximum). The task may not be fully complete. You can: Ask me to continue where I left off Adjust the max_tool_iterations setting if you need more tool calls Break the task into smaller steps

解决方法:

调整最大调用次数:
_agent_config: AgentConfig = AgentConfig(
	max_tool_iterations=20
)