Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions cmd/bricksllm/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ func main() {
log.Sugar().Fatalf("error creating user id for users table: %v", err)
}

err = store.InitializeSecondaryKeyTable()
if err != nil {
log.Sugar().Fatalf("error initializing secondary key table: %v", err)
}

go store.PrepareEventsIndexes(log)

cpMemStore, err := memdb.NewCustomProvidersMemDb(store, log, cfg.InMemoryDbUpdateInterval)
Expand Down Expand Up @@ -190,79 +195,79 @@ func main() {
rateLimitRedisCache := redis.NewClient(defaultRedisOption(cfg, 0))
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := rateLimitRedisCache.Ping(ctx).Err(); err != nil {
if err = rateLimitRedisCache.Ping(ctx).Err(); err != nil {
log.Sugar().Fatalf("error connecting to rate limit redis cache: %v", err)
}

costLimitRedisCache := redis.NewClient(defaultRedisOption(cfg, 1))

ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := costLimitRedisCache.Ping(ctx).Err(); err != nil {
if err = costLimitRedisCache.Ping(ctx).Err(); err != nil {
log.Sugar().Fatalf("error connecting to cost limit redis cache: %v", err)
}

costRedisStorage := redis.NewClient(defaultRedisOption(cfg, 2))

ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := costRedisStorage.Ping(ctx).Err(); err != nil {
if err = costRedisStorage.Ping(ctx).Err(); err != nil {
log.Sugar().Fatalf("error connecting to cost limit redis storage: %v", err)
}

apiRedisCache := redis.NewClient(defaultRedisOption(cfg, 3))

ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := apiRedisCache.Ping(ctx).Err(); err != nil {
if err = apiRedisCache.Ping(ctx).Err(); err != nil {
log.Sugar().Fatalf("error connecting to api redis cache: %v", err)
}

accessRedisCache := redis.NewClient(defaultRedisOption(cfg, 4))

ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := accessRedisCache.Ping(ctx).Err(); err != nil {
if err = accessRedisCache.Ping(ctx).Err(); err != nil {
log.Sugar().Fatalf("error connecting to api redis cache: %v", err)
}

userRateLimitRedisCache := redis.NewClient(defaultRedisOption(cfg, 5))

ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := userRateLimitRedisCache.Ping(ctx).Err(); err != nil {
if err = userRateLimitRedisCache.Ping(ctx).Err(); err != nil {
log.Sugar().Fatalf("error connecting to user rate limit redis cache: %v", err)
}

userCostLimitRedisCache := redis.NewClient(defaultRedisOption(cfg, 6))

ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := userCostLimitRedisCache.Ping(ctx).Err(); err != nil {
if err = userCostLimitRedisCache.Ping(ctx).Err(); err != nil {
log.Sugar().Fatalf("error connecting to user cost limit redis cache: %v", err)
}

userCostRedisStorage := redis.NewClient(defaultRedisOption(cfg, 7))

ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := userCostRedisStorage.Ping(ctx).Err(); err != nil {
if err = userCostRedisStorage.Ping(ctx).Err(); err != nil {
log.Sugar().Fatalf("error connecting to user cost redis cache: %v", err)
}

userAccessRedisCache := redis.NewClient(defaultRedisOption(cfg, 8))

ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := userAccessRedisCache.Ping(ctx).Err(); err != nil {
if err = userAccessRedisCache.Ping(ctx).Err(); err != nil {
log.Sugar().Fatalf("error connecting to user access redis storage: %v", err)
}

providerSettingsRedisCache := redis.NewClient(defaultRedisOption(cfg, 9))

ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := providerSettingsRedisCache.Ping(ctx).Err(); err != nil {
if err = providerSettingsRedisCache.Ping(ctx).Err(); err != nil {
log.Sugar().Fatalf("error connecting to provider settings redis storage: %v", err)
}

Expand All @@ -278,10 +283,18 @@ func main() {

ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := requestsLimitRedisStorage.Ping(ctx).Err(); err != nil {
if err = requestsLimitRedisStorage.Ping(ctx).Err(); err != nil {
log.Sugar().Fatalf("error connecting to requests limit redis storage: %v", err)
}

secondaryKeysRedisCache := redis.NewClient(defaultRedisOption(cfg, 12))

ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err = secondaryKeysRedisCache.Ping(ctx).Err(); err != nil {
log.Sugar().Fatalf("error connecting to secondary keys redis storage: %v", err)
}

rateLimitCache := redisStorage.NewCache(rateLimitRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
costLimitCache := redisStorage.NewCache(costLimitRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
costStorage := redisStorage.NewStore(costRedisStorage, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
Expand All @@ -295,6 +308,7 @@ func main() {

psCache := redisStorage.NewProviderSettingsCache(providerSettingsRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
keysCache := redisStorage.NewKeysCache(keysRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
secondaryKeysCache := redisStorage.NewSecondaryKeysCache(secondaryKeysRedisCache, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)
requestsLimitStorage := redisStorage.NewStore(requestsLimitRedisStorage, cfg.RedisWriteTimeout, cfg.RedisReadTimeout)

encryptor, err := encryptor.NewEncryptor(cfg.DecryptionEndpoint, cfg.EncryptionEndpoint, cfg.EnableEncrytion, cfg.EncryptionTimeout, cfg.Audience)
Expand All @@ -303,7 +317,7 @@ func main() {
}
v := validator.NewValidator(costLimitCache, rateLimitCache, costStorage, requestsLimitStorage)

m := manager.NewManager(store, costLimitCache, rateLimitCache, accessCache, keysCache, requestsLimitStorage)
m := manager.NewManager(store, costLimitCache, rateLimitCache, accessCache, keysCache, secondaryKeysCache, requestsLimitStorage)
krm := manager.NewReportingManager(costStorage, store, store, v)
psm := manager.NewProviderSettingsManager(store, psCache, encryptor)
cpm := manager.NewCustomProvidersManager(store, cpMemStore)
Expand Down
41 changes: 28 additions & 13 deletions internal/authenticator/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type routesManager interface {

type keysCache interface {
GetKeyViaCache(hash string) (*key.ResponseKey, error)
GetKeyHashBySecondary(kHash string) (string, error)
}

type keyStorage interface {
Expand Down Expand Up @@ -206,6 +207,19 @@ func anonymize(input string) string {
return string(input[0:5]) + "**********************************************"
}

const secondaryPrefix = "secondary_"

func (a *Authenticator) getHashViaSecondary(rawKey string) (string, error) {
if len(rawKey) > 0 && rawKey[0] != secondaryPrefix[0] {
return hasher.Hash(rawKey), nil
}
if !strings.HasPrefix(rawKey, secondaryPrefix) {
return hasher.Hash(rawKey), nil
}
hash := hasher.Hash(rawKey)
return a.kc.GetKeyHashBySecondary(hash)
}

func (a *Authenticator) AuthenticateHttpRequest(req *http.Request, xCustomProviderId string) (*key.ResponseKey, []*provider.Setting, error) {
var raw string
var err error
Expand All @@ -224,31 +238,32 @@ func (a *Authenticator) AuthenticateHttpRequest(req *http.Request, xCustomProvid
return nil, nil, err
}

hash := hasher.Hash(raw)
hash, err := a.getHashViaSecondary(raw)

key, err := a.kc.GetKeyViaCache(hash)
if key != nil {
rKey, err := a.kc.GetKeyViaCache(hash)
if rKey != nil {
telemetry.Incr(metricname.COUNTER_AUTHENTICATOR_FOUND_KEY_FROM_MEMDB, nil, 1)
}

if key == nil {
key, err = a.kc.GetKeyViaCache(raw)
if rKey == nil {
rKey, err = a.kc.GetKeyViaCache(raw)
}

if err != nil {
_, ok := err.(notFoundError)
var nFoundError notFoundError
ok := errors.As(err, &nFoundError)
if ok {
return nil, nil, internal_errors.NewAuthError(fmt.Sprintf("key %s is not found", anonymize(raw)))
}

return nil, nil, err
}

if key == nil {
if rKey == nil {
return nil, nil, internal_errors.NewAuthError(fmt.Sprintf("key %s is not found", anonymize(raw)))
}

if key.Revoked {
if rKey.Revoked {
return nil, nil, internal_errors.NewAuthError(fmt.Sprintf("key %s has been revoked", anonymize(raw)))
}

Expand All @@ -271,17 +286,17 @@ func (a *Authenticator) AuthenticateHttpRequest(req *http.Request, xCustomProvid
default:
return nil, nil, errors.New("invalid xCustomAuth location")
}
return key, settings, nil
return rKey, settings, nil
}

if strings.HasPrefix(req.URL.Path, "/api/routes") {
err = a.canKeyAccessCustomRoute(req.URL.Path, key.KeyId)
err = a.canKeyAccessCustomRoute(req.URL.Path, rKey.KeyId)
if err != nil {
return nil, nil, err
}
}

settingIds := key.GetSettingIds()
settingIds := rKey.GetSettingIds()
allSettings := []*provider.Setting{}
selected := []*provider.Setting{}
for _, settingId := range settingIds {
Expand All @@ -308,7 +323,7 @@ func (a *Authenticator) AuthenticateHttpRequest(req *http.Request, xCustomProvid

if len(selected) != 0 {
used := selected[0]
if key.RotationEnabled {
if rKey.RotationEnabled {
used = selected[rand.Intn(len(selected))]
}

Expand Down Expand Up @@ -337,7 +352,7 @@ func (a *Authenticator) AuthenticateHttpRequest(req *http.Request, xCustomProvid
return nil, nil, err
}

return key, selected, nil
return rKey, selected, nil
}

return nil, nil, internal_errors.NewAuthError(fmt.Sprintf("provider setting not found for key %s", raw))
Expand Down
91 changes: 78 additions & 13 deletions internal/manager/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/bricks-cloud/bricksllm/internal/key"
"github.com/bricks-cloud/bricksllm/internal/policy"
"github.com/bricks-cloud/bricksllm/internal/provider"
secondarykey "github.com/bricks-cloud/bricksllm/internal/secondary-key"
"github.com/bricks-cloud/bricksllm/internal/telemetry"
"github.com/bricks-cloud/bricksllm/internal/util"
)
Expand All @@ -26,6 +27,9 @@ type Storage interface {
GetProviderSettings(withSecret bool, ids []string) ([]*provider.Setting, error)
GetKey(keyId string) (*key.ResponseKey, error)
GetKeyByHash(hash string) (*key.ResponseKey, error)
GetKeyHashBySecondary(sHash string) (string, error)
CreateSecondaryKey(secondaryHash string) error
UpdateSecondaryKey(secondaryHash, keyHash string) error
}

type costLimitCache interface {
Expand All @@ -46,27 +50,35 @@ type keyCache interface {
Get(keyId string) (*key.ResponseKey, error)
}

type secondaryKeyCache interface {
Set(sHash string, value string, ttl time.Duration) error
Delete(sHash string) error
Get(sHash string) (string, error)
}

type requestsLimitStorage interface {
DeleteCounter(keyId string) error
}

type Manager struct {
s Storage
clc costLimitCache
rlc rateLimitCache
ac accessCache
kc keyCache
rqls requestsLimitStorage
s Storage
clc costLimitCache
rlc rateLimitCache
ac accessCache
kc keyCache
secondaryKC secondaryKeyCache
rqls requestsLimitStorage
}

func NewManager(s Storage, clc costLimitCache, rlc rateLimitCache, ac accessCache, kc keyCache, rqls requestsLimitStorage) *Manager {
func NewManager(s Storage, clc costLimitCache, rlc rateLimitCache, ac accessCache, kc keyCache, secondaryKC secondaryKeyCache, rqls requestsLimitStorage) *Manager {
return &Manager{
s: s,
clc: clc,
rlc: rlc,
ac: ac,
kc: kc,
rqls: rqls,
s: s,
clc: clc,
rlc: rlc,
ac: ac,
kc: kc,
secondaryKC: secondaryKC,
rqls: rqls,
}
}

Expand Down Expand Up @@ -241,6 +253,59 @@ func (m *Manager) GetKeyViaCache(raw string) (*key.ResponseKey, error) {
return k, nil
}

func (m *Manager) GetKeyHashBySecondary(sHash string) (string, error) {
h, _ := m.secondaryKC.Get(sHash)
if h == "" {
telemetry.Incr("bricksllm.manager.get_key_hash_by_secondary.cache_miss", nil, 1)
stored, err := m.s.GetKeyHashBySecondary(sHash)
if err != nil {
return "", err
}
if stored == "" {
return "", errors.New("key hash not found")
}
err = m.secondaryKC.Set(sHash, stored, 24*time.Hour)
if err != nil {
telemetry.Incr("bricksllm.manager.get_key_hash_by_secondary.set_error", nil, 1)
}
h = stored
}
telemetry.Incr("bricksllm.manager.get_key_hash_by_secondary.cache_hit", nil, 1)
return h, nil
}

func (m *Manager) CreateSecondaryKey(keyCreate secondarykey.SecondaryKeyCreate) error {
if keyCreate.Key == "" {
return errors.New("key is required for creating secondary key")
}
return m.s.CreateSecondaryKey(hasher.Hash(keyCreate.Key))
}

func (m *Manager) UpdateSecondaryKey(keyUpdate secondarykey.SecondaryKeyUpdate) error {
if keyUpdate.Key == "" {
return errors.New("key is required for updating secondary key")
}
if keyUpdate.LinkedKeyId == "" {
return errors.New("linkedKeyId is required for updating secondary key")
}
rKey, err := m.s.GetKey(keyUpdate.LinkedKeyId)
if err != nil {
return err
}
if rKey == nil {
return errors.New("linked key not found for updating secondary key")
}
err = m.s.UpdateSecondaryKey(hasher.Hash(keyUpdate.Key), rKey.Key)
if err != nil {
return err
}
err = m.secondaryKC.Set(hasher.Hash(keyUpdate.Key), rKey.Key, 24*time.Hour)
if err != nil {
telemetry.Incr("bricksllm.manager.update_secondary_key.set_cache_error", nil, 1)
}
return nil
}

func (m *Manager) DeleteKey(id string) error {
return m.s.DeleteKey(id)
}
10 changes: 10 additions & 0 deletions internal/secondary-key/secondary_key.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package secondary_key

type SecondaryKeyCreate struct {
Key string `json:"key"`
}

type SecondaryKeyUpdate struct {
Key string `json:"key"`
LinkedKeyId string `json:"linkedKeyId"`
}
Loading
Loading