support keeping conversation history, we can do multi-round conversation now; support websocket stream now.

This commit is contained in:
Yuanzhang Hu 2024-04-23 16:00:48 -04:00
parent 7630b42805
commit e2f8a0f1d3
No known key found for this signature in database
GPG Key ID: 83B3E45AFE1D32EF
4 changed files with 92 additions and 15 deletions

8
app.py
View File

@ -37,7 +37,11 @@ async def send_conversation(request: Request, token=Depends(verify_token)):
raise HTTPException(status_code=400, detail={"error": "Invalid JSON body"})
chat_service = await async_retry(to_send_conversation, request_data, access_token)
chat_service.prepare_send_conversation()
chat_service.prepare_send_conversation(
parent_message_id=request_data.get('parent_message_id'),
conversation_id=request_data.get('conversation_id'),
history_disabled= request_data.get('history_disabled', True) # default is True
)
res = await chat_service.send_conversation()
if isinstance(res, types.AsyncGeneratorType):
@ -60,4 +64,4 @@ if __name__ == "__main__":
log_config["formatters"]["default"]["fmt"] = default_format
log_config["formatters"]["access"]["fmt"] = access_format
uvicorn.run("app:app", host="0.0.0.0", port=5005)
uvicorn.run("app:app", host="0.0.0.0", port=5205)

View File

