From 308f2aea78595af604ad757833b359cef1869738 Mon Sep 17 00:00:00 2001 From: Yanyutin753 <132346501+Yanyutin753@users.noreply.github.com> Date: Thu, 23 May 2024 11:55:57 +0800 Subject: [PATCH] =?UTF-8?q?=E2=AD=90=E6=94=AF=E6=8C=81=E5=AF=B9=E4=BA=8E42?= =?UTF-8?q?9=E4=B9=8B=E5=90=8E=E6=A0=B9=E6=8D=AE=E8=BF=94=E5=9B=9E?= =?UTF-8?q?=E6=97=B6=E9=97=B4=EF=BC=8C=E9=99=90=E5=88=B6token=EF=BC=8C?= =?UTF-8?q?=E5=87=8F=E5=B0=91=E5=B0=81=E5=8F=B7=E5=8F=AF=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app.py | 12 ++++++++ chat2api.py | 12 ++++++-- chatgpt/chatLimit.py | 66 ++++++++++++++++++++++++++++++++++++++++++++ utils/config.py | 2 ++ 4 files changed, 89 insertions(+), 3 deletions(-) create mode 100644 chatgpt/chatLimit.py diff --git a/app.py b/app.py index ce1d13e..b67f330 100644 --- a/app.py +++ b/app.py @@ -1,9 +1,21 @@ +import asyncio + import uvicorn +from chat2api import app +from chatgpt.chatLimit import clean_background_task + log_config = uvicorn.config.LOGGING_CONFIG default_format = "%(asctime)s | %(levelname)s | %(message)s" access_format = r'%(asctime)s | %(levelname)s | %(client_addr)s: %(request_line)s %(status_code)s' log_config["formatters"]["default"]["fmt"] = default_format log_config["formatters"]["access"]["fmt"] = access_format + +# 在应用启动时启动后台任务 +@app.on_event("startup") +async def startup_event(): + clean_task = asyncio.create_task(clean_background_task()) + + uvicorn.run("chat2api:app", host="0.0.0.0", port=5005) diff --git a/chat2api.py b/chat2api.py index 711beb5..5b937ca 100644 --- a/chat2api.py +++ b/chat2api.py @@ -3,12 +3,13 @@ import warnings from fastapi import FastAPI, Request, Depends, HTTPException, Form from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import HTMLResponse from fastapi.responses import StreamingResponse, JSONResponse from fastapi.templating import Jinja2Templates -from fastapi.responses import HTMLResponse from starlette.background import BackgroundTask from chatgpt.ChatService import ChatService +from chatgpt.chatLimit import check_isLimit, handle_request_limit from chatgpt.reverseProxy import chatgpt_reverse_proxy from utils.Logger import logger from utils.authorization import verify_token, token_list @@ -30,13 +31,17 @@ app.add_middleware( async def to_send_conversation(request_data, access_token): - chat_service = ChatService(access_token) try: + limit_response = await handle_request_limit(request_data, access_token) + if limit_response: + raise HTTPException(status_code=429, detail=limit_response) + chat_service = ChatService(access_token) await chat_service.set_dynamic_data(request_data) await chat_service.get_chat_requirements() return chat_service except HTTPException as e: await chat_service.close_client() + await check_isLimit(e, access_token) raise HTTPException(status_code=e.status_code, detail=e.detail) except Exception as e: await chat_service.close_client() @@ -73,7 +78,8 @@ async def send_conversation(request: Request, token=Depends(verify_token)): @app.get(f"/{api_prefix}/tokens" if api_prefix else "/tokens", response_class=HTMLResponse) async def upload_html(request: Request): tokens_count = len(token_list) - return templates.TemplateResponse("tokens.html", {"request": request, "api_prefix": api_prefix, "tokens_count": tokens_count}) + return templates.TemplateResponse("tokens.html", + {"request": request, "api_prefix": api_prefix, "tokens_count": tokens_count}) @app.post(f"/{api_prefix}/tokens/upload" if api_prefix else "/tokens/upload") diff --git a/chatgpt/chatLimit.py b/chatgpt/chatLimit.py new file mode 100644 index 0000000..193750d --- /dev/null +++ b/chatgpt/chatLimit.py @@ -0,0 +1,66 @@ +import asyncio +import threading +import time +from datetime import datetime + +from utils.Logger import logger + +# 开启线程锁 +lock = threading.Lock() +# 对于密钥的限制 +limit_access_token = {} + + +async def check_isLimit(e, access_token): + if e.status_code == 429: + clearTime = time.time() + e.detail.get('clears_in') + clearDate = datetime.fromtimestamp(clearTime) + initial_access_list(access_token, clearTime, clearDate) + + +async def initial_access_list(key, clear_time): + with lock: + limit_access_token[key] = clear_time + logger.info( + f"{key}: Reached 429 limit, will be cleared at {datetime.fromtimestamp(clear_time).replace(second=0, microsecond=0)}") + + +async def remove_refresh_list(key): + with lock: + if key in limit_access_token: + limit_access_token.pop(key) + logger.info(f"Removed limit for {key[:40]}.") + else: + logger.info(f"Access token {key[:40]} not found in limit_access_token.") + + +async def handle_request_limit(request_data, access_token): + if "gpt-4" in request_data.get("model") and access_token in limit_access_token: + is_clear = limit_access_token[access_token] < time.time() + if is_clear: + await remove_refresh_list(access_token) + else: + clear_date = datetime.fromtimestamp(limit_access_token[access_token]) + logger.info(f"Request limit exceeded. " + f"You can continue with the default model now, " + f"or try again after {clear_date.replace(second=0, microsecond=0)}") + return f"Request limit exceeded. You can continue with the default model now, or try again after {clear_date.replace(second=0, microsecond=0)}" + return None + + +# 清理函数,用于删除长时间未使用的键值对 +def clean_dict(): + logger.info("开始执行清理过期的limit_access_token......") + current_time = time.time() + keys_to_remove = [key for key, clear_time in limit_access_token.items() if clear_time < current_time] + for key in keys_to_remove: + with lock: + del limit_access_token[key] + + +# 后台任务,定期执行清理函数 +async def clean_background_task(): + while True: + clean_dict() + # 一天执行一次,防止limit_access_token过多 + await asyncio.sleep(3600 * 24) diff --git a/utils/config.py b/utils/config.py index 6bbed86..b114b68 100644 --- a/utils/config.py +++ b/utils/config.py @@ -28,6 +28,7 @@ pow_difficulty = os.getenv('POW_DIFFICULTY', '000032') retry_times = int(os.getenv('RETRY_TIMES', 3)) enable_gateway = is_true(os.getenv('ENABLE_GATEWAY', True)) conversation_only = is_true(os.getenv('CONVERSATION_ONLY', False)) +enable_limit = is_true(os.getenv('ENABLE_LIMIT', False)) authorization_list = authorization.split(',') if authorization else [] chatgpt_base_url_list = chatgpt_base_url.split(',') if chatgpt_base_url else [] @@ -48,4 +49,5 @@ logger.info("POW_DIFFICULTY: " + str(pow_difficulty)) logger.info("RETRY_TIMES: " + str(retry_times)) logger.info("ENABLE_GATEWAY: " + str(enable_gateway)) logger.info("CONVERSATION_ONLY: " + str(conversation_only)) +logger.info("ENABLE_LIMIT: " + str(enable_limit)) logger.info("-" * 60)