6
6
from .utils import download_json , format_node , find_node_by_path , find_node_by_axtid
7
7
8
8
class SFTConverter :
9
- """将处理后的轨迹数据转换为SFT训练格式 """
9
+ """Convert processed trajectory data to SFT training format """
10
10
11
11
def __init__ (self , config : Dict ):
12
12
self .config = config
13
13
self ._setup_logging ()
14
14
self ._setup_templates ()
15
15
16
16
def _setup_logging (self ):
17
- """设置日志 """
17
+ """Initialize logging configuration """
18
18
log_dir = self .config .get ('log_dir' , 'logs/converter' )
19
19
os .makedirs (log_dir , exist_ok = True )
20
20
@@ -28,8 +28,8 @@ def _setup_logging(self):
28
28
)
29
29
30
30
def _setup_templates (self ):
31
- """设置提示模板 """
32
- # 系统提示信息
31
+ """Setup prompt templates for model input/output """
32
+ # System prompt
33
33
self .prompt_system = '''
34
34
# CONTEXT
35
35
@@ -60,7 +60,7 @@ def _setup_templates(self):
60
60
3. Format actions correctly using the specified structure.
61
61
'''
62
62
63
- # 输入提示模板
63
+ # Input prompt template
64
64
self .prompt_input_template = '''
65
65
# OBSERVATION
66
66
@@ -79,22 +79,22 @@ def _setup_templates(self):
79
79
{action_list}
80
80
'''
81
81
82
- # 输出提示模板
82
+ # Output prompt template
83
83
self .prompt_output_template = '''
84
84
Based on the observation and objective, I will:
85
85
86
86
{action}
87
87
'''
88
88
89
- # 动作模板
89
+ # Action template
90
90
self .action_template = '''
91
91
## Action {i}
92
92
- action_type: {action_type}
93
93
- action_value: {action_value}
94
94
'''
95
95
96
96
def convert_to_sft_format (self , processed_data : List [Dict ]) -> List [Dict ]:
97
- """转换为SFT训练格式 """
97
+ """Convert data to SFT training format """
98
98
sft_data = []
99
99
100
100
for traj_idx , trajectory in enumerate (processed_data ):
@@ -106,33 +106,27 @@ def convert_to_sft_format(self, processed_data: List[Dict]) -> List[Dict]:
106
106
steps = json .loads (steps_str )
107
107
logging .info (f"Found { len (steps )} steps in trajectory" )
108
108
109
- # 创建轨迹特定的目录
110
109
traj_dirs = self ._create_trajectory_dirs (traj_idx )
111
110
112
- # 处理每个步骤
113
111
for step_idx , step in enumerate (steps ):
114
112
logging .info (f"Processing step { step_idx } " )
115
113
116
114
if not self ._validate_step (step ):
117
115
continue
118
116
119
117
try :
120
- # 处理 axtree 数据
121
118
formatted_axtree , retrieved_axtree = self ._process_axtree (step , traj_idx , step_idx , traj_dirs )
122
119
if not formatted_axtree or not retrieved_axtree :
123
120
continue
124
121
125
- # 构建动作历史
126
122
action_list = self ._build_action_list (steps [:step_idx ])
127
123
128
- # 构建当前动作
129
124
current_action = {
130
125
"action_type" : step ["type" ],
131
126
"action_id" : step .get ("axtId" , "" ),
132
127
"action_value" : step .get ("value" , "" )
133
128
}
134
129
135
- # 构建训练样本
136
130
sample = {
137
131
"prompt_system" : self .prompt_system ,
138
132
"prompt_input" : self .prompt_input_template .format (
@@ -150,10 +144,8 @@ def convert_to_sft_format(self, processed_data: List[Dict]) -> List[Dict]:
150
144
"url" : step .get ("href" , "" )
151
145
}
152
146
}
153
-
154
147
sft_data .append (sample )
155
- logging .info (f"Successfully created sample for step { step_idx } " )
156
-
148
+
157
149
except Exception as e :
158
150
logging .error (f"Error processing step { step_idx } : { str (e )} " )
159
151
continue
@@ -169,16 +161,16 @@ def convert_to_sft_format(self, processed_data: List[Dict]) -> List[Dict]:
169
161
return sft_data
170
162
171
163
def _validate_step (self , step : Dict ) -> bool :
172
- """验证步骤数据的有效性 """
173
- required_fields = ['type' , 'href' ] # 移除了 formatted_axtree 的要求
164
+ """Validate step data completeness """
165
+ required_fields = ['type' , 'href' ]
174
166
valid = all (field in step for field in required_fields )
175
167
if not valid :
176
168
missing_fields = [field for field in required_fields if field not in step ]
177
169
logging .warning (f"Missing required fields: { missing_fields } " )
178
170
return valid
179
171
180
172
def _build_action_list (self , previous_steps : List [Dict ]) -> str :
181
- """构建动作历史列表 """
173
+ """Build history of previous actions """
182
174
action_list = ""
183
175
for i , step in enumerate (previous_steps ):
184
176
action_list += self .action_template .format (
@@ -189,18 +181,17 @@ def _build_action_list(self, previous_steps: List[Dict]) -> str:
189
181
return action_list
190
182
191
183
def save_sft_data (self , sft_data : List [Dict ], output_path : str ):
192
- """保存SFT训练数据 """
184
+ """Save SFT training data to JSONL format """
193
185
os .makedirs (os .path .dirname (output_path ), exist_ok = True )
194
186
195
- # 保存为JSONL格式
196
187
with open (output_path , 'w' , encoding = 'utf-8' ) as f :
197
188
for sample in sft_data :
198
189
f .write (json .dumps (sample , ensure_ascii = False ) + '\n ' )
199
190
200
191
logging .info (f"Saved { len (sft_data )} training samples to { output_path } " )
201
192
202
193
def _create_trajectory_dirs (self , traj_idx : int ) -> Dict [str , str ]:
203
- """创建轨迹相关的目录 """
194
+ """Create directories for trajectory data """
204
195
dirs = {
205
196
'raw' : os .path .join (self .config .get ('raw_axtree_dir' , 'data/raw_axtree' ), str (traj_idx )),
206
197
'formatted' : os .path .join (self .config .get ('formatted_axtree_dir' , 'data/formatted_axtree' ), str (traj_idx )),
@@ -213,41 +204,39 @@ def _create_trajectory_dirs(self, traj_idx: int) -> Dict[str, str]:
213
204
return dirs
214
205
215
206
def _process_axtree (self , step : Dict , traj_idx : int , step_idx : int , traj_dirs : Dict [str , str ]) -> tuple :
216
- """处理单个步骤的 axtree 数据 """
207
+ """Process axtree data for a single step """
217
208
if not step .get ("axTree" ):
218
209
logging .warning (f"Step { step_idx } has no axTree data" )
219
210
return None , None
220
211
221
212
try :
222
- # 下载和保存原始 axtree
213
+ # Download and save raw axtree
223
214
raw_path = os .path .join (traj_dirs ['raw' ], f"{ step_idx } .json" )
224
215
download_json (step ["axTree" ], raw_path )
225
216
226
- # 读取原始 axtree
227
217
with open (raw_path , 'r' , encoding = 'utf-8' ) as f :
228
218
raw_axtree = json .load (f )
229
219
230
- # 格式化完整 axtree
220
+ # Format complete axtree
231
221
formatted_nodes = format_node (raw_axtree )
232
222
formatted_axtree = "\n " .join (formatted_nodes )
233
223
234
- # 保存格式化后的 axtree
224
+ # Save formatted axtree
235
225
formatted_path = os .path .join (traj_dirs ['formatted' ], f"{ step_idx } .txt" )
236
226
with open (formatted_path , 'w' , encoding = 'utf-8' ) as f :
237
227
f .write (formatted_axtree )
238
228
239
- # 查找目标节点
229
+ # Find target node
240
230
retrieved_node = None
241
231
if "axtId" in step and step ["axtId" ]:
242
- # 优先使用 axtId 查找
243
232
logging .info (f"Searching node by axtId: { step ['axtId' ]} " )
244
233
retrieved_node = find_node_by_axtid (raw_axtree , step ["axtId" ])
245
234
if retrieved_node :
246
235
logging .info (f"Found node by axtId: { step ['axtId' ]} " )
247
236
else :
248
237
logging .warning (f"Node not found by axtId: { step ['axtId' ]} , falling back to path search" )
249
238
250
- # 如果没有 axtId 或未找到,则使用 path 查找
239
+ # If no axtId or not found, use path search
251
240
if retrieved_node is None and "path" in step :
252
241
logging .info (f"Searching node by path: { step ['path' ]} " )
253
242
path = ["html" ] + step ["path" ].split ('>' )
@@ -261,16 +250,16 @@ def _process_axtree(self, step: Dict, traj_idx: int, step_idx: int, traj_dirs: D
261
250
logging .warning (f"No node found for step { step_idx } " )
262
251
return formatted_axtree , ""
263
252
264
- # 格式化检索到的节点
253
+ # Format retrieved node
265
254
retrieved_nodes = format_node (retrieved_node )
266
255
retrieved_axtree = "\n " .join (retrieved_nodes )
267
256
268
- # 保存检索到的节点
257
+ # Save retrieved node
269
258
retrieved_path = os .path .join (traj_dirs ['retrieved' ], f"{ step_idx } .txt" )
270
259
with open (retrieved_path , 'w' , encoding = 'utf-8' ) as f :
271
260
f .write (retrieved_axtree )
272
261
273
- # 验证找到的节点的 axtId 是否匹配(如果原始步骤中有 axtId)
262
+ # Verify found node's axtId matches (if original step has axtId)
274
263
if "axtId" in step and step ["axtId" ]:
275
264
found_axt_id = retrieved_node .get ("attributes" , {}).get ("data-imean-axt-id" )
276
265
if found_axt_id != step ["axtId" ]:
0 commit comments