Skip to content

Commit 70207e7

Browse files
committed
fix fused_mlp import bug
1 parent d0eeadb commit 70207e7

File tree

3 files changed

+4
-0
lines changed

3 files changed

+4
-0
lines changed

scripts/vit_triplane_sit_sample.py

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import traceback
1212

1313
import torch as th
14+
from xformers.components.feedforward import fused_mlp
1415
# from xformers.triton import FusedLayerNorm as LayerNorm
1516
import torch.multiprocessing as mp
1617
import torch.distributed as dist

scripts/vit_triplane_sit_train.py

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import nsr
3535
import nsr.lsgm
3636
# from nsr.train_util_diffusion import TrainLoop3DDiffusion as TrainLoop
37+
from xformers.components.feedforward import fused_mlp
3738

3839
from datasets.eg3d_dataset import LMDBDataset_MV_Compressed_eg3d
3940
from nsr.script_util import create_3DAE_model, encoder_and_nsr_defaults, loss_defaults, rendering_options_defaults, eg3d_options_default, dataset_defaults

scripts/vit_triplane_train.py

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import torch as th
1717

18+
from xformers.components.feedforward import fused_mlp
19+
1820
# if th.cuda.is_available(): # FIXME
1921
# from xformers.triton import FusedLayerNorm as LayerNorm
2022

0 commit comments

Comments
 (0)