From 1e4f2a8ec4bc1f75e9083b04e53d3a88417fa856 Mon Sep 17 00:00:00 2001 From: Raymond Jacobson Date: Thu, 26 Mar 2026 10:27:53 -0700 Subject: [PATCH 1/5] Add Solana wallet middleware for coin-gated stream access New middleware verifies ed25519 signatures from X-Solana-Wallet, X-Solana-Message, and X-Solana-Signature headers. When standard access check fails, the stream endpoint falls back to real-time on-chain token balance verification via Solana RPC for wallets that hold the required artist coin. Co-Authored-By: Claude Opus 4.6 --- api/dbv1/media_link.go | 5 ++ api/server.go | 1 + api/solana_wallet_access.go | 88 +++++++++++++++++++++++++++++++++ api/solana_wallet_middleware.go | 61 +++++++++++++++++++++++ api/v1_track_stream.go | 37 ++++++++++++-- 5 files changed, 189 insertions(+), 3 deletions(-) create mode 100644 api/solana_wallet_access.go create mode 100644 api/solana_wallet_middleware.go diff --git a/api/dbv1/media_link.go b/api/dbv1/media_link.go index e782c416..b075e8e8 100644 --- a/api/dbv1/media_link.go +++ b/api/dbv1/media_link.go @@ -24,6 +24,11 @@ type Id3Tags struct { Artist string `json:"artist"` } +// BuildMediaLink generates a signed media link for streaming a track CID from content nodes. +func BuildMediaLink(cid string, trackId int32, userId int32, id3Tags *Id3Tags) (*MediaLink, error) { + return mediaLink(cid, trackId, userId, id3Tags) +} + func mediaLink(cid string, trackId int32, userId int32, id3Tags *Id3Tags) (*MediaLink, error) { first, rest := rendezvous.GlobalHasher.Select(cid) diff --git a/api/server.go b/api/server.go index e0f7224a..55b4deef 100644 --- a/api/server.go +++ b/api/server.go @@ -364,6 +364,7 @@ func NewApiServer(config config.Config) *ApiServer { app.Use(app.isFullMiddleware) app.Use(app.resolveMyIdMiddleware) app.Use(app.authMiddleware) + app.Use(app.solanaWalletMiddleware) v1 := app.Group("/v1") v1Full := app.Group("/v1/full") diff --git a/api/solana_wallet_access.go b/api/solana_wallet_access.go new file mode 100644 index 00000000..f07aea71 --- /dev/null +++ b/api/solana_wallet_access.go @@ -0,0 +1,88 @@ +package api + +import ( + "context" + "math" + + "github.com/gagliardetto/solana-go" + "github.com/gagliardetto/solana-go/rpc" + "go.uber.org/zap" +) + +// checkSolanaWalletTokenAccess checks whether a Solana wallet holds sufficient tokens +// for a coin-gated track by doing a real-time on-chain balance lookup. +func (app *ApiServer) checkSolanaWalletTokenAccess( + ctx context.Context, + solanaWallet string, + tokenMint string, + requiredAmount int64, +) (bool, error) { + walletPubkey, err := solana.PublicKeyFromBase58(solanaWallet) + if err != nil { + return false, err + } + + mintPubkey, err := solana.PublicKeyFromBase58(tokenMint) + if err != nil { + return false, err + } + + // Derive the associated token account address + ata, _, err := solana.FindAssociatedTokenAddress(walletPubkey, mintPubkey) + if err != nil { + return false, err + } + + // Get token balance from chain + balanceResult, err := app.solanaRpcClient.GetTokenAccountBalance(ctx, ata, rpc.CommitmentConfirmed) + if err != nil { + // Account doesn't exist means zero balance + app.logger.Debug("checkSolanaWalletTokenAccess: token account not found", + zap.String("wallet", solanaWallet), + zap.String("mint", tokenMint), + zap.Error(err), + ) + return false, nil + } + + if balanceResult == nil || balanceResult.Value == nil { + return false, nil + } + + rawBalance := balanceResult.Value.Amount + if rawBalance == "" { + return false, nil + } + + // Parse the raw balance (string of lamports/smallest unit) + var balance uint64 + for _, c := range rawBalance { + if c < '0' || c > '9' { + return false, nil + } + balance = balance*10 + uint64(c-'0') + } + + // Look up coin decimals from the artist_coins table + var decimals int32 + err = app.pool.QueryRow(ctx, ` + SELECT decimals FROM artist_coins WHERE mint = $1 + `, tokenMint).Scan(&decimals) + if err != nil { + // If coin not found, try using the decimals from the RPC response + decimals = int32(balanceResult.Value.Decimals) + } + + // Scale required amount by decimals + scaledRequired := requiredAmount * int64(math.Pow10(int(decimals))) + + app.logger.Debug("checkSolanaWalletTokenAccess", + zap.String("wallet", solanaWallet), + zap.String("mint", tokenMint), + zap.Uint64("balance", balance), + zap.Int64("scaledRequired", scaledRequired), + zap.Bool("hasAccess", int64(balance) >= scaledRequired), + ) + + return int64(balance) >= scaledRequired, nil +} diff --git a/api/solana_wallet_middleware.go b/api/solana_wallet_middleware.go new file mode 100644 index 00000000..15bd4c57 --- /dev/null +++ b/api/solana_wallet_middleware.go @@ -0,0 +1,61 @@ +package api + +import ( + "crypto/ed25519" + + "github.com/gofiber/fiber/v2" + "github.com/mr-tron/base58" + "go.uber.org/zap" +) + +// solanaWalletMiddleware verifies Solana wallet signatures from request headers. +// If the X-Solana-Wallet, X-Solana-Message, and X-Solana-Signature headers are +// present and valid, the verified wallet public key is stored in c.Locals("solanaWallet"). +// If headers are absent, the middleware is a no-op. If present but invalid, returns 401. +func (app *ApiServer) solanaWalletMiddleware(c *fiber.Ctx) error { + wallet := c.Get("X-Solana-Wallet") + message := c.Get("X-Solana-Message") + signature := c.Get("X-Solana-Signature") + + // No Solana headers — skip silently + if wallet == "" && message == "" && signature == "" { + return c.Next() + } + + // Partial headers — reject + if wallet == "" || message == "" || signature == "" { + return fiber.NewError(fiber.StatusUnauthorized, "incomplete Solana wallet headers: X-Solana-Wallet, X-Solana-Message, and X-Solana-Signature are all required") + } + + // Decode the base58 public key (32 bytes for ed25519) + pubkeyBytes, err := base58.Decode(wallet) + if err != nil || len(pubkeyBytes) != ed25519.PublicKeySize { + app.logger.Warn("solanaWalletMiddleware: invalid wallet public key", zap.String("wallet", wallet), zap.Error(err)) + return fiber.NewError(fiber.StatusUnauthorized, "invalid Solana wallet public key") + } + + // Decode the base58 signature (64 bytes for ed25519) + sigBytes, err := base58.Decode(signature) + if err != nil || len(sigBytes) != ed25519.SignatureSize { + app.logger.Warn("solanaWalletMiddleware: invalid signature", zap.Error(err)) + return fiber.NewError(fiber.StatusUnauthorized, "invalid Solana signature") + } + + // Verify the ed25519 signature + if !ed25519.Verify(pubkeyBytes, []byte(message), sigBytes) { + app.logger.Warn("solanaWalletMiddleware: signature verification failed", zap.String("wallet", wallet)) + return fiber.NewError(fiber.StatusUnauthorized, "Solana signature verification failed") + } + + app.logger.Debug("solanaWalletMiddleware: verified", zap.String("wallet", wallet)) + c.Locals("solanaWallet", wallet) + return c.Next() +} + +// tryGetSolanaWallet returns the verified Solana wallet from context, or "" if not set. +func (app *ApiServer) tryGetSolanaWallet(c *fiber.Ctx) string { + if w, ok := c.Locals("solanaWallet").(string); ok { + return w + } + return "" +} diff --git a/api/v1_track_stream.go b/api/v1_track_stream.go index bbae56dc..032968ff 100644 --- a/api/v1_track_stream.go +++ b/api/v1_track_stream.go @@ -26,11 +26,42 @@ func (app *ApiServer) v1TrackStream(c *fiber.Ctx) error { } track := tracks[0] - if !track.Access.Stream { - return fiber.NewError(fiber.StatusForbidden, "track not streamable") + + // If standard access check passes, use the pre-built stream URL + if track.Access.Stream { + return app.redirectToStream(c, track.Stream) + } + + // Fallback: check if a verified Solana wallet has sufficient token balance + solWallet := app.tryGetSolanaWallet(c) + if solWallet != "" && track.StreamConditions != nil && track.StreamConditions.TokenGate != nil { + hasAccess, err := app.checkSolanaWalletTokenAccess( + c.Context(), + solWallet, + track.StreamConditions.TokenGate.TokenMint, + track.StreamConditions.TokenGate.TokenAmount, + ) + if err == nil && hasAccess { + // Build a stream URL on the fly since the track didn't include one + // (stream URLs are only pre-built when standard access is granted) + stream, err := dbv1.BuildMediaLink( + track.TrackCid.String, + track.TrackID, + myId, + &dbv1.Id3Tags{Title: track.Title.String, Artist: track.User.Name.String}, + ) + if err != nil { + return err + } + return app.redirectToStream(c, stream) + } } - streamURL := tryFindWorkingUrl(track.Stream) + return fiber.NewError(fiber.StatusForbidden, "track not streamable") +} + +func (app *ApiServer) redirectToStream(c *fiber.Ctx, stream *dbv1.MediaLink) error { + streamURL := tryFindWorkingUrl(stream) if skipPlayCount := c.Query("skip_play_count"); skipPlayCount != "" { q := streamURL.Query() From 2c6f70ead7210e42db79ca8a2544af050c63322b Mon Sep 17 00:00:00 2001 From: Raymond Jacobson Date: Thu, 26 Mar 2026 12:05:00 -0700 Subject: [PATCH 2/5] Skip on-chain grant check for PKCE-authenticated requests A valid PKCE access token already proves the user authorized the client app during the OAuth consent flow. The on-chain grant check is redundant for these requests and fails for read-only apps that don't have a grant registered in the grants table. Co-Authored-By: Claude Opus 4.6 --- api/auth_middleware.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/api/auth_middleware.go b/api/auth_middleware.go index c86e6320..6f303edc 100644 --- a/api/auth_middleware.go +++ b/api/auth_middleware.go @@ -321,8 +321,11 @@ func (app *ApiServer) authMiddleware(c *fiber.Ctx) error { c.Locals("authedWallet", wallet) + // A valid PKCE access token already proves the user authorized this client + _, pkceAuthed := c.Locals("oauthScope").(string) + // Not authorized to act on behalf of myId - if myId != 0 && !app.isAuthorizedRequest(c.Context(), myId, wallet) { + if myId != 0 && !pkceAuthed && !app.isAuthorizedRequest(c.Context(), myId, wallet) { return fiber.NewError( fiber.StatusForbidden, fmt.Sprintf( From 7e17d221ea3b57d94295e9db8c1b1428a57819fa Mon Sep 17 00:00:00 2001 From: Raymond Jacobson Date: Thu, 26 Mar 2026 12:28:38 -0700 Subject: [PATCH 3/5] Refactor Solana wallet access into bulk track access layer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move token gate checks for Solana wallets from the stream endpoint into GetBulkTrackAccess via a TokenBalanceFetcher callback. Balances are looked up from the indexed sol_token_account_balances table instead of Solana RPC, removing the RPC dependency entirely. - Add TokenBalanceFetcher type and optional param to GetBulkTrackAccess - Replace RPC-based checkSolanaWalletTokenAccess with DB-backed fetcher - Remove inline fallback in v1TrackStream — access is now pre-calculated - Remove unused BuildMediaLink export - Add middleware test for Solana wallet signature verification Co-Authored-By: Claude Opus 4.6 --- api/dbv1/access.go | 55 +++++++++----- api/dbv1/media_link.go | 5 -- api/dbv1/tracks.go | 3 +- api/solana_wallet_access.go | 103 +++++++-------------------- api/solana_wallet_middleware_test.go | 83 +++++++++++++++++++++ api/v1_track_stream.go | 38 +++------- 6 files changed, 157 insertions(+), 130 deletions(-) create mode 100644 api/solana_wallet_middleware_test.go diff --git a/api/dbv1/access.go b/api/dbv1/access.go index 253e773c..35a6ea72 100644 --- a/api/dbv1/access.go +++ b/api/dbv1/access.go @@ -13,6 +13,10 @@ type Access struct { Download bool `json:"download"` } +// TokenBalanceFetcher fetches on-chain token balances for a set of mints. +// Returns a map of mint → raw balance (smallest unit, already scaled by decimals). +type TokenBalanceFetcher func(ctx context.Context, mints []string) (map[string]int64, error) + func (q *Queries) GetPlaylistAccess( ctx context.Context, myId int32, @@ -86,6 +90,7 @@ func (q *Queries) GetBulkTrackAccess( myId int32, tracks []*GetTracksRow, users map[int32]*User, + tokenBalanceFetcher TokenBalanceFetcher, ) (map[int32]Access, error) { // Initialize result map result := make(map[int32]Access) @@ -280,26 +285,40 @@ func (q *Queries) GetBulkTrackAccess( // Query for token balances if len(tokenGateTokenMintsSlice) > 0 { - g.Go(func() error { - rows, err := q.db.Query(ctx, ` - SELECT mint, COALESCE(balance, 0) - FROM sol_user_balances - WHERE user_id = $1 - AND mint = ANY($2) - `, myId, tokenGateTokenMintsSlice) - if err != nil { - return err - } - defer rows.Close() - for rows.Next() { - var mint string - var balance int64 - if err := rows.Scan(&mint, &balance); err == nil { + if tokenBalanceFetcher != nil { + // Use the provided fetcher (e.g. on-chain RPC for Solana wallet auth) + g.Go(func() error { + balances, err := tokenBalanceFetcher(ctx, tokenGateTokenMintsSlice) + if err != nil { + return err + } + for mint, balance := range balances { userTokenBalances[mint] = balance } - } - return rows.Err() - }) + return nil + }) + } else { + g.Go(func() error { + rows, err := q.db.Query(ctx, ` + SELECT mint, COALESCE(balance, 0) + FROM sol_user_balances + WHERE user_id = $1 + AND mint = ANY($2) + `, myId, tokenGateTokenMintsSlice) + if err != nil { + return err + } + defer rows.Close() + for rows.Next() { + var mint string + var balance int64 + if err := rows.Scan(&mint, &balance); err == nil { + userTokenBalances[mint] = balance + } + } + return rows.Err() + }) + } // Query for coin decimals g.Go(func() error { diff --git a/api/dbv1/media_link.go b/api/dbv1/media_link.go index b075e8e8..e782c416 100644 --- a/api/dbv1/media_link.go +++ b/api/dbv1/media_link.go @@ -24,11 +24,6 @@ type Id3Tags struct { Artist string `json:"artist"` } -// BuildMediaLink generates a signed media link for streaming a track CID from content nodes. -func BuildMediaLink(cid string, trackId int32, userId int32, id3Tags *Id3Tags) (*MediaLink, error) { - return mediaLink(cid, trackId, userId, id3Tags) -} - func mediaLink(cid string, trackId int32, userId int32, id3Tags *Id3Tags) (*MediaLink, error) { first, rest := rendezvous.GlobalHasher.Select(cid) diff --git a/api/dbv1/tracks.go b/api/dbv1/tracks.go index e2a0f57e..f6a97ec0 100644 --- a/api/dbv1/tracks.go +++ b/api/dbv1/tracks.go @@ -10,6 +10,7 @@ import ( type TracksParams struct { GetTracksParams + TokenBalanceFetcher TokenBalanceFetcher } // Track is the standard track type containing all track data @@ -84,7 +85,7 @@ func (q *Queries) TracksKeyed(ctx context.Context, arg TracksParams) (map[int32] } // Get bulk access for all tracks - accessMap, err := q.GetBulkTrackAccess(ctx, arg.MyID.(int32), trackPtrs, userPtrMap) + accessMap, err := q.GetBulkTrackAccess(ctx, arg.MyID.(int32), trackPtrs, userPtrMap, arg.TokenBalanceFetcher) if err != nil { return nil, err } diff --git a/api/solana_wallet_access.go b/api/solana_wallet_access.go index f07aea71..63660fa2 100644 --- a/api/solana_wallet_access.go +++ b/api/solana_wallet_access.go @@ -2,87 +2,34 @@ package api import ( "context" - "math" - "github.com/gagliardetto/solana-go" - "github.com/gagliardetto/solana-go/rpc" - "go.uber.org/zap" + "api.audius.co/api/dbv1" ) -// checkSolanaWalletTokenAccess checks whether a Solana wallet holds sufficient tokens -// for a coin-gated track by doing a real-time on-chain balance lookup. -func (app *ApiServer) checkSolanaWalletTokenAccess( - ctx context.Context, - solanaWallet string, - tokenMint string, - requiredAmount int64, -) (bool, error) { - walletPubkey, err := solana.PublicKeyFromBase58(solanaWallet) - if err != nil { - return false, err - } - - mintPubkey, err := solana.PublicKeyFromBase58(tokenMint) - if err != nil { - return false, err - } - - // Derive the associated token account address - ata, _, err := solana.FindAssociatedTokenAddress(walletPubkey, mintPubkey) - if err != nil { - return false, err - } - - // Get token balance from chain - balanceResult, err := app.solanaRpcClient.GetTokenAccountBalance(ctx, ata, rpc.CommitmentConfirmed) - if err != nil { - // Account doesn't exist means zero balance - app.logger.Debug("checkSolanaWalletTokenAccess: token account not found", - zap.String("wallet", solanaWallet), - zap.String("mint", tokenMint), - zap.Error(err), - ) - return false, nil - } - - if balanceResult == nil || balanceResult.Value == nil { - return false, nil - } - - rawBalance := balanceResult.Value.Amount - if rawBalance == "" { - return false, nil - } - - // Parse the raw balance (string of lamports/smallest unit) - var balance uint64 - for _, c := range rawBalance { - if c < '0' || c > '9' { - return false, nil +// newSolanaWalletTokenBalanceFetcher returns a TokenBalanceFetcher that looks up +// on-chain token balances for the given Solana wallet from the indexed +// sol_token_account_balances table. The returned balances are raw (smallest unit) +// matching the format used by GetBulkTrackAccess. +func (app *ApiServer) newSolanaWalletTokenBalanceFetcher(wallet string) dbv1.TokenBalanceFetcher { + return func(ctx context.Context, mints []string) (map[string]int64, error) { + balances := make(map[string]int64, len(mints)) + rows, err := app.pool.Query(ctx, ` + SELECT mint, COALESCE(balance, 0) + FROM sol_token_account_balances + WHERE owner = $1 + AND mint = ANY($2) + `, wallet, mints) + if err != nil { + return nil, err } - balance = balance*10 + uint64(c-'0') - } - - // Look up coin decimals from the artist_coins table - var decimals int32 - err = app.pool.QueryRow(ctx, ` - SELECT decimals FROM artist_coins WHERE mint = $1 - `, tokenMint).Scan(&decimals) - if err != nil { - // If coin not found, try using the decimals from the RPC response - decimals = int32(balanceResult.Value.Decimals) + defer rows.Close() + for rows.Next() { + var mint string + var balance int64 + if err := rows.Scan(&mint, &balance); err == nil { + balances[mint] = balance + } + } + return balances, rows.Err() } - - // Scale required amount by decimals - scaledRequired := requiredAmount * int64(math.Pow10(int(decimals))) - - app.logger.Debug("checkSolanaWalletTokenAccess", - zap.String("wallet", solanaWallet), - zap.String("mint", tokenMint), - zap.Uint64("balance", balance), - zap.Int64("scaledRequired", scaledRequired), - zap.Bool("hasAccess", int64(balance) >= scaledRequired), - ) - - return int64(balance) >= scaledRequired, nil } diff --git a/api/solana_wallet_middleware_test.go b/api/solana_wallet_middleware_test.go new file mode 100644 index 00000000..075b84d5 --- /dev/null +++ b/api/solana_wallet_middleware_test.go @@ -0,0 +1,83 @@ +package api + +import ( + "crypto/ed25519" + "io" + "net/http/httptest" + "testing" + + "github.com/gofiber/fiber/v2" + "github.com/mr-tron/base58" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSolanaWalletMiddleware(t *testing.T) { + app := emptyTestApp(t) + + var capturedWallet string + testApp := fiber.New() + testApp.Get("/", app.solanaWalletMiddleware, func(c *fiber.Ctx) error { + capturedWallet = app.tryGetSolanaWallet(c) + return c.SendStatus(fiber.StatusOK) + }) + + // Generate a fresh ed25519 keypair for testing + pub, priv, err := ed25519.GenerateKey(nil) + require.NoError(t, err) + wallet := base58.Encode(pub) + message := "Sign in to Audius" + sig := ed25519.Sign(priv, []byte(message)) + + t.Run("no headers is a no-op", func(t *testing.T) { + capturedWallet = "" + req := httptest.NewRequest("GET", "/", nil) + res, err := testApp.Test(req, -1) + assert.NoError(t, err) + assert.Equal(t, fiber.StatusOK, res.StatusCode) + assert.Equal(t, "", capturedWallet) + }) + + t.Run("valid signature sets solanaWallet", func(t *testing.T) { + capturedWallet = "" + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("X-Solana-Wallet", wallet) + req.Header.Set("X-Solana-Message", message) + req.Header.Set("X-Solana-Signature", base58.Encode(sig)) + res, err := testApp.Test(req, -1) + assert.NoError(t, err) + assert.Equal(t, fiber.StatusOK, res.StatusCode) + assert.Equal(t, wallet, capturedWallet) + }) + + t.Run("partial headers returns 401", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("X-Solana-Wallet", wallet) + // missing message and signature + res, err := testApp.Test(req, -1) + assert.NoError(t, err) + assert.Equal(t, fiber.StatusUnauthorized, res.StatusCode) + }) + + t.Run("invalid signature returns 401", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("X-Solana-Wallet", wallet) + req.Header.Set("X-Solana-Message", "wrong message") + req.Header.Set("X-Solana-Signature", base58.Encode(sig)) + res, err := testApp.Test(req, -1) + assert.NoError(t, err) + assert.Equal(t, fiber.StatusUnauthorized, res.StatusCode) + }) + + t.Run("invalid wallet pubkey returns 401", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("X-Solana-Wallet", "notavalidkey") + req.Header.Set("X-Solana-Message", message) + req.Header.Set("X-Solana-Signature", base58.Encode(sig)) + res, err := testApp.Test(req, -1) + assert.NoError(t, err) + assert.Equal(t, fiber.StatusUnauthorized, res.StatusCode) + body, _ := io.ReadAll(res.Body) + assert.Contains(t, string(body), "invalid Solana wallet public key") + }) +} diff --git a/api/v1_track_stream.go b/api/v1_track_stream.go index 032968ff..c160d4fe 100644 --- a/api/v1_track_stream.go +++ b/api/v1_track_stream.go @@ -9,14 +9,22 @@ func (app *ApiServer) v1TrackStream(c *fiber.Ctx) error { myId := app.getMyId(c) trackId := c.Locals("trackId").(int) - tracks, err := app.queries.Tracks(c.Context(), dbv1.TracksParams{ + params := dbv1.TracksParams{ GetTracksParams: dbv1.GetTracksParams{ MyID: myId, Ids: []int32{int32(trackId)}, AuthedWallet: app.tryGetAuthedWallet(c), IncludeUnlisted: true, }, - }) + } + + // If a verified Solana wallet is present, attach an on-chain balance + // fetcher so GetBulkTrackAccess can check token gates via RPC. + if solWallet := app.tryGetSolanaWallet(c); solWallet != "" { + params.TokenBalanceFetcher = app.newSolanaWalletTokenBalanceFetcher(solWallet) + } + + tracks, err := app.queries.Tracks(c.Context(), params) if err != nil { return err } @@ -27,36 +35,10 @@ func (app *ApiServer) v1TrackStream(c *fiber.Ctx) error { track := tracks[0] - // If standard access check passes, use the pre-built stream URL if track.Access.Stream { return app.redirectToStream(c, track.Stream) } - // Fallback: check if a verified Solana wallet has sufficient token balance - solWallet := app.tryGetSolanaWallet(c) - if solWallet != "" && track.StreamConditions != nil && track.StreamConditions.TokenGate != nil { - hasAccess, err := app.checkSolanaWalletTokenAccess( - c.Context(), - solWallet, - track.StreamConditions.TokenGate.TokenMint, - track.StreamConditions.TokenGate.TokenAmount, - ) - if err == nil && hasAccess { - // Build a stream URL on the fly since the track didn't include one - // (stream URLs are only pre-built when standard access is granted) - stream, err := dbv1.BuildMediaLink( - track.TrackCid.String, - track.TrackID, - myId, - &dbv1.Id3Tags{Title: track.Title.String, Artist: track.User.Name.String}, - ) - if err != nil { - return err - } - return app.redirectToStream(c, stream) - } - } - return fiber.NewError(fiber.StatusForbidden, "track not streamable") } From 715bc9c5329c2faed3a28fbff3f55e808628e151 Mon Sep 17 00:00:00 2001 From: Raymond Jacobson Date: Thu, 26 Mar 2026 12:32:30 -0700 Subject: [PATCH 4/5] Simplify: pass Solana wallet directly to GetBulkTrackAccess Replace the TokenBalanceFetcher abstraction with a plain solanaWallet string parameter. When present, a parallel query against sol_token_account_balances runs alongside the existing sol_user_balances lookup, merging the higher balance per mint after both complete. Co-Authored-By: Claude Opus 4.6 --- api/dbv1/access.go | 57 ++++++++++++++++++++++++------------- api/dbv1/tracks.go | 4 +-- api/solana_wallet_access.go | 35 ----------------------- api/v1_track_stream.go | 6 ++-- 4 files changed, 42 insertions(+), 60 deletions(-) delete mode 100644 api/solana_wallet_access.go diff --git a/api/dbv1/access.go b/api/dbv1/access.go index 35a6ea72..61faffdb 100644 --- a/api/dbv1/access.go +++ b/api/dbv1/access.go @@ -13,10 +13,6 @@ type Access struct { Download bool `json:"download"` } -// TokenBalanceFetcher fetches on-chain token balances for a set of mints. -// Returns a map of mint → raw balance (smallest unit, already scaled by decimals). -type TokenBalanceFetcher func(ctx context.Context, mints []string) (map[string]int64, error) - func (q *Queries) GetPlaylistAccess( ctx context.Context, myId int32, @@ -90,7 +86,7 @@ func (q *Queries) GetBulkTrackAccess( myId int32, tracks []*GetTracksRow, users map[int32]*User, - tokenBalanceFetcher TokenBalanceFetcher, + solanaWallet string, ) (map[int32]Access, error) { // Initialize result map result := make(map[int32]Access) @@ -208,6 +204,7 @@ func (q *Queries) GetBulkTrackAccess( purchasedPlaylists := make(map[int32]bool) prevPurchasedPlaylists := make(map[int32]bool) userTokenBalances := make(map[string]int64) + walletTokenBalances := make(map[string]int64) coinDecimals := make(map[string]int32) g, ctx := errgroup.WithContext(ctx) @@ -285,26 +282,39 @@ func (q *Queries) GetBulkTrackAccess( // Query for token balances if len(tokenGateTokenMintsSlice) > 0 { - if tokenBalanceFetcher != nil { - // Use the provided fetcher (e.g. on-chain RPC for Solana wallet auth) - g.Go(func() error { - balances, err := tokenBalanceFetcher(ctx, tokenGateTokenMintsSlice) - if err != nil { - return err - } - for mint, balance := range balances { + // Look up balances from the per-user aggregate table + g.Go(func() error { + rows, err := q.db.Query(ctx, ` + SELECT mint, COALESCE(balance, 0) + FROM sol_user_balances + WHERE user_id = $1 + AND mint = ANY($2) + `, myId, tokenGateTokenMintsSlice) + if err != nil { + return err + } + defer rows.Close() + for rows.Next() { + var mint string + var balance int64 + if err := rows.Scan(&mint, &balance); err == nil { userTokenBalances[mint] = balance } - return nil - }) - } else { + } + return rows.Err() + }) + + // If a Solana wallet was provided (e.g. signed via middleware), + // also check balances from the token account balances table. + // Results are merged after g.Wait() to avoid concurrent map writes. + if solanaWallet != "" { g.Go(func() error { rows, err := q.db.Query(ctx, ` SELECT mint, COALESCE(balance, 0) - FROM sol_user_balances - WHERE user_id = $1 + FROM sol_token_account_balances + WHERE owner = $1 AND mint = ANY($2) - `, myId, tokenGateTokenMintsSlice) + `, solanaWallet, tokenGateTokenMintsSlice) if err != nil { return err } @@ -313,7 +323,7 @@ func (q *Queries) GetBulkTrackAccess( var mint string var balance int64 if err := rows.Scan(&mint, &balance); err == nil { - userTokenBalances[mint] = balance + walletTokenBalances[mint] = balance } } return rows.Err() @@ -408,6 +418,13 @@ func (q *Queries) GetBulkTrackAccess( return nil, err } + // Merge wallet balances, keeping the higher value per mint + for mint, balance := range walletTokenBalances { + if balance > userTokenBalances[mint] { + userTokenBalances[mint] = balance + } + } + // Now determine access for each track for _, track := range tracks { if track == nil { diff --git a/api/dbv1/tracks.go b/api/dbv1/tracks.go index f6a97ec0..e7c8bda3 100644 --- a/api/dbv1/tracks.go +++ b/api/dbv1/tracks.go @@ -10,7 +10,7 @@ import ( type TracksParams struct { GetTracksParams - TokenBalanceFetcher TokenBalanceFetcher + SolanaWallet string } // Track is the standard track type containing all track data @@ -85,7 +85,7 @@ func (q *Queries) TracksKeyed(ctx context.Context, arg TracksParams) (map[int32] } // Get bulk access for all tracks - accessMap, err := q.GetBulkTrackAccess(ctx, arg.MyID.(int32), trackPtrs, userPtrMap, arg.TokenBalanceFetcher) + accessMap, err := q.GetBulkTrackAccess(ctx, arg.MyID.(int32), trackPtrs, userPtrMap, arg.SolanaWallet) if err != nil { return nil, err } diff --git a/api/solana_wallet_access.go b/api/solana_wallet_access.go deleted file mode 100644 index 63660fa2..00000000 --- a/api/solana_wallet_access.go +++ /dev/null @@ -1,35 +0,0 @@ -package api - -import ( - "context" - - "api.audius.co/api/dbv1" -) - -// newSolanaWalletTokenBalanceFetcher returns a TokenBalanceFetcher that looks up -// on-chain token balances for the given Solana wallet from the indexed -// sol_token_account_balances table. The returned balances are raw (smallest unit) -// matching the format used by GetBulkTrackAccess. -func (app *ApiServer) newSolanaWalletTokenBalanceFetcher(wallet string) dbv1.TokenBalanceFetcher { - return func(ctx context.Context, mints []string) (map[string]int64, error) { - balances := make(map[string]int64, len(mints)) - rows, err := app.pool.Query(ctx, ` - SELECT mint, COALESCE(balance, 0) - FROM sol_token_account_balances - WHERE owner = $1 - AND mint = ANY($2) - `, wallet, mints) - if err != nil { - return nil, err - } - defer rows.Close() - for rows.Next() { - var mint string - var balance int64 - if err := rows.Scan(&mint, &balance); err == nil { - balances[mint] = balance - } - } - return balances, rows.Err() - } -} diff --git a/api/v1_track_stream.go b/api/v1_track_stream.go index c160d4fe..64a28786 100644 --- a/api/v1_track_stream.go +++ b/api/v1_track_stream.go @@ -18,10 +18,10 @@ func (app *ApiServer) v1TrackStream(c *fiber.Ctx) error { }, } - // If a verified Solana wallet is present, attach an on-chain balance - // fetcher so GetBulkTrackAccess can check token gates via RPC. + // If a verified Solana wallet is present, pass it through so + // GetBulkTrackAccess can check token gate balances for it. if solWallet := app.tryGetSolanaWallet(c); solWallet != "" { - params.TokenBalanceFetcher = app.newSolanaWalletTokenBalanceFetcher(solWallet) + params.SolanaWallet = solWallet } tracks, err := app.queries.Tracks(c.Context(), params) From a2aae0073b8d7ad79057431ec3426bbaeb34d8b8 Mon Sep 17 00:00:00 2001 From: Raymond Jacobson Date: Thu, 26 Mar 2026 12:39:12 -0700 Subject: [PATCH 5/5] Sum user and wallet token balances instead of taking max Co-Authored-By: Claude Opus 4.6 --- api/dbv1/access.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/api/dbv1/access.go b/api/dbv1/access.go index 61faffdb..6dd8109a 100644 --- a/api/dbv1/access.go +++ b/api/dbv1/access.go @@ -418,11 +418,9 @@ func (q *Queries) GetBulkTrackAccess( return nil, err } - // Merge wallet balances, keeping the higher value per mint + // Merge wallet balances by summing with user balances for mint, balance := range walletTokenBalances { - if balance > userTokenBalances[mint] { - userTokenBalances[mint] = balance - } + userTokenBalances[mint] += balance } // Now determine access for each track