Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion api/auth_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,11 @@ func (app *ApiServer) authMiddleware(c *fiber.Ctx) error {

c.Locals("authedWallet", wallet)

// A valid PKCE access token already proves the user authorized this client
_, pkceAuthed := c.Locals("oauthScope").(string)

// Not authorized to act on behalf of myId
if myId != 0 && !app.isAuthorizedRequest(c.Context(), myId, wallet) {
if myId != 0 && !pkceAuthed && !app.isAuthorizedRequest(c.Context(), myId, wallet) {
return fiber.NewError(
fiber.StatusForbidden,
fmt.Sprintf(
Expand Down
34 changes: 34 additions & 0 deletions api/dbv1/access.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ func (q *Queries) GetBulkTrackAccess(
myId int32,
tracks []*GetTracksRow,
users map[int32]*User,
solanaWallet string,
) (map[int32]Access, error) {
// Initialize result map
result := make(map[int32]Access)
Expand Down Expand Up @@ -203,6 +204,7 @@ func (q *Queries) GetBulkTrackAccess(
purchasedPlaylists := make(map[int32]bool)
prevPurchasedPlaylists := make(map[int32]bool)
userTokenBalances := make(map[string]int64)
walletTokenBalances := make(map[string]int64)
coinDecimals := make(map[string]int32)

g, ctx := errgroup.WithContext(ctx)
Expand Down Expand Up @@ -280,6 +282,7 @@ func (q *Queries) GetBulkTrackAccess(

// Query for token balances
if len(tokenGateTokenMintsSlice) > 0 {
// Look up balances from the per-user aggregate table
g.Go(func() error {
rows, err := q.db.Query(ctx, `
SELECT mint, COALESCE(balance, 0)
Expand All @@ -301,6 +304,32 @@ func (q *Queries) GetBulkTrackAccess(
return rows.Err()
})

// If a Solana wallet was provided (e.g. signed via middleware),
// also check balances from the token account balances table.
// Results are merged after g.Wait() to avoid concurrent map writes.
if solanaWallet != "" {
g.Go(func() error {
rows, err := q.db.Query(ctx, `
SELECT mint, COALESCE(balance, 0)
FROM sol_token_account_balances
WHERE owner = $1
AND mint = ANY($2)
`, solanaWallet, tokenGateTokenMintsSlice)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var mint string
var balance int64
if err := rows.Scan(&mint, &balance); err == nil {
walletTokenBalances[mint] = balance
}
}
return rows.Err()
})
}

// Query for coin decimals
g.Go(func() error {
rows, err := q.db.Query(ctx, `
Expand Down Expand Up @@ -389,6 +418,11 @@ func (q *Queries) GetBulkTrackAccess(
return nil, err
}

// Merge wallet balances by summing with user balances
for mint, balance := range walletTokenBalances {
userTokenBalances[mint] += balance
}

// Now determine access for each track
for _, track := range tracks {
if track == nil {
Expand Down
3 changes: 2 additions & 1 deletion api/dbv1/tracks.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

type TracksParams struct {
GetTracksParams
SolanaWallet string
}

// Track is the standard track type containing all track data
Expand Down Expand Up @@ -84,7 +85,7 @@ func (q *Queries) TracksKeyed(ctx context.Context, arg TracksParams) (map[int32]
}

// Get bulk access for all tracks
accessMap, err := q.GetBulkTrackAccess(ctx, arg.MyID.(int32), trackPtrs, userPtrMap)
accessMap, err := q.GetBulkTrackAccess(ctx, arg.MyID.(int32), trackPtrs, userPtrMap, arg.SolanaWallet)
if err != nil {
return nil, err
}
Expand Down
1 change: 1 addition & 0 deletions api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ func NewApiServer(config config.Config) *ApiServer {
app.Use(app.isFullMiddleware)
app.Use(app.resolveMyIdMiddleware)
app.Use(app.authMiddleware)
app.Use(app.solanaWalletMiddleware)

v1 := app.Group("/v1")
v1Full := app.Group("/v1/full")
Expand Down
61 changes: 61 additions & 0 deletions api/solana_wallet_middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package api

import (
"crypto/ed25519"

"github.com/gofiber/fiber/v2"
"github.com/mr-tron/base58"
"go.uber.org/zap"
)

// solanaWalletMiddleware verifies Solana wallet signatures from request headers.
// If the X-Solana-Wallet, X-Solana-Message, and X-Solana-Signature headers are
// present and valid, the verified wallet public key is stored in c.Locals("solanaWallet").
// If headers are absent, the middleware is a no-op. If present but invalid, returns 401.
func (app *ApiServer) solanaWalletMiddleware(c *fiber.Ctx) error {
wallet := c.Get("X-Solana-Wallet")
message := c.Get("X-Solana-Message")
signature := c.Get("X-Solana-Signature")

// No Solana headers — skip silently
if wallet == "" && message == "" && signature == "" {
return c.Next()
}

// Partial headers — reject
if wallet == "" || message == "" || signature == "" {
return fiber.NewError(fiber.StatusUnauthorized, "incomplete Solana wallet headers: X-Solana-Wallet, X-Solana-Message, and X-Solana-Signature are all required")
}

// Decode the base58 public key (32 bytes for ed25519)
pubkeyBytes, err := base58.Decode(wallet)
if err != nil || len(pubkeyBytes) != ed25519.PublicKeySize {
app.logger.Warn("solanaWalletMiddleware: invalid wallet public key", zap.String("wallet", wallet), zap.Error(err))
return fiber.NewError(fiber.StatusUnauthorized, "invalid Solana wallet public key")
}

// Decode the base58 signature (64 bytes for ed25519)
sigBytes, err := base58.Decode(signature)
if err != nil || len(sigBytes) != ed25519.SignatureSize {
app.logger.Warn("solanaWalletMiddleware: invalid signature", zap.Error(err))
return fiber.NewError(fiber.StatusUnauthorized, "invalid Solana signature")
}

// Verify the ed25519 signature
if !ed25519.Verify(pubkeyBytes, []byte(message), sigBytes) {
app.logger.Warn("solanaWalletMiddleware: signature verification failed", zap.String("wallet", wallet))
return fiber.NewError(fiber.StatusUnauthorized, "Solana signature verification failed")
}

app.logger.Debug("solanaWalletMiddleware: verified", zap.String("wallet", wallet))
c.Locals("solanaWallet", wallet)
return c.Next()
}

// tryGetSolanaWallet returns the verified Solana wallet from context, or "" if not set.
func (app *ApiServer) tryGetSolanaWallet(c *fiber.Ctx) string {
if w, ok := c.Locals("solanaWallet").(string); ok {
return w
}
return ""
}
83 changes: 83 additions & 0 deletions api/solana_wallet_middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package api

import (
"crypto/ed25519"
"io"
"net/http/httptest"
"testing"

"github.com/gofiber/fiber/v2"
"github.com/mr-tron/base58"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestSolanaWalletMiddleware(t *testing.T) {
app := emptyTestApp(t)

var capturedWallet string
testApp := fiber.New()
testApp.Get("/", app.solanaWalletMiddleware, func(c *fiber.Ctx) error {
capturedWallet = app.tryGetSolanaWallet(c)
return c.SendStatus(fiber.StatusOK)
})

// Generate a fresh ed25519 keypair for testing
pub, priv, err := ed25519.GenerateKey(nil)
require.NoError(t, err)
wallet := base58.Encode(pub)
message := "Sign in to Audius"
sig := ed25519.Sign(priv, []byte(message))

t.Run("no headers is a no-op", func(t *testing.T) {
capturedWallet = ""
req := httptest.NewRequest("GET", "/", nil)
res, err := testApp.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, fiber.StatusOK, res.StatusCode)
assert.Equal(t, "", capturedWallet)
})

t.Run("valid signature sets solanaWallet", func(t *testing.T) {
capturedWallet = ""
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("X-Solana-Wallet", wallet)
req.Header.Set("X-Solana-Message", message)
req.Header.Set("X-Solana-Signature", base58.Encode(sig))
res, err := testApp.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, fiber.StatusOK, res.StatusCode)
assert.Equal(t, wallet, capturedWallet)
})

t.Run("partial headers returns 401", func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("X-Solana-Wallet", wallet)
// missing message and signature
res, err := testApp.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, fiber.StatusUnauthorized, res.StatusCode)
})

t.Run("invalid signature returns 401", func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("X-Solana-Wallet", wallet)
req.Header.Set("X-Solana-Message", "wrong message")
req.Header.Set("X-Solana-Signature", base58.Encode(sig))
res, err := testApp.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, fiber.StatusUnauthorized, res.StatusCode)
})

t.Run("invalid wallet pubkey returns 401", func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("X-Solana-Wallet", "notavalidkey")
req.Header.Set("X-Solana-Message", message)
req.Header.Set("X-Solana-Signature", base58.Encode(sig))
res, err := testApp.Test(req, -1)
assert.NoError(t, err)
assert.Equal(t, fiber.StatusUnauthorized, res.StatusCode)
body, _ := io.ReadAll(res.Body)
assert.Contains(t, string(body), "invalid Solana wallet public key")
})
}
23 changes: 18 additions & 5 deletions api/v1_track_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,22 @@ func (app *ApiServer) v1TrackStream(c *fiber.Ctx) error {
myId := app.getMyId(c)
trackId := c.Locals("trackId").(int)

tracks, err := app.queries.Tracks(c.Context(), dbv1.TracksParams{
params := dbv1.TracksParams{
GetTracksParams: dbv1.GetTracksParams{
MyID: myId,
Ids: []int32{int32(trackId)},
AuthedWallet: app.tryGetAuthedWallet(c),
IncludeUnlisted: true,
},
})
}

// If a verified Solana wallet is present, pass it through so
// GetBulkTrackAccess can check token gate balances for it.
if solWallet := app.tryGetSolanaWallet(c); solWallet != "" {
params.SolanaWallet = solWallet
}

tracks, err := app.queries.Tracks(c.Context(), params)
if err != nil {
return err
}
Expand All @@ -26,11 +34,16 @@ func (app *ApiServer) v1TrackStream(c *fiber.Ctx) error {
}

track := tracks[0]
if !track.Access.Stream {
return fiber.NewError(fiber.StatusForbidden, "track not streamable")

if track.Access.Stream {
return app.redirectToStream(c, track.Stream)
}

streamURL := tryFindWorkingUrl(track.Stream)
return fiber.NewError(fiber.StatusForbidden, "track not streamable")
}

func (app *ApiServer) redirectToStream(c *fiber.Ctx, stream *dbv1.MediaLink) error {
streamURL := tryFindWorkingUrl(stream)

if skipPlayCount := c.Query("skip_play_count"); skipPlayCount != "" {
q := streamURL.Query()
Expand Down
Loading