Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prototype jax with dqn #222

Merged
merged 27 commits into from
Jul 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
88f33d6
Prototype JAX + DQN
kinalmehta Jun 28, 2022
219cd67
formatting changes
kinalmehta Jun 28, 2022
93304ab
bug fix: predicted q value in mse
kinalmehta Jun 28, 2022
c79684e
Prototype JAX + DQN + Atari
kinalmehta Jun 28, 2022
873cd3c
formatting changes
kinalmehta Jun 28, 2022
ce554b4
Fix `UNKNOWN: CUDNN_STATUS_EXECUTION`
vwxyzjn Jun 28, 2022
2e2d664
update mse loss calculation to be (target-pred) instead of (pred-target)
kinalmehta Jun 29, 2022
bc6f16d
Fix image format and Conv padding
vwxyzjn Jun 29, 2022
f51ab19
Adapting to the TrainState API
kinalmehta Jun 30, 2022
a2045dd
Merge branch 'vwxyzjn:master' into master
kinalmehta Jul 1, 2022
6dfd0c0
Add assets
vwxyzjn Jul 18, 2022
a0426b6
Add my benchmark script
vwxyzjn Jul 18, 2022
22ddeb8
fix benchmark script embed
kinalmehta Jul 19, 2022
89dcbb4
docs: add DQN + JAX documentation
kinalmehta Jul 19, 2022
6942c97
jit action selection and linear_schedule
kinalmehta Jul 26, 2022
0a48ade
docs fix
vwxyzjn Jul 28, 2022
ac36f61
update docs
vwxyzjn Jul 28, 2022
c3a82f1
Merge branch 'master' into kinalmehta/master
vwxyzjn Jul 28, 2022
b35ac8c
change documentation addr
vwxyzjn Jul 28, 2022
bf4c52e
add test cases
vwxyzjn Jul 28, 2022
5d61f6d
update ci
vwxyzjn Jul 28, 2022
aec5166
Add warning on installing jax on windows
vwxyzjn Jul 28, 2022
e77e789
fix pre-commit
vwxyzjn Jul 28, 2022
1794f15
revert back changes
vwxyzjn Jul 31, 2022
b18f980
update benchmark scripts
vwxyzjn Jul 31, 2022
2990502
Add docs
vwxyzjn Jul 31, 2022
8cb5fef
update docs
vwxyzjn Jul 31, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ jobs:
run: poetry run pip install setuptools==59.5.0
- name: Run core tests
run: poetry run pytest tests/test_classic_control.py
- name: Install jax
if: runner.os == 'Linux' || runner.os == 'macOS'
run: poetry install -E jax
- name: Run core tests with jax
if: runner.os == 'Linux' || runner.os == 'macOS'
run: poetry run pytest tests/test_classic_control_jax.py

test-atari-envs:
strategy:
Expand All @@ -62,6 +68,12 @@ jobs:
run: poetry run pip install setuptools==59.5.0
- name: Run atari tests
run: poetry run pytest tests/test_atari.py
- name: Install jax
if: runner.os == 'Linux' || runner.os == 'macOS'
run: poetry install -E jax
- name: Run core tests with jax
if: runner.os == 'Linux' || runner.os == 'macOS'
run: poetry run pytest tests/test_atari_jax.py

test-pybullet-envs:
strategy:
Expand Down Expand Up @@ -136,7 +148,7 @@ jobs:
with:
poetry-version: ${{ matrix.poetry-version }}

