Python 文件代码如下:
import os
import logging
from flask import Flask, render_template, request, jsonify
from flask_cors import CORS
import boto3
from dotenv import load_dotenv
# 加载环境变量
load_dotenv()
# 初始化 Flask 应用
app = Flask(__name__)
CORS(app)
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 初始化 AWS 会话
session = boto3.Session(
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
region_name=os.getenv("AWS_DEFAULT_REGION", "us-east-1")
)
# 初始化 Bedrock 客户端
bedrock_runtime = session.client("bedrock-runtime")
bedrock_agent = session.client("bedrock-agent-runtime")
# 配置参数
# 使用 Amazon DeepSeek 的跨区域推理配置文件 ARN
MODEL_ARN = "你的 model ARN" # 可替换为其他模型
KNOWLEDGE_BASE_ID = "你的知识库ID" # 在控制台查看 Knowledge Base ID
@app.route("/")
def home():
return render_template("index.html")
@app.route("/chat", methods=["POST"])
def chat():
try:
data = request.get_json()
user_message = data.get("message", "").strip()
if not user_message:
return jsonify({"message": "消息不能为空"}), 400
logger.info(f"用户输入: {user_message}")
logger.info("开始连接知识库并生成回答...")
# 调用 Bedrock Agent 进行知识库检索和生成回答
response = bedrock_agent.retrieve_and_generate(
input={"text": user_message},
retrieveAndGenerateConfiguration={
"type": "KNOWLEDGE_BASE",
"knowledgeBaseConfiguration": {
"knowledgeBaseId": KNOWLEDGE_BASE_ID,
"modelArn": MODEL_ARN,
"retrievalConfiguration": {
"vectorSearchConfiguration": {
"numberOfResults": 3
}
},
"generationConfiguration": {
"inferenceConfig": {
"textInferenceConfig": {
"temperature": 0.3,
"maxTokens": 1024,
"topP": 0.9
}
},
"promptTemplate": {
"textPromptTemplate": """\n\nHuman: 请根据以下知识库内容用中文回答问题:
<context>
$search_results$
</context>
问题: $input$
\n\nAssistant:"""
}
}
}
}
)
logger.info("成功连接知识库,生成回答完成。")
# 提取回答内容
bot_reply = response.get("output", {}).get("text", "未能生成回答")
# 提取引用来源
citations = response.get("citations", [])
sources = []
for c in citations:
references = c.get("retrievedReferences", [])
for ref in references:
s3_uri = ref.get("location", {}).get("s3Location", {}).get("uri", "")
excerpt = ref.get("content", {}).get("text", "")[:100]
sources.append({
"title": s3_uri.split("/")[-1] if s3_uri else "未知文档",
"excerpt": excerpt + "..."
})
return jsonify({
"message": bot_reply,
"sources": sources
})
except Exception as e:
logger.error(f"请求失败: {str(e)}", exc_info=True)
return jsonify({"message": f"请求失败: {str(e)}"}), 500
if name == "__main__":
app.run(host="0.0.0.0", port=5000, debug=True)