Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 85 additions & 6 deletions github_rest_api/github.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
"""Simple wrapper of GitHub REST APIs."""

import re
import sys
from abc import ABCMeta, abstractmethod
from base64 import b64encode
from collections.abc import Sequence
from enum import StrEnum
from pathlib import Path
from typing import Any, Callable
from urllib.parse import quote

import requests
from nacl import encoding, public

from github_rest_api.pr_content import (
deterministic_body,
deterministic_title,
generate_pr_content,
)

URL_API = "https://api.github.com"

_SECRET_NAME_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
# Default LiteLLM 'provider/model' used to generate PR titles and descriptions.
DEFAULT_PR_MODEL = "anthropic/claude-haiku-4-5-20251001"


def _validate_secret_name(name: str) -> None:
Expand All @@ -34,7 +43,7 @@ def _validate_secret_name(name: str) -> None:
f"Invalid secret name {name!r}: names must not start with the "
"reserved 'GITHUB_' prefix."
)
if not _SECRET_NAME_PATTERN.fullmatch(name):
if not re.fullmatch(r"^[A-Za-z_][A-Za-z0-9_]*$", name):
raise ValueError(
f"Invalid secret name {name!r}: names may only contain alphanumeric "
"characters and underscores, and must not start with a digit."
Expand Down Expand Up @@ -211,6 +220,7 @@ def __init__(self, token: str, repo: str):
self._url_issues = f"{self._url_repo}/issues"
self._url_releases = f"{self._url_repo}/releases"
self._url_secrets = f"{self._url_repo}/actions/secrets"
self._url_compare = f"{self._url_repo}/compare"

def get_releases(self, n: int = 0) -> list[dict[str, Any]]:
"""List releases in this repository."""
Expand Down Expand Up @@ -270,21 +280,72 @@ def get_pull_requests(self, n: int = 0) -> list[dict[str, Any]]:
"""List pull requests in this repository."""
return self._extract_all(url=self._url_pull, n=n)

def create_pull_request(self, json: dict[str, str]) -> dict[str, Any] | None:
def _generate_pull_request_content(
self, base: str, head: str, model: str
) -> tuple[str, str] | None:
"""Generate a `(title, body)` for a new PR from the head/base comparison.

Returns None when there is nothing to describe (an empty comparison), so
the caller keeps the provided title.

:param base: The base branch the PR merges into.
:param head: The head branch containing the changes.
:param model: The LiteLLM 'provider/model' string for LLM generation.
When empty, the title and body are generated deterministically from
the commit messages and changed files; otherwise an LLM is used,
falling back to deterministic generation on failure.
"""
compare = self.compare(base=base, head=head)
if not compare.get("commits"):
return None
if model:
try:
return generate_pr_content(compare, model=model)
except Exception as error:
print(
f"LLM PR generation failed ({error}); "
"falling back to deterministic content.",
file=sys.stderr,
)
Comment on lines +305 to +309

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using print(..., file=sys.stderr) for error logging in a library is generally discouraged because it prevents library consumers from configuring, capturing, or suppressing the log output. It is highly recommended to use Python's standard logging module instead.

Suggested change
print(
f"LLM PR generation failed ({error}); "
"falling back to deterministic content.",
file=sys.stderr,
)
import logging
logging.getLogger(__name__).warning(
f"LLM PR generation failed ({error}); "
"falling back to deterministic content."
)

return deterministic_title(compare), deterministic_body(compare)

def create_pull_request(
self,
json: dict[str, str],
model: str = "",
) -> dict[str, Any] | None:
"""Create a pull request.

A Conventional-Commits title and a detailed Markdown body are generated
for a newly created PR. A caller-provided 'title' or 'body' in `json`
takes precedence; only the missing field(s) are generated. To skip
generation entirely, provide both 'title' and 'body' in `json`.

:param json: A dict containing info (e.g., base, head, title, body, etc.)
about the pull request to be created.
It's passed to the json parameter of requests.post.
:param model: The LiteLLM 'provider/model' string. When empty (the
default), missing title/body are generated deterministically from
the commit messages and changed files. When non-empty, an LLM is
used (with the optional 'ai' extra installed and the matching
provider API key read from the environment), falling back to
deterministic generation on failure.
"""
if not ("head" in json and "base" in json):
raise ValueError("The data dict must contains keys head and base!")
# return an existing PR
prs = self.get_pull_requests()
for pr in prs:
for pr in self.get_pull_requests():
if pr["head"]["ref"] == json["head"] and pr["base"]["ref"] == json["base"]:
return pr
# creat a new PR
# generate any title/body not already provided by the caller
if "title" not in json or "body" not in json:
content = self._generate_pull_request_content(
base=json["base"], head=json["head"], model=model
)
if content is not None:
# caller-provided title/body take precedence over generated ones
json = {"title": content[0], "body": content[1], **json}
# create a new PR
resp = self._post(
url=self._url_pull,
json=json,
Expand All @@ -308,11 +369,14 @@ def update_branch(self, update: str, upstream: str) -> dict[str, Any] | None:
:param update: The branch to update.
:param upstream: The upstream branch.
"""
# Provide a title and an (empty) body so no description is generated for
# this mechanical, immediately merged update PR.
pr = self.create_pull_request(
{
"base": update,
"head": upstream,
"title": f"Merge {upstream} into {update}",
"body": "",
},
)
if pr is None:
Expand All @@ -328,6 +392,21 @@ def get_pull_request_files(
"""
return self._extract_all(url=f"{self._url_pull}/{pr_number}/files", n=n)

def compare(self, base: str, head: str) -> dict[str, Any]:
"""Compare two commits/branches in this repository.

:param base: The base branch (or commit) of the comparison.
:param head: The head branch (or commit) of the comparison.
:return: The comparison result containing `commits` and `files`
(each file with `filename`, `status`, `additions`, `deletions`,
and `patch`). See
https://docs.github.com/en/rest/commits/commits#compare-two-commits.
"""
# Branch names may contain slashes, so each ref is URL-encoded while the
# literal `...` separator between them is preserved.
basehead = f"{quote(base, safe='')}...{quote(head, safe='')}"
return self._get(url=f"{self._url_compare}/{basehead}").json()

def get_branches(self, n: int = 0) -> list[dict[str, Any]]:
"""List branches in this repository."""
return self._extract_all(url=self._url_branches, n=n)
Expand Down
226 changes: 226 additions & 0 deletions github_rest_api/pr_content.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
"""Generate pull request titles and descriptions from a comparison result.

The functions here operate on the JSON returned by the GitHub "compare two
commits" endpoint (see `Repository.compare`). They provide two layers:

- A deterministic layer (`deterministic_title` / `deterministic_body`) that
derives a Conventional-Commits title and a Markdown body purely from commit
messages and changed files. It has no extra dependencies and is always
available.
- An optional AI layer (`generate_pr_content`) that uses LiteLLM to produce a
richer title and description. It raises on any failure (missing `litellm`,
missing provider key, malformed reply) so the caller can fall back to the
deterministic layer.
"""

import json
import re
from collections import Counter
from typing import Any, cast


def parse_conventional(subject: str) -> tuple[str, str | None, bool, str] | None:
"""Parse a Conventional-Commits subject line.

:param subject: The first line of a commit message.
:return: A `(type, scope, breaking, description)` tuple, or None when the
subject does not follow the Conventional-Commits grammar.
"""
pattern = re.compile(
r"^(?P<type>feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert)"
r"(?:\((?P<scope>[^)]+)\))?(?P<breaking>!)?: (?P<description>.+)$"
)
match = pattern.match(subject)
if match is None:
return None
return (
match["type"],
match["scope"],
bool(match["breaking"]),
match["description"],
)


def _commit_messages(compare: dict[str, Any], skip_merges: bool = True) -> list[str]:
"""Extract commit messages from a comparison result.

:param compare: The comparison result from `Repository.compare`.
:param skip_merges: Whether to skip merge commits (subjects starting with
``Merge ``).
"""
messages = []
for commit in compare.get("commits", []):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If the commits key in the API response is present but its value is null (or if it is mocked as None in tests), compare.get("commits", []) will return None, causing a TypeError when iterated. Using compare.get("commits") or [] is a safer, more defensive approach.

Suggested change
for commit in compare.get("commits", []):
for commit in (compare.get("commits") or []):

message = (commit.get("commit") or {}).get("message", "")
if not message:
continue
subject = message.splitlines()[0].strip()
if skip_merges and subject.startswith("Merge "):
continue
messages.append(message)
return messages


def _common_scope(compare: dict[str, Any]) -> str | None:
"""Derive a Conventional-Commits scope from the changed files.

The scope is the common top-level directory shared by every changed file,
or None when the files do not share one (or live at the repository root).
"""
filenames = [
file["filename"] for file in compare.get("files", []) if file.get("filename")
]
if not filenames or not all("/" in name for name in filenames):
return None
segments = {name.split("/", 1)[0] for name in filenames}
return next(iter(segments)) if len(segments) == 1 else None


def deterministic_title(compare: dict[str, Any]) -> str:
"""Derive a Conventional-Commits title from a comparison result.

The type is the most significant type present across the commits
(``feat`` > ``fix`` > the most frequent parsed type > ``chore``); the scope
is the common top-level directory of the changed files; ``!`` is appended
for breaking changes.

:param compare: The comparison result from `Repository.compare`.
"""
messages = _commit_messages(compare)
subjects = [message.splitlines()[0].strip() for message in messages]
parsed = [parse_conventional(subject) for subject in subjects]
types = [item[0] for item in parsed if item]
scope = _common_scope(compare)
prefix_scope = f"({scope})" if scope else ""
if not subjects:
return f"chore{prefix_scope}: update"
if "feat" in types:
type_ = "feat"
elif "fix" in types:
type_ = "fix"
elif types:
type_ = Counter(types).most_common(1)[0][0]
else:
type_ = "chore"
# The Conventional Commits spec allows the breaking-change footer token to
# be spelled either "BREAKING CHANGE" or "BREAKING-CHANGE".
breaking_change_pattern = re.compile(r"BREAKING[ -]CHANGE")
breaking = any(item[2] for item in parsed if item) or any(
breaking_change_pattern.search(message) for message in messages
)
Comment thread
dclong marked this conversation as resolved.
if len(subjects) == 1:
description = parsed[0][3] if parsed[0] else subjects[0]
else:
description = next(
(item[3] for item in parsed if item and item[0] == type_),
f"update {len(subjects)} commits",
)
return f"{type_}{prefix_scope}{'!' if breaking else ''}: {description}"


def deterministic_body(compare: dict[str, Any]) -> str:
"""Build a Markdown PR body from a comparison result.

:param compare: The comparison result from `Repository.compare`.
"""
sections = []
subjects = [
message.splitlines()[0].strip() for message in _commit_messages(compare)
]
if subjects:
summary = "\n".join(f"- {subject}" for subject in subjects)
sections.append(f"## Summary\n\n{summary}")
files = compare.get("files", [])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If the files key in the API response is null, compare.get("files", []) will return None, causing a TypeError when checked or iterated. Using compare.get("files") or [] is safer.

Suggested change
files = compare.get("files", [])
files = compare.get("files") or []

if files:
# Conventional-Commit statuses ordered for a stable, readable "Changed
# files" section. Any status not listed is appended afterwards.
status_order = ("added", "modified", "removed", "renamed", "copied", "changed")
by_status: dict[str, list[dict[str, Any]]] = {}
for file in files:
by_status.setdefault(file.get("status") or "modified", []).append(file)
ordered = [status for status in status_order if status in by_status]
ordered += [status for status in by_status if status not in status_order]
lines = []
for status in ordered:
lines.append(f"**{status.capitalize()}**")
for file in by_status[status]:
lines.append(
f"- `{file.get('filename', '')}` "
f"(+{file.get('additions', 0)}/-{file.get('deletions', 0)})"
)
sections.append("## Changed files\n\n" + "\n".join(lines))
commits = compare.get("commits", [])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If the commits key in the API response is null, compare.get("commits", []) will return None, causing a TypeError when checked or iterated. Using compare.get("commits") or [] is safer.

Suggested change
commits = compare.get("commits", [])
commits = compare.get("commits") or []

if commits:
lines = []
for commit in commits:
sha = (commit.get("sha") or "")[:7]
message = (commit.get("commit") or {}).get("message", "")
subject = message.splitlines()[0].strip() if message else ""
lines.append(f"- {sha} {subject}")
sections.append("## Commits\n\n" + "\n".join(lines))
return "\n\n".join(sections)


def summarize_for_ai(compare: dict[str, Any], max_chars: int = 12000) -> str:
"""Assemble a compact, size-capped context for the AI prompt.

Combines the deterministic body with per-file patches, truncating once the
character budget is exhausted so very large diffs do not blow up the prompt.

:param compare: The comparison result from `Repository.compare`.
:param max_chars: The approximate maximum size of the returned context.
"""
parts = [deterministic_body(compare), "## Diff"]
budget = max_chars - sum(len(part) for part in parts)
for file in compare.get("files", []):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If the files key in the API response is null, compare.get("files", []) will return None, causing a TypeError when iterated. Using compare.get("files") or [] is safer.

Suggested change
for file in compare.get("files", []):
for file in (compare.get("files") or []):

patch = file.get("patch")
if not patch:
continue
chunk = f"### {file.get('filename', '')}\n{patch}"
if len(chunk) > budget:
parts.append("<remaining patches omitted: size limit reached>")
break
parts.append(chunk)
budget -= len(chunk)
return "\n\n".join(parts)


def generate_pr_content(compare: dict[str, Any], model: str) -> tuple[str, str]:
"""Generate a PR title and body with an LLM via LiteLLM.

:param compare: The comparison result from `Repository.compare`.
:param model: A LiteLLM ``provider/model`` string (e.g.
``anthropic/claude-haiku-4-5-20251001``, ``gemini/gemini-2.5-flash``).
The matching provider API key is read from the environment by LiteLLM.
:return: A `(title, body)` tuple.
:raises Exception: If LiteLLM is not installed, no provider key is set, the
request fails, or the reply cannot be parsed. The caller is expected to
fall back to the deterministic layer.
"""
import litellm

prompt = (
"You are writing a GitHub pull request from the changes below. "
"Respond with a JSON object containing exactly two string keys: "
"'title' and 'body'. The 'title' must be a single concise line "
"following the Conventional Commits specification "
"(e.g. 'feat(scope): summary'). The 'body' must be GitHub-flavored "
"Markdown describing the motivation and the key changes.\n\n"
f"{summarize_for_ai(compare)}"
)
response = litellm.completion(
model=model,
messages=[{"role": "user", "content": prompt}],
response_format={"type": "json_object"},
# Drop response_format for providers that do not support it rather than
# erroring; the prompt already requests JSON, so parsing still works.
drop_params=True,
)
response_any = cast(Any, response)
data = json.loads(response_any["choices"][0]["message"]["content"])
Comment thread
dclong marked this conversation as resolved.
if not isinstance(data, dict):
raise ValueError("The model did not return a JSON object.")
title = str(data.get("title", "")).strip()
body = str(data.get("body", "")).strip()
Comment on lines +219 to +223

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When response_format is dropped or not fully supported by the LLM provider, the model might return JSON wrapped in markdown code blocks (e.g., ```json ... ```). Parsing this directly with json.loads will raise a JSONDecodeError.

Additionally, if the model returns null for title or body, data.get("title", "") will return None, and str(None) will evaluate to the string "None".

We can make the parsing more robust by stripping wrapping markdown code blocks and safely handling potential None values.

Suggested change
data = json.loads(response_any["choices"][0]["message"]["content"])
if not isinstance(data, dict):
raise ValueError("The model did not return a JSON object.")
title = str(data.get("title", "")).strip()
body = str(data.get("body", "")).strip()
content = (response_any["choices"][0]["message"]["content"] or "").strip()
if content.startswith(chr(96) * 3) and content.endswith(chr(96) * 3):
content = re.sub(r"^" + chr(96) * 3 + r"(?:json)?\s*", "", content, flags=re.IGNORECASE)
content = re.sub(r"\s*" + chr(96) * 3 + r"$", "", content)
data = json.loads(content)
if not isinstance(data, dict):
raise ValueError("The model did not return a JSON object.")
title = str(data.get("title") or "").strip()
body = str(data.get("body") or "").strip()

if not title or not body:
raise ValueError("The model returned an empty title or body.")
return title, body
Loading
Loading