Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/avatar_agents/tavus/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ This example demonstrates how to create a animated avatar using [Tavus](https://
```bash
# Tavus Config
export TAVUS_API_KEY="..."
export TAVUS_REPLICA_ID="..."
export TAVUS_FACE_ID="..."

# OpenAI config (or other models, tts, stt)
export OPENAI_API_KEY="..."
Expand Down
6 changes: 3 additions & 3 deletions examples/avatar_agents/tavus/agent_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ async def entrypoint(ctx: JobContext):
resume_false_interruption=False,
)

persona_id = os.getenv("TAVUS_PERSONA_ID")
replica_id = os.getenv("TAVUS_REPLICA_ID")
tavus_avatar = tavus.AvatarSession(persona_id=persona_id, replica_id=replica_id)
pal_id = os.getenv("TAVUS_PAL_ID")
face_id = os.getenv("TAVUS_FACE_ID")
tavus_avatar = tavus.AvatarSession(pal_id=pal_id, face_id=face_id)
await tavus_avatar.start(session, room=ctx.room)

# start the agent, it will join the room and wait for the avatar to join
Expand Down
105 changes: 92 additions & 13 deletions livekit-plugins/livekit-plugins-tavus/livekit/plugins/tavus/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import os
import warnings
from typing import Any

import aiohttp
Expand All @@ -22,6 +23,37 @@ class TavusException(Exception):


DEFAULT_API_URL = "https://tavusapi.com/v2"
# Stock Tavus PAL. Use create_pal() to create a PAL with the appearance you'd like.
DEFAULT_PAL_ID = "pb87e71797da"


def _coalesce_with_deprecated(
new_value: NotGivenOr[str],
deprecated_value: NotGivenOr[str],
*,
deprecated_name: str,
new_name: str,
) -> NotGivenOr[str]:
# Prefer the new arg; fall back to the deprecated alias and warn only when it's used.
if deprecated_value and not new_value:
warnings.warn(
f"`{deprecated_name}` is deprecated, use `{new_name}` instead",
DeprecationWarning,
stacklevel=3,
)
return new_value or deprecated_value


def _deprecated_env(deprecated_name: str, new_name: str) -> str | None:
# Read a deprecated env var, warning if it's set so callers migrate to `new_name`.
value = os.getenv(deprecated_name)
if value:
warnings.warn(
f"`{deprecated_name}` is deprecated, use `{new_name}` instead",
DeprecationWarning,
stacklevel=3,
)
return value


class TavusAPI:
Expand All @@ -45,26 +77,43 @@ def __init__(
async def create_conversation(
self,
*,
face_id: NotGivenOr[str] = NOT_GIVEN,
pal_id: NotGivenOr[str] = NOT_GIVEN,
replica_id: NotGivenOr[str] = NOT_GIVEN,
persona_id: NotGivenOr[str] = NOT_GIVEN,
properties: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
extra_payload: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
) -> str:
replica_id = replica_id or (os.getenv("TAVUS_REPLICA_ID") or NOT_GIVEN)
if not replica_id:
raise TavusException("TAVUS_REPLICA_ID must be set")

persona_id = persona_id or (os.getenv("TAVUS_PERSONA_ID") or NOT_GIVEN)
if not persona_id:
# create a persona if not provided
persona_id = await self.create_persona()
# `replica_id`/`persona_id` are deprecated aliases for `face_id`/`pal_id`.
face_id = _coalesce_with_deprecated(
face_id, replica_id, deprecated_name="replica_id", new_name="face_id"
)
pal_id = _coalesce_with_deprecated(
pal_id, persona_id, deprecated_name="persona_id", new_name="pal_id"
)

face_id = (
face_id
or os.getenv("TAVUS_FACE_ID")
or _deprecated_env("TAVUS_REPLICA_ID", "TAVUS_FACE_ID")
or NOT_GIVEN
)
pal_id = (
pal_id
or os.getenv("TAVUS_PAL_ID")
or _deprecated_env("TAVUS_PERSONA_ID", "TAVUS_PAL_ID")
or NOT_GIVEN
)

if not pal_id:
# no pal supplied — use the default stock pal (carries its own face)
pal_id = DEFAULT_PAL_ID

properties = properties or {}
payload = {
"replica_id": replica_id,
"persona_id": persona_id,
"properties": properties,
}
payload: dict[str, Any] = {"pal_id": pal_id, "properties": properties}
# send face_id only when given; otherwise the pal's default_face_id is used
if face_id:
payload["face_id"] = face_id
if utils.is_given(extra_payload):
payload.update(extra_payload)

Expand All @@ -74,12 +123,42 @@ async def create_conversation(
response_data = await self._post("conversations", payload)
return response_data["conversation_id"] # type: ignore

async def create_pal(
self,
name: NotGivenOr[str] = NOT_GIVEN,
*,
default_face_id: str,
extra_payload: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
) -> str:
name = name or utils.shortuuid("lk_pal_")

payload = {
"pal_name": name,
"default_face_id": default_face_id,
"pipeline_mode": "echo",
"layers": {
"transport": {"transport_type": "livekit"},
},
}

if utils.is_given(extra_payload):
payload.update(extra_payload)

response_data = await self._post("pals", payload)
return response_data["pal_id"] # type: ignore

async def create_persona(
self,
name: NotGivenOr[str] = NOT_GIVEN,
*,
extra_payload: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
) -> str:
# Deprecated: use create_pal(). Kept on the legacy /v2/personas endpoint.
warnings.warn(
"`create_persona` is deprecated, use `create_pal` instead",
DeprecationWarning,
stacklevel=2,
)
name = name or utils.shortuuid("lk_persona_")

payload = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from livekit.agents.voice.avatar import AvatarSession as BaseAvatarSession, DataStreamAudioOutput
from livekit.agents.voice.room_io import ATTRIBUTE_PUBLISH_ON_BEHALF

from .api import TavusAPI, TavusException
from .api import TavusAPI, TavusException, _coalesce_with_deprecated
from .log import logger

SAMPLE_RATE = 24000
Expand All @@ -31,6 +31,8 @@ class AvatarSession(BaseAvatarSession):
def __init__(
self,
*,
face_id: NotGivenOr[str] = NOT_GIVEN,
pal_id: NotGivenOr[str] = NOT_GIVEN,
replica_id: NotGivenOr[str] = NOT_GIVEN,
persona_id: NotGivenOr[str] = NOT_GIVEN,
api_url: NotGivenOr[str] = NOT_GIVEN,
Expand All @@ -43,8 +45,13 @@ def __init__(
self._http_session: aiohttp.ClientSession | None = None
self._conn_options = conn_options
self.conversation_id: str | None = None
self._persona_id = persona_id
self._replica_id = replica_id
# `replica_id`/`persona_id` are deprecated aliases for `face_id`/`pal_id`.
self._pal_id = _coalesce_with_deprecated(
pal_id, persona_id, deprecated_name="persona_id", new_name="pal_id"
)
self._face_id = _coalesce_with_deprecated(
face_id, replica_id, deprecated_name="replica_id", new_name="face_id"
)
self._api = TavusAPI(
api_url=api_url,
api_key=api_key,
Expand Down Expand Up @@ -104,8 +111,8 @@ async def start(

logger.debug("starting avatar session")
self.conversation_id = await self._api.create_conversation(
persona_id=self._persona_id,
replica_id=self._replica_id,
pal_id=self._pal_id,
face_id=self._face_id,
properties={"livekit_ws_url": livekit_url, "livekit_room_token": livekit_token},
)

Expand Down
140 changes: 140 additions & 0 deletions livekit-plugins/livekit-plugins-tavus/tests/test_tavus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import warnings
Comment thread
tinalenguyen marked this conversation as resolved.
from unittest.mock import AsyncMock, patch

import pytest

from livekit.agents.utils import http_context
from livekit.plugins.tavus.api import DEFAULT_PAL_ID, TavusAPI
from livekit.plugins.tavus.avatar import AvatarSession

pytestmark = pytest.mark.unit


@pytest.fixture(autouse=True)
def _env(monkeypatch):
for v in ("TAVUS_FACE_ID", "TAVUS_PAL_ID", "TAVUS_REPLICA_ID", "TAVUS_PERSONA_ID"):
monkeypatch.delenv(v, raising=False)
monkeypatch.setenv("TAVUS_API_KEY", "test-key")


def _api() -> TavusAPI:
# session is unused because _post is always mocked in these tests
return TavusAPI(session=object()) # type: ignore[arg-type]


def _mock_post() -> AsyncMock:
return AsyncMock(return_value={"conversation_id": "conv1", "persona_id": "pal_auto"})


def _no_deprecation(rec: list[warnings.WarningMessage]) -> bool:
return not [w for w in rec if issubclass(w.category, DeprecationWarning)]


async def test_new_args_map_to_unchanged_wire_keys():
api = _api()
with patch.object(api, "_post", new=_mock_post()) as m:
with warnings.catch_warnings(record=True) as rec:
warnings.simplefilter("always")
cid = await api.create_conversation(face_id="f1", pal_id="p1")
assert cid == "conv1"
payload = m.call_args.args[1]
assert payload["face_id"] == "f1"
assert payload["pal_id"] == "p1"
assert _no_deprecation(rec)


async def test_deprecated_args_still_work_and_warn():
api = _api()
with patch.object(api, "_post", new=_mock_post()) as m:
with pytest.warns(DeprecationWarning) as rec:
await api.create_conversation(replica_id="r1", persona_id="x1")
payload = m.call_args.args[1]
assert payload["face_id"] == "r1"
assert payload["pal_id"] == "x1"
msgs = [str(w.message) for w in rec]
assert any("replica_id" in s and "face_id" in s for s in msgs)
assert any("persona_id" in s and "pal_id" in s for s in msgs)


async def test_no_warning_when_new_and_deprecated_both_given():
api = _api()
with patch.object(api, "_post", new=_mock_post()) as m:
with warnings.catch_warnings(record=True) as rec:
warnings.simplefilter("always")
await api.create_conversation(
face_id="f1", replica_id="r1", pal_id="p1", persona_id="x1"
)
# the new values win, so the deprecated aliases are unused -> no warning
assert _no_deprecation(rec)
payload = m.call_args.args[1]
assert payload["face_id"] == "f1"
assert payload["pal_id"] == "p1"


async def test_new_env_vars_fallback(monkeypatch):
monkeypatch.setenv("TAVUS_FACE_ID", "envf")
monkeypatch.setenv("TAVUS_PAL_ID", "envp")
api = _api()
with patch.object(api, "_post", new=_mock_post()) as m:
with warnings.catch_warnings(record=True) as rec:
warnings.simplefilter("always")
await api.create_conversation()
payload = m.call_args.args[1]
assert payload["face_id"] == "envf"
assert payload["pal_id"] == "envp"
assert _no_deprecation(rec)


async def test_deprecated_env_vars_still_work_and_warn(monkeypatch):
monkeypatch.setenv("TAVUS_REPLICA_ID", "oldf")
monkeypatch.setenv("TAVUS_PERSONA_ID", "oldp")
api = _api()
with patch.object(api, "_post", new=_mock_post()) as m:
with pytest.warns(DeprecationWarning):
await api.create_conversation()
payload = m.call_args.args[1]
assert payload["face_id"] == "oldf"
assert payload["pal_id"] == "oldp"


async def test_no_pal_uses_default_pal_with_face_override():
api = _api()
with patch.object(api, "_post", new=_mock_post()) as m:
await api.create_conversation(face_id="f1")
assert "pals" not in [c.args[0] for c in m.call_args_list] # no pal is created
payload = m.call_args.args[1]
assert payload["pal_id"] == DEFAULT_PAL_ID
assert payload["face_id"] == "f1"


async def test_pal_id_only_skips_pal_creation_and_omits_face():
api = _api()
with patch.object(api, "_post", new=_mock_post()) as m:
await api.create_conversation(pal_id="p1")
endpoints = [c.args[0] for c in m.call_args_list]
assert "pals" not in endpoints # an existing pal carries its own default face
payload = m.call_args.args[1]
assert payload["pal_id"] == "p1"
assert "face_id" not in payload


async def test_defaults_to_stock_pal_when_neither_given():
api = _api()
with patch.object(api, "_post", new=_mock_post()) as m:
await api.create_conversation()
assert "pals" not in [c.args[0] for c in m.call_args_list] # no pal is created
payload = m.call_args.args[1]
assert payload["pal_id"] == DEFAULT_PAL_ID
assert "face_id" not in payload


async def test_avatar_session_resolves_new_and_deprecated_args():
async with http_context.open():
with pytest.warns(DeprecationWarning):
deprecated = AvatarSession(replica_id="r9", persona_id="x9")
assert deprecated._face_id == "r9"
assert deprecated._pal_id == "x9"

renamed = AvatarSession(face_id="f9", pal_id="p9")
assert renamed._face_id == "f9"
assert renamed._pal_id == "p9"
Loading