diff --git a/pkg/modelprovider/huggingface/downloader.go b/pkg/modelprovider/huggingface/downloader.go index 3f5616e..0250333 100644 --- a/pkg/modelprovider/huggingface/downloader.go +++ b/pkg/modelprovider/huggingface/downloader.go @@ -71,36 +71,57 @@ func parseModelURL(modelURL string) (owner, repo string, err error) { return owner, repo, nil } -// checkHuggingFaceAuth checks if the user is authenticated with HuggingFace -func checkHuggingFaceAuth() error { - // Try to find the HF token - token := os.Getenv("HF_TOKEN") - if token != "" { - return nil +// tokenFilePaths returns the list of candidate token file paths to check, +// in priority order. It respects the HF_HOME environment variable and +// supports both legacy (~/.huggingface/token) and modern +// (~/.cache/huggingface/token) locations. +func tokenFilePaths() []string { + var paths []string + + // If HF_HOME is set, check there first + if hfHome := os.Getenv("HF_HOME"); hfHome != "" { + paths = append(paths, filepath.Join(hfHome, "token")) } - // Check if the token file exists homeDir, err := os.UserHomeDir() - if err != nil { - return fmt.Errorf("failed to get user home directory: %w", err) + if err == nil { + // Modern location used by huggingface_hub >= 0.14 and the `hf` CLI + paths = append(paths, filepath.Join(homeDir, ".cache", "huggingface", "token")) + // Legacy location + paths = append(paths, filepath.Join(homeDir, ".huggingface", "token")) } - tokenPath := filepath.Join(homeDir, ".huggingface", "token") - if _, err := os.Stat(tokenPath); err == nil { + return paths +} + +// checkHuggingFaceAuth checks if the user is authenticated with HuggingFace +func checkHuggingFaceAuth() error { + // Try to find the HF token via environment variable + token := os.Getenv("HF_TOKEN") + if token != "" { return nil } - // Try using whoami command - if _, err := exec.LookPath("huggingface-cli"); err == nil { - cmd := exec.Command("huggingface-cli", "whoami") - cmd.Stdout = io.Discard - cmd.Stderr = io.Discard - if err := cmd.Run(); err == nil { + // Check token file paths (modern and legacy locations) + for _, tokenPath := range tokenFilePaths() { + if _, err := os.Stat(tokenPath); err == nil { return nil } } - return fmt.Errorf("not authenticated with HuggingFace. Please run: huggingface-cli login") + // Try using whoami command with available CLI tool + for _, cli := range []string{"hf", "huggingface-cli"} { + if path, err := exec.LookPath(cli); err == nil { + cmd := exec.Command(path, "whoami") + cmd.Stdout = io.Discard + cmd.Stderr = io.Discard + if err := cmd.Run(); err == nil { + return nil + } + } + } + + return fmt.Errorf("not authenticated with HuggingFace. Please run: hf auth login") } // getToken retrieves the HuggingFace token from environment or token file @@ -111,17 +132,13 @@ func getToken() (string, error) { return token, nil } - // Then check the token file - homeDir, err := os.UserHomeDir() - if err != nil { - return "", fmt.Errorf("failed to get user home directory: %w", err) - } - - tokenPath := filepath.Join(homeDir, ".huggingface", "token") - data, err := os.ReadFile(tokenPath) - if err != nil { - return "", fmt.Errorf("failed to read token file: %w", err) + // Check token file paths (modern and legacy locations) + for _, tokenPath := range tokenFilePaths() { + data, err := os.ReadFile(tokenPath) + if err == nil { + return strings.TrimSpace(string(data)), nil + } } - return strings.TrimSpace(string(data)), nil + return "", fmt.Errorf("HuggingFace token not found. Please run: hf auth login") } diff --git a/pkg/modelprovider/huggingface/downloader_test.go b/pkg/modelprovider/huggingface/downloader_test.go index f7c385c..5951877 100644 --- a/pkg/modelprovider/huggingface/downloader_test.go +++ b/pkg/modelprovider/huggingface/downloader_test.go @@ -17,6 +17,8 @@ package huggingface import ( + "os" + "path/filepath" "strings" "testing" ) @@ -166,3 +168,263 @@ func TestProvider_Name(t *testing.T) { t.Errorf("Provider.Name() = %v, want %v", got, "huggingface") } } + +func TestTokenFilePaths(t *testing.T) { + t.Run("without HF_HOME", func(t *testing.T) { + // Ensure HF_HOME is unset + t.Setenv("HF_HOME", "") + + paths := tokenFilePaths() + + // Should have exactly 2 paths: modern and legacy (no HF_HOME path) + if len(paths) != 2 { + t.Fatalf("tokenFilePaths() returned %d paths, want 2", len(paths)) + } + + // First should be the modern cache path + if !strings.Contains(paths[0], filepath.Join(".cache", "huggingface", "token")) { + t.Errorf("tokenFilePaths()[0] = %q, want path containing .cache/huggingface/token", paths[0]) + } + + // Second should be the legacy path + if !strings.Contains(paths[1], filepath.Join(".huggingface", "token")) { + t.Errorf("tokenFilePaths()[1] = %q, want path containing .huggingface/token", paths[1]) + } + }) + + t.Run("with HF_HOME", func(t *testing.T) { + customDir := t.TempDir() + t.Setenv("HF_HOME", customDir) + + paths := tokenFilePaths() + + // Should have 3 paths: HF_HOME, modern, legacy + if len(paths) != 3 { + t.Fatalf("tokenFilePaths() returned %d paths, want 3", len(paths)) + } + + // First should be the HF_HOME path + expected := filepath.Join(customDir, "token") + if paths[0] != expected { + t.Errorf("tokenFilePaths()[0] = %q, want %q", paths[0], expected) + } + }) +} + +func TestCheckHuggingFaceAuth(t *testing.T) { + t.Run("authenticated via HF_TOKEN env var", func(t *testing.T) { + t.Setenv("HF_TOKEN", "hf_test_token_123") + + err := checkHuggingFaceAuth() + if err != nil { + t.Errorf("checkHuggingFaceAuth() returned error %v, want nil", err) + } + }) + + t.Run("authenticated via token file at HF_HOME", func(t *testing.T) { + t.Setenv("HF_TOKEN", "") + + tmpDir := t.TempDir() + t.Setenv("HF_HOME", tmpDir) + // Override HOME so the modern/legacy paths don't accidentally find a real token + t.Setenv("HOME", t.TempDir()) + + tokenPath := filepath.Join(tmpDir, "token") + if err := os.WriteFile(tokenPath, []byte("hf_test_token"), 0644); err != nil { + t.Fatal(err) + } + + err := checkHuggingFaceAuth() + if err != nil { + t.Errorf("checkHuggingFaceAuth() returned error %v, want nil", err) + } + }) + + t.Run("authenticated via modern token path", func(t *testing.T) { + t.Setenv("HF_TOKEN", "") + t.Setenv("HF_HOME", "") + + fakeHome := t.TempDir() + t.Setenv("HOME", fakeHome) + + tokenPath := filepath.Join(fakeHome, ".cache", "huggingface", "token") + if err := os.MkdirAll(filepath.Dir(tokenPath), 0755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(tokenPath, []byte("hf_test_token"), 0644); err != nil { + t.Fatal(err) + } + + err := checkHuggingFaceAuth() + if err != nil { + t.Errorf("checkHuggingFaceAuth() returned error %v, want nil", err) + } + }) + + t.Run("authenticated via legacy token path", func(t *testing.T) { + t.Setenv("HF_TOKEN", "") + t.Setenv("HF_HOME", "") + + fakeHome := t.TempDir() + t.Setenv("HOME", fakeHome) + + tokenPath := filepath.Join(fakeHome, ".huggingface", "token") + if err := os.MkdirAll(filepath.Dir(tokenPath), 0755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(tokenPath, []byte("hf_test_token"), 0644); err != nil { + t.Fatal(err) + } + + err := checkHuggingFaceAuth() + if err != nil { + t.Errorf("checkHuggingFaceAuth() returned error %v, want nil", err) + } + }) + + t.Run("not authenticated", func(t *testing.T) { + t.Setenv("HF_TOKEN", "") + t.Setenv("HF_HOME", "") + t.Setenv("HOME", t.TempDir()) + // Ensure no CLI tools are found by overriding PATH + t.Setenv("PATH", t.TempDir()) + + err := checkHuggingFaceAuth() + if err == nil { + t.Error("checkHuggingFaceAuth() returned nil, want error") + } + if err != nil && !strings.Contains(err.Error(), "not authenticated") { + t.Errorf("checkHuggingFaceAuth() error = %q, want error containing 'not authenticated'", err.Error()) + } + }) +} + +func TestGetToken(t *testing.T) { + t.Run("token from HF_TOKEN env var", func(t *testing.T) { + t.Setenv("HF_TOKEN", "hf_env_token_abc") + + token, err := getToken() + if err != nil { + t.Fatalf("getToken() returned error: %v", err) + } + if token != "hf_env_token_abc" { + t.Errorf("getToken() = %q, want %q", token, "hf_env_token_abc") + } + }) + + t.Run("token from HF_HOME file", func(t *testing.T) { + t.Setenv("HF_TOKEN", "") + + tmpDir := t.TempDir() + t.Setenv("HF_HOME", tmpDir) + t.Setenv("HOME", t.TempDir()) + + tokenPath := filepath.Join(tmpDir, "token") + if err := os.WriteFile(tokenPath, []byte(" hf_file_token_xyz \n"), 0644); err != nil { + t.Fatal(err) + } + + token, err := getToken() + if err != nil { + t.Fatalf("getToken() returned error: %v", err) + } + if token != "hf_file_token_xyz" { + t.Errorf("getToken() = %q, want %q (should be trimmed)", token, "hf_file_token_xyz") + } + }) + + t.Run("token from modern cache path", func(t *testing.T) { + t.Setenv("HF_TOKEN", "") + t.Setenv("HF_HOME", "") + + fakeHome := t.TempDir() + t.Setenv("HOME", fakeHome) + + tokenPath := filepath.Join(fakeHome, ".cache", "huggingface", "token") + if err := os.MkdirAll(filepath.Dir(tokenPath), 0755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(tokenPath, []byte("hf_modern_token"), 0644); err != nil { + t.Fatal(err) + } + + token, err := getToken() + if err != nil { + t.Fatalf("getToken() returned error: %v", err) + } + if token != "hf_modern_token" { + t.Errorf("getToken() = %q, want %q", token, "hf_modern_token") + } + }) + + t.Run("token from legacy path", func(t *testing.T) { + t.Setenv("HF_TOKEN", "") + t.Setenv("HF_HOME", "") + + fakeHome := t.TempDir() + t.Setenv("HOME", fakeHome) + + tokenPath := filepath.Join(fakeHome, ".huggingface", "token") + if err := os.MkdirAll(filepath.Dir(tokenPath), 0755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(tokenPath, []byte("hf_legacy_token"), 0644); err != nil { + t.Fatal(err) + } + + token, err := getToken() + if err != nil { + t.Fatalf("getToken() returned error: %v", err) + } + if token != "hf_legacy_token" { + t.Errorf("getToken() = %q, want %q", token, "hf_legacy_token") + } + }) + + t.Run("modern path takes priority over legacy", func(t *testing.T) { + t.Setenv("HF_TOKEN", "") + t.Setenv("HF_HOME", "") + + fakeHome := t.TempDir() + t.Setenv("HOME", fakeHome) + + // Create both modern and legacy token files + modernPath := filepath.Join(fakeHome, ".cache", "huggingface", "token") + if err := os.MkdirAll(filepath.Dir(modernPath), 0755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(modernPath, []byte("modern_token"), 0644); err != nil { + t.Fatal(err) + } + + legacyPath := filepath.Join(fakeHome, ".huggingface", "token") + if err := os.MkdirAll(filepath.Dir(legacyPath), 0755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(legacyPath, []byte("legacy_token"), 0644); err != nil { + t.Fatal(err) + } + + token, err := getToken() + if err != nil { + t.Fatalf("getToken() returned error: %v", err) + } + if token != "modern_token" { + t.Errorf("getToken() = %q, want %q (modern should take priority)", token, "modern_token") + } + }) + + t.Run("no token found", func(t *testing.T) { + t.Setenv("HF_TOKEN", "") + t.Setenv("HF_HOME", "") + t.Setenv("HOME", t.TempDir()) + + _, err := getToken() + if err == nil { + t.Error("getToken() returned nil error, want error") + } + if err != nil && !strings.Contains(err.Error(), "token not found") { + t.Errorf("getToken() error = %q, want error containing 'token not found'", err.Error()) + } + }) +} diff --git a/pkg/modelprovider/huggingface/provider.go b/pkg/modelprovider/huggingface/provider.go index a4cc769..56ff05d 100644 --- a/pkg/modelprovider/huggingface/provider.go +++ b/pkg/modelprovider/huggingface/provider.go @@ -48,7 +48,19 @@ func (p *Provider) SupportsURL(url string) bool { return strings.Contains(url, "huggingface.co") } -// DownloadModel downloads a model from HuggingFace using the huggingface-cli +// findHFCLI returns the path of the first available HuggingFace CLI tool +// and whether it is the legacy huggingface-cli. +// It checks for the modern `hf` CLI first, then falls back to `huggingface-cli`. +func findHFCLI() (path string, isLegacy bool, err error) { + for _, cli := range []string{"hf", "huggingface-cli"} { + if path, err := exec.LookPath(cli); err == nil { + return path, cli == "huggingface-cli", nil + } + } + return "", false, fmt.Errorf("neither 'hf' nor 'huggingface-cli' found in PATH. Please install the HuggingFace CLI: pip install huggingface_hub[cli]") +} + +// DownloadModel downloads a model from HuggingFace using the HuggingFace CLI func (p *Provider) DownloadModel(ctx context.Context, modelURL, destDir string) (string, error) { owner, repo, err := parseModelURL(modelURL) if err != nil { @@ -57,9 +69,10 @@ func (p *Provider) DownloadModel(ctx context.Context, modelURL, destDir string) repoID := fmt.Sprintf("%s/%s", owner, repo) - // Check if huggingface-cli is available - if _, err := exec.LookPath("huggingface-cli"); err != nil { - return "", fmt.Errorf("huggingface-cli not found in PATH. Please install it using: pip install huggingface_hub[cli]") + // Find available HuggingFace CLI tool (hf or huggingface-cli) + cliPath, isLegacy, err := findHFCLI() + if err != nil { + return "", err } // Create destination directory if it doesn't exist @@ -70,15 +83,21 @@ func (p *Provider) DownloadModel(ctx context.Context, modelURL, destDir string) // Construct the download path downloadPath := filepath.Join(destDir, repo) - // Use huggingface-cli to download the model - // The --local-dir-use-symlinks=False flag ensures files are copied, not symlinked - cmd := exec.CommandContext(ctx, "huggingface-cli", "download", repoID, "--local-dir", downloadPath, "--local-dir-use-symlinks", "False") + // Build CLI arguments. The legacy huggingface-cli supports + // --local-dir-use-symlinks to ensure files are copied, not symlinked. + // The modern hf CLI removed that flag; --local-dir alone is sufficient. + args := []string{"download", repoID, "--local-dir", downloadPath} + if isLegacy { + args = append(args, "--local-dir-use-symlinks", "False") + } + + cmd := exec.CommandContext(ctx, cliPath, args...) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr if err := cmd.Run(); err != nil { - return "", fmt.Errorf("failed to download model using huggingface-cli: %w", err) + return "", fmt.Errorf("failed to download model using %s: %w", filepath.Base(cliPath), err) } return downloadPath, nil