Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ include(${PROJECT_SOURCE_DIR}/tools/cmake/Codegen.cmake)
include(${PROJECT_SOURCE_DIR}/tools/cmake/Utils.cmake)
include(CMakeDependentOption)
include(ExternalProject)
include(FetchContent)
include(GNUInstallDirs)

if(NOT CMAKE_CXX_STANDARD)
Expand Down Expand Up @@ -406,6 +407,14 @@ set(_common_include_directories
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/executorch/runtime/core/portable_type/c10>
)

if(TARGET jinja2cpp)
install(
TARGETS jinja2cpp
EXPORT ExecuTorchTargets
DESTINATION ${CMAKE_INSTALL_LIBDIR}
)
endif()

#
# The `_<target>_srcs` lists are defined by executorch_load_build_variables.
#
Expand Down Expand Up @@ -803,7 +812,7 @@ endif()

if(EXECUTORCH_BUILD_EXTENSION_LLM)
if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER)
set(SUPPORT_REGEX_LOOKAHEAD ON)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/chat_template)
# llama/runner/CMakeLists.txt builds a shared library libllama_runner.so
# that transitively depends on tokenizers. Need to build tokenizers with
# -fPIC.
Expand Down
18 changes: 18 additions & 0 deletions examples/models/llama/runner/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,20 @@ def build_args_parser() -> argparse.ArgumentParser:
help="Have multi-turn chat with the model",
)

parser.add_argument(
"--system_prompt",
type=str,
default="",
help="System prompt for chat formatting (optional).",
)

parser.add_argument(
"--chat_template_file",
type=str,
default="",
help="Path to a custom Jinja2 chat template file for chat mode.",
)

parser.add_argument(
"--tokenizer_config_path",
type=str,
Expand All @@ -104,6 +118,8 @@ def execute_runner(runner_class: Type[LlamaRunner]) -> None:
temperature = args.temperature
show_tokens = args.show_tokens
chat_mode = args.chat
system_prompt = args.system_prompt
chat_template_file = args.chat_template_file
tokenizer_config_path = args.tokenizer_config_path
use_attention_sink = args.use_attention_sink

Expand All @@ -113,6 +129,8 @@ def execute_runner(runner_class: Type[LlamaRunner]) -> None:
llm_config=llm_config,
tokenizer_config_path=tokenizer_config_path,
use_attention_sink=use_attention_sink,
system_prompt=system_prompt,
chat_template_file=chat_template_file or None,
)

