merge send_conversation and send_conversation_for_stream

This commit is contained in:
Yuanzhang Hu 2024-04-23 17:15:27 -04:00
parent 54ece4c709
commit c986776f6d
No known key found for this signature in database
GPG Key ID: 83B3E45AFE1D32EF

View File

@ -34,7 +34,7 @@ async def stream_response(service, response, model, max_tokens):
all_text = ""
async for chunk in response:
chunk = chunk.decode("utf-8")
print(f"chunk:{chunk}")
# print(f"chunk:{chunk}")
if end:
yield "data: [DONE]\n\n"
break
@ -48,6 +48,9 @@ async def stream_response(service, response, model, max_tokens):
chunk_old_data = json.loads(chunk[6:])
finish_reason = None
message = chunk_old_data.get("message", {})
role = message.get('author', {}).get('role')
if role == 'user':
continue
status = message.get("status")
content = message.get("content", {})
recipient = message.get("recipient", "")
@ -467,34 +470,9 @@ class ChatService:
print("Timeout! No message received within the specified time.")
return
async def send_conversation_for_stream(self):
url = f'{self.base_url}/conversation'
model = model_proxy.get(self.model, self.model)
try:
r = await self.s.post(url, headers=self.headers, json=self.chat_request, timeout=600, stream=True)
if r.status_code != 200:
if r.status_code == 403:
detail = "cf-please-wait"
else:
rtext = await r.atext()
detail = rtext
detail = json.loads(rtext).get("detail", rtext)
raise HTTPException(status_code=r.status_code, detail=detail)
content_type = r.headers.get("Content-Type", "")
if "text/event-stream" in content_type:
return stream_response(self, r.aiter_lines(), model, self.max_tokens)
elif "application/json" in content_type:
rtext = await r.atext()
detail = json.loads(rtext).get("detail", json.loads(rtext))
wss_r = self.wss_response_stream(detail)
return stream_response(self, wss_r, model, self.max_tokens)
else:
raise HTTPException(status_code=r.status_code, detail="Unsupported Content-Type")
except HTTPException as e:
raise HTTPException(status_code=e.status_code, detail=str(e))
async def send_conversation(self):
url = f'{self.base_url}/conversation'
# Check for model access or existence
if "gpt-4" in self.model and self.persona != "chatgpt-paid":
raise HTTPException(status_code=404, detail={
"message": f"The model `{self.model}` does not exist or you do not have access to it.",
@ -506,23 +484,30 @@ class ChatService:
try:
stream = self.data.get("stream", False)
r = await self.s.post(url, headers=self.headers, json=self.chat_request, timeout=600, stream=True)
content_type = r.headers.get("Content-Type", "")
if r.status_code == 200 and stream is False and "stream" in content_type:
return await chat_response(self, r, self.prompt_tokens, model, self.max_tokens)
elif r.status_code == 200 and stream is True and "stream" in content_type:
return stream_response(self, r, model, self.max_tokens)
else:
rtext = await r.atext()
if "application/json" in content_type:
detail = json.loads(rtext).get("detail", json.loads(rtext))
else:
detail = rtext
if r.status_code != 200:
if r.status_code == 403:
raise HTTPException(status_code=r.status_code, detail="cf-please-wait")
detail = "cf-please-wait"
else:
rtext = await r.atext()
detail = json.loads(rtext).get("detail", rtext)
raise HTTPException(status_code=r.status_code, detail=detail)
content_type = r.headers.get("Content-Type", "")
if "text/event-stream" in content_type and stream:
return stream_response(self, r.aiter_lines(), model, self.max_tokens)
elif "application/json" in content_type:
rtext = await r.atext()
detail = json.loads(rtext).get("detail", json.loads(rtext))
if "wss_url" in detail:
Logger.error(f"Websockets: {detail}")
raise HTTPException(status_code=403, detail="Wss not supported")
raise HTTPException(status_code=r.status_code, detail=detail)
if stream:
wss_r = self.wss_response_stream(detail)
return stream_response(self, wss_r, model, self.max_tokens)
else:
return chat_response(self, r, self.prompt_tokens, model, self.max_tokens)
else:
raise HTTPException(status_code=r.status_code, detail="Unsupported Content-Type")
except HTTPException as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)