From 783168dbe1e3a4d8b14db7633fe3ff219cb4dc93 Mon Sep 17 00:00:00 2001 From: LanQian <5499636+LanQian528@users.noreply.github.com> Date: Tue, 30 Apr 2024 04:05:35 +0800 Subject: [PATCH] fix memery leak --- app.py | 23 +++++++++++------ chatgpt/ChatService.py | 19 +++++++++----- utils/Client.py | 58 +++++++++++++++--------------------------- 3 files changed, 49 insertions(+), 51 deletions(-) diff --git a/app.py b/app.py index ab68f3c..714338f 100644 --- a/app.py +++ b/app.py @@ -22,8 +22,12 @@ app.add_middleware( async def to_send_conversation(request_data, access_token): chat_service = ChatService(request_data, access_token) - await chat_service.get_chat_requirements() - return chat_service + try: + await chat_service.get_chat_requirements() + return chat_service + except Exception: + if chat_service.s.session: + await chat_service.close_client() @app.post("/v1/chat/completions") @@ -37,13 +41,16 @@ async def send_conversation(request: Request, token=Depends(verify_token)): raise HTTPException(status_code=400, detail={"error": "Invalid JSON body"}) chat_service = await async_retry(to_send_conversation, request_data, access_token) - await chat_service.prepare_send_conversation() + try: + await chat_service.prepare_send_conversation() - res = await chat_service.send_conversation() - if isinstance(res, types.AsyncGeneratorType): - return StreamingResponse(res, media_type="text/event-stream") - else: - return JSONResponse(res, media_type="application/json") + res = await chat_service.send_conversation() + if isinstance(res, types.AsyncGeneratorType): + return StreamingResponse(res, media_type="text/event-stream") + else: + return JSONResponse(res, media_type="application/json") + finally: + await chat_service.close_client() @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH", "TRACE"]) diff --git a/chatgpt/ChatService.py b/chatgpt/ChatService.py index ad64b1a..22198bc 100644 --- a/chatgpt/ChatService.py +++ b/chatgpt/ChatService.py @@ -12,6 +12,7 @@ from chatgpt.proofofwork import calc_proof_token from utils.Client import Client from utils.Logger import Logger from utils.config import proxy_url_list, chatgpt_base_url_list, arkose_token_url_list, history_disabled +from starlette.concurrency import run_in_threadpool class ChatService: @@ -93,7 +94,7 @@ class ChatService: if proofofwork_required: proofofwork_seed = proofofwork.get("seed") proofofwork_diff = proofofwork.get("difficulty") - self.proof_token = calc_proof_token(proofofwork_seed, proofofwork_diff) + self.proof_token = await run_in_threadpool(calc_proof_token, proofofwork_seed, proofofwork_diff) turnstile_required = turnstile.get('required') if turnstile_required: @@ -184,7 +185,7 @@ class ChatService: try: self.prompt_tokens = await num_tokens_from_messages(self.api_messages, self.model) stream = self.data.get("stream", False) - r = await self.s.post(url, headers=self.headers, json=self.chat_request, timeout=600, stream=True) + r = await self.s.post_stream(url, headers=self.headers, json=self.chat_request, timeout=600, stream=True) if r.status_code != 200: if r.status_code == 403: detail = "cf-please-wait" @@ -207,13 +208,15 @@ class ChatService: try: async with websockets.connect(wss_url, ping_interval=None, subprotocols=subprotocols) as websocket: wss_r = wss_stream_response(websocket) + if stream and isinstance(wss_r, types.AsyncGeneratorType): + return stream_response(self, wss_r, model, self.max_tokens) + else: + return await format_not_stream_response( + stream_response(self, wss_r, model, self.max_tokens), self.prompt_tokens, + self.max_tokens, model) except websockets.exceptions.InvalidStatusCode as e: Logger.error(f"Invalid status code: {str(e)}") raise HTTPException(status_code=e.status_code, detail=str(e)) - if stream and isinstance(wss_r, types.AsyncGeneratorType): - return stream_response(self, wss_r, model, self.max_tokens) - else: - return await format_not_stream_response(stream_response(self, wss_r, model, self.max_tokens), self.prompt_tokens, self.max_tokens, model) else: raise HTTPException(status_code=r.status_code, detail="Unsupported Content-Type") except HTTPException as e: @@ -305,3 +308,7 @@ class ChatService: return False except Exception: return False + + async def close_client(self): + await self.s.session.close() + self.s.session = None \ No newline at end of file diff --git a/utils/Client.py b/utils/Client.py index bfbd0ac..f8d5cde 100644 --- a/utils/Client.py +++ b/utils/Client.py @@ -11,47 +11,31 @@ class Client: } self.timeout = timeout self.verify = verify - self.headers = None - self.cookies = None self.impersonate = random.choice(["chrome", "safari", "safari_ios"]) + self.session = AsyncSession(proxies=self.proxies, timeout=self.timeout, verify=self.verify) - async def post(self, *args, headers=None, cookies=None, **kwargs): - if not headers: - headers = self.headers - if not cookies: - cookies = self.cookies - s = AsyncSession(proxies=self.proxies, timeout=self.timeout, verify=self.verify) - r = await s.post(*args, headers=headers, cookies=cookies, impersonate=self.impersonate, **kwargs) - self.cookies = r.cookies + async def post(self, *args, **kwargs): + r = await self.session.post(*args, impersonate=self.impersonate, **kwargs) return r - async def get(self, *args, headers=None, cookies=None, **kwargs): - if not headers: - headers = self.headers - if not cookies: - cookies = self.cookies - async with AsyncSession(proxies=self.proxies, timeout=self.timeout, verify=self.verify) as s: - r = await s.get(*args, headers=headers, cookies=cookies, impersonate=self.impersonate, **kwargs) - self.cookies = r.cookies - return r + async def post_stream(self, *args, headers=None, cookies=None, **kwargs): + if self.session: + headers = headers or self.session.headers + cookies = cookies or self.session.cookies + await self.session.close() + self.session = None + self.session = AsyncSession(proxies=self.proxies, timeout=self.timeout, verify=self.verify) + r = await self.session.post(*args, headers=headers, cookies=cookies, impersonate=self.impersonate, **kwargs) + return r - async def request(self, *args, headers=None, cookies=None, **kwargs): - if not headers: - headers = self.headers - if not cookies: - cookies = self.cookies - async with AsyncSession(proxies=self.proxies, timeout=self.timeout, verify=self.verify) as s: - r = await s.request(*args, headers=headers, cookies=cookies, impersonate=self.impersonate, **kwargs) - self.cookies = r.cookies - return r + async def get(self, *args, **kwargs): + r = await self.session.get(*args, impersonate=self.impersonate, **kwargs) + return r + + async def request(self, *args, **kwargs): + r = await self.session.request(*args, impersonate=self.impersonate, **kwargs) + return r async def put(self, *args, headers=None, cookies=None, **kwargs): - if not headers: - headers = self.headers - if not cookies: - cookies = self.cookies - async with AsyncSession(proxies=self.proxies, timeout=self.timeout, verify=self.verify) as s: - r = await s.put(*args, headers=headers, cookies=cookies, impersonate=self.impersonate, **kwargs) - self.cookies = r.cookies - return r - + r = await self.session.put(*args, impersonate=self.impersonate, **kwargs) + return r