Skip to content

Commit 1a4b100

Browse files
authored
Merge pull request #24 from awslabs/feature/nova-models
Added Support for Amazon Nova Lite and Pro models 🚀
2 parents c8866ab + 751f415 commit 1a4b100

File tree

8 files changed

+162
-57
lines changed

8 files changed

+162
-57
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pyrhubarb"
3-
version = "0.0.3"
3+
version = "0.0.4"
44
description = "A Python framework for multi-modal document understanding with generative AI"
55
authors = ["Rhubarb Developers <rhubarb-developers@amazon.com>"]
66
license = "Apache 2.0"

src/rhubarb/analyze.py

+44-29
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
import logging
55
from typing import Any, List, Optional, Generator
66

7-
from pydantic import Field, BaseModel, PrivateAttr, model_validator
7+
from pydantic import Field, BaseModel, PrivateAttr, validator, model_validator
88
from botocore.config import Config
99

1010
from rhubarb.models import LanguageModels
1111
from rhubarb.invocations import Invocations
12-
from rhubarb.user_prompts import AnthropicMessages
12+
from rhubarb.user_prompts import UserMessages
1313
from rhubarb.system_prompts import SystemPrompts
1414

1515
logger = logging.getLogger(__name__)
@@ -53,9 +53,15 @@ class DocAnalysis(BaseModel):
5353
modelId: LanguageModels = Field(default=LanguageModels.CLAUDE_SONNET_V2)
5454
"""Bedrock Model ID"""
5555

56-
system_prompt: str = Field(default=SystemPrompts().DefaultSysPrompt)
56+
system_prompt: str = Field(default="")
5757
"""System prompt"""
5858

59+
@validator("system_prompt", pre=True, always=True)
60+
def set_system_prompt(cls, v, values):
61+
return SystemPrompts(
62+
model_id=values.get("modelId", LanguageModels.CLAUDE_SONNET_V2)
63+
).DefaultSysPrompt
64+
5965
boto3_session: Any
6066
"""Instance of boto3.session.Session"""
6167

@@ -129,14 +135,14 @@ def validate_model(cls, values: dict) -> dict:
129135
def history(self) -> Any:
130136
return self._message_history
131137

