v1.4.0 renew 'retry' logic

This commit is contained in:
lanqian528 2024-08-10 18:15:46 +08:00
parent c064b39dec
commit 992ded1ecc
6 changed files with 64 additions and 86 deletions

View File

@ -37,7 +37,7 @@ jobs:
images: lanqian528/chat2api
tags: |
type=raw,value=latest,enable={{is_default_branch}}
type=raw,value=v1.3.7
type=raw,value=v1.4.0
- name: Build and push
uses: docker/build-push-action@v5

View File

@ -26,7 +26,7 @@
## 功能
### 最新版 v1.3.7
### 最新版 v1.4.0
> 已完成
> - [x] 流式、非流式传输

View File

@ -58,17 +58,20 @@ async def to_send_conversation(request_data, req_token):
raise HTTPException(status_code=500, detail="Server error")
async def process(request_data, req_token):
chat_service = await to_send_conversation(request_data, req_token)
await chat_service.prepare_send_conversation()
res = await chat_service.send_conversation()
return chat_service, res
@app.post(f"/{api_prefix}/v1/chat/completions" if api_prefix else "/v1/chat/completions")
async def send_conversation(request: Request, req_token: str = Depends(oauth2_scheme)):
try:
request_data = await request.json()
except Exception:
raise HTTPException(status_code=400, detail={"error": "Invalid JSON body"})
chat_service = await async_retry(to_send_conversation, request_data, req_token)
chat_service, res = await async_retry(process, request_data, req_token)
try:
await chat_service.prepare_send_conversation()
res = await chat_service.send_conversation()
if isinstance(res, types.AsyncGeneratorType):
background = BackgroundTask(chat_service.close_client)
return StreamingResponse(res, media_type="text/event-stream", background=background)

View File

