fix memery leak

This commit is contained in:
LanQian 2024-04-30 04:05:35 +08:00
parent 6ccbb56a85
commit 783168dbe1
3 changed files with 49 additions and 51 deletions

23
app.py
View File

@ -22,8 +22,12 @@ app.add_middleware(
async def to_send_conversation(request_data, access_token):
chat_service = ChatService(request_data, access_token)
await chat_service.get_chat_requirements()
return chat_service
try:
await chat_service.get_chat_requirements()
return chat_service
except Exception:
if chat_service.s.session:
await chat_service.close_client()
@app.post("/v1/chat/completions")
@ -37,13 +41,16 @@ async def send_conversation(request: Request, token=Depends(verify_token)):
raise HTTPException(status_code=400, detail={"error": "Invalid JSON body"})
chat_service = await async_retry(to_send_conversation, request_data, access_token)
await chat_service.prepare_send_conversation()
try:
await chat_service.prepare_send_conversation()
res = await chat_service.send_conversation()
if isinstance(res, types.AsyncGeneratorType):
return StreamingResponse(res, media_type="text/event-stream")
else:
return JSONResponse(res, media_type="application/json")
res = await chat_service.send_conversation()
if isinstance(res, types.AsyncGeneratorType):
return StreamingResponse(res, media_type="text/event-stream")
else:
return JSONResponse(res, media_type="application/json")
finally:
await chat_service.close_client()
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH", "TRACE"])

View File

@ -12,6 +12,7 @@ from chatgpt.proofofwork import calc_proof_token
from utils.Client import Client
from utils.Logger import Logger
from utils.config import proxy_url_list, chatgpt_base_url_list, arkose_token_url_list, history_disabled
from starlette.concurrency import run_in_threadpool
class ChatService:
@ -93,7 +94,7 @@ class ChatService:
if proofofwork_required:
proofofwork_seed = proofofwork.get("seed")
proofofwork_diff = proofofwork.get("difficulty")
self.proof_token = calc_proof_token(proofofwork_seed, proofofwork_diff)
self.proof_token = await run_in_threadpool(calc_proof_token, proofofwork_seed, proofofwork_diff)
turnstile_required = turnstile.get('required')
if turnstile_required:
@ -184,7 +185,7 @@ class ChatService:
try:
self.prompt_tokens = await num_tokens_from_messages(self.api_messages, self.model)
stream = self.data.get("stream", False)
r = await self.s.post(url, headers=self.headers, json=self.chat_request, timeout=600, stream=True)
r = await self.s.post_stream(url, headers=self.headers, json=self.chat_request, timeout=600, stream=True)
if r.status_code != 200:
if r.status_code == 403:
detail = "cf-please-wait"
@ -207,13 +208,15 @@ class ChatService:
try:
async with websockets.connect(wss_url, ping_interval=None, subprotocols=subprotocols) as websocket:
wss_r = wss_stream_response(websocket)
if stream and isinstance(wss_r, types.AsyncGeneratorType):
return stream_response(self, wss_r, model, self.max_tokens)
else:
return await format_not_stream_response(
stream_response(self, wss_r, model, self.max_tokens), self.prompt_tokens,
self.max_tokens, model)
except websockets.exceptions.InvalidStatusCode as e:
Logger.error(f"Invalid status code: {str(e)}")
raise HTTPException(status_code=e.status_code, detail=str(e))
if stream and isinstance(wss_r, types.AsyncGeneratorType):
return stream_response(self, wss_r, model, self.max_tokens)
else:
return await format_not_stream_response(stream_response(self, wss_r, model, self.max_tokens), self.prompt_tokens, self.max_tokens, model)
else:
raise HTTPException(status_code=r.status_code, detail="Unsupported Content-Type")
except HTTPException as e:
@ -305,3 +308,7 @@ class ChatService:
return False
except Exception:
return False
async def close_client(self):
await self.s.session.close()
self.s.session = None

View File

@ -11,47 +11,31 @@ class Client:
}
self.timeout = timeout
self.verify = verify
self.headers = None
self.cookies = None
self.impersonate = random.choice(["chrome", "safari", "safari_ios"])
self.session = AsyncSession(proxies=self.proxies, timeout=self.timeout, verify=self.verify)
async def post(self, *args, headers=None, cookies=None, **kwargs):
if not headers:
headers = self.headers
if not cookies:
cookies = self.cookies
s = AsyncSession(proxies=self.proxies, timeout=self.timeout, verify=self.verify)
r = await s.post(*args, headers=headers, cookies=cookies, impersonate=self.impersonate, **kwargs)
self.cookies = r.cookies
async def post(self, *args, **kwargs):
r = await self.session.post(*args, impersonate=self.impersonate, **kwargs)
return r
async def get(self, *args, headers=None, cookies=None, **kwargs):
if not headers:
headers = self.headers
if not cookies:
cookies = self.cookies
async with AsyncSession(proxies=self.proxies, timeout=self.timeout, verify=self.verify) as s:
r = await s.get(*args, headers=headers, cookies=cookies, impersonate=self.impersonate, **kwargs)
self.cookies = r.cookies
return r
async def post_stream(self, *args, headers=None, cookies=None, **kwargs):
if self.session:
headers = headers or self.session.headers
cookies = cookies or self.session.cookies
await self.session.close()
self.session = None
self.session = AsyncSession(proxies=self.proxies, timeout=self.timeout, verify=self.verify)
r = await self.session.post(*args, headers=headers, cookies=cookies, impersonate=self.impersonate, **kwargs)
return r
async def request(self, *args, headers=None, cookies=None, **kwargs):
if not headers:
headers = self.headers
if not cookies:
cookies = self.cookies
async with AsyncSession(proxies=self.proxies, timeout=self.timeout, verify=self.verify) as s:
r = await s.request(*args, headers=headers, cookies=cookies, impersonate=self.impersonate, **kwargs)
self.cookies = r.cookies
return r
async def get(self, *args, **kwargs):
r = await self.session.get(*args, impersonate=self.impersonate, **kwargs)
return r
async def request(self, *args, **kwargs):
r = await self.session.request(*args, impersonate=self.impersonate, **kwargs)
return r
async def put(self, *args, headers=None, cookies=None, **kwargs):
if not headers:
headers = self.headers
if not cookies:
cookies = self.cookies
async with AsyncSession(proxies=self.proxies, timeout=self.timeout, verify=self.verify) as s:
r = await s.put(*args, headers=headers, cookies=cookies, impersonate=self.impersonate, **kwargs)
self.cookies = r.cookies
return r
r = await self.session.put(*args, impersonate=self.impersonate, **kwargs)
return r