@ -3,6 +3,9 @@ import random
import string
import time
import uuid
import websockets
import asyncio
import base64
from fastapi import HTTPException
@ -10,7 +13,7 @@ from api.chat_completions import num_tokens_from_messages, model_system_fingerpr
from chatgpt.proofofwork import calc_proof_token
from utils.Client import Client
from utils.Logger import Logger
from utils.config import history_disabled, proxy_url_list, chatgpt_base_url_list, arkose_token_url_list
from utils.config import proxy_url_list, chatgpt_base_url_list, arkose_token_url_list
moderation_message = "I'm sorry, I cannot provide or engage in any content related to pornography, violence, or any unethical material. If you have any other questions or need assistance, please feel free to let me know. I'll do my best to provide support and assistance."
@ -26,13 +29,18 @@ async def stream_response(service, response, model, max_tokens):
last_recipient = None
end = False
message_id = None
async for chunk in response.aiter_lines():
current_message_id = None
conversation_id = None
all_text = ""
async for chunk in response:
chunk = chunk.decode("utf-8")
print(f"chunk:{chunk}")
if end:
yield "data: [DONE]\n\n"
break
try:
if chunk == "data: [DONE]":
# to ignore afterwards \n\n for websocket message.
if chunk.startswith("data: [DONE]"):
yield "data: [DONE]\n\n"
elif not chunk.startswith("data: "):
continue
@ -43,6 +51,8 @@ async def stream_response(service, response, model, max_tokens):
status = message.get("status")
content = message.get("content", {})
recipient = message.get("recipient", "")
current_message_id = message.get('id')
conversation_id = chunk_old_data.get('conversation_id')
if not message and chunk_old_data.get("type") == "moderation":
delta = {"role": "assistant", "content": moderation_message}
finish_reason = "stop"
@ -55,7 +65,8 @@ async def stream_response(service, response, model, max_tokens):
message_id = message.get("id")
new_text = ""
else:
if message_id != message.get("id"):
# for wss message, first valid text, message_id is None
if message_id and message_id != message.get("id"):
continue
new_text = part[len_last_content:]
len_last_content = len(part)
@ -111,6 +122,8 @@ async def stream_response(service, response, model, max_tokens):
delta = {"role": "assistant", "content": ""}
chunk_new_data = {
"id": chat_id,
"message_id": current_message_id,
"conversation_id": conversation_id,
"object": "chat.completion.chunk",
"created": created_time,
"model": model,
@ -125,9 +138,10 @@ async def stream_response(service, response, model, max_tokens):
"system_fingerprint": system_fingerprint
}
completion_tokens += 1
# print(f'chunk_new_data:{chunk_new_data}')
yield f"data: {json.dumps(chunk_new_data)}\n\n"
except Exception:
Logger.error(f"Error: {chunk}")
except Exception as e:
Logger.error(f"Error: {chunk}, error: {str(e)}")
continue
@ -149,7 +163,7 @@ async def chat_response(service, response, prompt_tokens, model, max_tokens):
if end:
break
try:
if chunk == "data: [DONE]":
if chunk.startswith("data: [DONE]"): # ignore afterwards \n\n for websocket.
break
elif not chunk.startswith("data: "):
continue
@ -226,8 +240,8 @@ async def chat_response(service, response, prompt_tokens, model, max_tokens):
continue
all_text += delta.get("content", "")
completion_tokens += 1
except Exception:
Logger.error(f"Error: {chunk}")
except Exception as e:
Logger.error(f"Error: {chunk}, error: {str(e)}")
continue
completion_tokens = num_tokens_from_content(all_text, model)
@ -374,7 +388,7 @@ class ChatService:
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def prepare_send_conversation(self):
def prepare_send_conversation(self, parent_message_id=None, conversation_id=None, history_disabled=True):
self.headers = {
'Accept': 'text/event-stream',
'Accept-Encoding': 'gzip, deflate, br',
@ -402,7 +416,8 @@ class ChatService:
model = "gpt-4"
else:
model = "text-davinci-002-render-sha"
parent_message_id = f"{uuid.uuid4()}"
parent_message_id = parent_message_id if parent_message_id else f"{uuid.uuid4()}"
print(f"input conversation_id: {conversation_id}")
websocket_request_id = f"{uuid.uuid4()}"
self.chat_request = {
"action": "next",
@ -411,6 +426,7 @@ class ChatService:
"model": model,
"timezone_offset_min": -480,
"suggestions": [],
# let user decide whether or not we need to keep conversation history.
"history_and_training_disabled": history_disabled,
"conversation_mode": {"kind": "primary_assistant"},
"force_paragen": False,
@ -419,8 +435,64 @@ class ChatService:
"force_rate_limit": False,
"websocket_request_id": websocket_request_id,
}
if conversation_id:
self.chat_request['conversation_id'] = conversation_id
# print(f"chat_request:{self.chat_request}")
return self.chat_request
async def wss_response_stream(self, detail):
wss_url = detail.get('wss_url')
subprotocols = ["json.reliable.webpubsub.azure.v1"]
async with websockets.connect(wss_url, ping_interval=None, subprotocols=subprotocols) as websocket:
while True:
message = None
try:
message = await asyncio.wait_for(websocket.recv(), timeout=30)
if message:
# print(f'wss messsage:{message}')
resultObj = json.loads(message)
sequenceId = resultObj.get("sequenceId", None)
if not sequenceId:
continue
result = resultObj.get("data", {}).get("body", None)
# {"type": "http.response.body", "body": "ZGF0YTogeyJjb252ZXJzYXRpb25faWQiOiAiNDJhNjNiNWItNDVhZC00Nzg0LWIwNGMtZWNlMDIyNDZhY2Q4IiwgIm1lc3NhZ2VfaWQiOiAiZGNiMTUxM2ItZmU1NC00ZWM5LWJiY2MtZmRhYTJkZWRhMzI0IiwgImlzX2NvbXBsZXRpb24iOiB0cnVlLCAibW9kZXJhdGlvbl9yZXNwb25zZSI6IHsiZmxhZ2dlZCI6IGZhbHNlLCAiYmxvY2tlZCI6IGZhbHNlLCAibW9kZXJhdGlvbl9pZCI6ICJtb2RyLThublA5UGpZVEZLUHl6Z0hDRlFtNHZTOUNESTU1In19Cgo=", "more_body": true, "response_id": "84f29a22d9ac8c4b-EWR", "conversation_id": "42a63b5b-45ad-4784-b04c-ece02246acd8"}
decoded_bytes = base64.b64decode(result)
# result = decoded_bytes.decode("utf-8")
yield decoded_bytes
else:
return
# print(f"Message from server: {message}")
except asyncio.TimeoutError:
# Handle timeout, e.g., by breaking the loop or doing something else
print("Timeout! No message received within the specified time.")
return
async def send_conversation_for_stream(self):
url = f'{self.base_url}/conversation'
model = model_proxy.get(self.model, self.model)
try:
r = await self.s.post(url, headers=self.headers, json=self.chat_request, timeout=600, stream=True)
if r.status_code != 200:
if r.status_code == 403:
detail = "cf-please-wait"
else:
rtext = await r.atext()
detail = rtext
detail = json.loads(rtext).get("detail", rtext)
raise HTTPException(status_code=r.status_code, detail=detail)
content_type = r.headers.get("Content-Type", "")
if "text/event-stream" in content_type:
return stream_response(self, r.aiter_lines(), model, self.max_tokens)
elif "application/json" in content_type:
rtext = await r.atext()
detail = json.loads(rtext).get("detail", json.loads(rtext))
wss_r = self.wss_response_stream(detail)
return stream_response(self, wss_r, model, self.max_tokens)
else:
raise HTTPException(status_code=r.status_code, detail="Unsupported Content-Type")
except HTTPException as e:
raise HTTPException(status_code=e.status_code, detail=str(e))
async def send_conversation(self):
url = f'{self.base_url}/conversation'
if "gpt-4" in self.model and self.persona != "chatgpt-paid":

View File

@ -2,4 +2,5 @@ fastapi
curl_cffi
uvicorn
tiktoken
python-dotenv
python-dotenv
websockets

View File

@ -29,7 +29,7 @@ chatgpt_base_url_list = chatgpt_base_url.split(',') if chatgpt_base_url else []
arkose_token_url_list = arkose_token_url.split(',') if arkose_token_url else []
proxy_url_list = proxy_url.split(',') if proxy_url else []
logger.info("Environment variables (no AUTHORIZATION):")
logger.info("Environment variables (AUTHORIZATION not displayed):")
logger.info("CHATGPT_BASE_URL: " + str(chatgpt_base_url_list))
logger.info("ARKOSE_TOKEN_URL: " + str(arkose_token_url_list))
logger.info("PROXY_URL: " + str(proxy_url_list))