This commit is contained in:
LanQian 2024-05-18 06:19:33 +08:00
parent 911430f78b
commit 089b3ed88a
2 changed files with 15 additions and 19 deletions

View File

@ -3,7 +3,6 @@ import random
import types
import uuid
import asyncio
import websockets
from fastapi import HTTPException
from starlette.concurrency import run_in_threadpool
@ -248,11 +247,12 @@ class ChatService:
rtext = await r.atext()
resp = json.loads(rtext)
self.wss_url = resp.get('wss_url')
conversation_id = resp.get('conversation_id')
await set_wss(self.access_token, self.wss_url)
logger.info(f"next wss_url: {self.wss_url}")
if not self.ws:
self.ws = await websockets.connect(self.wss_url, ping_interval=None, subprotocols=subprotocols)
wss_r = wss_stream_response(self.ws)
wss_r = wss_stream_response(self.ws, conversation_id)
try:
if stream and isinstance(wss_r, types.AsyncGeneratorType):
return stream_response(self, wss_r, self.target_model, self.max_tokens)
@ -373,12 +373,5 @@ class ChatService:
async def close_client(self):
await self.s.close()
if self.ws:
while not self.ws.closed:
try:
await asyncio.wait_for(self.ws.recv(), timeout=3)
except asyncio.TimeoutError:
break
except Exception as e:
logger.error(f"Closing websocket error: {str(e)}")
await self.ws.close()
del self.ws

View File

@ -64,17 +64,19 @@ async def format_not_stream_response(response, prompt_tokens, max_tokens, model)
}
async def wss_stream_response(websocket):
async def wss_stream_response(websocket, conversation_id):
while not websocket.closed:
try:
message = await asyncio.wait_for(websocket.recv(), timeout=15)
message = await asyncio.wait_for(websocket.recv(), timeout=10)
if message:
resultObj = json.loads(message)
sequenceId = resultObj.get("sequenceId", None)
if not sequenceId:
continue
result = resultObj.get("data", {}).get("body", None)
decoded_bytes = base64.b64decode(result)
data = resultObj.get("data", {})
if conversation_id != data.get("conversation_id", ""):
continue
decoded_bytes = base64.b64decode(data.get("body", None))
yield decoded_bytes
else:
print("No message received within the specified time.")
@ -108,22 +110,19 @@ async def stream_response(service, response, model, max_tokens):
yield "data: [DONE]\n\n"
break
try:
if chunk.startswith("data: [DONE]"):
yield "data: [DONE]\n\n"
elif not chunk.startswith("data: {"):
continue
else:
if chunk.startswith("data: {"):
chunk_old_data = json.loads(chunk[6:])
finish_reason = None
message = chunk_old_data.get("message", {})
role = message.get('author', {}).get('role')
if role == 'user':
continue
conversation_id = chunk_old_data.get("conversation_id")
status = message.get("status")
content = message.get("content", {})
recipient = message.get("recipient", "")
current_message_id = message.get('id')
conversation_id = chunk_old_data.get('conversation_id')
if not message and chunk_old_data.get("type") == "moderation":
delta = {"role": "assistant", "content": moderation_message}
finish_reason = "stop"
@ -217,6 +216,10 @@ async def stream_response(service, response, model, max_tokens):
})
completion_tokens += 1
yield f"data: {json.dumps(chunk_new_data)}\n\n"
elif chunk.startswith("data: [DONE]"):
yield "data: [DONE]\n\n"
else:
continue
except Exception as e:
logger.error(f"Error: {chunk}, details: {str(e)}")
continue