From f22c21dd913aa8da65d447d7d439f6fac570e254 Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Tue, 30 Jun 2026 16:15:26 +0800 Subject: [PATCH] style: improve code formatting and consistency across multiple files --- .../serializers/model_apply_serializers.py | 90 ++-- apps/local_model/serializers/rsa_util.py | 49 +- apps/local_model/views/model_apply.py | 29 +- apps/models_provider/api/model.py | 82 ++-- apps/models_provider/api/provide.py | 100 ++-- apps/models_provider/base_model_provider.py | 8 +- .../constants/model_provider_constants.py | 12 +- .../serializers/model_apply_serializers.py | 61 +-- .../serializers/model_serializer.py | 401 +++++++-------- apps/models_provider/tools.py | 57 +-- apps/models_provider/views/model.py | 462 ++++++++++-------- apps/models_provider/views/model_apply.py | 73 ++- apps/models_provider/views/provide.py | 133 ++--- 13 files changed, 837 insertions(+), 720 deletions(-) diff --git a/apps/local_model/serializers/model_apply_serializers.py b/apps/local_model/serializers/model_apply_serializers.py index 76c8792bbfa..a69e1a2a7ef 100644 --- a/apps/local_model/serializers/model_apply_serializers.py +++ b/apps/local_model/serializers/model_apply_serializers.py @@ -1,33 +1,32 @@ # coding=utf-8 """ - @project: MaxKB - @Author:虎 - @file: model_apply_serializers.py - @date:2024/8/20 20:39 - @desc: +@project: MaxKB +@Author:虎 +@file: model_apply_serializers.py +@date:2024/8/20 20:39 +@desc: """ + import json import threading import time +from common.cache.mem_cache import MemCache from django.db import connection from django.db.models import QuerySet from django.utils.translation import gettext_lazy as _ from langchain_core.documents import Document -from rest_framework import serializers - from local_model.models import Model from local_model.serializers.rsa_util import rsa_long_decrypt from models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider - -from common.cache.mem_cache import MemCache +from rest_framework import serializers _lock = threading.Lock() locks = {} class ModelManage: - cache = MemCache('model', {}) + cache = MemCache("model", {}) up_clear_time = time.time() @staticmethod @@ -74,87 +73,100 @@ def delete_key(_id): def get_local_model(model, **kwargs): - return LocalModelProvider().get_model(model.model_type, model.model_name, - json.loads( - rsa_long_decrypt(model.credential)), - model_id=model.id, - streaming=True, **kwargs) + return LocalModelProvider().get_model( + model.model_type, + model.model_name, + json.loads(rsa_long_decrypt(model.credential)), + model_id=model.id, + streaming=True, + **kwargs, + ) def get_embedding_model(model_id): model = QuerySet(Model).filter(id=model_id).first() # 手动关闭数据库连接 connection.close() - embedding_model = ModelManage.get_model(model_id, - lambda _id: get_local_model(model, use_local=True)) + embedding_model = ModelManage.get_model(model_id, lambda _id: get_local_model(model, use_local=True)) return embedding_model class EmbedDocuments(serializers.Serializer): - texts = serializers.ListField(required=True, - child=serializers.CharField(required=True, label=_('vector text')), - label=_('vector text list')) + texts = serializers.ListField( + required=True, child=serializers.CharField(required=True, label=_("vector text")), label=_("vector text list") + ) class EmbedQuery(serializers.Serializer): - text = serializers.CharField(required=True, label=_('vector text')) + text = serializers.CharField(required=True, label=_("vector text")) class CompressDocument(serializers.Serializer): - page_content = serializers.CharField(required=True, label=_('text')) - metadata = serializers.DictField(required=False, label=_('metadata')) + page_content = serializers.CharField(required=True, label=_("text")) + metadata = serializers.DictField(required=False, label=_("metadata")) class CompressDocuments(serializers.Serializer): documents = CompressDocument(required=True, many=True) - query = serializers.CharField(required=True, label=_('query')) + query = serializers.CharField(required=True, label=_("query")) class ValidateModelSerializers(serializers.Serializer): - model_name = serializers.CharField(required=True, label=_('model_name')) + model_name = serializers.CharField(required=True, label=_("model_name")) - model_type = serializers.CharField(required=True, label=_('model_type')) + model_type = serializers.CharField(required=True, label=_("model_type")) model_credential = serializers.DictField(required=True, label="credential") def validate_model(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) - LocalModelProvider().is_valid_credential(self.data.get('model_type'), self.data.get('model_name'), - self.data.get('model_credential'), model_params={}, - raise_exception=True) + LocalModelProvider().is_valid_credential( + self.data.get("model_type"), + self.data.get("model_name"), + self.data.get("model_credential"), + model_params={}, + raise_exception=True, + ) class ModelApplySerializers(serializers.Serializer): - model_id = serializers.UUIDField(required=True, label=_('model id')) + model_id = serializers.UUIDField(required=True, label=_("model id")) def embed_documents(self, instance, with_valid=True): if with_valid: self.is_valid(raise_exception=True) EmbedDocuments(data=instance).is_valid(raise_exception=True) - model = get_embedding_model(self.data.get('model_id')) - return model.embed_documents(instance.getlist('texts')) + model = get_embedding_model(self.data.get("model_id")) + return model.embed_documents(instance.getlist("texts")) def embed_query(self, instance, with_valid=True): if with_valid: self.is_valid(raise_exception=True) EmbedQuery(data=instance).is_valid(raise_exception=True) - model = get_embedding_model(self.data.get('model_id')) - return model.embed_query(instance.get('text')) + model = get_embedding_model(self.data.get("model_id")) + return model.embed_query(instance.get("text")) def compress_documents(self, instance, with_valid=True): if with_valid: self.is_valid(raise_exception=True) CompressDocuments(data=instance).is_valid(raise_exception=True) - model = get_embedding_model(self.data.get('model_id')) - return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents( - [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in - instance.get('documents')], instance.get('query'))] + model = get_embedding_model(self.data.get("model_id")) + return [ + {"page_content": d.page_content, "metadata": d.metadata} + for d in model.compress_documents( + [ + Document(page_content=document.get("page_content"), metadata=document.get("metadata")) + for document in instance.get("documents") + ], + instance.get("query"), + ) + ] def unload(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) - ModelManage.delete_key(self.data.get('model_id')) + ModelManage.delete_key(self.data.get("model_id")) return True diff --git a/apps/local_model/serializers/rsa_util.py b/apps/local_model/serializers/rsa_util.py index df2cedba736..1bedd9c2994 100644 --- a/apps/local_model/serializers/rsa_util.py +++ b/apps/local_model/serializers/rsa_util.py @@ -1,21 +1,21 @@ # coding=utf-8 """ - @project: maxkb - @Author:虎 - @file: rsa_util.py - @date:2023/11/3 11:13 - @desc: +@project: maxkb +@Author:虎 +@file: rsa_util.py +@date:2023/11/3 11:13 +@desc: """ + import base64 import threading +from common.constants.cache_version import Cache_Version from Crypto.Cipher import PKCS1_v1_5 as PKCS1_cipher from Crypto.PublicKey import RSA from django.core import cache from django.db.models import QuerySet - -from common.constants.cache_version import Cache_Version -from local_model.models.system_setting import SystemSetting, SettingType +from local_model.models.system_setting import SettingType, SystemSetting lock = threading.Lock() rsa_cache = cache.cache @@ -33,9 +33,8 @@ def generate(): key = RSA.generate(2048) # 获取私钥 - encrypted_key = key.export_key(passphrase=secret_code, pkcs=8, - protection="scryptAndAES128-CBC") - return {'key': key.publickey().export_key(), 'value': encrypted_key} + encrypted_key = key.export_key(passphrase=secret_code, pkcs=8, protection="scryptAndAES128-CBC") + return {"key": key.publickey().export_key(), "value": encrypted_key} def get_key_pair(): @@ -47,7 +46,7 @@ def get_key_pair(): return rsa_value rsa_value = get_key_pair_by_sql() version, get_key = Cache_Version.SYSTEM.value - rsa_cache.set(get_key(key='rsa_key'), rsa_value, timeout=None, version=version) + rsa_cache.set(get_key(key="rsa_key"), rsa_value, timeout=None, version=version) return rsa_value @@ -55,8 +54,9 @@ def get_key_pair_by_sql(): system_setting = QuerySet(SystemSetting).filter(type=SettingType.RSA.value).first() if system_setting is None: kv = generate() - system_setting = SystemSetting(type=SettingType.RSA.value, - meta={'key': kv.get('key').decode(), 'value': kv.get('value').decode()}) + system_setting = SystemSetting( + type=SettingType.RSA.value, meta={"key": kv.get("key").decode(), "value": kv.get("value").decode()} + ) system_setting.save() return system_setting.meta @@ -69,7 +69,7 @@ def encrypt(msg, public_key: str | None = None): :return: 加密后的数据 """ if public_key is None: - public_key = get_key_pair().get('key') + public_key = get_key_pair().get("key") cipher = PKCS1_cipher.new(RSA.importKey(public_key)) encrypt_msg = cipher.encrypt(msg.encode("utf-8")) return base64.b64encode(encrypt_msg).decode() @@ -83,7 +83,7 @@ def decrypt(msg, pri_key: str | None = None): :return: 解密后数据 """ if pri_key is None: - pri_key = get_key_pair().get('value') + pri_key = get_key_pair().get("value") cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code)) decrypt_data = cipher.decrypt(base64.b64decode(msg), 0) return decrypt_data.decode("utf-8") @@ -100,22 +100,21 @@ def rsa_long_encrypt(message, public_key: str | None = None, length=200): """ # 读取公钥 if public_key is None: - public_key = get_key_pair().get('key') - cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key, - passphrase=secret_code)) + public_key = get_key_pair().get("key") + cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key, passphrase=secret_code)) # 处理:Plaintext is too long. 分段加密 if len(message) <= length: # 对编码的数据进行加密,并通过base64进行编码 - result = base64.b64encode(cipher.encrypt(message.encode('utf-8'))) + result = base64.b64encode(cipher.encrypt(message.encode("utf-8"))) else: rsa_text = [] # 对编码后的数据进行切片,原因:加密长度不能过长 for i in range(0, len(message), length): - cont = message[i:i + length] + cont = message[i : i + length] # 对切片后的数据进行加密,并新增到text后面 - rsa_text.append(cipher.encrypt(cont.encode('utf-8'))) + rsa_text.append(cipher.encrypt(cont.encode("utf-8"))) # 加密完进行拼接 - cipher_text = b''.join(rsa_text) + cipher_text = b"".join(rsa_text) # base64进行编码 result = base64.b64encode(cipher_text) return result.decode() @@ -130,10 +129,10 @@ def rsa_long_decrypt(message, pri_key: str | None = None, length=256): :return: 解密后的数据 """ if pri_key is None: - pri_key = get_key_pair().get('value') + pri_key = get_key_pair().get("value") cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code)) base64_de = base64.b64decode(message) res = [] for i in range(0, len(base64_de), length): - res.append(cipher.decrypt(base64_de[i:i + length], 0)) + res.append(cipher.decrypt(base64_de[i : i + length], 0)) return b"".join(res).decode() diff --git a/apps/local_model/views/model_apply.py b/apps/local_model/views/model_apply.py index 98c07dd7493..4259c695db1 100644 --- a/apps/local_model/views/model_apply.py +++ b/apps/local_model/views/model_apply.py @@ -1,42 +1,35 @@ # coding=utf-8 """ - @project: MaxKB - @Author:虎 - @file: model_apply.py - @date:2024/8/20 20:38 - @desc: +@project: MaxKB +@Author:虎 +@file: model_apply.py +@date:2024/8/20 20:38 +@desc: """ -from urllib.request import Request -from rest_framework.views import APIView +from urllib.request import Request from common.result import result from local_model.serializers.model_apply_serializers import ModelApplySerializers, ValidateModelSerializers +from rest_framework.views import APIView class LocalModelApply(APIView): class EmbedDocuments(APIView): - def post(self, request: Request, model_id): - return result.success( - ModelApplySerializers(data={'model_id': model_id}).embed_documents(request.data)) + return result.success(ModelApplySerializers(data={"model_id": model_id}).embed_documents(request.data)) class EmbedQuery(APIView): - def post(self, request: Request, model_id): - return result.success( - ModelApplySerializers(data={'model_id': model_id}).embed_query(request.data)) + return result.success(ModelApplySerializers(data={"model_id": model_id}).embed_query(request.data)) class CompressDocuments(APIView): - def post(self, request: Request, model_id): - return result.success( - ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data)) + return result.success(ModelApplySerializers(data={"model_id": model_id}).compress_documents(request.data)) class Unload(APIView): def post(self, request: Request, model_id): - return result.success( - ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data)) + return result.success(ModelApplySerializers(data={"model_id": model_id}).compress_documents(request.data)) class Validate(APIView): def post(self, request: Request): diff --git a/apps/models_provider/api/model.py b/apps/models_provider/api/model.py index d79849f4ab2..5d7fb2721c2 100644 --- a/apps/models_provider/api/model.py +++ b/apps/models_provider/api/model.py @@ -1,13 +1,12 @@ # coding=utf-8 +from common.mixins.api_mixin import APIMixin +from common.result import DefaultResultSerializer, ResultSerializer +from django.utils.translation import gettext_lazy as _ from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import OpenApiParameter +from models_provider.serializers.model_serializer import ModelCreateRequest, ModelModelSerializer from rest_framework import serializers -from common.mixins.api_mixin import APIMixin -from common.result import ResultSerializer, DefaultResultSerializer -from models_provider.serializers.model_serializer import ModelModelSerializer, ModelCreateRequest -from django.utils.translation import gettext_lazy as _ - class ModelCreateResponse(ResultSerializer): def get_data(self): @@ -25,48 +24,49 @@ def get_data(self): @staticmethod def get_parameters(): - return [OpenApiParameter( - name="workspace_id", - description=_("workspace id"), - type=OpenApiTypes.STR, - location=OpenApiParameter.PATH, - required=True, - ), + return [ + OpenApiParameter( + name="workspace_id", + description=_("workspace id"), + type=OpenApiTypes.STR, + location=OpenApiParameter.PATH, # type: ignore + required=True, + ), OpenApiParameter( name="name", description=_("model name"), type=OpenApiTypes.STR, - location=OpenApiParameter.QUERY, + location=OpenApiParameter.QUERY, # type: ignore required=False, ), OpenApiParameter( name="model_type", description=_("model type"), type=OpenApiTypes.STR, - location=OpenApiParameter.QUERY, + location=OpenApiParameter.QUERY, # type: ignore required=False, ), OpenApiParameter( name="model_name", description=_("base model"), type=OpenApiTypes.STR, - location=OpenApiParameter.QUERY, + location=OpenApiParameter.QUERY, # type: ignore required=False, ), OpenApiParameter( name="provider", description=_("provider"), type=OpenApiTypes.STR, - location=OpenApiParameter.QUERY, + location=OpenApiParameter.QUERY, # type: ignore required=False, ), OpenApiParameter( name="create_user", description=_("create user"), type=OpenApiTypes.STR, - location=OpenApiParameter.QUERY, + location=OpenApiParameter.QUERY, # type: ignore required=False, - ) + ), ] @@ -81,33 +81,37 @@ def get_response(): @classmethod def get_parameters(cls): - return [OpenApiParameter( - name="workspace_id", - description=_("workspace id"), - type=OpenApiTypes.STR, - location=OpenApiParameter.PATH, - required=True, - )] + return [ + OpenApiParameter( + name="workspace_id", + description=_("workspace id"), + type=OpenApiTypes.STR, + location=OpenApiParameter.PATH, # type: ignore + required=True, + ) + ] class GetModelApi(APIMixin): - @staticmethod def get_query_params_api(): - return [OpenApiParameter( - name="workspace_id", - description=_("workspace id"), - type=OpenApiTypes.STR, - location=OpenApiParameter.PATH, - required=True, - ), OpenApiParameter( - name="model_id", - description=_("model id"), - type=OpenApiTypes.STR, - location=OpenApiParameter.PATH, - required=True, - ) + return [ + OpenApiParameter( + name="workspace_id", + description=_("workspace id"), + type=OpenApiTypes.STR, + location=OpenApiParameter.PATH, # type: ignore + required=True, + ), + OpenApiParameter( + name="model_id", + description=_("model id"), + type=OpenApiTypes.STR, + location=OpenApiParameter.PATH, # type: ignore + required=True, + ), ] + @staticmethod def get_request(): return [] diff --git a/apps/models_provider/api/provide.py b/apps/models_provider/api/provide.py index 83d81a10603..9e0921dbbd3 100644 --- a/apps/models_provider/api/provide.py +++ b/apps/models_provider/api/provide.py @@ -1,11 +1,10 @@ # coding=utf-8 -from drf_spectacular.types import OpenApiTypes -from drf_spectacular.utils import OpenApiParameter - from common.mixins.api_mixin import APIMixin from common.result import ResultSerializer -from rest_framework import serializers from django.utils.translation import gettext_lazy as _ +from drf_spectacular.types import OpenApiTypes +from drf_spectacular.utils import OpenApiParameter +from rest_framework import serializers class ProvideResponse(ResultSerializer): @@ -65,25 +64,28 @@ class ProvideApi(APIMixin): class ModelParamsForm(APIMixin): @staticmethod def get_query_params_api(): - return [OpenApiParameter( - name="model_type", - description=_("model type"), - type=OpenApiTypes.STR, - location=OpenApiParameter.QUERY, - required=True, - ), OpenApiParameter( - name="provider", - description=_("provider"), - type=OpenApiTypes.STR, - location=OpenApiParameter.QUERY, - required=True, - ), OpenApiParameter( - name="model_name", - description=_("model name"), - type=OpenApiTypes.STR, - location=OpenApiParameter.QUERY, - required=True, - ) + return [ + OpenApiParameter( + name="model_type", + description=_("model type"), + type=OpenApiTypes.STR, + location=OpenApiParameter.QUERY, # type: ignore + required=True, + ), + OpenApiParameter( + name="provider", + description=_("provider"), + type=OpenApiTypes.STR, + location=OpenApiParameter.QUERY, # type: ignore + required=True, + ), + OpenApiParameter( + name="model_name", + description=_("model name"), + type=OpenApiTypes.STR, + location=OpenApiParameter.QUERY, # type: ignore + required=True, + ), ] @staticmethod @@ -93,19 +95,21 @@ def get_response(): class ModelList(APIMixin): @staticmethod def get_query_params_api(): - return [OpenApiParameter( - name="model_type", - description=_("model type"), - type=OpenApiTypes.STR, - location=OpenApiParameter.QUERY, - required=True, - ), OpenApiParameter( - name="provider", - description=_("provider"), - type=OpenApiTypes.STR, - location=OpenApiParameter.QUERY, - required=True, - ) + return [ + OpenApiParameter( + name="model_type", + description=_("model type"), + type=OpenApiTypes.STR, + location=OpenApiParameter.QUERY, # type: ignore + required=True, + ), + OpenApiParameter( + name="provider", + description=_("provider"), + type=OpenApiTypes.STR, + location=OpenApiParameter.QUERY, # type: ignore + required=True, + ), ] @staticmethod @@ -119,17 +123,19 @@ def get_response(): class ModelTypeList(APIMixin): @staticmethod def get_query_params_api(): - return [OpenApiParameter( - # 参数的名称是done - name="provider", - # 对参数的备注 - description=_("provider"), - # 指定参数的类型 - type=OpenApiTypes.STR, - location=OpenApiParameter.QUERY, - # 指定必须给 - required=True, - )] + return [ + OpenApiParameter( + # 参数的名称是done + name="provider", + # 对参数的备注 + description=_("provider"), + # 指定参数的类型 + type=OpenApiTypes.STR, + location=OpenApiParameter.QUERY, # type: ignore + # 指定必须给 + required=True, + ) + ] @staticmethod def get_response(): diff --git a/apps/models_provider/base_model_provider.py b/apps/models_provider/base_model_provider.py index 77f5275cad2..312f5a43729 100644 --- a/apps/models_provider/base_model_provider.py +++ b/apps/models_provider/base_model_provider.py @@ -3,14 +3,12 @@ from abc import ABC, abstractmethod from enum import Enum from functools import reduce -from typing import Dict, Iterator, Type, List - -from pydantic import BaseModel +from typing import Dict, Iterator, List, Type from common.exception.app_exception import AppApiException -from django.utils.translation import gettext_lazy as _ - from common.utils.common import encryption +from django.utils.translation import gettext_lazy as _ +from pydantic import BaseModel class DownModelChunkStatus(Enum): diff --git a/apps/models_provider/constants/model_provider_constants.py b/apps/models_provider/constants/model_provider_constants.py index 4533e0fca9b..d49343eea5d 100644 --- a/apps/models_provider/constants/model_provider_constants.py +++ b/apps/models_provider/constants/model_provider_constants.py @@ -1,8 +1,9 @@ # coding=utf-8 from enum import Enum -from models_provider.impl.aliyun_bai_lian_model_provider.aliyun_bai_lian_model_provider import \ - AliyunBaiLianModelProvider +from models_provider.impl.aliyun_bai_lian_model_provider.aliyun_bai_lian_model_provider import ( + AliyunBaiLianModelProvider, +) from models_provider.impl.anthropic_model_provider.anthropic_model_provider import AnthropicModelProvider from models_provider.impl.aws_bedrock_model_provider.aws_bedrock_model_provider import BedrockModelProvider from models_provider.impl.azure_model_provider.azure_model_provider import AzureModelProvider @@ -11,6 +12,7 @@ from models_provider.impl.gemini_model_provider.gemini_model_provider import GeminiModelProvider from models_provider.impl.kimi_model_provider.kimi_model_provider import KimiModelProvider from models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider +from models_provider.impl.minimax_model_provider.minimax_model_provider import MiniMaxModelProvider from models_provider.impl.ollama_model_provider.ollama_model_provider import OllamaModelProvider from models_provider.impl.openai_model_provider.openai_model_provider import OpenAIModelProvider from models_provider.impl.regolo_model_provider.regolo_model_provider import RegoloModelProvider @@ -18,12 +20,12 @@ from models_provider.impl.tencent_cloud_model_provider.tencent_cloud_model_provider import TencentCloudModelProvider from models_provider.impl.tencent_model_provider.tencent_model_provider import TencentModelProvider from models_provider.impl.vllm_model_provider.vllm_model_provider import VllmModelProvider -from models_provider.impl.volcanic_engine_model_provider.volcanic_engine_model_provider import \ - VolcanicEngineModelProvider +from models_provider.impl.volcanic_engine_model_provider.volcanic_engine_model_provider import ( + VolcanicEngineModelProvider, +) from models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider from models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider from models_provider.impl.xinference_model_provider.xinference_model_provider import XinferenceModelProvider -from models_provider.impl.minimax_model_provider.minimax_model_provider import MiniMaxModelProvider from models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider diff --git a/apps/models_provider/serializers/model_apply_serializers.py b/apps/models_provider/serializers/model_apply_serializers.py index 30c33147f1e..c069cb56a96 100644 --- a/apps/models_provider/serializers/model_apply_serializers.py +++ b/apps/models_provider/serializers/model_apply_serializers.py @@ -1,76 +1,81 @@ # coding=utf-8 """ - @project: MaxKB - @Author:虎 - @file: model_apply_serializers.py - @date:2024/8/20 20:39 - @desc: +@project: MaxKB +@Author:虎 +@file: model_apply_serializers.py +@date:2024/8/20 20:39 +@desc: """ -from django.db import connection -from django.db.models import QuerySet -from langchain_core.documents import Document -from rest_framework import serializers from common.config.embedding_config import ModelManage +from django.db import connection +from django.db.models import QuerySet from django.utils.translation import gettext_lazy as _ - +from langchain_core.documents import Document from models_provider.models import Model from models_provider.tools import get_model +from rest_framework import serializers def get_embedding_model(model_id): model = QuerySet(Model).filter(id=model_id).first() # 手动关闭数据库连接 connection.close() - embedding_model = ModelManage.get_model(model_id, - lambda _id: get_model(model, use_local=True)) + embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, use_local=True)) return embedding_model class EmbedDocuments(serializers.Serializer): - texts = serializers.ListField(required=True, - child=serializers.CharField(required=True, label=_('vector text')), - label=_('vector text list')) + texts = serializers.ListField( + required=True, child=serializers.CharField(required=True, label=_("vector text")), label=_("vector text list") + ) class EmbedQuery(serializers.Serializer): - text = serializers.CharField(required=True, label=_('vector text')) + text = serializers.CharField(required=True, label=_("vector text")) class CompressDocument(serializers.Serializer): - page_content = serializers.CharField(required=True, label=_('text')) - metadata = serializers.DictField(required=False, label=_('metadata')) + page_content = serializers.CharField(required=True, label=_("text")) + metadata = serializers.DictField(required=False, label=_("metadata")) class CompressDocuments(serializers.Serializer): documents = CompressDocument(required=True, many=True) - query = serializers.CharField(required=True, label=_('query')) + query = serializers.CharField(required=True, label=_("query")) class ModelApplySerializers(serializers.Serializer): - model_id = serializers.UUIDField(required=True, label=_('model id')) + model_id = serializers.UUIDField(required=True, label=_("model id")) def embed_documents(self, instance, with_valid=True): if with_valid: self.is_valid(raise_exception=True) EmbedDocuments(data=instance).is_valid(raise_exception=True) - model = get_embedding_model(self.data.get('model_id')) - return model.embed_documents(instance.getlist('texts')) + model = get_embedding_model(self.data.get("model_id")) + return model.embed_documents(instance.getlist("texts")) def embed_query(self, instance, with_valid=True): if with_valid: self.is_valid(raise_exception=True) EmbedQuery(data=instance).is_valid(raise_exception=True) - model = get_embedding_model(self.data.get('model_id')) - return model.embed_query(instance.get('text')) + model = get_embedding_model(self.data.get("model_id")) + return model.embed_query(instance.get("text")) def compress_documents(self, instance, with_valid=True): if with_valid: self.is_valid(raise_exception=True) CompressDocuments(data=instance).is_valid(raise_exception=True) - model = get_embedding_model(self.data.get('model_id')) - return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents( - [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in - instance.get('documents')], instance.get('query'))] + model = get_embedding_model(self.data.get("model_id")) + return [ + {"page_content": d.page_content, "metadata": d.metadata} + for d in model.compress_documents( + [ + Document(page_content=document.get("page_content"), metadata=document.get("metadata")) + for document in instance.get("documents") + ], + instance.get("query"), + ) + ] diff --git a/apps/models_provider/serializers/model_serializer.py b/apps/models_provider/serializers/model_serializer.py index 7e866eda128..0c158b2250b 100644 --- a/apps/models_provider/serializers/model_serializer.py +++ b/apps/models_provider/serializers/model_serializer.py @@ -6,26 +6,25 @@ from typing import Dict import uuid_utils.compat as uuid -from django.core.cache import cache -from django.db import transaction -from django.db.models import QuerySet -from django.utils.translation import gettext_lazy as _ -from rest_framework import serializers - from common.config.embedding_config import ModelManage from common.constants.cache_version import Cache_Version -from common.constants.permission_constants import ResourcePermission, ResourceAuthType +from common.constants.permission_constants import ResourceAuthType, ResourcePermission from common.database_model_manage.database_model_manage import DatabaseModelManage from common.db.search import native_search from common.exception.app_exception import AppApiException from common.utils.common import get_file_content -from common.utils.rsa_util import rsa_long_encrypt, rsa_long_decrypt +from common.utils.rsa_util import rsa_long_decrypt, rsa_long_encrypt +from django.core.cache import cache +from django.db import transaction +from django.db.models import QuerySet +from django.utils.translation import gettext_lazy as _ from maxkb.conf import PROJECT_DIR -from models_provider.base_model_provider import ValidCode, DownModelChunkStatus +from models_provider.base_model_provider import DownModelChunkStatus, ValidCode from models_provider.constants.model_provider_constants import ModelProvideConstants from models_provider.models import Model, Status from models_provider.tools import get_model_credential -from system_manage.models import WorkspaceUserResourcePermission, AuthTargetType +from rest_framework import serializers +from system_manage.models import AuthTargetType, WorkspaceUserResourcePermission from system_manage.models.resource_mapping import ResourceMapping from system_manage.serializers.resource_mapping_serializers import ResourceMappingSerializer from system_manage.serializers.user_resource_permission import UserResourcePermissionSerializer @@ -44,9 +43,19 @@ class ModelModelSerializer(serializers.ModelSerializer): class Meta: model = Model fields = [ - 'id', 'name', 'status', 'model_type', 'model_name', - 'user', 'provider', 'credential', 'meta', - 'model_params_form', 'workspace_id', 'create_time', 'update_time' + "id", + "name", + "status", + "model_type", + "model_name", + "user", + "provider", + "credential", + "meta", + "model_params_form", + "workspace_id", + "create_time", + "update_time", ] @@ -83,19 +92,15 @@ def pull(model: Model, credential: Dict): status = Status.ERROR message = "" for chunk in down_model_chunk.values(): - if chunk.get('status') == DownModelChunkStatus.success.value: + if chunk.get("status") == DownModelChunkStatus.success.value: status = Status.SUCCESS - elif chunk.get('status') == DownModelChunkStatus.error.value: + elif chunk.get("status") == DownModelChunkStatus.error.value: message = chunk.get("digest") - QuerySet(Model).filter(id=model.id).update( - meta={"down_model_chunk": [], "message": message}, - status=status - ) + QuerySet(Model).filter(id=model.id).update(meta={"down_model_chunk": [], "message": message}, status=status) except Exception as e: QuerySet(Model).filter(id=model.id).update( - meta={"down_model_chunk": [], "message": str(e)}, - status=Status.ERROR + meta={"down_model_chunk": [], "message": str(e)}, status=Status.ERROR ) @@ -104,19 +109,19 @@ class ModelSerializer(serializers.Serializer): def model_to_dict(model: Model): credential = json.loads(rsa_long_decrypt(model.credential)) return { - 'id': str(model.id), - 'provider': model.provider, - 'name': model.name, - 'model_type': model.model_type, - 'model_name': model.model_name, - 'status': model.status, - 'meta': model.meta, - 'credential': ModelProvideConstants[model.provider].value.get_model_credential( - model.model_type, model.model_name - ).encryption_dict(credential), - 'workspace_id': model.workspace_id, - 'nick_name': model.user.nick_name if model.user else '', - 'username': model.user.username if model.user else '' + "id": str(model.id), + "provider": model.provider, + "name": model.name, + "model_type": model.model_type, + "model_name": model.model_name, + "status": model.status, + "meta": model.meta, + "credential": ModelProvideConstants[model.provider] + .value.get_model_credential(model.model_type, model.model_name) + .encryption_dict(credential), + "workspace_id": model.workspace_id, + "nick_name": model.user.nick_name if model.user else "", + "username": model.user.username if model.user else "", } class Operate(serializers.Serializer): @@ -132,44 +137,49 @@ def is_valid(self, *, raise_exception=False): model_query = model_query.filter(workspace_id=workspace_id) model = model_query.first() if model is None: - raise AppApiException(500, _('Model does not exist')) - if model.workspace_id == 'None': - raise AppApiException(500, _('Shared models cannot be deleted or modified')) + raise AppApiException(500, _("Model does not exist")) + if model.workspace_id == "None": + raise AppApiException(500, _("Shared models cannot be deleted or modified")) def one(self, with_valid=False): if with_valid: super().is_valid(raise_exception=True) - model = QuerySet(Model).get( - id=self.data.get('id'), workspace_id=self.data.get('workspace_id', 'None') - ) + model = QuerySet(Model).get(id=self.data.get("id"), workspace_id=self.data.get("workspace_id", "None")) return ModelSerializer.model_to_dict(model) def one_meta(self, with_valid=False): model = None if with_valid: super().is_valid(raise_exception=True) - model = QuerySet(Model).filter(id=self.data.get("id"), - workspace_id=self.data.get('workspace_id', 'None')).first() + model = ( + QuerySet(Model) + .filter(id=self.data.get("id"), workspace_id=self.data.get("workspace_id", "None")) + .first() + ) if model is None: - raise AppApiException(500, _('Model does not exist')) - return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type, - 'model_name': model.model_name, - 'status': model.status, - 'meta': model.meta, - 'workspace_id': model.workspace_id, - } + raise AppApiException(500, _("Model does not exist")) + return { + "id": str(model.id), + "provider": model.provider, + "name": model.name, + "model_type": model.model_type, + "model_name": model.model_name, + "status": model.status, + "meta": model.meta, + "workspace_id": model.workspace_id, + } def pause_download(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) - QuerySet(Model).filter(id=self.data.get('id')).update(status=Status.PAUSE_DOWNLOAD) + QuerySet(Model).filter(id=self.data.get("id")).update(status=Status.PAUSE_DOWNLOAD) return True @transaction.atomic def delete(self, with_valid=True): if with_valid: super().is_valid(raise_exception=True) - model_id = self.data.get('id') + model_id = self.data.get("id") model = Model.objects.filter(id=model_id).first() if model is None: return True @@ -198,30 +208,28 @@ def delete(self, with_valid=True): def edit(self, instance: Dict, user_id: str, with_valid=True): if with_valid: super().is_valid(raise_exception=True) - model = QuerySet(Model).filter(id=self.data.get('id')).first() + model = QuerySet(Model).filter(id=self.data.get("id")).first() - credential, model_credential, provider_handler = ModelSerializer.Edit( - data={**instance}).is_valid( - model=model) + credential, model_credential, provider_handler = ModelSerializer.Edit(data={**instance}).is_valid( + model=model + ) try: model.status = Status.SUCCESS - default_params = {item['field']: item['default_value'] for item in model.model_params_form} + default_params = {item["field"]: item["default_value"] for item in model.model_params_form} # 校验模型认证数据 - provider_handler.is_valid_credential(model.model_type, - instance.get("model_name"), - credential, - default_params, - raise_exception=True) + provider_handler.is_valid_credential( + model.model_type, instance.get("model_name"), credential, default_params, raise_exception=True + ) except AppApiException as e: if e.code == ValidCode.model_not_fount: model.status = Status.DOWNLOAD else: raise e - update_keys = ['credential', 'name', 'model_type', 'model_name'] + update_keys = ["credential", "name", "model_type", "model_name"] for update_key in update_keys: if update_key in instance and instance.get(update_key) is not None: - if update_key == 'credential': + if update_key == "credential": model_credential_str = json.dumps(credential) model.__setattr__(update_key, rsa_long_encrypt(model_credential_str)) else: @@ -235,38 +243,35 @@ def edit(self, instance: Dict, user_id: str, with_valid=True): return self.one(with_valid=False) class Edit(serializers.Serializer): - user_id = serializers.CharField(required=False, label=(_('user id'))) + user_id = serializers.CharField(required=False, label=(_("user id"))) - name = serializers.CharField(required=False, max_length=64, - label=(_("model name"))) + name = serializers.CharField(required=False, max_length=64, label=(_("model name"))) model_type = serializers.CharField(required=False, label=(_("model type"))) model_name = serializers.CharField(required=False, label=(_("base model"))) - credential = serializers.DictField(required=False, - label=(_("certification information"))) + credential = serializers.DictField(required=False, label=(_("certification information"))) workspace_id = serializers.CharField(required=False, label=(_("workspace id"))) def is_valid(self, model=None, raise_exception=False): super().is_valid(raise_exception=True) - filter_params = {'workspace_id': model.workspace_id} - if 'name' in self.data and self.data.get('name') is not None: - filter_params['name'] = self.data.get('name') + filter_params = {"workspace_id": model.workspace_id} + if "name" in self.data and self.data.get("name") is not None: + filter_params["name"] = self.data.get("name") if QuerySet(Model).exclude(id=model.id).filter(**filter_params).exists(): - raise AppApiException(500, _('base model【{model_name}】already exists').format( - model_name=self.data.get("name"))) + raise AppApiException( + 500, _("base model【{model_name}】already exists").format(model_name=self.data.get("name")) + ) ModelSerializer.model_to_dict(model) provider = model.provider - model_type = self.data.get('model_type') - model_name = self.data.get( - 'model_name') - credential = self.data.get('credential') + model_type = self.data.get("model_type") + model_name = self.data.get("model_name") + credential = self.data.get("credential") provider_handler = ModelProvideConstants[provider].value - model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type, - model_name) + model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type, model_name) source_model_credential = json.loads(rsa_long_decrypt(model.credential)) source_encryption_model_credential = model_credential.encryption_dict(source_model_credential) if credential is not None: @@ -276,7 +281,7 @@ def is_valid(self, model=None, raise_exception=False): return credential, model_credential, provider_handler class Create(serializers.Serializer): - user_id = serializers.UUIDField(required=True, label=_('user id')) + user_id = serializers.UUIDField(required=True, label=_("user id")) name = serializers.CharField(required=True, max_length=64, label=_("model name")) provider = serializers.CharField(required=True, label=_("provider")) model_type = serializers.CharField(required=True, label=_("model type")) @@ -287,21 +292,21 @@ class Create(serializers.Serializer): def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) - if QuerySet(Model).filter( - name=self.data.get('name'), - workspace_id=self.data.get('workspace_id', 'None') - ).exists(): + if ( + QuerySet(Model) + .filter(name=self.data.get("name"), workspace_id=self.data.get("workspace_id", "None")) + .exists() + ): raise AppApiException( - 500, - _('base model【{model_name}】already exists').format(model_name=self.data.get("name")) + 500, _("base model【{model_name}】already exists").format(model_name=self.data.get("name")) ) - default_params = {item['field']: item['default_value'] for item in self.data.get('model_params_form')} - ModelProvideConstants[self.data.get('provider')].value.is_valid_credential( - self.data.get('model_type'), - self.data.get('model_name'), - self.data.get('credential'), + default_params = {item["field"]: item["default_value"] for item in self.data.get("model_params_form")} + ModelProvideConstants[self.data.get("provider")].value.is_valid_credential( + self.data.get("model_type"), + self.data.get("model_name"), + self.data.get("credential"), default_params, - raise_exception=True + raise_exception=True, ) def insert(self, workspace_id, with_valid=True): @@ -315,28 +320,30 @@ def insert(self, workspace_id, with_valid=True): else: raise e - credential = self.data.get('credential') + credential = self.data.get("credential") model_data = { - 'id': uuid.uuid7(), - 'status': status, - 'user_id': self.data.get('user_id'), - 'name': self.data.get('name'), - 'credential': rsa_long_encrypt(json.dumps(credential)), - 'provider': self.data.get('provider'), - 'model_type': self.data.get('model_type'), - 'model_name': self.data.get('model_name'), - 'model_params_form': self.data.get('model_params_form'), - 'workspace_id': workspace_id + "id": uuid.uuid7(), + "status": status, + "user_id": self.data.get("user_id"), + "name": self.data.get("name"), + "credential": rsa_long_encrypt(json.dumps(credential)), + "provider": self.data.get("provider"), + "model_type": self.data.get("model_type"), + "model_name": self.data.get("model_name"), + "model_params_form": self.data.get("model_params_form"), + "workspace_id": workspace_id, } model = Model(**model_data) try: model.save() - if workspace_id != 'None': - UserResourcePermissionSerializer(data={ - 'workspace_id': workspace_id, - 'user_id': self.data.get('user_id'), - 'auth_target_type': AuthTargetType.MODEL.value - }).auth_resource(str(model.id)) + if workspace_id != "None": + UserResourcePermissionSerializer( + data={ + "workspace_id": workspace_id, + "user_id": self.data.get("user_id"), + "auth_target_type": AuthTargetType.MODEL.value, + } + ).auth_resource(str(model.id)) except Exception as save_error: # 可添加日志记录 raise AppApiException(500, _("Model saving failed")) from save_error @@ -349,12 +356,12 @@ def insert(self, workspace_id, with_valid=True): class Query(serializers.Serializer): user_id = serializers.CharField(required=True, label=_("User ID")) - name = serializers.CharField(required=False, max_length=64, label=_('model name')) - model_type = serializers.CharField(required=False, label=_('model type')) - model_name = serializers.CharField(required=False, label=_('base model')) - provider = serializers.CharField(required=False, label=_('provider')) - create_user = serializers.CharField(required=False, label=_('create user')) - workspace_id = serializers.CharField(required=False, label=_('workspace id')) + name = serializers.CharField(required=False, max_length=64, label=_("model name")) + model_type = serializers.CharField(required=False, label=_("model type")) + model_name = serializers.CharField(required=False, label=_("base model")) + provider = serializers.CharField(required=False, label=_("provider")) + create_user = serializers.CharField(required=False, label=_("create user")) + workspace_id = serializers.CharField(required=False, label=_("workspace id")) @staticmethod def is_x_pack_ee(): @@ -366,15 +373,23 @@ def list(self, workspace_id, with_valid): if with_valid: self.is_valid(raise_exception=True) user_id = self.data.get("user_id") - workspace_manage = is_workspace_manage_permission_read(user_id, workspace_id, 'MODEL:READ') + workspace_manage = is_workspace_manage_permission_read(user_id, workspace_id, "MODEL:READ") query_params = self._build_query_params(workspace_id, workspace_manage, user_id) is_x_pack_ee = self.is_x_pack_ee() - result = native_search(query_params, - select_string=get_file_content( - os.path.join(PROJECT_DIR, "apps", "models_provider", 'sql', - 'list_model.sql' if workspace_manage else ( - 'list_model_user_ee.sql' if is_x_pack_ee else 'list_model_user.sql') - ))) + result = native_search( + query_params, + select_string=get_file_content( + os.path.join( + PROJECT_DIR, + "apps", + "models_provider", + "sql", + "list_model.sql" + if workspace_manage + else ("list_model_user_ee.sql" if is_x_pack_ee else "list_model_user.sql"), + ) + ), + ) return ResourceMappingSerializer().get_resource_count(result) def share_list(self, workspace_id, with_valid=True): @@ -382,24 +397,20 @@ def share_list(self, workspace_id, with_valid=True): self.is_valid(raise_exception=True) user_id = self.data.get("user_id") query_params = self._build_query_params(workspace_id, False, user_id) - result = [ - self._build_model_data( - model - ) for model in query_params.get('model_query_set') - ] + result = [self._build_model_data(model) for model in query_params.get("model_query_set")] return ResourceMappingSerializer().get_resource_count(result) def model_list(self, workspace_id, with_valid=True): if with_valid: self.is_valid(raise_exception=True) user_id = self.data.get("user_id") - workspace_manage = is_workspace_manage_permission_read(user_id, workspace_id, 'MODEL:READ') + workspace_manage = is_workspace_manage_permission_read(user_id, workspace_id, "MODEL:READ") queryset = self._build_query_params(workspace_id, workspace_manage, user_id) get_authorized_model = DatabaseModelManage.get_model("get_authorized_model") shared_queryset = QuerySet(Model).none() if get_authorized_model is not None: - shared_queryset = self._build_query_params('None', False, user_id)['model_query_set'] + shared_queryset = self._build_query_params("None", False, user_id)["model_query_set"] shared_queryset = get_authorized_model(shared_queryset, workspace_id) # 构建共享模型和普通模型列表 @@ -410,60 +421,64 @@ def model_list(self, workspace_id, with_valid=True): queryset, select_string=get_file_content( os.path.join( - PROJECT_DIR, "apps", "models_provider", 'sql', - 'list_model.sql' if workspace_manage else ( - 'list_model_user_ee.sql' if is_x_pack_ee else 'list_model_user.sql') + PROJECT_DIR, + "apps", + "models_provider", + "sql", + "list_model.sql" + if workspace_manage + else ("list_model_user_ee.sql" if is_x_pack_ee else "list_model_user.sql"), ) - ) + ), ) - return { - "shared_model": shared_model, - "model": normal_model - } + return {"shared_model": shared_model, "model": normal_model} def _build_query_params(self, workspace_id, workspace_manage: bool, user_id): queryset = QuerySet(Model) if workspace_id: queryset = queryset.filter(workspace_id=workspace_id) - for field in ['name', 'model_type', 'model_name', 'provider', 'create_user']: + for field in ["name", "model_type", "model_name", "provider", "create_user"]: value = self.data.get(field) if value is not None: - if field == 'name': - queryset = queryset.filter(**{f'{field}__icontains': value}) - elif field == 'create_user': + if field == "name": + queryset = queryset.filter(**{f"{field}__icontains": value}) + elif field == "create_user": queryset = queryset.filter(user_id=value) else: queryset = queryset.filter(**{field: value}) queryset = queryset.order_by("-create_time") - return { - 'model_query_set': queryset, - 'workspace_user_resource_permission_query_set': QuerySet(WorkspaceUserResourcePermission).filter( - auth_target_type="MODEL", - workspace_id=workspace_id, - user_id=user_id)} if ( - not workspace_manage) else { - 'model_query_set': queryset, - } + return ( + { + "model_query_set": queryset, + "workspace_user_resource_permission_query_set": QuerySet(WorkspaceUserResourcePermission).filter( + auth_target_type="MODEL", workspace_id=workspace_id, user_id=user_id + ), + } + if (not workspace_manage) + else { + "model_query_set": queryset, + } + ) def _build_model_data(self, model): return { - 'id': str(model.id), - 'provider': model.provider, - 'name': model.name, - 'model_type': model.model_type, - 'model_name': model.model_name, - 'status': model.status, - 'meta': model.meta, - 'user_id': model.user_id, - 'username': model.user.username, - 'nick_name': model.user.nick_name, + "id": str(model.id), + "provider": model.provider, + "name": model.name, + "model_type": model.model_type, + "model_name": model.model_name, + "status": model.status, + "meta": model.meta, + "user_id": model.user_id, + "username": model.user.username, + "nick_name": model.user.nick_name, } def page(self, current_page, page_size): pass class ModelParams(serializers.Serializer): - id = serializers.UUIDField(required=True, label=_('model id')) + id = serializers.UUIDField(required=True, label=_("model id")) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -474,7 +489,7 @@ def is_valid(self, *, raise_exception=False): def get_model_params(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) - model_id = self.data.get('id') + model_id = self.data.get("id") model = QuerySet(Model).filter(id=model_id).first() return model.model_params_form @@ -483,22 +498,26 @@ def save_model_params_form(self, model_params_form, with_valid=True): self.is_valid(raise_exception=True) if model_params_form is None: model_params_form = [] - model_id = self.data.get('id') + model_id = self.data.get("id") model = QuerySet(Model).filter(id=model_id).first() if not isinstance(model_params_form, list): - raise AppApiException(500, _('model_params_form must be a list')) + raise AppApiException(500, _("model_params_form must be a list")) # 还需要校验几个字段:label required default_value # 校验每个配置项的必要字段 for index, param in enumerate(model_params_form): if not isinstance(param, dict): - raise AppApiException(500, _('The {index}th item in model_params_form must be a dictionary').format( - index=index)) + raise AppApiException( + 500, _("The {index}th item in model_params_form must be a dictionary").format(index=index) + ) # 校验 label 字段 - if 'label' not in param or param['label'] is None: - raise AppApiException(500, - _('The label field is required for the {index}th item in model_params_form').format( - index=index)) + if "label" not in param or param["label"] is None: + raise AppApiException( + 500, + _("The label field is required for the {index}th item in model_params_form").format( + index=index + ), + ) model.model_params_form = model_params_form model.save() @@ -506,31 +525,31 @@ def save_model_params_form(self, model_params_form, with_valid=True): class WorkspaceSharedModelSerializer(serializers.Serializer): - workspace_id = serializers.CharField(required=True, label=_('workspace id')) - name = serializers.CharField(required=False, max_length=64, label=_('model name')) - model_type = serializers.CharField(required=False, label=_('model type')) - model_name = serializers.CharField(required=False, label=_('base model')) - provider = serializers.CharField(required=False, label=_('provider')) - create_user = serializers.CharField(required=False, label=_('create user')) + workspace_id = serializers.CharField(required=True, label=_("workspace id")) + name = serializers.CharField(required=False, max_length=64, label=_("model name")) + model_type = serializers.CharField(required=False, label=_("model type")) + model_name = serializers.CharField(required=False, label=_("base model")) + provider = serializers.CharField(required=False, label=_("provider")) + create_user = serializers.CharField(required=False, label=_("create user")) def get_share_model_list(self): self.is_valid(raise_exception=True) - workspace_id = self.data.get('workspace_id') + workspace_id = self.data.get("workspace_id") queryset = self._build_queryset(workspace_id) return [ { - 'id': str(model.id), - 'provider': model.provider, - 'name': model.name, - 'model_type': model.model_type, - 'model_name': model.model_name, - 'status': model.status, - 'meta': model.meta, - 'user_id': model.user_id, - 'nick_name': model.user.nick_name, - 'username': model.user.username + "id": str(model.id), + "provider": model.provider, + "name": model.name, + "model_type": model.model_type, + "model_name": model.model_name, + "status": model.status, + "meta": model.meta, + "user_id": model.user_id, + "nick_name": model.user.nick_name, + "username": model.user.username, } for model in queryset.order_by("-create_time") ] @@ -542,12 +561,12 @@ def _build_queryset(self, workspace_id): if get_authorized_model is not None: queryset = get_authorized_model(queryset, workspace_id) - for field in ['name', 'model_type', 'model_name', 'provider', 'create_user']: + for field in ["name", "model_type", "model_name", "provider", "create_user"]: value = self.data.get(field) if value is not None: - if field == 'name': - queryset = queryset.filter(**{f'{field}__icontains': value}) - elif field == 'create_user': + if field == "name": + queryset = queryset.filter(**{f"{field}__icontains": value}) + elif field == "create_user": queryset = queryset.filter(user_id=value) else: queryset = queryset.filter(**{field: value}) diff --git a/apps/models_provider/tools.py b/apps/models_provider/tools.py index 9bdb8b8d261..4ec31303a03 100644 --- a/apps/models_provider/tools.py +++ b/apps/models_provider/tools.py @@ -1,25 +1,24 @@ # coding=utf-8 """ - @project: MaxKB - @Author:虎 - @file: tools.py - @date:2024/7/22 11:18 - @desc: +@project: MaxKB +@Author:虎 +@file: tools.py +@date:2024/7/22 11:18 +@desc: """ -from django.db import connection -from django.db.models import QuerySet - -from common.config.embedding_config import ModelManage -from common.database_model_manage.database_model_manage import DatabaseModelManage -from models_provider.base_model_provider import ModelTypeConst -from models_provider.models import Model -from django.utils.translation import gettext_lazy as _ import json from typing import Dict +from common.config.embedding_config import ModelManage +from common.database_model_manage.database_model_manage import DatabaseModelManage from common.utils.rsa_util import rsa_long_decrypt +from django.db import connection +from django.db.models import QuerySet +from django.utils.translation import gettext_lazy as _ +from models_provider.base_model_provider import ModelTypeConst from models_provider.constants.model_provider_constants import ModelProvideConstants +from models_provider.models import Model def get_model_(provider, model_type, model_name, credential, model_id, use_local=False, **kwargs): @@ -33,12 +32,15 @@ def get_model_(provider, model_type, model_name, credential, model_id, use_local @param use_local: 是否调用本地模型 只适用于本地供应商 @return: 模型实例 """ - model = get_provider(provider).get_model(model_type, model_name, - json.loads( - rsa_long_decrypt(credential)), - model_id=model_id, - use_local=use_local, - streaming=True, **kwargs) + model = get_provider(provider).get_model( + model_type, + model_name, + json.loads(rsa_long_decrypt(credential)), + model_id=model_id, + use_local=use_local, + streaming=True, + **kwargs, + ) return model @@ -90,8 +92,9 @@ def get_model_type_list(provider): return get_provider(provider).get_model_type_list() -def is_valid_credential(provider, model_type, model_name, model_credential: Dict[str, object], model_params, - raise_exception=False): +def is_valid_credential( + provider, model_type, model_name, model_credential: Dict[str, object], model_params, raise_exception=False +): """ 校验模型认证参数 @param provider: 供应商字符串 @@ -101,8 +104,9 @@ def is_valid_credential(provider, model_type, model_name, model_credential: Dict @param raise_exception: 是否抛出错误 @return: True|False """ - return get_provider(provider).is_valid_credential(model_type, model_name, model_credential, model_params, - raise_exception) + return get_provider(provider).is_valid_credential( + model_type, model_name, model_credential, model_params, raise_exception + ) def get_model_by_id(_id, workspace_id): @@ -126,10 +130,7 @@ def convert_to_int(value): return value return value - return { - p.get('field'): convert_to_int(p.get('default_value')) - for p in model.model_params_form - } + return {p.get("field"): convert_to_int(p.get("default_value")) for p in model.model_params_form} def reset_model_params(default_model_params, **kwargs): @@ -151,6 +152,6 @@ def get_model_instance_by_model_workspace_id(model_id, workspace_id, **kwargs): model = get_model_by_id(model_id, workspace_id) default_model_params = get_model_default_params(model) if model.model_type == ModelTypeConst.RERANKER.name: - default_model_params.setdefault('top_n', 3) + default_model_params.setdefault("top_n", 3) model_params = reset_model_params(default_model_params, **kwargs) return ModelManage.get_model(model_id, lambda _id: get_model(model, **model_params)) diff --git a/apps/models_provider/views/model.py b/apps/models_provider/views/model.py index fe498b20ac1..a2c60796f6c 100644 --- a/apps/models_provider/views/model.py +++ b/apps/models_provider/views/model.py @@ -1,28 +1,27 @@ # coding=utf-8 """ - @project: MaxKB - @Author:虎虎 - @file: user.py - @date:2025/4/14 19:25 - @desc: +@project: MaxKB +@Author:虎虎 +@file: user.py +@date:2025/4/14 19:25 +@desc: """ -from django.db.models import QuerySet -from drf_spectacular.utils import extend_schema -from rest_framework.views import APIView -from django.utils.translation import gettext_lazy as _ -from rest_framework.request import Request from common.auth import TokenAuth from common.auth.authentication import has_permissions -from common.constants.permission_constants import PermissionConstants, RoleConstants, ViewPermission, CompareConstants +from common.constants.permission_constants import CompareConstants, PermissionConstants, RoleConstants, ViewPermission from common.log.log import log from common.result import result from common.utils.common import query_params_to_single_dict -from models_provider.api.model import ModelCreateAPI, GetModelApi, ModelEditApi, ModelListResponse, DefaultModelResponse +from django.db.models import QuerySet +from django.utils.translation import gettext_lazy as _ +from drf_spectacular.utils import extend_schema +from models_provider.api.model import DefaultModelResponse, GetModelApi, ModelCreateAPI, ModelEditApi, ModelListResponse from models_provider.api.provide import ProvideApi from models_provider.models import Model -from models_provider.serializers.model_serializer import ModelSerializer, \ - WorkspaceSharedModelSerializer +from models_provider.serializers.model_serializer import ModelSerializer, WorkspaceSharedModelSerializer +from rest_framework.request import Request +from rest_framework.views import APIView from system_manage.views import encryption_str @@ -36,47 +35,49 @@ def get_edit_model_details(request): path = request.path body = request.data query = request.query_params - credential = body.get('credential', {}) + credential = body.get("credential", {}) credential_encryption_ed = encryption_credential(credential) - return { - 'path': path, - 'body': {**body, 'credential': credential_encryption_ed}, - 'query': query - } + return {"path": path, "body": {**body, "credential": credential_encryption_ed}, "query": query} def get_model_operation_object(model_id): model_model = QuerySet(model=Model).filter(id=model_id).first() if model_model is not None: - return { - "name": model_model.name - } + return {"name": model_model.name} return {} class ModelSetting(APIView): authentication_classes = [TokenAuth] - @extend_schema(methods=['POST'], - summary=_("Create model"), - description=_("Create model"), - operation_id=_("Create model"), # type: ignore - tags=[_("Model")], # type: ignore - parameters=ModelCreateAPI.get_parameters(), - request=ModelCreateAPI.get_request(), - responses=ModelCreateAPI.get_response()) - @has_permissions(PermissionConstants.MODEL_CREATE.get_workspace_permission(), - PermissionConstants.MODEL_EDIT.get_workspace_permission_workspace_manage_role(), - RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role()) - @log(menu='model', operate='Create model', - get_operation_object=lambda r, k: {'name': r.date.get('name')}, - get_details=get_edit_model_details, - ) + @extend_schema( + methods=["POST"], + summary=_("Create model"), + description=_("Create model"), + operation_id=_("Create model"), # type: ignore + tags=[_("Model")], # type: ignore + parameters=ModelCreateAPI.get_parameters(), + request=ModelCreateAPI.get_request(), + responses=ModelCreateAPI.get_response(), + ) + @has_permissions( + PermissionConstants.MODEL_CREATE.get_workspace_permission(), + PermissionConstants.MODEL_EDIT.get_workspace_permission_workspace_manage_role(), + RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), + RoleConstants.USER.get_workspace_role(), + ) + @log( + menu="model", + operate="Create model", + get_operation_object=lambda r, k: {"name": r.date.get("name")}, + get_details=get_edit_model_details, + ) def post(self, request: Request, workspace_id: str): return result.success( ModelSerializer.Create( - data={**request.data, 'user_id': request.user.id, 'workspace_id': workspace_id}).insert(workspace_id, - with_valid=True)) + data={**request.data, "user_id": request.user.id, "workspace_id": workspace_id} + ).insert(workspace_id, with_valid=True) + ) # @extend_schema(methods=['PUT'], # summary=_('Update model'), @@ -90,193 +91,247 @@ def post(self, request: Request, workspace_id: str): # ModelSerializer.Create(data={**request.data, 'user_id': str(request.user.id)}).insert(request.user.id, # with_valid=True)) - @extend_schema(methods=['GET'], - summary=_('Query model list'), - description=_('Query model list'), - operation_id=_('Query model list'), # type: ignore - parameters=ModelListResponse.get_parameters(), - responses=ModelListResponse.get_response(), - tags=[_('Model')]) # type: ignore - @has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(), - PermissionConstants.MODEL_READ.get_workspace_permission_workspace_manage_role(), - RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role()) + @extend_schema( + methods=["GET"], + summary=_("Query model list"), + description=_("Query model list"), + operation_id=_("Query model list"), # type: ignore + parameters=ModelListResponse.get_parameters(), + responses=ModelListResponse.get_response(), + tags=[_("Model")], # type: ignore + ) + @has_permissions( + PermissionConstants.MODEL_READ.get_workspace_permission(), + PermissionConstants.MODEL_READ.get_workspace_permission_workspace_manage_role(), + RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), + RoleConstants.USER.get_workspace_role(), + ) def get(self, request: Request, workspace_id: str): return result.success( ModelSerializer.Query( - data={**query_params_to_single_dict(request.query_params), 'user_id': str(request.user.id)}).list( - workspace_id=workspace_id, - with_valid=True)) + data={**query_params_to_single_dict(request.query_params), "user_id": str(request.user.id)} + ).list(workspace_id=workspace_id, with_valid=True) + ) class Operate(APIView): authentication_classes = [TokenAuth] - @extend_schema(methods=['PUT'], - summary=_('Update model'), - description=_('Update model'), - operation_id=_('Update model'), # type: ignore - request=ModelEditApi.get_request(), - parameters=GetModelApi.get_parameters(), - responses=ModelEditApi.get_response(), - tags=[_('Model')]) # type: ignore - @has_permissions(PermissionConstants.MODEL_EDIT.get_workspace_model_permission(), - RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), - PermissionConstants.MODEL_EDIT.get_workspace_permission_workspace_manage_role(), - ViewPermission([RoleConstants.USER.get_workspace_role()], - [PermissionConstants.MODEL.get_workspace_model_permission()], - CompareConstants.AND), ) - @log(menu='model', operate='Update model', - get_operation_object=lambda r, k: get_model_operation_object(k.get('model_id')), - get_details=get_edit_model_details, - ) + @extend_schema( + methods=["PUT"], + summary=_("Update model"), + description=_("Update model"), + operation_id=_("Update model"), # type: ignore + request=ModelEditApi.get_request(), + parameters=GetModelApi.get_parameters(), + responses=ModelEditApi.get_response(), + tags=[_("Model")], # type: ignore + ) + @has_permissions( + PermissionConstants.MODEL_EDIT.get_workspace_model_permission(), + RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), + PermissionConstants.MODEL_EDIT.get_workspace_permission_workspace_manage_role(), + ViewPermission( + [RoleConstants.USER.get_workspace_role()], + [PermissionConstants.MODEL.get_workspace_model_permission()], + CompareConstants.AND, + ), + ) + @log( + menu="model", + operate="Update model", + get_operation_object=lambda r, k: get_model_operation_object(k.get("model_id")), + get_details=get_edit_model_details, + ) def put(self, request: Request, workspace_id, model_id: str): return result.success( ModelSerializer.Operate( - data={'id': model_id, 'user_id': request.user.id, 'workspace_id': workspace_id}).edit(request.data, - str(request.user.id))) + data={"id": model_id, "user_id": request.user.id, "workspace_id": workspace_id} + ).edit(request.data, str(request.user.id)) + ) - @extend_schema(methods=['DELETE'], - summary=_('Delete model'), - description=_('Delete model'), - operation_id=_('Delete model'), # type: ignore - parameters=GetModelApi.get_parameters(), - responses=DefaultModelResponse.get_response(), - tags=[_('Model')]) # type: ignore - @has_permissions(PermissionConstants.MODEL_DELETE.get_workspace_model_permission(), - PermissionConstants.MODEL_DELETE.get_workspace_permission_workspace_manage_role(), - RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), - ViewPermission([RoleConstants.USER.get_workspace_role()], - [PermissionConstants.MODEL.get_workspace_model_permission()], - CompareConstants.AND), ) - @log(menu='model', operate='Delete model', - get_operation_object=lambda r, k: get_model_operation_object(k.get('model_id')), - ) + @extend_schema( + methods=["DELETE"], + summary=_("Delete model"), + description=_("Delete model"), + operation_id=_("Delete model"), # type: ignore + parameters=GetModelApi.get_parameters(), + responses=DefaultModelResponse.get_response(), + tags=[_("Model")], # type: ignore + ) + @has_permissions( + PermissionConstants.MODEL_DELETE.get_workspace_model_permission(), + PermissionConstants.MODEL_DELETE.get_workspace_permission_workspace_manage_role(), + RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), + ViewPermission( + [RoleConstants.USER.get_workspace_role()], + [PermissionConstants.MODEL.get_workspace_model_permission()], + CompareConstants.AND, + ), + ) + @log( + menu="model", + operate="Delete model", + get_operation_object=lambda r, k: get_model_operation_object(k.get("model_id")), + ) def delete(self, request: Request, workspace_id: str, model_id: str): return result.success( ModelSerializer.Operate( - data={'id': model_id, 'user_id': request.user.id, 'workspace_id': workspace_id}).delete()) + data={"id": model_id, "user_id": request.user.id, "workspace_id": workspace_id} + ).delete() + ) - @extend_schema(methods=['GET'], - summary=_('Query model details'), - description=_('Query model details'), - operation_id=_('Query model details'), # type: ignore - parameters=GetModelApi.get_parameters(), - responses=GetModelApi.get_response(), - tags=[_('Model')]) # type: ignore - @has_permissions(PermissionConstants.MODEL_READ.get_workspace_model_permission(), - PermissionConstants.MODEL_READ.get_workspace_permission_workspace_manage_role(), - RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), - ViewPermission([RoleConstants.USER.get_workspace_role()], - [PermissionConstants.MODEL.get_workspace_model_permission()], - CompareConstants.AND), ) + @extend_schema( + methods=["GET"], + summary=_("Query model details"), + description=_("Query model details"), + operation_id=_("Query model details"), # type: ignore + parameters=GetModelApi.get_parameters(), + responses=GetModelApi.get_response(), + tags=[_("Model")], # type: ignore + ) + @has_permissions( + PermissionConstants.MODEL_READ.get_workspace_model_permission(), + PermissionConstants.MODEL_READ.get_workspace_permission_workspace_manage_role(), + RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), + ViewPermission( + [RoleConstants.USER.get_workspace_role()], + [PermissionConstants.MODEL.get_workspace_model_permission()], + CompareConstants.AND, + ), + ) def get(self, request: Request, workspace_id: str, model_id: str): return result.success( ModelSerializer.Operate( - data={'id': model_id, 'user_id': request.user.id, 'workspace_id': workspace_id}).one( - with_valid=True)) + data={"id": model_id, "user_id": request.user.id, "workspace_id": workspace_id} + ).one(with_valid=True) + ) class ModelParamsForm(APIView): authentication_classes = [TokenAuth] - @extend_schema(methods=['GET'], - summary=_('Get model parameter form'), - description=_('Get model parameter form'), - operation_id=_('Get model parameter form'), # type: ignore - parameters=GetModelApi.get_parameters(), - responses=ProvideApi.ModelParamsForm.get_response(), - tags=[_('Model')]) # type: ignore - @has_permissions(PermissionConstants.MODEL_READ.get_workspace_model_permission(), - PermissionConstants.KNOWLEDGE_READ.get_workspace_permission(), - PermissionConstants.APPLICATION_READ.get_workspace_permission(), - PermissionConstants.MODEL_READ.get_workspace_permission_workspace_manage_role(), - PermissionConstants.KNOWLEDGE_READ.get_workspace_permission_workspace_manage_role(), - PermissionConstants.APPLICATION_READ.get_workspace_permission_workspace_manage_role(), - PermissionConstants.MODEL_READ.get_workspace_permission(), - RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), - RoleConstants.USER.get_workspace_role(),) + @extend_schema( + methods=["GET"], + summary=_("Get model parameter form"), + description=_("Get model parameter form"), + operation_id=_("Get model parameter form"), # type: ignore + parameters=GetModelApi.get_parameters(), + responses=ProvideApi.ModelParamsForm.get_response(), + tags=[_("Model")], # type: ignore + ) + @has_permissions( + PermissionConstants.MODEL_READ.get_workspace_model_permission(), + PermissionConstants.KNOWLEDGE_READ.get_workspace_permission(), + PermissionConstants.APPLICATION_READ.get_workspace_permission(), + PermissionConstants.MODEL_READ.get_workspace_permission_workspace_manage_role(), + PermissionConstants.KNOWLEDGE_READ.get_workspace_permission_workspace_manage_role(), + PermissionConstants.APPLICATION_READ.get_workspace_permission_workspace_manage_role(), + PermissionConstants.MODEL_READ.get_workspace_permission(), + RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), + RoleConstants.USER.get_workspace_role(), + ) def get(self, request: Request, workspace_id: str, model_id: str): - return result.success( - ModelSerializer.ModelParams(data={'id': model_id}).get_model_params()) + return result.success(ModelSerializer.ModelParams(data={"id": model_id}).get_model_params()) - @extend_schema(methods=['PUT'], - summary=_('Save model parameter form'), - description=_('Save model parameter form'), - operation_id=_('Save model parameter form'), # type: ignore - parameters=GetModelApi.get_parameters(), - request=GetModelApi.get_request(), - responses=ProvideApi.ModelParamsForm.get_response(), - tags=[_('Model')]) # type: ignore - @has_permissions(PermissionConstants.MODEL_EDIT.get_workspace_model_permission(), - PermissionConstants.MODEL_EDIT.get_workspace_permission_workspace_manage_role(), - RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), - PermissionConstants.MODEL_READ.get_workspace_permission(), - ViewPermission([RoleConstants.USER.get_workspace_role()], - [PermissionConstants.MODEL.get_workspace_model_permission()], - CompareConstants.AND), ) - @log(menu='model', operate='Save model parameter form', - get_operation_object=lambda r, k: get_model_operation_object(k.get('model_id')), - ) + @extend_schema( + methods=["PUT"], + summary=_("Save model parameter form"), + description=_("Save model parameter form"), + operation_id=_("Save model parameter form"), # type: ignore + parameters=GetModelApi.get_parameters(), + request=GetModelApi.get_request(), + responses=ProvideApi.ModelParamsForm.get_response(), + tags=[_("Model")], # type: ignore + ) + @has_permissions( + PermissionConstants.MODEL_EDIT.get_workspace_model_permission(), + PermissionConstants.MODEL_EDIT.get_workspace_permission_workspace_manage_role(), + RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), + PermissionConstants.MODEL_READ.get_workspace_permission(), + ViewPermission( + [RoleConstants.USER.get_workspace_role()], + [PermissionConstants.MODEL.get_workspace_model_permission()], + CompareConstants.AND, + ), + ) + @log( + menu="model", + operate="Save model parameter form", + get_operation_object=lambda r, k: get_model_operation_object(k.get("model_id")), + ) def put(self, request: Request, workspace_id: str, model_id: str): return result.success( - ModelSerializer.ModelParams(data={'id': model_id}).save_model_params_form(request.data)) + ModelSerializer.ModelParams(data={"id": model_id}).save_model_params_form(request.data) + ) class ModelMeta(APIView): authentication_classes = [TokenAuth] - @extend_schema(methods=['GET'], - summary=_( - 'Query model meta information, this interface does not carry authentication information'), - description=_( - 'Query model meta information, this interface does not carry authentication information'), - operation_id=_( - 'Query model meta information, this interface does not carry authentication information'), - parameters=GetModelApi.get_parameters(), - responses=GetModelApi.get_response(), - tags=[_('Model')]) # type: ignore - @has_permissions(PermissionConstants.MODEL_READ.get_workspace_model_permission(), - PermissionConstants.MODEL_READ.get_workspace_permission_workspace_manage_role(), - RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), - PermissionConstants.MODEL_READ.get_workspace_permission(), - ViewPermission([RoleConstants.USER.get_workspace_role()], - [PermissionConstants.MODEL.get_workspace_model_permission()], - CompareConstants.AND), ) + @extend_schema( + methods=["GET"], + summary=_("Query model meta information, this interface does not carry authentication information"), + description=_("Query model meta information, this interface does not carry authentication information"), + operation_id=_("Query model meta information, this interface does not carry authentication information"), # type: ignore + parameters=GetModelApi.get_parameters(), + responses=GetModelApi.get_response(), + tags=[_("Model")], # type: ignore + ) + @has_permissions( + PermissionConstants.MODEL_READ.get_workspace_model_permission(), + PermissionConstants.MODEL_READ.get_workspace_permission_workspace_manage_role(), + RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), + PermissionConstants.MODEL_READ.get_workspace_permission(), + ViewPermission( + [RoleConstants.USER.get_workspace_role()], + [PermissionConstants.MODEL.get_workspace_model_permission()], + CompareConstants.AND, + ), + ) def get(self, request: Request, workspace_id: str, model_id: str): return result.success( - ModelSerializer.Operate(data={'id': model_id, 'workspace_id': workspace_id}).one_meta(with_valid=True)) + ModelSerializer.Operate(data={"id": model_id, "workspace_id": workspace_id}).one_meta(with_valid=True) + ) class PauseDownload(APIView): authentication_classes = [TokenAuth] - @extend_schema(methods=['PUT'], - summary=_('Pause model download'), - description=_('Pause model download'), - operation_id=_('Pause model download'), # type: ignore - parameters=GetModelApi.get_parameters(), - request=GetModelApi.get_request(), - responses=DefaultModelResponse.get_response(), - tags=[_('Model')]) # type: ignore - @has_permissions(PermissionConstants.MODEL_CREATE.get_workspace_model_permission(), - PermissionConstants.MODEL_CREATE.get_workspace_permission_workspace_manage_role(), - RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), - ViewPermission([RoleConstants.USER.get_workspace_role()], - [PermissionConstants.MODEL.get_workspace_model_permission()], - CompareConstants.AND), ) + @extend_schema( + methods=["PUT"], + summary=_("Pause model download"), + description=_("Pause model download"), + operation_id=_("Pause model download"), # type: ignore + parameters=GetModelApi.get_parameters(), + request=GetModelApi.get_request(), + responses=DefaultModelResponse.get_response(), + tags=[_("Model")], # type: ignore + ) + @has_permissions( + PermissionConstants.MODEL_CREATE.get_workspace_model_permission(), + PermissionConstants.MODEL_CREATE.get_workspace_permission_workspace_manage_role(), + RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), + ViewPermission( + [RoleConstants.USER.get_workspace_role()], + [PermissionConstants.MODEL.get_workspace_model_permission()], + CompareConstants.AND, + ), + ) def put(self, request: Request, workspace_id: str, model_id: str): return result.success( - ModelSerializer.Operate(data={'id': model_id, 'workspace_id': workspace_id}).pause_download()) + ModelSerializer.Operate(data={"id": model_id, "workspace_id": workspace_id}).pause_download() + ) class WorkspaceSharedModelSetting(APIView): authentication_classes = [TokenAuth] @extend_schema( - methods=['Get'], - summary=_('Get Share model by workspace id'), - description=_('Get Share model by workspace id'), - operation_id=_('Get Share model by workspace id'), # type: ignore + methods=["Get"], + summary=_("Get Share model by workspace id"), + description=_("Get Share model by workspace id"), + operation_id=_("Get Share model by workspace id"), # type: ignore parameters=ModelListResponse.get_parameters(), responses=DefaultModelResponse.get_response(), - tags=[_('Shared Model')] - ) # type: ignore + tags=[_("Shared Model")], # type: ignore + ) @has_permissions( PermissionConstants.MODEL_READ.get_workspace_permission(), PermissionConstants.MODEL_READ.get_workspace_permission_workspace_manage_role(), @@ -285,30 +340,37 @@ class WorkspaceSharedModelSetting(APIView): ) def get(self, request: Request, workspace_id: str): return result.success( - WorkspaceSharedModelSerializer(data={**query_params_to_single_dict(request.query_params), - 'workspace_id': workspace_id}).get_share_model_list()) + WorkspaceSharedModelSerializer( + data={**query_params_to_single_dict(request.query_params), "workspace_id": workspace_id} + ).get_share_model_list() + ) class ModelList(APIView): authentication_classes = [TokenAuth] - @extend_schema(methods=['GET'], - summary=_('Query all model list'), - description=_('Query all model list'), - operation_id=_('Query all model list'), # type: ignore - parameters=ModelListResponse.get_parameters(), - responses=ModelListResponse.get_response(), - tags=[_('Model')]) # type: ignore - @has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission(), - PermissionConstants.KNOWLEDGE_READ.get_workspace_permission(), - PermissionConstants.APPLICATION_READ.get_workspace_permission(), - PermissionConstants.MODEL_READ.get_workspace_permission_workspace_manage_role(), - PermissionConstants.KNOWLEDGE_READ.get_workspace_permission_workspace_manage_role(), - PermissionConstants.APPLICATION_READ.get_workspace_permission_workspace_manage_role(), - RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role()) + @extend_schema( + methods=["GET"], + summary=_("Query all model list"), + description=_("Query all model list"), + operation_id=_("Query all model list"), # type: ignore + parameters=ModelListResponse.get_parameters(), + responses=ModelListResponse.get_response(), + tags=[_("Model")], # type: ignore + ) + @has_permissions( + PermissionConstants.MODEL_READ.get_workspace_permission(), + PermissionConstants.KNOWLEDGE_READ.get_workspace_permission(), + PermissionConstants.APPLICATION_READ.get_workspace_permission(), + PermissionConstants.MODEL_READ.get_workspace_permission_workspace_manage_role(), + PermissionConstants.KNOWLEDGE_READ.get_workspace_permission_workspace_manage_role(), + PermissionConstants.APPLICATION_READ.get_workspace_permission_workspace_manage_role(), + RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), + RoleConstants.USER.get_workspace_role(), + ) def get(self, request: Request, workspace_id: str): return result.success( ModelSerializer.Query( - data={**query_params_to_single_dict(request.query_params), 'user_id': str(request.user.id)}).model_list( - workspace_id=workspace_id, - with_valid=True)) + data={**query_params_to_single_dict(request.query_params), "user_id": str(request.user.id)} + ).model_list(workspace_id=workspace_id, with_valid=True) + ) diff --git a/apps/models_provider/views/model_apply.py b/apps/models_provider/views/model_apply.py index d7e691c336f..d9231457541 100644 --- a/apps/models_provider/views/model_apply.py +++ b/apps/models_provider/views/model_apply.py @@ -1,57 +1,54 @@ # coding=utf-8 """ - @project: MaxKB - @Author:虎 - @file: model_apply.py - @date:2024/8/20 20:38 - @desc: +@project: MaxKB +@Author:虎 +@file: model_apply.py +@date:2024/8/20 20:38 +@desc: """ -from urllib.request import Request +from urllib.request import Request +from common.result import result from django.utils.translation import gettext_lazy as _ from drf_spectacular.utils import extend_schema -from rest_framework.views import APIView - -from common.auth.authentication import has_permissions -from common.constants.permission_constants import PermissionConstants -from common.result import result from models_provider.api.model import DefaultModelResponse from models_provider.serializers.model_apply_serializers import ModelApplySerializers +from rest_framework.views import APIView class ModelApply(APIView): class EmbedDocuments(APIView): - @extend_schema(methods=['POST'], - summary=_('Vectorization documentation'), - description=_('Vectorization documentation'), - operation_id=_('Vectorization documentation'), # type: ignore - responses=DefaultModelResponse.get_response(), - tags=[_('Model')] # type: ignore - ) + @extend_schema( + methods=["POST"], + summary=_("Vectorization documentation"), + description=_("Vectorization documentation"), + operation_id=_("Vectorization documentation"), # type: ignore + responses=DefaultModelResponse.get_response(), + tags=[_("Model")], # type: ignore + ) def post(self, request: Request, model_id): - return result.success( - ModelApplySerializers(data={'model_id': model_id}).embed_documents(request.data)) + return result.success(ModelApplySerializers(data={"model_id": model_id}).embed_documents(request.data)) class EmbedQuery(APIView): - @extend_schema(methods=['POST'], - summary=_('Vectorization documentation'), - description=_('Vectorization documentation'), - operation_id=_('Vectorization documentation'), # type: ignore - responses=DefaultModelResponse.get_response(), - tags=[_('Model')] # type: ignore - ) + @extend_schema( + methods=["POST"], + summary=_("Vectorization documentation"), + description=_("Vectorization documentation"), + operation_id=_("Vectorization documentation"), # type: ignore + responses=DefaultModelResponse.get_response(), + tags=[_("Model")], # type: ignore + ) def post(self, request: Request, model_id): - return result.success( - ModelApplySerializers(data={'model_id': model_id}).embed_query(request.data)) + return result.success(ModelApplySerializers(data={"model_id": model_id}).embed_query(request.data)) class CompressDocuments(APIView): - @extend_schema(methods=['POST'], - summary=_('Reorder documents'), - description=_('Reorder documents'), - operation_id=_('Reorder documents'), # type: ignore - responses=DefaultModelResponse.get_response(), - tags=[_('Model')] # type: ignore - ) + @extend_schema( + methods=["POST"], + summary=_("Reorder documents"), + description=_("Reorder documents"), + operation_id=_("Reorder documents"), # type: ignore + responses=DefaultModelResponse.get_response(), + tags=[_("Model")], # type: ignore + ) def post(self, request: Request, model_id): - return result.success( - ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data)) + return result.success(ModelApplySerializers(data={"model_id": model_id}).compress_documents(request.data)) diff --git a/apps/models_provider/views/provide.py b/apps/models_provider/views/provide.py index 70b916ca19d..882f3fca5c0 100644 --- a/apps/models_provider/views/provide.py +++ b/apps/models_provider/views/provide.py @@ -1,103 +1,122 @@ # coding=utf-8 -from django.utils.translation import gettext_lazy as _ -from drf_spectacular.utils import extend_schema -from rest_framework.request import Request -from rest_framework.views import APIView - from common import result from common.auth import TokenAuth from common.auth.authentication import has_permissions from common.constants.permission_constants import PermissionConstants +from django.utils.translation import gettext_lazy as _ +from drf_spectacular.utils import extend_schema from models_provider.api.provide import ProvideApi from models_provider.constants.model_provider_constants import ModelProvideConstants from models_provider.serializers.model_serializer import get_default_model_params_setting +from rest_framework.request import Request +from rest_framework.views import APIView class Provide(APIView): authentication_classes = [TokenAuth] - @extend_schema(methods=['GET'], - summary=_('Get a list of model suppliers'), - description=_('Get a list of model suppliers'), - operation_id=_('Get a list of model suppliers'), # type: ignore - responses=ProvideApi.get_response(), - tags=[_('Model')]) # type: ignore + @extend_schema( + methods=["GET"], + summary=_("Get a list of model suppliers"), + description=_("Get a list of model suppliers"), + operation_id=_("Get a list of model suppliers"), # type: ignore + responses=ProvideApi.get_response(), + tags=[_("Model")], # type: ignore + ) def get(self, request: Request): - model_type = request.query_params.get('model_type') + model_type = request.query_params.get("model_type") if model_type: providers = [] for key in ModelProvideConstants.__members__: - if len([item for item in ModelProvideConstants[key].value.get_model_type_list() if - item['value'] == model_type]) > 0: + if ( + len( + [ + item + for item in ModelProvideConstants[key].value.get_model_type_list() + if item["value"] == model_type + ] + ) + > 0 + ): providers.append(ModelProvideConstants[key].value.get_model_provide_info().to_dict()) return result.success(providers) return result.success( - [ModelProvideConstants[key].value.get_model_provide_info().to_dict() for key in - ModelProvideConstants.__members__]) + [ + ModelProvideConstants[key].value.get_model_provide_info().to_dict() + for key in ModelProvideConstants.__members__ + ] + ) class ModelTypeList(APIView): authentication_classes = [TokenAuth] - @extend_schema(methods=['GET'], - summary=_('Get a list of model types'), - description=_('Get a list of model types'), - operation_id=_('Get a list of model types'), # type: ignore - parameters=ProvideApi.ModelTypeList.get_query_params_api(), - responses=ProvideApi.ModelTypeList.get_response(), - tags=[_('Model')]) # type: ignore + @extend_schema( + methods=["GET"], + summary=_("Get a list of model types"), + description=_("Get a list of model types"), + operation_id=_("Get a list of model types"), # type: ignore + parameters=ProvideApi.ModelTypeList.get_query_params_api(), + responses=ProvideApi.ModelTypeList.get_response(), + tags=[_("Model")], # type: ignore + ) def get(self, request: Request): - provider = request.query_params.get('provider') + provider = request.query_params.get("provider") return result.success(ModelProvideConstants[provider].value.get_model_type_list()) class ModelList(APIView): authentication_classes = [TokenAuth] - @extend_schema(methods=['GET'], - summary=_('Example of obtaining model list'), - description=_('Example of obtaining model list'), - operation_id=_('Example of obtaining model list'), # type: ignore - parameters=ProvideApi.ModelList.get_query_params_api(), - responses=ProvideApi.ModelList.get_response(), - tags=[_('Model')]) # type: ignore + @extend_schema( + methods=["GET"], + summary=_("Example of obtaining model list"), + description=_("Example of obtaining model list"), + operation_id=_("Example of obtaining model list"), # type: ignore + parameters=ProvideApi.ModelList.get_query_params_api(), + responses=ProvideApi.ModelList.get_response(), + tags=[_("Model")], # type: ignore + ) def get(self, request: Request): - provider = request.query_params.get('provider') - model_type = request.query_params.get('model_type') + provider = request.query_params.get("provider") + model_type = request.query_params.get("model_type") - return result.success( - ModelProvideConstants[provider].value.get_model_list( - model_type)) + return result.success(ModelProvideConstants[provider].value.get_model_list(model_type)) class ModelParamsForm(APIView): authentication_classes = [TokenAuth] - @extend_schema(methods=['GET'], - summary=_('Get model default parameters'), - description=_('Get model default parameters'), - operation_id=_('Get model default parameters'), # type: ignore - parameters=ProvideApi.ModelParamsForm.get_query_params_api(), - responses=ProvideApi.ModelParamsForm.get_response(), - tags=[_('Model')]) # type: ignore + @extend_schema( + methods=["GET"], + summary=_("Get model default parameters"), + description=_("Get model default parameters"), + operation_id=_("Get model default parameters"), # type: ignore + parameters=ProvideApi.ModelParamsForm.get_query_params_api(), + responses=ProvideApi.ModelParamsForm.get_response(), + tags=[_("Model")], # type: ignore + ) def get(self, request: Request): - provider = request.query_params.get('provider') - model_type = request.query_params.get('model_type') - model_name = request.query_params.get('model_name') + provider = request.query_params.get("provider") + model_type = request.query_params.get("model_type") + model_name = request.query_params.get("model_name") return result.success(get_default_model_params_setting(provider, model_type, model_name)) class ModelForm(APIView): authentication_classes = [TokenAuth] - @extend_schema(methods=['GET'], - summary=_('Get the model creation form'), - description=_('Get the model creation form'), - operation_id=_('Get the model creation form'), # type: ignore - parameters=ProvideApi.ModelParamsForm.get_query_params_api(), - responses=ProvideApi.ModelParamsForm.get_response(), - tags=[_('Model')]) # type: ignore + @extend_schema( + methods=["GET"], + summary=_("Get the model creation form"), + description=_("Get the model creation form"), + operation_id=_("Get the model creation form"), # type: ignore + parameters=ProvideApi.ModelParamsForm.get_query_params_api(), + responses=ProvideApi.ModelParamsForm.get_response(), + tags=[_("Model")], # type: ignore + ) def get(self, request: Request): - provider = request.query_params.get('provider') - model_type = request.query_params.get('model_type') - model_name = request.query_params.get('model_name') + provider = request.query_params.get("provider") + model_type = request.query_params.get("model_type") + model_name = request.query_params.get("model_name") return result.success( - ModelProvideConstants[provider].value.get_model_credential(model_type, model_name).to_form_list()) + ModelProvideConstants[provider].value.get_model_credential(model_type, model_name).to_form_list() + )