# pybullet tests
# mujoco tests
- name: Install core dependencies
run: poetry install -E pytest
- name: Install pybullet dependencies
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ You may also use a prebuilt development environment hosted in Gitpod:
| | [`ppo_continuous_action_isaacgym.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_continuous_action_isaacgym/ppo_continuous_action_isaacgym.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_continuous_action_isaacgympy)
| ✅ [Deep Q-Learning (DQN)](https://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf) | [`dqn.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn.py), [docs](https://docs.cleanrl.dev/rl-algorithms/dqn/#dqnpy) |
| | [`dqn_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_atari.py), [docs](https://docs.cleanrl.dev/rl-algorithms/dqn/#dqn_ataripy) |
| | [`dqn_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_jax.py), [docs](/rl-algorithms/dqn/#dqn_jaxpy) |
| | [`dqn_atari_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_atari_jax.py), [docs](/rl-algorithms/dqn/#dqn_atari_jaxpy) |
| ✅ [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf) | [`c51.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51.py), [docs](https://docs.cleanrl.dev/rl-algorithms/c51/#c51py) |
| | [`c51_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51_atari.py), [docs](https://docs.cleanrl.dev/rl-algorithms/c51/#c51_ataripy) |
| ✅ [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf) | [`sac_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/sac_continuous_action.py), [docs](https://docs.cleanrl.dev/rl-algorithms/sac/#sac_continuous_actionpy) |
Expand Down
16 changes: 16 additions & 0 deletions benchmark/dqn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,19 @@ OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--command "poetry run python cleanrl/dqn_atari.py --track --capture-video" \
--num-seeds 3 \
--workers 1

poetry install -E "jax"
poetry run pip install --upgrade "jax[cuda]==0.3.14" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
xvfb-run -a python -m cleanrl_utils.benchmark \
--env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \
--command "poetry run python cleanrl/dqn_jax.py --track --capture-video" \
--num-seeds 3 \
--workers 1

poetry install -E "atari jax"
poetry run pip install --upgrade "jax[cuda]==0.3.14" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
xvfb-run -a python -m cleanrl_utils.benchmark \
--env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \
--command "poetry run python cleanrl/dqn_atari_jax.py --track --capture-video" \
--num-seeds 3 \
--workers 1
266 changes: 266 additions & 0 deletions cleanrl/dqn_atari_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqn_atari_jaxpy
import argparse
import os
import random
import time
from distutils.util import strtobool

os.environ[
"XLA_PYTHON_CLIENT_MEM_FRACTION"
] = "0.7" # see https://github.com/google/jax/discussions/6332#discussioncomment-1279991

import flax
import flax.linen as nn
import gym
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.training.train_state import TrainState
from stable_baselines3.common.atari_wrappers import (
ClipRewardEnv,
EpisodicLifeEnv,
FireResetEnv,
MaxAndSkipEnv,
NoopResetEnv,
)
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter


def parse_args():
# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
help="the name of this experiment")
parser.add_argument("--seed", type=int, default=1,
help="seed of the experiment")
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, `torch.backends.cudnn.deterministic=False`")
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, cuda will be enabled by default")
parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="if toggled, this experiment will be tracked with Weights and Biases")
parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
help="the wandb's project name")
parser.add_argument("--wandb-entity", type=str, default=None,
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="weather to capture videos of the agent performances (check out `videos` folder)")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="BreakoutNoFrameskip-v4",
help="the id of the environment")
parser.add_argument("--total-timesteps", type=int, default=10000000,
help="total timesteps of the experiments")
parser.add_argument("--learning-rate", type=float, default=1e-4,
help="the learning rate of the optimizer")
parser.add_argument("--buffer-size", type=int, default=1000000,
help="the replay memory buffer size")
parser.add_argument("--gamma", type=float, default=0.99,
help="the discount factor gamma")
parser.add_argument("--target-network-frequency", type=int, default=1000,
help="the timesteps it takes to update the target network")
parser.add_argument("--batch-size", type=int, default=32,
help="the batch size of sample from the reply memory")
parser.add_argument("--start-e", type=float, default=1,
help="the starting epsilon for exploration")
parser.add_argument("--end-e", type=float, default=0.01,
help="the ending epsilon for exploration")
parser.add_argument("--exploration-fraction", type=float, default=0.10,
help="the fraction of `total-timesteps` it takes from start-e to go end-e")
parser.add_argument("--learning-starts", type=int, default=80000,
help="timestep to start learning")
parser.add_argument("--train-frequency", type=int, default=4,
help="the frequency of training")
args = parser.parse_args()
# fmt: on
return args


