-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
256 lines (214 loc) · 7.87 KB
/
main.py
File metadata and controls
256 lines (214 loc) · 7.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
"""
主入口 - 基于 LLM 的自然语言转 SQL + 执行 + 导出工具
工作流程:
1. 用户输入自然语言问题
2. 系统获取数据库 Schema 并组装 Prompt
3. 调用大模型生成 SQL
4. 展示 SQL,用户确认后执行
5. 导出查询结果为 Excel
"""
import sys
import traceback
import argparse
from config import (
ENABLE_AUTO_FIX_SQL,
EXPORT_DIR,
list_llm_providers,
set_llm,
)
from db_utils import (
execute_sql,
export_to_excel,
get_schema,
refresh_schema,
test_connection,
)
from llm_utils import generate_sql, fix_sql, test_llm_connection
def print_banner():
"""打印启动横幅"""
banner = """
╔══════════════════════════════════════════════╗
║ 🤖 DM AI Analyst - 智能数据分析工具 ║
║ 自然语言 → SQL → 执行 → Excel ║
║ 目标数据库: 达梦8 (DM8) ║
╚══════════════════════════════════════════════╝
"""
print(banner)
print(f"📂 Excel 导出目录: {EXPORT_DIR}")
print("输入 'exit' 退出程序,输入 '!schema' 查看当前 Schema,输入 '!refresh' 刷新 Schema\n")
def handle_schema_command():
"""处理 !schema 命令:显示当前 Schema"""
print("\n📋 正在获取数据库 Schema...")
try:
schema = get_schema()
print("\n" + "=" * 60)
print(schema)
print("=" * 60 + "\n")
except Exception as e:
print(f"❌ 获取 Schema 失败: {e}\n")
def handle_refresh_command():
"""处理 !refresh 命令:刷新 Schema 缓存"""
print("\n🔄 正在刷新 Schema 缓存...")
refresh_schema()
print("✅ Schema 缓存已刷新\n")
def execute_query_flow(user_question: str, schema_desc: str):
"""
完整的查询工作流:生成 SQL → 确认 → 执行 → 导出
Args:
user_question: 用户的问题
schema_desc: Schema 描述
"""
# Step 1: 生成 SQL
print("\n🤔 正在分析问题,生成 SQL...")
try:
sql = generate_sql(user_question, schema_desc)
except Exception as e:
print(f"❌ 调用大模型失败: {e}")
return
print(f"\n📝 生成的 SQL:")
print("-" * 60)
print(sql)
print("-" * 60)
# Step 2: 用户确认
confirm = input("\n❓ 是否执行该 SQL?(y=执行 / n=取消 / e=手动编辑): ").strip().lower()
if confirm == "n":
print("⏭️ 已取消执行\n")
return
elif confirm == "e":
print("📝 请输入修改后的 SQL(输入 'END' 单独一行结束):")
lines = []
while True:
line = input()
if line.strip() == "END":
break
lines.append(line)
sql = "\n".join(lines)
print("✅ SQL 已更新\n")
# Step 3: 执行 SQL
print("⚡ 正在执行 SQL...")
max_fix_attempts = 3 if ENABLE_AUTO_FIX_SQL else 1
fix_count = 0
while fix_count < max_fix_attempts:
try:
df = execute_sql(sql)
break # 执行成功,跳出循环
except Exception as e:
fix_count += 1
error_msg = str(e)
if fix_count < max_fix_attempts:
print(f"⚠️ SQL 执行出错 (尝试 {fix_count}/{max_fix_attempts - 1}): {error_msg}")
print("🔄 正在请求大模型修正 SQL...")
try:
sql = fix_sql(sql, error_msg, schema_desc, user_question)
print(f"📝 修正后的 SQL:")
print("-" * 60)
print(sql)
print("-" * 60)
confirm_fix = input("❓ 是否执行修正后的 SQL?(y/n): ").strip().lower()
if confirm_fix != "y":
print("⏭️ 已跳过修正\n")
return
except Exception as fix_err:
print(f"❌ SQL 修正失败: {fix_err}")
return
else:
print(f"❌ SQL 执行失败 (已达最大修正次数): {error_msg}")
print("💡 提示:您可以修复 SQL 后使用 !sql 命令手动执行")
return
# Step 4: 显示结果
print(f"\n✅ 查询成功!共 {len(df)} 条记录,{len(df.columns)} 列")
print("\n📊 前 10 条结果预览:")
print("-" * 60)
print(df.head(10).to_string(index=False))
print("-" * 60)
# Step 5: 导出 Excel
export_confirm = input("\n❓ 是否导出为 Excel?(y/n): ").strip().lower()
if export_confirm == "y":
try:
filepath = export_to_excel(df)
print(f"✅ Excel 已导出: {filepath}\n")
except Exception as e:
print(f"❌ 导出失败: {e}\n")
else:
print("⏭️ 跳过导出\n")
def run_startup_checks():
"""启动时检查数据库和 LLM 连接"""
print("🔍 正在进行启动检查...\n")
# 检查数据库连接
db_ok, db_msg = test_connection()
if db_ok:
print(f" ✅ {db_msg}")
else:
print(f" ❌ {db_msg}")
print(" 💡 请检查 config.py 中的数据库配置,或设置环境变量 DM_USER/DM_PASSWORD 等\n")
# 检查 LLM 连接
llm_ok, llm_msg = test_llm_connection()
if llm_ok:
print(f" ✅ {llm_msg}")
else:
print(f" ❌ {llm_msg}")
print(" 💡 请检查 config.py 中的 LLM 配置,或设置环境变量 LLM_API_KEY\n")
if not db_ok or not llm_ok:
proceed = input("⚠️ 部分检查未通过,是否继续?(y/n): ").strip().lower()
if proceed != "y":
print("👋 已退出")
sys.exit(0)
print("")
def main():
"""主入口函数"""
parser = argparse.ArgumentParser(description="DM AI Analyst - 智能数据分析工具")
parser.add_argument(
"--llm",
default=None,
choices=list_llm_providers(),
help="LLM 平台选择(缺省使用 .env 中 LLM_DEFAULT 配置)",
)
args, _ = parser.parse_known_args()
if args.llm:
set_llm(args.llm)
print(f"🔀 已切换 LLM → {args.llm}")
print_banner()
# 启动检查
run_startup_checks()
# 预加载 Schema
print("📦 正在加载数据库 Schema...")
try:
schema_desc = get_schema()
print(f"✅ Schema 加载完成\n")
except Exception as e:
print(f"⚠️ Schema 加载失败: {e}")
print("💡 程序仍可运行,但需要先使用 !refresh 命令重新加载\n")
schema_desc = ""
# 主循环
while True:
try:
user_input = input("💬 请输入您的数据查询需求 (或输入命令): ").strip()
if not user_input:
continue
# 处理退出
if user_input.lower() in ("exit", "quit", "q"):
print("👋 感谢使用,再见!")
break
# 处理命令
if user_input.lower() == "!schema":
handle_schema_command()
continue
elif user_input.lower() == "!refresh":
handle_refresh_command()
schema_desc = get_schema()
continue
# 执行查询流程
if not schema_desc:
print("⚠️ Schema 未加载,请先使用 !refresh 命令加载\n")
continue
execute_query_flow(user_input, schema_desc)
except KeyboardInterrupt:
print("\n\n👋 检测到中断,退出程序")
break
except Exception as e:
print(f"\n❌ 发生未预期的错误: {e}")
print("📄 错误详情:")
traceback.print_exc()
print("")
if __name__ == "__main__":
main()