Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nemo/collections/common/prompts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from nemo.collections.common.prompts.canary import CanaryPromptFormatter
from nemo.collections.common.prompts.canary2 import Canary2PromptFormatter
from nemo.collections.common.prompts.formatter import PromptFormatter
from nemo.collections.common.prompts.gemma import GemmaPromptFormatter
from nemo.collections.common.prompts.gemma import GemmaPromptFormatter, Gemma4PromptFormatter
from nemo.collections.common.prompts.llama import Llama2PromptFormatter, Llama3PromptFormatter
from nemo.collections.common.prompts.mistral import MistralPromptFormatter
from nemo.collections.common.prompts.nemotron_h import NemotronHPromptFormatter
Expand Down
67 changes: 62 additions & 5 deletions nemo/collections/common/prompts/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Implemented following the guide at https://www.promptingguide.ai/models/gemma#gemma-7b-prompt-format
"""
Gemma1 prompt format reference:
https://www.promptingguide.ai/models/gemma#gemma-7b-prompt-format

Gemma4 prompt format reference (multimodal: text + image + audio):
<|turn>user
Describe this image: <|image|>
And translate this audio: <|audio|><turn|>
<|turn>model
"""
from lhotse.cut import Cut, MixedCut

from nemo.collections.common.data.prompt_fn import registered_prompt_format_fn
Expand Down Expand Up @@ -58,9 +63,61 @@ def gemma1(cut: Cut, prompt: GemmaPromptFormatter):
context = cut.question
else:
context = cut.default_context

turns = [{"role": "user", "slots": {"message": context}}]
if (answer := cut.supervisions[0].text) is not None:
turns.append({"role": "assistant", "slots": {"message": answer}})

return prompt.encode_dialog(turns)


GEMMA4_BOT = "<|turn>" # beginning-of-turn
GEMMA4_EOT = "<turn|>" # end-of-turn
GEMMA4_IMAGE = "<|image|>" # image placeholder token
GEMMA4_AUDIO = "<|audio|>" # audio placeholder token


class Gemma4PromptFormatter(PromptFormatter):
NAME = "gemma4"
OUTPUT_ROLE = "assistant"
INSERT_BOS = True
INSERT_EOS = True
TEMPLATE = {
"user": {
"template": f"{GEMMA4_BOT}user\n|message|{GEMMA4_EOT}\n{GEMMA4_BOT}model\n",
"slots": {
"message": Modality.Text,
},
},
OUTPUT_ROLE: {
"template": f"|message|{GEMMA4_EOT}\n",
"slots": {
"message": Modality.Text,
},
},
}


@registered_prompt_format_fn(Cut, Gemma4PromptFormatter)
def gemma4(cut: Cut, prompt: Gemma4PromptFormatter):
if isinstance(cut, MixedCut):
cut = cut.first_non_padding_cut
if cut.has_custom("context"):
context = cut.context
elif cut.has_custom("question"):
context = cut.question
else:
context = cut.default_context
parts = []
if context:
parts.append(context)
if cut.has_custom("image") and cut.image is not None:
parts.append(GEMMA4_IMAGE)
if getattr(cut, "has_recording", False) or cut.has_custom("audio_filepath"):
parts.append(GEMMA4_AUDIO)
if cut.has_custom("extra_audios") and cut.extra_audios:
for _ in cut.extra_audios:
parts.append(GEMMA4_AUDIO)
user_message = "\n".join(parts)
turns = [{"role": "user", "slots": {"message": user_message}}]
if (answer := cut.supervisions[0].text) is not None:
turns.append({"role": "assistant", "slots": {"message": answer}})
return prompt.encode_dialog(turns)
Loading
Loading