diff --git a/.github/workflows/build_docker.yml b/.github/workflows/build_docker.yml index f3f684f..fd463d9 100644 --- a/.github/workflows/build_docker.yml +++ b/.github/workflows/build_docker.yml @@ -37,7 +37,7 @@ jobs: images: lanqian528/chat2api tags: | type=raw,value=latest,enable={{is_default_branch}} - type=raw,value=v1.2.8 + type=raw,value=v1.3.0 - name: Build and push uses: docker/build-push-action@v5 diff --git a/README.md b/README.md index 3d24d35..5142104 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ ## 功能 -### 最新版 v1.2.8 +### 最新版 v1.3.0 > 已完成 > - [x] 流式、非流式传输 diff --git a/chat2api.py b/chat2api.py index 8b584d5..332c81d 100644 --- a/chat2api.py +++ b/chat2api.py @@ -12,9 +12,10 @@ from fastapi.templating import Jinja2Templates from starlette.background import BackgroundTask from chatgpt.ChatService import ChatService +from chatgpt.authorization import refresh_all_tokens +import chatgpt.globals as globals from chatgpt.reverseProxy import chatgpt_reverse_proxy from utils.Logger import logger -from utils.authorization import token_list, error_token_list, refresh_all_tokens from utils.config import api_prefix, scheduled_refresh from utils.retry import async_retry @@ -88,7 +89,7 @@ async def send_conversation(request: Request, req_token: str = Depends(oauth2_sc @app.get(f"/{api_prefix}/tokens" if api_prefix else "/tokens", response_class=HTMLResponse) async def upload_html(request: Request): - tokens_count = len(set(token_list) - set(error_token_list)) + tokens_count = len(set(globals.token_list) - set(globals.error_token_list)) return templates.TemplateResponse("tokens.html", {"request": request, "api_prefix": api_prefix, "tokens_count": tokens_count}) @@ -98,28 +99,28 @@ async def upload_post(text: str = Form(...)): lines = text.split("\n") for line in lines: if line.strip() and not line.startswith("#"): - token_list.append(line.strip()) + globals.token_list.append(line.strip()) with open("data/token.txt", "a", encoding="utf-8") as f: f.write(line.strip() + "\n") - logger.info(f"Token list count: {len(token_list)}") - tokens_count = len(set(token_list) - set(error_token_list)) + logger.info(f"Token count: {len(globals.token_list)}, Error token count: {len(globals.error_token_list)}") + tokens_count = len(set(globals.token_list) - set(globals.error_token_list)) return {"status": "success", "tokens_count": tokens_count} @app.post(f"/{api_prefix}/tokens/clear" if api_prefix else "/tokens/clear") async def upload_post(): - token_list.clear() - error_token_list.clear() + globals.token_list.clear() + globals.error_token_list.clear() with open("data/token.txt", "w", encoding="utf-8") as f: pass - logger.info(f"Token list count: {len(token_list)}") - tokens_count = len(set(token_list) - set(error_token_list)) + logger.info(f"Token count: {len(globals.token_list)}, Error token count: {len(globals.error_token_list)}") + tokens_count = len(set(globals.token_list) - set(globals.error_token_list)) return {"status": "success", "tokens_count": tokens_count} @app.post(f"/{api_prefix}/tokens/error" if api_prefix else "/tokens/error") async def error_tokens(): - error_tokens_list = list(set(error_token_list)) + error_tokens_list = list(set(globals.error_token_list)) return {"status": "success", "error_tokens": error_tokens_list} diff --git a/chatgpt/ChatService.py b/chatgpt/ChatService.py index f6eda49..3c235bc 100644 --- a/chatgpt/ChatService.py +++ b/chatgpt/ChatService.py @@ -10,13 +10,13 @@ from starlette.concurrency import run_in_threadpool from api.files import get_image_size, get_file_extension, determine_file_use_case from api.models import model_proxy +from chatgpt.authorization import get_req_token, verify_token from chatgpt.chatFormat import api_messages_to_chat, stream_response, wss_stream_response, format_not_stream_response from chatgpt.chatLimit import check_is_limit, handle_request_limit from chatgpt.proofofWork import get_config, get_dpl, get_answer_token, get_requirements_token from chatgpt.wssClient import token2wss, set_wss from utils.Client import Client from utils.Logger import logger -from utils.authorization import verify_token, get_req_token from utils.config import proxy_url_list, chatgpt_base_url_list, arkose_token_url_list, history_disabled, pow_difficulty, \ conversation_only, enable_limit, upload_by_url, check_model, auth_key diff --git a/chatgpt/authorization.py b/chatgpt/authorization.py new file mode 100644 index 0000000..c11f015 --- /dev/null +++ b/chatgpt/authorization.py @@ -0,0 +1,56 @@ +import asyncio + +from fastapi import HTTPException + +from chatgpt.refreshToken import rt2ac +from utils.Logger import logger +from utils.config import authorization_list +import chatgpt.globals as globals + + +def get_req_token(req_token): + if req_token in authorization_list: + if globals.token_list: + global count + count += 1 + count %= len(globals.token_list) + while globals.token_list[count] in globals.error_token_list: + count += 1 + count %= len(globals.token_list) + return globals.token_list[count] + else: + return None + else: + return req_token + + +async def verify_token(req_token): + if not req_token: + if authorization_list: + logger.error("Unauthorized with empty token.") + raise HTTPException(status_code=401) + else: + return None + else: + if req_token.startswith("eyJhbGciOi") or req_token.startswith("fk-"): + access_token = req_token + return access_token + elif len(req_token) == 45: + try: + access_token = await rt2ac(req_token, force_refresh=False) + return access_token + except HTTPException as e: + raise HTTPException(status_code=e.status_code, detail=e.detail) + else: + return req_token + + +async def refresh_all_tokens(force_refresh=False): + for token in globals.token_list: + if len(token) == 45: + try: + await asyncio.sleep(2) + await rt2ac(token, force_refresh=force_refresh) + except HTTPException: + pass + logger.info("All tokens refreshed.") diff --git a/chatgpt/globals.py b/chatgpt/globals.py new file mode 100644 index 0000000..82e1cdb --- /dev/null +++ b/chatgpt/globals.py @@ -0,0 +1,54 @@ +import json +import os + +from utils.Logger import logger + +DATA_FOLDER = "data" +TOKENS_FILE = os.path.join(DATA_FOLDER, "token.txt") +REFRESH_MAP_FILE = os.path.join(DATA_FOLDER, "refresh_map.json") +ERROR_TOKENS_FILE = os.path.join(DATA_FOLDER, "error_token.txt") +WSS_MAP_FILE = os.path.join(DATA_FOLDER, "wss_map.json") + +count = 0 +token_list = [] +error_token_list = [] +refresh_map = {} +wss_map = {} + + +if not os.path.exists(DATA_FOLDER): + os.makedirs(DATA_FOLDER) + +if os.path.exists(REFRESH_MAP_FILE): + with open(REFRESH_MAP_FILE, "r") as file: + refresh_map = json.load(file) +else: + refresh_map = {} + +if os.path.exists(WSS_MAP_FILE): + with open(WSS_MAP_FILE, "r") as file: + wss_map = json.load(file) +else: + wss_map = {} + + +if os.path.exists(TOKENS_FILE): + with open(TOKENS_FILE, "r", encoding="utf-8") as f: + for line in f: + if line.strip() and not line.startswith("#"): + token_list.append(line.strip()) +else: + with open(TOKENS_FILE, "w", encoding="utf-8") as f: + pass + +if os.path.exists(ERROR_TOKENS_FILE): + with open(ERROR_TOKENS_FILE, "r", encoding="utf-8") as f: + for line in f: + if line.strip() and not line.startswith("#"): + error_token_list.append(line.strip()) +else: + with open(ERROR_TOKENS_FILE, "w", encoding="utf-8") as f: + pass + +if token_list: + logger.info(f"Token list count: {len(token_list)}, Error token list count: {len(error_token_list)}") \ No newline at end of file diff --git a/chatgpt/refreshToken.py b/chatgpt/refreshToken.py index 63eb4f2..d10167e 100644 --- a/chatgpt/refreshToken.py +++ b/chatgpt/refreshToken.py @@ -1,5 +1,4 @@ import json -import os import random import time @@ -8,35 +7,24 @@ from fastapi import HTTPException from utils.Client import Client from utils.Logger import logger from utils.config import proxy_url_list - -DATA_FOLDER = "data" -REFRESH_MAP_FILE = os.path.join(DATA_FOLDER, "refresh_map.json") - -if not os.path.exists(DATA_FOLDER): - os.makedirs(DATA_FOLDER) - -if os.path.exists(REFRESH_MAP_FILE): - with open(REFRESH_MAP_FILE, "r") as file: - refresh_map = json.load(file) -else: - refresh_map = {} +import chatgpt.globals as globals def save_refresh_map(refresh_map): - with open(REFRESH_MAP_FILE, "w") as file: + with open(globals.REFRESH_MAP_FILE, "w") as file: json.dump(refresh_map, file) async def rt2ac(refresh_token, force_refresh=False): - if not force_refresh and (refresh_token in refresh_map and int(time.time()) - refresh_map.get(refresh_token, {}).get("timestamp", 0) < 5 * 24 * 60 * 60): - access_token = refresh_map[refresh_token]["token"] + if not force_refresh and (refresh_token in globals.refresh_map and int(time.time()) - globals.refresh_map.get(refresh_token, {}).get("timestamp", 0) < 5 * 24 * 60 * 60): + access_token = globals.refresh_map[refresh_token]["token"] logger.info(f"refresh_token -> access_token from cache") return access_token else: try: access_token = await chat_refresh(refresh_token) - refresh_map[refresh_token] = {"token": access_token, "timestamp": int(time.time())} - save_refresh_map(refresh_map) + globals.refresh_map[refresh_token] = {"token": access_token, "timestamp": int(time.time())} + save_refresh_map(globals.refresh_map) logger.info(f"refresh_token -> access_token with openai: {access_token}") return access_token except HTTPException as e: @@ -57,6 +45,10 @@ async def chat_refresh(refresh_token): access_token = r.json()['access_token'] return access_token else: + with open(globals.ERROR_TOKENS_FILE, "a", encoding="utf-8") as f: + f.write(refresh_token + "\n") + if refresh_token not in globals.error_token_list: + globals.error_token_list.append(refresh_token) raise Exception(r.text[:100]) except Exception as e: logger.error(f"Failed to refresh access_token `{refresh_token}`: {str(e)}") diff --git a/chatgpt/wssClient.py b/chatgpt/wssClient.py index 171f221..0c60f82 100644 --- a/chatgpt/wssClient.py +++ b/chatgpt/wssClient.py @@ -1,37 +1,23 @@ import json -import os import time from utils.Logger import logger - -DATA_FOLDER = "data" -WSS_MAP_FILE = os.path.join(DATA_FOLDER, "wss_map.json") - -wss_map = {} - -if not os.path.exists(DATA_FOLDER): - os.makedirs(DATA_FOLDER) - -if os.path.exists(WSS_MAP_FILE): - with open(WSS_MAP_FILE, "r") as file: - wss_map = json.load(file) -else: - wss_map = {} +import chatgpt.globals as globals def save_wss_map(wss_map): - with open(WSS_MAP_FILE, "w") as file: + with open(globals.WSS_MAP_FILE, "w") as file: json.dump(wss_map, file) async def token2wss(token): if not token: return False, None - if token in wss_map: - wss_mode = wss_map[token]["wss_mode"] + if token in globals.wss_map: + wss_mode = globals.wss_map[token]["wss_mode"] if wss_mode: - if int(time.time()) - wss_map.get(token, {}).get("timestamp", 0) < 60 * 60: - wss_url = wss_map[token]["wss_url"] + if int(time.time()) - globals.wss_map.get(token, {}).get("timestamp", 0) < 60 * 60: + wss_url = globals.wss_map[token]["wss_url"] logger.info(f"token -> wss_url from cache") return wss_mode, wss_url else: @@ -45,6 +31,6 @@ async def token2wss(token): async def set_wss(token, wss_mode, wss_url=None): if not token: return True - wss_map[token] = {"timestamp": int(time.time()), "wss_url": wss_url, "wss_mode": wss_mode} - save_wss_map(wss_map) + globals.wss_map[token] = {"timestamp": int(time.time()), "wss_url": wss_url, "wss_mode": wss_mode} + save_wss_map(globals.wss_map) return True diff --git a/utils/authorization.py b/utils/authorization.py deleted file mode 100644 index 1ff30d9..0000000 --- a/utils/authorization.py +++ /dev/null @@ -1,92 +0,0 @@ -import asyncio -import os - -from fastapi import HTTPException - -from chatgpt.refreshToken import rt2ac, save_refresh_map -from utils.Logger import logger -from utils.config import authorization_list - -count = 0 -token_list = [] -error_token_list = [] - -DATA_FOLDER = "data" -TOKENS_FILE = os.path.join(DATA_FOLDER, "token.txt") -ERROR_TOKENS_FILE = os.path.join(DATA_FOLDER, "error_token.txt") - -if not os.path.exists(DATA_FOLDER): - os.makedirs(DATA_FOLDER) - -if os.path.exists(TOKENS_FILE): - with open(TOKENS_FILE, "r", encoding="utf-8") as f: - for line in f: - if line.strip() and not line.startswith("#"): - token_list.append(line.strip()) -else: - with open(TOKENS_FILE, "w", encoding="utf-8") as f: - pass - -if os.path.exists(ERROR_TOKENS_FILE): - with open(ERROR_TOKENS_FILE, "r", encoding="utf-8") as f: - for line in f: - if line.strip() and not line.startswith("#"): - error_token_list.append(line.strip()) -else: - with open(ERROR_TOKENS_FILE, "w", encoding="utf-8") as f: - pass - -if token_list: - logger.info(f"Token list count: {len(token_list)}") - - -def get_req_token(req_token): - if req_token in authorization_list: - if token_list: - global count - count += 1 - count %= len(token_list) - while token_list[count] in error_token_list: - count += 1 - count %= len(token_list) - return token_list[count] - else: - return None - else: - return req_token - - -async def verify_token(req_token): - if not req_token: - if authorization_list: - logger.error("Unauthorized with empty token.") - raise HTTPException(status_code=401) - else: - return None - else: - if req_token.startswith("eyJhbGciOi") or req_token.startswith("fk-"): - access_token = req_token - return access_token - elif len(req_token) == 45: - try: - access_token = await rt2ac(req_token, force_refresh=False) - return access_token - except HTTPException as e: - logger.error(f"{e.detail}: {req_token}") - raise HTTPException(status_code=e.status_code, detail=e.detail) - else: - return req_token - - -async def refresh_all_tokens(force_refresh=False): - for token in token_list: - if len(token) == 45: - try: - await asyncio.sleep(2) - await rt2ac(token, force_refresh=force_refresh) - except HTTPException: - with open(ERROR_TOKENS_FILE, "a", encoding="utf-8") as f: - f.write(token + "\n") - if token not in error_token_list: - error_token_list.append(token) - logger.info("All tokens refreshed.") diff --git a/utils/config.py b/utils/config.py index 20877f7..eb259e5 100644 --- a/utils/config.py +++ b/utils/config.py @@ -42,7 +42,7 @@ arkose_token_url_list = arkose_token_url.split(',') if arkose_token_url else [] proxy_url_list = proxy_url.split(',') if proxy_url else [] logger.info("-" * 60) -logger.info("Chat2Api v1.2.8 | https://github.com/lanqian528/chat2api") +logger.info("Chat2Api v1.3.0 | https://github.com/lanqian528/chat2api") logger.info("-" * 60) logger.info("Environment variables:") logger.info("API_PREFIX: " + str(api_prefix))