Skip to content

Commit 4cde60a

Browse files
committed
initialize the repo with data downloading, pre-processing and SFT format converting
1 parent 84dddb4 commit 4cde60a

File tree

7 files changed

+105
-155
lines changed

7 files changed

+105
-155
lines changed

configs/config.yaml

+6-36
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,30 @@
1-
# 全局配置
1+
# Global Configuration
22
log_dir: "logs"
33

4-
# 数据下载配置
4+
# Data Download Configuration
55
data_download:
66
challenge_id: "L_if1ihd1jmMJq4WUbrYe"
77
save_path: "data/raw"
88
save_raw_data: true
99
username: ""
1010
password: ""
1111

12-
# 处理模式
13-
mode: "dom_tree" # "vision"
12+
# Processing Mode
13+
mode: "dom_tree" # or "vision"
1414

15-
# 数据处理配置
15+
# Data Processing Configuration
1616
data_processing:
1717
save_path: "data/processed/processed_data.json"
1818
valid_actions: ["click", "type", "hover", "press_enter", "paste", "copy"]
1919
min_steps: 2
2020
max_steps: 50
2121
dom_tree:
2222
max_sequence_length: 1024
23-
include_attributes: ["tag", "text", "href", "src"]
2423

2524
vision:
2625
image_size: [1024, 1024]
2726
augmentation: false
28-
29-
# 训练配置
30-
training:
31-
save_path: "data/sft"
32-
templates:
33-
system_prompt: true
34-
include_retrieved_axtree: true
35-
dom_tree:
36-
model_type: "text-to-text"
37-
max_length: 2048
38-
39-
vision:
40-
planning:
41-
model_type: "vision-language"
42-
image_encoder: "clip"
43-
grounding:
44-
model_type: "vision-coordinates"
45-
image_size: [1024, 1024]
46-
47-
# 推理配置
48-
inference:
49-
browser_type: "playwright"
50-
timeout: 30
51-
max_retries: 3
52-
53-
# 评估配置
54-
evaluation:
55-
metrics: ["success_rate", "steps_per_task", "completion_time"]
56-
max_steps: 20
57-
27+
5828
# Data paths
5929
data:
6030
processed_data_path: "data/processed/processed_data.json"

data_processing/converter/data_converter.py

+23-34
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66
from .utils import download_json, format_node, find_node_by_path, find_node_by_axtid
77

88
class SFTConverter:
9-
"""将处理后的轨迹数据转换为SFT训练格式"""
9+
"""Convert processed trajectory data to SFT training format"""
1010

1111
def __init__(self, config: Dict):
1212
self.config = config
1313
self._setup_logging()
1414
self._setup_templates()
1515

1616
def _setup_logging(self):
17-
"""设置日志"""
17+
"""Initialize logging configuration"""
1818
log_dir = self.config.get('log_dir', 'logs/converter')
1919
os.makedirs(log_dir, exist_ok=True)
2020

@@ -28,8 +28,8 @@ def _setup_logging(self):
2828
)
2929

3030
def _setup_templates(self):
31-
"""设置提示模板"""
32-
# 系统提示信息
31+
"""Setup prompt templates for model input/output"""
32+
# System prompt
3333
self.prompt_system = '''
3434
# CONTEXT
3535
@@ -60,7 +60,7 @@ def _setup_templates(self):
6060
3. Format actions correctly using the specified structure.
6161
'''
6262

63-
# 输入提示模板
63+
# Input prompt template
6464
self.prompt_input_template = '''
6565
# OBSERVATION
6666
@@ -79,22 +79,22 @@ def _setup_templates(self):
7979
{action_list}
8080
'''
8181

82-
# 输出提示模板
82+
# Output prompt template
8383
self.prompt_output_template = '''
8484
Based on the observation and objective, I will:
8585
8686
{action}
8787
'''
8888

