diff --git a/api/auth_middleware.go b/api/auth_middleware.go index c86e6320..6f303edc 100644 --- a/api/auth_middleware.go +++ b/api/auth_middleware.go @@ -321,8 +321,11 @@ func (app *ApiServer) authMiddleware(c *fiber.Ctx) error { c.Locals("authedWallet", wallet) + // A valid PKCE access token already proves the user authorized this client + _, pkceAuthed := c.Locals("oauthScope").(string) + // Not authorized to act on behalf of myId - if myId != 0 && !app.isAuthorizedRequest(c.Context(), myId, wallet) { + if myId != 0 && !pkceAuthed && !app.isAuthorizedRequest(c.Context(), myId, wallet) { return fiber.NewError( fiber.StatusForbidden, fmt.Sprintf( diff --git a/api/dbv1/access.go b/api/dbv1/access.go index 253e773c..6dd8109a 100644 --- a/api/dbv1/access.go +++ b/api/dbv1/access.go @@ -86,6 +86,7 @@ func (q *Queries) GetBulkTrackAccess( myId int32, tracks []*GetTracksRow, users map[int32]*User, + solanaWallet string, ) (map[int32]Access, error) { // Initialize result map result := make(map[int32]Access) @@ -203,6 +204,7 @@ func (q *Queries) GetBulkTrackAccess( purchasedPlaylists := make(map[int32]bool) prevPurchasedPlaylists := make(map[int32]bool) userTokenBalances := make(map[string]int64) + walletTokenBalances := make(map[string]int64) coinDecimals := make(map[string]int32) g, ctx := errgroup.WithContext(ctx) @@ -280,6 +282,7 @@ func (q *Queries) GetBulkTrackAccess( // Query for token balances if len(tokenGateTokenMintsSlice) > 0 { + // Look up balances from the per-user aggregate table g.Go(func() error { rows, err := q.db.Query(ctx, ` SELECT mint, COALESCE(balance, 0) @@ -301,6 +304,32 @@ func (q *Queries) GetBulkTrackAccess( return rows.Err() }) + // If a Solana wallet was provided (e.g. signed via middleware), + // also check balances from the token account balances table. + // Results are merged after g.Wait() to avoid concurrent map writes. + if solanaWallet != "" { + g.Go(func() error { + rows, err := q.db.Query(ctx, ` + SELECT mint, COALESCE(balance, 0) + FROM sol_token_account_balances + WHERE owner = $1 + AND mint = ANY($2) + `, solanaWallet, tokenGateTokenMintsSlice) + if err != nil { + return err + } + defer rows.Close() + for rows.Next() { + var mint string + var balance int64 + if err := rows.Scan(&mint, &balance); err == nil { + walletTokenBalances[mint] = balance + } + } + return rows.Err() + }) + } + // Query for coin decimals g.Go(func() error { rows, err := q.db.Query(ctx, ` @@ -389,6 +418,11 @@ func (q *Queries) GetBulkTrackAccess( return nil, err } + // Merge wallet balances by summing with user balances + for mint, balance := range walletTokenBalances { + userTokenBalances[mint] += balance + } + // Now determine access for each track for _, track := range tracks { if track == nil { diff --git a/api/dbv1/tracks.go b/api/dbv1/tracks.go index e2a0f57e..e7c8bda3 100644 --- a/api/dbv1/tracks.go +++ b/api/dbv1/tracks.go @@ -10,6 +10,7 @@ import ( type TracksParams struct { GetTracksParams + SolanaWallet string } // Track is the standard track type containing all track data @@ -84,7 +85,7 @@ func (q *Queries) TracksKeyed(ctx context.Context, arg TracksParams) (map[int32] } // Get bulk access for all tracks - accessMap, err := q.GetBulkTrackAccess(ctx, arg.MyID.(int32), trackPtrs, userPtrMap) + accessMap, err := q.GetBulkTrackAccess(ctx, arg.MyID.(int32), trackPtrs, userPtrMap, arg.SolanaWallet) if err != nil { return nil, err } diff --git a/api/server.go b/api/server.go index e0f7224a..55b4deef 100644 --- a/api/server.go +++ b/api/server.go @@ -364,6 +364,7 @@ func NewApiServer(config config.Config) *ApiServer { app.Use(app.isFullMiddleware) app.Use(app.resolveMyIdMiddleware) app.Use(app.authMiddleware) + app.Use(app.solanaWalletMiddleware) v1 := app.Group("/v1") v1Full := app.Group("/v1/full") diff --git a/api/solana_wallet_middleware.go b/api/solana_wallet_middleware.go new file mode 100644 index 00000000..15bd4c57 --- /dev/null +++ b/api/solana_wallet_middleware.go @@ -0,0 +1,61 @@ +package api + +import ( + "crypto/ed25519" + + "github.com/gofiber/fiber/v2" + "github.com/mr-tron/base58" + "go.uber.org/zap" +) + +// solanaWalletMiddleware verifies Solana wallet signatures from request headers. +// If the X-Solana-Wallet, X-Solana-Message, and X-Solana-Signature headers are +// present and valid, the verified wallet public key is stored in c.Locals("solanaWallet"). +// If headers are absent, the middleware is a no-op. If present but invalid, returns 401. +func (app *ApiServer) solanaWalletMiddleware(c *fiber.Ctx) error { + wallet := c.Get("X-Solana-Wallet") + message := c.Get("X-Solana-Message") + signature := c.Get("X-Solana-Signature") + + // No Solana headers — skip silently + if wallet == "" && message == "" && signature == "" { + return c.Next() + } + + // Partial headers — reject + if wallet == "" || message == "" || signature == "" { + return fiber.NewError(fiber.StatusUnauthorized, "incomplete Solana wallet headers: X-Solana-Wallet, X-Solana-Message, and X-Solana-Signature are all required") + } + + // Decode the base58 public key (32 bytes for ed25519) + pubkeyBytes, err := base58.Decode(wallet) + if err != nil || len(pubkeyBytes) != ed25519.PublicKeySize { + app.logger.Warn("solanaWalletMiddleware: invalid wallet public key", zap.String("wallet", wallet), zap.Error(err)) + return fiber.NewError(fiber.StatusUnauthorized, "invalid Solana wallet public key") + } + + // Decode the base58 signature (64 bytes for ed25519) + sigBytes, err := base58.Decode(signature) + if err != nil || len(sigBytes) != ed25519.SignatureSize { + app.logger.Warn("solanaWalletMiddleware: invalid signature", zap.Error(err)) + return fiber.NewError(fiber.StatusUnauthorized, "invalid Solana signature") + } + + // Verify the ed25519 signature + if !ed25519.Verify(pubkeyBytes, []byte(message), sigBytes) { + app.logger.Warn("solanaWalletMiddleware: signature verification failed", zap.String("wallet", wallet)) + return fiber.NewError(fiber.StatusUnauthorized, "Solana signature verification failed") + } + + app.logger.Debug("solanaWalletMiddleware: verified", zap.String("wallet", wallet)) + c.Locals("solanaWallet", wallet) + return c.Next() +} + +// tryGetSolanaWallet returns the verified Solana wallet from context, or "" if not set. +func (app *ApiServer) tryGetSolanaWallet(c *fiber.Ctx) string { + if w, ok := c.Locals("solanaWallet").(string); ok { + return w + } + return "" +} diff --git a/api/solana_wallet_middleware_test.go b/api/solana_wallet_middleware_test.go new file mode 100644 index 00000000..075b84d5 --- /dev/null +++ b/api/solana_wallet_middleware_test.go @@ -0,0 +1,83 @@ +package api + +import ( + "crypto/ed25519" + "io" + "net/http/httptest" + "testing" + + "github.com/gofiber/fiber/v2" + "github.com/mr-tron/base58" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSolanaWalletMiddleware(t *testing.T) { + app := emptyTestApp(t) + + var capturedWallet string + testApp := fiber.New() + testApp.Get("/", app.solanaWalletMiddleware, func(c *fiber.Ctx) error { + capturedWallet = app.tryGetSolanaWallet(c) + return c.SendStatus(fiber.StatusOK) + }) + + // Generate a fresh ed25519 keypair for testing + pub, priv, err := ed25519.GenerateKey(nil) + require.NoError(t, err) + wallet := base58.Encode(pub) + message := "Sign in to Audius" + sig := ed25519.Sign(priv, []byte(message)) + + t.Run("no headers is a no-op", func(t *testing.T) { + capturedWallet = "" + req := httptest.NewRequest("GET", "/", nil) + res, err := testApp.Test(req, -1) + assert.NoError(t, err) + assert.Equal(t, fiber.StatusOK, res.StatusCode) + assert.Equal(t, "", capturedWallet) + }) + + t.Run("valid signature sets solanaWallet", func(t *testing.T) { + capturedWallet = "" + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("X-Solana-Wallet", wallet) + req.Header.Set("X-Solana-Message", message) + req.Header.Set("X-Solana-Signature", base58.Encode(sig)) + res, err := testApp.Test(req, -1) + assert.NoError(t, err) + assert.Equal(t, fiber.StatusOK, res.StatusCode) + assert.Equal(t, wallet, capturedWallet) + }) + + t.Run("partial headers returns 401", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("X-Solana-Wallet", wallet) + // missing message and signature + res, err := testApp.Test(req, -1) + assert.NoError(t, err) + assert.Equal(t, fiber.StatusUnauthorized, res.StatusCode) + }) + + t.Run("invalid signature returns 401", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("X-Solana-Wallet", wallet) + req.Header.Set("X-Solana-Message", "wrong message") + req.Header.Set("X-Solana-Signature", base58.Encode(sig)) + res, err := testApp.Test(req, -1) + assert.NoError(t, err) + assert.Equal(t, fiber.StatusUnauthorized, res.StatusCode) + }) + + t.Run("invalid wallet pubkey returns 401", func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("X-Solana-Wallet", "notavalidkey") + req.Header.Set("X-Solana-Message", message) + req.Header.Set("X-Solana-Signature", base58.Encode(sig)) + res, err := testApp.Test(req, -1) + assert.NoError(t, err) + assert.Equal(t, fiber.StatusUnauthorized, res.StatusCode) + body, _ := io.ReadAll(res.Body) + assert.Contains(t, string(body), "invalid Solana wallet public key") + }) +} diff --git a/api/v1_track_stream.go b/api/v1_track_stream.go index bbae56dc..64a28786 100644 --- a/api/v1_track_stream.go +++ b/api/v1_track_stream.go @@ -9,14 +9,22 @@ func (app *ApiServer) v1TrackStream(c *fiber.Ctx) error { myId := app.getMyId(c) trackId := c.Locals("trackId").(int) - tracks, err := app.queries.Tracks(c.Context(), dbv1.TracksParams{ + params := dbv1.TracksParams{ GetTracksParams: dbv1.GetTracksParams{ MyID: myId, Ids: []int32{int32(trackId)}, AuthedWallet: app.tryGetAuthedWallet(c), IncludeUnlisted: true, }, - }) + } + + // If a verified Solana wallet is present, pass it through so + // GetBulkTrackAccess can check token gate balances for it. + if solWallet := app.tryGetSolanaWallet(c); solWallet != "" { + params.SolanaWallet = solWallet + } + + tracks, err := app.queries.Tracks(c.Context(), params) if err != nil { return err } @@ -26,11 +34,16 @@ func (app *ApiServer) v1TrackStream(c *fiber.Ctx) error { } track := tracks[0] - if !track.Access.Stream { - return fiber.NewError(fiber.StatusForbidden, "track not streamable") + + if track.Access.Stream { + return app.redirectToStream(c, track.Stream) } - streamURL := tryFindWorkingUrl(track.Stream) + return fiber.NewError(fiber.StatusForbidden, "track not streamable") +} + +func (app *ApiServer) redirectToStream(c *fiber.Ctx, stream *dbv1.MediaLink) error { + streamURL := tryFindWorkingUrl(stream) if skipPlayCount := c.Query("skip_play_count"); skipPlayCount != "" { q := streamURL.Query()