diff --git a/.github/workflows/build_docker.yml b/.github/workflows/build_docker.yml index d4eb932..3fe8a8d 100644 --- a/.github/workflows/build_docker.yml +++ b/.github/workflows/build_docker.yml @@ -37,7 +37,7 @@ jobs: images: lanqian528/chat2api tags: | type=raw,value=latest,enable={{is_default_branch}} - type=raw,value=v1.3.7 + type=raw,value=v1.4.0 - name: Build and push uses: docker/build-push-action@v5 diff --git a/README.md b/README.md index ea6fedd..c43c7b7 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ ## 功能 -### 最新版 v1.3.7 +### 最新版 v1.4.0 > 已完成 > - [x] 流式、非流式传输 diff --git a/chat2api.py b/chat2api.py index 332c81d..66da6dd 100644 --- a/chat2api.py +++ b/chat2api.py @@ -58,17 +58,20 @@ async def to_send_conversation(request_data, req_token): raise HTTPException(status_code=500, detail="Server error") +async def process(request_data, req_token): + chat_service = await to_send_conversation(request_data, req_token) + await chat_service.prepare_send_conversation() + res = await chat_service.send_conversation() + return chat_service, res + @app.post(f"/{api_prefix}/v1/chat/completions" if api_prefix else "/v1/chat/completions") async def send_conversation(request: Request, req_token: str = Depends(oauth2_scheme)): try: request_data = await request.json() except Exception: raise HTTPException(status_code=400, detail={"error": "Invalid JSON body"}) - - chat_service = await async_retry(to_send_conversation, request_data, req_token) + chat_service, res = await async_retry(process, request_data, req_token) try: - await chat_service.prepare_send_conversation() - res = await chat_service.send_conversation() if isinstance(res, types.AsyncGeneratorType): background = BackgroundTask(chat_service.close_client) return StreamingResponse(res, media_type="text/event-stream", background=background) diff --git a/chatgpt/ChatService.py b/chatgpt/ChatService.py index dc02ed6..dd6927d 100644 --- a/chatgpt/ChatService.py +++ b/chatgpt/ChatService.py @@ -1,21 +1,18 @@ import asyncio import json import random -import types import uuid -import websockets from fastapi import HTTPException from starlette.concurrency import run_in_threadpool from api.files import get_image_size, get_file_extension, determine_file_use_case from api.models import model_proxy from chatgpt.authorization import get_req_token, verify_token -from chatgpt.chatFormat import api_messages_to_chat, stream_response, wss_stream_response, format_not_stream_response +from chatgpt.chatFormat import api_messages_to_chat, stream_response, format_not_stream_response, head_process_response from chatgpt.chatLimit import check_is_limit, handle_request_limit from chatgpt.proofofWork import get_config, get_dpl, get_answer_token, get_requirements_token from chatgpt.turnstile import process_turnstile -from chatgpt.wssClient import token2wss, set_wss from utils.Client import Client from utils.Logger import logger from utils.config import proxy_url_list, chatgpt_base_url_list, arkose_token_url_list, history_disabled, pow_difficulty, \ @@ -68,12 +65,6 @@ class ChatService: self.arkose_token_url = random.choice(arkose_token_url_list) if arkose_token_url_list else None self.s = Client(proxy=self.proxy_url) - self.ws = None - if conversation_only: - self.wss_mode = False - self.wss_url = None - else: - self.wss_mode, self.wss_url = await token2wss(self.req_token) self.oai_device_id = str(uuid.uuid4()) self.persona = None @@ -133,21 +124,6 @@ class ChatService: else: self.req_model = "text-davinci-002-render-sha" - async def get_wss_url(self): - url = f'{self.base_url}/register-websocket' - headers = self.base_headers.copy() - r = await self.s.post(url, headers=headers, data='', timeout=5) - try: - if r.status_code == 200: - resp = r.json() - logger.info(f'register-websocket response:{resp}') - wss_url = resp.get('wss_url') - return wss_url - raise Exception(r.text) - except Exception as e: - logger.error(f"get_wss_url error: {str(e)}") - raise HTTPException(status_code=r.status_code, detail=f"Failed to get wss url: {str(e)}") - async def get_chat_requirements(self): if conversation_only: return None @@ -308,14 +284,6 @@ class ChatService: async def send_conversation(self): try: - try: - if self.wss_mode: - if not self.wss_url: - self.wss_url = await self.get_wss_url() - self.ws = await websockets.connect(self.wss_url, ping_interval=None, subprotocols=["json.reliable.webpubsub.azure.v1"]) - except Exception as e: - logger.error(f"Failed to connect to wss: {str(e)}", ) - raise HTTPException(status_code=502, detail="Failed to connect to wss") url = f'{self.base_url}/conversation' stream = self.data.get("stream", False) r = await self.s.post_stream(url, headers=self.chat_headers, json=self.chat_request, timeout=10, @@ -338,40 +306,23 @@ class ChatService: raise HTTPException(status_code=r.status_code, detail=detail) content_type = r.headers.get("Content-Type", "") - if "text/event-stream" in content_type and stream: - await set_wss(self.req_token, False) - return stream_response(self, r.aiter_lines(), self.resp_model, self.max_tokens) - elif "text/event-stream" in content_type and not stream: - await set_wss(self.req_token, False) - return await format_not_stream_response( - stream_response(self, r.aiter_lines(), self.resp_model, self.max_tokens), self.prompt_tokens, - self.max_tokens, self.resp_model) + if "text/event-stream" in content_type: + res, start = await head_process_response(r.aiter_lines()) + if not start: + raise HTTPException(status_code=403, detail="Our systems have detected unusual activity coming from your system. Please try again later.") + if stream: + return stream_response(self, res, self.resp_model, self.max_tokens) + else: + return await format_not_stream_response( + stream_response(self, res, self.resp_model, self.max_tokens), self.prompt_tokens, + self.max_tokens, self.resp_model) elif "application/json" in content_type: rtext = await r.atext() resp = json.loads(rtext) - self.wss_url = resp.get('wss_url') - conversation_id = resp.get('conversation_id') - await set_wss(self.req_token, True, self.wss_url) - logger.info(f"next wss_url: {self.wss_url}") - if not self.ws: - try: - self.ws = await websockets.connect(self.wss_url, ping_interval=None, subprotocols=["json.reliable.webpubsub.azure.v1"]) - except Exception as e: - logger.error(f"Failed to connect to wss: {str(e)}", ) - raise HTTPException(status_code=502, detail="Failed to connect to wss") - wss_r = wss_stream_response(self.ws, conversation_id) - try: - if stream and isinstance(wss_r, types.AsyncGeneratorType): - return stream_response(self, wss_r, self.resp_model, self.max_tokens) - else: - return await format_not_stream_response( - stream_response(self, wss_r, self.resp_model, self.max_tokens), self.prompt_tokens, - self.max_tokens, self.resp_model) - finally: - if not isinstance(wss_r, types.AsyncGeneratorType): - await self.ws.close() + raise HTTPException(status_code=r.status_code, detail=resp) else: - raise HTTPException(status_code=r.status_code, detail="Unsupported Content-Type") + rtext = await r.atext() + raise HTTPException(status_code=r.status_code, detail=rtext) except HTTPException as e: raise HTTPException(status_code=e.status_code, detail=e.detail) except Exception as e: diff --git a/chatgpt/chatFormat.py b/chatgpt/chatFormat.py index 58d4ec1..affe308 100644 --- a/chatgpt/chatFormat.py +++ b/chatgpt/chatFormat.py @@ -105,6 +105,25 @@ async def wss_stream_response(websocket, conversation_id): continue +async def head_process_response(response): + async for chunk in response: + chunk = chunk.decode("utf-8") + if chunk.startswith("data: {"): + chunk_old_data = json.loads(chunk[6:]) + message = chunk_old_data.get("message", {}) + if not message and "error" in chunk_old_data: + return response, False + role = message.get('author', {}).get('role') + if role == 'user' or role == 'system': + continue + + status = message.get("status") + if status == "in_progress": + return response, True + return response, False + + + async def stream_response(service, response, model, max_tokens): chat_id = f"chatcmpl-{''.join(random.choice(string.ascii_letters + string.digits) for _ in range(29))}" system_fingerprint_list = model_system_fingerprint.get(model, None) @@ -116,9 +135,27 @@ async def stream_response(service, response, model, max_tokens): last_message_id = None last_content_type = None last_recipient = None - start = False + start = True end = False + chunk_new_data = { + "id": chat_id, + "object": "chat.completion.chunk", + "created": created_time, + "model": model, + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": ""}, + "logprobs": None, + "finish_reason": None + } + ] + } + if system_fingerprint: + chunk_new_data["system_fingerprint"] = system_fingerprint + yield f"data: {json.dumps(chunk_new_data)}\n\n" + async for chunk in response: chunk = chunk.decode("utf-8") # chunk = 'data: {"message": null, "conversation_id": "38b8bfcf-9912-45db-a48e-b62fb585c855", "error": "Our systems have detected unusual activity coming from your system. Please try again later."}' @@ -237,21 +274,8 @@ async def stream_response(service, response, model, max_tokens): last_message_id = message_id if not end and not delta.get("content"): delta = {"role": "assistant", "content": ""} - chunk_new_data = { - "id": chat_id, - "object": "chat.completion.chunk", - "created": created_time, - "model": model, - "choices": [ - { - "index": 0, - "delta": delta, - "logprobs": None, - "finish_reason": finish_reason - } - ], - "system_fingerprint": system_fingerprint - } + chunk_new_data["choices"][0]["delta"] = delta + chunk_new_data["choices"][0]["finish_reason"] = finish_reason if not service.history_disabled: chunk_new_data.update({ "message_id": message_id, diff --git a/utils/config.py b/utils/config.py index 0cc3b5e..88f4530 100644 --- a/utils/config.py +++ b/utils/config.py @@ -46,7 +46,7 @@ proxy_url_list = proxy_url.split(',') if proxy_url else [] user_agents_list = user_agents.split(',') if user_agents else [] logger.info("-" * 60) -logger.info("Chat2Api v1.3.7 | https://github.com/lanqian528/chat2api") +logger.info("Chat2Api v1.4.0 | https://github.com/lanqian528/chat2api") logger.info("-" * 60) logger.info("Environment variables:") logger.info("API_PREFIX: " + str(api_prefix))