generated_tokens = (
Expand Down
63 changes: 55 additions & 8 deletions examples/models/llama/runner/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import List, Optional

import torch

from pytorch_tokenizers import get_tokenizer


Expand Down Expand Up @@ -56,6 +55,9 @@ def __init__(
max_batch_size: int,
use_kv_cache: bool,
vocab_size: int,
chat_format: str = "none",
system_prompt: str = "",
chat_template_file: Optional[str] = None,
device: str = "cpu",
):
"""
Expand All @@ -74,6 +76,10 @@ def __init__(
self.use_kv_cache = use_kv_cache
self.tokenizer = get_tokenizer(tokenizer_path, tokenizer_config_path)
self.device = device
self.chat_format = chat_format
self.system_prompt = system_prompt
self.chat_template_file = chat_template_file
self._chat_formatter = None
# For some models like qwen, mismatch is acceptable: https://github.com/QwenLM/Qwen2.5/issues/466#issuecomment-2146759706
if vocab_size != self.tokenizer.n_words:
print(
Expand Down Expand Up @@ -207,9 +213,14 @@ def chat_completion(
prompt = input("Me: ")
while prompt and prompt != exit_prompt:
print("LLM: ", end="", flush=True)
prompt_tokens = self.tokenizer.encode(
self._format_prompt(prompt), bos=True, eos=False
formatter = self._get_chat_formatter()
formatted_prompt = (
formatter.format(prompt, self.system_prompt)
if formatter is not None
else prompt
)
bos = not (formatter is not None and formatter.includes_bos())
prompt_tokens = self.tokenizer.encode(formatted_prompt, bos=bos, eos=False)
generated_tokens = self.generate(
prompt_tokens=pre_stop_token + prompt_tokens,
max_seq_len=max_seq_len,
Expand All @@ -227,8 +238,44 @@ def chat_completion(
return tokens

def _format_prompt(self, prompt: str) -> str:
return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>

{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
formatter = self._get_chat_formatter()
if formatter is None:
return prompt
return formatter.format(prompt, self.system_prompt)

def _resolve_template_type(self, chat_template_type) -> Optional["object"]:
normalized = (self.chat_format or "").lower()
if normalized in ("", "none"):
return None
if normalized == "llama3":
return chat_template_type.Llama3
if normalized in ("llama3.2", "llama32", "llama3_2"):
return chat_template_type.Llama32
if normalized == "gemma3":
return chat_template_type.Gemma3
return None

def _get_chat_formatter(self):
if self._chat_formatter is not None:
return self._chat_formatter
try:
from executorch.extension.llm.runner import (
ChatTemplateType,
JinjaChatFormatter,
)
except ImportError as exc:
raise RuntimeError(
"Jinja chat templates require ExecuTorch pybindings. "
"Build with EXECUTORCH_BUILD_PYBIND=ON."
) from exc

if self.chat_template_file:
self._chat_formatter = JinjaChatFormatter.from_file(self.chat_template_file)
return self._chat_formatter

template_type = self._resolve_template_type(ChatTemplateType)
if template_type is None:
return None

self._chat_formatter = JinjaChatFormatter.from_template(template_type)
return self._chat_formatter
18 changes: 18 additions & 0 deletions extension/llm/chat_template/BUCK
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target")
oncall("executorch")

# Any targets that should be shared between fbcode and xplat must be defined in
# targets.bzl. This file can contain fbcode-only targets.

load(":targets.bzl", "define_common_targets")

non_fbcode_target(_kind = define_common_targets,)

# !!!! fbcode/executorch/extension/llm/chat_template/TARGETS was merged into this file, see https://fburl.com/workplace/xl8l9yuo for more info !!!!

# Any targets that should be shared between fbcode and xplat must be defined in
# targets.bzl. This file can contain fbcode-only targets.

load(":targets.bzl", "define_common_targets")

fbcode_target(_kind = define_common_targets,)
116 changes: 116 additions & 0 deletions extension/llm/chat_template/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
if(NOT EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER)
return()
endif()

include(FetchContent)
cmake_policy(SET CMP0077 NEW)

FetchContent_Declare(
jinja2cpp
GIT_REPOSITORY https://github.com/jinja2cpp/Jinja2Cpp.git
GIT_TAG 1.3.2
GIT_SUBMODULES_RECURSE TRUE
)

set(JINJA2CPP_BUILD_TESTS
OFF
CACHE BOOL ""
FORCE
)
set(JINJA2CPP_BUILD_SHARED
OFF
CACHE BOOL ""
FORCE
)
set(JINJA2CPP_INSTALL
OFF
CACHE BOOL ""
FORCE
)
# Enable PCRE2-based regex lookahead support in Jinja2Cpp. This must be set
# BEFORE FetchContent_MakeAvailable(jinja2cpp) so it propagates to the
# Jinja2Cpp configure step.
set(SUPPORT_REGEX_LOOKAHEAD
ON
CACHE BOOL ""
FORCE
)

FetchContent_MakeAvailable(jinja2cpp)
if(NOT TARGET jinja2cpp)
message(FATAL_ERROR "Jinja2Cpp target not found after FetchContent.")
endif()

if(DEFINED jinja2cpp_SOURCE_DIR)
function(executorch_copy_nonstd_header dep_name target header_name dest_root)
set(_copied FALSE)
if(TARGET ${target})
get_target_property(_aliased ${target} ALIASED_TARGET)
if(_aliased)
set(_resolved_target ${_aliased})
else()
set(_resolved_target ${target})
endif()
get_target_property(
_include_dirs ${_resolved_target} INTERFACE_INCLUDE_DIRECTORIES
)
foreach(_dir IN LISTS _include_dirs)
if(EXISTS "${_dir}/nonstd/${header_name}")
file(MAKE_DIRECTORY "${dest_root}/nonstd")
file(
COPY "${_dir}/nonstd/${header_name}"
DESTINATION "${dest_root}/nonstd"
)
set(_copied TRUE)
break()
endif()
endforeach()
endif()
if(NOT _copied)
set(
_fallback_path
"${CMAKE_BINARY_DIR}/_deps/${dep_name}-src/include/nonstd/${header_name}"
)
if(EXISTS "${_fallback_path}")
file(MAKE_DIRECTORY "${dest_root}/nonstd")
file(COPY "${_fallback_path}" DESTINATION "${dest_root}/nonstd")
endif()
endif()
endfunction()

set(_jinja2cpp_nonstd_root
"${jinja2cpp_SOURCE_DIR}/thirdparty/nonstd"
)
executorch_copy_nonstd_header(
expected-lite
nonstd::expected-lite
expected.hpp
"${_jinja2cpp_nonstd_root}/expected-lite/include"
)
executorch_copy_nonstd_header(
variant-lite
nonstd::variant-lite
variant.hpp
"${_jinja2cpp_nonstd_root}/variant-lite/include"
)
executorch_copy_nonstd_header(
optional-lite
nonstd::optional-lite
optional.hpp
"${_jinja2cpp_nonstd_root}/optional-lite/include"
)
executorch_copy_nonstd_header(
string-view-lite
nonstd::string-view-lite
string_view.hpp
"${_jinja2cpp_nonstd_root}/string-view-lite/include"
)
endif()

# Install the chat_templates.h header so that downstream consumers of the
# installed ExecuTorch SDK can include
# <executorch/extension/llm/chat_template/chat_templates.h>.
install(
FILES chat_templates.h
DESTINATION include/executorch/extension/llm/chat_template
)
51 changes: 51 additions & 0 deletions extension/llm/chat_template/chat_templates.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#pragma once

#include <string>
#include <string_view>
#include <unordered_map>
#include <vector>

namespace executorch::extension::llm {

enum class ChatTemplateType {
None,
Llama3,
Llama32,
Gemma3,
Custom,
};

constexpr std::string_view kLlama3Template = R"({{ bos_token }}{%- for message in messages -%}<|start_header_id|>{{ message.role }}<|end_header_id|>

{{ message.content }}<|eot_id|>{%- endfor -%}{%- if add_generation_prompt -%}<|start_header_id|>assistant<|end_header_id|>

{%- endif -%})";

constexpr std::string_view kGemma3Template = R"({{ bos_token }}{%- for message in messages -%}{%- if message.role == 'assistant' -%}<start_of_turn>model
{%- else -%}<start_of_turn>{{ message.role }}
{%- endif -%}{{ message.content }}<end_of_turn>{%- endfor -%}{%- if add_generation_prompt -%}<start_of_turn>model
{%- endif -%})";

inline const std::unordered_map<ChatTemplateType, std::string_view>
kEmbeddedTemplates = {
{ChatTemplateType::Llama3, kLlama3Template},
{ChatTemplateType::Llama32, kLlama3Template},
{ChatTemplateType::Gemma3, kGemma3Template},
};

struct ModelTokens {
std::string bos_token;
std::string eos_token;
std::vector<std::string> stop_tokens;
};

inline const std::unordered_map<ChatTemplateType, ModelTokens> kModelTokens = {
{ChatTemplateType::Llama3,
{"<|begin_of_text|>", "<|eot_id|>", {"<|eot_id|>", "<|end_of_text|>"}}},
{ChatTemplateType::Llama32,
{"<|begin_of_text|>", "<|eot_id|>", {"<|eot_id|>", "<|end_of_text|>"}}},
{ChatTemplateType::Gemma3,
{"<bos>", "<end_of_turn>", {"<end_of_turn>", "<eos>"}}},
};

} // namespace executorch::extension::llm
16 changes: 16 additions & 0 deletions extension/llm/chat_template/targets.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

def define_common_targets():
runtime.cxx_library(
name = "chat_templates",
exported_headers = [
"chat_templates.h",
],
visibility = ["PUBLIC"],
)
Loading
Loading