Skip to content

FEAT Beam Search for OpenAIResponseTarget#1346

Open
riedgar-ms wants to merge 264 commits into
microsoft:mainfrom
riedgar-ms:riedgar-ms/beam-search-01
Open

FEAT Beam Search for OpenAIResponseTarget#1346
riedgar-ms wants to merge 264 commits into
microsoft:mainfrom
riedgar-ms:riedgar-ms/beam-search-01

Conversation

@riedgar-ms

@riedgar-ms riedgar-ms commented Feb 2, 2026

Copy link
Copy Markdown
Contributor

Description

Use the Lark grammar feature of the OpenAIResponseTarget to create a beam search for PyRIT. This is a single turn attack, where a collection of candidate responses (the beams) are maintained. On each iteration, the model's response is allowed to extend a little for each beam. The beams are scored, with the worst performing ones discarded, and replaced with copies of higher scoring beams.

Tests and Documentation

Have basic unit tests of the classes added, but since this requires features only currently in the OpenAIResponseTarget there didn't seem much point in mocking that. There is a notebook which runs everything E2E.

@riedgar-ms riedgar-ms left a comment

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is ready for preliminary review; there aren't any docs or (proper) tests yet. I'd like to make sure that I'm manipulating the database correctly before delving into those.

Comment thread beam_search_test.py Outdated
@@ -0,0 +1,72 @@
import asyncio

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file will ultimately be converted into tests and/or a notebook. For now, it's the easiest way for me to test.

**deepcopy(kwargs),
}

def fresh_instance(self) -> "OpenAIResponseTarget":

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes are required because the OpenAI API takes a grammar as a tool and PyRIT makes the tool list part of the object, not the send_prompt_async() API. Since we have multiple beams being managed asynchronously, each task needs its own copy of the OpenAIResponseTarget

logger = logging.getLogger(__name__)


def _print_message(message: Message) -> None:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will be deleted; this is for my debugging convenience.

return new_beams


class BeamSearchAttack(SingleTurnAttackStrategy):

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is largely copied from the PromptSendingAttack

Comment thread pyrit/executor/attack/single_turn/beam_search.py Outdated
target = self._get_target_for_beam(beam)

current_context = copy.deepcopy(self._start_context)
await self._setup_async(context=current_context)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not certain I'm handling the context correctly here. I end up making lots of copies of things, which is going to be filling up the database with fragmentary responses. Each time one is extended, it ends up being cloned and a new conversation started.

objective=context.objective,
)

aux_scores = scoring_results["auxiliary_scores"]

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the auxilliary scorer is required; it is used to assess the beams as they develop

Comment thread pyrit/executor/attack/single_turn/beam_search.py Outdated

new_beams = list(reversed(sorted_beams[: self.k]))
for i in range(len(beams) - len(new_beams)):
nxt = copy.deepcopy(new_beams[i % self.k])

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't just duplicate the highest scoring beam, in the hope of maintaining some variety

Args:
context (SingleTurnAttackContext): The attack context containing attack parameters.
"""
self._start_context = copy.deepcopy(context)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See note below. I duplicate the context and the message for each beam on each iteration. I'm not certain that this is the best way to use the database.

**kwargs_copy,
}

def fresh_instance(

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am very much aware that this is less than ideal. However, since tools are stored on the class rather than supplied when a completion is wanted, and we want to run multiple different calls in parallel, I think this is the smallest possible change.

It might be better to have the tools be an extra argument to send_prompt_async() but that would have quite a blast radius.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants