diff --git a/gsma_dataset_creation/argilla_cli.py b/gsma_dataset_creation/argilla_cli.py index d057caf..635966d 100644 --- a/gsma_dataset_creation/argilla_cli.py +++ b/gsma_dataset_creation/argilla_cli.py @@ -8,6 +8,7 @@ import typer from loguru import logger +from gsma_dataset_creation.clients.argilla_client import get_argilla_client app = typer.Typer(help="Argilla dataset annotation commands") @@ -53,6 +54,9 @@ def upload( api_key: str | None = typer.Option( None, "--api-key", help="Argilla API key (defaults to ARGILLA_API_KEY env var)" ), + hf_token: str | None = typer.Option( + None, "--hf-token", help="HuggingFace API token for private HF spaces (defaults to HF_TOKEN env var)" + ), logger_level: str = typer.Option("INFO", "--logger-level", help="Logging level"), ) -> None: """ @@ -91,6 +95,7 @@ def upload( labels=labels, api_url=api_url, api_key=api_key, + hf_token=hf_token ) logger.info(f"✅ Successfully uploaded dataset: {dataset_name}") except Exception as e: @@ -113,6 +118,9 @@ def delete( api_key: str | None = typer.Option( None, "--api-key", help="Argilla API key (defaults to ARGILLA_API_KEY env var)" ), + hf_token: str | None = typer.Option( + None, "--hf-token", help="HuggingFace API token for private HF spaces (defaults to HF_TOKEN env var)" + ), logger_level: str = typer.Option("INFO", "--logger-level", help="Logging level"), ) -> None: """ @@ -166,6 +174,7 @@ def delete( workspace=workspace, api_url=api_url, api_key=api_key, + hf_token=hf_token ) logger.info("✅ Successfully deleted dataset") except Exception as e: @@ -188,6 +197,9 @@ def download( api_key: str | None = typer.Option( None, "--api-key", help="Argilla API key (defaults to ARGILLA_API_KEY env var)" ), + hf_token: str | None = typer.Option( + None, "--hf-token", help="HuggingFace API token for private HF spaces (defaults to HF_TOKEN env var)" + ), logger_level: str = typer.Option("INFO", "--logger-level", help="Logging level"), ) -> None: """ @@ -227,6 +239,7 @@ def download( workspace=workspace, api_url=api_url, api_key=api_key, + hf_token=hf_token, ) logger.info(f"✅ Successfully downloaded dataset: {dataset_name}") except Exception as e: @@ -277,6 +290,9 @@ def upload_by_subgroup( api_key: str | None = typer.Option( None, "--api-key", help="Argilla API key (defaults to ARGILLA_API_KEY env var)" ), + hf_token: str | None = typer.Option( + None, "--hf-token", help="HuggingFace API token for private HF spaces (defaults to HF_TOKEN env var)" + ), dataset_name_prefix: str = typer.Option( "gsma_subgroup", "--dataset-name-prefix", @@ -330,7 +346,7 @@ def upload_by_subgroup( if dataset_path is not None and dataset_repo is not None: logger.error("❌ Error: Cannot specify both --dataset-path and --dataset-repo") raise typer.Exit(code=2) - + if dataset_path and not dataset_path.exists(): logger.error(f"❌ Dataset path not found: {dataset_path}") raise typer.Exit(code=2) @@ -362,6 +378,7 @@ def upload_by_subgroup( workspace_name=effective_workspace, api_url=api_url, api_key=api_key, + hf_token=hf_token, ) if dataset_name is None: @@ -374,6 +391,8 @@ def upload_by_subgroup( logger.info(f"🔒 Password: {effective_working_group.lower()}-gsma") except Exception as e: logger.error(f"❌ Upload failed: {e}") + import traceback + logger.error(traceback.format_exc()) raise typer.Exit(code=1) from e @@ -392,6 +411,9 @@ def delete_workspace( api_key: str | None = typer.Option( None, "--api-key", help="Argilla API key (defaults to ARGILLA_API_KEY env var)" ), + hf_token: str | None = typer.Option( + None, "--hf-token", help="HuggingFace API token for private HF spaces (defaults to HF_TOKEN env var)" + ), logger_level: str = typer.Option("INFO", "--logger-level", help="Logging level"), ) -> None: """ @@ -414,16 +436,6 @@ def delete_workspace( logger.error("Install with: pip install argilla") raise typer.Exit(code=1) from e - # Get API credentials - api_url = api_url or os.getenv("ARGILLA_API_URL") - api_key = api_key or os.getenv("ARGILLA_API_KEY") - - if not api_url or not api_key: - logger.error( - "❌ Argilla credentials not found. Set ARGILLA_API_URL and ARGILLA_API_KEY environment variables or use --api-url and --api-key options." - ) - raise typer.Exit(code=1) - # Confirmation prompt if not force: confirm = typer.confirm( @@ -434,8 +446,8 @@ def delete_workspace( raise typer.Exit(code=0) try: - # Create client - client = rg.Argilla(api_url=api_url, api_key=api_key) + # Initialising an Argilla client + client = get_argilla_client(api_url, api_key, hf_token) # Get workspace workspace = client.workspaces(name=workspace_name) @@ -512,6 +524,9 @@ def add_user( api_key: str | None = typer.Option( None, "--api-key", help="Argilla API key (defaults to ARGILLA_API_KEY env var)" ), + hf_token: str | None = typer.Option( + None, "--hf-token", help="HuggingFace API token for private HF spaces (defaults to HF_TOKEN env var)" + ), logger_level: str = typer.Option("INFO", "--logger-level", help="Logging level"), ) -> None: """ @@ -567,19 +582,9 @@ def add_user( logger.error("Install with: pip install argilla") raise typer.Exit(code=1) from e - # Get API credentials - api_url = api_url or os.getenv("ARGILLA_API_URL") - api_key = api_key or os.getenv("ARGILLA_API_KEY") - - if not api_url or not api_key: - logger.error( - "❌ Argilla credentials not found. Set ARGILLA_API_URL and ARGILLA_API_KEY environment variables or use --api-url and --api-key options." - ) - raise typer.Exit(code=1) - try: - # Create client - client = rg.Argilla(api_url=api_url, api_key=api_key) + # Initialising an Argilla client + client = get_argilla_client(api_url, api_key, hf_token) # Validate all workspaces exist workspace_objs = [] @@ -676,6 +681,9 @@ def add_to_workspace( api_key: str | None = typer.Option( None, "--api-key", help="Argilla API key (defaults to ARGILLA_API_KEY env var)" ), + hf_token: str | None = typer.Option( + None, "--hf-token", help="HuggingFace API token for private HF spaces (defaults to HF_TOKEN env var)" + ), logger_level: str = typer.Option("INFO", "--logger-level", help="Logging level"), ) -> None: """ @@ -706,19 +714,9 @@ def add_to_workspace( logger.error("Install with: pip install argilla") raise typer.Exit(code=1) from e - # Get API credentials - api_url = api_url or os.getenv("ARGILLA_API_URL") - api_key = api_key or os.getenv("ARGILLA_API_KEY") - - if not api_url or not api_key: - logger.error( - "❌ Argilla credentials not found. Set ARGILLA_API_URL and ARGILLA_API_KEY environment variables or use --api-url and --api-key options." - ) - raise typer.Exit(code=1) - try: - # Create client - client = rg.Argilla(api_url=api_url, api_key=api_key) + # Initialising an Argilla client + client = get_argilla_client(api_url, api_key, hf_token) # Check if user exists user = client.users(username=username) @@ -778,6 +776,9 @@ def add_users( api_key: str | None = typer.Option( None, "--api-key", help="Argilla API key (defaults to ARGILLA_API_KEY env var)" ), + hf_token: str | None = typer.Option( + None, "--hf-token", help="HuggingFace API token for private HF spaces (defaults to HF_TOKEN env var)" + ), logger_level: str = typer.Option("INFO", "--logger-level", help="Logging level"), ) -> None: """ @@ -815,19 +816,9 @@ def add_users( logger.error("Install with: pip install argilla") raise typer.Exit(code=1) from e - # Get API credentials - api_url = api_url or os.getenv("ARGILLA_API_URL") - api_key = api_key or os.getenv("ARGILLA_API_KEY") - - if not api_url or not api_key: - logger.error( - "❌ Argilla credentials not found. Set ARGILLA_API_URL and ARGILLA_API_KEY environment variables or use --api-url and --api-key options." - ) - raise typer.Exit(code=1) - try: - # Create client - client = rg.Argilla(api_url=api_url, api_key=api_key) + # Initialising an Argilla client + client = get_argilla_client(api_url, api_key, hf_token) # Validate workspace exists workspace_obj = client.workspaces(name=workspace) @@ -917,6 +908,9 @@ def list_users( api_key: str | None = typer.Option( None, "--api-key", help="Argilla API key (defaults to ARGILLA_API_KEY env var)" ), + hf_token: str | None = typer.Option( + None, "--hf-token", help="HuggingFace API token for private HF spaces (defaults to HF_TOKEN env var)" + ), logger_level: str = typer.Option("INFO", "--logger-level", help="Logging level"), ) -> None: """ @@ -955,19 +949,9 @@ def list_users( logger.error("Install with: pip install argilla") raise typer.Exit(code=1) from e - # Get API credentials - api_url = api_url or os.getenv("ARGILLA_API_URL") - api_key = api_key or os.getenv("ARGILLA_API_KEY") - - if not api_url or not api_key: - logger.error( - "❌ Argilla credentials not found. Set ARGILLA_API_URL and ARGILLA_API_KEY environment variables or use --api-url and --api-key options." - ) - raise typer.Exit(code=1) - try: - # Create client - client = rg.Argilla(api_url=api_url, api_key=api_key) + # Initialising an Argilla client + client = get_argilla_client(api_url, api_key, hf_token) if workspace: # Get workspace-specific users @@ -1094,6 +1078,9 @@ def track_progress( api_key: str | None = typer.Option( None, "--api-key", help="Argilla API key (defaults to ARGILLA_API_KEY env var)" ), + hf_token: str | None = typer.Option( + None, "--hf-token", help="HuggingFace API token for private HF spaces (defaults to HF_TOKEN env var)" + ), logger_level: str = typer.Option("INFO", "--logger-level", help="Logging level"), ) -> None: """ @@ -1126,19 +1113,9 @@ def track_progress( logger.error("Install with: pip install argilla") raise typer.Exit(code=1) from e - # Get API credentials - api_url = api_url or os.getenv("ARGILLA_API_URL") - api_key = api_key or os.getenv("ARGILLA_API_KEY") - - if not api_url or not api_key: - logger.error( - "❌ Argilla credentials not found. Set ARGILLA_API_URL and ARGILLA_API_KEY environment variables or use --api-url and --api-key options." - ) - raise typer.Exit(code=1) - try: - # Create client - client = rg.Argilla(api_url=api_url, api_key=api_key) + # Initialising an Argilla client + client = get_argilla_client(api_url, api_key, hf_token) # Get workspace workspace_obj = client.workspaces(name=workspace) @@ -1307,6 +1284,9 @@ def list_workspaces( api_key: str | None = typer.Option( None, "--api-key", help="Argilla API key (defaults to ARGILLA_API_KEY env var)" ), + hf_token: str | None = typer.Option( + None, "--hf-token", help="HuggingFace API token for private HF spaces (defaults to HF_TOKEN env var)" + ), logger_level: str = typer.Option("INFO", "--logger-level", help="Logging level"), ) -> None: """ @@ -1332,19 +1312,9 @@ def list_workspaces( logger.error("Install with: pip install argilla") raise typer.Exit(code=1) from e - # Get API credentials - api_url = api_url or os.getenv("ARGILLA_API_URL") - api_key = api_key or os.getenv("ARGILLA_API_KEY") - - if not api_url or not api_key: - logger.error( - "❌ Argilla credentials not found. Set ARGILLA_API_URL and ARGILLA_API_KEY environment variables or use --api-url and --api-key options." - ) - raise typer.Exit(code=1) - try: - # Create client - client = rg.Argilla(api_url=api_url, api_key=api_key) + # Initialising an Argilla client + client = get_argilla_client(api_url, api_key, hf_token) # Get all workspaces workspaces = sorted([ws.name for ws in client.workspaces]) @@ -1390,6 +1360,9 @@ def list_datasets( api_key: str | None = typer.Option( None, "--api-key", help="Argilla API key (defaults to ARGILLA_API_KEY env var)" ), + hf_token: str | None = typer.Option( + None, "--hf-token", help="HuggingFace API token for private HF spaces (defaults to HF_TOKEN env var)" + ), logger_level: str = typer.Option("INFO", "--logger-level", help="Logging level"), ) -> None: """ @@ -1419,19 +1392,9 @@ def list_datasets( logger.error("Install with: pip install argilla") raise typer.Exit(code=1) from e - # Get API credentials - api_url = api_url or os.getenv("ARGILLA_API_URL") - api_key = api_key or os.getenv("ARGILLA_API_KEY") - - if not api_url or not api_key: - logger.error( - "❌ Argilla credentials not found. Set ARGILLA_API_URL and ARGILLA_API_KEY environment variables or use --api-url and --api-key options." - ) - raise typer.Exit(code=1) - try: - # Create client - client = rg.Argilla(api_url=api_url, api_key=api_key) + # Initialising an Argilla client + client = get_argilla_client(api_url, api_key, hf_token) # Get workspace workspace_obj = client.workspaces(name=workspace) @@ -1489,6 +1452,9 @@ def delete_user( api_key: str | None = typer.Option( None, "--api-key", help="Argilla API key (defaults to ARGILLA_API_KEY env var)" ), + hf_token: str | None = typer.Option( + None, "--hf-token", help="HuggingFace API token for private HF spaces (defaults to HF_TOKEN env var)" + ), logger_level: str = typer.Option("INFO", "--logger-level", help="Logging level"), ) -> None: """ @@ -1518,16 +1484,6 @@ def delete_user( logger.error("Install with: pip install argilla") raise typer.Exit(code=1) from e - # Get API credentials - api_url = api_url or os.getenv("ARGILLA_API_URL") - api_key = api_key or os.getenv("ARGILLA_API_KEY") - - if not api_url or not api_key: - logger.error( - "❌ Argilla credentials not found. Set ARGILLA_API_URL and ARGILLA_API_KEY environment variables or use --api-url and --api-key options." - ) - raise typer.Exit(code=1) - # Confirmation prompt if not force: confirm = typer.confirm( @@ -1538,8 +1494,8 @@ def delete_user( raise typer.Exit(code=0) try: - # Create client - client = rg.Argilla(api_url=api_url, api_key=api_key) + # Initialising an Argilla client + client = get_argilla_client(api_url, api_key, hf_token) # Find user user = client.users(username=username) diff --git a/gsma_dataset_creation/clients/__init__.py b/gsma_dataset_creation/clients/__init__.py new file mode 100644 index 0000000..c423f6f --- /dev/null +++ b/gsma_dataset_creation/clients/__init__.py @@ -0,0 +1 @@ +"""All clients for GSMA-data-creation repo""" \ No newline at end of file diff --git a/gsma_dataset_creation/clients/argilla_client.py b/gsma_dataset_creation/clients/argilla_client.py new file mode 100644 index 0000000..c7a4542 --- /dev/null +++ b/gsma_dataset_creation/clients/argilla_client.py @@ -0,0 +1,55 @@ +import argilla as rg +from loguru import logger +import traceback +import os + +def get_argilla_client(api_url: str | None, api_key: str | None, hf_token: str | None): + # Connect to Argilla + api_url = api_url or os.getenv("ARGILLA_API_URL") + api_key = api_key or os.getenv("ARGILLA_API_KEY") + hf_token = hf_token or os.getenv("HF_TOKEN") + + if not api_url: + raise ValueError( + "ARGILLA_API_URL environment variable not set and no api_url provided" + ) + + if not api_key: + raise ValueError( + "ARGILLA_API_KEY environment variable not set and no api_key provided" + ) + + try: + logger.info(f"🔗 Connecting to Argilla at {api_url}...") + logger.debug(f"🔑 Using API key: {mask_api_key(api_key)}") + if hf_token: + logger.debug(f"hf_token passed through using: {mask_api_key(hf_token)}, using this for Argilla instances ", + "sitting on private Hugging Face spaces") + return rg.Argilla(api_url=api_url, api_key=api_key, headers={"Authorization": f"Bearer {hf_token}"}) + else: + return rg.Argilla(api_url=api_url, api_key=api_key) + except Exception as error: + logger.error(f"❌ Failed to connect to Argilla, with error: {error}") + raise + +def mask_api_key(api_key: str) -> str: + """ + Mask API key for safe logging. + + Shows first 4 and last 4 characters, masks the middle. + + Args: + api_key: API key to mask + + Returns: + Masked API key string + + Examples: + >>> mask_api_key("abcdefghijklmnop") + 'abcd...mnop' + >>> mask_api_key("abc") + '***' + """ + if len(api_key) <= 8: + return "***" + return f"{api_key[:4]}...{api_key[-4:]}" \ No newline at end of file diff --git a/gsma_dataset_creation/validation/argilla_subgroup_uploader.py b/gsma_dataset_creation/validation/argilla_subgroup_uploader.py index 2b1143c..e5a327f 100644 --- a/gsma_dataset_creation/validation/argilla_subgroup_uploader.py +++ b/gsma_dataset_creation/validation/argilla_subgroup_uploader.py @@ -370,6 +370,7 @@ def upload_subgroup_dataset_to_argilla( workspace_name: str | None = None, api_url: str | None = None, api_key: str | None = None, + hf_token: str | None = None, ) -> str | None: """ Upload a subgroup-filtered dataset to Argilla for Q&A quality annotation. @@ -410,11 +411,21 @@ def upload_subgroup_dataset_to_argilla( "ARGILLA_API_KEY not set. Set environment variable or use --api-key" ) + if not hf_token: + hf_token = os.getenv("HF_TOKEN") + if not hf_token: + logger.warning( + "HF_TOKEN not set. Set environment variable or use --hf-token. Private HF spaces require HuggingFace Token." + ) + logger.info(f"🔗 Connecting to Argilla at: {api_url}") logger.info(f"🔑 Using API key: {mask_api_key(api_key)}") # Initialize Argilla client - client = rg.Argilla(api_url=api_url, api_key=api_key) + if hf_token: + client = rg.Argilla(api_url=api_url, api_key=api_key, headers={"Authorization": f"Bearer {hf_token}"}) + else: + client = rg.Argilla(api_url=api_url, api_key=api_key) # Load dataset logger.info("📂 Loading dataset...") diff --git a/gsma_dataset_creation/validation/argilla_uploader.py b/gsma_dataset_creation/validation/argilla_uploader.py index 3ad1f75..3f3af18 100644 --- a/gsma_dataset_creation/validation/argilla_uploader.py +++ b/gsma_dataset_creation/validation/argilla_uploader.py @@ -17,6 +17,7 @@ import argilla as rg except ImportError: raise ImportError("Argilla SDK not installed. Install with: pip install argilla") +from clients.argilla_client import get_argilla_client ANNOTATION_GUIDELINES = """When reviewing each Q/A pair, please check for the following: @@ -100,6 +101,7 @@ def upload_dataset_to_argilla( labels: list[str] | None = None, api_url: str | None = None, api_key: str | None = None, + hf_token: str | None = None ) -> str: """ Upload validation dataset to Argilla for human annotation. @@ -115,6 +117,7 @@ def upload_dataset_to_argilla( labels: List of annotation labels (defaults to standard 5 labels) api_url: Argilla API URL (defaults to env ARGILLA_API_URL) api_key: Argilla API key (defaults to env ARGILLA_API_KEY) + hf_token: Hugging face token to allow access to private spaces (defaults to env HF_TOKEN) Returns: Name of created Argilla dataset (includes timestamp and git commit hash) @@ -166,31 +169,8 @@ def upload_dataset_to_argilla( .reset_index(drop=True) ) - # Connect to Argilla - api_url = api_url or os.getenv("ARGILLA_API_URL") - api_key = api_key or os.getenv("ARGILLA_API_KEY") - - if not api_url: - raise ValueError( - "ARGILLA_API_URL environment variable not set and no api_url provided" - ) - - if not api_key: - raise ValueError( - "ARGILLA_API_KEY environment variable not set and no api_key provided" - ) - - logger.info(f"🔗 Connecting to Argilla at {api_url}...") - logger.debug(f"🔑 Using API key: {mask_api_key(api_key)}") - - try: - client = rg.Argilla(api_url=api_url, api_key=api_key) - logger.info("✅ Connected successfully") - except Exception as e: - # Mask API key in error messages - error_msg = str(e).replace(api_key, mask_api_key(api_key)) - logger.error(f"❌ Connection failed: {error_msg}") - raise ValueError(f"Failed to connect to Argilla: {error_msg}") from e + # Initialising an Argilla client + client = get_argilla_client(api_url, api_key, hf_token) # Create or get workspace ws = client.workspaces(workspace) @@ -356,6 +336,7 @@ def delete_dataset_from_argilla( workspace: str, api_url: str | None = None, api_key: str | None = None, + hf_token: str | None = None ) -> None: """ Delete a dataset from Argilla. @@ -369,31 +350,8 @@ def delete_dataset_from_argilla( # Parse dataset identifier from name or URL dataset_identifier = parse_dataset_identifier(dataset_name_or_url) - # Connect to Argilla - api_url = api_url or os.getenv("ARGILLA_API_URL") - api_key = api_key or os.getenv("ARGILLA_API_KEY") - - if not api_url: - raise ValueError( - "ARGILLA_API_URL environment variable not set and no api_url provided" - ) - - if not api_key: - raise ValueError( - "ARGILLA_API_KEY environment variable not set and no api_key provided" - ) - - logger.info(f"🔗 Connecting to Argilla at {api_url}...") - logger.debug(f"🔑 Using API key: {mask_api_key(api_key)}") - - try: - client = rg.Argilla(api_url=api_url, api_key=api_key) - logger.info("✅ Connected successfully") - except Exception as e: - # Mask API key in error messages - error_msg = str(e).replace(api_key, mask_api_key(api_key)) - logger.error(f"❌ Connection failed: {error_msg}") - raise ValueError(f"Failed to connect to Argilla: {error_msg}") from e + # Initialising an Argilla client + client = get_argilla_client(api_url, api_key, hf_token) # Check if identifier looks like a UUID (dataset ID) is_uuid = len(dataset_identifier) == 36 and dataset_identifier.count("-") == 4 @@ -444,6 +402,7 @@ def download_dataset_from_argilla( workspace: str, api_url: str | None = None, api_key: str | None = None, + hf_token: str | None = None, ) -> None: """ Download a dataset from Argilla. @@ -458,31 +417,8 @@ def download_dataset_from_argilla( # Parse dataset identifier from name or URL dataset_identifier = parse_dataset_identifier(dataset_name) - # Connect to Argilla - api_url = api_url or os.getenv("ARGILLA_API_URL") - api_key = api_key or os.getenv("ARGILLA_API_KEY") - - if not api_url: - raise ValueError( - "ARGILLA_API_URL environment variable not set and no api_url provided" - ) - - if not api_key: - raise ValueError( - "ARGILLA_API_KEY environment variable not set and no api_key provided" - ) - - logger.info(f"🔗 Connecting to Argilla at {api_url}...") - logger.debug(f"🔑 Using API key: {mask_api_key(api_key)}") - - try: - client = rg.Argilla(api_url=api_url, api_key=api_key) - logger.info("✅ Connected successfully") - except Exception as e: - # Mask API key in error messages - error_msg = str(e).replace(api_key, mask_api_key(api_key)) - logger.error(f"❌ Connection failed: {error_msg}") - raise ValueError(f"Failed to connect to Argilla: {error_msg}") from e + # Initialising an Argilla client + client = get_argilla_client(api_url, api_key, hf_token) # Check if identifier looks like a UUID (dataset ID) is_uuid = len(dataset_identifier) == 36 and dataset_identifier.count("-") == 4