Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 256 additions & 43 deletions src/specify_cli/__init__.py

Large diffs are not rendered by default.

236 changes: 206 additions & 30 deletions src/specify_cli/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import json
import hashlib
import os
import sys
import tarfile
import tempfile
import zipfile
import shutil
Expand Down Expand Up @@ -106,6 +108,137 @@ def normalize_priority(value: Any, default: int = 10) -> int:
return priority if priority >= 1 else default


def detect_archive_format(url: str, content_type: str = "") -> str:
"""Detect archive format from URL path extension or Content-Type header.

Args:
url: URL or file path to inspect.
content_type: Optional ``Content-Type`` header value from the HTTP response.

Returns:
``"zip"`` for ZIP archives, ``"tar.gz"`` for gzipped tarballs, or ``""``
when the format cannot be determined.
"""
# Strip query-string / fragment before examining the path extension.
url_path = url.split("?")[0].split("#")[0].lower()
if url_path.endswith(".zip"):
return "zip"
if url_path.endswith(".tar.gz") or url_path.endswith(".tgz"):
return "tar.gz"

# Fall back to Content-Type header inspection.
ct = content_type.lower()
if "application/zip" in ct or "application/x-zip" in ct:
return "zip"
if any(
t in ct
for t in (
"application/gzip",
"application/x-gzip",
"application/x-tar+gzip",
)
):
return "tar.gz"

return ""


def safe_extract_tarball(
archive_path: Path,
dest_dir: Path,
error_class: "type[Exception]" = Exception,
) -> None:
"""Safely extract a ``.tar.gz`` or ``.tgz`` archive into *dest_dir*.

All members are validated before extraction to prevent *tar slip*
(path traversal) attacks. Symlinks, hard links, and special files
(devices, FIFOs, etc.) are rejected.

On Python 3.12 and later the ``"data"`` extraction filter is applied
for an additional layer of OS-level protection. On earlier versions
the explicit member list (containing only pre-validated regular files
and directories) is passed to ``extractall()`` — since all symlinks are
already rejected in the validation phase, no archive-introduced symlink
can be followed during extraction.

Args:
archive_path: Path to the ``.tar.gz``/``.tgz`` archive.
dest_dir: Destination directory (must already exist).
error_class: Exception class to raise on unsafe entries.

Raises:
error_class: If any member is unsafe or the archive cannot be read.
"""
dest_resolved = dest_dir.resolve()
# Tar metadata member types to skip during validation — they carry no
# extractable payload and are generated automatically by many common
# archiving tools (e.g. PAX headers, GNU longname/longlink entries).
# GNUTYPE_SPARSE is intentionally excluded: it carries a real file payload
# and tarfile.TarInfo.isreg() returns True for it, so it passes the
# regular-file check below and is extracted correctly.
_TAR_METADATA_TYPES = (
tarfile.XHDTYPE, # PAX extended header
tarfile.XGLTYPE, # PAX global extended header
tarfile.SOLARIS_XHDTYPE, # Solaris PAX extended header
tarfile.GNUTYPE_LONGNAME, # GNU long path name (metadata only)
tarfile.GNUTYPE_LONGLINK, # GNU long link name (metadata only)
)
Comment on lines +179 to +185

try:
with tarfile.open(archive_path, "r:gz") as tf:
members = tf.getmembers()
safe_members = []

# Validate every member before extracting anything.
for member in members:
# Reject absolute paths and any path component that is "..".
if os.path.isabs(member.name) or any(
part == ".." for part in member.name.replace("\\", "/").split("/")
):
raise error_class(
f"Unsafe path in tar archive: {member.name} (potential path traversal)"
)

# Confirm the resolved path stays inside dest_dir.
member_path = (dest_dir / member.name).resolve()
try:
member_path.relative_to(dest_resolved)
except ValueError:
raise error_class(
f"Unsafe path in tar archive: {member.name} (potential path traversal)"
)

# Skip tar metadata members — they carry no extractable payload.
if member.type in _TAR_METADATA_TYPES:
continue

# Reject symlinks and hard links.
if member.issym() or member.islnk():
raise error_class(
f"Symlinks are not allowed in archive: {member.name}"
)

# Reject devices, FIFOs and other special file types.
if not (member.isreg() or member.isdir()):
raise error_class(
f"Non-regular file in archive: {member.name}"
)

safe_members.append(member)

# Extract — use the "data" filter on Python 3.12+ for extra hardening.
# On older versions pass only the pre-validated members so that no
# unvetted entry (added concurrently or via a race) slips through.
if sys.version_info >= (3, 12):
tf.extractall(dest_dir, filter="data") # type: ignore[call-arg]
else:
tf.extractall(dest_dir, members=safe_members) # noqa: S202 — validated above
except error_class:
raise
except (tarfile.TarError, OSError) as e:
raise error_class(f"Failed to read archive {archive_path}: {e}") from e


@dataclass
class CatalogEntry:
"""Represents a single catalog entry in the catalog stack."""
Expand Down Expand Up @@ -1202,18 +1335,19 @@ def install_from_zip(
speckit_version: str,
priority: int = 10,
) -> ExtensionManifest:
"""Install extension from ZIP file.
"""Install extension from a ZIP or ``.tar.gz``/``.tgz`` archive.

Args:
zip_path: Path to extension ZIP file
zip_path: Path to the extension archive (ZIP or gzipped tarball).
speckit_version: Current spec-kit version
priority: Resolution priority (lower = higher precedence, default 10)

