mirror of
https://github.com/lanqian528/chat2api.git
synced 2026-06-13 21:02:46 +08:00
194 lines
6.6 KiB
Python
194 lines
6.6 KiB
Python
import base64
|
|
import io
|
|
import math
|
|
|
|
import tiktoken
|
|
from PIL import Image
|
|
|
|
from utils.Client import Client
|
|
|
|
model_proxy = {
|
|
"gpt-3.5-turbo": "gpt-3.5-turbo-0125",
|
|
"gpt-3.5-turbo-16k": "gpt-3.5-turbo-16k-0613",
|
|
"gpt-4": "gpt-4-0613",
|
|
"gpt-4-32k": "gpt-4-32k-0613",
|
|
"gpt-4-turbo-preview": "gpt-4-0125-preview",
|
|
"gpt-4-vision-preview": "gpt-4-1106-vision-preview",
|
|
"gpt-4-turbo": "gpt-4-turbo-2024-04-09",
|
|
"claude-3-opus": "claude-3-opus-20240229",
|
|
"claude-3-sonnet": "claude-3-sonnet-20240229",
|
|
"claude-3-haiku": "claude-3-haiku-20240307",
|
|
}
|
|
|
|
model_system_fingerprint = {
|
|
"gpt-3.5-turbo-0125": ["fp_b28b39ffa8"],
|
|
"gpt-3.5-turbo-1106": ["fp_592ef5907d"],
|
|
"gpt-4-0125-preview": ["fp_f38f4d6482", "fp_2f57f81c11", "fp_a7daf7c51e", "fp_a865e8ede4", "fp_13c70b9f70",
|
|
"fp_b77cb481ed"],
|
|
"gpt-4-1106-preview": ["fp_e467c31c3d", "fp_d986a8d1ba", "fp_99a5a401bb", "fp_123d5a9f90", "fp_0d1affc7a6",
|
|
"fp_5c95a4634e"],
|
|
"gpt-4-turbo-2024-04-09": ["fp_d1bac968b4"]
|
|
}
|
|
|
|
|
|
async def get_file_content(url):
|
|
if url.startswith("data:"):
|
|
mime_type, base64_data = url.split(';')[0].split(':')[1], url.split(',')[1]
|
|
file_content = base64.b64decode(base64_data)
|
|
return file_content, mime_type
|
|
else:
|
|
client = Client()
|
|
try:
|
|
r = await client.get(url)
|
|
r.raise_for_status()
|
|
file_content = r.content
|
|
mime_type = r.headers.get('Content-Type', '').split(';')[0].strip()
|
|
return file_content, mime_type
|
|
finally:
|
|
await client.close()
|
|
|
|
|
|
async def calculate_image_tokens(width, height, detail):
|
|
if detail == "low":
|
|
return 85
|
|
else:
|
|
max_dimension = max(width, height)
|
|
if max_dimension > 2048:
|
|
scale_factor = 2048 / max_dimension
|
|
new_width = int(width * scale_factor)
|
|
new_height = int(height * scale_factor)
|
|
else:
|
|
new_width = width
|
|
new_height = height
|
|
|
|
width, height = new_width, new_height
|
|
min_dimension = min(width, height)
|
|
if min_dimension > 768:
|
|
scale_factor = 768 / min_dimension
|
|
new_width = int(width * scale_factor)
|
|
new_height = int(height * scale_factor)
|
|
else:
|
|
new_width = width
|
|
new_height = height
|
|
|
|
width, height = new_width, new_height
|
|
num_masks_w = math.ceil(width / 512)
|
|
num_masks_h = math.ceil(height / 512)
|
|
total_masks = num_masks_w * num_masks_h
|
|
|
|
tokens_per_mask = 170
|
|
total_tokens = total_masks * tokens_per_mask + 85
|
|
|
|
return total_tokens
|
|
|
|
|
|
async def num_tokens_from_messages(messages, model=''):
|
|
try:
|
|
encoding = tiktoken.encoding_for_model(model)
|
|
except KeyError:
|
|
encoding = tiktoken.get_encoding("cl100k_base")
|
|
if model in {
|
|
"gpt-3.5-turbo-0613",
|
|
"gpt-3.5-turbo-16k-0613",
|
|
"gpt-4-0314",
|
|
"gpt-4-32k-0314",
|
|
"gpt-4-0613",
|
|
"gpt-4-32k-0613",
|
|
}:
|
|
tokens_per_message = 3
|
|
elif model == "gpt-3.5-turbo-0301":
|
|
tokens_per_message = 4
|
|
else:
|
|
tokens_per_message = 3
|
|
num_tokens = 0
|
|
for message in messages:
|
|
num_tokens += tokens_per_message
|
|
for key, value in message.items():
|
|
if isinstance(value, list):
|
|
for item in value:
|
|
if item.get("type") == "text":
|
|
num_tokens += len(encoding.encode(item.get("text")))
|
|
if item.get("type") == "image_url":
|
|
pass
|
|
else:
|
|
num_tokens += len(encoding.encode(value))
|
|
num_tokens += 3
|
|
return num_tokens
|
|
|
|
|
|
async def num_tokens_from_content(content, model=None):
|
|
try:
|
|
encoding = tiktoken.encoding_for_model(model)
|
|
except KeyError:
|
|
encoding = tiktoken.get_encoding("cl100k_base")
|
|
encoded_content = encoding.encode(content)
|
|
len_encoded_content = len(encoded_content)
|
|
return len_encoded_content
|
|
|
|
|
|
async def split_tokens_from_content(content, max_tokens, model=None):
|
|
try:
|
|
encoding = tiktoken.encoding_for_model(model)
|
|
except KeyError:
|
|
encoding = tiktoken.get_encoding("cl100k_base")
|
|
encoded_content = encoding.encode(content)
|
|
len_encoded_content = len(encoded_content)
|
|
if len_encoded_content >= max_tokens:
|
|
content = encoding.decode(encoded_content[:max_tokens])
|
|
return content, max_tokens, "length"
|
|
else:
|
|
return content, len_encoded_content, "stop"
|
|
|
|
|
|
async def determine_file_use_case(mime_type):
|
|
multimodal_types = ["image/jpeg", "image/webp", "image/png", "image/gif"]
|
|
my_files_types = ["text/x-php", "application/msword", "text/x-c", "text/html",
|
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
|
"application/json", "text/javascript", "application/pdf",
|
|
"text/x-java", "text/x-tex", "text/x-typescript", "text/x-sh",
|
|
"text/x-csharp", "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
|
"text/x-c++", "application/x-latext", "text/markdown", "text/plain",
|
|
"text/x-ruby", "text/x-script.python"]
|
|
|
|
if mime_type in multimodal_types:
|
|
return "multimodal"
|
|
elif mime_type in my_files_types:
|
|
return "my_files"
|
|
else:
|
|
return "ace_upload"
|
|
|
|
|
|
async def get_image_size(file_content):
|
|
with Image.open(io.BytesIO(file_content)) as img:
|
|
return img.width, img.height
|
|
|
|
|
|
async def get_file_extension(mime_type):
|
|
extension_mapping = {
|
|
"image/jpeg": ".jpg",
|
|
"image/png": ".png",
|
|
"image/gif": ".gif",
|
|
"image/webp": ".webp",
|
|
"text/x-php": ".php",
|
|
"application/msword": ".doc",
|
|
"text/x-c": ".c",
|
|
"text/html": ".html",
|
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
|
"application/json": ".json",
|
|
"text/javascript": ".js",
|
|
"application/pdf": ".pdf",
|
|
"text/x-java": ".java",
|
|
"text/x-tex": ".tex",
|
|
"text/x-typescript": ".ts",
|
|
"text/x-sh": ".sh",
|
|
"text/x-csharp": ".cs",
|
|
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
|
|
"text/x-c++": ".cpp",
|
|
"application/x-latext": ".latex",
|
|
"text/markdown": ".md",
|
|
"text/plain": ".txt",
|
|
"text/x-ruby": ".rb",
|
|
"text/x-script.python": ".py",
|
|
# 其他 MIME 类型和扩展名...
|
|
}
|
|
return extension_mapping.get(mime_type, "") |