mirror of
https://github.com/lanqian528/chat2api.git
synced 2026-06-13 21:02:46 +08:00
fix wss
This commit is contained in:
parent
911430f78b
commit
089b3ed88a
@ -3,7 +3,6 @@ import random
|
||||
import types
|
||||
import uuid
|
||||
|
||||
import asyncio
|
||||
import websockets
|
||||
from fastapi import HTTPException
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
@ -248,11 +247,12 @@ class ChatService:
|
||||
rtext = await r.atext()
|
||||
resp = json.loads(rtext)
|
||||
self.wss_url = resp.get('wss_url')
|
||||
conversation_id = resp.get('conversation_id')
|
||||
await set_wss(self.access_token, self.wss_url)
|
||||
logger.info(f"next wss_url: {self.wss_url}")
|
||||
if not self.ws:
|
||||
self.ws = await websockets.connect(self.wss_url, ping_interval=None, subprotocols=subprotocols)
|
||||
wss_r = wss_stream_response(self.ws)
|
||||
wss_r = wss_stream_response(self.ws, conversation_id)
|
||||
try:
|
||||
if stream and isinstance(wss_r, types.AsyncGeneratorType):
|
||||
return stream_response(self, wss_r, self.target_model, self.max_tokens)
|
||||
@ -373,12 +373,5 @@ class ChatService:
|
||||
async def close_client(self):
|
||||
await self.s.close()
|
||||
if self.ws:
|
||||
while not self.ws.closed:
|
||||
try:
|
||||
await asyncio.wait_for(self.ws.recv(), timeout=3)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Closing websocket error: {str(e)}")
|
||||
await self.ws.close()
|
||||
del self.ws
|
||||
|
||||
@ -64,17 +64,19 @@ async def format_not_stream_response(response, prompt_tokens, max_tokens, model)
|
||||
}
|
||||
|
||||
|
||||
async def wss_stream_response(websocket):
|
||||
async def wss_stream_response(websocket, conversation_id):
|
||||
while not websocket.closed:
|
||||
try:
|
||||
message = await asyncio.wait_for(websocket.recv(), timeout=15)
|
||||
message = await asyncio.wait_for(websocket.recv(), timeout=10)
|
||||
if message:
|
||||
resultObj = json.loads(message)
|
||||
sequenceId = resultObj.get("sequenceId", None)
|
||||
if not sequenceId:
|
||||
continue
|
||||
result = resultObj.get("data", {}).get("body", None)
|
||||
decoded_bytes = base64.b64decode(result)
|
||||
data = resultObj.get("data", {})
|
||||
if conversation_id != data.get("conversation_id", ""):
|
||||
continue
|
||||
decoded_bytes = base64.b64decode(data.get("body", None))
|
||||
yield decoded_bytes
|
||||
else:
|
||||
print("No message received within the specified time.")
|
||||
@ -108,22 +110,19 @@ async def stream_response(service, response, model, max_tokens):
|
||||
yield "data: [DONE]\n\n"
|
||||
break
|
||||
try:
|
||||
if chunk.startswith("data: [DONE]"):
|
||||
yield "data: [DONE]\n\n"
|
||||
elif not chunk.startswith("data: {"):
|
||||
continue
|
||||
else:
|
||||
if chunk.startswith("data: {"):
|
||||
chunk_old_data = json.loads(chunk[6:])
|
||||
finish_reason = None
|
||||
message = chunk_old_data.get("message", {})
|
||||
role = message.get('author', {}).get('role')
|
||||
if role == 'user':
|
||||
continue
|
||||
|
||||
conversation_id = chunk_old_data.get("conversation_id")
|
||||
status = message.get("status")
|
||||
content = message.get("content", {})
|
||||
recipient = message.get("recipient", "")
|
||||
current_message_id = message.get('id')
|
||||
conversation_id = chunk_old_data.get('conversation_id')
|
||||
if not message and chunk_old_data.get("type") == "moderation":
|
||||
delta = {"role": "assistant", "content": moderation_message}
|
||||
finish_reason = "stop"
|
||||
@ -217,6 +216,10 @@ async def stream_response(service, response, model, max_tokens):
|
||||
})
|
||||
completion_tokens += 1
|
||||
yield f"data: {json.dumps(chunk_new_data)}\n\n"
|
||||
elif chunk.startswith("data: [DONE]"):
|
||||
yield "data: [DONE]\n\n"
|
||||
else:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {chunk}, details: {str(e)}")
|
||||
continue
|
||||
|
||||
Loading…
Reference in New Issue
Block a user