def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
env = EpisodicLifeEnv(env)
if "FIRE" in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = ClipRewardEnv(env)
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayScaleObservation(env)
env = gym.wrappers.FrameStack(env, 4)
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env

return thunk


# ALGO LOGIC: initialize agent here:
class QNetwork(nn.Module):
action_dim: int

@nn.compact
def __call__(self, x):
x = jnp.transpose(x, (0, 2, 3, 1))
x = x / (255.0)
x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")(x)
x = nn.relu(x)
x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")(x)
x = nn.relu(x)
x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")(x)
x = nn.relu(x)
x = x.reshape((x.shape[0], -1))
x = nn.Dense(512)(x)
x = nn.relu(x)
x = nn.Dense(self.action_dim)(x)
return x


class TrainState(TrainState):
target_params: flax.core.FrozenDict


def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
slope = (end_e - start_e) / duration
return max(slope * t + start_e, end_e)


if __name__ == "__main__":
args = parse_args()
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
import wandb

wandb.init(
project=args.wandb_project_name,
entity=args.wandb_entity,
sync_tensorboard=True,
config=vars(args),
name=run_name,
monitor_gym=True,
save_code=True,
)
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)

# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
key = jax.random.PRNGKey(args.seed)
key, q_key = jax.random.split(key, 2)

# env setup
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

obs = envs.reset()

q_network = QNetwork(action_dim=envs.single_action_space.n)

q_state = TrainState.create(
apply_fn=q_network.apply,
params=q_network.init(q_key, obs),
target_params=q_network.init(q_key, obs),
tx=optax.adam(learning_rate=args.learning_rate),
)

q_network.apply = jax.jit(q_network.apply)
# This step is not necessary as init called on same observation and key will always lead to same initializations
q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1))

rb = ReplayBuffer(
args.buffer_size,
envs.single_observation_space,
envs.single_action_space,
"cpu",
optimize_memory_usage=True,
handle_timeout_termination=True,
)

@jax.jit
def update(q_state, observations, actions, next_observations, rewards, dones):
q_next_target = q_network.apply(q_state.target_params, next_observations) # (batch_size, num_actions)
q_next_target = jnp.max(q_next_target, axis=-1) # (batch_size,)
next_q_value = rewards + (1 - dones) * args.gamma * q_next_target

def mse_loss(params):
q_pred = q_network.apply(params, observations) # (batch_size, num_actions)
q_pred = q_pred[np.arange(q_pred.shape[0]), actions.squeeze()] # (batch_size,)
return ((q_pred - next_q_value) ** 2).mean(), q_pred

(loss_value, q_pred), grads = jax.value_and_grad(mse_loss, has_aux=True)(q_state.params)
q_state = q_state.apply_gradients(grads=grads)
return loss_value, q_pred, q_state

start_time = time.time()

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
if random.random() < epsilon:
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
else:
# obs = jax.device_put(obs)
logits = q_network.apply(q_state.params, obs)
actions = logits.argmax(axis=-1)
actions = jax.device_get(actions)

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
if "episode" in info.keys():
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs

# ALGO LOGIC: training.
if global_step > args.learning_starts and global_step % args.train_frequency == 0:
data = rb.sample(args.batch_size)
# perform a gradient-descent step
loss, old_val, q_state = update(
q_state,
data.observations.numpy(),
data.actions.numpy(),
data.next_observations.numpy(),
data.rewards.flatten().numpy(),
data.dones.flatten().numpy(),
)

if global_step % 100 == 0:
writer.add_scalar("losses/td_loss", jax.device_get(loss), global_step)
writer.add_scalar("losses/q_values", jax.device_get(old_val).mean(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

# update the target network
if global_step % args.target_network_frequency == 0:
q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1))

envs.close()
writer.close()
Loading