From 089b3ed88a4b366b5a2d9bf2ae4cb5e6aab8eb9e Mon Sep 17 00:00:00 2001 From: LanQian <5499636+LanQian528@users.noreply.github.com> Date: Sat, 18 May 2024 06:19:33 +0800 Subject: [PATCH] fix wss --- chatgpt/ChatService.py | 11 ++--------- chatgpt/chatFormat.py | 23 +++++++++++++---------- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/chatgpt/ChatService.py b/chatgpt/ChatService.py index 7a88a84..780bf84 100644 --- a/chatgpt/ChatService.py +++ b/chatgpt/ChatService.py @@ -3,7 +3,6 @@ import random import types import uuid -import asyncio import websockets from fastapi import HTTPException from starlette.concurrency import run_in_threadpool @@ -248,11 +247,12 @@ class ChatService: rtext = await r.atext() resp = json.loads(rtext) self.wss_url = resp.get('wss_url') + conversation_id = resp.get('conversation_id') await set_wss(self.access_token, self.wss_url) logger.info(f"next wss_url: {self.wss_url}") if not self.ws: self.ws = await websockets.connect(self.wss_url, ping_interval=None, subprotocols=subprotocols) - wss_r = wss_stream_response(self.ws) + wss_r = wss_stream_response(self.ws, conversation_id) try: if stream and isinstance(wss_r, types.AsyncGeneratorType): return stream_response(self, wss_r, self.target_model, self.max_tokens) @@ -373,12 +373,5 @@ class ChatService: async def close_client(self): await self.s.close() if self.ws: - while not self.ws.closed: - try: - await asyncio.wait_for(self.ws.recv(), timeout=3) - except asyncio.TimeoutError: - break - except Exception as e: - logger.error(f"Closing websocket error: {str(e)}") await self.ws.close() del self.ws diff --git a/chatgpt/chatFormat.py b/chatgpt/chatFormat.py index 0abf3d3..e677e9c 100644 --- a/chatgpt/chatFormat.py +++ b/chatgpt/chatFormat.py @@ -64,17 +64,19 @@ async def format_not_stream_response(response, prompt_tokens, max_tokens, model) } -async def wss_stream_response(websocket): +async def wss_stream_response(websocket, conversation_id): while not websocket.closed: try: - message = await asyncio.wait_for(websocket.recv(), timeout=15) + message = await asyncio.wait_for(websocket.recv(), timeout=10) if message: resultObj = json.loads(message) sequenceId = resultObj.get("sequenceId", None) if not sequenceId: continue - result = resultObj.get("data", {}).get("body", None) - decoded_bytes = base64.b64decode(result) + data = resultObj.get("data", {}) + if conversation_id != data.get("conversation_id", ""): + continue + decoded_bytes = base64.b64decode(data.get("body", None)) yield decoded_bytes else: print("No message received within the specified time.") @@ -108,22 +110,19 @@ async def stream_response(service, response, model, max_tokens): yield "data: [DONE]\n\n" break try: - if chunk.startswith("data: [DONE]"): - yield "data: [DONE]\n\n" - elif not chunk.startswith("data: {"): - continue - else: + if chunk.startswith("data: {"): chunk_old_data = json.loads(chunk[6:]) finish_reason = None message = chunk_old_data.get("message", {}) role = message.get('author', {}).get('role') if role == 'user': continue + + conversation_id = chunk_old_data.get("conversation_id") status = message.get("status") content = message.get("content", {}) recipient = message.get("recipient", "") current_message_id = message.get('id') - conversation_id = chunk_old_data.get('conversation_id') if not message and chunk_old_data.get("type") == "moderation": delta = {"role": "assistant", "content": moderation_message} finish_reason = "stop" @@ -217,6 +216,10 @@ async def stream_response(service, response, model, max_tokens): }) completion_tokens += 1 yield f"data: {json.dumps(chunk_new_data)}\n\n" + elif chunk.startswith("data: [DONE]"): + yield "data: [DONE]\n\n" + else: + continue except Exception as e: logger.error(f"Error: {chunk}, details: {str(e)}") continue