From e2f8a0f1d3e6f627c4db2cb5b27f011aefaad369 Mon Sep 17 00:00:00 2001 From: Yuanzhang Hu Date: Tue, 23 Apr 2024 16:00:48 -0400 Subject: [PATCH] support keeping conversation history, we can do multi-round conversation now; support websocket stream now. --- app.py | 8 +++- chatgpt/ChatService.py | 94 +++++++++++++++++++++++++++++++++++++----- requirements.txt | 3 +- utils/config.py | 2 +- 4 files changed, 92 insertions(+), 15 deletions(-) diff --git a/app.py b/app.py index 66126aa..0b14525 100644 --- a/app.py +++ b/app.py @@ -37,7 +37,11 @@ async def send_conversation(request: Request, token=Depends(verify_token)): raise HTTPException(status_code=400, detail={"error": "Invalid JSON body"}) chat_service = await async_retry(to_send_conversation, request_data, access_token) - chat_service.prepare_send_conversation() + chat_service.prepare_send_conversation( + parent_message_id=request_data.get('parent_message_id'), + conversation_id=request_data.get('conversation_id'), + history_disabled= request_data.get('history_disabled', True) # default is True + ) res = await chat_service.send_conversation() if isinstance(res, types.AsyncGeneratorType): @@ -60,4 +64,4 @@ if __name__ == "__main__": log_config["formatters"]["default"]["fmt"] = default_format log_config["formatters"]["access"]["fmt"] = access_format - uvicorn.run("app:app", host="0.0.0.0", port=5005) + uvicorn.run("app:app", host="0.0.0.0", port=5205) diff --git a/chatgpt/ChatService.py b/chatgpt/ChatService.py index c4a4f88..f185bf2 100644 --- a/chatgpt/ChatService.py +++ b/chatgpt/ChatService.py @@ -3,6 +3,9 @@ import random import string import time import uuid +import websockets +import asyncio +import base64 from fastapi import HTTPException @@ -10,7 +13,7 @@ from api.chat_completions import num_tokens_from_messages, model_system_fingerpr from chatgpt.proofofwork import calc_proof_token from utils.Client import Client from utils.Logger import Logger -from utils.config import history_disabled, proxy_url_list, chatgpt_base_url_list, arkose_token_url_list +from utils.config import proxy_url_list, chatgpt_base_url_list, arkose_token_url_list moderation_message = "I'm sorry, I cannot provide or engage in any content related to pornography, violence, or any unethical material. If you have any other questions or need assistance, please feel free to let me know. I'll do my best to provide support and assistance." @@ -26,13 +29,18 @@ async def stream_response(service, response, model, max_tokens): last_recipient = None end = False message_id = None - async for chunk in response.aiter_lines(): + current_message_id = None + conversation_id = None + all_text = "" + async for chunk in response: chunk = chunk.decode("utf-8") + print(f"chunk:{chunk}") if end: yield "data: [DONE]\n\n" break try: - if chunk == "data: [DONE]": + # to ignore afterwards \n\n for websocket message. + if chunk.startswith("data: [DONE]"): yield "data: [DONE]\n\n" elif not chunk.startswith("data: "): continue @@ -43,6 +51,8 @@ async def stream_response(service, response, model, max_tokens): status = message.get("status") content = message.get("content", {}) recipient = message.get("recipient", "") + current_message_id = message.get('id') + conversation_id = chunk_old_data.get('conversation_id') if not message and chunk_old_data.get("type") == "moderation": delta = {"role": "assistant", "content": moderation_message} finish_reason = "stop" @@ -55,7 +65,8 @@ async def stream_response(service, response, model, max_tokens): message_id = message.get("id") new_text = "" else: - if message_id != message.get("id"): + # for wss message, first valid text, message_id is None + if message_id and message_id != message.get("id"): continue new_text = part[len_last_content:] len_last_content = len(part) @@ -111,6 +122,8 @@ async def stream_response(service, response, model, max_tokens): delta = {"role": "assistant", "content": ""} chunk_new_data = { "id": chat_id, + "message_id": current_message_id, + "conversation_id": conversation_id, "object": "chat.completion.chunk", "created": created_time, "model": model, @@ -125,9 +138,10 @@ async def stream_response(service, response, model, max_tokens): "system_fingerprint": system_fingerprint } completion_tokens += 1 + # print(f'chunk_new_data:{chunk_new_data}') yield f"data: {json.dumps(chunk_new_data)}\n\n" - except Exception: - Logger.error(f"Error: {chunk}") + except Exception as e: + Logger.error(f"Error: {chunk}, error: {str(e)}") continue @@ -149,7 +163,7 @@ async def chat_response(service, response, prompt_tokens, model, max_tokens): if end: break try: - if chunk == "data: [DONE]": + if chunk.startswith("data: [DONE]"): # ignore afterwards \n\n for websocket. break elif not chunk.startswith("data: "): continue @@ -226,8 +240,8 @@ async def chat_response(service, response, prompt_tokens, model, max_tokens): continue all_text += delta.get("content", "") completion_tokens += 1 - except Exception: - Logger.error(f"Error: {chunk}") + except Exception as e: + Logger.error(f"Error: {chunk}, error: {str(e)}") continue completion_tokens = num_tokens_from_content(all_text, model) @@ -374,7 +388,7 @@ class ChatService: except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - def prepare_send_conversation(self): + def prepare_send_conversation(self, parent_message_id=None, conversation_id=None, history_disabled=True): self.headers = { 'Accept': 'text/event-stream', 'Accept-Encoding': 'gzip, deflate, br', @@ -402,7 +416,8 @@ class ChatService: model = "gpt-4" else: model = "text-davinci-002-render-sha" - parent_message_id = f"{uuid.uuid4()}" + parent_message_id = parent_message_id if parent_message_id else f"{uuid.uuid4()}" + print(f"input conversation_id: {conversation_id}") websocket_request_id = f"{uuid.uuid4()}" self.chat_request = { "action": "next", @@ -411,6 +426,7 @@ class ChatService: "model": model, "timezone_offset_min": -480, "suggestions": [], + # let user decide whether or not we need to keep conversation history. "history_and_training_disabled": history_disabled, "conversation_mode": {"kind": "primary_assistant"}, "force_paragen": False, @@ -419,8 +435,64 @@ class ChatService: "force_rate_limit": False, "websocket_request_id": websocket_request_id, } + if conversation_id: + self.chat_request['conversation_id'] = conversation_id + # print(f"chat_request:{self.chat_request}") return self.chat_request + async def wss_response_stream(self, detail): + wss_url = detail.get('wss_url') + subprotocols = ["json.reliable.webpubsub.azure.v1"] + async with websockets.connect(wss_url, ping_interval=None, subprotocols=subprotocols) as websocket: + while True: + message = None + try: + message = await asyncio.wait_for(websocket.recv(), timeout=30) + if message: + # print(f'wss messsage:{message}') + resultObj = json.loads(message) + sequenceId = resultObj.get("sequenceId", None) + if not sequenceId: + continue + result = resultObj.get("data", {}).get("body", None) + # {"type": "http.response.body", "body": "ZGF0YTogeyJjb252ZXJzYXRpb25faWQiOiAiNDJhNjNiNWItNDVhZC00Nzg0LWIwNGMtZWNlMDIyNDZhY2Q4IiwgIm1lc3NhZ2VfaWQiOiAiZGNiMTUxM2ItZmU1NC00ZWM5LWJiY2MtZmRhYTJkZWRhMzI0IiwgImlzX2NvbXBsZXRpb24iOiB0cnVlLCAibW9kZXJhdGlvbl9yZXNwb25zZSI6IHsiZmxhZ2dlZCI6IGZhbHNlLCAiYmxvY2tlZCI6IGZhbHNlLCAibW9kZXJhdGlvbl9pZCI6ICJtb2RyLThublA5UGpZVEZLUHl6Z0hDRlFtNHZTOUNESTU1In19Cgo=", "more_body": true, "response_id": "84f29a22d9ac8c4b-EWR", "conversation_id": "42a63b5b-45ad-4784-b04c-ece02246acd8"} + decoded_bytes = base64.b64decode(result) + # result = decoded_bytes.decode("utf-8") + yield decoded_bytes + else: + return + # print(f"Message from server: {message}") + except asyncio.TimeoutError: + # Handle timeout, e.g., by breaking the loop or doing something else + print("Timeout! No message received within the specified time.") + return + + async def send_conversation_for_stream(self): + url = f'{self.base_url}/conversation' + model = model_proxy.get(self.model, self.model) + try: + r = await self.s.post(url, headers=self.headers, json=self.chat_request, timeout=600, stream=True) + if r.status_code != 200: + if r.status_code == 403: + detail = "cf-please-wait" + else: + rtext = await r.atext() + detail = rtext + detail = json.loads(rtext).get("detail", rtext) + raise HTTPException(status_code=r.status_code, detail=detail) + content_type = r.headers.get("Content-Type", "") + if "text/event-stream" in content_type: + return stream_response(self, r.aiter_lines(), model, self.max_tokens) + elif "application/json" in content_type: + rtext = await r.atext() + detail = json.loads(rtext).get("detail", json.loads(rtext)) + wss_r = self.wss_response_stream(detail) + return stream_response(self, wss_r, model, self.max_tokens) + else: + raise HTTPException(status_code=r.status_code, detail="Unsupported Content-Type") + except HTTPException as e: + raise HTTPException(status_code=e.status_code, detail=str(e)) + async def send_conversation(self): url = f'{self.base_url}/conversation' if "gpt-4" in self.model and self.persona != "chatgpt-paid": diff --git a/requirements.txt b/requirements.txt index 33bd5a2..b72fd03 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ fastapi curl_cffi uvicorn tiktoken -python-dotenv \ No newline at end of file +python-dotenv +websockets \ No newline at end of file diff --git a/utils/config.py b/utils/config.py index 2b2a2bf..95e85ad 100644 --- a/utils/config.py +++ b/utils/config.py @@ -29,7 +29,7 @@ chatgpt_base_url_list = chatgpt_base_url.split(',') if chatgpt_base_url else [] arkose_token_url_list = arkose_token_url.split(',') if arkose_token_url else [] proxy_url_list = proxy_url.split(',') if proxy_url else [] -logger.info("Environment variables (no AUTHORIZATION):") +logger.info("Environment variables (AUTHORIZATION not displayed):") logger.info("CHATGPT_BASE_URL: " + str(chatgpt_base_url_list)) logger.info("ARKOSE_TOKEN_URL: " + str(arkose_token_url_list)) logger.info("PROXY_URL: " + str(proxy_url_list))