132-
def _get_anthropic_prompt(
138+
def _get_user_prompt(
133139
self,
134140
message: Any,
135141
sys_prompt: str,
136142
output_schema: Optional[dict] = None,
137143
history: Optional[List[dict]] = None,
138144
) -> Any:
139-
return AnthropicMessages(
145+
return UserMessages(
140146
file_path=self.file_path,
141147
s3_client=self._s3_client,
142148
system_prompt=sys_prompt,
@@ -147,6 +153,7 @@ def _get_anthropic_prompt(
147153
pages=self.pages,
148154
use_converse_api=self.use_converse_api,
149155
message_history=history,
156+
modelId=self.modelId,
150157
)
151158

152159
def run(
@@ -163,12 +170,14 @@ def run(
163170
- `output_schema` (`Optional[dict]`, optional): The output JSON schema for the language model response. Defaults to None.
164171
"""
165172
if (
166-
self.modelId == LanguageModels.CLAUDE_OPUS_V1
167-
or self.modelId == LanguageModels.CLAUDE_HAIKU_V1
173+
self.modelId == LanguageModels.CLAUDE_HAIKU_V1
168174
or self.modelId == LanguageModels.CLAUDE_SONNET_V1
169-
or self.modelId == LanguageModels.CLAUDE_SONNET_V2
175+
or self.modelId == LanguageModels.CLAUDE_SONNET_V2
176+
or self.modelId == LanguageModels.NOVA_LITE
177+
or self.modelId == LanguageModels.NOVA_PRO
170178
):
171-
a_msg = self._get_anthropic_prompt(
179+
# sys_prompt = SystemPrompts(model_id=self.modelId).DefaultSysPrompt
180+
a_msg = self._get_user_prompt(
172181
message=message,
173182
output_schema=output_schema,
174183
sys_prompt=self.system_prompt,
@@ -182,8 +191,8 @@ def run(
182191
boto3_session=self.boto3_session,
183192
model_id=self.modelId.value,
184193
output_schema=output_schema,
185-
use_converse_api = self.use_converse_api,
186-
enable_cri = self.enable_cri
194+
use_converse_api=self.use_converse_api,
195+
enable_cri=self.enable_cri,
187196
)
188197
response = model_invoke.run_inference()
189198
self._message_history = model_invoke.message_history
@@ -202,20 +211,22 @@ def run_stream(
202211
self.modelId == LanguageModels.CLAUDE_OPUS_V1
203212
or self.modelId == LanguageModels.CLAUDE_HAIKU_V1
204213
or self.modelId == LanguageModels.CLAUDE_SONNET_V1
205-
or self.modelId == LanguageModels.CLAUDE_SONNET_V2
214+
or self.modelId == LanguageModels.CLAUDE_SONNET_V2
215+
or self.modelId == LanguageModels.NOVA_LITE
216+
or self.modelId == LanguageModels.NOVA_PRO
206217
):
207-
a_msg = self._get_anthropic_prompt(
218+
a_msg = self._get_user_prompt(
208219
message=message, sys_prompt=self.system_prompt, history=history
209220
)
210221
body = a_msg.messages()
211222

212223
model_invoke = Invocations(
213-
body=body,
224+
body=body,
214225
bedrock_client=self._bedrock_client,
215226
boto3_session=self.boto3_session,
216227
model_id=self.modelId.value,
217-
use_converse_api = self.use_converse_api,
218-
enable_cri = self.enable_cri
228+
use_converse_api=self.use_converse_api,
229+
enable_cri=self.enable_cri,
219230
)
220231
for response in model_invoke.run_inference_stream():
221232
yield response
@@ -233,26 +244,28 @@ def run_entity(self, message: Any, entities: List[Any]) -> Any:
233244
self.modelId == LanguageModels.CLAUDE_OPUS_V1
234245
or self.modelId == LanguageModels.CLAUDE_HAIKU_V1
235246
or self.modelId == LanguageModels.CLAUDE_SONNET_V1
236-
or self.modelId == LanguageModels.CLAUDE_SONNET_V2
247+
or self.modelId == LanguageModels.CLAUDE_SONNET_V2
248+
or self.modelId == LanguageModels.NOVA_LITE
249+
or self.modelId == LanguageModels.NOVA_PRO
237250
):
238-
sys_prompt = SystemPrompts(entities=entities).NERSysPrompt
239-
a_msg = self._get_anthropic_prompt(message=message, sys_prompt=sys_prompt)
251+
sys_prompt = SystemPrompts(entities=entities, model_id=self.modelId).NERSysPrompt
252+
a_msg = self._get_user_prompt(message=message, sys_prompt=sys_prompt)
240253
body = a_msg.messages()
241254

242255
model_invoke = Invocations(
243-
body=body,
256+
body=body,
244257
bedrock_client=self._bedrock_client,
245258
boto3_session=self.boto3_session,
246259
model_id=self.modelId.value,
247-
use_converse_api = self.use_converse_api,
248-
enable_cri = self.enable_cri
260+
use_converse_api=self.use_converse_api,
261+
enable_cri=self.enable_cri,
249262
)
250263
response = model_invoke.run_inference()
251264
return response
252265

253266
def generate_schema(self, message: str, assistive_rephrase: Optional[bool] = False) -> dict:
254267
"""
255-
Invokes the specified language model with the given message to genereate a JSON
268+
Invokes the specified language model with the given message to generate a JSON
256269
schema for a given document.
257270
258271
Args:
@@ -264,21 +277,23 @@ def generate_schema(self, message: str, assistive_rephrase: Optional[bool] = Fal
264277
or self.modelId == LanguageModels.CLAUDE_HAIKU_V1
265278
or self.modelId == LanguageModels.CLAUDE_SONNET_V1
266279
or self.modelId == LanguageModels.CLAUDE_SONNET_V2
280+
or self.modelId == LanguageModels.NOVA_LITE
281+
or self.modelId == LanguageModels.NOVA_PRO
267282
):
268283
if assistive_rephrase:
269-
sys_prompt = SystemPrompts().SchemaGenSysPromptWithRephrase
284+
sys_prompt = SystemPrompts(model_id=self.modelId).SchemaGenSysPromptWithRephrase
270285
else:
271-
sys_prompt = SystemPrompts().SchemaGenSysPrompt
272-
a_msg = self._get_anthropic_prompt(message=message, sys_prompt=sys_prompt)
286+
sys_prompt = SystemPrompts(model_id=self.modelId).SchemaGenSysPrompt
287+
a_msg = self._get_user_prompt(message=message, sys_prompt=sys_prompt)
273288
body = a_msg.messages()
274289

275290
model_invoke = Invocations(
276-
body=body,
291+
body=body,
277292
bedrock_client=self._bedrock_client,
278293
boto3_session=self.boto3_session,
279294
model_id=self.modelId.value,
280-
use_converse_api = self.use_converse_api,
281-
enable_cri = self.enable_cri
295+
use_converse_api=self.use_converse_api,
296+
enable_cri=self.enable_cri,
282297
)
283298
response = model_invoke.run_inference()
284299
return response

src/rhubarb/invocations/invocations.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,14 @@ def _invoke_model_json(self) -> dict:
219219
with response["body"] as stream:
220220
response_body = json.load(stream)
221221

222-
response_text = response_body["content"][0]["text"]
223-
input_tokens = response_body["usage"]["input_tokens"]
224-
output_tokens = response_body["usage"]["output_tokens"]
222+
if "nova" in str((self.model_id)).lower():
223+
response_text = response_body["output"]["message"]["content"][0]["text"]
224+
input_tokens = response_body["usage"]["inputTokens"]
225+
output_tokens = response_body["usage"]["outputTokens"]
226+
else:
227+
response_text = response_body["content"][0]["text"]
228+
input_tokens = response_body["usage"]["input_tokens"]
229+
output_tokens = response_body["usage"]["output_tokens"]
225230
total_tokens = input_tokens + output_tokens
226231

227232
self.token_usage = {
@@ -231,7 +236,13 @@ def _invoke_model_json(self) -> dict:
231236
}
232237

233238
messages = self.body["messages"]
234-
messages.append({"role": response_body["role"], "content": response_body["content"]})
239+
messages.append(
240+
{
241+
"role": response_body.get("role", "assistant"),
242+
"content": response_body.get("content")
243+
or response_body.get("output", {}).get("message", {}).get("content", ""),
244+
}
245+
)
235246

236247
self.history = messages
237248
output = self._extract_json_from_markdown(response_text)

src/rhubarb/models.py

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ class LanguageModels(Enum):
99
CLAUDE_SONNET_V1 = "anthropic.claude-3-sonnet-20240229-v1:0"
1010
CLAUDE_HAIKU_V1 = "anthropic.claude-3-haiku-20240307-v1:0"
1111
CLAUDE_SONNET_V2 = "anthropic.claude-3-5-sonnet-20240620-v1:0"
12+
NOVA_PRO = "amazon.nova-pro-v1:0"
13+
NOVA_LITE = "amazon.nova-lite-v1:0"
1214

1315

1416
class EmbeddingModels(Enum):

0 commit comments

Comments
 (0)