16
16
import warnings
17
17
from pathlib import Path
18
18
from typing import Literal
19
+ from uuid import uuid4
19
20
20
21
from crab import (
21
22
BenchmarkConfig ,
22
23
Experiment ,
23
24
MessageType ,
25
+ Task ,
24
26
TaskGenerator ,
25
27
create_benchmark ,
26
28
)
37
39
)
38
40
from crab .core .agent_policy import AgentPolicy
39
41
from crab .core .benchmark import Benchmark
42
+ from crab .core .decorators import evaluator
43
+ from crab .environments .macos import mac_env
40
44
41
45
from .android_env import ANDROID_ENV
42
46
from .dataset .android_subtasks import android_subtasks
@@ -79,6 +83,11 @@ def get_prompt(self):
79
83
return result_prompt
80
84
81
85
86
+ @evaluator (env_name = "macos" )
87
+ def empty_evaluator () -> bool :
88
+ return False
89
+
90
+
82
91
def get_benchmark (env : str , ubuntu_url : str ):
83
92
ubuntu_env = UBUNTU_ENV .model_copy ()
84
93
ubuntu_env .remote_url = ubuntu_url
@@ -88,6 +97,9 @@ def get_benchmark(env: str, ubuntu_url: str):
88
97
android_tool = {
89
98
"screenshot" : groundingdino_easyocr (font_size = 40 ) >> get_elements_prompt
90
99
}
100
+ mac_tool = {
101
+ "screenshot" : groundingdino_easyocr (font_size = 24 ) >> get_elements_prompt
102
+ }
91
103
92
104
if env == "ubuntu" :
93
105
prompting_tools = {"ubuntu" : ubuntu_tool }
@@ -122,6 +134,22 @@ def get_benchmark(env: str, ubuntu_url: str):
122
134
root_action_space = [complete ],
123
135
multienv = True ,
124
136
)
137
+ elif env == "mac" :
138
+ task = Task (
139
+ description = "Open firefox in both macos and android." ,
140
+ id = "0" ,
141
+ evaluator = empty_evaluator ,
142
+ )
143
+ prompting_tools = {"macos" : mac_tool , "android" : android_tool }
144
+ mac_env .remote_url = "http://10.85.170.240:8000"
145
+ benchmark_config = BenchmarkConfig (
146
+ name = "mac_benchmark" ,
147
+ tasks = [task ],
148
+ environments = [mac_env , ANDROID_ENV ],
149
+ prompting_tools = prompting_tools ,
150
+ root_action_space = [complete ],
151
+ multienv = True ,
152
+ )
125
153
else :
126
154
raise ValueError ("Env not support" )
127
155
@@ -169,7 +197,13 @@ def get_benchmark(env: str, ubuntu_url: str):
169
197
help = "ubuntu, android or cross" ,
170
198
default = "cross" ,
171
199
)
172
- parser .add_argument ("--task-id" , type = str , help = "task id" )
200
+ parser .add_argument ("--task-id" , type = str , help = "task id" , default = None )
201
+ parser .add_argument (
202
+ "--task-description" ,
203
+ type = str ,
204
+ help = "task description. If provided, will overwrite the task id." ,
205
+ default = None ,
206
+ )
173
207
parser .add_argument (
174
208
"--loglevel" ,
175
209
type = str ,
@@ -180,20 +214,39 @@ def get_benchmark(env: str, ubuntu_url: str):
180
214
loglevel = args .loglevel
181
215
numeric_level = getattr (logging , loglevel .upper (), None )
182
216
if not isinstance (numeric_level , int ):
183
- raise ValueError (' Invalid log level: %s' % loglevel )
217
+ raise ValueError (" Invalid log level: %s" % loglevel )
184
218
logging .basicConfig (level = numeric_level )
185
219
186
-
187
220
benchmark = get_benchmark (args .env , args .remote_url )
188
221
222
+ if args .task_description is not None :
223
+ task_id = str (uuid4 ())
224
+ benchmark .tasks = [
225
+ Task (
226
+ id = task_id ,
227
+ description = args .task_description ,
228
+ evaluator = empty_evaluator ,
229
+ )
230
+ ]
231
+ else :
232
+ task_id = args .task_id
233
+
234
+ history_messages_len = 2
235
+
189
236
if args .model == "gpt4o" :
190
- model = OpenAIModel (model = "gpt-4o" )
237
+ model = OpenAIModel (model = "gpt-4o" , history_messages_len = history_messages_len )
191
238
elif args .policy == "gpt4turbo" :
192
- model = OpenAIModel (model = "gpt-4-turbo" )
239
+ model = OpenAIModel (
240
+ model = "gpt-4-turbo" , history_messages_len = history_messages_len
241
+ )
193
242
elif args .policy == "gemini" :
194
- model = GeminiModel (model = "gemini-1.5-pro-latest" )
243
+ model = GeminiModel (
244
+ model = "gemini-1.5-pro-latest" , history_messages_len = history_messages_len
245
+ )
195
246
elif args .policy == "claude" :
196
- model = ClaudeModel (model = "claude-3-opus-20240229" )
247
+ model = ClaudeModel (
248
+ model = "claude-3-opus-20240229" , history_messages_len = history_messages_len
249
+ )
197
250
else :
198
251
print ("Unsupported model: " , args .model )
199
252
exit ()
@@ -215,7 +268,7 @@ def get_benchmark(env: str, ubuntu_url: str):
215
268
log_dir = (Path (__file__ ).parent / "logs" ).resolve ()
216
269
expeirment = CrabBenchmarkV0 (
217
270
benchmark = benchmark ,
218
- task_id = args . task_id ,
271
+ task_id = task_id ,
219
272
agent_policy = agent_policy ,
220
273
log_dir = log_dir ,
221
274
)
0 commit comments