Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions database/create_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,9 @@ def create_schema(conn: Connection):
INNER JOIN informal i ON d.name = i.symbol_name
INNER JOIN symbol s ON d.name = s.name
""",

"""
CREATE SCHEMA physlibsearch
""",

"""
CREATE TABLE physlibsearch.query (
id UUID PRIMARY KEY,
Expand All @@ -108,7 +106,7 @@ def create_schema(conn: Connection):
declaration_name JSONB REFERENCES declaration(name) NOT NULL,
action TEXT NOT NULL,
PRIMARY KEY (query_id, declaration_name)
)"""
)""",
]

with conn.cursor() as cursor:
Expand Down
4 changes: 1 addition & 3 deletions database/informalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def generate_informal(conn: Connection, batch_size: int = 50, limit_level: int |
max_level = limit_level

with conn.cursor(row_factory=scalar_row) as cnt_cursor:
total_remaining = cnt_cursor.execute(
"SELECT COUNT(*) FROM symbol s WHERE NOT EXISTS(SELECT 1 FROM informal i WHERE i.symbol_name = s.name)"
).fetchone() or 0
total_remaining = cnt_cursor.execute("SELECT COUNT(*) FROM symbol s WHERE NOT EXISTS(SELECT 1 FROM informal i WHERE i.symbol_name = s.name)").fetchone() or 0
done = 0
logger.warning("starting informalization: %d declarations remaining", total_remaining)

Expand Down
32 changes: 19 additions & 13 deletions database/jixia_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,38 @@
from pathlib import Path

from jixia import LeanProject
from jixia.structs import LeanName, Symbol, Declaration, is_internal, StringRange
from jixia.structs import LeanName, Symbol, Declaration, is_internal
from psycopg import Connection
from psycopg.types.json import Jsonb
from psycopg.types.range import Range

logger = logging.getLogger(__name__)


def _get_signature(declaration: Declaration, module_content):
if declaration.signature.pp is not None:
return declaration.signature.pp
elif declaration.signature.range is not None:
return module_content[declaration.signature.range.as_slice()].decode()
else:
return ''
return ""


def _get_value(declaration: Declaration, module_content):
if declaration.value is not None and declaration.value.range is not None:
return module_content[declaration.value.range.as_slice()].decode()
else:
return None


def _get_range(declaration: Declaration):
r = declaration.ref.range
if r is not None:
return Range(r.start, r.stop)
else:
return None


def load_data(project: LeanProject, prefixes: list[LeanName], conn: Connection):
def load_module(data: Iterable[LeanName], base_dir: Path):
values = ((Jsonb(m), project.path_of_module(m, base_dir).read_bytes(), project.load_module_info(m).docstring) for m in data)
Expand Down Expand Up @@ -115,17 +119,19 @@ def load_declaration(module_name: LeanName):
for index, decl in enumerate(declarations):
if is_internal(decl.name) or decl.kind == "proofWanted":
continue
db_declarations.append({
"module_name": Jsonb(module_name),
"index" : index,
"name" : Jsonb(decl.name) if decl.kind != "example" else None,
"visible" : decl.modifiers.visibility != "private" and decl.kind != "example",
"docstring" : decl.modifiers.docstring,
"kind" : decl.kind,
"signature" : _get_signature(decl, module_content),
"value" : _get_value(decl, module_content),
"range" : _get_range(decl),
})
db_declarations.append(
{
"module_name": Jsonb(module_name),
"index": index,
"name": Jsonb(decl.name) if decl.kind != "example" else None,
"visible": decl.modifiers.visibility != "private" and decl.kind != "example",
"docstring": decl.modifiers.docstring,
"kind": decl.kind,
"signature": _get_signature(decl, module_content),
"value": _get_value(decl, module_content),
"range": _get_range(decl),
}
)
cursor.executemany(
"""
INSERT INTO declaration (module_name, index, name, visible, docstring, kind, signature, value)
Expand Down
1 change: 0 additions & 1 deletion engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from collections.abc import Iterable

import chromadb
Expand Down
18 changes: 12 additions & 6 deletions prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,43 @@
from jixia import LeanProject
from jixia.structs import parse_name, is_prefix_of, LeanName


def format(lean_name: LeanName, with_indent: bool = False) -> str:
formatted = '.'.join(str(x) for x in lean_name)
indent = ' ' * (len(lean_name) - 1) if with_indent else ''
formatted = ".".join(str(x) for x in lean_name)
indent = " " * (len(lean_name) - 1) if with_indent else ""
return f"{indent}{formatted}"


def sort(lean_names: list[LeanName]) -> list[LeanName]:
return sorted(lean_names, key=format)


def main(project_root: str, prefixes: str | None) -> None:
project = LeanProject(project_root)
all_module_names : list[LeanName] = project.find_modules()
all_module_names: list[LeanName] = project.find_modules()

print("____________ALL MODULES_____________")
for module_name in sort(all_module_names):
print(format(module_name, with_indent=True))

if prefixes is not None:
print("__________MODULES THAT MATCH YOUR PREFIX___________")
prefix_names : list[LeanName] = [parse_name(p) for p in prefixes.split(",")]
prefix_names: list[LeanName] = [parse_name(p) for p in prefixes.split(",")]
matching_names = [n for n in all_module_names if any(is_prefix_of(p, n) for p in prefix_names)]
for module_name in sort(matching_names):
print(format(module_name))


if __name__ == "__main__":
dotenv.load_dotenv()
path_to_lean = path_to_lean = Path(os.environ["LEAN_SYSROOT"]) / "src" / "lean"
parser = ArgumentParser(description=f"""
parser = ArgumentParser(
description=f"""
Helper command that helps you understand what files are available for indexing.
For example, you can run it with:
python -m prefix --project_root "{path_to_lean}" --prefixes Init.Grind,Init.Control.Lawful
""")
"""
)
parser.add_argument("--project_root", help="Path to the project you want to index", required=False)
parser.add_argument("--prefixes", help="Comma-separated list of module prefixes to be included in the index; e.g., Init.Grind,Init.Control.Lawful", required=False)
args = parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion query_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, model: str):
)
self.model = model
# Matches everything after "Hypothetical: " — the full hypothetical declaration
self.pattern = re.compile(r'Hypothetical:\s*(.*)', re.DOTALL)
self.pattern = re.compile(r"Hypothetical:\s*(.*)", re.DOTALL)

async def expand(self, user_input: str) -> str | None:
"""
Expand Down
48 changes: 25 additions & 23 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
from typing import Annotated

import dotenv
import psycopg
from fastapi import FastAPI, Body, Response, Cookie
from jixia.structs import LeanName
from psycopg import Connection
from psycopg.rows import scalar_row, class_row, dict_row
from psycopg.rows import scalar_row, class_row
from psycopg.types.json import Jsonb
from psycopg_pool import ConnectionPool
from pydantic import BaseModel
Expand All @@ -28,9 +26,9 @@
async def lifespan(app: FastAPI):
dotenv.load_dotenv()
with ConnectionPool(
os.environ["CONNECTION_STRING"],
kwargs={"autocommit": True},
check=ConnectionPool.check_connection,
os.environ["CONNECTION_STRING"],
kwargs={"autocommit": True},
check=ConnectionPool.check_connection,
) as pool:
app.expander = QueryExpander(os.environ["GEMINI_FAST_MODEL"])
app.retriever = PhyslibSearchEngine(os.environ["CHROMA_PATH"], None)
Expand Down Expand Up @@ -63,25 +61,31 @@ async def set_connection(request: Request, call_next):

@app.post("/search")
def search(
response: Response,
query: list[str],
num_results: Annotated[int, Body(gt=0, le=150)] = 10,
response: Response,
query: list[str],
num_results: Annotated[int, Body(gt=0, le=150)] = 10,
) -> list[list[QueryResult]]:
if len(query) == 1:
with app.retriever.conn.cursor(row_factory=scalar_row) as cursor:
cursor.execute("""
cursor.execute(
"""
INSERT INTO physlibsearch.query(id, query, time)
VALUES (GEN_RANDOM_UUID(), %s, NOW())
RETURNING id
""", (query[0],))
""",
(query[0],),
)
session_id = cursor.fetchone()
response.set_cookie("session", str(session_id))
else:
with app.retriever.conn.cursor() as cursor:
cursor.executemany("""
cursor.executemany(
"""
INSERT INTO physlibsearch.query(id, query, time)
VALUES (GEN_RANDOM_UUID(), %s, NOW())
""", [(q,) for q in query])
""",
[(q,) for q in query],
)

return app.retriever.find_declarations(query, num_results)

Expand Down Expand Up @@ -138,13 +142,16 @@ def list_modules(request: Request) -> list[ModuleInfo]:
def module_declarations(request: Request, module_name: LeanName) -> list[Record]:
with app.pool.connection() as conn:
with conn.cursor(row_factory=class_row(Record)) as cursor:
cursor.execute("""
cursor.execute(
"""
SELECT r.*
FROM record r
INNER JOIN declaration d ON r.name = d.name
WHERE r.module_name = %s AND d.visible = TRUE
ORDER BY d.index
""", (Jsonb(module_name),))
""",
(Jsonb(module_name),),
)
return cursor.fetchall()


Expand All @@ -153,6 +160,7 @@ def module_declarations(request: Request, module_name: LeanName) -> list[Record]
async def user_feedback(request: Request, body: UserFeedback):
if not (1 <= body.rating <= 5):
from fastapi import HTTPException

raise HTTPException(status_code=422, detail="Rating must be between 1 and 5")
with app.pool.connection() as conn:
with conn.cursor() as cursor:
Expand All @@ -171,13 +179,7 @@ async def feedback(session: Annotated[str, Cookie()], body: Feedback):
query_id = uuid.UUID(session)
if body.cancel:
with app.retriever.conn.cursor() as cursor:
cursor.execute(
"DELETE FROM physlibsearch.feedback WHERE query_id = %s AND declaration_name = %s",
(query_id, Jsonb(body.declaration))
)
cursor.execute("DELETE FROM physlibsearch.feedback WHERE query_id = %s AND declaration_name = %s", (query_id, Jsonb(body.declaration)))
else:
with app.retriever.conn.cursor() as cursor:
cursor.execute(
"INSERT INTO physlibsearch.feedback(query_id, declaration_name, action) VALUES (%s, %s, %s)",
(query_id, Jsonb(body.declaration), body.action)
)
cursor.execute("INSERT INTO physlibsearch.feedback(query_id, declaration_name, action) VALUES (%s, %s, %s)", (query_id, Jsonb(body.declaration), body.action))
Loading