4
4
import logging
5
5
from typing import Any , List , Optional , Generator
6
6
7
- from pydantic import Field , BaseModel , PrivateAttr , model_validator
7
+ from pydantic import Field , BaseModel , PrivateAttr , validator , model_validator
8
8
from botocore .config import Config
9
9
10
10
from rhubarb .models import LanguageModels
11
11
from rhubarb .invocations import Invocations
12
- from rhubarb .user_prompts import AnthropicMessages
12
+ from rhubarb .user_prompts import UserMessages
13
13
from rhubarb .system_prompts import SystemPrompts
14
14
15
15
logger = logging .getLogger (__name__ )
@@ -53,9 +53,15 @@ class DocAnalysis(BaseModel):
53
53
modelId : LanguageModels = Field (default = LanguageModels .CLAUDE_SONNET_V2 )
54
54
"""Bedrock Model ID"""
55
55
56
- system_prompt : str = Field (default = SystemPrompts (). DefaultSysPrompt )
56
+ system_prompt : str = Field (default = "" )
57
57
"""System prompt"""
58
58
59
+ @validator ("system_prompt" , pre = True , always = True )
60
+ def set_system_prompt (cls , v , values ):
61
+ return SystemPrompts (
62
+ model_id = values .get ("modelId" , LanguageModels .CLAUDE_SONNET_V2 )
63
+ ).DefaultSysPrompt
64
+
59
65
boto3_session : Any
60
66
"""Instance of boto3.session.Session"""
61
67
@@ -129,14 +135,14 @@ def validate_model(cls, values: dict) -> dict:
129
135
def history (self ) -> Any :
130
136
return self ._message_history
131
137
132
- def _get_anthropic_prompt (
138
+ def _get_user_prompt (
133
139
self ,
134
140
message : Any ,
135
141
sys_prompt : str ,
136
142
output_schema : Optional [dict ] = None ,
137
143
history : Optional [List [dict ]] = None ,
138
144
) -> Any :
139
- return AnthropicMessages (
145
+ return UserMessages (
140
146
file_path = self .file_path ,
141
147
s3_client = self ._s3_client ,
142
148
system_prompt = sys_prompt ,
@@ -147,6 +153,7 @@ def _get_anthropic_prompt(
147
153
pages = self .pages ,
148
154
use_converse_api = self .use_converse_api ,
149
155
message_history = history ,
156
+ modelId = self .modelId ,
150
157
)
151
158
152
159
def run (
@@ -163,12 +170,14 @@ def run(
163
170
- `output_schema` (`Optional[dict]`, optional): The output JSON schema for the language model response. Defaults to None.
164
171
"""
165
172
if (
166
- self .modelId == LanguageModels .CLAUDE_OPUS_V1
167
- or self .modelId == LanguageModels .CLAUDE_HAIKU_V1
173
+ self .modelId == LanguageModels .CLAUDE_HAIKU_V1
168
174
or self .modelId == LanguageModels .CLAUDE_SONNET_V1
169
- or self .modelId == LanguageModels .CLAUDE_SONNET_V2
175
+ or self .modelId == LanguageModels .CLAUDE_SONNET_V2
176
+ or self .modelId == LanguageModels .NOVA_LITE
177
+ or self .modelId == LanguageModels .NOVA_PRO
170
178
):
171
- a_msg = self ._get_anthropic_prompt (
179
+ # sys_prompt = SystemPrompts(model_id=self.modelId).DefaultSysPrompt
180
+ a_msg = self ._get_user_prompt (
172
181
message = message ,
173
182
output_schema = output_schema ,
174
183
sys_prompt = self .system_prompt ,
@@ -182,8 +191,8 @@ def run(
182
191
boto3_session = self .boto3_session ,
183
192
model_id = self .modelId .value ,
184
193
output_schema = output_schema ,
185
- use_converse_api = self .use_converse_api ,
186
- enable_cri = self .enable_cri
194
+ use_converse_api = self .use_converse_api ,
195
+ enable_cri = self .enable_cri ,
187
196
)
188
197
response = model_invoke .run_inference ()
189
198
self ._message_history = model_invoke .message_history
@@ -202,20 +211,22 @@ def run_stream(
202
211
self .modelId == LanguageModels .CLAUDE_OPUS_V1
203
212
or self .modelId == LanguageModels .CLAUDE_HAIKU_V1
204
213
or self .modelId == LanguageModels .CLAUDE_SONNET_V1
205
- or self .modelId == LanguageModels .CLAUDE_SONNET_V2
214
+ or self .modelId == LanguageModels .CLAUDE_SONNET_V2
215
+ or self .modelId == LanguageModels .NOVA_LITE
216
+ or self .modelId == LanguageModels .NOVA_PRO
206
217
):
207
- a_msg = self ._get_anthropic_prompt (
218
+ a_msg = self ._get_user_prompt (
208
219
message = message , sys_prompt = self .system_prompt , history = history
209
220
)
210
221
body = a_msg .messages ()
211
222
212
223
model_invoke = Invocations (
213
- body = body ,
224
+ body = body ,
214
225
bedrock_client = self ._bedrock_client ,
215
226
boto3_session = self .boto3_session ,
216
227
model_id = self .modelId .value ,
217
- use_converse_api = self .use_converse_api ,
218
- enable_cri = self .enable_cri
228
+ use_converse_api = self .use_converse_api ,
229
+ enable_cri = self .enable_cri ,
219
230
)
220
231
for response in model_invoke .run_inference_stream ():
221
232
yield response
@@ -233,26 +244,28 @@ def run_entity(self, message: Any, entities: List[Any]) -> Any:
233
244
self .modelId == LanguageModels .CLAUDE_OPUS_V1
234
245
or self .modelId == LanguageModels .CLAUDE_HAIKU_V1
235
246
or self .modelId == LanguageModels .CLAUDE_SONNET_V1
236
- or self .modelId == LanguageModels .CLAUDE_SONNET_V2
247
+ or self .modelId == LanguageModels .CLAUDE_SONNET_V2
248
+ or self .modelId == LanguageModels .NOVA_LITE
249
+ or self .modelId == LanguageModels .NOVA_PRO
237
250
):
238
- sys_prompt = SystemPrompts (entities = entities ).NERSysPrompt
239
- a_msg = self ._get_anthropic_prompt (message = message , sys_prompt = sys_prompt )
251
+ sys_prompt = SystemPrompts (entities = entities , model_id = self . modelId ).NERSysPrompt
252
+ a_msg = self ._get_user_prompt (message = message , sys_prompt = sys_prompt )
240
253
body = a_msg .messages ()
241
254
242
255
model_invoke = Invocations (
243
- body = body ,
256
+ body = body ,
244
257
bedrock_client = self ._bedrock_client ,
245
258
boto3_session = self .boto3_session ,
246
259
model_id = self .modelId .value ,
247
- use_converse_api = self .use_converse_api ,
248
- enable_cri = self .enable_cri
260
+ use_converse_api = self .use_converse_api ,
261
+ enable_cri = self .enable_cri ,
249
262
)
250
263
response = model_invoke .run_inference ()
251
264
return response
252
265
253
266
def generate_schema (self , message : str , assistive_rephrase : Optional [bool ] = False ) -> dict :
254
267
"""
255
- Invokes the specified language model with the given message to genereate a JSON
268
+ Invokes the specified language model with the given message to generate a JSON
256
269
schema for a given document.
257
270
258
271
Args:
@@ -264,21 +277,23 @@ def generate_schema(self, message: str, assistive_rephrase: Optional[bool] = Fal
264
277
or self .modelId == LanguageModels .CLAUDE_HAIKU_V1
265
278
or self .modelId == LanguageModels .CLAUDE_SONNET_V1
266
279
or self .modelId == LanguageModels .CLAUDE_SONNET_V2
280
+ or self .modelId == LanguageModels .NOVA_LITE
281
+ or self .modelId == LanguageModels .NOVA_PRO
267
282
):
268
283
if assistive_rephrase :
269
- sys_prompt = SystemPrompts ().SchemaGenSysPromptWithRephrase
284
+ sys_prompt = SystemPrompts (model_id = self . modelId ).SchemaGenSysPromptWithRephrase
270
285
else :
271
- sys_prompt = SystemPrompts ().SchemaGenSysPrompt
272
- a_msg = self ._get_anthropic_prompt (message = message , sys_prompt = sys_prompt )
286
+ sys_prompt = SystemPrompts (model_id = self . modelId ).SchemaGenSysPrompt
287
+ a_msg = self ._get_user_prompt (message = message , sys_prompt = sys_prompt )
273
288
body = a_msg .messages ()
274
289
275
290
model_invoke = Invocations (
276
- body = body ,
291
+ body = body ,
277
292
bedrock_client = self ._bedrock_client ,
278
293
boto3_session = self .boto3_session ,
279
294
model_id = self .modelId .value ,
280
- use_converse_api = self .use_converse_api ,
281
- enable_cri = self .enable_cri
295
+ use_converse_api = self .use_converse_api ,
296
+ enable_cri = self .enable_cri ,
282
297
)
283
298
response = model_invoke .run_inference ()
284
299
return response
0 commit comments