|
24 | 24 | conv1_kernel_size=5,
|
25 | 25 | bn_momentum=0.02)),
|
26 | 26 | memory=dict(type='MultilevelMemory', in_channels=[32, 64, 128, 256], queue=-1, vmp_layer=(0,1,2,3)),
|
27 |
| - # memory=dict(type='MultilevelMemory', in_channels=[32, 64, 128, 256], queue=-1, vmp_layer=(2,3)), |
28 | 27 | pool=dict(type='GeoAwarePooling', channel_proj=96),
|
29 | 28 | decoder=dict(
|
30 | 29 | type='ScanNetMixQueryDecoder',
|
31 | 30 | num_layers=3,
|
32 | 31 | share_attn_mlp=False,
|
33 | 32 | share_mask_mlp=False,
|
34 |
| - temporal_attn=False, # TODO: to be extended |
| 33 | + temporal_attn=False, |
35 | 34 | # the last mp_mode should be "P"
|
36 | 35 | cross_attn_mode=["", "SP", "SP", "SP"],
|
37 | 36 | mask_pred_mode=["SP", "SP", "P", "P"],
|
|
51 | 50 | fix_attention=True,
|
52 | 51 | objectness_flag=False,
|
53 | 52 | bbox_flag=use_bbox),
|
54 |
| - merge_head=dict(type='MergeHead', in_channels=256, out_channels=256), |
| 53 | + merge_head=dict(type='MergeHead', in_channels=256, out_channels=256, norm='layer'), |
55 | 54 | merge_criterion=dict(type='ScanNetMergeCriterion_Fast', tmp=True, p2s=False),
|
56 | 55 | criterion=dict(
|
57 | 56 | type='ScanNetMixedCriterion',
|
|
76 | 75 | fix_dice_loss_weight=True,
|
77 | 76 | iter_matcher=True,
|
78 | 77 | fix_mean_loss=True)),
|
79 |
| - train_cfg=dict(), |
| 78 | + train_cfg=None, |
80 | 79 | test_cfg=dict(
|
81 | 80 | # TODO: a larger topK may be better
|
82 | 81 | topk_insts=20,
|
|
91 | 90 | stuff_classes=[0, 1],
|
92 | 91 | merge_type='learnable_online'))
|
93 | 92 |
|
94 |
| -# TODO: complete the dataset |
95 | 93 | dataset_type = 'ScanNet200SegMVDataset_'
|
96 | 94 | data_root = 'data/scenenn-mv/'
|
97 | 95 |
|
|
163 | 161 | with_seg_3d=True,
|
164 | 162 | with_sp_mask_3d=True,
|
165 | 163 | with_rec=True,
|
166 |
| - dataset_type = 'scenenn'), |
167 |
| - # dict(type='SwapChairAndFloorWithRec'), |
| 164 | + dataset_type='scenenn'), |
168 | 165 | dict(type='PointSegClassMappingWithRec'),
|
169 | 166 | dict(
|
170 | 167 | type='MultiScaleFlipAug3D',
|
|
186 | 183 | dict(type='Pack3DDetInputs_Online', keys=['points', 'sp_pts_mask'])
|
187 | 184 | ]
|
188 | 185 |
|
| 186 | +train_dataloader = None |
| 187 | + |
189 | 188 | val_dataloader = dict(
|
190 | 189 | # persistent_workers=False,
|
191 | 190 | # num_workers=0,
|
|
233 | 232 | metric_meta=metric_meta)
|
234 | 233 | test_evaluator = val_evaluator
|
235 | 234 |
|
236 |
| - |
237 | 235 | custom_hooks = [dict(type='EmptyCacheHook', after_iter=True)]
|
238 | 236 | default_hooks = dict(
|
239 | 237 | checkpoint=dict(
|
|
242 | 240 | save_best=['all_ap_50%'],
|
243 | 241 | rule='greater'))
|
244 | 242 |
|
245 |
| - |
246 |
| -# training schedule for 1x |
247 | 243 | val_cfg = dict(type='ValLoop')
|
248 | 244 | test_cfg = dict(type='TestLoop')
|
0 commit comments