Skip to content

Commit

Permalink
Merge pull request #148 from masa-su/feature/tutorial_v3
Browse files Browse the repository at this point in the history
Feature/tutorial v3
  • Loading branch information
masa-su authored Oct 27, 2020
2 parents 850ffac + 066e03e commit 766ed7a
Show file tree
Hide file tree
Showing 15 changed files with 10,818 additions and 0 deletions.
1,352 changes: 1,352 additions & 0 deletions tutorial/English/00-PixyzOverview.ipynb

Large diffs are not rendered by default.

724 changes: 724 additions & 0 deletions tutorial/English/01-DistributionAPITutorial.ipynb

Large diffs are not rendered by default.

859 changes: 859 additions & 0 deletions tutorial/English/02-LossAPITutorial.ipynb

Large diffs are not rendered by default.

531 changes: 531 additions & 0 deletions tutorial/English/03-ModelAPITutorial.ipynb

Large diffs are not rendered by default.

3,551 changes: 3,551 additions & 0 deletions tutorial/English/04-DeepMarkovModel.ipynb

Large diffs are not rendered by default.

34 changes: 34 additions & 0 deletions tutorial/English/prepare_cartpole_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import gym
import pickle
import numpy as np
import cv2
env = gym.make("CartPole-v1")
observation = env.reset()

episodes = {"frames":[], "actions":[]}

# for 56 *56 episode num = 500
# for 28 * 28 episode num = 1000
for _episode in range(1000):
frames = []
actions = []
for _frame in range(30):
action = env.action_space.sample() # your agent here (this takes random actions)
frame = env.render(mode='rgb_array')
observation, reward, done, info = env.step(action)

img = frame
img = img[150:350, 200:400]
img = cv2.resize(img, (28,28))

frames.append(img)
actions.append(action)
observation = env.reset()
episodes["frames"].append(frames)
episodes["actions"].append(actions)
env.close()

data = [np.array(episodes["frames"]), np.array(episodes["actions"])]
print(data[0].shape, data[1].shape)
with open('cartpole_28.pickle', mode='wb') as f:
pickle.dump(data, f)
43 changes: 43 additions & 0 deletions tutorial/English/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from torch.utils.data import Dataset
import pickle
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt

def imshow(img_tensors):
img = torchvision.utils.make_grid(img_tensors)
npimg = img.numpy()
plt.figure(figsize=(16, 12))
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()



class DMMDataset(Dataset):
def __init__(self, pickle_path="cartpole_28.pickle"):

with open(pickle_path, mode='rb') as f:
data = pickle.load(f)
episode_frames, actions = data
# episode_frames: np.array([episode_num, one_episode_length, height, width, Channels]) (10000, 30, 28, 28, 3)
# actions: np.array([episode_num, one_episode_length]) (10000, 30)
# HWC → CHW
episode_frames = episode_frames.transpose(0, 1, 4, 2, 3)
actions = actions[:, :, np.newaxis]

self.episode_frames = torch.from_numpy(episode_frames.astype(np.float32))
self.actions = torch.from_numpy(actions.astype(np.float32))

def __len__(self):
return len(self.episode_frames)

def __getitem__(self, idx):
return {
"episode_frames": self.episode_frames[idx] / 255,
"actions": self.actions[idx]
}


if __name__ == "__main__":
pass
Loading

0 comments on commit 766ed7a

Please sign in to comment.