@ -1,21 +1,18 @@
import asyncio
import json
import random
import types
import uuid
import websockets
from fastapi import HTTPException
from starlette.concurrency import run_in_threadpool
from api.files import get_image_size, get_file_extension, determine_file_use_case
from api.models import model_proxy
from chatgpt.authorization import get_req_token, verify_token
from chatgpt.chatFormat import api_messages_to_chat, stream_response, wss_stream_response, format_not_stream_response
from chatgpt.chatFormat import api_messages_to_chat, stream_response, format_not_stream_response, head_process_response
from chatgpt.chatLimit import check_is_limit, handle_request_limit
from chatgpt.proofofWork import get_config, get_dpl, get_answer_token, get_requirements_token
from chatgpt.turnstile import process_turnstile
from chatgpt.wssClient import token2wss, set_wss
from utils.Client import Client
from utils.Logger import logger
from utils.config import proxy_url_list, chatgpt_base_url_list, arkose_token_url_list, history_disabled, pow_difficulty, \
@ -68,12 +65,6 @@ class ChatService:
self.arkose_token_url = random.choice(arkose_token_url_list) if arkose_token_url_list else None
self.s = Client(proxy=self.proxy_url)
self.ws = None
if conversation_only:
self.wss_mode = False
self.wss_url = None
else:
self.wss_mode, self.wss_url = await token2wss(self.req_token)
self.oai_device_id = str(uuid.uuid4())
self.persona = None
@ -133,21 +124,6 @@ class ChatService:
else:
self.req_model = "text-davinci-002-render-sha"
async def get_wss_url(self):
url = f'{self.base_url}/register-websocket'
headers = self.base_headers.copy()
r = await self.s.post(url, headers=headers, data='', timeout=5)
try:
if r.status_code == 200:
resp = r.json()
logger.info(f'register-websocket response:{resp}')
wss_url = resp.get('wss_url')
return wss_url
raise Exception(r.text)
except Exception as e:
logger.error(f"get_wss_url error: {str(e)}")
raise HTTPException(status_code=r.status_code, detail=f"Failed to get wss url: {str(e)}")
async def get_chat_requirements(self):
if conversation_only:
return None
@ -308,14 +284,6 @@ class ChatService:
async def send_conversation(self):
try:
try:
if self.wss_mode:
if not self.wss_url:
self.wss_url = await self.get_wss_url()
self.ws = await websockets.connect(self.wss_url, ping_interval=None, subprotocols=["json.reliable.webpubsub.azure.v1"])
except Exception as e:
logger.error(f"Failed to connect to wss: {str(e)}", )
raise HTTPException(status_code=502, detail="Failed to connect to wss")
url = f'{self.base_url}/conversation'
stream = self.data.get("stream", False)
r = await self.s.post_stream(url, headers=self.chat_headers, json=self.chat_request, timeout=10,
@ -338,40 +306,23 @@ class ChatService:
raise HTTPException(status_code=r.status_code, detail=detail)
content_type = r.headers.get("Content-Type", "")
if "text/event-stream" in content_type and stream:
await set_wss(self.req_token, False)
return stream_response(self, r.aiter_lines(), self.resp_model, self.max_tokens)
elif "text/event-stream" in content_type and not stream:
await set_wss(self.req_token, False)
return await format_not_stream_response(
stream_response(self, r.aiter_lines(), self.resp_model, self.max_tokens), self.prompt_tokens,
self.max_tokens, self.resp_model)
if "text/event-stream" in content_type:
res, start = await head_process_response(r.aiter_lines())
if not start:
raise HTTPException(status_code=403, detail="Our systems have detected unusual activity coming from your system. Please try again later.")
if stream:
return stream_response(self, res, self.resp_model, self.max_tokens)
else:
return await format_not_stream_response(
stream_response(self, res, self.resp_model, self.max_tokens), self.prompt_tokens,
self.max_tokens, self.resp_model)
elif "application/json" in content_type:
rtext = await r.atext()
resp = json.loads(rtext)
self.wss_url = resp.get('wss_url')
conversation_id = resp.get('conversation_id')
await set_wss(self.req_token, True, self.wss_url)
logger.info(f"next wss_url: {self.wss_url}")
if not self.ws:
try:
self.ws = await websockets.connect(self.wss_url, ping_interval=None, subprotocols=["json.reliable.webpubsub.azure.v1"])
except Exception as e:
logger.error(f"Failed to connect to wss: {str(e)}", )
raise HTTPException(status_code=502, detail="Failed to connect to wss")
wss_r = wss_stream_response(self.ws, conversation_id)
try:
if stream and isinstance(wss_r, types.AsyncGeneratorType):
return stream_response(self, wss_r, self.resp_model, self.max_tokens)
else:
return await format_not_stream_response(
stream_response(self, wss_r, self.resp_model, self.max_tokens), self.prompt_tokens,
self.max_tokens, self.resp_model)
finally:
if not isinstance(wss_r, types.AsyncGeneratorType):
await self.ws.close()
raise HTTPException(status_code=r.status_code, detail=resp)
else:
raise HTTPException(status_code=r.status_code, detail="Unsupported Content-Type")
rtext = await r.atext()
raise HTTPException(status_code=r.status_code, detail=rtext)
except HTTPException as e:
raise HTTPException(status_code=e.status_code, detail=e.detail)
except Exception as e:

View File

@ -105,6 +105,25 @@ async def wss_stream_response(websocket, conversation_id):
continue
async def head_process_response(response):
async for chunk in response:
chunk = chunk.decode("utf-8")
if chunk.startswith("data: {"):
chunk_old_data = json.loads(chunk[6:])
message = chunk_old_data.get("message", {})
if not message and "error" in chunk_old_data:
return response, False
role = message.get('author', {}).get('role')
if role == 'user' or role == 'system':
continue
status = message.get("status")
if status == "in_progress":
return response, True
return response, False
async def stream_response(service, response, model, max_tokens):
chat_id = f"chatcmpl-{''.join(random.choice(string.ascii_letters + string.digits) for _ in range(29))}"
system_fingerprint_list = model_system_fingerprint.get(model, None)
@ -116,9 +135,27 @@ async def stream_response(service, response, model, max_tokens):
last_message_id = None
last_content_type = None
last_recipient = None
start = False
start = True
end = False
chunk_new_data = {
"id": chat_id,
"object": "chat.completion.chunk",
"created": created_time,
"model": model,
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": ""},
"logprobs": None,
"finish_reason": None
}
]
}
if system_fingerprint:
chunk_new_data["system_fingerprint"] = system_fingerprint
yield f"data: {json.dumps(chunk_new_data)}\n\n"
async for chunk in response:
chunk = chunk.decode("utf-8")
# chunk = 'data: {"message": null, "conversation_id": "38b8bfcf-9912-45db-a48e-b62fb585c855", "error": "Our systems have detected unusual activity coming from your system. Please try again later."}'
@ -237,21 +274,8 @@ async def stream_response(service, response, model, max_tokens):
last_message_id = message_id
if not end and not delta.get("content"):
delta = {"role": "assistant", "content": ""}
chunk_new_data = {
"id": chat_id,
"object": "chat.completion.chunk",
"created": created_time,
"model": model,
"choices": [
{
"index": 0,
"delta": delta,
"logprobs": None,
"finish_reason": finish_reason
}
],
"system_fingerprint": system_fingerprint
}
chunk_new_data["choices"][0]["delta"] = delta
chunk_new_data["choices"][0]["finish_reason"] = finish_reason
if not service.history_disabled:
chunk_new_data.update({
"message_id": message_id,

View File

@ -46,7 +46,7 @@ proxy_url_list = proxy_url.split(',') if proxy_url else []
user_agents_list = user_agents.split(',') if user_agents else []
logger.info("-" * 60)
logger.info("Chat2Api v1.3.7 | https://github.com/lanqian528/chat2api")
logger.info("Chat2Api v1.4.0 | https://github.com/lanqian528/chat2api")
logger.info("-" * 60)
logger.info("Environment variables:")
logger.info("API_PREFIX: " + str(api_prefix))