Skip to content

Commit d520316

Browse files
committed
fix: address mkozakov review — remove stale model names and revert AWS changes
- README: remove specific model names from Supported APIs and Model Availability sections (will go out of date quickly per review feedback) - Revert all AWS client changes (aws_client.py, cohere_aws/client.py, test_aws_client_unit.py) — this OCI PR should not touch AWS code
1 parent 76d357b commit d520316

4 files changed

Lines changed: 289 additions & 11 deletions

File tree

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,16 +140,16 @@ co = cohere.OciClient(
140140
### Supported OCI APIs
141141

142142
The OCI client supports the following Cohere APIs:
143-
- **Embed**: Full support for all embedding models (embed-english-v3.0, embed-light-v3.0, embed-multilingual-v3.0)
143+
- **Embed**: Full support for all embedding models
144144
- **Chat**: Full support with both V1 (`OciClient`) and V2 (`OciClientV2`) APIs
145145
- Streaming available via `chat_stream()`
146146
- Supports Command-R and Command-A model families
147147

148148
### OCI Model Availability and Limitations
149149

150150
**Available on OCI On-Demand Inference:**
151-
-**Embed models**: embed-english-v3.0, embed-light-v3.0, embed-multilingual-v3.0
152-
-**Chat models**: command-r-08-2024, command-r-plus, command-a-03-2025
151+
-**Embed models**: available on OCI Generative AI
152+
-**Chat models**: available via `OciClient` (V1) and `OciClientV2` (V2)
153153

154154
**Not Available on OCI On-Demand Inference:**
155155
-**Generate API**: OCI TEXT_GENERATION models are base models that require fine-tuning before deployment

src/cohere/aws_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def _event_hook(request: httpx.Request) -> None:
239239
)
240240
request.url = URL(url)
241241
request.headers["host"] = request.url.host
242+
headers["host"] = request.url.host
242243

243244
if endpoint == "rerank":
244245
body["api_version"] = get_api_version(version=api_version)

src/cohere/manually_maintained/cohere_aws/client.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,29 @@ class Client:
2020
def __init__(
2121
self,
2222
aws_region: typing.Optional[str] = None,
23+
mode: Mode = Mode.SAGEMAKER,
2324
):
2425
"""
2526
By default we assume region configured in AWS CLI (`aws configure get region`). You can change the region with
2627
`aws configure set region us-west-2` or override it with `region_name` parameter.
2728
"""
28-
self._client = lazy_boto3().client("sagemaker-runtime", region_name=aws_region)
29-
self._service_client = lazy_boto3().client("sagemaker", region_name=aws_region)
29+
self.mode = mode
3030
if os.environ.get('AWS_DEFAULT_REGION') is None:
3131
os.environ['AWS_DEFAULT_REGION'] = aws_region
32-
self._sess = lazy_sagemaker().Session(sagemaker_client=self._service_client)
33-
self.mode = Mode.SAGEMAKER
3432

33+
if self.mode == Mode.SAGEMAKER:
34+
self._client = lazy_boto3().client("sagemaker-runtime", region_name=aws_region)
35+
self._service_client = lazy_boto3().client("sagemaker", region_name=aws_region)
36+
self._sess = lazy_sagemaker().Session(sagemaker_client=self._service_client)
37+
elif self.mode == Mode.BEDROCK:
38+
self._client = lazy_boto3().client("bedrock-runtime", region_name=aws_region)
39+
self._service_client = lazy_boto3().client("bedrock", region_name=aws_region)
40+
self._sess = None
41+
self._endpoint_name = None
3542

43+
def _require_sagemaker(self) -> None:
44+
if self.mode != Mode.SAGEMAKER:
45+
raise CohereError("This method is only supported in SageMaker mode.")
3646

