From 7f45621bceedcc354a246d4f160c100952be8fb0 Mon Sep 17 00:00:00 2001 From: Yanyutin753 <132346501+Yanyutin753@users.noreply.github.com> Date: Thu, 23 May 2024 14:49:51 +0800 Subject: [PATCH] =?UTF-8?q?feat=20=E6=96=B0=E5=A2=9E=E4=B8=A4=E4=B8=AA?= =?UTF-8?q?=E7=8E=AF=E5=A2=83=E5=8F=98=E9=87=8F=E7=94=A8=E6=9D=A5=E6=8E=A7?= =?UTF-8?q?=E5=88=B6=E6=98=AF=E5=90=A6=E5=AF=B94=E6=B6=88=E6=81=AF?= =?UTF-8?q?=E8=BF=9B=E8=A1=8C=E9=99=90=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app.py | 15 +++++++----- chat2api.py | 12 +++++---- chatgpt/ChatService.py | 24 ++++++++++++------ chatgpt/chatLimit.py | 55 +++++++++++++++++++----------------------- requirements.txt | 3 ++- utils/config.py | 2 ++ 6 files changed, 62 insertions(+), 49 deletions(-) diff --git a/app.py b/app.py index b67f330..8f082a1 100644 --- a/app.py +++ b/app.py @@ -1,9 +1,8 @@ -import asyncio - import uvicorn +from apscheduler.schedulers.background import BackgroundScheduler from chat2api import app -from chatgpt.chatLimit import clean_background_task +from chatgpt.chatLimit import clean_dict log_config = uvicorn.config.LOGGING_CONFIG default_format = "%(asctime)s | %(levelname)s | %(message)s" @@ -11,11 +10,15 @@ access_format = r'%(asctime)s | %(levelname)s | %(client_addr)s: %(request_line) log_config["formatters"]["default"]["fmt"] = default_format log_config["formatters"]["access"]["fmt"] = access_format +scheduler = BackgroundScheduler() -# 在应用启动时启动后台任务 + +# 用于自动更新限制access_token的字典 +# 避免过多access_token没有销毁,而堆积 @app.on_event("startup") -async def startup_event(): - clean_task = asyncio.create_task(clean_background_task()) +async def app_start(): + scheduler.add_job(id='updateLimit_run', func=clean_dict, trigger='cron', hour=3, minute=0) + scheduler.start() uvicorn.run("chat2api:app", host="0.0.0.0", port=5005) diff --git a/chat2api.py b/chat2api.py index 5b937ca..808a2db 100644 --- a/chat2api.py +++ b/chat2api.py @@ -9,11 +9,12 @@ from fastapi.templating import Jinja2Templates from starlette.background import BackgroundTask from chatgpt.ChatService import ChatService -from chatgpt.chatLimit import check_isLimit, handle_request_limit +from chatgpt.chatLimit import handle_request_limit from chatgpt.reverseProxy import chatgpt_reverse_proxy from utils.Logger import logger from utils.authorization import verify_token, token_list from utils.config import api_prefix +from utils.config import enable_limit, limit_status_code from utils.retry import async_retry warnings.filterwarnings("ignore") @@ -32,16 +33,12 @@ app.add_middleware( async def to_send_conversation(request_data, access_token): try: - limit_response = await handle_request_limit(request_data, access_token) - if limit_response: - raise HTTPException(status_code=429, detail=limit_response) chat_service = ChatService(access_token) await chat_service.set_dynamic_data(request_data) await chat_service.get_chat_requirements() return chat_service except HTTPException as e: await chat_service.close_client() - await check_isLimit(e, access_token) raise HTTPException(status_code=e.status_code, detail=e.detail) except Exception as e: await chat_service.close_client() @@ -56,6 +53,11 @@ async def send_conversation(request: Request, token=Depends(verify_token)): except Exception: raise HTTPException(status_code=400, detail={"error": "Invalid JSON body"}) + if enable_limit: + limit_response = await handle_request_limit(request_data, token) + if limit_response: + raise HTTPException(status_code=int(limit_status_code), detail=limit_response) + chat_service = await async_retry(to_send_conversation, request_data, token) try: await chat_service.prepare_send_conversation() diff --git a/chatgpt/ChatService.py b/chatgpt/ChatService.py index ba3f53e..51f2711 100644 --- a/chatgpt/ChatService.py +++ b/chatgpt/ChatService.py @@ -10,6 +10,7 @@ from starlette.concurrency import run_in_threadpool from api.files import get_image_size, get_file_extension, determine_file_use_case from api.models import model_proxy from chatgpt.chatFormat import api_messages_to_chat, stream_response, wss_stream_response, format_not_stream_response +from chatgpt.chatLimit import check_isLimit from chatgpt.proofofWork import get_config, get_dpl, get_answer_token, get_requirements_token from chatgpt.wssClient import ac2wss, set_wss from utils.Client import Client @@ -79,7 +80,8 @@ class ChatService: self.base_url = self.host_url + "/backend-anon" await get_dpl(self) - self.s.session.cookies.set("__Secure-next-auth.callback-url", "https%3A%2F%2Fchatgpt.com;", domain=self.host_url.split("://")[1], secure=True) + self.s.session.cookies.set("__Secure-next-auth.callback-url", "https%3A%2F%2Fchatgpt.com;", + domain=self.host_url.split("://")[1], secure=True) async def get_wss_url(self): url = f'{self.base_url}/register-websocket' @@ -124,9 +126,11 @@ class ChatService: if proofofwork_required: proofofwork_diff = proofofwork.get("difficulty") if proofofwork_diff <= pow_difficulty: - raise HTTPException(status_code=403, detail=f"Proof of work difficulty too high: {proofofwork_diff}") + raise HTTPException(status_code=403, + detail=f"Proof of work difficulty too high: {proofofwork_diff}") proofofwork_seed = proofofwork.get("seed") - self.proof_token, solved = await run_in_threadpool(get_answer_token, proofofwork_seed, proofofwork_diff, config) + self.proof_token, solved = await run_in_threadpool(get_answer_token, proofofwork_seed, + proofofwork_diff, config) if not solved: raise HTTPException(status_code=403, detail="Failed to solve proof of work") @@ -170,7 +174,6 @@ class ChatService: if r.status_code == 429: raise HTTPException(status_code=r.status_code, detail="rate-limit") raise HTTPException(status_code=r.status_code, detail=detail) - except HTTPException as e: raise HTTPException(status_code=e.status_code, detail=e.detail) except Exception as e: @@ -243,11 +246,14 @@ class ChatService: raise HTTPException(status_code=e.status_code, detail=str(e)) url = f'{self.base_url}/conversation' stream = self.data.get("stream", False) - r = await self.s.post_stream(url, headers=self.chat_headers, json=self.chat_request, timeout=10, stream=True) + r = await self.s.post_stream(url, headers=self.chat_headers, json=self.chat_request, timeout=10, + stream=True) if r.status_code != 200: rtext = await r.atext() if "application/json" == r.headers.get("Content-Type", ""): detail = json.loads(rtext).get("detail", json.loads(rtext)) + if r.status_code == 429: + check_isLimit(detail, access_token=self.access_token) else: if "cf-please-wait" in rtext: raise HTTPException(status_code=r.status_code, detail="cf-please-wait") @@ -260,7 +266,9 @@ class ChatService: if "text/event-stream" in content_type and stream: return stream_response(self, r.aiter_lines(), self.target_model, self.max_tokens) elif "text/event-stream" in content_type and not stream: - return await format_not_stream_response(stream_response(self, r.aiter_lines(), self.target_model, self.max_tokens), self.prompt_tokens, self.max_tokens, self.target_model) + return await format_not_stream_response( + stream_response(self, r.aiter_lines(), self.target_model, self.max_tokens), self.prompt_tokens, + self.max_tokens, self.target_model) elif "application/json" in content_type: rtext = await r.atext() resp = json.loads(rtext) @@ -275,7 +283,9 @@ class ChatService: if stream and isinstance(wss_r, types.AsyncGeneratorType): return stream_response(self, wss_r, self.target_model, self.max_tokens) else: - return await format_not_stream_response(stream_response(self, wss_r, self.target_model, self.max_tokens), self.prompt_tokens, self.max_tokens, self.target_model) + return await format_not_stream_response( + stream_response(self, wss_r, self.target_model, self.max_tokens), self.prompt_tokens, + self.max_tokens, self.target_model) finally: if not isinstance(wss_r, types.AsyncGeneratorType): await self.ws.close() diff --git a/chatgpt/chatLimit.py b/chatgpt/chatLimit.py index 193750d..f5995ed 100644 --- a/chatgpt/chatLimit.py +++ b/chatgpt/chatLimit.py @@ -1,8 +1,6 @@ -import asyncio import threading import time from datetime import datetime - from utils.Logger import logger # 开启线程锁 @@ -11,56 +9,53 @@ lock = threading.Lock() limit_access_token = {} -async def check_isLimit(e, access_token): - if e.status_code == 429: - clearTime = time.time() + e.detail.get('clears_in') - clearDate = datetime.fromtimestamp(clearTime) - initial_access_list(access_token, clearTime, clearDate) +def check_isLimit(detail, access_token): + if detail.get('clears_in'): + clearTime = time.time() + detail.get('clears_in') + initial_access_list(access_token, clearTime) -async def initial_access_list(key, clear_time): +def initial_access_list(key, clear_time): with lock: limit_access_token[key] = clear_time logger.info( - f"{key}: Reached 429 limit, will be cleared at {datetime.fromtimestamp(clear_time).replace(second=0, microsecond=0)}") + f"{key[:40]}: Reached 429 limit, will be cleared at {datetime.fromtimestamp(clear_time).replace(second=0, microsecond=0)}") -async def remove_refresh_list(key): +def remove_refresh_list(key): with lock: if key in limit_access_token: limit_access_token.pop(key) logger.info(f"Removed limit for {key[:40]}.") else: - logger.info(f"Access token {key[:40]} not found in limit_access_token.") + logger.info(f"Access token: {key[:40]} not found in limit_access_token.") async def handle_request_limit(request_data, access_token): - if "gpt-4" in request_data.get("model") and access_token in limit_access_token: - is_clear = limit_access_token[access_token] < time.time() - if is_clear: - await remove_refresh_list(access_token) - else: - clear_date = datetime.fromtimestamp(limit_access_token[access_token]) - logger.info(f"Request limit exceeded. " - f"You can continue with the default model now, " - f"or try again after {clear_date.replace(second=0, microsecond=0)}") - return f"Request limit exceeded. You can continue with the default model now, or try again after {clear_date.replace(second=0, microsecond=0)}" - return None + try: + if "gpt-4" in request_data.get("model", "gpt-3.5-turbo") and access_token in limit_access_token: + limit_time = limit_access_token[access_token] + is_clear = limit_time < time.time() + if is_clear: + remove_refresh_list(access_token) + else: + clear_date = datetime.fromtimestamp(limit_time).replace(second=0, microsecond=0) + logger.info(f"Request limit exceeded. " + f"You can continue with the default model now, " + f"or try again after {clear_date}") + return f"Request limit exceeded. You can continue with the default model now, or try again after {clear_date}" + return None + except Exception as e: + logger.error(e) + return None # 清理函数,用于删除长时间未使用的键值对 def clean_dict(): + logger.info(f"==========================================") logger.info("开始执行清理过期的limit_access_token......") current_time = time.time() keys_to_remove = [key for key, clear_time in limit_access_token.items() if clear_time < current_time] for key in keys_to_remove: with lock: del limit_access_token[key] - - -# 后台任务,定期执行清理函数 -async def clean_background_task(): - while True: - clean_dict() - # 一天执行一次,防止limit_access_token过多 - await asyncio.sleep(3600 * 24) diff --git a/requirements.txt b/requirements.txt index d1b1ff6..c77f816 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ python-dotenv websockets pillow pybase64 -jinja2 \ No newline at end of file +jinja2 +APScheduler diff --git a/utils/config.py b/utils/config.py index b114b68..cdfeef7 100644 --- a/utils/config.py +++ b/utils/config.py @@ -29,6 +29,7 @@ retry_times = int(os.getenv('RETRY_TIMES', 3)) enable_gateway = is_true(os.getenv('ENABLE_GATEWAY', True)) conversation_only = is_true(os.getenv('CONVERSATION_ONLY', False)) enable_limit = is_true(os.getenv('ENABLE_LIMIT', False)) +limit_status_code = os.getenv('LIMIT_STATUS_CODE', 429) authorization_list = authorization.split(',') if authorization else [] chatgpt_base_url_list = chatgpt_base_url.split(',') if chatgpt_base_url else [] @@ -50,4 +51,5 @@ logger.info("RETRY_TIMES: " + str(retry_times)) logger.info("ENABLE_GATEWAY: " + str(enable_gateway)) logger.info("CONVERSATION_ONLY: " + str(conversation_only)) logger.info("ENABLE_LIMIT: " + str(enable_limit)) +logger.info("LIMIT_STATUS_CODE: " + str(limit_status_code)) logger.info("-" * 60)