Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions steps/src/toxicity_guardrail/toxicity_guardrail.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,9 @@
#

from typing import Any, Dict

from transformers import pipeline

class ToxicityGuardrailStep:
"""
A serving graph step that filters out toxic requests using a pre-trained
text classification model.

If the toxicity score of the input text meets or exceeds the threshold,
the request is blocked with a ValueError. Safe requests are passed through
unchanged.

The classifier label "toxic" maps directly to the toxicity score; any
other label (e.g. "non-toxic") inverts the model's confidence score.
"""

def __init__(
self,
context=None,
Expand All @@ -37,13 +25,23 @@ def __init__(
model_name: str = "unitary/toxic-bert",
**kwargs,
):
"""
A serving graph step that filters out toxic requests using a pre-trained
text classification model.

:param context: MLRun context object, injected automatically by the serving graph.
:param name: Name of this step in the serving graph.
:param threshold: Toxicity score threshold; requests whose toxicity score meets or
exceeds this value are blocked with a ValueError. Defaults to 0.5.
:param model_name: HuggingFace model identifier used for text classification.
Defaults to "unitary/toxic-bert".
:param kwargs: Additional keyword arguments forwarded to the serving graph step base.
"""
self.threshold = threshold
self.model_name = model_name
self._classifier = None

def post_init(self, mode="sync", **kwargs):
from transformers import pipeline

self._classifier = pipeline("text-classification", model=self.model_name)

def do(self, event: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
Loading