Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 51 additions & 39 deletions apps/local_model/serializers/model_apply_serializers.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,32 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: model_apply_serializers.py
@date:2024/8/20 20:39
@desc:
@project: MaxKB
@Author:虎
@file: model_apply_serializers.py
@date:2024/8/20 20:39
@desc:
"""

import json
import threading
import time

from common.cache.mem_cache import MemCache
from django.db import connection
from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _
from langchain_core.documents import Document
from rest_framework import serializers

from local_model.models import Model
from local_model.serializers.rsa_util import rsa_long_decrypt
from models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider

from common.cache.mem_cache import MemCache
from rest_framework import serializers

_lock = threading.Lock()
locks = {}


class ModelManage:
cache = MemCache('model', {})
cache = MemCache("model", {})
up_clear_time = time.time()

@staticmethod
Expand Down Expand Up @@ -74,87 +73,100 @@ def delete_key(_id):


def get_local_model(model, **kwargs):
return LocalModelProvider().get_model(model.model_type, model.model_name,
json.loads(
rsa_long_decrypt(model.credential)),
model_id=model.id,
streaming=True, **kwargs)
return LocalModelProvider().get_model(
model.model_type,
model.model_name,
json.loads(rsa_long_decrypt(model.credential)),
model_id=model.id,
streaming=True,
**kwargs,
)


def get_embedding_model(model_id):
model = QuerySet(Model).filter(id=model_id).first()
# 手动关闭数据库连接
connection.close()
embedding_model = ModelManage.get_model(model_id,
lambda _id: get_local_model(model, use_local=True))
embedding_model = ModelManage.get_model(model_id, lambda _id: get_local_model(model, use_local=True))
return embedding_model


class EmbedDocuments(serializers.Serializer):
texts = serializers.ListField(required=True,
child=serializers.CharField(required=True, label=_('vector text')),
label=_('vector text list'))
texts = serializers.ListField(
required=True, child=serializers.CharField(required=True, label=_("vector text")), label=_("vector text list")
)


class EmbedQuery(serializers.Serializer):
text = serializers.CharField(required=True, label=_('vector text'))
text = serializers.CharField(required=True, label=_("vector text"))


class CompressDocument(serializers.Serializer):
page_content = serializers.CharField(required=True, label=_('text'))
metadata = serializers.DictField(required=False, label=_('metadata'))
page_content = serializers.CharField(required=True, label=_("text"))
metadata = serializers.DictField(required=False, label=_("metadata"))


class CompressDocuments(serializers.Serializer):
documents = CompressDocument(required=True, many=True)
query = serializers.CharField(required=True, label=_('query'))
query = serializers.CharField(required=True, label=_("query"))


class ValidateModelSerializers(serializers.Serializer):
model_name = serializers.CharField(required=True, label=_('model_name'))
model_name = serializers.CharField(required=True, label=_("model_name"))

model_type = serializers.CharField(required=True, label=_('model_type'))
model_type = serializers.CharField(required=True, label=_("model_type"))

model_credential = serializers.DictField(required=True, label="credential")

def validate_model(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
LocalModelProvider().is_valid_credential(self.data.get('model_type'), self.data.get('model_name'),
self.data.get('model_credential'), model_params={},
raise_exception=True)
LocalModelProvider().is_valid_credential(
self.data.get("model_type"),
self.data.get("model_name"),
self.data.get("model_credential"),
model_params={},
raise_exception=True,
)


class ModelApplySerializers(serializers.Serializer):
model_id = serializers.UUIDField(required=True, label=_('model id'))
model_id = serializers.UUIDField(required=True, label=_("model id"))

def embed_documents(self, instance, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
EmbedDocuments(data=instance).is_valid(raise_exception=True)

model = get_embedding_model(self.data.get('model_id'))
return model.embed_documents(instance.getlist('texts'))
model = get_embedding_model(self.data.get("model_id"))
return model.embed_documents(instance.getlist("texts"))

def embed_query(self, instance, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
EmbedQuery(data=instance).is_valid(raise_exception=True)

model = get_embedding_model(self.data.get('model_id'))
return model.embed_query(instance.get('text'))
model = get_embedding_model(self.data.get("model_id"))
return model.embed_query(instance.get("text"))

def compress_documents(self, instance, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
CompressDocuments(data=instance).is_valid(raise_exception=True)
model = get_embedding_model(self.data.get('model_id'))
return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents(
[Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in
instance.get('documents')], instance.get('query'))]
model = get_embedding_model(self.data.get("model_id"))
return [
{"page_content": d.page_content, "metadata": d.metadata}
for d in model.compress_documents(
[
Document(page_content=document.get("page_content"), metadata=document.get("metadata"))
for document in instance.get("documents")
],
instance.get("query"),
)
]

def unload(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
ModelManage.delete_key(self.data.get('model_id'))
ModelManage.delete_key(self.data.get("model_id"))
return True
49 changes: 24 additions & 25 deletions apps/local_model/serializers/rsa_util.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
# coding=utf-8
"""
@project: maxkb
@Author:虎
@file: rsa_util.py
@date:2023/11/3 11:13
@desc:
@project: maxkb
@Author:虎
@file: rsa_util.py
@date:2023/11/3 11:13
@desc:
"""

import base64
import threading

from common.constants.cache_version import Cache_Version
from Crypto.Cipher import PKCS1_v1_5 as PKCS1_cipher
from Crypto.PublicKey import RSA
from django.core import cache
from django.db.models import QuerySet

from common.constants.cache_version import Cache_Version
from local_model.models.system_setting import SystemSetting, SettingType
from local_model.models.system_setting import SettingType, SystemSetting

lock = threading.Lock()
rsa_cache = cache.cache
Expand All @@ -33,9 +33,8 @@ def generate():
key = RSA.generate(2048)

# 获取私钥
encrypted_key = key.export_key(passphrase=secret_code, pkcs=8,
protection="scryptAndAES128-CBC")
return {'key': key.publickey().export_key(), 'value': encrypted_key}
encrypted_key = key.export_key(passphrase=secret_code, pkcs=8, protection="scryptAndAES128-CBC")
return {"key": key.publickey().export_key(), "value": encrypted_key}


def get_key_pair():
Expand All @@ -47,16 +46,17 @@ def get_key_pair():
return rsa_value
rsa_value = get_key_pair_by_sql()
version, get_key = Cache_Version.SYSTEM.value
rsa_cache.set(get_key(key='rsa_key'), rsa_value, timeout=None, version=version)
rsa_cache.set(get_key(key="rsa_key"), rsa_value, timeout=None, version=version)
return rsa_value


def get_key_pair_by_sql():
system_setting = QuerySet(SystemSetting).filter(type=SettingType.RSA.value).first()
if system_setting is None:
kv = generate()
system_setting = SystemSetting(type=SettingType.RSA.value,
meta={'key': kv.get('key').decode(), 'value': kv.get('value').decode()})
system_setting = SystemSetting(
type=SettingType.RSA.value, meta={"key": kv.get("key").decode(), "value": kv.get("value").decode()}
)
system_setting.save()
return system_setting.meta

Expand All @@ -69,7 +69,7 @@ def encrypt(msg, public_key: str | None = None):
:return: 加密后的数据
"""
if public_key is None:
public_key = get_key_pair().get('key')
public_key = get_key_pair().get("key")
cipher = PKCS1_cipher.new(RSA.importKey(public_key))
encrypt_msg = cipher.encrypt(msg.encode("utf-8"))
return base64.b64encode(encrypt_msg).decode()
Expand All @@ -83,7 +83,7 @@ def decrypt(msg, pri_key: str | None = None):
:return: 解密后数据
"""
if pri_key is None:
pri_key = get_key_pair().get('value')
pri_key = get_key_pair().get("value")
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
decrypt_data = cipher.decrypt(base64.b64decode(msg), 0)
return decrypt_data.decode("utf-8")
Expand All @@ -100,22 +100,21 @@ def rsa_long_encrypt(message, public_key: str | None = None, length=200):
"""
# 读取公钥
if public_key is None:
public_key = get_key_pair().get('key')
cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key,
passphrase=secret_code))
public_key = get_key_pair().get("key")
cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key, passphrase=secret_code))
# 处理:Plaintext is too long. 分段加密
if len(message) <= length:
# 对编码的数据进行加密,并通过base64进行编码
result = base64.b64encode(cipher.encrypt(message.encode('utf-8')))
result = base64.b64encode(cipher.encrypt(message.encode("utf-8")))
else:
rsa_text = []
# 对编码后的数据进行切片,原因:加密长度不能过长
for i in range(0, len(message), length):
cont = message[i:i + length]
cont = message[i : i + length]
# 对切片后的数据进行加密,并新增到text后面
rsa_text.append(cipher.encrypt(cont.encode('utf-8')))
rsa_text.append(cipher.encrypt(cont.encode("utf-8")))
# 加密完进行拼接
cipher_text = b''.join(rsa_text)
cipher_text = b"".join(rsa_text)
# base64进行编码
result = base64.b64encode(cipher_text)
return result.decode()
Expand All @@ -130,10 +129,10 @@ def rsa_long_decrypt(message, pri_key: str | None = None, length=256):
:return: 解密后的数据
"""
if pri_key is None:
pri_key = get_key_pair().get('value')
pri_key = get_key_pair().get("value")
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
base64_de = base64.b64decode(message)
res = []
for i in range(0, len(base64_de), length):
res.append(cipher.decrypt(base64_de[i:i + length], 0))
res.append(cipher.decrypt(base64_de[i : i + length], 0))
return b"".join(res).decode()
29 changes: 11 additions & 18 deletions apps/local_model/views/model_apply.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,35 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: model_apply.py
@date:2024/8/20 20:38
@desc:
@project: MaxKB
@Author:虎
@file: model_apply.py
@date:2024/8/20 20:38
@desc:
"""
from urllib.request import Request

from rest_framework.views import APIView
from urllib.request import Request

from common.result import result
from local_model.serializers.model_apply_serializers import ModelApplySerializers, ValidateModelSerializers
from rest_framework.views import APIView


class LocalModelApply(APIView):
class EmbedDocuments(APIView):

def post(self, request: Request, model_id):
return result.success(
ModelApplySerializers(data={'model_id': model_id}).embed_documents(request.data))
return result.success(ModelApplySerializers(data={"model_id": model_id}).embed_documents(request.data))

class EmbedQuery(APIView):

def post(self, request: Request, model_id):
return result.success(
ModelApplySerializers(data={'model_id': model_id}).embed_query(request.data))
return result.success(ModelApplySerializers(data={"model_id": model_id}).embed_query(request.data))

class CompressDocuments(APIView):

def post(self, request: Request, model_id):
return result.success(
ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data))
return result.success(ModelApplySerializers(data={"model_id": model_id}).compress_documents(request.data))

class Unload(APIView):
def post(self, request: Request, model_id):
return result.success(
ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data))
return result.success(ModelApplySerializers(data={"model_id": model_id}).compress_documents(request.data))

class Validate(APIView):
def post(self, request: Request):
Expand Down
Loading
Loading