3747
def _does_endpoint_exist(self, endpoint_name: str) -> bool:
3848
try:
@@ -50,6 +60,7 @@ def connect_to_endpoint(self, endpoint_name: str) -> None:
5060
Raises:
5161
CohereError: Connection to the endpoint failed.
5262
"""
63+
self._require_sagemaker()
5364
if not self._does_endpoint_exist(endpoint_name):
5465
raise CohereError(f"Endpoint {endpoint_name} does not exist.")
5566
self._endpoint_name = endpoint_name
@@ -137,6 +148,7 @@ def create_endpoint(
137148
will be used to get the role. This should work when one uses the client inside SageMaker. If this errors
138149
out, the default role "ServiceRoleSagemaker" will be used, which generally works outside of SageMaker.
139150
"""
151+
self._require_sagemaker()
140152
# First, check if endpoint already exists
141153
if self._does_endpoint_exist(endpoint_name):
142154
if recreate:
@@ -550,11 +562,15 @@ def embed(
550562
variant: Optional[str] = None,
551563
input_type: Optional[str] = None,
552564
model_id: Optional[str] = None,
553-
) -> Embeddings:
565+
output_dimension: Optional[int] = None,
566+
embedding_types: Optional[List[str]] = None,
567+
) -> Union[Embeddings, Dict[str, List]]:
554568
json_params = {
555569
'texts': texts,
556570
'truncate': truncate,
557-
"input_type": input_type
571+
"input_type": input_type,
572+
"output_dimension": output_dimension,
573+
"embedding_types": embedding_types,
558574
}
559575
for key, value in list(json_params.items()):
560576
if value is None:
@@ -591,7 +607,10 @@ def _sagemaker_embed(self, json_params: Dict[str, Any], variant: str):
591607
# ValidationError, e.g. when variant is bad
592608
raise CohereError(str(e))
593609

594-
return Embeddings(response['embeddings'])
610+
embeddings = response['embeddings']
611+
if isinstance(embeddings, dict):
612+
return embeddings
613+
return Embeddings(embeddings)
595614

596615
def _bedrock_embed(self, json_params: Dict[str, Any], model_id: str):
597616
if not model_id:
@@ -612,7 +631,10 @@ def _bedrock_embed(self, json_params: Dict[str, Any], model_id: str):
612631
# ValidationError, e.g. when variant is bad
613632
raise CohereError(str(e))
614633

615-
return Embeddings(response['embeddings'])
634+
embeddings = response['embeddings']
635+
if isinstance(embeddings, dict):
636+
return embeddings
637+
return Embeddings(embeddings)
616638

617639

