Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 46 additions & 29 deletions pkg/modelprovider/huggingface/downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,36 +71,57 @@ func parseModelURL(modelURL string) (owner, repo string, err error) {
return owner, repo, nil
}

// checkHuggingFaceAuth checks if the user is authenticated with HuggingFace
func checkHuggingFaceAuth() error {
// Try to find the HF token
token := os.Getenv("HF_TOKEN")
if token != "" {
return nil
// tokenFilePaths returns the list of candidate token file paths to check,
// in priority order. It respects the HF_HOME environment variable and
// supports both legacy (~/.huggingface/token) and modern
// (~/.cache/huggingface/token) locations.
func tokenFilePaths() []string {
var paths []string

// If HF_HOME is set, check there first
if hfHome := os.Getenv("HF_HOME"); hfHome != "" {
paths = append(paths, filepath.Join(hfHome, "token"))
}

// Check if the token file exists
homeDir, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("failed to get user home directory: %w", err)
if err == nil {
// Modern location used by huggingface_hub >= 0.14 and the `hf` CLI
paths = append(paths, filepath.Join(homeDir, ".cache", "huggingface", "token"))
// Legacy location
paths = append(paths, filepath.Join(homeDir, ".huggingface", "token"))
}

tokenPath := filepath.Join(homeDir, ".huggingface", "token")
if _, err := os.Stat(tokenPath); err == nil {
return paths
}

// checkHuggingFaceAuth checks if the user is authenticated with HuggingFace
func checkHuggingFaceAuth() error {
// Try to find the HF token via environment variable
token := os.Getenv("HF_TOKEN")
if token != "" {
return nil
}

// Try using whoami command
if _, err := exec.LookPath("huggingface-cli"); err == nil {
cmd := exec.Command("huggingface-cli", "whoami")
cmd.Stdout = io.Discard
cmd.Stderr = io.Discard
if err := cmd.Run(); err == nil {
// Check token file paths (modern and legacy locations)
for _, tokenPath := range tokenFilePaths() {
if _, err := os.Stat(tokenPath); err == nil {
return nil
}
}

return fmt.Errorf("not authenticated with HuggingFace. Please run: huggingface-cli login")
// Try using whoami command with available CLI tool
for _, cli := range []string{"hf", "huggingface-cli"} {
if path, err := exec.LookPath(cli); err == nil {
cmd := exec.Command(path, "whoami")
cmd.Stdout = io.Discard
cmd.Stderr = io.Discard
if err := cmd.Run(); err == nil {
return nil
}
}
}

return fmt.Errorf("not authenticated with HuggingFace. Please run: hf auth login")
}

// getToken retrieves the HuggingFace token from environment or token file
Expand All @@ -111,17 +132,13 @@ func getToken() (string, error) {
return token, nil
}

// Then check the token file
homeDir, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("failed to get user home directory: %w", err)
}

tokenPath := filepath.Join(homeDir, ".huggingface", "token")
data, err := os.ReadFile(tokenPath)
if err != nil {
return "", fmt.Errorf("failed to read token file: %w", err)
// Check token file paths (modern and legacy locations)
for _, tokenPath := range tokenFilePaths() {
data, err := os.ReadFile(tokenPath)
if err == nil {
return strings.TrimSpace(string(data)), nil
}
}

return strings.TrimSpace(string(data)), nil
return "", fmt.Errorf("HuggingFace token not found. Please run: hf auth login")
}
262 changes: 262 additions & 0 deletions pkg/modelprovider/huggingface/downloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package huggingface

import (
"os"
"path/filepath"
"strings"
"testing"
)
Expand Down Expand Up @@ -166,3 +168,263 @@ func TestProvider_Name(t *testing.T) {
t.Errorf("Provider.Name() = %v, want %v", got, "huggingface")
}
}

func TestTokenFilePaths(t *testing.T) {
t.Run("without HF_HOME", func(t *testing.T) {
// Ensure HF_HOME is unset
t.Setenv("HF_HOME", "")

paths := tokenFilePaths()

// Should have exactly 2 paths: modern and legacy (no HF_HOME path)
if len(paths) != 2 {
t.Fatalf("tokenFilePaths() returned %d paths, want 2", len(paths))
}

// First should be the modern cache path
if !strings.Contains(paths[0], filepath.Join(".cache", "huggingface", "token")) {
t.Errorf("tokenFilePaths()[0] = %q, want path containing .cache/huggingface/token", paths[0])
}

// Second should be the legacy path
if !strings.Contains(paths[1], filepath.Join(".huggingface", "token")) {
t.Errorf("tokenFilePaths()[1] = %q, want path containing .huggingface/token", paths[1])
}
})

t.Run("with HF_HOME", func(t *testing.T) {
customDir := t.TempDir()
t.Setenv("HF_HOME", customDir)

paths := tokenFilePaths()

// Should have 3 paths: HF_HOME, modern, legacy
if len(paths) != 3 {
t.Fatalf("tokenFilePaths() returned %d paths, want 3", len(paths))
}

// First should be the HF_HOME path
expected := filepath.Join(customDir, "token")
if paths[0] != expected {
t.Errorf("tokenFilePaths()[0] = %q, want %q", paths[0], expected)
}
})
}

func TestCheckHuggingFaceAuth(t *testing.T) {
t.Run("authenticated via HF_TOKEN env var", func(t *testing.T) {
t.Setenv("HF_TOKEN", "hf_test_token_123")

err := checkHuggingFaceAuth()
if err != nil {
t.Errorf("checkHuggingFaceAuth() returned error %v, want nil", err)
}
})

t.Run("authenticated via token file at HF_HOME", func(t *testing.T) {
t.Setenv("HF_TOKEN", "")

tmpDir := t.TempDir()
t.Setenv("HF_HOME", tmpDir)
// Override HOME so the modern/legacy paths don't accidentally find a real token
t.Setenv("HOME", t.TempDir())

tokenPath := filepath.Join(tmpDir, "token")
if err := os.WriteFile(tokenPath, []byte("hf_test_token"), 0644); err != nil {
t.Fatal(err)
}

err := checkHuggingFaceAuth()
if err != nil {
t.Errorf("checkHuggingFaceAuth() returned error %v, want nil", err)
}
})

t.Run("authenticated via modern token path", func(t *testing.T) {
t.Setenv("HF_TOKEN", "")
t.Setenv("HF_HOME", "")

fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)

tokenPath := filepath.Join(fakeHome, ".cache", "huggingface", "token")
if err := os.MkdirAll(filepath.Dir(tokenPath), 0755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(tokenPath, []byte("hf_test_token"), 0644); err != nil {
t.Fatal(err)
}

err := checkHuggingFaceAuth()
if err != nil {
t.Errorf("checkHuggingFaceAuth() returned error %v, want nil", err)
}
})

t.Run("authenticated via legacy token path", func(t *testing.T) {
t.Setenv("HF_TOKEN", "")
t.Setenv("HF_HOME", "")

fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)

tokenPath := filepath.Join(fakeHome, ".huggingface", "token")
if err := os.MkdirAll(filepath.Dir(tokenPath), 0755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(tokenPath, []byte("hf_test_token"), 0644); err != nil {
t.Fatal(err)
}

err := checkHuggingFaceAuth()
if err != nil {
t.Errorf("checkHuggingFaceAuth() returned error %v, want nil", err)
}
})

t.Run("not authenticated", func(t *testing.T) {
t.Setenv("HF_TOKEN", "")
t.Setenv("HF_HOME", "")
t.Setenv("HOME", t.TempDir())
// Ensure no CLI tools are found by overriding PATH
t.Setenv("PATH", t.TempDir())

err := checkHuggingFaceAuth()
if err == nil {
t.Error("checkHuggingFaceAuth() returned nil, want error")
}
if err != nil && !strings.Contains(err.Error(), "not authenticated") {
t.Errorf("checkHuggingFaceAuth() error = %q, want error containing 'not authenticated'", err.Error())
}
})
}

func TestGetToken(t *testing.T) {
t.Run("token from HF_TOKEN env var", func(t *testing.T) {
t.Setenv("HF_TOKEN", "hf_env_token_abc")

token, err := getToken()
if err != nil {
t.Fatalf("getToken() returned error: %v", err)
}
if token != "hf_env_token_abc" {
t.Errorf("getToken() = %q, want %q", token, "hf_env_token_abc")
}
})

t.Run("token from HF_HOME file", func(t *testing.T) {
t.Setenv("HF_TOKEN", "")

tmpDir := t.TempDir()
t.Setenv("HF_HOME", tmpDir)
t.Setenv("HOME", t.TempDir())

tokenPath := filepath.Join(tmpDir, "token")
if err := os.WriteFile(tokenPath, []byte(" hf_file_token_xyz \n"), 0644); err != nil {
t.Fatal(err)
}

token, err := getToken()
if err != nil {
t.Fatalf("getToken() returned error: %v", err)
}
if token != "hf_file_token_xyz" {
t.Errorf("getToken() = %q, want %q (should be trimmed)", token, "hf_file_token_xyz")
}
})

t.Run("token from modern cache path", func(t *testing.T) {
t.Setenv("HF_TOKEN", "")
t.Setenv("HF_HOME", "")

fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)

tokenPath := filepath.Join(fakeHome, ".cache", "huggingface", "token")
if err := os.MkdirAll(filepath.Dir(tokenPath), 0755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(tokenPath, []byte("hf_modern_token"), 0644); err != nil {
t.Fatal(err)
}

token, err := getToken()
if err != nil {
t.Fatalf("getToken() returned error: %v", err)
}
if token != "hf_modern_token" {
t.Errorf("getToken() = %q, want %q", token, "hf_modern_token")
}
})

t.Run("token from legacy path", func(t *testing.T) {
t.Setenv("HF_TOKEN", "")
t.Setenv("HF_HOME", "")

fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)

tokenPath := filepath.Join(fakeHome, ".huggingface", "token")
if err := os.MkdirAll(filepath.Dir(tokenPath), 0755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(tokenPath, []byte("hf_legacy_token"), 0644); err != nil {
t.Fatal(err)
}

token, err := getToken()
if err != nil {
t.Fatalf("getToken() returned error: %v", err)
}
if token != "hf_legacy_token" {
t.Errorf("getToken() = %q, want %q", token, "hf_legacy_token")
}
})

t.Run("modern path takes priority over legacy", func(t *testing.T) {
t.Setenv("HF_TOKEN", "")
t.Setenv("HF_HOME", "")

fakeHome := t.TempDir()
t.Setenv("HOME", fakeHome)

// Create both modern and legacy token files
modernPath := filepath.Join(fakeHome, ".cache", "huggingface", "token")
if err := os.MkdirAll(filepath.Dir(modernPath), 0755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(modernPath, []byte("modern_token"), 0644); err != nil {
t.Fatal(err)
}

legacyPath := filepath.Join(fakeHome, ".huggingface", "token")
if err := os.MkdirAll(filepath.Dir(legacyPath), 0755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(legacyPath, []byte("legacy_token"), 0644); err != nil {
t.Fatal(err)
}

token, err := getToken()
if err != nil {
t.Fatalf("getToken() returned error: %v", err)
}
if token != "modern_token" {
t.Errorf("getToken() = %q, want %q (modern should take priority)", token, "modern_token")
}
})

t.Run("no token found", func(t *testing.T) {
t.Setenv("HF_TOKEN", "")
t.Setenv("HF_HOME", "")
t.Setenv("HOME", t.TempDir())

_, err := getToken()
if err == nil {
t.Error("getToken() returned nil error, want error")
}
if err != nil && !strings.Contains(err.Error(), "token not found") {
t.Errorf("getToken() error = %q, want error containing 'token not found'", err.Error())
}
})
}
Loading
Loading