mirror of
https://github.com/lanqian528/chat2api.git
synced 2026-06-13 21:02:46 +08:00
fix memery leak
This commit is contained in:
parent
6ccbb56a85
commit
783168dbe1
23
app.py
23
app.py
@ -22,8 +22,12 @@ app.add_middleware(
|
||||
|
||||
async def to_send_conversation(request_data, access_token):
|
||||
chat_service = ChatService(request_data, access_token)
|
||||
await chat_service.get_chat_requirements()
|
||||
return chat_service
|
||||
try:
|
||||
await chat_service.get_chat_requirements()
|
||||
return chat_service
|
||||
except Exception:
|
||||
if chat_service.s.session:
|
||||
await chat_service.close_client()
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
@ -37,13 +41,16 @@ async def send_conversation(request: Request, token=Depends(verify_token)):
|
||||
raise HTTPException(status_code=400, detail={"error": "Invalid JSON body"})
|
||||
|
||||
chat_service = await async_retry(to_send_conversation, request_data, access_token)
|
||||
await chat_service.prepare_send_conversation()
|
||||
try:
|
||||
await chat_service.prepare_send_conversation()
|
||||
|
||||
res = await chat_service.send_conversation()
|
||||
if isinstance(res, types.AsyncGeneratorType):
|
||||
return StreamingResponse(res, media_type="text/event-stream")
|
||||
else:
|
||||
return JSONResponse(res, media_type="application/json")
|
||||
res = await chat_service.send_conversation()
|
||||
if isinstance(res, types.AsyncGeneratorType):
|
||||
return StreamingResponse(res, media_type="text/event-stream")
|
||||
else:
|
||||
return JSONResponse(res, media_type="application/json")
|
||||
finally:
|
||||
await chat_service.close_client()
|
||||
|
||||
|
||||
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH", "TRACE"])
|
||||
|
||||
@ -12,6 +12,7 @@ from chatgpt.proofofwork import calc_proof_token
|
||||
from utils.Client import Client
|
||||
from utils.Logger import Logger
|
||||
from utils.config import proxy_url_list, chatgpt_base_url_list, arkose_token_url_list, history_disabled
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
|
||||
class ChatService:
|
||||
@ -93,7 +94,7 @@ class ChatService:
|
||||
if proofofwork_required:
|
||||
proofofwork_seed = proofofwork.get("seed")
|
||||
proofofwork_diff = proofofwork.get("difficulty")
|
||||
self.proof_token = calc_proof_token(proofofwork_seed, proofofwork_diff)
|
||||
self.proof_token = await run_in_threadpool(calc_proof_token, proofofwork_seed, proofofwork_diff)
|
||||
|
||||
turnstile_required = turnstile.get('required')
|
||||
if turnstile_required:
|
||||
@ -184,7 +185,7 @@ class ChatService:
|
||||
try:
|
||||
self.prompt_tokens = await num_tokens_from_messages(self.api_messages, self.model)
|
||||
stream = self.data.get("stream", False)
|
||||
r = await self.s.post(url, headers=self.headers, json=self.chat_request, timeout=600, stream=True)
|
||||
r = await self.s.post_stream(url, headers=self.headers, json=self.chat_request, timeout=600, stream=True)
|
||||
if r.status_code != 200:
|
||||
if r.status_code == 403:
|
||||
detail = "cf-please-wait"
|
||||
@ -207,13 +208,15 @@ class ChatService:
|
||||
try:
|
||||
async with websockets.connect(wss_url, ping_interval=None, subprotocols=subprotocols) as websocket:
|
||||
wss_r = wss_stream_response(websocket)
|
||||
if stream and isinstance(wss_r, types.AsyncGeneratorType):
|
||||
return stream_response(self, wss_r, model, self.max_tokens)
|
||||
else:
|
||||
return await format_not_stream_response(
|
||||
stream_response(self, wss_r, model, self.max_tokens), self.prompt_tokens,
|
||||
self.max_tokens, model)
|
||||
except websockets.exceptions.InvalidStatusCode as e:
|
||||
Logger.error(f"Invalid status code: {str(e)}")
|
||||
raise HTTPException(status_code=e.status_code, detail=str(e))
|
||||
if stream and isinstance(wss_r, types.AsyncGeneratorType):
|
||||
return stream_response(self, wss_r, model, self.max_tokens)
|
||||
else:
|
||||
return await format_not_stream_response(stream_response(self, wss_r, model, self.max_tokens), self.prompt_tokens, self.max_tokens, model)
|
||||
else:
|
||||
raise HTTPException(status_code=r.status_code, detail="Unsupported Content-Type")
|
||||
except HTTPException as e:
|
||||
@ -305,3 +308,7 @@ class ChatService:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def close_client(self):
|
||||
await self.s.session.close()
|
||||
self.s.session = None
|
||||
@ -11,47 +11,31 @@ class Client:
|
||||
}
|
||||
self.timeout = timeout
|
||||
self.verify = verify
|
||||
self.headers = None
|
||||
self.cookies = None
|
||||
self.impersonate = random.choice(["chrome", "safari", "safari_ios"])
|
||||
self.session = AsyncSession(proxies=self.proxies, timeout=self.timeout, verify=self.verify)
|
||||
|
||||
async def post(self, *args, headers=None, cookies=None, **kwargs):
|
||||
if not headers:
|
||||
headers = self.headers
|
||||
if not cookies:
|
||||
cookies = self.cookies
|
||||
s = AsyncSession(proxies=self.proxies, timeout=self.timeout, verify=self.verify)
|
||||
r = await s.post(*args, headers=headers, cookies=cookies, impersonate=self.impersonate, **kwargs)
|
||||
self.cookies = r.cookies
|
||||
async def post(self, *args, **kwargs):
|
||||
r = await self.session.post(*args, impersonate=self.impersonate, **kwargs)
|
||||
return r
|
||||
|
||||
async def get(self, *args, headers=None, cookies=None, **kwargs):
|
||||
if not headers:
|
||||
headers = self.headers
|
||||
if not cookies:
|
||||
cookies = self.cookies
|
||||
async with AsyncSession(proxies=self.proxies, timeout=self.timeout, verify=self.verify) as s:
|
||||
r = await s.get(*args, headers=headers, cookies=cookies, impersonate=self.impersonate, **kwargs)
|
||||
self.cookies = r.cookies
|
||||
return r
|
||||
async def post_stream(self, *args, headers=None, cookies=None, **kwargs):
|
||||
if self.session:
|
||||
headers = headers or self.session.headers
|
||||
cookies = cookies or self.session.cookies
|
||||
await self.session.close()
|
||||
self.session = None
|
||||
self.session = AsyncSession(proxies=self.proxies, timeout=self.timeout, verify=self.verify)
|
||||
r = await self.session.post(*args, headers=headers, cookies=cookies, impersonate=self.impersonate, **kwargs)
|
||||
return r
|
||||
|
||||
async def request(self, *args, headers=None, cookies=None, **kwargs):
|
||||
if not headers:
|
||||
headers = self.headers
|
||||
if not cookies:
|
||||
cookies = self.cookies
|
||||
async with AsyncSession(proxies=self.proxies, timeout=self.timeout, verify=self.verify) as s:
|
||||
r = await s.request(*args, headers=headers, cookies=cookies, impersonate=self.impersonate, **kwargs)
|
||||
self.cookies = r.cookies
|
||||
return r
|
||||
async def get(self, *args, **kwargs):
|
||||
r = await self.session.get(*args, impersonate=self.impersonate, **kwargs)
|
||||
return r
|
||||
|
||||
async def request(self, *args, **kwargs):
|
||||
r = await self.session.request(*args, impersonate=self.impersonate, **kwargs)
|
||||
return r
|
||||
|
||||
async def put(self, *args, headers=None, cookies=None, **kwargs):
|
||||
if not headers:
|
||||
headers = self.headers
|
||||
if not cookies:
|
||||
cookies = self.cookies
|
||||
async with AsyncSession(proxies=self.proxies, timeout=self.timeout, verify=self.verify) as s:
|
||||
r = await s.put(*args, headers=headers, cookies=cookies, impersonate=self.impersonate, **kwargs)
|
||||
self.cookies = r.cookies
|
||||
return r
|
||||
|
||||
r = await self.session.put(*args, impersonate=self.impersonate, **kwargs)
|
||||
return r
|
||||
|
||||
Loading…
Reference in New Issue
Block a user