618640
def rerank(self,
@@ -805,6 +827,7 @@ def export_finetune(
805827
This should work when one uses the client inside SageMaker. If this errors out,
806828
the default role "ServiceRoleSagemaker" will be used, which generally works outside SageMaker.
807829
"""
830+
self._require_sagemaker()
808831
if name == "model":
809832
raise ValueError("name cannot be 'model'")
810833

@@ -948,6 +971,7 @@ def summarize(
948971
additional_command: Optional[str] = "",
949972
variant: Optional[str] = None
950973
) -> Summary:
974+
self._require_sagemaker()
951975

952976
if self._endpoint_name is None:
953977
raise CohereError("No endpoint connected. "
@@ -989,6 +1013,7 @@ def summarize(
9891013

9901014

9911015
def delete_endpoint(self) -> None:
1016+
self._require_sagemaker()
9921017
if self._endpoint_name is None:
9931018
raise CohereError("No endpoint connected.")
9941019
try:

tests/test_aws_client_unit.py

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
"""
2+
Unit tests (mocked, no AWS credentials needed) for AWS client fixes.
3+
4+
Covers:
5+
- Fix 1: SigV4 signing uses the correct host header after URL rewrite
6+
- Fix 2: cohere_aws.Client conditionally initializes based on mode
7+
- Fix 3: embed() accepts and passes output_dimension and embedding_types
8+
"""
9+
10+
import inspect
11+
import json
12+
import os
13+
import unittest
14+
from unittest.mock import MagicMock, patch
15+
16+
import httpx
17+
18+
from cohere.manually_maintained.cohere_aws.mode import Mode
19+
20+
21+
class TestSigV4HostHeader(unittest.TestCase):
22+
"""Fix 1: The headers dict passed to AWSRequest for SigV4 signing must
23+
contain the rewritten Bedrock/SageMaker host, not the stale api.cohere.com."""
24+
25+
def test_sigv4_signs_with_correct_host(self) -> None:
26+
captured_aws_request_kwargs: dict = {}
27+
28+
mock_aws_request_cls = MagicMock()
29+
30+
def capture_aws_request(**kwargs): # type: ignore
31+
captured_aws_request_kwargs.update(kwargs)
32+
mock_req = MagicMock()
33+
mock_req.prepare.return_value = MagicMock(
34+
headers={"host": "bedrock-runtime.us-east-1.amazonaws.com"}
35+
)
36+
return mock_req
37+
38+
mock_aws_request_cls.side_effect = capture_aws_request
39+
40+
mock_botocore = MagicMock()
41+
mock_botocore.awsrequest.AWSRequest = mock_aws_request_cls
42+
mock_botocore.auth.SigV4Auth.return_value = MagicMock()
43+
44+
mock_boto3 = MagicMock()
45+
mock_session = MagicMock()
46+
mock_session.region_name = "us-east-1"
47+
mock_session.get_credentials.return_value = MagicMock()
48+
mock_boto3.Session.return_value = mock_session
49+
50+
with patch("cohere.aws_client.lazy_botocore", return_value=mock_botocore), \
51+
patch("cohere.aws_client.lazy_boto3", return_value=mock_boto3):
52+
53+
from cohere.aws_client import map_request_to_bedrock
54+
55+
hook = map_request_to_bedrock(service="bedrock", aws_region="us-east-1")
56+
57+
request = httpx.Request(
58+
method="POST",
59+
url="https://api.cohere.com/v1/chat",
60+
headers={"connection": "keep-alive"},
61+
json={"model": "cohere.command-r-plus-v1:0", "message": "hello"},
62+
)
63+
64+
self.assertEqual(request.url.host, "api.cohere.com")
65+
66+
hook(request)
67+
68+
self.assertIn("bedrock-runtime.us-east-1.amazonaws.com", str(request.url))
69+
70+
signed_headers = captured_aws_request_kwargs["headers"]
71+
self.assertEqual(
72+
signed_headers["host"],
73+
"bedrock-runtime.us-east-1.amazonaws.com",
74+
)
75+
76+
77+
class TestModeConditionalInit(unittest.TestCase):
78+
"""Fix 2: cohere_aws.Client should initialize different boto3 clients
79+
depending on mode, and default to SAGEMAKER for backwards compat."""
80+
81+
def test_sagemaker_mode_creates_sagemaker_clients(self) -> None:
82+
mock_boto3 = MagicMock()
83+
mock_sagemaker = MagicMock()
84+
85+
with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \
86+
patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=mock_sagemaker), \
87+
patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-east-1"}):
88+
89+
from cohere.manually_maintained.cohere_aws.client import Client
90+
91+
client = Client(aws_region="us-east-1")
92+
93+
self.assertEqual(client.mode, Mode.SAGEMAKER)
94+
95+
service_names = [c[0][0] for c in mock_boto3.client.call_args_list]
96+
self.assertIn("sagemaker-runtime", service_names)
97+
self.assertIn("sagemaker", service_names)
98+
self.assertNotIn("bedrock-runtime", service_names)
99+
self.assertNotIn("bedrock", service_names)
100+
101+
mock_sagemaker.Session.assert_called_once()
102+
103+
def test_bedrock_mode_creates_bedrock_clients(self) -> None:
104+
mock_boto3 = MagicMock()
105+
mock_sagemaker = MagicMock()
106+
107+
with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \
108+
patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=mock_sagemaker), \
109+
patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-west-2"}):
110+
111+
from cohere.manually_maintained.cohere_aws.client import Client
112+
113+
client = Client(aws_region="us-west-2", mode=Mode.BEDROCK)
114+
115+
self.assertEqual(client.mode, Mode.BEDROCK)
116+
117+
service_names = [c[0][0] for c in mock_boto3.client.call_args_list]
118+
self.assertIn("bedrock-runtime", service_names)
119+
self.assertIn("bedrock", service_names)
120+
self.assertNotIn("sagemaker-runtime", service_names)
121+
self.assertNotIn("sagemaker", service_names)
122+
123+
mock_sagemaker.Session.assert_not_called()
124+
125+
def test_default_mode_is_sagemaker(self) -> None:
126+
from cohere.manually_maintained.cohere_aws.client import Client
127+
128+
sig = inspect.signature(Client.__init__)
129+
self.assertEqual(sig.parameters["mode"].default, Mode.SAGEMAKER)
130+
131+
132+
class TestEmbedV4Params(unittest.TestCase):
133+
"""Fix 3: embed() should accept output_dimension and embedding_types,
134+
pass them through to the request body, and strip them when None."""
135+
136+
@staticmethod
137+
def _make_bedrock_client(): # type: ignore
138+
mock_boto3 = MagicMock()
139+
mock_botocore = MagicMock()
140+
captured_body: dict = {}
141+
142+
def fake_invoke_model(**kwargs): # type: ignore
143+
captured_body.update(json.loads(kwargs["body"]))
144+
mock_body = MagicMock()
145+
mock_body.read.return_value = json.dumps({"embeddings": [[0.1, 0.2]]}).encode()
146+
return {"body": mock_body}
147+
148+
mock_bedrock_client = MagicMock()
149+
mock_bedrock_client.invoke_model.side_effect = fake_invoke_model
150+
151+
def fake_boto3_client(service_name, **kwargs): # type: ignore
152+
if service_name == "bedrock-runtime":
153+
return mock_bedrock_client
154+
return MagicMock()
155+
156+
mock_boto3.client.side_effect = fake_boto3_client
157+
return mock_boto3, mock_botocore, captured_body
158+
159+
def test_embed_accepts_new_params(self) -> None:
160+
from cohere.manually_maintained.cohere_aws.client import Client
161+
162+
sig = inspect.signature(Client.embed)
163+
self.assertIn("output_dimension", sig.parameters)
164+
self.assertIn("embedding_types", sig.parameters)
165+
self.assertIsNone(sig.parameters["output_dimension"].default)
166+
self.assertIsNone(sig.parameters["embedding_types"].default)
167+
168+
def test_embed_passes_params_to_bedrock(self) -> None:
169+
mock_boto3, mock_botocore, captured_body = self._make_bedrock_client()
170+
171+
with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \
172+
patch("cohere.manually_maintained.cohere_aws.client.lazy_botocore", return_value=mock_botocore), \
173+
patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=MagicMock()), \
174+
patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-east-1"}):
175+
176+
from cohere.manually_maintained.cohere_aws.client import Client
177+
178+
client = Client(aws_region="us-east-1", mode=Mode.BEDROCK)
179+
client.embed(
180+
texts=["hello world"],
181+
input_type="search_document",
182+
model_id="cohere.embed-english-v3",
183+
output_dimension=256,
184+
embedding_types=["float", "int8"],
185+
)
186+
187+
self.assertEqual(captured_body["output_dimension"], 256)
188+
self.assertEqual(captured_body["embedding_types"], ["float", "int8"])
189+
190+
def test_embed_omits_none_params(self) -> None:
191+
mock_boto3, mock_botocore, captured_body = self._make_bedrock_client()
192+
193+
with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \
194+
patch("cohere.manually_maintained.cohere_aws.client.lazy_botocore", return_value=mock_botocore), \
195+
patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=MagicMock()), \
196+
patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-east-1"}):
197+
198+
from cohere.manually_maintained.cohere_aws.client import Client
199+
200+
client = Client(aws_region="us-east-1", mode=Mode.BEDROCK)
201+
client.embed(
202+
texts=["hello world"],
203+
input_type="search_document",
204+
model_id="cohere.embed-english-v3",
205+
)
206+
207+
self.assertNotIn("output_dimension", captured_body)
208+
self.assertNotIn("embedding_types", captured_body)
209+
210+
def test_embed_with_embedding_types_returns_dict(self) -> None:
211+
"""When embedding_types is specified, the API returns embeddings as a dict.
212+
The client should return that dict rather than wrapping it in Embeddings."""
213+
mock_boto3 = MagicMock()
214+
mock_botocore = MagicMock()
215+
216+
by_type_embeddings = {"float": [[0.1, 0.2]], "int8": [[1, 2]]}
217+
218+
def fake_invoke_model(**kwargs): # type: ignore
219+
mock_body = MagicMock()
220+
mock_body.read.return_value = json.dumps({
221+
"embeddings": by_type_embeddings,
222+
"response_type": "embeddings_by_type",
223+
}).encode()
224+
return {"body": mock_body}
225+
226+
mock_bedrock_client = MagicMock()
227+
mock_bedrock_client.invoke_model.side_effect = fake_invoke_model
228+
229+
def fake_boto3_client(service_name, **kwargs): # type: ignore
230+
if service_name == "bedrock-runtime":
231+
return mock_bedrock_client
232+
return MagicMock()
233+
234+
mock_boto3.client.side_effect = fake_boto3_client
235+
236+
with patch("cohere.manually_maintained.cohere_aws.client.lazy_boto3", return_value=mock_boto3), \
237+
patch("cohere.manually_maintained.cohere_aws.client.lazy_botocore", return_value=mock_botocore), \
238+
patch("cohere.manually_maintained.cohere_aws.client.lazy_sagemaker", return_value=MagicMock()), \
239+
patch.dict(os.environ, {"AWS_DEFAULT_REGION": "us-east-1"}):
240+
241+
from cohere.manually_maintained.cohere_aws.client import Client
242+
243+
client = Client(aws_region="us-east-1", mode=Mode.BEDROCK)
244+
result = client.embed(
245+
texts=["hello world"],
246+
input_type="search_document",
247+
model_id="cohere.embed-english-v3",
248+
embedding_types=["float", "int8"],
249+
)
250+
251+
self.assertIsInstance(result, dict)
252+
self.assertEqual(result, by_type_embeddings)

0 commit comments

Comments
 (0)