89-
# 动作模板
89+
# Action template
9090
self.action_template = '''
9191
## Action {i}
9292
- action_type: {action_type}
9393
- action_value: {action_value}
9494
'''
9595

9696
def convert_to_sft_format(self, processed_data: List[Dict]) -> List[Dict]:
97-
"""转换为SFT训练格式"""
97+
"""Convert data to SFT training format"""
9898
sft_data = []
9999

100100
for traj_idx, trajectory in enumerate(processed_data):
@@ -106,33 +106,27 @@ def convert_to_sft_format(self, processed_data: List[Dict]) -> List[Dict]:
106106
steps = json.loads(steps_str)
107107
logging.info(f"Found {len(steps)} steps in trajectory")
108108

109-
# 创建轨迹特定的目录
110109
traj_dirs = self._create_trajectory_dirs(traj_idx)
111110

112-
# 处理每个步骤
113111
for step_idx, step in enumerate(steps):
114112
logging.info(f"Processing step {step_idx}")
115113

116114
if not self._validate_step(step):
117115
continue
118116

119117
try:
120-
# 处理 axtree 数据
121118
formatted_axtree, retrieved_axtree = self._process_axtree(step, traj_idx, step_idx, traj_dirs)
122119
if not formatted_axtree or not retrieved_axtree:
123120
continue
124121

125-
# 构建动作历史
126122
action_list = self._build_action_list(steps[:step_idx])
127123

128-
# 构建当前动作
129124
current_action = {
130125
"action_type": step["type"],
131126
"action_id": step.get("axtId", ""),
132127
"action_value": step.get("value", "")
133128
}
134129

135-
# 构建训练样本
136130
sample = {
137131
"prompt_system": self.prompt_system,
138132
"prompt_input": self.prompt_input_template.format(
@@ -150,10 +144,8 @@ def convert_to_sft_format(self, processed_data: List[Dict]) -> List[Dict]:
150144
"url": step.get("href", "")
151145
}
152146
}
153-
154147
sft_data.append(sample)
155-
logging.info(f"Successfully created sample for step {step_idx}")
156-
148+
157149
except Exception as e:
158150
logging.error(f"Error processing step {step_idx}: {str(e)}")
159151
continue
@@ -169,16 +161,16 @@ def convert_to_sft_format(self, processed_data: List[Dict]) -> List[Dict]:
169161
return sft_data
170162

171163
def _validate_step(self, step: Dict) -> bool:
172-
"""验证步骤数据的有效性"""
173-
required_fields = ['type', 'href'] # 移除了 formatted_axtree 的要求
164+
"""Validate step data completeness"""
165+
required_fields = ['type', 'href']
174166
valid = all(field in step for field in required_fields)
175167
if not valid:
176168
missing_fields = [field for field in required_fields if field not in step]
177169
logging.warning(f"Missing required fields: {missing_fields}")
178170
return valid
179171

