mirror of
https://github.com/lanqian528/chat2api.git
synced 2026-06-16 21:10:55 +08:00
support keeping conversation history, we can do multi-round conversation now; support websocket stream now.
This commit is contained in:
parent
7630b42805
commit
e2f8a0f1d3
8
app.py
8
app.py
@ -37,7 +37,11 @@ async def send_conversation(request: Request, token=Depends(verify_token)):
|
||||
raise HTTPException(status_code=400, detail={"error": "Invalid JSON body"})
|
||||
|
||||
chat_service = await async_retry(to_send_conversation, request_data, access_token)
|
||||
chat_service.prepare_send_conversation()
|
||||
chat_service.prepare_send_conversation(
|
||||
parent_message_id=request_data.get('parent_message_id'),
|
||||
conversation_id=request_data.get('conversation_id'),
|
||||
history_disabled= request_data.get('history_disabled', True) # default is True
|
||||
)
|
||||
|
||||
res = await chat_service.send_conversation()
|
||||
if isinstance(res, types.AsyncGeneratorType):
|
||||
@ -60,4 +64,4 @@ if __name__ == "__main__":
|
||||
log_config["formatters"]["default"]["fmt"] = default_format
|
||||
log_config["formatters"]["access"]["fmt"] = access_format
|
||||
|
||||
uvicorn.run("app:app", host="0.0.0.0", port=5005)
|
||||
uvicorn.run("app:app", host="0.0.0.0", port=5205)
|
||||
|
||||
@ -3,6 +3,9 @@ import random
|
||||
import string
|
||||
import time
|
||||
import uuid
|
||||
import websockets
|
||||
import asyncio
|
||||
import base64
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
@ -10,7 +13,7 @@ from api.chat_completions import num_tokens_from_messages, model_system_fingerpr
|
||||
from chatgpt.proofofwork import calc_proof_token
|
||||
from utils.Client import Client
|
||||
from utils.Logger import Logger
|
||||
from utils.config import history_disabled, proxy_url_list, chatgpt_base_url_list, arkose_token_url_list
|
||||
from utils.config import proxy_url_list, chatgpt_base_url_list, arkose_token_url_list
|
||||
|
||||
moderation_message = "I'm sorry, I cannot provide or engage in any content related to pornography, violence, or any unethical material. If you have any other questions or need assistance, please feel free to let me know. I'll do my best to provide support and assistance."
|
||||
|
||||
@ -26,13 +29,18 @@ async def stream_response(service, response, model, max_tokens):
|
||||
last_recipient = None
|
||||
end = False
|
||||
message_id = None
|
||||
async for chunk in response.aiter_lines():
|
||||
current_message_id = None
|
||||
conversation_id = None
|
||||
all_text = ""
|
||||
async for chunk in response:
|
||||
chunk = chunk.decode("utf-8")
|
||||
print(f"chunk:{chunk}")
|
||||
if end:
|
||||
yield "data: [DONE]\n\n"
|
||||
break
|
||||
try:
|
||||
if chunk == "data: [DONE]":
|
||||
# to ignore afterwards \n\n for websocket message.
|
||||
if chunk.startswith("data: [DONE]"):
|
||||
yield "data: [DONE]\n\n"
|
||||
elif not chunk.startswith("data: "):
|
||||
continue
|
||||
@ -43,6 +51,8 @@ async def stream_response(service, response, model, max_tokens):
|
||||
status = message.get("status")
|
||||
content = message.get("content", {})
|
||||
recipient = message.get("recipient", "")
|
||||
current_message_id = message.get('id')
|
||||
conversation_id = chunk_old_data.get('conversation_id')
|
||||
if not message and chunk_old_data.get("type") == "moderation":
|
||||
delta = {"role": "assistant", "content": moderation_message}
|
||||
finish_reason = "stop"
|
||||
@ -55,7 +65,8 @@ async def stream_response(service, response, model, max_tokens):
|
||||
message_id = message.get("id")
|
||||
new_text = ""
|
||||
else:
|
||||
if message_id != message.get("id"):
|
||||
# for wss message, first valid text, message_id is None
|
||||
if message_id and message_id != message.get("id"):
|
||||
continue
|
||||
new_text = part[len_last_content:]
|
||||
len_last_content = len(part)
|
||||
@ -111,6 +122,8 @@ async def stream_response(service, response, model, max_tokens):
|
||||
delta = {"role": "assistant", "content": ""}
|
||||
chunk_new_data = {
|
||||
"id": chat_id,
|
||||
"message_id": current_message_id,
|
||||
"conversation_id": conversation_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
@ -125,9 +138,10 @@ async def stream_response(service, response, model, max_tokens):
|
||||
"system_fingerprint": system_fingerprint
|
||||
}
|
||||
completion_tokens += 1
|
||||
# print(f'chunk_new_data:{chunk_new_data}')
|
||||
yield f"data: {json.dumps(chunk_new_data)}\n\n"
|
||||
except Exception:
|
||||
Logger.error(f"Error: {chunk}")
|
||||
except Exception as e:
|
||||
Logger.error(f"Error: {chunk}, error: {str(e)}")
|
||||
continue
|
||||
|
||||
|
||||
@ -149,7 +163,7 @@ async def chat_response(service, response, prompt_tokens, model, max_tokens):
|
||||
if end:
|
||||
break
|
||||
try:
|
||||
if chunk == "data: [DONE]":
|
||||
if chunk.startswith("data: [DONE]"): # ignore afterwards \n\n for websocket.
|
||||
break
|
||||
elif not chunk.startswith("data: "):
|
||||
continue
|
||||
@ -226,8 +240,8 @@ async def chat_response(service, response, prompt_tokens, model, max_tokens):
|
||||
continue
|
||||
all_text += delta.get("content", "")
|
||||
completion_tokens += 1
|
||||
except Exception:
|
||||
Logger.error(f"Error: {chunk}")
|
||||
except Exception as e:
|
||||
Logger.error(f"Error: {chunk}, error: {str(e)}")
|
||||
continue
|
||||
|
||||
completion_tokens = num_tokens_from_content(all_text, model)
|
||||
@ -374,7 +388,7 @@ class ChatService:
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
def prepare_send_conversation(self):
|
||||
def prepare_send_conversation(self, parent_message_id=None, conversation_id=None, history_disabled=True):
|
||||
self.headers = {
|
||||
'Accept': 'text/event-stream',
|
||||
'Accept-Encoding': 'gzip, deflate, br',
|
||||
@ -402,7 +416,8 @@ class ChatService:
|
||||
model = "gpt-4"
|
||||
else:
|
||||
model = "text-davinci-002-render-sha"
|
||||
parent_message_id = f"{uuid.uuid4()}"
|
||||
parent_message_id = parent_message_id if parent_message_id else f"{uuid.uuid4()}"
|
||||
print(f"input conversation_id: {conversation_id}")
|
||||
websocket_request_id = f"{uuid.uuid4()}"
|
||||
self.chat_request = {
|
||||
"action": "next",
|
||||
@ -411,6 +426,7 @@ class ChatService:
|
||||
"model": model,
|
||||
"timezone_offset_min": -480,
|
||||
"suggestions": [],
|
||||
# let user decide whether or not we need to keep conversation history.
|
||||
"history_and_training_disabled": history_disabled,
|
||||
"conversation_mode": {"kind": "primary_assistant"},
|
||||
"force_paragen": False,
|
||||
@ -419,8 +435,64 @@ class ChatService:
|
||||
"force_rate_limit": False,
|
||||
"websocket_request_id": websocket_request_id,
|
||||
}
|
||||
if conversation_id:
|
||||
self.chat_request['conversation_id'] = conversation_id
|
||||
# print(f"chat_request:{self.chat_request}")
|
||||
return self.chat_request
|
||||
|
||||
async def wss_response_stream(self, detail):
|
||||
wss_url = detail.get('wss_url')
|
||||
subprotocols = ["json.reliable.webpubsub.azure.v1"]
|
||||
async with websockets.connect(wss_url, ping_interval=None, subprotocols=subprotocols) as websocket:
|
||||
while True:
|
||||
message = None
|
||||
try:
|
||||
message = await asyncio.wait_for(websocket.recv(), timeout=30)
|
||||
if message:
|
||||
# print(f'wss messsage:{message}')
|
||||
resultObj = json.loads(message)
|
||||
sequenceId = resultObj.get("sequenceId", None)
|
||||
if not sequenceId:
|
||||
continue
|
||||
result = resultObj.get("data", {}).get("body", None)
|
||||
# {"type": "http.response.body", "body": "ZGF0YTogeyJjb252ZXJzYXRpb25faWQiOiAiNDJhNjNiNWItNDVhZC00Nzg0LWIwNGMtZWNlMDIyNDZhY2Q4IiwgIm1lc3NhZ2VfaWQiOiAiZGNiMTUxM2ItZmU1NC00ZWM5LWJiY2MtZmRhYTJkZWRhMzI0IiwgImlzX2NvbXBsZXRpb24iOiB0cnVlLCAibW9kZXJhdGlvbl9yZXNwb25zZSI6IHsiZmxhZ2dlZCI6IGZhbHNlLCAiYmxvY2tlZCI6IGZhbHNlLCAibW9kZXJhdGlvbl9pZCI6ICJtb2RyLThublA5UGpZVEZLUHl6Z0hDRlFtNHZTOUNESTU1In19Cgo=", "more_body": true, "response_id": "84f29a22d9ac8c4b-EWR", "conversation_id": "42a63b5b-45ad-4784-b04c-ece02246acd8"}
|
||||
decoded_bytes = base64.b64decode(result)
|
||||
# result = decoded_bytes.decode("utf-8")
|
||||
yield decoded_bytes
|
||||
else:
|
||||
return
|
||||
# print(f"Message from server: {message}")
|
||||
except asyncio.TimeoutError:
|
||||
# Handle timeout, e.g., by breaking the loop or doing something else
|
||||
print("Timeout! No message received within the specified time.")
|
||||
return
|
||||
|
||||
async def send_conversation_for_stream(self):
|
||||
url = f'{self.base_url}/conversation'
|
||||
model = model_proxy.get(self.model, self.model)
|
||||
try:
|
||||
r = await self.s.post(url, headers=self.headers, json=self.chat_request, timeout=600, stream=True)
|
||||
if r.status_code != 200:
|
||||
if r.status_code == 403:
|
||||
detail = "cf-please-wait"
|
||||
else:
|
||||
rtext = await r.atext()
|
||||
detail = rtext
|
||||
detail = json.loads(rtext).get("detail", rtext)
|
||||
raise HTTPException(status_code=r.status_code, detail=detail)
|
||||
content_type = r.headers.get("Content-Type", "")
|
||||
if "text/event-stream" in content_type:
|
||||
return stream_response(self, r.aiter_lines(), model, self.max_tokens)
|
||||
elif "application/json" in content_type:
|
||||
rtext = await r.atext()
|
||||
detail = json.loads(rtext).get("detail", json.loads(rtext))
|
||||
wss_r = self.wss_response_stream(detail)
|
||||
return stream_response(self, wss_r, model, self.max_tokens)
|
||||
else:
|
||||
raise HTTPException(status_code=r.status_code, detail="Unsupported Content-Type")
|
||||
except HTTPException as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=str(e))
|
||||
|
||||
async def send_conversation(self):
|
||||
url = f'{self.base_url}/conversation'
|
||||
if "gpt-4" in self.model and self.persona != "chatgpt-paid":
|
||||
|
||||
@ -2,4 +2,5 @@ fastapi
|
||||
curl_cffi
|
||||
uvicorn
|
||||
tiktoken
|
||||
python-dotenv
|
||||
python-dotenv
|
||||
websockets
|
||||
@ -29,7 +29,7 @@ chatgpt_base_url_list = chatgpt_base_url.split(',') if chatgpt_base_url else []
|
||||
arkose_token_url_list = arkose_token_url.split(',') if arkose_token_url else []
|
||||
proxy_url_list = proxy_url.split(',') if proxy_url else []
|
||||
|
||||
logger.info("Environment variables (no AUTHORIZATION):")
|
||||
logger.info("Environment variables (AUTHORIZATION not displayed):")
|
||||
logger.info("CHATGPT_BASE_URL: " + str(chatgpt_base_url_list))
|
||||
logger.info("ARKOSE_TOKEN_URL: " + str(arkose_token_url_list))
|
||||
logger.info("PROXY_URL: " + str(proxy_url_list))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user