Returns:
Installed extension manifest

Raises:
ValidationError: If manifest is invalid or priority is invalid
ValidationError: If manifest is invalid, the archive is unsafe, or
priority is invalid
CompatibilityError: If extension is incompatible
"""
# Validate priority early
Expand All @@ -1223,21 +1357,27 @@ def install_from_zip(
with tempfile.TemporaryDirectory() as tmpdir:
temp_path = Path(tmpdir)

# Extract ZIP safely (prevent Zip Slip attack)
with zipfile.ZipFile(zip_path, 'r') as zf:
# Validate all paths first before extracting anything
temp_path_resolved = temp_path.resolve()
for member in zf.namelist():
member_path = (temp_path / member).resolve()
# Use is_relative_to for safe path containment check
try:
member_path.relative_to(temp_path_resolved)
except ValueError:
raise ValidationError(
f"Unsafe path in ZIP archive: {member} (potential path traversal)"
)
# Only extract after all paths are validated
zf.extractall(temp_path)
archive_fmt = detect_archive_format(str(zip_path))

if archive_fmt == "tar.gz":
# Extract tarball safely (prevent tar slip attack)
safe_extract_tarball(zip_path, temp_path, ValidationError)
else:
# Extract ZIP safely (prevent Zip Slip attack)
with zipfile.ZipFile(zip_path, 'r') as zf:
# Validate all paths first before extracting anything
temp_path_resolved = temp_path.resolve()
for member in zf.namelist():
member_path = (temp_path / member).resolve()
# Use is_relative_to for safe path containment check
try:
member_path.relative_to(temp_path_resolved)
except ValueError:
raise ValidationError(
f"Unsafe path in ZIP archive: {member} (potential path traversal)"
)
# Only extract after all paths are validated
zf.extractall(temp_path)

# Find extension directory (may be nested)
extension_dir = temp_path
Expand All @@ -1251,7 +1391,7 @@ def install_from_zip(
manifest_path = extension_dir / "extension.yml"

if not manifest_path.exists():
raise ValidationError("No extension.yml found in ZIP file")
raise ValidationError("No extension.yml found in archive")

# Install from extracted directory
return self.install_from_directory(extension_dir, speckit_version, priority=priority)
Expand Down Expand Up @@ -1965,14 +2105,18 @@ def get_extension_info(self, extension_id: str) -> Optional[Dict[str, Any]]:
return None

def download_extension(self, extension_id: str, target_dir: Optional[Path] = None) -> Path:
"""Download extension ZIP from catalog.
"""Download extension archive from catalog.

Supports both ZIP (``.zip``) and gzipped tarball (``.tar.gz``/``.tgz``)
archives. The format is detected from the download URL's path extension;
when ambiguous the ``Content-Type`` header is used as a fallback.

Args:
extension_id: ID of the extension to download
target_dir: Directory to save ZIP file (defaults to temp directory)
target_dir: Directory to save the archive (defaults to cache directory)

Returns:
Path to downloaded ZIP file
Path to downloaded archive file

Raises:
ExtensionError: If extension not found or download fails
Expand Down Expand Up @@ -2011,21 +2155,53 @@ def download_extension(self, extension_id: str, target_dir: Optional[Path] = Non
target_dir.mkdir(parents=True, exist_ok=True)

version = ext_info.get("version", "unknown")
zip_filename = f"{extension_id}-{version}.zip"
zip_path = target_dir / zip_filename

# Download the ZIP file
# Detect archive format from URL; resolve via Content-Type when needed.
# `final_url` may differ from `download_url` if the server redirects.
archive_fmt = detect_archive_format(download_url)
final_url = download_url

# Download the archive
try:
with self._open_url(download_url, timeout=60) as response:
zip_data = response.read()

zip_path.write_bytes(zip_data)
return zip_path
final_url = response.geturl()
if not archive_fmt:
content_type = response.headers.get("Content-Type", "")
archive_fmt = detect_archive_format(final_url, content_type)
archive_data = response.read()
Comment on lines +2159 to +2171

except urllib.error.URLError as e:
raise ExtensionError(f"Failed to download extension from {download_url}: {e}")
except IOError as e:
raise ExtensionError(f"Failed to save extension ZIP: {e}")
raise ExtensionError(f"Failed to read extension archive from {download_url}: {e}")

# Re-validate scheme after any redirect to guard against scheme-downgrade.
_final_parsed = urlparse(final_url)
_final_is_localhost = _final_parsed.hostname in ("localhost", "127.0.0.1", "::1")
if _final_parsed.scheme != "https" and not (
_final_parsed.scheme == "http" and _final_is_localhost
):
raise ExtensionError(
f"Extension download URL was redirected to a non-HTTPS URL: {final_url}"
)

# Choose file extension based on detected format.
if not archive_fmt:
raise ExtensionError(
f"Could not determine archive format for {download_url}. "
"Ensure the URL points to a .zip or .tar.gz/.tgz file."
)
if archive_fmt == "tar.gz":
archive_filename = f"{extension_id}-{version}.tar.gz"
else:
archive_filename = f"{extension_id}-{version}.zip"

Comment thread
mnriem marked this conversation as resolved.
archive_path = target_dir / archive_filename
try:
archive_path.write_bytes(archive_data)
except IOError as e:
raise ExtensionError(f"Failed to save extension archive: {e}")
return archive_path

def clear_cache(self):
"""Clear the catalog cache (both legacy and URL-hash-based files)."""
Expand Down
Loading