180172
def _build_action_list(self, previous_steps: List[Dict]) -> str:
181-
"""构建动作历史列表"""
173+
"""Build history of previous actions"""
182174
action_list = ""
183175
for i, step in enumerate(previous_steps):
184176
action_list += self.action_template.format(
@@ -189,18 +181,17 @@ def _build_action_list(self, previous_steps: List[Dict]) -> str:
189181
return action_list
190182

191183
def save_sft_data(self, sft_data: List[Dict], output_path: str):
192-
"""保存SFT训练数据"""
184+
"""Save SFT training data to JSONL format"""
193185
os.makedirs(os.path.dirname(output_path), exist_ok=True)
194186

195-
# 保存为JSONL格式
196187
with open(output_path, 'w', encoding='utf-8') as f:
197188
for sample in sft_data:
198189
f.write(json.dumps(sample, ensure_ascii=False) + '\n')
199190

200191
logging.info(f"Saved {len(sft_data)} training samples to {output_path}")
201192

202193
def _create_trajectory_dirs(self, traj_idx: int) -> Dict[str, str]:
203-
"""创建轨迹相关的目录"""
194+
"""Create directories for trajectory data"""
204195
dirs = {
205196
'raw': os.path.join(self.config.get('raw_axtree_dir', 'data/raw_axtree'), str(traj_idx)),
206197
'formatted': os.path.join(self.config.get('formatted_axtree_dir', 'data/formatted_axtree'), str(traj_idx)),
@@ -213,41 +204,39 @@ def _create_trajectory_dirs(self, traj_idx: int) -> Dict[str, str]:
213204
return dirs
214205

215206
def _process_axtree(self, step: Dict, traj_idx: int, step_idx: int, traj_dirs: Dict[str, str]) -> tuple:
216-
"""处理单个步骤的 axtree 数据"""
207+
"""Process axtree data for a single step"""
217208
if not step.get("axTree"):
218209
logging.warning(f"Step {step_idx} has no axTree data")
219210
return None, None
220211

221212
try:
222-
# 下载和保存原始 axtree
213+
# Download and save raw axtree
223214
raw_path = os.path.join(traj_dirs['raw'], f"{step_idx}.json")
224215
download_json(step["axTree"], raw_path)
225216

226-
# 读取原始 axtree
227217
with open(raw_path, 'r', encoding='utf-8') as f:
228218
raw_axtree = json.load(f)
229219

230-
# 格式化完整 axtree
220+
# Format complete axtree
231221
formatted_nodes = format_node(raw_axtree)
232222
formatted_axtree = "\n".join(formatted_nodes)
233223

234-
# 保存格式化后的 axtree
224+
# Save formatted axtree
235225
formatted_path = os.path.join(traj_dirs['formatted'], f"{step_idx}.txt")
236226
with open(formatted_path, 'w', encoding='utf-8') as f:
237227
f.write(formatted_axtree)
238228

239-
# 查找目标节点
229+
# Find target node
240230
retrieved_node = None
241231
if "axtId" in step and step["axtId"]:
242-
# 优先使用 axtId 查找
243232
logging.info(f"Searching node by axtId: {step['axtId']}")
244233
retrieved_node = find_node_by_axtid(raw_axtree, step["axtId"])
245234
if retrieved_node:
246235
logging.info(f"Found node by axtId: {step['axtId']}")
247236
else:
248237
logging.warning(f"Node not found by axtId: {step['axtId']}, falling back to path search")
249238

250-
# 如果没有 axtId 或未找到,则使用 path 查找
239+
# If no axtId or not found, use path search
251240
if retrieved_node is None and "path" in step:
252241
logging.info(f"Searching node by path: {step['path']}")
253242
path = ["html"] + step["path"].split('>')
@@ -261,16 +250,16 @@ def _process_axtree(self, step: Dict, traj_idx: int, step_idx: int, traj_dirs: D
261250
logging.warning(f"No node found for step {step_idx}")
262251
return formatted_axtree, ""
263252

264-
# 格式化检索到的节点
253+
# Format retrieved node
265254
retrieved_nodes = format_node(retrieved_node)
266255
retrieved_axtree = "\n".join(retrieved_nodes)
267256

268-
# 保存检索到的节点
257+
# Save retrieved node
269258
retrieved_path = os.path.join(traj_dirs['retrieved'], f"{step_idx}.txt")
270259
with open(retrieved_path, 'w', encoding='utf-8') as f:
271260
f.write(retrieved_axtree)
272261

273-
# 验证找到的节点的 axtId 是否匹配(如果原始步骤中有 axtId
262+
# Verify found node's axtId matches (if original step has axtId)
274263
if "axtId" in step and step["axtId"]:
275264
found_axt_id = retrieved_node.get("attributes", {}).get("data-imean-axt-id")
276265
if found_axt_id != step["axtId"]:

data_processing/converter/utils.py

+17-22
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,14 @@
33

44
def download_json(url, output_file='output.json'):
55
try:
6-
# 发送GET请求获取数据
6+
# Send GET request
77
response = requests.get(url)
8-
response.raise_for_status() # 检查请求是否成功
8+
response.raise_for_status()
99

10-
# 将JSON数据保存到文件
10+
# Save JSON data to file
1111
with open(output_file, 'w', encoding='utf-8') as f:
1212
json.dump(response.json(), f, ensure_ascii=False, indent=2)
1313

14-
# print(f"Successfully downloaded JSON to {output_file}")
15-
1614
except requests.exceptions.RequestException as e:
1715
print(f"Error downloading the file: {e}")
1816
except json.JSONDecodeError as e:
@@ -21,21 +19,23 @@ def download_json(url, output_file='output.json'):
2119

2220
def find_node_by_axtid(node, axt_id):
2321
"""
24-
递归遍历 axtree,寻找指定 axt_id 的节点。
22+
Recursively traverse axtree to find node with specified axt_id.
2523
26-
:param node: 当前节点
27-
:param axt_id: 要查找的 axt_id
28-
:return: 如果找到匹配的节点,返回节点对象;否则返回 None
24+
Args:
25+
node: Current node
26+
axt_id: Target axt_id to find
27+
Returns:
28+
Matching node object if found, None otherwise
2929
"""
3030
if node is None:
3131
return None
3232

33-
# 检查当前节点的 axt_id
33+
# Check current node's axt_id
3434
current_axt_id = node.get("attributes", {}).get("data-imean-axt-id")
3535
if current_axt_id == axt_id:
3636
return node
3737

38-
# 递归检查子节点
38+
# Check child nodes recursively
3939
for child in node.get("children", []):
4040
result = find_node_by_axtid(child, axt_id)
4141
if result:
@@ -45,46 +45,41 @@ def find_node_by_axtid(node, axt_id):
4545

4646
def find_node_by_path(node, path, current_level=0):
4747
"""
48-
递归遍历 axtree,寻找路径为 path 的节点。
49-
作为备选方案,当 axtId 不存在或未找到时使用。
48+
Recursively traverse axtree to find node at specified path.
49+
Used as fallback when axtId is not available or not found.
5050
"""
5151
if node is None:
5252
return None
5353

54-
# 获取当前节点的标签
5554
html_tag = node.get("attributes", {}).get("html_tag", "")
5655

57-
# 检查当前节点是否匹配路径的当前部分
5856
if html_tag != path[current_level]:
5957
return None
6058

61-
# 如果已经匹配到路径的最后一级,返回当前节点
6259
if current_level == len(path) - 1:
6360
return node
6461

65-
# 遍历子节点,递归查找下一层级
6662
for child in node.get("children", []):
6763
result = find_node_by_path(child, path, current_level + 1)
6864
if result:
6965
return result
7066

71-
# 如果没有找到,返回 None
7267
return None
7368

7469
def format_node(node, level=0):
70+
"""Format node and its children into a readable tree structure"""
7571
result = []
76-
indent = " " * level # 2 spaces per level
72+
indent = " " * level
7773

7874
if node is None:
7975
return result
8076

81-
# Get attributes for current node
77+
# Get node attributes
8278
axt_id = node.get("attributes", {}).get("data-imean-axt-id")
8379
role = node.get("role")
8480
name = node.get("name")
8581
value = node.get("value")
8682

87-
# Add formatted string if node has all required attributes
8883
if axt_id and role:
8984
formatted = indent + f"[{axt_id}] {role}"
9085
if name:
@@ -93,7 +88,7 @@ def format_node(node, level=0):
9388
formatted += f" '{value}'"
9489
result.append(formatted)
9590

96-
# Recursively process children
91+
# Process children recursively
9792
children = node.get("children", [])
9893
for child in children:
9994
result.extend(format_node(child, level + 1))

0 commit comments

Comments
 (0)