Skip to content

Commit 3663704

Browse files
committed
Update HybrIK support by @Jeff-sjtu
1 parent 53273e0 commit 3663704

20 files changed

+2390
-32
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ results/*
99
force_push.sh
1010
scripts/vis*
1111
scripts/process_all*
12+
.idea

README.md

+5-4
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
<br />
4747

4848
## News :triangular_flag_on_post:
49+
- [2022/04/26] <a href="https://github.com/Jeff-sjtu/HybrIK">HybrIK (SMPL)</a> is supported as optional HPS by <a href="https://jeffli.site/">Jiefeng Li</a>.
4950
- [2022/03/05] <a href="https://github.com/YadiraF/PIXIE">PIXIE (SMPL-X)</a>, <a href="https://github.com/mkocabas/PARE">PARE (SMPL)</a>, <a href="https://github.com/HongwenZhang/PyMAF">PyMAF (SMPL)</a> are all supported as optional HPS.
5051
- [2022/02/07] <a href='https://colab.research.google.com/drive/1-AWeWhPvCTBX0KfMtgtMk10uPU05ihoA?usp=sharing' style='padding-left: 0.5rem;'><img src='https://colab.research.google.com/assets/colab-badge.svg' alt='Google Colab'></a> is ready to use.
5152

@@ -119,7 +120,7 @@
119120
## TODO
120121

121122
- [x] testing code and pretrained models (*self-implemented version)
122-
- [x] ICON (w/ & w/o global encoder, w/ PyMAF/PIXIE/PARE as HPS)
123+
- [x] ICON (w/ & w/o global encoder, w/ PyMAF/HybrIK/PIXIE/PARE as HPS)
123124
- [x] PIFu* (RGB image + predicted normal map as input)
124125
- [x] PaMIR* (RGB image + predicted normal map as input, w/ PyMAF/PARE as HPS)
125126
- [x] colab notebook <a href='https://colab.research.google.com/drive/1-AWeWhPvCTBX0KfMtgtMk10uPU05ihoA?usp=sharing' style='padding-left: 0.5rem;'>
@@ -150,10 +151,10 @@ python infer.py -cfg ../configs/pifu.yaml -gpu 0 -in_dir ../examples -out_dir ..
150151
python infer.py -cfg ../configs/pamir.yaml -gpu 0 -in_dir ../examples -out_dir ../results
151152

152153
# ICON w/ global filter (better visual details --> lower Normal Error))
153-
python infer.py -cfg ../configs/icon-filter.yaml -gpu 0 -in_dir ../examples -out_dir ../results -hps_type {pixie/pymaf/pare}
154+
python infer.py -cfg ../configs/icon-filter.yaml -gpu 0 -in_dir ../examples -out_dir ../results -hps_type {pixie/pymaf/pare/hybrik}
154155

155156
# ICON w/o global filter (higher evaluation scores --> lower P2S/Chamfer Error))
156-
python infer.py -cfg ../configs/icon-nofilter.yaml -gpu 0 -in_dir ../examples -out_dir ../results -hps_type {pixie/pymaf/pare}
157+
python infer.py -cfg ../configs/icon-nofilter.yaml -gpu 0 -in_dir ../examples -out_dir ../results -hps_type {pixie/pymaf/pare/hybrik}
157158
```
158159

159160
## More Qualitative Results
@@ -194,7 +195,7 @@ Here are some great resources we benefit from:
194195
- [PaMIR](https://github.com/ZhengZerong/PaMIR), [PIFu](https://github.com/shunsukesaito/PIFu), [PIFuHD](https://github.com/facebookresearch/pifuhd), and [MonoPort](https://github.com/Project-Splinter/MonoPort) for Benchmark
195196
- [SCANimate](https://github.com/shunsukesaito/SCANimate) and [AIST++](https://github.com/google/aistplusplus_api) for Animation
196197
- [rembg](https://github.com/danielgatis/rembg) for Human Segmentation
197-
- [smplx](https://github.com/vchoutas/smplx), [PARE](https://github.com/mkocabas/PARE), [PyMAF](https://github.com/HongwenZhang/PyMAF), and [PIXIE](https://github.com/YadiraF/PIXIE) for Human Pose & Shape Estimation
198+
- [smplx](https://github.com/vchoutas/smplx), [PARE](https://github.com/mkocabas/PARE), [PyMAF](https://github.com/HongwenZhang/PyMAF), [PIXIE](https://github.com/YadiraF/PIXIE), and [HybrIK](https://github.com/Jeff-sjtu/HybrIK) for Human Pose & Shape Estimation
198199
- [CAPE](https://github.com/qianlim/CAPE) and [THuman](https://github.com/ZhengZerong/DeepHuman/tree/master/THUmanDataset) for Dataset
199200
- [PyTorch3D](https://github.com/facebookresearch/pytorch3d) for Differential Rendering
200201

assets/rendering/080.png

97 KB
Loading

assets/rendering/SMPL_norm_B_080.png

72.7 KB
Loading

assets/rendering/SMPL_norm_F_080.png

70.9 KB
Loading

assets/rendering/norm_B_080.png

114 KB
Loading

assets/rendering/norm_F_080.png

114 KB
Loading

docs/dataset.md

+7
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,10 @@ bash render_batch.sh gen all
3030
```
3131

3232
Then you will get the whole generated dataset under `data/thuman2_{num_views}views`
33+
34+
## Examples
35+
36+
|<img src="assets/../../assets/rendering/080.png" width="150">|<img src="assets/../../assets/rendering/norm_F_080.png" width="150">|<img src="assets/../../assets/rendering/norm_B_080.png" width="150">|<img src="assets/../../assets/rendering/SMPL_norm_F_080.png" width="150">|<img src="assets/../../assets/rendering/SMPL_norm_B_080.png" width="150">|
37+
|---|---|---|---|---|
38+
|Image|Normal(Front)|Normal(Back)|Normal(SMPL, Front)|Normal(SMPL, Back)|
39+

docs/installation.md

+9-1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ source activate icon
3636
pip install -r requirements.txt --use-deprecated=legacy-resolver
3737
```
3838

39+
40+
:warning: If you have trouble assessing Google Drive, you need VPN to use `rembg` for the first time.
41+
3942
## Register at [ICON's website](https://icon.is.tue.mpg.de/)
4043

4144
![Register](../assets/register.png)
@@ -58,7 +61,7 @@ Optional:
5861
cd ICON
5962
bash fetch_data.sh # requires username and password
6063
```
61-
* Download [PyMAF](https://github.com/HongwenZhang/PyMAF#necessary-files), [PARE (optional, SMPL)](https://github.com/mkocabas/PARE#demo), [PIXIE (optional, SMPL-X)](https://pixie.is.tue.mpg.de/)
64+
* Download [PyMAF](https://github.com/HongwenZhang/PyMAF#necessary-files), [PARE (optional, SMPL)](https://github.com/mkocabas/PARE#demo), [PIXIE (optional, SMPL-X)](https://pixie.is.tue.mpg.de/), [HybrIK (optional, SMPL)](https://github.com/Jeff-sjtu/HybrIK)
6265

6366
```bash
6467
bash fetch_hps.sh
@@ -75,6 +78,11 @@ data/
7578
│ ├── normal.ckpt
7679
│ ├── pamir.ckpt
7780
│ └── pifu.ckpt
81+
├── hybrik_data/
82+
│ ├── h36m_mean_beta.npy
83+
│ ├── J_regressor_h36m.npy
84+
│ ├── hybrik_config.yaml
85+
│ └── pretrained_w_cam.pth
7886
├── pare_data/
7987
│ ├── J_regressor_{extra,h36m}.npy
8088
│ ├── pare/

fetch_hps.sh

+25-2
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ rm -rf data && rm -f data.tar.gz
2020
source activate icon
2121
pip install gdown --upgrade
2222
gdown https://drive.google.com/drive/u/1/folders/1CkF79XRaZzdRlj6eJUt4W0nbTORv2t7O -O pretrained_model --folder
23-
cd ..
23+
cd ../..
2424
echo "PyMAF done!"
2525

2626
function download_pare(){
2727
# (optional) download PARE
28+
cd data
2829
wget https://www.dropbox.com/s/aeulffqzb3zmh8x/pare-github-data.zip
2930
unzip pare-github-data.zip && mv data pare_data
3031
rm -f pare-github-data.zip
@@ -54,6 +55,20 @@ function download_pixie(){
5455
cd ../../
5556
}
5657

58+
function download_hybrik(){
59+
mkdir -p data/hybrik_data
60+
61+
# (optional) download HybrIK
62+
# gdown https://drive.google.com/uc?id=16Y_MGUynFeEzV8GVtKTE5AtkHSi3xsF9 -O data/hybrik_data/pretrained_w_cam.pth
63+
gdown https://drive.google.com/uc?id=1lEWZgqxiDNNJgvpjlIXef2VuxcGbtXzi -O data/hybrik_data.zip
64+
cd data
65+
unzip hybrik_data.zip
66+
rm -r *.zip __MACOSX
67+
cd ..
68+
69+
echo "HybrIK done!"
70+
}
71+
5772
read -p "(optional) Download PARE[SMPL] (y/n)?" choice
5873
case "$choice" in
5974
y|Y ) download_pare;;
@@ -66,4 +81,12 @@ case "$choice" in
6681
y|Y ) download_pixie;;
6782
n|N ) echo "PIXIE Done!";;
6883
* ) echo "Invalid input! Please use y|Y or n|N";;
69-
esac
84+
esac
85+
86+
pwd
87+
read -p "(optional) Download HybrIK[SMPL] (y/n)?" choice
88+
case "$choice" in
89+
y|Y ) download_hybrik;;
90+
n|N ) echo "HybrIK Done!";;
91+
* ) echo "Invalid input! Please use y|Y or n|N";;
92+
esac

lib/dataset/TestDataset.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@
5050
from lib.pixielib.pixie import PIXIE
5151
from lib.pixielib.utils.config import cfg as pixie_cfg
5252

53+
# for hybrik
54+
from lib.hybrik.models.simple3dpose import HybrIKBaseSMPLCam
55+
5356

5457
class TestDataset():
5558
def __init__(self, cfg, device):
@@ -104,8 +107,12 @@ def __init__(self, cfg, device):
104107
elif self.hps_type == 'pixie':
105108
self.hps = PIXIE(config = pixie_cfg, device=self.device)
106109
self.smpl_model = self.hps.smplx
107-
108-
110+
elif self.hps_type == 'hybrik':
111+
smpl_path = osp.join(self.smpl_data.model_dir, "smpl/SMPL_NEUTRAL.pkl")
112+
self.hps = HybrIKBaseSMPLCam(cfg_file=path_config.HYBRIK_CFG, smpl_path=smpl_path, data_path=path_config.hybrik_data_dir)
113+
self.hps.load_state_dict(torch.load(path_config.HYBRIK_CKPT, map_location='cpu'), strict=False)
114+
self.hps.to(self.device)
115+
109116
print(colored(f"Using {self.hps_type} as HPS Estimator\n", "green"))
110117

111118
self.render = Render(size=512, device=device)
@@ -217,6 +224,14 @@ def __getitem__(self, index):
217224
data_dict['smpl_verts'] = preds_dict['vertices']
218225
scale, tranX, tranY = preds_dict['cam'][0, :3]
219226

227+
elif self.hps_type == 'hybrik':
228+
data_dict['body_pose'] = preds_dict['pred_theta_mats'][:, 1:]
229+
data_dict['global_orient'] = preds_dict['pred_theta_mats'][:, [0]]
230+
data_dict['betas'] = preds_dict['pred_shape']
231+
data_dict['smpl_verts'] = preds_dict['pred_vertices']
232+
scale, tranX, tranY = preds_dict['pred_camera'][0, :3]
233+
scale = scale * 2
234+
220235
data_dict['scale'] = scale
221236
data_dict['trans'] = torch.tensor([tranX, tranY, 0.0]).to(self.device)
222237

@@ -246,7 +261,6 @@ def visualize_alignment(self, data):
246261
global_orient=data['global_orient'],
247262
pose2rot=False)
248263
smpl_verts = ((smpl_out.vertices + data['trans'])* data['scale']).detach().cpu().numpy()[0]
249-
250264
else:
251265
smpl_verts, _, _ = self.smpl_model(shape_params=data['betas'],
252266
expression_params=data['exp'],
@@ -303,7 +317,7 @@ def visualize_alignment(self, data):
303317
{
304318
'image_dir': "../examples",
305319
'has_det': True, # w/ or w/o detection
306-
'hps_type': 'pixie' # pymaf/pare/pixie
320+
'hps_type': 'hybrik' # pymaf/pare/pixie/hybrik
307321
}, device)
308322

309323

lib/hybrik/models/layers/Resnet.py

+166
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import torch.nn as nn
2+
import torch.nn.functional as F
3+
4+
5+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
6+
"""3x3 convolution with padding"""
7+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
8+
padding=dilation, groups=groups, bias=False, dilation=dilation)
9+
10+
11+
class BasicBlock(nn.Module):
12+
expansion = 1
13+
14+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
15+
base_width=64, dilation=1, norm_layer=None, dcn=None):
16+
super(BasicBlock, self).__init__()
17+
if norm_layer is None:
18+
norm_layer = nn.BatchNorm2d
19+
if groups != 1 or base_width != 64:
20+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
21+
if dilation > 1:
22+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
23+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
24+
self.conv1 = conv3x3(inplanes, planes, stride)
25+
self.bn1 = norm_layer(planes)
26+
self.relu = nn.ReLU(inplace=True)
27+
self.conv2 = conv3x3(planes, planes)
28+
self.bn2 = norm_layer(planes)
29+
self.downsample = downsample
30+
self.stride = stride
31+
32+
def forward(self, x):
33+
identity = x
34+
35+
out = self.conv1(x)
36+
out = self.bn1(out)
37+
out = self.relu(out)
38+
39+
out = self.conv2(out)
40+
out = self.bn2(out)
41+
42+
if self.downsample is not None:
43+
identity = self.downsample(x)
44+
45+
out += identity
46+
out = self.relu(out)
47+
48+
return out
49+
50+
51+
class Bottleneck(nn.Module):
52+
expansion = 4
53+
54+
def __init__(self, inplanes, planes, stride=1,
55+
downsample=None, norm_layer=nn.BatchNorm2d, dcn=None):
56+
super(Bottleneck, self).__init__()
57+
self.dcn = dcn
58+
self.with_dcn = dcn is not None
59+
60+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
61+
self.bn1 = norm_layer(planes, momentum=0.1)
62+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
63+
padding=1, bias=False)
64+
65+
self.bn2 = norm_layer(planes, momentum=0.1)
66+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
67+
self.bn3 = norm_layer(planes * 4, momentum=0.1)
68+
self.downsample = downsample
69+
self.stride = stride
70+
71+
def forward(self, x):
72+
residual = x
73+
74+
out = F.relu(self.bn1(self.conv1(x)), inplace=True)
75+
if not self.with_dcn:
76+
out = F.relu(self.bn2(self.conv2(out)), inplace=True)
77+
elif self.with_modulated_dcn:
78+
offset_mask = self.conv2_offset(out)
79+
offset = offset_mask[:, :18 * self.deformable_groups, :, :]
80+
mask = offset_mask[:, -9 * self.deformable_groups:, :, :]
81+
mask = mask.sigmoid()
82+
out = F.relu(self.bn2(self.conv2(out, offset, mask)))
83+
else:
84+
offset = self.conv2_offset(out)
85+
out = F.relu(self.bn2(self.conv2(out, offset)), inplace=True)
86+
87+
out = self.conv3(out)
88+
out = self.bn3(out)
89+
90+
if self.downsample is not None:
91+
residual = self.downsample(x)
92+
93+
out += residual
94+
out = F.relu(out)
95+
96+
return out
97+
98+
99+
class ResNet(nn.Module):
100+
""" ResNet """
101+
102+
def __init__(self, architecture, norm_layer=nn.BatchNorm2d, dcn=None, stage_with_dcn=(False, False, False, False)):
103+
super(ResNet, self).__init__()
104+
self._norm_layer = norm_layer
105+
assert architecture in ["resnet18", "resnet34", "resnet50", "resnet101", 'resnet152']
106+
layers = {
107+
'resnet18': [2, 2, 2, 2],
108+
'resnet34': [3, 4, 6, 3],
109+
'resnet50': [3, 4, 6, 3],
110+
'resnet101': [3, 4, 23, 3],
111+
'resnet152': [3, 8, 36, 3],
112+
}
113+
self.inplanes = 64
114+
if architecture == "resnet18" or architecture == 'resnet34':
115+
self.block = BasicBlock
116+
else:
117+
self.block = Bottleneck
118+
self.layers = layers[architecture]
119+
120+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7,
121+
stride=2, padding=3, bias=False)
122+
self.bn1 = norm_layer(64, eps=1e-5, momentum=0.1, affine=True)
123+
self.relu = nn.ReLU(inplace=True)
124+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
125+
126+
stage_dcn = [dcn if with_dcn else None for with_dcn in stage_with_dcn]
127+
128+
self.layer1 = self.make_layer(
129+
self.block, 64, self.layers[0], dcn=stage_dcn[0])
130+
self.layer2 = self.make_layer(
131+
self.block, 128, self.layers[1], stride=2, dcn=stage_dcn[1])
132+
self.layer3 = self.make_layer(
133+
self.block, 256, self.layers[2], stride=2, dcn=stage_dcn[2])
134+
135+
self.layer4 = self.make_layer(
136+
self.block, 512, self.layers[3], stride=2, dcn=stage_dcn[3])
137+
138+
def forward(self, x):
139+
x = self.maxpool(self.relu(self.bn1(self.conv1(x)))) # 64 * h/4 * w/4
140+
x = self.layer1(x) # 256 * h/4 * w/4
141+
x = self.layer2(x) # 512 * h/8 * w/8
142+
x = self.layer3(x) # 1024 * h/16 * w/16
143+
x = self.layer4(x) # 2048 * h/32 * w/32
144+
return x
145+
146+
def stages(self):
147+
return [self.layer1, self.layer2, self.layer3, self.layer4]
148+
149+
def make_layer(self, block, planes, blocks, stride=1, dcn=None):
150+
downsample = None
151+
if stride != 1 or self.inplanes != planes * block.expansion:
152+
downsample = nn.Sequential(
153+
nn.Conv2d(self.inplanes, planes * block.expansion,
154+
kernel_size=1, stride=stride, bias=False),
155+
self._norm_layer(planes * block.expansion),
156+
)
157+
158+
layers = []
159+
layers.append(block(self.inplanes, planes, stride, downsample,
160+
norm_layer=self._norm_layer, dcn=dcn))
161+
self.inplanes = planes * block.expansion
162+
for i in range(1, blocks):
163+
layers.append(block(self.inplanes, planes,
164+
norm_layer=self._norm_layer, dcn=dcn))
165+
166+
return nn.Sequential(*layers)

0 commit comments

Comments
 (0)