Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] add casual vqvae ✨ #145

Merged
merged 1 commit into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion opensora/models/ae/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
from .videobase import videobase_ae, videovae, videovqvae, videobase_ae_stride, videobase_ae_channel
from .videobase import (
VQVAEConfiguration,
VQVAEModel
VQVAEModel,
VQVAEDataset,
VQVAETrainer,
CausalVQVAEModel,
CausalVQVAEConfiguration,
CausalVQVAEDataset,
CausalVQVAETrainer
)

ae_stride_config = {}
Expand Down
8 changes: 7 additions & 1 deletion opensora/models/ae/videobase/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
VQVAEConfiguration,
VQVAEModel,
VQVAETrainer,
VQVAEDataset, VideoGPTVQVAEWrapper
VQVAEDataset, VideoGPTVQVAEWrapper,
)
from .causal_vqvae import (
CausalVQVAEConfiguration,
CausalVQVAEDataset,
CausalVQVAETrainer,
CausalVQVAEModel
)


Expand Down
4 changes: 4 additions & 0 deletions opensora/models/ae/videobase/causal_vqvae/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .configuration_causalvqvae import CausalVQVAEConfiguration
from .modeling_causalvqvae import CausalVQVAEModel
from .trainer_causalvqvae import CausalVQVAETrainer
from .dataset_causalvqvae import CausalVQVAEDataset
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from ..configuration_videobase import VideoBaseConfiguration
from typing import Union, Tuple

class CausalVQVAEConfiguration(VideoBaseConfiguration):
def __init__(
self,
embedding_dim: int = 256,
n_codes: int = 2048,
n_hiddens: int = 240,
n_res_layers: int = 4,
resolution: int = 128,
sequence_length: int = 16,
time_downsample: int = 4,
spatial_downsample: int = 8,
no_pos_embd: bool = True,
**kwargs,
):
super().__init__(**kwargs)

self.embedding_dim = embedding_dim
self.n_codes = n_codes
self.n_hiddens = n_hiddens
self.n_res_layers = n_res_layers
self.resolution = resolution
self.sequence_length = sequence_length
self.time_downsample = time_downsample
self.spatial_downsample = spatial_downsample
self.no_pos_embd = no_pos_embd
153 changes: 153 additions & 0 deletions opensora/models/ae/videobase/causal_vqvae/dataset_causalvqvae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import os.path as osp
import math
import pickle
import warnings

import glob
from PIL import Image
import torch
import torch.utils.data as data
import torch.distributed as dist
import torch.nn.functional as F
from torchvision.datasets.video_utils import VideoClips
import torchvision.transforms as transforms

# Copied from https://github.com/wilson1yan/VideoGPT
class CausalVQVAEDataset(data.Dataset):
""" Generic dataset for videos files stored in folders
Returns BCTHW videos in the range [-0.5, 0.5] """
video_exts = ['avi', 'mp4', 'webm']
image_exts = ['png', 'jpg', 'jpeg']
def __init__(self, video_folder, sequence_length, image_folder=None, train=True, resolution=64):
"""
Args:
data_folder: path to the folder with videos. The folder
should contain a 'train' and a 'test' directory,
each with corresponding videos stored
sequence_length: length of extracted video sequences
"""
super().__init__()
if image_folder is not None:
raise NotImplementedError("Image training is not supported now.")

self.train = train
self.sequence_length = sequence_length
self.resolution = resolution

files = []
video_files = []
image_files = []
for data_folder in [video_folder, image_folder]:
if data_folder is None:
continue
folder = data_folder
video_files += sum([glob.glob(osp.join(folder, '**', f'*.{ext}'), recursive=True)
for ext in self.video_exts], [])
image_files += sum([glob.glob(osp.join(folder, '**', f'*.{ext}'), recursive=True)
for ext in self.image_exts], [])
files = video_files + image_files
# hacky way to compute # of classes (count # of unique parent directories)
# self.classes = list(set([get_parent_dir(f) for f in files]))
# self.classes.sort()
# self.class_to_label = {c: i for i, c in enumerate(self.classes)}

warnings.filterwarnings('ignore')
if len(video_files) != 0:
cache_file = osp.join(folder, f"metadata_{sequence_length}.pkl")
if not osp.exists(cache_file):
clips = VideoClips(video_files, sequence_length, num_workers=32)
if dist.is_initialized() and dist.get_rank() == 0:
pickle.dump(clips.metadata, open(cache_file, 'wb'))
else:
metadata = pickle.load(open(cache_file, 'rb'))
clips = VideoClips(video_files, sequence_length,
_precomputed_metadata=metadata)

self._clips = clips
self._clips_num = self._clips.num_clips()
else:
self._clips = None
self._clips_num = 0
self.image_files = image_files

@property
def n_classes(self):
return len(self.classes)

def __len__(self):
return self._clips_num + len(self.image_files)

def __getitem__(self, idx):
resolution = self.resolution
if idx < self._clips_num:
video, _, _, idx = self._clips.get_clip(idx)
video = preprocess(video, resolution)
class_name = get_parent_dir(self._clips.video_paths[idx])
else:
idx -= self._clips_num
image = Image.open(self.image_files[idx])
video = preprocess_image(image, resolution, self.sequence_length)
# label = self.class_to_label[class_name]
return dict(video=video, label="")

# Copied from https://github.com/wilson1yan/VideoGPT
def get_parent_dir(path):
return osp.basename(osp.dirname(path))

# Copied from https://github.com/wilson1yan/VideoGPT
def preprocess(video, resolution, sequence_length=None):
# video: THWC, {0, ..., 255}
video = video.permute(0, 3, 1, 2).float() / 255. # TCHW
t, c, h, w = video.shape

# temporal crop
if sequence_length is not None:
assert sequence_length <= t
video = video[:sequence_length]

# scale shorter side to resolution
scale = resolution / min(h, w)
if h < w:
target_size = (resolution, math.ceil(w * scale))
else:
target_size = (math.ceil(h * scale), resolution)
video = F.interpolate(video, size=target_size, mode='bilinear',
align_corners=False)

# center crop
t, c, h, w = video.shape
w_start = (w - resolution) // 2
h_start = (h - resolution) // 2
video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution]
video = video.permute(1, 0, 2, 3).contiguous() # CTHW

video -= 0.5

return video


def preprocess_image(image, resolution, sequence_length=1):
# image: HWC, {0, ..., 255}
image = image.convert("RGB")
w,h = image.size
scale = resolution / min(h, w)
if h < w:
target_size = (resolution, math.ceil(w * scale))
else:
target_size = (math.ceil(h * scale), resolution)
image = image.resize(target_size)

image = transforms.ToTensor()(image)
image = image.float()
c, h, w = image.shape
w_start = (w - resolution) // 2
h_start = (h - resolution) // 2
image = image[:, h_start:h_start + resolution, w_start:w_start + resolution]
image -= 0.5
c, h, w = image.shape
new_image = torch.zeros((c, sequence_length, h, w))
new_image = new_image.to(image.device)
new_image[:, :1, :, :] = image.unsqueeze(1)
new_image = new_image.contiguous()

return new_image
Loading