diff --git a/lzero/entry/__init__.py b/lzero/entry/__init__.py index 7dc0328f7..7189fcf04 100644 --- a/lzero/entry/__init__.py +++ b/lzero/entry/__init__.py @@ -9,4 +9,14 @@ from .train_rezero import train_rezero from .train_unizero import train_unizero from .train_unizero_segment import train_unizero_segment + +from .train_muzero_multitask_segment_noddp import train_muzero_multitask_segment_noddp +from .train_muzero_multitask_segment_ddp import train_muzero_multitask_segment_ddp + + +from .train_unizero_multitask_serial import train_unizero_multitask_serial +from .train_unizero_multitask_segment_ddp import train_unizero_multitask_segment_ddp +from .train_unizero_multitask_segment_serial import train_unizero_multitask_segment_serial + +from .train_unizero_multitask_segment_eval import train_unizero_multitask_segment_eval from .utils import * diff --git a/lzero/entry/train_muzero_multitask_segment_ddp.py b/lzero/entry/train_muzero_multitask_segment_ddp.py new file mode 100644 index 000000000..b717b710e --- /dev/null +++ b/lzero/entry/train_muzero_multitask_segment_ddp.py @@ -0,0 +1,579 @@ +import logging +import os +from functools import partial +from typing import Tuple, Optional, List + +import torch +import numpy as np +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank, get_world_size +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.mcts import MuZeroGameBuffer as GameBuffer +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector +from ding.utils import EasyTimer +import torch.distributed as dist + +import concurrent.futures + +# ========== 超时时间设置 ========== +TIMEOUT = 3600 # 例如,60分钟 + + +def safe_eval( + evaluator: Evaluator, + learner: BaseLearner, + collector: Collector, + rank: int, + world_size: int +) -> Tuple[Optional[bool], Optional[float]]: + """ + 安全地执行评估操作,防止因超时导致训练过程阻塞。 + + Args: + evaluator (Evaluator): 评估器实例。 + learner (BaseLearner): 学习器实例。 + collector (Collector): 数据收集器实例。 + rank (int): 当前进程的排名。 + world_size (int): 总进程数。 + + Returns: + Tuple[Optional[bool], Optional[float]]: + - stop (Optional[bool]): 评估是否停止的标志。 + - reward (Optional[float]): 评估得到的奖励。 + """ + print(f"=========评估前 Rank {rank}/{world_size}===========") + # 重置 stop_event,确保每次评估前都处于未设置状态 + evaluator.stop_event.clear() + with concurrent.futures.ThreadPoolExecutor() as executor: + # 提交 evaluator.eval 任务 + future = executor.submit( + evaluator.eval, + learner.save_checkpoint, + learner.train_iter, + collector.envstep + ) + + try: + stop, reward = future.result(timeout=TIMEOUT) + except concurrent.futures.TimeoutError: + # 超时,设置 evaluator 的 stop_event + evaluator.stop_event.set() + print(f"评估操作在 Rank {rank}/{world_size} 上超过 {TIMEOUT} 秒超时。") + return None, None + + print(f"======评估后 Rank {rank}/{world_size}======") + return stop, reward + + +def allocate_batch_size( + cfgs: List, + game_buffers: List[GameBuffer], + alpha: float = 1.0, + clip_scale: int = 1 +) -> List[int]: + """ + 根据不同任务的 num_of_collected_episodes 反比分配 batch_size, + 并动态调整 batch_size 限制范围以提高训练的稳定性和效率。 + + Args: + cfgs (List): 每个任务的配置列表。 + game_buffers (List[GameBuffer]): 每个任务的 replay_buffer 实例列表。 + alpha (float): 控制反比程度的超参数 (默认为1.0)。 + clip_scale (int): 动态调整的缩放因子 (默认为1)。 + + Returns: + List[int]: 分配后的 batch_size 列表。 + """ + # 提取每个任务的 num_of_collected_episodes + buffer_num_of_collected_episodes = [ + buffer.num_of_collected_episodes for buffer in game_buffers + ] + + # 获取当前的 world_size 和 rank + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + # 收集所有 rank 的 num_of_collected_episodes 列表 + all_task_num_of_collected_episodes = [None for _ in range(world_size)] + torch.distributed.all_gather_object( + all_task_num_of_collected_episodes, + buffer_num_of_collected_episodes + ) + + # 将所有 rank 的 num_of_collected_episodes 拼接成一个大列表 + all_task_num_of_collected_episodes = [ + item for sublist in all_task_num_of_collected_episodes for item in sublist + ] + if rank == 0: + print(f'all_task_num_of_collected_episodes: {all_task_num_of_collected_episodes}') + + # 计算每个任务的反比权重 + inv_episodes = np.array([ + 1.0 / (episodes + 1) for episodes in all_task_num_of_collected_episodes + ]) + inv_sum = np.sum(inv_episodes) + + # 计算总的 batch_size (所有任务 cfg.policy.max_batch_size 的和) + max_batch_size = cfgs[0].policy.max_batch_size + + # 动态调整的部分:最小和最大的 batch_size 范围 + avg_batch_size = max_batch_size / world_size + min_batch_size = avg_batch_size / clip_scale + max_batch_size = avg_batch_size * clip_scale + + # 动态调整 alpha,让 batch_size 的变化更加平滑 + task_weights = (inv_episodes / inv_sum) ** alpha + batch_sizes = max_batch_size * task_weights + + # 控制 batch_size 在 [min_batch_size, max_batch_size] 之间 + batch_sizes = np.clip(batch_sizes, min_batch_size, max_batch_size) + + # 确保 batch_size 是整数 + batch_sizes = [int(size) for size in batch_sizes] + + # 返回最终分配的 batch_size 列表 + return batch_sizes + + +def train_muzero_multitask_segment_ddp( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + Overview: + The train entry for multi-task MuZero, adapted from UniZero's multi-task training. + This script aims to enhance the planning capabilities of reinforcement learning agents + by leveraging multi-task learning to address diverse environments. + + Args: + input_cfg_list (List[Tuple[int, Tuple[dict, dict]]]): + Configurations for different tasks as a list of tuples containing task ID and configuration dictionaries. + seed (int): + Random seed for reproducibility. + model (Optional[torch.nn.Module]): + Predefined model instance. If provided, it will be used instead of creating a new one. + model_path (Optional[str]): + Path to the pretrained model checkpoint. Should point to the ckpt file of the pretrained model. + max_train_iter (Optional[int]): + Maximum number of training iterations. Defaults to 1e10. + max_env_step (Optional[int]): + Maximum number of environment interaction steps. Defaults to 1e10. + + Returns: + Policy: + The trained policy instance. + """ + # 获取当前进程的 rank 和总的进程数 + rank = get_rank() + world_size = get_world_size() + + # 任务划分 + total_tasks = len(input_cfg_list) + tasks_per_rank = total_tasks // world_size + remainder = total_tasks % world_size + + if rank < remainder: + start_idx = rank * (tasks_per_rank + 1) + end_idx = start_idx + tasks_per_rank + 1 + else: + start_idx = rank * tasks_per_rank + remainder + end_idx = start_idx + tasks_per_rank + + tasks_for_this_rank = input_cfg_list[start_idx:end_idx] + + # 确保至少有一个任务 + if len(tasks_for_this_rank) == 0: + logging.warning(f"Rank {rank}: 未分配任何任务,继续运行但无任务处理。") + # 初始化一些空列表以避免后续代码报错 + cfgs, game_buffers, collector_envs, evaluator_envs, collectors, evaluators = [], [], [], [], [], [] + return + + print(f"Rank {rank}/{world_size}, 处理任务 {start_idx} 到 {end_idx - 1}") + + cfgs = [] + game_buffers = [] + collector_envs = [] + evaluator_envs = [] + collectors = [] + evaluators = [] + + # 使用第一个任务的配置来创建共享的 policy + task_id, [cfg, create_cfg] = tasks_for_this_rank[0] + + # 设置每个任务的随机种子和任务编号 + for config in tasks_for_this_rank: + config[1][0].policy.task_num = len(tasks_for_this_rank) + + # 根据 CUDA 可用性设置设备 + cfg.policy.device = cfg.policy.model.device if torch.cuda.is_available() else 'cpu' + logging.info(f'cfg.policy.device: {cfg.policy.device}') + + # 编译配置 + cfg = compile_config( + cfg, + seed=seed, + env=None, + auto=True, + create_cfg=create_cfg, + save_cfg=True + ) + # 创建共享的 policy + policy = create_policy( + cfg.policy, + model=model, + enable_field=['learn', 'collect', 'eval'] + ) + + # 如果指定了预训练模型,则加载 + if model_path is not None: + logging.info(f'开始加载模型来自 {model_path}...') + policy.learn_mode.load_state_dict( + torch.load(model_path, map_location=cfg.policy.device) + ) + logging.info(f'完成加载模型来自 {model_path}.') + + # 创建 TensorBoard 的日志记录器 + log_dir = os.path.join(f'./{cfg.exp_name}/log', f'serial_rank_{rank}') + tb_logger = SummaryWriter(log_dir) + + # 创建共享的 learner + learner = BaseLearner( + cfg.policy.learn.learner, + policy.learn_mode, + tb_logger, + exp_name=cfg.exp_name + ) + + policy_config = cfg.policy + batch_size = policy_config.batch_size[0] + + # 只处理当前进程分配到的任务 + for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(tasks_for_this_rank): + # 设置每个任务自己的随机种子 + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + cfg = compile_config( + cfg, + seed=seed + task_id, + env=None, + auto=True, + create_cfg=create_cfg, + save_cfg=True + ) + policy_config = cfg.policy + policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode + + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager( + cfg.env.manager, + [partial(env_fn, cfg=c) for c in collector_env_cfg] + ) + evaluator_env = create_env_manager( + cfg.env.manager, + [partial(env_fn, cfg=c) for c in evaluator_env_cfg] + ) + collector_env.seed(cfg.seed + task_id) + evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False) + set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda) + + # 为每个任务创建不同的 game buffer、collector、evaluator + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + + cfgs.append(cfg) + replay_buffer.batch_size = cfg.policy.batch_size[task_id] + + game_buffers.append(replay_buffer) + collector_envs.append(collector_env) + evaluator_envs.append(evaluator_env) + collectors.append(collector) + evaluators.append(evaluator) + + learner.call_hook('before_run') + value_priority_tasks = {} + + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + update_per_collect = cfg.policy.update_per_collect + + while True: + torch.cuda.empty_cache() + + if cfg.policy.allocated_batch_sizes: + # TODO========== + # 线性变化的 随着 train_epoch 从 0 增加到 1000, clip_scale 从 1 线性增加到 4 + clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4) + allocated_batch_sizes = allocate_batch_size( + cfgs, + game_buffers, + alpha=1.0, + clip_scale=clip_scale + ) + if rank == 0: + print("分配后的 batch_sizes: ", allocated_batch_sizes) + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers) + ): + cfg.policy.batch_size = allocated_batch_sizes[idx] + policy._cfg.batch_size[idx] = allocated_batch_sizes[idx] + + # 对于当前进程的每个任务,进行数据收集和评估 + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers) + ): + + log_buffer_memory_usage( + learner.train_iter, + replay_buffer, + tb_logger, + cfg.policy.task_id + ) + + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 # 默认的 epsilon 值 + } + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter): + print('=' * 20) + print(f'Rank {rank} 评估 task_id: {cfg.policy.task_id}...') + + # 在训练进程中调用 safe_eval + stop, reward = safe_eval( + evaluator, + learner, + collector, + rank, + world_size + ) + # 判断评估是否成功 + if stop is None or reward is None: + print(f"Rank {rank} 在评估期间遇到问题。继续训练中...") + else: + print(f"评估成功: stop={stop}, reward={reward}") + + print('=' * 20) + print(f'entry: Rank {rank} 收集 task_id: {cfg.policy.task_id}...') + + # 收集数据 + new_data = collector.collect( + train_iter=learner.train_iter, + policy_kwargs=collect_kwargs + ) + + # 更新 replay buffer + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + # 周期性地重新分析缓冲区 + if cfg.policy.buffer_reanalyze_freq >= 1: + # 在一个训练 epoch 中重新分析缓冲区 次 + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + # 每 <1/buffer_reanalyze_freq> 个训练 epoch 重新分析一次缓冲区 + if ( + train_epoch % int(1 / cfg.policy.buffer_reanalyze_freq) == 0 and + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > + int(reanalyze_batch_size / cfg.policy.reanalyze_partition) + ): + with timer: + # 每个重新分析过程将重新分析 个序列 + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'缓冲区重新分析计数: {buffer_reanalyze_count}') + logging.info(f'缓冲区重新分析时间: {timer.value}') + + # 数据收集结束后添加日志 + logging.info(f'Rank {rank}: 完成任务 {cfg.policy.task_id} 的数据收集') + + # 检查是否有足够的数据进行训练 + not_enough_data = any( + replay_buffer.get_num_of_transitions() < cfgs[0].policy.max_batch_size / world_size + for replay_buffer in game_buffers + ) + + # 同步训练前所有 rank 的准备状态 + try: + dist.barrier() + logging.info(f'Rank {rank}: 通过训练前的 barrier') + except Exception as e: + logging.error(f'Rank {rank}: Barrier 失败,错误: {e}') + break # 或者进行其他错误处理 + + # 学习策略 + if not not_enough_data: + # Learner 将在一次迭代中训练 update_per_collect 次 + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_multi_task = 0 + for idx, (cfg, collector, replay_buffer) in enumerate( + zip(cfgs, collectors, game_buffers) + ): + envstep_multi_task += collector.envstep + batch_size = cfg.policy.batch_size[cfg.policy.task_id] + if replay_buffer.get_num_of_transitions() > batch_size: + if cfg.policy.buffer_reanalyze_freq >= 1: + # 在一个训练 epoch 中重新分析缓冲区 次 + if ( + i % reanalyze_interval == 0 and + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > + int(reanalyze_batch_size / cfg.policy.reanalyze_partition) + ): + with timer: + # 每个重新分析过程将重新分析 个序列 + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'缓冲区重新分析计数: {buffer_reanalyze_count}') + logging.info(f'缓冲区重新分析时间: {timer.value}') + + train_data = replay_buffer.sample(batch_size, policy) + # 追加 task_id,以便在训练时区分任务 + train_data.append(cfg.policy.task_id) + train_data_multi_task.append(train_data) + else: + logging.warning( + f'Replay buffer 中的数据不足以采样一个 mini-batch: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}' + ) + break + + if train_data_multi_task: + # 在训练时,DDP 会自动同步梯度和参数 + log_vars = learner.train(train_data_multi_task, envstep_multi_task) + + if cfg.policy.use_priority: + for idx, (cfg, replay_buffer) in enumerate( + zip(cfgs, game_buffers) + ): + # 更新任务特定的 replay buffer 的优先级 + task_id = cfg.policy.task_id + replay_buffer.update_priority( + train_data_multi_task[idx], + log_vars[0][f'value_priority_task{task_id}'] + ) + + current_priorities = log_vars[0][f'value_priority_task{task_id}'] + + mean_priority = np.mean(current_priorities) + std_priority = np.std(current_priorities) + + alpha = 0.1 # 运行均值的平滑因子 + if f'running_mean_priority_task{task_id}' not in value_priority_tasks: + # 如果不存在,则初始化运行均值 + value_priority_tasks[f'running_mean_priority_task{task_id}'] = mean_priority + else: + # 更新运行均值 + value_priority_tasks[f'running_mean_priority_task{task_id}'] = ( + alpha * mean_priority + + (1 - alpha) * value_priority_tasks[f'running_mean_priority_task{task_id}'] + ) + + # 使用运行均值计算归一化的优先级 + running_mean_priority = value_priority_tasks[f'running_mean_priority_task{task_id}'] + normalized_priorities = ( + current_priorities - running_mean_priority + ) / (std_priority + 1e-6) + + # 如果需要,可以将归一化的优先级存储回 replay buffer + # replay_buffer.update_priority(train_data_multi_task[idx], normalized_priorities) + + # 如果设置了 print_task_priority_logs 标志,则记录统计信息 + if cfg.policy.print_task_priority_logs: + print( + f"任务 {task_id} - 平均优先级: {mean_priority:.8f}, " + f"运行平均优先级: {running_mean_priority:.8f}, " + f"标准差: {std_priority:.8f}" + ) + + train_epoch += 1 + + # 同步所有 Rank,确保所有 Rank 都完成了训练 + try: + dist.barrier() + logging.info(f'Rank {rank}: 通过训练后的 barrier') + except Exception as e: + logging.error(f'Rank {rank}: Barrier 失败,错误: {e}') + break # 或者进行其他错误处理 + + # 检查是否需要终止训练 + try: + # local_envsteps 不再需要填充 + local_envsteps = [collector.envstep for collector in collectors] + + total_envsteps = [None for _ in range(world_size)] + dist.all_gather_object(total_envsteps, local_envsteps) + + # 将所有 envsteps 拼接在一起 + all_envsteps = torch.cat([ + torch.tensor(envsteps, device=cfg.policy.device) + for envsteps in total_envsteps + ]) + max_envstep_reached = torch.all(all_envsteps >= max_env_step) + + # 收集所有进程的 train_iter + global_train_iter = torch.tensor([learner.train_iter], device=cfg.policy.device) + all_train_iters = [torch.zeros_like(global_train_iter) for _ in range(world_size)] + dist.all_gather(all_train_iters, global_train_iter) + + max_train_iter_reached = torch.any( + torch.stack(all_train_iters) >= max_train_iter + ) + + if max_envstep_reached.item() or max_train_iter_reached.item(): + logging.info(f'Rank {rank}: 满足终止条件') + dist.barrier() # 确保所有进程同步 + break + else: + pass + + except Exception as e: + logging.error(f'Rank {rank}: 终止检查失败,错误: {e}') + break # 或者进行其他错误处理 + + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/entry/train_muzero_multitask_segment_noddp.py b/lzero/entry/train_muzero_multitask_segment_noddp.py new file mode 100644 index 000000000..bbeecb227 --- /dev/null +++ b/lzero/entry/train_muzero_multitask_segment_noddp.py @@ -0,0 +1,270 @@ +import logging +import os +from functools import partial +from typing import Tuple, Optional, List + +import torch +import numpy as np +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage, log_buffer_run_time +from lzero.policy import visit_count_temperature +from lzero.policy.random_policy import LightZeroRandomPolicy +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector +from lzero.mcts import MuZeroGameBuffer as GameBuffer # 根据不同策略选择合适的 GameBuffer +from .utils import random_collect + +from ding.utils import EasyTimer +timer = EasyTimer() +from line_profiler import line_profiler + + +def train_muzero_multitask_segment_noddp( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': # noqa + """ + Overview: + 多任务训练入口,基于 MuZero 的多任务版本,支持多任务环境的训练。 + 参考论文 UniZero: Generalized and Efficient Planning with Scalable Latent World Models。 + Arguments: + - input_cfg_list (List[Tuple[int, Tuple[dict, dict]]]): 不同任务的配置列表。 + - seed (int): 随机种子。 + - model (Optional[torch.nn.Module]): torch.nn.Module 的实例。 + - model_path (Optional[str]): 预训练模型路径,指向预训练模型的 ckpt 文件。 + - max_train_iter (Optional[int]): 最大训练迭代次数。 + - max_env_step (Optional[int]): 最大环境交互步数。 + Returns: + - policy (Policy): 收敛后的策略。 + """ + cfgs = [] + game_buffers = [] + collector_envs = [] + evaluator_envs = [] + collectors = [] + evaluators = [] + + task_id, [cfg, create_cfg] = input_cfg_list[0] + + # Ensure the specified policy type is supported + assert create_cfg.policy.type in ['muzero_multitask'], "train_muzero entry now only supports 'muzero'" + + # Set device based on CUDA availability + cfg.policy.device = cfg.policy.model.device if torch.cuda.is_available() else 'cpu' + logging.info(f'cfg.policy.device: {cfg.policy.device}') + + # Compile the configuration + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create shared policy for all tasks + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # Load pretrained model if specified + if model_path is not None: + logging.info(f'Loading model from {model_path} begin...') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info(f'Loading model from {model_path} end!') + + # Create SummaryWriter for TensorBoard logging + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None + # Create shared learner for all tasks + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + # TODO task_id = 0: + policy_config = cfg.policy + batch_size = policy_config.batch_size[0] + + # 初始化多任务配置 + for task_id, input_cfg in input_cfg_list: + + if task_id > 0: + # Get the configuration for each task + cfg, create_cfg = input_cfg + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + policy_config = cfg.policy + policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode + + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(cfg.seed + task_id) + evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False) + set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda) + + # ===== NOTE: Create different game buffer, collector, evaluator for each task ==== + # TODO: share replay buffer for all tasks + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + + cfgs.append(cfg) + replay_buffer.batch_size = cfg.policy.batch_size[task_id] + game_buffers.append(replay_buffer) + collector_envs.append(collector_env) + evaluator_envs.append(evaluator_env) + collectors.append(collector) + evaluators.append(evaluator) + + learner.call_hook('before_run') + value_priority_tasks = {} + + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + update_per_collect = cfg.policy.update_per_collect + + while True: + torch.cuda.empty_cache() + + # 遍历每个任务进行数据收集和评估 + for task_id, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 # 默认 epsilon 值 + } + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + # 评估策略性能 + if learner.train_iter ==0 or evaluator.should_eval(learner.train_iter): + logging.info(f'========== 评估任务 {task_id} ==========') + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + # 收集数据 + logging.info(f'========== 收集任务 {task_id} 数据 ==========') + # collector.reset() + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + # 确定每次收集后的更新次数 + if update_per_collect is None: + collected_transitions_num = sum( + min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0]) + update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio) + + # 更新回放缓冲区 + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + # 定期重新分析缓冲区 + if cfg.policy.buffer_reanalyze_freq >= 1: + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + if train_epoch % int(1 / cfg.policy.buffer_reanalyze_freq) == 0 and replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int(reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'缓冲区重新分析次数: {buffer_reanalyze_count}') + logging.info(f'缓冲区重新分析时间: {timer.value}') + + # 检查是否有足够的数据进行训练 + not_enough_data = any(replay_buffer.get_num_of_transitions() < batch_size for replay_buffer in game_buffers) + + if not not_enough_data: + # 进行训练 + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_multi_task = 0 + for task_id, (cfg, collector, replay_buffer) in enumerate(zip(cfgs, collectors, game_buffers)): + envstep_multi_task += collector.envstep + if replay_buffer.get_num_of_transitions() > batch_size: + batch_size = cfg.policy.batch_size[task_id] + + + if cfg.policy.buffer_reanalyze_freq >= 1: + if i % reanalyze_interval == 0 and replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int( + reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'缓冲区重新分析次数: {buffer_reanalyze_count}') + logging.info(f'缓冲区重新分析时间: {timer.value}') + + train_data = replay_buffer.sample(batch_size, policy) + train_data.append(task_id) # 添加 task_id + train_data_multi_task.append(train_data) + else: + logging.warning( + f'回放缓冲区数据不足以采样 mini-batch: ' + f'batch_size: {batch_size}, 回放缓冲区: {replay_buffer}' + ) + break + + if train_data_multi_task: + log_vars = learner.train(train_data_multi_task, envstep_multi_task) + + if cfg.policy.use_priority: + for task_id, replay_buffer in enumerate(game_buffers): + current_priorities = log_vars[0][f'value_priority_task{task_id}'] + mean_priority = np.mean(current_priorities) + std_priority = np.std(current_priorities) + alpha = 0.1 + if f'running_mean_priority_task{task_id}' not in value_priority_tasks: + value_priority_tasks[f'running_mean_priority_task{task_id}'] = mean_priority + else: + value_priority_tasks[f'running_mean_priority_task{task_id}'] = ( + alpha * mean_priority + (1 - alpha) * value_priority_tasks[f'running_mean_priority_task{task_id}'] + ) + running_mean_priority = value_priority_tasks[f'running_mean_priority_task{task_id}'] + normalized_priorities = (current_priorities - running_mean_priority) / (std_priority + 1e-6) + + if cfg.policy.print_task_priority_logs: + print(f"任务 {task_id} - 平均优先级: {mean_priority:.8f}, " + f"运行平均优先级: {running_mean_priority:.8f}, " + f"标准差: {std_priority:.8f}") + + # 清除位置嵌入缓存 + train_epoch += 1 + + # 检查是否达到训练结束条件 + if all(collector.envstep >= max_env_step for collector in collectors) or learner.train_iter >= max_train_iter: + break + + # 调用学习器的 after_run 钩子 + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/entry/train_unizero.py b/lzero/entry/train_unizero.py index cd7ff7605..c1687749b 100644 --- a/lzero/entry/train_unizero.py +++ b/lzero/entry/train_unizero.py @@ -119,6 +119,10 @@ def train_unizero( batch_size = policy._cfg.batch_size + # TODO: for visualize + # stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + # import sys; sys.exit(0) + while True: # Log buffer memory usage log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) diff --git a/lzero/entry/train_unizero_multitask_segment_ddp.py b/lzero/entry/train_unizero_multitask_segment_ddp.py new file mode 100644 index 000000000..aea9a4ba6 --- /dev/null +++ b/lzero/entry/train_unizero_multitask_segment_ddp.py @@ -0,0 +1,557 @@ +import logging +import os +from functools import partial +from typing import Tuple, Optional, List + +import torch +import numpy as np +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank, get_world_size +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.mcts import UniZeroGameBuffer as GameBuffer +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector +from ding.utils import EasyTimer +import torch.distributed as dist + +import concurrent.futures + +# 设置超时时间 (秒) +TIMEOUT = 12000 # 例如200分钟 + +timer = EasyTimer() + + +def safe_eval( + evaluator: Evaluator, + learner: BaseLearner, + collector: Collector, + rank: int, + world_size: int +) -> Tuple[Optional[bool], Optional[float]]: + """ + Safely执行评估任务,避免超时。 + + Args: + evaluator (Evaluator): 评估器实例。 + learner (BaseLearner): 学习器实例。 + collector (Collector): 数据收集器实例。 + rank (int): 当前进程的rank。 + world_size (int): 总进程数。 + + Returns: + Tuple[Optional[bool], Optional[float]]: 如果评估成功,返回停止标志和奖励,否则返回(None, None)。 + """ + try: + print(f"=========评估开始 Rank {rank}/{world_size}===========") + # 重置 stop_event,确保每次评估前都处于未设置状态 + evaluator.stop_event.clear() + with concurrent.futures.ThreadPoolExecutor() as executor: + # 提交评估任务 + future = executor.submit(evaluator.eval, learner.save_checkpoint, learner.train_iter, collector.envstep) + try: + stop, reward = future.result(timeout=TIMEOUT) + except concurrent.futures.TimeoutError: + # 超时,设置 stop_event + evaluator.stop_event.set() + print(f"评估操作在 Rank {rank}/{world_size} 上超时,耗时 {TIMEOUT} 秒。") + return None, None + + print(f"======评估结束 Rank {rank}/{world_size}======") + return stop, reward + except Exception as e: + print(f"Rank {rank}/{world_size} 评估过程中发生错误: {e}") + return None, None + + +def allocate_batch_size( + cfgs: List[dict], + game_buffers: List[GameBuffer], + alpha: float = 1.0, + clip_scale: int = 1 +) -> List[int]: + """ + 根据不同任务的收集剧集数反比分配batch_size,并动态调整batch_size范围以提高训练稳定性和效率。 + + Args: + cfgs (List[dict]): 每个任务的配置列表。 + game_buffers (List[GameBuffer]): 每个任务的重放缓冲区实例列表。 + alpha (float, optional): 控制反比程度的超参数。默认为1.0。 + clip_scale (int, optional): 动态调整的clip比例。默认为1。 + + Returns: + List[int]: 分配后的batch_size列表。 + """ + # 提取每个任务的 collected episodes 数量 + buffer_num_of_collected_episodes = [buffer.num_of_collected_episodes for buffer in game_buffers] + + # 获取当前的 world_size 和 rank + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + # 收集所有 rank 的 collected episodes 列表 + all_task_num_of_collected_episodes = [None for _ in range(world_size)] + torch.distributed.all_gather_object(all_task_num_of_collected_episodes, buffer_num_of_collected_episodes) + + # 将所有 rank 的 collected episodes 合并为一个大列表 + all_task_num_of_collected_episodes = [ + episode for sublist in all_task_num_of_collected_episodes for episode in sublist + ] + if rank == 0: + print(f'所有任务的 collected episodes: {all_task_num_of_collected_episodes}') + + # 计算每个任务的反比权重 + inv_episodes = np.array([1.0 / (episodes + 1) for episodes in all_task_num_of_collected_episodes]) + inv_sum = np.sum(inv_episodes) + + # 计算总的batch_size (所有任务 cfg.policy.batch_size 的和) + total_batch_size = cfgs[0].policy.total_batch_size + + # 动态调整的部分:最小和最大的 batch_size 范围 + avg_batch_size = total_batch_size / world_size + min_batch_size = avg_batch_size / clip_scale + max_batch_size = avg_batch_size * clip_scale + + # 动态调整 alpha,让 batch_size 的变化更加平滑 + task_weights = (inv_episodes / inv_sum) ** alpha + batch_sizes = total_batch_size * task_weights + + # 控制 batch_size 在 [min_batch_size, max_batch_size] 之间 + batch_sizes = np.clip(batch_sizes, min_batch_size, max_batch_size) + + # 确保 batch_size 是整数 + batch_sizes = [int(size) for size in batch_sizes] + + return batch_sizes + + +def train_unizero_multitask_segment_ddp( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + Overview: + UniZero的训练入口,旨在通过解决MuZero类算法在需要捕捉长期依赖环境中的局限性,提高强化学习代理的规划能力。 + 详细信息请参阅 https://arxiv.org/abs/2406.10667。 + + Args: + - input_cfg_list (:obj:`List[Tuple[int, Tuple[dict, dict]]]`): 不同任务的配置列表。 + - seed (:obj:`int`): 随机种子。 + - model (:obj:`Optional[torch.nn.Module]`): torch.nn.Module实例。 + - model_path (:obj:`Optional[str]`): 预训练模型路径,应指向预训练模型的ckpt文件。 + - max_train_iter (:obj:`Optional[int]`): 训练中的最大策略更新迭代次数。 + - max_env_step (:obj:`Optional[int]`): 最大收集环境交互步数。 + + Returns: + - policy (:obj:`Policy`): 收敛的策略。 + """ + # 获取当前进程的rank和总进程数 + rank = get_rank() + world_size = get_world_size() + + # 任务划分 + total_tasks = len(input_cfg_list) + tasks_per_rank = total_tasks // world_size + remainder = total_tasks % world_size + + if rank < remainder: + start_idx = rank * (tasks_per_rank + 1) + end_idx = start_idx + tasks_per_rank + 1 + else: + start_idx = rank * tasks_per_rank + remainder + end_idx = start_idx + tasks_per_rank + + tasks_for_this_rank = input_cfg_list[start_idx:end_idx] + + # 确保至少有一个任务 + if len(tasks_for_this_rank) == 0: + logging.warning(f"Rank {rank}: 未分配任务,继续执行。") + # 初始化空列表以避免后续代码报错 + cfgs, game_buffers, collector_envs, evaluator_envs, collectors, evaluators = [], [], [], [], [], [] + else: + print(f"Rank {rank}/{world_size}, 处理任务 {start_idx} 到 {end_idx - 1}") + + cfgs = [] + game_buffers = [] + collector_envs = [] + evaluator_envs = [] + collectors = [] + evaluators = [] + + if tasks_for_this_rank: + # 使用第一个任务的配置创建共享的policy + task_id, [cfg, create_cfg] = tasks_for_this_rank[0] + + for config in tasks_for_this_rank: + config[1][0].policy.task_num = tasks_per_rank + + # 确保指定的policy类型是支持的 + assert create_cfg.policy.type in ['unizero_multitask'], "当前仅支持 'unizero_multitask' 类型的policy" + + # 根据CUDA可用性设置设备 + cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' + logging.info(f'配置的设备: {cfg.policy.device}') + + # 编译配置 + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # 创建共享的policy + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # 加载预训练模型(如果提供) + if model_path is not None: + logging.info(f'开始加载模型: {model_path}') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info(f'完成加载模型: {model_path}') + + # 创建TensorBoard日志记录器 + log_dir = os.path.join('./{}/log'.format(cfg.exp_name), f'serial_rank_{rank}') + tb_logger = SummaryWriter(log_dir) + + # 创建共享的learner + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + policy_config = cfg.policy + + # 处理当前进程分配到的每个任务 + for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(tasks_for_this_rank): + # 设置每个任务的随机种子 + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + cfg = compile_config(cfg, seed=seed + task_id, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + policy_config = cfg.policy + policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode + + # 创建环境 + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(cfg.seed + task_id) + evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False) + set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda) + + # 创建不同的game buffer、collector和evaluator + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + + cfgs.append(cfg) + replay_buffer.batch_size = cfg.policy.batch_size[task_id] + + game_buffers.append(replay_buffer) + collector_envs.append(collector_env) + evaluator_envs.append(evaluator_env) + collectors.append(collector) + evaluators.append(evaluator) + + # 调用learner的before_run钩子 + learner.call_hook('before_run') + value_priority_tasks = {} + + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + update_per_collect = cfg.policy.update_per_collect + + if cfg.policy.eval_offline: + eval_train_iter_list = [] + eval_train_envstep_list = [] + + while True: + # 动态调整batch_size + if cfg.policy.allocated_batch_sizes: + clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4) + allocated_batch_sizes = allocate_batch_size(cfgs, game_buffers, alpha=1.0, clip_scale=clip_scale) + if rank == 0: + print("分配后的 batch_sizes: ", allocated_batch_sizes) + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + cfg.policy.batch_size = allocated_batch_sizes + policy._cfg.batch_size = allocated_batch_sizes + + # 对于当前进程的每个任务,进行数据收集和评估 + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + + # 记录缓冲区内存使用情况 + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger, cfg.policy.task_id) + + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 # 默认的epsilon值 + } + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + # 判断是否需要进行评估 + if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter): + print(f'cfg.policy.eval_offline:{cfg.policy.eval_offline}') + if cfg.policy.eval_offline: + eval_train_iter_list.append(learner.train_iter) + eval_train_envstep_list.append(collector.envstep) + else: + print('=' * 20) + print(f'Rank {rank} 评估任务_id: {cfg.policy.task_id}...') + + # 执行安全评估 + stop, reward = safe_eval(evaluator, learner, collector, rank, world_size) + # 判断评估是否成功 + if stop is None or reward is None: + print(f"Rank {rank} 在评估过程中遇到问题,继续训练...") + else: + print(f"评估成功: stop={stop}, reward={reward}") + + print('=' * 20) + print(f'开始收集 Rank {rank} 的任务_id: {cfg.policy.task_id}...') + + # 在每次收集之前重置初始数据,这对于多任务设置非常重要 + collector._policy.reset(reset_init_data=True) + # 收集数据 + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + # 更新重放缓冲区 + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + # 周期性地重新分析缓冲区 + if cfg.policy.buffer_reanalyze_freq >= 1: + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + if train_epoch % int(1 / cfg.policy.buffer_reanalyze_freq) == 0 and \ + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int( + reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'缓冲区重新分析次数: {buffer_reanalyze_count}') + logging.info(f'缓冲区重新分析耗时: {timer.value}') + + # 数据收集结束后添加日志 + logging.info(f'Rank {rank}: 完成任务 {cfg.policy.task_id} 的数据收集') + + # 检查是否有足够的数据进行训练 + not_enough_data = any( + replay_buffer.get_num_of_transitions() < cfgs[0].policy.total_batch_size / world_size + for replay_buffer in game_buffers + ) + + # 同步训练前所有rank的准备状态 + try: + dist.barrier() + logging.info(f'Rank {rank}: 通过训练前的同步障碍') + except Exception as e: + logging.error(f'Rank {rank}: 同步障碍失败,错误: {e}') + break + + # 学习策略 + if not not_enough_data: + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_multi_task = 0 + for idx, (cfg, collector, replay_buffer) in enumerate(zip(cfgs, collectors, game_buffers)): + envstep_multi_task += collector.envstep + batch_size = cfg.policy.batch_size[cfg.policy.task_id] + if replay_buffer.get_num_of_transitions() > batch_size: + if cfg.policy.buffer_reanalyze_freq >= 1: + if i % reanalyze_interval == 0 and \ + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int( + reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'缓冲区重新分析次数: {buffer_reanalyze_count}') + logging.info(f'缓冲区重新分析耗时: {timer.value}') + + train_data = replay_buffer.sample(batch_size, policy) + train_data.append(cfg.policy.task_id) # 追加task_id以区分任务 + train_data_multi_task.append(train_data) + else: + logging.warning( + f'重放缓冲区中的数据不足以采样mini-batch: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}' + ) + break + + if train_data_multi_task: + # 在训练时,DDP会自动同步梯度和参数 + log_vars = learner.train(train_data_multi_task, envstep_multi_task) + + if cfg.policy.use_priority: + for idx, (cfg, replay_buffer) in enumerate(zip(cfgs, game_buffers)): + # 更新任务特定的重放缓冲区优先级 + task_id = cfg.policy.task_id + replay_buffer.update_priority( + train_data_multi_task[idx], + log_vars[0][f'value_priority_task{task_id}'] + ) + + current_priorities = log_vars[0][f'value_priority_task{task_id}'] + mean_priority = np.mean(current_priorities) + std_priority = np.std(current_priorities) + + alpha = 0.1 # 平滑因子 + if f'running_mean_priority_task{task_id}' not in value_priority_tasks: + value_priority_tasks[f'running_mean_priority_task{task_id}'] = mean_priority + else: + value_priority_tasks[f'running_mean_priority_task{task_id}'] = ( + alpha * mean_priority + + (1 - alpha) * value_priority_tasks[f'running_mean_priority_task{task_id}'] + ) + + # 使用运行均值计算归一化的优先级 + running_mean_priority = value_priority_tasks[f'running_mean_priority_task{task_id}'] + normalized_priorities = (current_priorities - running_mean_priority) / (std_priority + 1e-6) + + # 如果需要,可以将归一化的优先级存储回重放缓冲区 + # replay_buffer.update_priority(train_data_multi_task[idx], normalized_priorities) + + # 记录优先级统计信息 + if cfg.policy.print_task_priority_logs: + print(f"任务 {task_id} - 平均优先级: {mean_priority:.8f}, " + f"运行平均优先级: {running_mean_priority:.8f}, " + f"标准差: {std_priority:.8f}") + + train_epoch += 1 + policy.recompute_pos_emb_diff_and_clear_cache() + + # 同步所有Rank,确保所有Rank完成训练 + try: + dist.barrier() + logging.info(f'Rank {rank}: 通过训练后的同步障碍') + except Exception as e: + logging.error(f'Rank {rank}: 同步障碍失败,错误: {e}') + break + + # 检查是否需要终止训练 + try: + local_envsteps = [collector.envstep for collector in collectors] + total_envsteps = [None for _ in range(world_size)] + dist.all_gather_object(total_envsteps, local_envsteps) + + all_envsteps = torch.cat([torch.tensor(envsteps, device=cfg.policy.device) for envsteps in total_envsteps]) + max_envstep_reached = torch.all(all_envsteps >= max_env_step) + + # 收集所有进程的train_iter + global_train_iter = torch.tensor([learner.train_iter], device=cfg.policy.device) + all_train_iters = [torch.zeros_like(global_train_iter) for _ in range(world_size)] + dist.all_gather(all_train_iters, global_train_iter) + + max_train_iter_reached = torch.any(torch.stack(all_train_iters) >= max_train_iter) + + # 同步所有Rank,确保所有Rank完成训练 + try: + dist.barrier() + logging.info(f'Rank {rank}: 通过训练后的同步障碍') + except Exception as e: + logging.error(f'Rank {rank}: 同步障碍失败,错误: {e}') + break + + if max_envstep_reached.item() or max_train_iter_reached.item(): + logging.info(f'Rank {rank}: 达到终止条件') + + if cfg.policy.eval_offline: + # 对于当前进程的每个任务,进行数据收集和评估 + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + + logging.info(f'Rank {rank} 评估任务_id: {cfg.policy.task_id}: eval offline beginning...') + + # ========= 注意目前只有rank0存储ckpt ========= + # ckpt_dirname = './data_unizero_mt_ddp-8gpu_20241226/8games_brf0.02_seed0/Pong_seed0/ckpt' + + # 让 rank0 生成 ckpt_dirname,其他 Rank 等待接收 + if rank == 0: + ckpt_dirname = './{}/ckpt'.format(learner.exp_name) + logging.info(f'Rank {rank}: 生成 ckpt_dirname 为 {ckpt_dirname}') + else: + ckpt_dirname = None + + # 使用一个列表来存储 ckpt_dirname + ckpt_dirname_list = [ckpt_dirname] + # 广播 ckpt_dirname + dist.broadcast_object_list(ckpt_dirname_list, src=0) + # 从列表中提取更新后的 ckpt_dirname + ckpt_dirname = ckpt_dirname_list[0] + + # 确认所有 Rank 都接收到正确的 ckpt_dirname + logging.info(f'Rank {rank}: 接收到的 ckpt_dirname 为 {ckpt_dirname}') + + # 检查 ckpt_dirname 是否有效 + if not isinstance(ckpt_dirname, str): + logging.error(f'Rank {rank}: 接收到的 ckpt_dirname 无效') + continue + + # Evaluate the performance of the pretrained model. + for train_iter, collector_envstep in zip(eval_train_iter_list, eval_train_envstep_list): + # if train_iter==0: + # continue + ckpt_name = 'iteration_{}.pth.tar'.format(train_iter) + ckpt_path = os.path.join(ckpt_dirname, ckpt_name) + try: + # load the ckpt of pretrained model + policy.learn_mode.load_state_dict(torch.load(ckpt_path, map_location=cfg.policy.device)) + except Exception as e: + logging.error(f'Rank {rank}: load_state_dict 失败,错误: {e}') + continue + + stop, reward = evaluator.eval(learner.save_checkpoint, train_iter, collector_envstep) + logging.info(f'Rank {rank} 评估任务_id: {cfg.policy.task_id}: eval offline at train_iter: {train_iter}, collector_envstep: {collector_envstep}, reward: {reward}') + + logging.info(f'eval_train_envstep_list: {eval_train_envstep_list}, eval_train_iter_list:{eval_train_iter_list}') + + logging.info(f'Rank {rank} 评估任务_id: {cfg.policy.task_id}: eval offline finished!') + + + + dist.barrier() # 确保所有进程同步 + # 评估结束后,显式关闭所有评估器 + for evaluator in evaluators: + evaluator.close() + break + except Exception as e: + logging.error(f'Rank {rank}: 终止检查失败,错误: {e}') + break + + + # 调用learner的after_run钩子 + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/entry/train_unizero_multitask_segment_eval.py b/lzero/entry/train_unizero_multitask_segment_eval.py new file mode 100644 index 000000000..f98e4c41b --- /dev/null +++ b/lzero/entry/train_unizero_multitask_segment_eval.py @@ -0,0 +1,480 @@ +import logging +import os +from functools import partial +from typing import Tuple, Optional, List, Dict, Any + +import torch +import numpy as np +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank, get_world_size, EasyTimer +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.mcts import UniZeroGameBuffer as GameBuffer +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector + +import torch.distributed as dist +import concurrent.futures + +# 设置超时时间 (秒) +TIMEOUT = 12000 # 例如200分钟 + + +def safe_eval( + evaluator: Evaluator, + learner: BaseLearner, + collector: Collector, + rank: int, + world_size: int +) -> Tuple[Optional[bool], Optional[float]]: + """ + Safely evaluates the policy using the evaluator with a timeout. + + Args: + evaluator (Evaluator): The evaluator instance. + learner (BaseLearner): The learner instance. + collector (Collector): The collector instance. + rank (int): The rank of the current process. + world_size (int): Total number of processes. + + Returns: + Tuple[Optional[bool], Optional[float]]: A tuple containing the stop flag and reward. + """ + try: + print(f"=========before eval Rank {rank}/{world_size}===========") + # 重置 stop_event,确保每次评估前都处于未设置状态 + evaluator.stop_event.clear() + with concurrent.futures.ThreadPoolExecutor() as executor: + # 提交 evaluator.eval 任务 + future = executor.submit( + evaluator.eval, + learner.save_checkpoint, + learner.train_iter, + collector.envstep + ) + + try: + stop, reward = future.result(timeout=TIMEOUT) + except concurrent.futures.TimeoutError: + # 超时,设置 evaluator 的 stop_event + evaluator.stop_event.set() + print(f"Eval operation timed out after {TIMEOUT} seconds on Rank {rank}/{world_size}.") + return None, None + + print(f"======after eval Rank {rank}/{world_size}======") + return stop, reward + except Exception as e: + print(f"An error occurred during evaluation on Rank {rank}/{world_size}: {e}") + return None, None + + +def allocate_batch_size( + cfgs: List[Any], + game_buffers: List[GameBuffer], + alpha: float = 1.0, + clip_scale: int = 1 +) -> List[int]: + """ + Allocates batch sizes inversely proportional to the number of collected episodes for each task. + Dynamically adjusts batch size within a specified range to enhance training stability and efficiency. + + Args: + cfgs (List[Any]): List of configurations for each task. + game_buffers (List[GameBuffer]): List of replay buffer instances for each task. + alpha (float): The hyperparameter controlling the degree of inverse proportionality. Default is 1.0. + clip_scale (int): The scaling factor to clip the batch size. Default is 1. + + Returns: + List[int]: A list of allocated batch sizes for each task. + """ + # 提取每个任务的 num_of_collected_episodes + buffer_num_of_collected_episodes = [ + buffer.num_of_collected_episodes for buffer in game_buffers + ] + + # 获取当前的 world_size 和 rank + world_size = get_world_size() + rank = get_rank() + + # 收集所有 rank 的 num_of_collected_episodes 列表 + all_task_num_of_collected_episodes = [None for _ in range(world_size)] + dist.all_gather_object(all_task_num_of_collected_episodes, buffer_num_of_collected_episodes) + + # 将所有 rank 的 num_of_collected_episodes 拼接成一个大列表 + all_task_num_of_collected_episodes = [ + item for sublist in all_task_num_of_collected_episodes for item in sublist + ] + if rank == 0: + print(f'all_task_num_of_collected_episodes: {all_task_num_of_collected_episodes}') + + # 计算每个任务的反比权重 + inv_episodes = np.array([ + 1.0 / (episodes + 1) for episodes in all_task_num_of_collected_episodes + ]) + inv_sum = np.sum(inv_episodes) + + # 计算总的 batch_size (所有任务 cfg.policy.batch_size 的和) + total_batch_size = cfgs[0].policy.total_batch_size + + # 动态调整的部分:最小和最大的 batch_size 范围 + avg_batch_size = total_batch_size / world_size + min_batch_size = avg_batch_size / clip_scale + max_batch_size = avg_batch_size * clip_scale + + # 动态调整 alpha,让 batch_size 的变化更加平滑 + task_weights = (inv_episodes / inv_sum) ** alpha + batch_sizes = total_batch_size * task_weights + + # 控制 batch_size 在 [min_batch_size, max_batch_size] 之间 + batch_sizes = np.clip(batch_sizes, min_batch_size, max_batch_size) + + # 确保 batch_size 是整数 + batch_sizes = [int(size) for size in batch_sizes] + + # 返回最终分配的 batch_size 列表 + return batch_sizes + + +def train_unizero_multitask_segment_eval( + input_cfg_list: List[Tuple[int, Tuple[Dict[str, Any], Dict[str, Any]]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + Overview: + The training entry point for UniZero, as proposed in the paper "UniZero: Generalized and Efficient Planning with Scalable Latent World Models". + UniZero aims to enhance the planning capabilities of reinforcement learning agents by addressing limitations found in MuZero-style algorithms, + particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667. + + Args: + input_cfg_list (List[Tuple[int, Tuple[Dict[str, Any], Dict[str, Any]]]]): + List of configurations for different tasks. Each item is a tuple containing a task ID and a tuple of configuration dictionaries. + seed (int): + Random seed for reproducibility. + model (Optional[torch.nn.Module]): + Instance of torch.nn.Module representing the model. If None, a new model will be created. + model_path (Optional[str]): + Path to a pretrained model checkpoint. Should point to the ckpt file of the pretrained model. + max_train_iter (Optional[int]): + Maximum number of policy update iterations during training. Default is a very large number. + max_env_step (Optional[int]): + Maximum number of environment interaction steps to collect. Default is a very large number. + + Returns: + 'Policy': + The converged policy after training. + """ + # 获取当前进程的 rank 和总的进程数 + rank = get_rank() + world_size = get_world_size() + + # 任务划分 + total_tasks = len(input_cfg_list) + tasks_per_rank = total_tasks // world_size + remainder = total_tasks % world_size + + if rank < remainder: + start_idx = rank * (tasks_per_rank + 1) + end_idx = start_idx + tasks_per_rank + 1 + else: + start_idx = rank * tasks_per_rank + remainder + end_idx = start_idx + tasks_per_rank + + tasks_for_this_rank = input_cfg_list[start_idx:end_idx] + + # 确保至少有一个任务 + if len(tasks_for_this_rank) == 0: + logging.warning(f"Rank {rank}: No tasks assigned, continuing without tasks.") + # 初始化一些空列表以避免后续代码报错 + cfgs, game_buffers, collectors, evaluators = [], [], [], [] + else: + print(f"Rank {rank}/{world_size}, handling tasks {start_idx} to {end_idx - 1}") + + cfgs: List[Any] = [] + game_buffers: List[GameBuffer] = [] + collectors: List[Collector] = [] + evaluators: List[Evaluator] = [] + + # 使用本rank的第一个任务的配置来创建共享的 policy + task_id, (cfg, create_cfg) = tasks_for_this_rank[0] + + # 设置每个任务的 task_num 以用于 learner_log + for config in tasks_for_this_rank: + config[1][0].policy.task_num = tasks_per_rank + + # 确保指定的 policy 类型是支持的 + assert create_cfg.policy.type in [ + 'unizero_multitask'], "train_unizero entry now only supports 'unizero_multitask'" + + # 根据 CUDA 可用性设置设备 + cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' + logging.info(f'cfg.policy.device: {cfg.policy.device}') + + # 编译配置 + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # 创建共享的 policy + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # 如果指定了预训练模型,则加载 + if model_path is not None: + logging.info(f'Loading model from {model_path} begin...') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info(f'Loading model from {model_path} end!') + + # 创建 TensorBoard 的日志记录器 + log_dir = os.path.join('./{}/log'.format(cfg.exp_name), f'serial_rank_{rank}') + tb_logger = SummaryWriter(log_dir) + + # 创建共享的 learner + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + policy_config = cfg.policy + batch_size = policy_config.batch_size[0] + + # 只处理当前进程分配到的任务 + for local_task_id, (task_id, (cfg, create_cfg)) in enumerate(tasks_for_this_rank): + # 设置每个任务自己的随机种子 + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + cfg = compile_config(cfg, seed=seed + task_id, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + policy_config = cfg.policy + policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode + + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(cfg.seed + task_id) + evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False) + set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda) + + # 为每个任务创建不同的 game buffer、collector、evaluator + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + + cfgs.append(cfg) + replay_buffer.batch_size = cfg.policy.batch_size[task_id] + + game_buffers.append(replay_buffer) + collectors.append(collector) + evaluators.append(evaluator) + + learner.call_hook('before_run') + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + update_per_collect = cfg.policy.update_per_collect + + while True: + # 预先计算位置嵌入矩阵(如果需要) + # policy._collect_model.world_model.precompute_pos_emb_diff_kv() + # policy._target_model.world_model.precompute_pos_emb_diff_kv() + + if cfg.policy.allocated_batch_sizes: + # 动态调整 clip_scale 随着 train_epoch 从 0 增加到 1000, clip_scale 从 1 线性增加到 4 + clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4) + allocated_batch_sizes = allocate_batch_size(cfgs, game_buffers, alpha=1.0, clip_scale=clip_scale) + if rank == 0: + print("分配后的 batch_sizes: ", allocated_batch_sizes) + for cfg, _collector, _evaluator, replay_buffer in zip(cfgs, collectors, evaluators, game_buffers): + cfg.policy.batch_size = allocated_batch_sizes + policy._cfg.batch_size = allocated_batch_sizes + + # 对于当前进程的每个任务,进行数据收集和评估 + for cfg, collector, evaluator, replay_buffer in zip(cfgs, collectors, evaluators, game_buffers): + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger, cfg.policy.task_id) + + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 # 默认的 epsilon 值 + } + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + if evaluator.should_eval(learner.train_iter): + print('=' * 20) + print(f'Rank {rank} evaluates task_id: {cfg.policy.task_id}...') + + # 在训练进程中调用 safe_eval + stop, reward = safe_eval(evaluator, learner, collector, rank, world_size) + # 判断评估是否成功 + if stop is None or reward is None: + print(f"Rank {rank} encountered an issue during evaluation. Continuing training...") + else: + print(f"Evaluation successful: stop={stop}, reward={reward}") + + print('=' * 20) + print(f'entry: Rank {rank} collects task_id: {cfg.policy.task_id}...') + + # NOTE: 在每次收集之前重置初始数据,这对于多任务设置非常重要 + collector._policy.reset(reset_init_data=True) + # 收集数据 + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + # 更新 replay buffer + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + # 周期性地重新分析缓冲区 + if cfg.policy.buffer_reanalyze_freq >= 1: + # 在一个训练 epoch 中重新分析缓冲区 次 + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + # 每 <1/buffer_reanalyze_freq> 个训练 epoch 重新分析一次缓冲区 + if (train_epoch % int(1 / cfg.policy.buffer_reanalyze_freq) == 0 and + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > + int(reanalyze_batch_size / cfg.policy.reanalyze_partition)): + with EasyTimer() as timer: + # 每个重新分析过程将重新分析 个序列 + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + logging.info(f'Buffer reanalyze time: {timer.value}') + + # 数据收集结束后添加日志 + logging.info(f'Rank {rank}: Completed data collection for task {cfg.policy.task_id}') + + # 检查是否有足够的数据进行训练 + not_enough_data = any( + replay_buffer.get_num_of_transitions() < cfgs[0].policy.total_batch_size / world_size + for replay_buffer in game_buffers + ) + + # 同步训练前所有 rank 的准备状态 + try: + dist.barrier() + logging.info(f'Rank {rank}: Passed barrier before training') + except Exception as e: + logging.error(f'Rank {rank}: Barrier failed with error {e}') + break # 或者进行其他错误处理 + + # 学习策略 + if not not_enough_data: + # Learner 将在一次迭代中训练 update_per_collect 次 + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_multi_task = 0 + for cfg, collector, replay_buffer in zip(cfgs, collectors, game_buffers): + envstep_multi_task += collector.envstep + batch_size = cfg.policy.batch_size[cfg.policy.task_id] + if replay_buffer.get_num_of_transitions() > batch_size: + if cfg.policy.buffer_reanalyze_freq >= 1: + # 在一个训练 epoch 中重新分析缓冲区 次 + if (i % reanalyze_interval == 0 and + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > + int(reanalyze_batch_size / cfg.policy.reanalyze_partition)): + with EasyTimer() as timer: + # 每个重新分析过程将重新分析 个序列 + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + logging.info(f'Buffer reanalyze time: {timer.value}') + + train_data = replay_buffer.sample(batch_size, policy) + # 追加 task_id,以便在训练时区分任务 + train_data.append(cfg.policy.task_id) + train_data_multi_task.append(train_data) + else: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}' + ) + break + + if train_data_multi_task: + # 在训练时,DDP 会自动同步梯度和参数 + log_vars = learner.train(train_data_multi_task, envstep_multi_task) + + # 同步训练前所有 rank 的准备状态 + try: + dist.barrier() + logging.info(f'Rank {rank}: Passed barrier during training') + except Exception as e: + logging.error(f'Rank {rank}: Barrier failed with error {e}') + break # 或者进行其他错误处理 + + # TODO: 可选:终止进程 + import sys + sys.exit(0) + + train_epoch += 1 + policy.recompute_pos_emb_diff_and_clear_cache() + + # 同步所有 Rank,确保所有 Rank 都完成了训练 + try: + dist.barrier() + logging.info(f'Rank {rank}: Passed barrier after training') + except Exception as e: + logging.error(f'Rank {rank}: Barrier failed with error {e}') + break # 或者进行其他错误处理 + + # 检查是否需要终止训练 + try: + # 收集本地的 envsteps + local_envsteps = [collector.envstep for collector in collectors] + + # 收集所有进程的 envsteps + total_envsteps: List[Optional[int]] = [None for _ in range(world_size)] + dist.all_gather_object(total_envsteps, local_envsteps) + + # 将所有 envsteps 拼接在一起进行检查 + all_envsteps = torch.cat([ + torch.tensor(envsteps, device=cfg.policy.device) for envsteps in total_envsteps + ]) + max_envstep_reached = torch.all(all_envsteps >= max_env_step) + + # 收集所有进程的 train_iter + global_train_iter = torch.tensor([learner.train_iter], device=cfg.policy.device) + all_train_iters = [torch.zeros_like(global_train_iter) for _ in range(world_size)] + dist.all_gather(all_train_iters, global_train_iter) + + max_train_iter_reached = torch.any(torch.stack(all_train_iters) >= max_train_iter) + + if max_envstep_reached.item() or max_train_iter_reached.item(): + logging.info(f'Rank {rank}: Termination condition met') + dist.barrier() # 确保所有进程同步 + break + except Exception as e: + logging.error(f'Rank {rank}: Termination check failed with error {e}') + break # 或者进行其他错误处理 + + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/entry/train_unizero_multitask_segment_serial.py b/lzero/entry/train_unizero_multitask_segment_serial.py new file mode 100644 index 000000000..9302ba19d --- /dev/null +++ b/lzero/entry/train_unizero_multitask_segment_serial.py @@ -0,0 +1,295 @@ +import logging +import os +from functools import partial +from typing import List, Optional, Tuple + +import numpy as np +import torch +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import EasyTimer, set_pkg_seed, get_rank +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage +from lzero.mcts import UniZeroGameBuffer as GameBuffer +from lzero.policy import visit_count_temperature +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector + + +timer = EasyTimer() + + +def train_unizero_multitask_segment_serial( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + 概述: + UniZero的训练入口,基于论文《UniZero: Generalized and Efficient Planning with Scalable Latent World Models》提出。 + UniZero旨在通过解决MuZero风格算法在需要捕捉长期依赖的环境中的局限性,增强强化学习代理的规划能力。 + 详细内容可参考 https://arxiv.org/abs/2406.10667。 + + 参数: + - input_cfg_list (List[Tuple[int, Tuple[dict, dict]]]): 不同任务的配置列表。 + - seed (int): 随机种子。 + - model (Optional[torch.nn.Module]): torch.nn.Module的实例。 + - model_path (Optional[str]): 预训练模型路径,应指向预训练模型的ckpt文件。 + - max_train_iter (Optional[int]): 训练中的最大策略更新迭代次数。 + - max_env_step (Optional[int]): 收集环境交互步骤的最大数量。 + + 返回: + - policy (Policy): 收敛的策略对象。 + """ + cfgs = [] + game_buffers = [] + collector_envs = [] + evaluator_envs = [] + collectors = [] + evaluators = [] + + # 获取第一个任务的配置 + task_id, [cfg, create_cfg] = input_cfg_list[0] + + # 确保指定的策略类型受支持 + assert create_cfg.policy.type in ['unizero_multitask'], "train_unizero entry 目前仅支持 'unizero_multitask'" + + # 根据CUDA可用性设置设备 + cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' + logging.info(f'cfg.policy.device: {cfg.policy.device}') + + # 编译配置 + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # 为所有任务创建共享策略 + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # 如果指定了预训练模型路径,加载预训练模型 + if model_path is not None: + logging.info(f'开始加载模型: {model_path}') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info(f'完成加载模型: {model_path}') + + # 为TensorBoard日志创建SummaryWriter + tb_logger = SummaryWriter(os.path.join(f'./{cfg.exp_name}/log/', 'serial')) if get_rank() == 0 else None + # 为所有任务创建共享学习器 + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + policy_config = cfg.policy + batch_size = policy_config.batch_size[0] + + # 遍历所有任务的配置 + for task_id, input_cfg in input_cfg_list: + if task_id > 0: + cfg, create_cfg = input_cfg + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + policy_config = cfg.policy + # 更新收集和评估模式的配置 + policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode + + # 创建环境管理器 + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(cfg.seed + task_id) + evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False) + set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda) + + # 创建各任务专属的游戏缓存、收集器和评估器 + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + + cfgs.append(cfg) + replay_buffer.batch_size = cfg.policy.batch_size[task_id] + game_buffers.append(replay_buffer) + collector_envs.append(collector_env) + evaluator_envs.append(evaluator_env) + collectors.append(collector) + evaluators.append(evaluator) + + learner.call_hook('before_run') + value_priority_tasks = {} + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + update_per_collect = cfg.policy.update_per_collect + + while True: + # 预计算收集和评估时的位置嵌入矩阵(非训练阶段) + # policy._collect_model.world_model.precompute_pos_emb_diff_kv() + # policy._target_model.world_model.precompute_pos_emb_diff_kv() + + # 为每个任务收集数据 + for task_id, (cfg, collector, evaluator, replay_buffer) in enumerate(zip(cfgs, collectors, evaluators, game_buffers)): + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 # 默认epsilon值 + } + + # 如果启用了epsilon-greedy探索,计算当前epsilon值 + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + # 评估阶段 + if evaluator.should_eval(learner.train_iter): + print('=' * 20) + print(f'开始评估任务 id: {task_id}...') + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + print('=' * 20) + print(f'开始收集任务 id: {task_id}...') + + # 在每次收集前重置初始数据,对于多任务设置非常重要 + collector._policy.reset(reset_init_data=True) + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + # 确定每次收集后的更新次数 + if update_per_collect is None: + # 如果未设置update_per_collect,则根据收集的转换数量和重放比例计算 + collected_transitions_num = sum( + min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0] + ) + update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio) + + # 更新重放缓存 + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + # 定期重新分析重放缓存 + if cfg.policy.buffer_reanalyze_freq >= 1: + # 一个训练epoch内重新分析buffer的次数 + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + # 每隔一定数量的训练epoch重新分析buffer + if train_epoch % int(1 / cfg.policy.buffer_reanalyze_freq) == 0 and replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int(reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + # 每次重新分析处理reanalyze_batch_size个序列 + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'重放缓存重新分析次数: {buffer_reanalyze_count}') + logging.info(f'重放缓存重新分析时间: {timer.value}') + + # 检查是否有重放缓存数据不足 + not_enough_data = any(replay_buffer.get_num_of_transitions() < batch_size for replay_buffer in game_buffers) + + # 从收集的数据中学习策略 + if not not_enough_data: + # 学习器将在一次迭代中进行update_per_collect次训练 + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_multi_task = 0 + for task_id, (cfg, collector, replay_buffer) in enumerate(zip(cfgs, collectors, game_buffers)): + envstep_multi_task += collector.envstep + if replay_buffer.get_num_of_transitions() > batch_size: + batch_size = cfg.policy.batch_size[task_id] + + if cfg.policy.buffer_reanalyze_freq >= 1: + # 在一个训练epoch内按照频率重新分析buffer + if i % reanalyze_interval == 0 and replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int(reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'重放缓存重新分析次数: {buffer_reanalyze_count}') + logging.info(f'重放缓存重新分析时间: {timer.value}') + + train_data = replay_buffer.sample(batch_size, policy) + # 将task_id附加到训练数据 + train_data.append(task_id) + train_data_multi_task.append(train_data) + else: + logging.warning( + f'重放缓存中的数据不足以采样一个小批量: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}' + ) + break + + if train_data_multi_task: + log_vars = learner.train(train_data_multi_task, envstep_multi_task) + + # 如果使用优先级重放,更新各任务的优先级 + if cfg.policy.use_priority: + for task_id, replay_buffer in enumerate(game_buffers): + # 更新任务特定重放缓存的优先级 + replay_buffer.update_priority(train_data_multi_task[task_id], log_vars[0][f'value_priority_task{task_id}']) + + # 获取当前任务的更新后优先级 + current_priorities = log_vars[0][f'value_priority_task{task_id}'] + + # 计算优先级的均值和标准差 + mean_priority = np.mean(current_priorities) + std_priority = np.std(current_priorities) + + # 使用指数移动平均计算运行中的均值 + alpha = 0.1 # 平滑因子,可根据需要调整 + if f'running_mean_priority_task{task_id}' not in value_priority_tasks: + # 初始化运行均值 + value_priority_tasks[f'running_mean_priority_task{task_id}'] = mean_priority + else: + # 更新运行均值 + value_priority_tasks[f'running_mean_priority_task{task_id}'] = ( + alpha * mean_priority + + (1 - alpha) * value_priority_tasks[f'running_mean_priority_task{task_id}'] + ) + + # 计算归一化优先级 + running_mean_priority = value_priority_tasks[f'running_mean_priority_task{task_id}'] + normalized_priorities = (current_priorities - running_mean_priority) / (std_priority + 1e-6) + + # 记录统计信息 + if cfg.policy.print_task_priority_logs: + print( + f"任务 {task_id} - 优先级均值: {mean_priority:.8f}, " + f"运行均值优先级: {running_mean_priority:.8f}, " + f"标准差: {std_priority:.8f}" + ) + + train_epoch += 1 + policy.recompute_pos_emb_diff_and_clear_cache() + + # 检查是否达到训练或环境步数的最大限制 + if all(collector.envstep >= max_env_step for collector in collectors) or learner.train_iter >= max_train_iter: + break + + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/entry/train_unizero_multitask_serial.py b/lzero/entry/train_unizero_multitask_serial.py new file mode 100644 index 000000000..0a5aaae25 --- /dev/null +++ b/lzero/entry/train_unizero_multitask_serial.py @@ -0,0 +1,256 @@ +import logging +import os +from functools import partial +from typing import Tuple, Optional, List + +import torch +import numpy as np +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.worker import MuZeroCollector as Collector, MuZeroEvaluator as Evaluator +from lzero.mcts import UniZeroGameBuffer as GameBuffer + +from line_profiler import line_profiler + +#@profile +def train_unizero_multitask_serial( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + Overview: + The train entry for UniZero, proposed in our paper UniZero: Generalized and Efficient Planning with Scalable Latent World Models. + UniZero aims to enhance the planning capabilities of reinforcement learning agents by addressing the limitations found in MuZero-style algorithms, + particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667. + Arguments: + - input_cfg_list (List[Tuple[int, Tuple[dict, dict]]]): List of configurations for different tasks. + - seed (int): Random seed. + - model (Optional[torch.nn.Module]): Instance of torch.nn.Module. + - model_path (Optional[str]): The pretrained model path, which should point to the ckpt file of the pretrained model. + - max_train_iter (Optional[int]): Maximum policy update iterations in training. + - max_env_step (Optional[int]): Maximum collected environment interaction steps. + Returns: + - policy (Policy): Converged policy. + """ + cfgs = [] + game_buffers = [] + collector_envs = [] + evaluator_envs = [] + collectors = [] + evaluators = [] + + task_id, [cfg, create_cfg] = input_cfg_list[0] + + # Ensure the specified policy type is supported + assert create_cfg.policy.type in ['unizero_multitask'], "train_unizero entry now only supports 'unizero'" + + # Set device based on CUDA availability + cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' + logging.info(f'cfg.policy.device: {cfg.policy.device}') + + # Compile the configuration + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create shared policy for all tasks + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # Load pretrained model if specified + if model_path is not None: + logging.info(f'Loading model from {model_path} begin...') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info(f'Loading model from {model_path} end!') + + # Create SummaryWriter for TensorBoard logging + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None + # Create shared learner for all tasks + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + # TODO task_id = 0: + policy_config = cfg.policy + batch_size = policy_config.batch_size[0] + + for task_id, input_cfg in input_cfg_list: + if task_id > 0: + # Get the configuration for each task + cfg, create_cfg = input_cfg + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + policy_config = cfg.policy + policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode + + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(cfg.seed + task_id) + evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False) + set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda) + + # ===== NOTE: Create different game buffer, collector, evaluator for each task ==== + # TODO: share replay buffer for all tasks + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + + cfgs.append(cfg) + replay_buffer.batch_size = cfg.policy.batch_size[task_id] + game_buffers.append(replay_buffer) + collector_envs.append(collector_env) + evaluator_envs.append(evaluator_env) + collectors.append(collector) + evaluators.append(evaluator) + + learner.call_hook('before_run') + value_priority_tasks = {} + update_per_collect = cfg.policy.update_per_collect + + while True: + # Precompute positional embedding matrices for collect/eval (not training) + policy._collect_model.world_model.precompute_pos_emb_diff_kv() + policy._target_model.world_model.precompute_pos_emb_diff_kv() + + # Collect data for each task + for task_id, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 # Default epsilon value + } + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + if evaluator.should_eval(learner.train_iter): + print('=' * 20) + print(f'evaluate task_id: {task_id}...') + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + print('=' * 20) + print(f'collect task_id: {task_id}...') + + # Reset initial data before each collection + collector._policy.reset(reset_init_data=True) + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + # Determine updates per collection + if update_per_collect is None: + collected_transitions_num = sum(len(game_segment) for game_segment in new_data[0]) + update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio) + + # Update replay buffer + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + not_enough_data = any(replay_buffer.get_num_of_transitions() < batch_size for replay_buffer in game_buffers) + + # Learn policy from collected data. + if not not_enough_data: + # Learner will train ``update_per_collect`` times in one iteration. + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_multi_task = 0 + for task_id, (cfg, collector, replay_buffer) in enumerate(zip(cfgs, collectors, game_buffers)): + envstep_multi_task += collector.envstep + if replay_buffer.get_num_of_transitions() > batch_size: + batch_size = cfg.policy.batch_size[task_id] + train_data = replay_buffer.sample(batch_size, policy) + if cfg.policy.reanalyze_ratio > 0 and i % 20 == 0: + policy.recompute_pos_emb_diff_and_clear_cache() + # Append task_id to train_data + train_data.append(task_id) + train_data_multi_task.append(train_data) + else: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}' + ) + break + + if train_data_multi_task: + log_vars = learner.train(train_data_multi_task, envstep_multi_task) + + if cfg.policy.use_priority: + for task_id, replay_buffer in enumerate(game_buffers): + # Update the priority for the task-specific replay buffer. + replay_buffer.update_priority(train_data_multi_task[task_id], log_vars[0][f'value_priority_task{task_id}']) + + # Retrieve the updated priorities for the current task. + current_priorities = log_vars[0][f'value_priority_task{task_id}'] + + # Calculate statistics: mean, running mean, standard deviation for the priorities. + mean_priority = np.mean(current_priorities) + std_priority = np.std(current_priorities) + + # Using exponential moving average for running mean (alpha is the smoothing factor). + alpha = 0.1 # You can adjust this smoothing factor as needed. + if f'running_mean_priority_task{task_id}' not in value_priority_tasks: + # Initialize running mean if it does not exist. + value_priority_tasks[f'running_mean_priority_task{task_id}'] = mean_priority + else: + # Update running mean. + value_priority_tasks[f'running_mean_priority_task{task_id}'] = ( + alpha * mean_priority + (1 - alpha) * value_priority_tasks[f'running_mean_priority_task{task_id}'] + ) + + # Calculate the normalized priority using the running mean. + running_mean_priority = value_priority_tasks[f'running_mean_priority_task{task_id}'] + normalized_priorities = (current_priorities - running_mean_priority) / (std_priority + 1e-6) + + # Store the normalized priorities back to the replay buffer (if needed). + # replay_buffer.update_priority(train_data_multi_task[task_id], normalized_priorities) + + # Log the statistics if the print_task_priority_logs flag is set. + if cfg.policy.print_task_priority_logs: + print(f"Task {task_id} - Mean Priority: {mean_priority:.8f}, " + f"Running Mean Priority: {running_mean_priority:.8f}, " + f"Standard Deviation: {std_priority:.8f}") + + + if all(collector.envstep >= max_env_step for collector in collectors) or learner.train_iter >= max_train_iter: + break + + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py index d2c23f930..97cebdc9c 100644 --- a/lzero/entry/utils.py +++ b/lzero/entry/utils.py @@ -63,7 +63,7 @@ def random_collect( collector.reset_policy(policy.collect_mode) -def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None: +def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter, task_id=0) -> None: """ Overview: Log the memory usage of the buffer and the current process to TensorBoard. @@ -74,9 +74,9 @@ def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: Summa """ # "writer is None" means we are in a slave process in the DDP setup. if writer is not None: - writer.add_scalar('Buffer/num_of_all_collected_episodes', buffer.num_of_collected_episodes, train_iter) - writer.add_scalar('Buffer/num_of_game_segments', len(buffer.game_segment_buffer), train_iter) - writer.add_scalar('Buffer/num_of_transitions', len(buffer.game_segment_game_pos_look_up), train_iter) + writer.add_scalar(f'Buffer/num_of_all_collected_episodes_{task_id}', buffer.num_of_collected_episodes, train_iter) + writer.add_scalar(f'Buffer/num_of_game_segments_{task_id}', len(buffer.game_segment_buffer), train_iter) + writer.add_scalar(f'Buffer/num_of_transitions_{task_id}', len(buffer.game_segment_game_pos_look_up), train_iter) game_segment_buffer = buffer.game_segment_buffer @@ -87,7 +87,7 @@ def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: Summa buffer_memory_usage_mb = buffer_memory_usage / (1024 * 1024) # Record the memory usage of self.game_segment_buffer to TensorBoard. - writer.add_scalar('Buffer/memory_usage/game_segment_buffer', buffer_memory_usage_mb, train_iter) + writer.add_scalar(f'Buffer/memory_usage/game_segment_buffer_{task_id}', buffer_memory_usage_mb, train_iter) # Get the amount of memory currently used by the process (in bytes). process = psutil.Process(os.getpid()) @@ -97,7 +97,7 @@ def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: Summa process_memory_usage_mb = process_memory_usage / (1024 * 1024) # Record the memory usage of the process to TensorBoard. - writer.add_scalar('Buffer/memory_usage/process', process_memory_usage_mb, train_iter) + writer.add_scalar(f'Buffer/memory_usage/process_{task_id}', process_memory_usage_mb, train_iter) def log_buffer_run_time(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None: diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index 2678066e9..3383dd2ef 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -102,22 +102,23 @@ def _make_batch(self, orig_data: Any, reanalyze_ratio: float) -> Tuple[Any]: """ pass - def _sample_orig_data(self, batch_size: int) -> Tuple: + def _sample_orig_data(self, batch_size: int, print_priority_logs: bool = False) -> Tuple: """ Overview: - sample orig_data that contains: - game_segment_list: a list of game segments - pos_in_game_segment_list: transition index in game (relative index) - batch_index_list: the index of start transition of sampled minibatch in replay buffer - weights_list: the weight concerning the priority - make_time: the time the batch is made (for correctly updating replay buffer when data is deleted) + Sample original data which includes: + - game_segment_list: A list of game segments. + - pos_in_game_segment_list: Transition index in the game (relative index). + - batch_index_list: The index of the start transition of the sampled mini-batch in the replay buffer. + - weights_list: The weight concerning the priority. + - make_time: The time the batch is made (for correctly updating the replay buffer when data is deleted). Arguments: - - batch_size (:obj:`int`): batch size - - beta: float the parameter in PER for calculating the priority + - batch_size (:obj:`int`): The size of the batch. + - print_priority_logs (:obj:`bool`): Whether to print logs related to priority statistics, defaults to False. """ - assert self._beta > 0 + assert self._beta > 0, "Beta should be greater than 0" num_of_transitions = self.get_num_of_transitions() - if self._cfg.use_priority is False: + if not self._cfg.use_priority: + # If priority is not used, set all priorities to 1 self.game_pos_priorities = np.ones_like(self.game_pos_priorities) # +1e-6 for numerical stability @@ -126,20 +127,21 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: # sample according to transition index batch_index_list = np.random.choice(num_of_transitions, batch_size, p=probs, replace=False) - - if self._cfg.reanalyze_outdated is True: - # NOTE: used in reanalyze part + + if self._cfg.reanalyze_outdated: + # Sort the batch indices if reanalyze is enabled batch_index_list.sort() - + + # Calculate weights for the sampled transitions weights_list = (num_of_transitions * probs[batch_index_list]) ** (-self._beta) - weights_list /= weights_list.max() + weights_list /= weights_list.max() # Normalize weights game_segment_list = [] pos_in_game_segment_list = [] for idx in batch_index_list: game_segment_idx, pos_in_game_segment = self.game_segment_game_pos_look_up[idx] - game_segment_idx -= self.base_idx + game_segment_idx -= self.base_idx # Adjust index based on base index game_segment = self.game_segment_buffer[game_segment_idx] game_segment_list.append(game_segment) @@ -151,14 +153,10 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: # Indices exceeding `game_segment_length` are padded with the next segment and are not updated # in the current implementation. Therefore, we need to sample `pos_in_game_segment` within # [0, game_segment_length - num_unroll_steps] to avoid padded data. - # TODO: Consider increasing `self._cfg.game_segment_length` to ensure sampling efficiency. - # if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps: - # pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps, 1).item() - # NOTE: Sample the init position from the whole segment, but not from the padded part - if pos_in_game_segment >= self._cfg.game_segment_length: - pos_in_game_segment = np.random.choice(self._cfg.game_segment_length, 1).item() + if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps: + pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps, 1).item() pos_in_game_segment_list.append(pos_in_game_segment) @@ -166,6 +164,12 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: make_time = [time.time() for _ in range(len(batch_index_list))] orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time) + + if print_priority_logs: + print(f"Sampled batch indices: {batch_index_list}") + print(f"Sampled priorities: {self.game_pos_priorities[batch_index_list]}") + print(f"Sampled weights: {weights_list}") + return orig_data def _sample_orig_reanalyze_batch(self, batch_size: int) -> Tuple: @@ -585,7 +589,8 @@ def remove_oldest_data_to_fit(self) -> None: Overview: remove some oldest data if the replay buffer is full. """ - assert self.replay_buffer_size > self._cfg.batch_size, "replay buffer size should be larger than batch size" + if isinstance(self._cfg.batch_size, int): + assert self.replay_buffer_size > self._cfg.batch_size, "replay buffer size should be larger than batch size" nums_of_game_segments = self.get_num_of_game_segments() total_transition = self.get_num_of_transitions() if total_transition > self.replay_buffer_size: @@ -597,8 +602,15 @@ def remove_oldest_data_to_fit(self) -> None: # find the max game_segment index to keep in the buffer index = i break - if total_transition >= self._cfg.batch_size: - self._remove(index + 1) + if isinstance(self._cfg.batch_size, int): + if total_transition >= self._cfg.batch_size: + self._remove(index + 1) + else: + try: + if total_transition >= self._cfg.batch_size[0]: + self._remove(index + 1) + except Exception as e: + print(e) def _remove(self, excess_game_segment_index: List[int]) -> None: """ diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index 2ba8180de..6a5bcf218 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -61,6 +61,13 @@ def __init__(self, cfg: dict): self.sample_times = 0 self.active_root_num = 0 + if hasattr(self._cfg, 'task_id'): + self.task_id = self._cfg.task_id + print(f"Task ID is set to {self.task_id}.") + else: + self.task_id = None + print("No task_id found in configuration. Task ID is set to None.") + def reset_runtime_metrics(self): """ Overview: @@ -448,7 +455,11 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A end_index = self._cfg.mini_infer_size * (i + 1) m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device) # calculate the target value - m_output = model.initial_inference(m_obs) + if self.task_id is not None: + m_output = model.initial_inference(m_obs, task_id=self.task_id) + else: + m_output = model.initial_inference(m_obs) + if not model.training: # if not in training, obtain the scalars of the value/reward @@ -573,7 +584,10 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: beg_index = self._cfg.mini_infer_size * i end_index = self._cfg.mini_infer_size * (i + 1) m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device) - m_output = model.initial_inference(m_obs) + if self.task_id is not None: + m_output = model.initial_inference(m_obs, task_id=self.task_id) + else: + m_output = model.initial_inference(m_obs) if not model.training: # if not in training, obtain the scalars of the value/reward @@ -603,7 +617,11 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: roots.prepare_no_noise(reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model with self._origin_search_timer: - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) + if self.task_id is not None: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) + else: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) + self.origin_search_time += self._origin_search_timer.value else: # python mcts_tree @@ -613,7 +631,11 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: else: roots.prepare_no_noise(reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) + if self.task_id is not None: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) + else: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) + roots_legal_actions_list = legal_actions roots_distributions = roots.get_distributions() @@ -757,6 +779,7 @@ def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) - NOTE: train_data = [current_batch, target_batch] current_batch = [obs_list, action_list, improved_policy_list(only in Gumbel MuZero), mask_list, batch_index_list, weights, make_time_list] + target_batch = [batch_rewards, batch_target_values, batch_target_policies] """ indices = train_data[0][-3] metas = {'make_time': train_data[0][-1], 'batch_priorities': batch_priorities} diff --git a/lzero/mcts/buffer/game_buffer_sampled_unizero.py b/lzero/mcts/buffer/game_buffer_sampled_unizero.py index abb7c92a8..24c9c78aa 100644 --- a/lzero/mcts/buffer/game_buffer_sampled_unizero.py +++ b/lzero/mcts/buffer/game_buffer_sampled_unizero.py @@ -112,6 +112,7 @@ def _make_batch_for_reanalyze(self, batch_size: int, reanalyze_ratio: float) -> mask_tmp = [1. for i in range(len(root_sampled_actions_tmp))] mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] + # pad random action if self._cfg.model.continuous_action_space: actions_tmp += [ diff --git a/lzero/mcts/buffer/game_buffer_unizero.py b/lzero/mcts/buffer/game_buffer_unizero.py index fcd47d851..642ebff13 100644 --- a/lzero/mcts/buffer/game_buffer_unizero.py +++ b/lzero/mcts/buffer/game_buffer_unizero.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy +from line_profiler import line_profiler @BUFFER_REGISTRY.register('game_buffer_unizero') @@ -48,6 +49,14 @@ def __init__(self, cfg: dict): self.game_segment_game_pos_look_up = [] self.sample_type = self._cfg.sample_type # 'transition' or 'episode' + if hasattr(self._cfg, 'task_id'): + self.task_id = self._cfg.task_id + print(f"Task ID is set to {self.task_id}.") + else: + self.task_id = None + print("No task_id found in configuration. Task ID is set to None.") + + #@profile def sample( self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] ) -> List[Any]: @@ -94,6 +103,7 @@ def sample( train_data = [current_batch, target_batch] return train_data + #@profile def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: """ Overview: @@ -136,6 +146,10 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: mask_tmp = [1. for i in range(min(len(actions_tmp), self._cfg.game_segment_length - pos_in_game_segment))] mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] + # TODO: original buffer mask + # mask_tmp = [1. for i in range(min(len(actions_tmp), self._cfg.game_segment_length - pos_in_game_segment))] + # mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] + # pad random action actions_tmp += [ np.random.randint(0, game.action_space_size) @@ -411,7 +425,12 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: # =============== NOTE: The key difference with MuZero ================= # To obtain the target policy from MCTS guided by the recent target model # TODO: batch_obs (policy_obs_list) is at timestep t, batch_action is at timestep t - m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num]) # NOTE: :self.reanalyze_num + + if self.task_id is not None: + m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num], task_id=self.task_id) # NOTE: :self.reanalyze_num + else: + m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num]) # NOTE: :self.reanalyze_num + # ======================================================================= if not model.training: @@ -438,13 +457,19 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: roots = MCTSCtree.roots(transition_batch_size, legal_actions) roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) + if self.task_id is not None: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) + else: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) else: # python mcts_tree roots = MCTSPtree.roots(transition_batch_size, legal_actions) roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) + if self.task_id is not None: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) + else: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) roots_legal_actions_list = legal_actions roots_distributions = roots.get_distributions() @@ -519,7 +544,11 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A # =============== NOTE: The key difference with MuZero ================= # calculate the bootstrapped value and target value # NOTE: batch_obs(value_obs_list) is at t+td_steps, batch_action is at timestep t+td_steps - m_output = model.initial_inference(batch_obs, batch_action) + if self.task_id is not None: + m_output = model.initial_inference(batch_obs, batch_action, task_id=self.task_id) + else: + m_output = model.initial_inference(batch_obs, batch_action) + # ====================================================================== if not model.training: diff --git a/lzero/mcts/ctree/ctree_sampled_muzero/lib/cnode.cpp b/lzero/mcts/ctree/ctree_sampled_muzero/lib/cnode.cpp index 7c5d11dd2..83f50e2da 100644 --- a/lzero/mcts/ctree/ctree_sampled_muzero/lib/cnode.cpp +++ b/lzero/mcts/ctree/ctree_sampled_muzero/lib/cnode.cpp @@ -22,6 +22,7 @@ #include #include + #ifdef _WIN32 #include "..\..\common_lib\utils.cpp" #else diff --git a/lzero/mcts/tree_search/mcts_ctree.py b/lzero/mcts/tree_search/mcts_ctree.py index dd58e8682..50d4b0927 100644 --- a/lzero/mcts/tree_search/mcts_ctree.py +++ b/lzero/mcts/tree_search/mcts_ctree.py @@ -15,6 +15,7 @@ from lzero.mcts.ctree.ctree_muzero import mz_tree as mz_ctree from lzero.mcts.ctree.ctree_gumbel_muzero import gmz_tree as gmz_ctree +from line_profiler import line_profiler class UniZeroMCTSCtree(object): """ @@ -71,10 +72,10 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "m from lzero.mcts.ctree.ctree_muzero import mz_tree as ctree return ctree.Roots(active_collect_env_num, legal_actions) - # @profile + #@profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, - List[Any]] + List[Any]], task_id=None ) -> None: """ Overview: @@ -144,7 +145,12 @@ def search( At the end of the simulation, the statistics along the trajectory are updated. """ # for UniZero - network_output = model.recurrent_inference(state_action_history, simulation_index, latent_state_index_in_search_path) + if task_id is not None: + # multi task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, latent_state_index_in_search_path, task_id=task_id) + else: + # single task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, latent_state_index_in_search_path) network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) @@ -225,10 +231,10 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "m from lzero.mcts.ctree.ctree_muzero import mz_tree as ctree return ctree.Roots(active_collect_env_num, legal_actions) - # @profile + # #@profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, - List[Any]] + List[Any]], task_id=None ) -> None: """ Overview: @@ -298,6 +304,13 @@ def search( """ network_output = model.recurrent_inference(latent_states, last_actions) # for classic muzero + if task_id is not None: + # multi task setting + network_output = model.recurrent_inference(latent_states, last_actions, task_id=task_id) + else: + # single task setting + network_output = model.recurrent_inference(latent_states, last_actions) + network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) network_output.value = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value)) @@ -495,7 +508,7 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "e """ return tree_muzero.Roots(active_collect_env_num, legal_actions) - # @profile + # #@profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], world_model_latent_history_roots: List[Any], to_play_batch: Union[int, List[Any]], ready_env_id=None, diff --git a/lzero/model/common.py b/lzero/model/common.py index 22afa95fe..7d8005fc1 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -185,6 +185,8 @@ def __init__(self, observation_shape: SequenceType, out_channels: int, super().__init__() assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + assert num_resblocks == 1, "num_resblocks must be 1 in DownSample" + self.observation_shape = observation_shape self.conv1 = nn.Conv2d( observation_shape[0], @@ -231,7 +233,7 @@ def __init__(self, observation_shape: SequenceType, out_channels: int, [ ResBlock( in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False - ) for _ in range(1) + ) for _ in range(num_resblocks) ] ) self.pooling2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) @@ -318,6 +320,7 @@ def __init__( num_channels, activation=activation, norm_type=norm_type, + num_resblocks=1, ) else: self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False) @@ -343,10 +346,10 @@ def __init__( self.embedding_dim = embedding_dim if self.observation_shape[1] == 64: - self.last_linear = nn.Linear(64 * 8 * 8, self.embedding_dim, bias=False) + self.last_linear = nn.Linear(num_channels * 8 * 8, self.embedding_dim, bias=False) elif self.observation_shape[1] in [84, 96]: - self.last_linear = nn.Linear(64 * 6 * 6, self.embedding_dim, bias=False) + self.last_linear = nn.Linear(num_channels * 6 * 6, self.embedding_dim, bias=False) self.sim_norm = SimNorm(simnorm_dim=group_size) @@ -817,9 +820,9 @@ def __init__( self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1) if observation_shape[1] == 96: - latent_shape = (observation_shape[1] / 16, observation_shape[2] / 16) + latent_shape = (observation_shape[1] // 16, observation_shape[2] // 16) elif observation_shape[1] == 64: - latent_shape = (observation_shape[1] / 8, observation_shape[2] / 8) + latent_shape = (observation_shape[1] // 8, observation_shape[2] // 8) if norm_type == 'BN': self.norm_value = nn.BatchNorm2d(value_head_channels) diff --git a/lzero/model/muzero_model_multitask.py b/lzero/model/muzero_model_multitask.py new file mode 100644 index 000000000..6d7326152 --- /dev/null +++ b/lzero/model/muzero_model_multitask.py @@ -0,0 +1,389 @@ +from typing import Optional, Tuple + +import math +import torch +import torch.nn as nn +from ding.torch_utils import MLP, ResBlock +from ding.utils import MODEL_REGISTRY, SequenceType +from numpy import ndarray + +from .common import MZNetworkOutput, RepresentationNetwork, PredictionNetwork, FeatureAndGradientHook +from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean + + +@MODEL_REGISTRY.register('MuZeroMTModel') +class MuZeroMTModel(nn.Module): + + def __init__( + self, + observation_shape: SequenceType = (12, 96, 96), + action_space_size: int = 6, + num_res_blocks: int = 1, + num_channels: int = 64, + reward_head_channels: int = 16, + value_head_channels: int = 16, + policy_head_channels: int = 16, + fc_reward_layers: SequenceType = [32], + fc_value_layers: SequenceType = [32], + fc_policy_layers: SequenceType = [32], + reward_support_size: int = 601, + value_support_size: int = 601, + proj_hid: int = 1024, + proj_out: int = 1024, + pred_hid: int = 512, + pred_out: int = 1024, + self_supervised_learning_loss: bool = False, + categorical_distribution: bool = True, + activation: nn.Module = nn.ReLU(inplace=True), + last_linear_layer_init_zero: bool = True, + state_norm: bool = False, + downsample: bool = False, + norm_type: Optional[str] = 'BN', + discrete_action_encoding_type: str = 'one_hot', + analysis_sim_norm: bool = False, + task_num: int = 1, # 任务数量 + *args, + **kwargs + ): + """ + 多任务MuZero模型的定义,继承自MuZeroModel。 + 增加了多任务相关的处理,如任务数量和动作空间大小调整。 + """ + super(MuZeroMTModel, self).__init__() + + print(f'==========MuZeroMTModel, num_res_blocks:{num_res_blocks}, num_channels:{num_channels}, task_num:{task_num}===========') + + if discrete_action_encoding_type == 'one_hot': + self.action_encoding_dim = action_space_size + elif discrete_action_encoding_type == 'not_one_hot': + self.action_encoding_dim = 1 + + assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type + + if isinstance(observation_shape, int) or len(observation_shape) == 1: + # for vector obs input, e.g. classical control and box2d environments + # to be compatible with LightZero model/policy, transform to shape: [C, W, H] + observation_shape = [1, observation_shape, 1] + + self.categorical_distribution = categorical_distribution + if self.categorical_distribution: + self.reward_support_size = reward_support_size + self.value_support_size = value_support_size + else: + self.reward_support_size = 1 + self.value_support_size = 1 + + self.task_num = task_num + self.action_space_size = 18 # 假设每个任务的动作空间相同 + + self.categorical_distribution = categorical_distribution + + self.discrete_action_encoding_type = 'one_hot' + + # 共享表示网络 + self.representation_network = RepresentationNetwork( + observation_shape, + num_res_blocks, + num_channels, + downsample, + activation=activation, + norm_type=norm_type + ) + + # ====== for analysis ====== + if analysis_sim_norm: + self.encoder_hook = FeatureAndGradientHook() + self.encoder_hook.setup_hooks(self.representation_network) + + # 共享动态网络 + self.dynamics_network = DynamicsNetwork( + observation_shape, + action_encoding_dim=self.action_encoding_dim, + num_res_blocks=num_res_blocks, + num_channels=num_channels + self.action_encoding_dim, + reward_head_channels=reward_head_channels, + fc_reward_layers=fc_reward_layers, + output_support_size=reward_support_size, + flatten_output_size_for_reward_head=reward_head_channels * self._get_latent_size(observation_shape, downsample), + downsample=downsample, + last_linear_layer_init_zero=last_linear_layer_init_zero, + activation=activation, + norm_type=norm_type + ) + + # 独立的预测网络,每个任务一个 + # 计算flatten_output_size + value_flatten_size = int(value_head_channels * self._get_latent_size(observation_shape, downsample)) + policy_flatten_size = int(policy_head_channels * self._get_latent_size(observation_shape, downsample)) + + self.prediction_networks = nn.ModuleList([ + PredictionNetwork( + observation_shape, + action_space_size, + num_res_blocks, + num_channels, + value_head_channels, + policy_head_channels, + fc_value_layers, + fc_policy_layers, + self.value_support_size, + flatten_output_size_for_value_head=value_flatten_size, + flatten_output_size_for_policy_head=policy_flatten_size, + downsample=downsample, + last_linear_layer_init_zero=last_linear_layer_init_zero, + activation=activation, + norm_type=norm_type + ) for _ in range(task_num) + ]) + + # 共享投影和预测头(如果使用自监督学习损失) + if self_supervised_learning_loss: + self.projection_network = nn.Sequential( + nn.Linear(num_channels * self._get_latent_size(observation_shape, downsample), proj_hid), + nn.BatchNorm1d(proj_hid), + activation, + nn.Linear(proj_hid, proj_hid), + nn.BatchNorm1d(proj_hid), + activation, + nn.Linear(proj_hid, proj_out), + nn.BatchNorm1d(proj_out) + ) + + self.prediction_head = nn.Sequential( + nn.Linear(proj_out, pred_hid), + nn.BatchNorm1d(pred_hid), + activation, + nn.Linear(pred_hid, pred_out), + ) + + self.self_supervised_learning_loss = self_supervised_learning_loss + self.state_norm = state_norm + self.downsample = downsample + + def _get_latent_size(self, observation_shape: SequenceType, downsample: bool) -> int: + """ + 辅助函数,根据观测形状和下采样选项计算潜在状态的大小。 + """ + if downsample: + return math.ceil(observation_shape[-2] / 16) * math.ceil(observation_shape[-1] / 16) + else: + return observation_shape[-2] * observation_shape[-1] + + def initial_inference(self, obs: torch.Tensor, task_id: int = 0) -> MZNetworkOutput: + """ + 多任务初始推理,基于任务ID选择对应的预测网络。 + """ + batch_size = obs.size(0) + latent_state = self.representation_network(obs) + if self.state_norm: + latent_state = renormalize(latent_state) + prediction_net = self.prediction_networks[task_id] + policy_logits, value = prediction_net(latent_state) + + return MZNetworkOutput( + value, + [0. for _ in range(batch_size)], + policy_logits, + latent_state, + ) + + def recurrent_inference(self, latent_state: torch.Tensor, action: torch.Tensor, task_id: int = 0) -> MZNetworkOutput: + """ + 多任务递归推理,根据任务ID选择对应的预测网络。 + """ + next_latent_state, reward = self._dynamics(latent_state, action) + + if self.state_norm: + next_latent_state = renormalize(next_latent_state) + prediction_net = self.prediction_networks[task_id] + policy_logits, value = prediction_net(next_latent_state) + + return MZNetworkOutput(value, reward, policy_logits, next_latent_state) + + + def _dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` + and ``reward``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns: + - next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. + - reward (:obj:`torch.Tensor`): The predicted reward of the current latent state and selected action. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + """ + # NOTE: the discrete action encoding type is important for some environments + + # discrete action space + if self.discrete_action_encoding_type == 'one_hot': + # Stack latent_state with the one hot encoded action. + # The final action_encoding shape is (batch_size, action_space_size, latent_state[2], latent_state[3]), e.g. (8, 2, 4, 1). + if len(action.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action = action.unsqueeze(-1) + + # transform action to one-hot encoding. + # action_one_hot shape: (batch_size, action_space_size), e.g., (8, 4) + action_one_hot = torch.zeros(action.shape[0], self.action_space_size, device=action.device) + # transform action to torch.int64 + action = action.long() + action_one_hot.scatter_(1, action, 1) + + action_encoding_tmp = action_one_hot.unsqueeze(-1).unsqueeze(-1) + action_encoding = action_encoding_tmp.expand( + latent_state.shape[0], self.action_space_size, latent_state.shape[2], latent_state.shape[3] + ) + + elif self.discrete_action_encoding_type == 'not_one_hot': + # Stack latent_state with the normalized encoded action. + # The final action_encoding shape is (batch_size, 1, latent_state[2], latent_state[3]), e.g. (8, 1, 4, 1). + if len(action.shape) == 2: + # (batch_size, action_dim=1) -> (batch_size, 1, 1, 1) + # e.g., torch.Size([8, 1]) -> torch.Size([8, 1, 1, 1]) + action = action.unsqueeze(-1).unsqueeze(-1) + elif len(action.shape) == 1: + # (batch_size,) -> (batch_size, 1, 1, 1) + # e.g., -> torch.Size([8, 1, 1, 1]) + action = action.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + + action_encoding = action.expand( + latent_state.shape[0], 1, latent_state.shape[2], latent_state.shape[3] + ) / self.action_space_size + + # state_action_encoding shape: (batch_size, latent_state[1] + action_dim, latent_state[2], latent_state[3]) or + # (batch_size, latent_state[1] + action_space_size, latent_state[2], latent_state[3]) depending on the discrete_action_encoding_type. + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + + next_latent_state, reward = self.dynamics_network(state_action_encoding) + if self.state_norm: + next_latent_state = renormalize(next_latent_state) + return next_latent_state, reward + + def project(self, latent_state: torch.Tensor, with_grad: bool = True) -> torch.Tensor: + """ + 多任务投影方法,当前实现为共享投影网络。 + """ + if not self.self_supervised_learning_loss: + raise NotImplementedError("Self-supervised learning loss is not enabled for this model.") + + latent_state = latent_state.reshape(latent_state.shape[0], -1) + proj = self.projection_network(latent_state) + if with_grad: + return self.prediction_head(proj) + else: + return proj.detach() + + def get_params_mean(self) -> float: + return get_params_mean(self) + + +class DynamicsNetwork(nn.Module): + + def __init__( + self, + observation_shape: SequenceType, + action_encoding_dim: int = 2, + num_res_blocks: int = 1, + num_channels: int = 64, + reward_head_channels: int = 64, + fc_reward_layers: SequenceType = [32], + output_support_size: int = 601, + flatten_output_size_for_reward_head: int = 64, + downsample: bool = False, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + ): + """ + DynamicsNetwork定义,适用于多任务共享。 + """ + super().__init__() + assert norm_type in ['BN', 'LN'], "norm_type must be in ['BN', 'LN']" + assert num_channels > action_encoding_dim, f'num_channels:{num_channels} <= action_encoding_dim:{action_encoding_dim}' + + self.num_channels = num_channels + self.flatten_output_size_for_reward_head = flatten_output_size_for_reward_head + + self.action_encoding_dim = action_encoding_dim + self.conv = nn.Conv2d(num_channels, num_channels - self.action_encoding_dim, kernel_size=3, stride=1, padding=1, bias=False) + + if norm_type == 'BN': + self.norm_common = nn.BatchNorm2d(num_channels - self.action_encoding_dim) + elif norm_type == 'LN': + if downsample: + self.norm_common = nn.LayerNorm([num_channels - self.action_encoding_dim, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)]) + else: + self.norm_common = nn.LayerNorm([num_channels - self.action_encoding_dim, observation_shape[-2], observation_shape[-1]]) + + self.resblocks = nn.ModuleList( + [ + ResBlock( + in_channels=num_channels - self.action_encoding_dim, activation=activation, norm_type='BN', res_type='basic', bias=False + ) for _ in range(num_res_blocks) + ] + ) + + self.conv1x1_reward = nn.Conv2d(num_channels - self.action_encoding_dim, reward_head_channels, 1) + + if norm_type == 'BN': + self.norm_reward = nn.BatchNorm2d(reward_head_channels) + elif norm_type == 'LN': + if downsample: + self.norm_reward = nn.LayerNorm([reward_head_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)]) + else: + self.norm_reward = nn.LayerNorm([reward_head_channels, observation_shape[-2], observation_shape[-1]]) + + self.fc_reward_head = MLP( + self.flatten_output_size_for_reward_head, + hidden_channels=fc_reward_layers[0], + layer_num=len(fc_reward_layers) + 1, + out_channels=output_support_size, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + self.activation = activation + + def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + DynamicsNetwork的前向传播,预测下一个潜在状态和奖励。 + """ + # 提取状态编码(去除动作编码部分) + state_encoding = state_action_encoding[:, :-self.action_encoding_dim, :, :] + x = self.conv(state_action_encoding) + x = self.norm_common(x) + + # 残差连接 + x += state_encoding + x = self.activation(x) + + for block in self.resblocks: + x = block(x) + next_latent_state = x + + x = self.conv1x1_reward(next_latent_state) + x = self.norm_reward(x) + x = self.activation(x) + x = x.view(x.shape[0], -1) + + # 使用全连接层预测奖励 + reward = self.fc_reward_head(x) + + return next_latent_state, reward + + def get_dynamic_mean(self) -> float: + return get_dynamic_mean(self) + + def get_reward_mean(self) -> Tuple[ndarray, float]: + return get_reward_mean(self) \ No newline at end of file diff --git a/lzero/model/unizero_model_multitask.py b/lzero/model/unizero_model_multitask.py new file mode 100644 index 000000000..e64e8b55f --- /dev/null +++ b/lzero/model/unizero_model_multitask.py @@ -0,0 +1,236 @@ +from typing import Optional + +import torch +import torch.nn as nn +from ding.utils import MODEL_REGISTRY, SequenceType +from easydict import EasyDict + +from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \ + VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook +from .unizero_world_models.tokenizer import Tokenizer +from .unizero_world_models.world_model_multitask import WorldModelMT + +from line_profiler import line_profiler + +# use ModelRegistry to register the model, for more details about ModelRegistry, please refer to DI-engine's document. +@MODEL_REGISTRY.register('UniZeroMTModel') +class UniZeroMTModel(nn.Module): + + #@profile + def __init__( + self, + observation_shape: SequenceType = (4, 64, 64), + action_space_size: int = 6, + num_res_blocks: int = 1, + num_channels: int = 64, + activation: nn.Module = nn.GELU(approximate='tanh'), + downsample: bool = True, + norm_type: Optional[str] = 'BN', + world_model_cfg: EasyDict = None, + task_num: int = 1, + *args, + **kwargs + ): + """ + Overview: + The definition of data procession in the scalable latent world model of UniZero (https://arxiv.org/abs/2406.10667), including two main parts: + - initial_inference, which is used to predict the value, policy, and latent state based on the current observation. + - recurrent_inference, which is used to predict the value, policy, reward, and next latent state based on the current latent state and action. + The world model consists of three main components: + - a tokenizer, which encodes observations into embeddings, + - a transformer, which processes the input sequences, + - and heads, which generate the logits for observations, rewards, policy, and value. + Arguments: + - observation_shape (:obj:`SequenceType`): Observation space shape, e.g. [C, W, H]=[3, 64, 64] for Atari. + - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. + - num_res_blocks (:obj:`int`): The number of res blocks in UniZero model. + - num_channels (:obj:`int`): The channels of hidden states in representation network. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ + defaults to True. This option is often used in video games like Atari. In board games like go, \ + we don't need this module. + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - world_model_cfg (:obj:`EasyDict`): The configuration of the world model, including the following keys: + - obs_type (:obj:`str`): The type of observation, which can be 'image', 'vector', or 'image_memory'. + - embed_dim (:obj:`int`): The dimension of the embedding. + - group_size (:obj:`int`): The group size of the transformer. + - max_blocks (:obj:`int`): The maximum number of blocks in the transformer. + - max_tokens (:obj:`int`): The maximum number of tokens in the transformer. + - context_length (:obj:`int`): The context length of the transformer. + - device (:obj:`str`): The device of the model, which can be 'cuda' or 'cpu'. + - action_space_size (:obj:`int`): The shape of the action. + - num_layers (:obj:`int`): The number of layers in the transformer. + - num_heads (:obj:`int`): The number of heads in the transformer. + - policy_entropy_weight (:obj:`float`): The weight of the policy entropy. + - analysis_sim_norm (:obj:`bool`): Whether to analyze the similarity of the norm. + """ + super(UniZeroMTModel, self).__init__() + + print(f'==========UniZeroMTModel, num_res_blocks:{num_res_blocks}, num_channels:{num_channels}===========') + + self.action_space_size = action_space_size + + # for multi-task + self.action_space_size = 18 + self.task_num = task_num + + self.activation = activation + self.downsample = downsample + world_model_cfg.norm_type = norm_type + assert world_model_cfg.max_tokens == 2 * world_model_cfg.max_blocks, 'max_tokens should be 2 * max_blocks, because each timestep has 2 tokens: obs and action' + + if world_model_cfg.obs_type == 'vector': + self.representation_network = RepresentationNetworkMLP( + observation_shape, + hidden_channels=world_model_cfg.embed_dim, + layer_num=2, + activation=self.activation, + group_size=world_model_cfg.group_size, + ) + # TODO: only for MemoryEnv now + self.decoder_network = VectorDecoderForMemoryEnv(embedding_dim=world_model_cfg.embed_dim, output_shape=25) + self.tokenizer = Tokenizer(encoder=self.representation_network, + decoder_network=self.decoder_network, with_lpips=False) + self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) + print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + print('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + print('==' * 20) + elif world_model_cfg.obs_type == 'image': + self.representation_network = nn.ModuleList() + # for task_id in range(self.task_num): # TODO: N independent encoder + for task_id in range(1): # TODO: one share encoder + self.representation_network.append(RepresentationNetworkUniZero( + observation_shape, + num_res_blocks, + num_channels, + self.downsample, + activation=self.activation, + norm_type=norm_type, + embedding_dim=world_model_cfg.embed_dim, + group_size=world_model_cfg.group_size, + )) + # TODO: we should change the output_shape to the real observation shape + # self.decoder_network = LatentDecoder(embedding_dim=world_model_cfg.embed_dim, output_shape=(3, 64, 64)) + + # ====== for analysis ====== + if world_model_cfg.analysis_sim_norm: + self.encoder_hook = FeatureAndGradientHook() + self.encoder_hook.setup_hooks(self.representation_network) + + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False) + self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) + print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + print('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + print('==' * 20) + elif world_model_cfg.obs_type == 'image_memory': + self.representation_network = LatentEncoderForMemoryEnv( + image_shape=(3, 5, 5), + embedding_size=world_model_cfg.embed_dim, + channels=[16, 32, 64], + kernel_sizes=[3, 3, 3], + strides=[1, 1, 1], + activation=self.activation, + group_size=world_model_cfg.group_size, + ) + self.decoder_network = LatentDecoderForMemoryEnv( + image_shape=(3, 5, 5), + embedding_size=world_model_cfg.embed_dim, + channels=[64, 32, 16], + kernel_sizes=[3, 3, 3], + strides=[1, 1, 1], + activation=self.activation, + ) + + if world_model_cfg.analysis_sim_norm: + # ====== for analysis ====== + self.encoder_hook = FeatureAndGradientHook() + self.encoder_hook.setup_hooks(self.representation_network) + + self.tokenizer = Tokenizer(with_lpips=True, encoder=self.representation_network, + decoder_network=self.decoder_network) + self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) + print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + print(f'{sum(p.numel() for p in self.world_model.parameters()) - sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) - sum(p.numel() for p in self.tokenizer.lpips.parameters())} parameters in agent.world_model - (decoder_network and lpips)') + + print('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + print(f'{sum(p.numel() for p in self.tokenizer.decoder_network.parameters())} parameters in agent.tokenizer.decoder_network') + print('==' * 20) + + #@profile + def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_obs_batch=None, task_id=None) -> MZNetworkOutput: + """ + Overview: + Initial inference of UniZero model, which is the first step of the UniZero model. + To perform the initial inference, we first use the representation network to obtain the ``latent_state``. + Then we use the prediction network to predict ``value`` and ``policy_logits`` of the ``latent_state``. + Arguments: + - obs_batch (:obj:`torch.Tensor`): The 3D image observation data. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action. \ + In initial inference, we set it to zero vector. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + """ + batch_size = obs_batch.size(0) + obs_act_dict = {'obs': obs_batch, 'action': action_batch, 'current_obs': current_obs_batch} + _, obs_token, logits_rewards, logits_policy, logits_value = self.world_model.forward_initial_inference(obs_act_dict, task_id=task_id) + latent_state, reward, policy_logits, value = obs_token, logits_rewards, logits_policy, logits_value + policy_logits = policy_logits.squeeze(1) + value = value.squeeze(1) + + return MZNetworkOutput( + value, + [0. for _ in range(batch_size)], + policy_logits, + latent_state, + ) + + #@profile + def recurrent_inference(self, state_action_history: torch.Tensor, simulation_index=0, + latent_state_index_in_search_path=[], task_id=None) -> MZNetworkOutput: + """ + Overview: + Recurrent inference of UniZero model.To perform the recurrent inference, we concurrently predict the latent dynamics (reward/next_latent_state) + and decision-oriented quantities (value/policy) conditioned on the learned latent history in the world_model. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - next_latent_state (:obj:`torch.Tensor`): The predicted next latent state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + """ + _, logits_observations, logits_rewards, logits_policy, logits_value = self.world_model.forward_recurrent_inference( + state_action_history, simulation_index, latent_state_index_in_search_path, task_id=task_id) + next_latent_state, reward, policy_logits, value = logits_observations, logits_rewards, logits_policy, logits_value + policy_logits = policy_logits.squeeze(1) + value = value.squeeze(1) + reward = reward.squeeze(1) + return MZNetworkOutput(value, reward, policy_logits, next_latent_state) \ No newline at end of file diff --git a/lzero/model/unizero_world_models/lpips.py b/lzero/model/unizero_world_models/lpips.py index c6ee6426c..7abd5c062 100644 --- a/lzero/model/unizero_world_models/lpips.py +++ b/lzero/model/unizero_world_models/lpips.py @@ -22,11 +22,13 @@ def __init__(self, use_dropout: bool = True): self.chns = [64, 128, 256, 512, 512] # vg16 features # Comment out the following line if you don't need perceptual loss # self.net = vgg16(pretrained=True, requires_grad=False) - self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) - self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) - self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) - self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) - self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + + # self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + # self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + # self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + # self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + # self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + # Comment out the following line if you don't need perceptual loss # self.load_from_pretrained() # for param in self.parameters(): diff --git a/lzero/model/unizero_world_models/moe.py b/lzero/model/unizero_world_models/moe.py new file mode 100644 index 000000000..159afd69e --- /dev/null +++ b/lzero/model/unizero_world_models/moe.py @@ -0,0 +1,49 @@ +import dataclasses +from typing import List + +import torch +import torch.nn.functional as F +from simple_parsing.helpers import Serializable +from torch import nn + +# Modified from https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/transformer.py#L108 +class MultiplicationFeedForward(nn.Module): + def __init__(self, config): + super().__init__() + + self.w1 = nn.Linear(config.embed_dim, 4 * config.embed_dim, bias=False) + self.w2 = nn.Linear(4 * config.embed_dim, config.embed_dim, bias=False) + self.w3 = nn.Linear(config.embed_dim, 4 * config.embed_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) # type: ignore + +@dataclasses.dataclass +class MoeArgs(Serializable): + num_experts: int + num_experts_per_tok: int + + +class MoeLayer(nn.Module): + def __init__(self, experts: List[nn.Module], gate: nn.Module, num_experts_per_tok=1): + super().__init__() + assert len(experts) > 0 + self.experts = nn.ModuleList(experts) + self.gate = gate + self.num_experts_per_tok = num_experts_per_tok + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + # if len(self.experts) == 1: + # # 只有一个专家时,直接使用该专家 + # return self.experts[0](inputs) + + gate_logits = self.gate(inputs) + weights, selected_experts = torch.topk(gate_logits, self.num_experts_per_tok) + weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype) + results = torch.zeros_like(inputs) + for i, expert in enumerate(self.experts): + # batch_idx, nth_expert = torch.where(selected_experts == i) + # results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(inputs[batch_idx]) + batch_idx, token_idx, nth_expert = torch.where(selected_experts == i) + results[batch_idx, token_idx] += weights[batch_idx, token_idx, nth_expert][:, None] * expert(inputs[batch_idx, token_idx]) + return results \ No newline at end of file diff --git a/lzero/model/unizero_world_models/test_moe.py b/lzero/model/unizero_world_models/test_moe.py new file mode 100644 index 000000000..6ab93cc16 --- /dev/null +++ b/lzero/model/unizero_world_models/test_moe.py @@ -0,0 +1,107 @@ +import dataclasses +from typing import List + +import torch +import torch.nn.functional as F +from simple_parsing.helpers import Serializable +from torch import nn + +# 定义MoeArgs数据类,用于存储MoE的配置参数 +@dataclasses.dataclass +class MoeArgs(Serializable): + num_experts: int + num_experts_per_tok: int + +# 定义Mixture of Experts(MoE)层 +class MoeLayer(nn.Module): + def __init__(self, experts: List[nn.Module], gate: nn.Module, num_experts_per_tok=1): + super().__init__() + assert len(experts) > 0 + self.experts = nn.ModuleList(experts) + self.gate = gate + self.num_experts_per_tok = num_experts_per_tok + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + if len(self.experts) == 1: + # 只有一个专家时,直接使用该专家 + return self.experts[0](inputs) + + gate_logits = self.gate(inputs) + weights, selected_experts = torch.topk(gate_logits, self.num_experts_per_tok) + weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype) + results = torch.zeros_like(inputs) + for i, expert in enumerate(self.experts): + batch_idx, token_idx, nth_expert = torch.where(selected_experts == i) + results[batch_idx, token_idx] += weights[batch_idx, token_idx, nth_expert][:, None] * expert(inputs[batch_idx, token_idx]) + return results + +# 定义一个简单的Transformer块 +class TransformerBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(config.embed_dim, 4 * config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(4 * config.embed_dim, config.embed_dim), + nn.Dropout(config.resid_pdrop), + ) + + if config.moe_in_transformer: + self.feed_forward = MoeLayer( + experts=[self.mlp for _ in range(config.num_experts_of_moe_in_transformer)], + gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False), + num_experts_per_tok=1, + ) + print("="*20) + print('使用MoE在Transformer的feed_forward中') + print("="*20) + else: + self.feed_forward = self.mlp + + def forward(self, x): + return self.feed_forward(x) + +# 定义配置类 +class Config: + def __init__(self, embed_dim, resid_pdrop, num_experts_of_moe_in_transformer, moe_in_transformer): + self.embed_dim = embed_dim + self.resid_pdrop = resid_pdrop + self.num_experts_of_moe_in_transformer = num_experts_of_moe_in_transformer + self.moe_in_transformer = moe_in_transformer + +# 测试代码 +def test_transformer_block(): + # 初始化配置 + embed_dim = 64 + resid_pdrop = 0.1 + num_experts_of_moe_in_transformer = 1 + + # 创建输入数据 + inputs = torch.randn(10, 5, embed_dim) # (batch_size, seq_len, embed_dim) + + # 初始化两个输出变量 + outputs_true = None + outputs_false = None + + # 对于moe_in_transformer为True和False分别进行测试 + for moe_in_transformer in [True, False]: + config = Config(embed_dim, resid_pdrop, num_experts_of_moe_in_transformer, moe_in_transformer) + transformer_block = TransformerBlock(config) + + outputs = transformer_block(inputs) + print(f"moe_in_transformer={moe_in_transformer}: outputs={outputs}") + + if moe_in_transformer: + outputs_true = outputs + else: + outputs_false = outputs + + # 计算输出的差异 + mse_difference = None + if outputs_true is not None and outputs_false is not None: + mse_difference = F.mse_loss(outputs_true, outputs_false).item() + + print(f"输出差异的均方误差(MSE): {mse_difference}") + +if __name__ == "__main__": + test_transformer_block() \ No newline at end of file diff --git a/lzero/model/unizero_world_models/tokenizer.py b/lzero/model/unizero_world_models/tokenizer.py index bd066ccec..4cd022fe0 100644 --- a/lzero/model/unizero_world_models/tokenizer.py +++ b/lzero/model/unizero_world_models/tokenizer.py @@ -54,35 +54,48 @@ def __init__(self, encoder=None, decoder_network=None, with_lpips: bool = False) self.encoder = encoder self.decoder_network = decoder_network - def encode_to_obs_embeddings(self, x: torch.Tensor) -> torch.Tensor: + def encode_to_obs_embeddings(self, x: torch.Tensor, task_id = None) -> torch.Tensor: """ Encode observations to embeddings. Arguments: - x (torch.Tensor): Input tensor of shape (B, ...). + - x (torch.Tensor): Input tensor of shape (B, ...). Returns: - torch.Tensor: Encoded embeddings of shape (B, 1, E). + - torch.Tensor: Encoded embeddings of shape (B, 1, E). """ shape = x.shape + if task_id is None: + # for compatibility with multitask setting + task_id = 0 + else: + task_id = 0 # one share encoder + # task_id = task_id # TODO: one encoder per task + # Process input tensor based on its dimensionality if len(shape) == 2: # Case when input is 2D (B, E) - obs_embeddings = self.encoder(x) + obs_embeddings = self.encoder[task_id](x) obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') elif len(shape) == 3: # Case when input is 3D (B, T, E) x = x.contiguous().view(-1, shape[-1]) # Flatten the last two dimensions (B * T, E) - obs_embeddings = self.encoder(x) + obs_embeddings = self.encoder[task_id](x) obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') elif len(shape) == 4: # Case when input is 4D (B, C, H, W) - obs_embeddings = self.encoder(x) + try: + obs_embeddings = self.encoder[task_id](x) + except Exception as e: + obs_embeddings = self.encoder(x) # TODO: for memory env obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') elif len(shape) == 5: # Case when input is 5D (B, T, C, H, W) x = x.contiguous().view(-1, *shape[-3:]) # Flatten the first two dimensions (B * T, C, H, W) - obs_embeddings = self.encoder(x) + try: + obs_embeddings = self.encoder[task_id](x) + except Exception as e: + obs_embeddings = self.encoder(x) # TODO: for memory env obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') else: raise ValueError(f"Invalid input shape: {shape}") diff --git a/lzero/model/unizero_world_models/transformer.py b/lzero/model/unizero_world_models/transformer.py index 62536c892..43ed53a7f 100644 --- a/lzero/model/unizero_world_models/transformer.py +++ b/lzero/model/unizero_world_models/transformer.py @@ -13,6 +13,8 @@ from torch.nn import functional as F from .kv_caching import KeysValues +from .moe import MoeLayer, MultiplicationFeedForward +from line_profiler import line_profiler @dataclass @@ -69,6 +71,7 @@ def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues: device = self.ln_f.weight.device # Assumption: All submodules are on the same device return KeysValues(n, self.config.num_heads, max_tokens, self.config.embed_dim, self.config.num_layers, device) + #@profile def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues] = None, valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: """ @@ -91,6 +94,8 @@ def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues return x + + class Block(nn.Module): """ Transformer block class. @@ -121,12 +126,48 @@ def __init__(self, config: TransformerConfig) -> None: self.ln1 = nn.LayerNorm(config.embed_dim) self.ln2 = nn.LayerNorm(config.embed_dim) self.attn = SelfAttention(config) - self.mlp = nn.Sequential( - nn.Linear(config.embed_dim, 4 * config.embed_dim), - nn.GELU(approximate='tanh'), - nn.Linear(4 * config.embed_dim, config.embed_dim), - nn.Dropout(config.resid_pdrop), - ) + if config.moe_in_transformer: + # 创Create multiple independent MLP instances + self.experts = nn.ModuleList([ + nn.Sequential( + nn.Linear(config.embed_dim, 4 * config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(4 * config.embed_dim, config.embed_dim), + nn.Dropout(config.resid_pdrop), + ) for _ in range(config.num_experts_of_moe_in_transformer) + ]) + + self.feed_forward = MoeLayer( + experts=self.experts, + gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False), + num_experts_per_tok=1, + ) + + print("="*20) + print(f'use moe in feed_forward of transformer, num of expert: {config.num_experts_of_moe_in_transformer}') + print("="*20) + elif config.multiplication_moe_in_transformer: + # Create multiple FeedForward instances for multiplication-based MoE + self.experts = nn.ModuleList([ + MultiplicationFeedForward(config) for _ in range(config.num_experts_of_moe_in_transformer) + ]) + + self.feed_forward = MoeLayer( + experts=self.experts, + gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False), + num_experts_per_tok=1, + ) + + print("="*20) + print(f'use multiplication moe in feed_forward of transformer, num of expert: {config.num_experts_of_moe_in_transformer}') + print("="*20) + else: + self.feed_forward = nn.Sequential( + nn.Linear(config.embed_dim, 4 * config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(4 * config.embed_dim, config.embed_dim), + nn.Dropout(config.resid_pdrop), + ) def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None, valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: @@ -144,10 +185,10 @@ def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths) if self.gru_gating: x = self.gate1(x, x_attn) - x = self.gate2(x, self.mlp(self.ln2(x))) + x = self.gate2(x, self.feed_forward(self.ln2(x))) else: x = x + x_attn - x = x + self.mlp(self.ln2(x)) + x = x + self.feed_forward(self.ln2(x)) return x @@ -188,6 +229,7 @@ def __init__(self, config: TransformerConfig) -> None: causal_mask = torch.tril(torch.ones(config.max_tokens, config.max_tokens)) self.register_buffer('mask', causal_mask) + #@profile def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: """ @@ -205,7 +247,10 @@ def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, B, T, C = x.size() if kv_cache is not None: b, nh, L, c = kv_cache.shape - assert nh == self.num_heads and b == B and c * nh == C, "Cache dimensions do not match input dimensions." + try: + assert nh == self.num_heads and b == B and c * nh == C, "Cache dimensions do not match input dimensions." + except Exception as e: + print('debug') else: L = 0 diff --git a/lzero/model/unizero_world_models/utils.py b/lzero/model/unizero_world_models/utils.py index 99c841cbe..0a0c9dd51 100644 --- a/lzero/model/unizero_world_models/utils.py +++ b/lzero/model/unizero_world_models/utils.py @@ -215,8 +215,14 @@ def init_weights(module, norm_type='BN'): module.bias.data.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): print(f"Init {module} using zero bias, 1 weight") - module.bias.data.zero_() - module.weight.data.fill_(1.0) + try: + module.bias.data.zero_() + except Exception as e: + print(e) + try: + module.weight.data.fill_(1.0) + except Exception as e: + print(e) elif isinstance(module, nn.BatchNorm2d): print(f"Init nn.BatchNorm2d using zero bias, 1 weight") module.weight.data.fill_(1.0) @@ -294,7 +300,7 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, continu self.loss_total += self.perceptual_loss_weight * v self.intermediate_losses = { - k: v if isinstance(v, dict) or isinstance(v, torch.Tensor) else (v if isinstance(v, float) else v.item()) + k: v if isinstance(v, dict) or isinstance(v, np.ndarray) or isinstance(v, torch.Tensor) else (v if isinstance(v, float) else v.item()) for k, v in kwargs.items() } diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py index 37d4cd3ec..481054972 100644 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -46,7 +46,7 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.tokenizer = tokenizer self.config = config self.transformer = Transformer(self.config) - + self.task_num = 1 if self.config.device == 'cpu': self.device = torch.device('cpu') else: @@ -392,6 +392,7 @@ def precompute_pos_emb_diff_kv(self): self.pos_emb_diff_k.append(layer_pos_emb_diff_k) self.pos_emb_diff_v.append(layer_pos_emb_diff_v) + #@profile def _get_positional_embedding(self, layer, attn_type) -> torch.Tensor: """ Helper function to get positional embedding for a given layer and attention type. @@ -413,6 +414,7 @@ def _get_positional_embedding(self, layer, attn_type) -> torch.Tensor: 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads ).transpose(1, 2).detach() + #@profile def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tuple]], past_keys_values: Optional[torch.Tensor] = None, kvcache_independent: bool = False, is_init_infer: bool = True, @@ -484,6 +486,7 @@ def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tu # logits_ends is None return WorldModelOutput(x, logits_observations, logits_rewards, None, logits_policy, logits_value) + #@profile def _add_position_embeddings(self, embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, valid_context_lengths): """ @@ -512,6 +515,7 @@ def _add_position_embeddings(self, embeddings, prev_steps, num_steps, kvcache_in valid_context_lengths + torch.arange(num_steps, device=self.device)).unsqueeze(1) return embeddings + position_embeddings + #@profile def _process_obs_act_combined_cont(self, obs_embeddings_or_act_tokens, prev_steps): """ Process combined observation embeddings and action tokens. @@ -548,6 +552,7 @@ def _process_obs_act_combined_cont(self, obs_embeddings_or_act_tokens, prev_step return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps + #@profile def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps): """ Process combined observation embeddings and action tokens. @@ -577,6 +582,7 @@ def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps): return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps + #@profile def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, valid_context_lengths): """ Pass sequences through the transformer. @@ -597,6 +603,7 @@ def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, va else: return self.transformer(sequences, past_keys_values, valid_context_lengths=valid_context_lengths) + #@profile @torch.no_grad() def reset_for_initial_inference(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: """ @@ -631,6 +638,7 @@ def reset_for_initial_inference(self, obs_act_dict: torch.FloatTensor) -> torch. return outputs_wm, self.latent_state + #@profile @torch.no_grad() def wm_forward_for_initial_infererence(self, last_obs_embeddings: torch.LongTensor, batch_action=None, @@ -754,6 +762,7 @@ def wm_forward_for_initial_infererence(self, last_obs_embeddings: torch.LongTens return outputs_wm + #@profile @torch.no_grad() def forward_initial_inference(self, obs_act_dict): """ @@ -771,6 +780,7 @@ def forward_initial_inference(self, obs_act_dict): return (outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value) + #@profile @torch.no_grad() def forward_recurrent_inference(self, state_action_history, simulation_index=0, latent_state_index_in_search_path=[]): @@ -856,6 +866,7 @@ def forward_recurrent_inference(self, state_action_history, simulation_index=0, return (outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value) + #@profile def trim_and_pad_kv_cache(self, is_init_infer=True) -> list: """ Adjusts the key-value cache for each environment to ensure they all have the same size. @@ -908,6 +919,7 @@ def trim_and_pad_kv_cache(self, is_init_infer=True) -> list: return self.keys_values_wm_size_list + #@profile def update_cache_context(self, latent_state, is_init_infer=True, simulation_index=0, latent_state_index_in_search_path=[], valid_context_lengths=None): """ @@ -1049,6 +1061,7 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde self.past_kv_cache_recurrent_infer[cache_key] = cache_index + #@profile def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, simulation_index: int = 0) -> list: """ @@ -1509,6 +1522,8 @@ def _calculate_policy_loss_cont(self, outputs, batch: dict) -> Tuple[torch.Tenso # KL as projector target_log_prob_sampled_actions = torch.log(target_normalized_visit_count + 1e-6) + + # KL as projector policy_loss = -torch.sum( torch.exp(target_log_prob_sampled_actions.detach()) * log_prob_sampled_actions, 1 ) * mask_batch @@ -1558,6 +1573,7 @@ def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): return loss + #@profile def compute_policy_entropy_loss(self, logits, mask): # Compute entropy of the policy probs = torch.softmax(logits, dim=1) @@ -1567,6 +1583,7 @@ def compute_policy_entropy_loss(self, logits, mask): entropy_loss = (entropy * mask) return entropy_loss + #@profile def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # assert torch.all(ends.sum(dim=1) <= 1) # Each sequence sample should have at most one 'done' flag @@ -1586,6 +1603,8 @@ def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torc return labels_observations, labels_rewards.view(-1, self.support_size), None + + #@profile def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute labels for value and policy predictions. """ diff --git a/lzero/model/unizero_world_models/world_model_multitask.py b/lzero/model/unizero_world_models/world_model_multitask.py new file mode 100644 index 000000000..2f3532cc6 --- /dev/null +++ b/lzero/model/unizero_world_models/world_model_multitask.py @@ -0,0 +1,1842 @@ +import collections +import logging +from typing import Any, Tuple +from typing import Optional +from typing import Union, Dict + +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from lzero.model.common import SimNorm +from lzero.model.unizero_world_models.world_model import WorldModel +from lzero.model.utils import cal_dormant_ratio +from .moe import MoeLayer, MultiplicationFeedForward +from .slicer import Head +from .tokenizer import Tokenizer +from .transformer import Transformer, TransformerConfig +from .utils import LossWithIntermediateLosses, init_weights +from .utils import WorldModelOutput, hash_state + +logging.getLogger().setLevel(logging.DEBUG) +from ding.utils import get_rank +import torch.distributed as dist +from sklearn.manifold import TSNE +import os +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.patches import Patch +from matplotlib.offsetbox import OffsetImage, AnnotationBbox +import torch + + +class WorldModelMT(WorldModel): + """ + Overview: + The WorldModel class is responsible for the scalable latent world model of UniZero (https://arxiv.org/abs/2406.10667), + which is used to predict the next latent state, rewards, policy, and value based on the current latent state and action. + The world model consists of three main components: + - a tokenizer, which encodes observations into embeddings, + - a transformer, which processes the input sequences, + - and heads, which generate the logits for observations, rewards, policy, and value. + """ + + #@profile + def __init__(self, config: TransformerConfig, tokenizer) -> None: + """ + Overview: + Initialize the WorldModel class. + Arguments: + - config (:obj:`TransformerConfig`): The configuration for the transformer. + - tokenizer (:obj:`Tokenizer`): The tokenizer. + """ + super().__init__(config, tokenizer) + self.tokenizer = tokenizer + self.config = config + self.transformer = Transformer(self.config) + + # TODO ======== + self.analysis_tsne = self.config.get('analysis_tsne', False) + + if self.analysis_tsne: + self.env_id_list = self.config.env_id_list + # 自动生成 self.env_short_names + self.env_short_names = {} + + # 遍历 env_id_list,提取短名称 + for env_id in self.config.env_id_list: + # 提取 'NoFrameskip-v4' 之前的部分作为短名称 + short_name = env_id.replace('NoFrameskip-v4', '') + self.env_short_names[env_id] = short_name + # 映射环境 ID 到简写名称 + # self.env_short_names = { + # 'PongNoFrameskip-v4': 'Pong', + # 'MsPacmanNoFrameskip-v4': 'MsPacman', + # 'SeaquestNoFrameskip-v4': 'Seaquest', + # 'BoxingNoFrameskip-v4': 'Boxing', + # 'AlienNoFrameskip-v4': 'Alien', + # 'ChopperCommandNoFrameskip-v4': 'Chopper', + # 'HeroNoFrameskip-v4': 'Hero', + # 'RoadRunnerNoFrameskip-v4': 'RoadRunner' + # } + # 颜色映射,确保每个任务有固定的颜色 + self.num_tasks = len(self.env_id_list) + + # 生成足够多的颜色 + self.colors = self._generate_colors(len(self.env_id_list)) + + + # TODO: multitask + self.task_num = config.task_num + self.task_emb = nn.Embedding(self.task_num, config.embed_dim, max_norm=1) # TODO + self.head_policy_multi_task = nn.ModuleList() + self.head_value_multi_task = nn.ModuleList() + self.head_rewards_multi_task = nn.ModuleList() + self.head_observations_multi_task = nn.ModuleList() + + self.num_experts_in_moe_head = config.num_experts_in_moe_head + self.use_normal_head = config.use_normal_head + self.use_moe_head = config.use_moe_head + self.use_softmoe_head = config.use_softmoe_head + + if self.config.device == 'cpu': + self.device = torch.device('cpu') + else: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # Move all modules to the specified device + print(f"self.device: {self.device}") + self.to(self.device) + + # Initialize configuration parameters + self._initialize_config_parameters() + + # Initialize patterns for block masks + self._initialize_patterns() + + self.hidden_size = config.embed_dim // config.num_heads + + # Position embedding + self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim, device=self.device) + self.precompute_pos_emb_diff_kv() + print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}") + self.continuous_action_space = self.config.continuous_action_space + + # Initialize action embedding table + if self.continuous_action_space: + # TODO: check the effect of SimNorm + self.act_embedding_table = nn.Sequential( + nn.Linear(config.action_space_size, config.embed_dim, device=self.device, bias=False), + SimNorm(simnorm_dim=self.group_size)) + else: + # for discrete action space + self.act_embedding_table = nn.Embedding(config.action_space_size, config.embed_dim, device=self.device) + print(f"self.act_embedding_table.weight.device: {self.act_embedding_table.weight.device}") + + # if self.num_experts_in_moe_head == -1: + assert self.num_experts_in_moe_head > 0 + if self.use_normal_head: + print('We use normal head') + # TODO: Normal Head + for task_id in range(self.task_num): # TODO + self.head_policy = self._create_head(self.value_policy_tokens_pattern, self.action_space_size) + self.head_policy_multi_task.append(self.head_policy) + + self.head_value = self._create_head(self.value_policy_tokens_pattern, self.support_size) + self.head_value_multi_task.append(self.head_value) + + self.head_rewards = self._create_head(self.act_tokens_pattern, self.support_size) + self.head_rewards_multi_task.append(self.head_rewards) + + self.head_observations = self._create_head(self.all_but_last_latent_state_pattern, + self.obs_per_embdding_dim, + self.sim_norm) # NOTE: we add a sim_norm to the head for observations + self.head_observations_multi_task.append(self.head_observations) + elif self.use_softmoe_head: + print(f'We use softmoe head, self.num_experts_in_moe_head is {self.num_experts_in_moe_head}') + # Dictionary to store SoftMoE instances + self.soft_moe_instances = {} + + # Create softmoe head modules + self.create_head_modules_softmoe() + + self.head_policy_multi_task.append(self.head_policy) + self.head_value_multi_task.append(self.head_value) + self.head_rewards_multi_task.append(self.head_rewards) + self.head_observations_multi_task.append(self.head_observations) + elif self.use_moe_head: + print(f'We use moe head, self.num_experts_in_moe_head is {self.num_experts_in_moe_head}') + # Dictionary to store moe instances + self.moe_instances = {} + + # Create moe head modules + self.create_head_modules_moe() + + self.head_policy_multi_task.append(self.head_policy) + self.head_value_multi_task.append(self.head_value) + self.head_rewards_multi_task.append(self.head_rewards) + self.head_observations_multi_task.append(self.head_observations) + + + # Apply weight initialization, the order is important + self.apply(lambda module: init_weights(module, norm_type=self.config.norm_type)) + self._initialize_last_layer() + + # Cache structures + self._initialize_cache_structures() + + # Projection input dimension + self._initialize_projection_input_dim() + + # Hit count and query count statistics + self._initialize_statistics() + + # Initialize keys and values for transformer + self._initialize_transformer_keys_values() + + self.latent_recon_loss = torch.tensor(0., device=self.device) + self.perceptual_loss = torch.tensor(0., device=self.device) + + # TODO: check the size of the shared pool + # for self.kv_cache_recurrent_infer + # If needed, recurrent_infer should store the results of the one MCTS search. + self.shared_pool_size = int(50*self.env_num) + self.shared_pool_recur_infer = [None] * self.shared_pool_size + self.shared_pool_index = 0 + + # for self.kv_cache_init_infer + # In contrast, init_infer only needs to retain the results of the most recent step. + # self.shared_pool_size_init = int(2*self.env_num) + self.shared_pool_size_init = int(2) # NOTE: Will having too many cause incorrect retrieval of the kv cache? + self.shared_pool_init_infer = [[None] * self.shared_pool_size_init for _ in range(self.env_num)] + self.shared_pool_index_init_envs = [0 for _ in range(self.env_num)] + + # for self.kv_cache_wm + self.shared_pool_size_wm = int(self.env_num) + self.shared_pool_wm = [None] * self.shared_pool_size_wm + self.shared_pool_index_wm = 0 + + self.reanalyze_phase = False + self._rank = get_rank() + + def _generate_colors(self, num_colors): + """ + 生成足够多的独特颜色,适用于大量分类。 + + 参数: + - num_colors: 所需颜色数量。 + + 返回: + - colors: 颜色列表。 + """ + # 使用多个matplotlib离散色图拼接 + color_maps = ['tab20', 'tab20b', 'tab20c'] + colors = [] + for cmap_name in color_maps: + cmap = plt.get_cmap(cmap_name) + colors.extend([cmap(i) for i in range(cmap.N)]) + if len(colors) >= num_colors: + break + if len(colors) < num_colors: + # 生成额外的颜色,如果需要 + additional_colors = plt.cm.get_cmap('hsv', num_colors - len(colors)) + colors.extend([additional_colors(i) for i in range(num_colors - len(colors))]) + return colors[:num_colors] + + def _initialize_config_parameters(self) -> None: + """Initialize configuration parameters.""" + self.policy_entropy_weight = self.config.policy_entropy_weight + self.predict_latent_loss_type = self.config.predict_latent_loss_type + self.group_size = self.config.group_size + self.num_groups = self.config.embed_dim // self.group_size + self.obs_type = self.config.obs_type + self.embed_dim = self.config.embed_dim + self.num_heads = self.config.num_heads + self.gamma = self.config.gamma + self.context_length = self.config.context_length + self.dormant_threshold = self.config.dormant_threshold + self.analysis_dormant_ratio = self.config.analysis_dormant_ratio + self.num_observations_tokens = self.config.tokens_per_block - 1 + self.latent_recon_loss_weight = self.config.latent_recon_loss_weight + self.perceptual_loss_weight = self.config.perceptual_loss_weight + self.device = self.config.device + self.support_size = self.config.support_size + self.action_space_size = self.config.action_space_size + self.max_cache_size = self.config.max_cache_size + self.env_num = self.config.env_num + self.num_layers = self.config.num_layers + self.obs_per_embdding_dim = self.config.embed_dim + self.sim_norm = SimNorm(simnorm_dim=self.group_size) + + def _initialize_patterns(self) -> None: + """Initialize patterns for block masks.""" + self.all_but_last_latent_state_pattern = torch.ones(self.config.tokens_per_block) + self.all_but_last_latent_state_pattern[-2] = 0 + self.act_tokens_pattern = torch.zeros(self.config.tokens_per_block) + self.act_tokens_pattern[-1] = 1 + self.value_policy_tokens_pattern = torch.zeros(self.config.tokens_per_block) + self.value_policy_tokens_pattern[-2] = 1 + + def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None) -> Head: + """Create head modules for the transformer.""" + modules = [ + nn.Linear(self.config.embed_dim, self.config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(self.config.embed_dim, output_dim) + ] + if norm_layer: + modules.append(norm_layer) + return Head( + max_blocks=self.config.max_blocks, + block_mask=block_mask, + head_module=nn.Sequential(*modules) + ) + + def _create_head_moe(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None, moe=None) -> Head: + """Create moe head modules for the transformer.""" + modules = [ + moe, + nn.Linear(self.config.embed_dim, output_dim) + ] + if norm_layer: + modules.append(norm_layer) + return Head( + max_blocks=self.config.max_blocks, + block_mask=block_mask, + head_module=nn.Sequential(*modules) + ) + def get_moe(self, name): + """Get or create a MoE instance""" + if name not in self.moe_instances: + # Create multiple FeedForward instances for multiplication-based MoE + self.experts = nn.ModuleList([ + MultiplicationFeedForward(self.config) for _ in range(self.config.num_experts_of_moe_in_transformer) + ]) + + self.moe_instances[name] = MoeLayer( + experts=self.experts, + gate=nn.Linear(self.config.embed_dim, self.config.num_experts_of_moe_in_transformer, bias=False), + num_experts_per_tok=1, + ) + + return self.moe_instances[name] + + def create_head_modules_moe(self): + """Create all softmoe head modules""" + # Rewards head + self.head_rewards = self._create_head_moe( + self.act_tokens_pattern, + self.support_size, + moe=self.get_moe("rewards_moe") + ) + + # Observations head + self.head_observations = self._create_head_moe( + self.all_but_last_latent_state_pattern, + self.obs_per_embdding_dim, + norm_layer=self.sim_norm, # NOTE + moe=self.get_moe("observations_moe") + ) + + # Policy head + self.head_policy = self._create_head_moe( + self.value_policy_tokens_pattern, + self.action_space_size, + moe=self.get_moe("policy_moe") + ) + + # Value head + self.head_value = self._create_head_moe( + self.value_policy_tokens_pattern, + self.support_size, + moe=self.get_moe("value_moe") + ) + + def _create_head_softmoe(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None, soft_moe=None) -> Head: + """Create softmoe head modules for the transformer.""" + modules = [ + soft_moe, + nn.Linear(self.config.embed_dim, output_dim) + ] + if norm_layer: + modules.append(norm_layer) + return Head( + max_blocks=self.config.max_blocks, + block_mask=block_mask, + head_module=nn.Sequential(*modules) + ) + + def get_soft_moe(self, name): + """Get or create a SoftMoE instance""" + # from soft_moe_pytorch import SoftMoE + # if name not in self.soft_moe_instances: + # self.soft_moe_instances[name] = SoftMoE( + # dim=self.embed_dim, + # seq_len=20, # TODO + # num_experts=self.num_experts_in_moe_head, + # ) + from soft_moe_pytorch import DynamicSlotsSoftMoE as SoftMoE + if name not in self.soft_moe_instances: + self.soft_moe_instances[name] = SoftMoE( + dim=self.embed_dim, + num_experts=self.num_experts_in_moe_head, + geglu = True + ) + return self.soft_moe_instances[name] + + def create_head_modules_softmoe(self): + """Create all softmoe head modules""" + # Rewards head + self.head_rewards = self._create_head_softmoe( + self.act_tokens_pattern, + self.support_size, + soft_moe=self.get_soft_moe("rewards_soft_moe") + ) + + # Observations head + self.head_observations = self._create_head_softmoe( + self.all_but_last_latent_state_pattern, + self.obs_per_embdding_dim, + norm_layer=self.sim_norm, # NOTE + soft_moe=self.get_soft_moe("observations_soft_moe") + ) + + # Policy head + self.head_policy = self._create_head_softmoe( + self.value_policy_tokens_pattern, + self.action_space_size, + soft_moe=self.get_soft_moe("policy_soft_moe") + ) + + # Value head + self.head_value = self._create_head_softmoe( + self.value_policy_tokens_pattern, + self.support_size, + soft_moe=self.get_soft_moe("value_soft_moe") + ) + + def _initialize_last_layer(self) -> None: + """Initialize the last linear layer.""" + last_linear_layer_init_zero = True + if last_linear_layer_init_zero: + # TODO: multitask + if self.task_num == 1: + for head in [self.head_policy, self.head_value, self.head_rewards, self.head_observations]: + for layer in reversed(head.head_module): + if isinstance(layer, nn.Linear): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + break + elif self.task_num > 1: + for head in self.head_policy_multi_task + self.head_value_multi_task + self.head_rewards_multi_task + self.head_observations_multi_task: + for layer in reversed(head.head_module): + if isinstance(layer, nn.Linear): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + break + + def _initialize_cache_structures(self) -> None: + """Initialize cache structures for past keys and values.""" + self.past_kv_cache_recurrent_infer = collections.OrderedDict() + self.past_kv_cache_init_infer = collections.OrderedDict() + self.past_kv_cache_init_infer_envs = [collections.OrderedDict() for _ in range(self.env_num)] + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + + def _initialize_projection_input_dim(self) -> None: + """Initialize the projection input dimension based on the number of observation tokens.""" + if self.num_observations_tokens == 16: + self.projection_input_dim = 128 + elif self.num_observations_tokens == 1: + self.projection_input_dim = self.obs_per_embdding_dim + + def _initialize_statistics(self) -> None: + """Initialize counters for hit count and query count statistics.""" + self.hit_count = 0 + self.total_query_count = 0 + self.length_largethan_maxminus5_context_cnt = 0 + self.length_largethan_maxminus7_context_cnt = 0 + self.root_hit_cnt = 0 + self.root_total_query_cnt = 0 + + #@profile + def _initialize_transformer_keys_values(self) -> None: + """Initialize keys and values for the transformer.""" + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, + max_tokens=self.context_length) + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=self.env_num, + max_tokens=self.context_length) + + #@profile + def precompute_pos_emb_diff_kv(self): + """ Precompute positional embedding differences for key and value. """ + if self.context_length <= 2: + # If context length is 2 or less, no context is present + return + + # Precompute positional embedding matrices for inference in collect/eval stages, not for training + self.positional_embedding_k = [ + self._get_positional_embedding(layer, 'key') + for layer in range(self.config.num_layers) + ] + self.positional_embedding_v = [ + self._get_positional_embedding(layer, 'value') + for layer in range(self.config.num_layers) + ] + + # Precompute all possible positional embedding differences + self.pos_emb_diff_k = [] + self.pos_emb_diff_v = [] + + for layer in range(self.config.num_layers): + layer_pos_emb_diff_k = {} + layer_pos_emb_diff_v = {} + + for start in [2]: + for end in [self.context_length - 1]: + original_pos_emb_k = self.positional_embedding_k[layer][:, :, start:end, :] + new_pos_emb_k = self.positional_embedding_k[layer][:, :, :end - start, :] + layer_pos_emb_diff_k[(start, end)] = new_pos_emb_k - original_pos_emb_k + + original_pos_emb_v = self.positional_embedding_v[layer][:, :, start:end, :] + new_pos_emb_v = self.positional_embedding_v[layer][:, :, :end - start, :] + layer_pos_emb_diff_v[(start, end)] = new_pos_emb_v - original_pos_emb_v + + self.pos_emb_diff_k.append(layer_pos_emb_diff_k) + self.pos_emb_diff_v.append(layer_pos_emb_diff_v) + + #@profile + def _get_positional_embedding(self, layer, attn_type) -> torch.Tensor: + """ + Helper function to get positional embedding for a given layer and attention type. + + Arguments: + - layer (:obj:`int`): Layer index. + - attn_type (:obj:`str`): Attention type, either 'key' or 'value'. + + Returns: + - torch.Tensor: The positional embedding tensor. + """ + attn_func = getattr(self.transformer.blocks[layer].attn, attn_type) + if torch.cuda.is_available(): + return attn_func(self.pos_emb.weight).view( + 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads + ).transpose(1, 2).to(self.device).detach() + else: + return attn_func(self.pos_emb.weight).view( + 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads + ).transpose(1, 2).detach() + + #@profile + def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tuple]], + past_keys_values: Optional[torch.Tensor] = None, + kvcache_independent: bool = False, is_init_infer: bool = True, + valid_context_lengths: Optional[torch.Tensor] = None, task_id=0) -> WorldModelOutput: + """ + Forward pass for the model. + + Arguments: + - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary containing observation embeddings or action tokens. + - past_keys_values (:obj:`Optional[torch.Tensor]`): Previous keys and values for transformer. + - kvcache_independent (:obj:`bool`): Whether to use independent key-value caching. + - is_init_infer (:obj:`bool`): Initialize inference. + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid context lengths. + Returns: + - WorldModelOutput: Model output containing logits for observations, rewards, policy, and value. + """ + # task_embeddings = self.task_emb(torch.tensor(task_id, device=self.device)) # NOTE: TODO + self.task_embeddings = torch.zeros(self.config.embed_dim, device=self.device) # ============= TODO: no task_embeddings now ============= + + # Determine previous steps based on key-value caching method + if kvcache_independent: + prev_steps = torch.tensor([0 if past_keys_values is None else past_kv.size for past_kv in past_keys_values], + device=self.device) + else: + prev_steps = 0 if past_keys_values is None else past_keys_values.size + + # Reset valid_context_lengths during initial inference + if is_init_infer: + valid_context_lengths = None + + # Process observation embeddings + if 'obs_embeddings' in obs_embeddings_or_act_tokens: + obs_embeddings = obs_embeddings_or_act_tokens['obs_embeddings'] + if len(obs_embeddings.shape) == 2: + obs_embeddings = obs_embeddings.unsqueeze(1) + num_steps = obs_embeddings.size(1) + sequences = self._add_position_embeddings(obs_embeddings, prev_steps, num_steps, kvcache_independent, + is_init_infer, valid_context_lengths) + # TODO: multitask + sequences = sequences + self.task_embeddings + + # Process action tokens + elif 'act_tokens' in obs_embeddings_or_act_tokens: + act_tokens = obs_embeddings_or_act_tokens['act_tokens'] + if len(act_tokens.shape) == 3: + act_tokens = act_tokens.squeeze(1) + num_steps = act_tokens.size(1) + act_embeddings = self.act_embedding_table(act_tokens) + sequences = self._add_position_embeddings(act_embeddings, prev_steps, num_steps, kvcache_independent, + is_init_infer, valid_context_lengths) + + # TODO: multitask + # TODO: 对于action_token不需要增加task_embeddings会造成歧义,反而干扰学习 + self.task_embeddings = torch.zeros(self.config.embed_dim, device=self.device) + sequences = sequences + self.task_embeddings + + # Process combined observation embeddings and action tokens + else: + sequences, num_steps = self._process_obs_act_combined(obs_embeddings_or_act_tokens, prev_steps) + + # Pass sequences through transformer + x = self._transformer_pass(sequences, past_keys_values, kvcache_independent, valid_context_lengths) + + # Generate logits + + # 1,...,0,1 https://github.com/eloialonso/iris/issues/19 + # TODO: one head or moe head + if self.use_moe_head: + logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) + logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) + logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) + logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) + else: + # TODO: in total N head, one head per task + logits_observations = self.head_observations_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) + logits_rewards = self.head_rewards_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) + logits_policy = self.head_policy_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) + logits_value = self.head_value_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) + + # logits_ends is None + return WorldModelOutput(x, logits_observations, logits_rewards, None, logits_policy, logits_value) + + + #@profile + def _add_position_embeddings(self, embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, + valid_context_lengths): + """ + Add position embeddings to the input embeddings. + + Arguments: + - embeddings (:obj:`torch.Tensor`): Input embeddings. + - prev_steps (:obj:`torch.Tensor`): Previous steps. + - num_steps (:obj:`int`): Number of steps. + - kvcache_independent (:obj:`bool`): Whether to use independent key-value caching. + - is_init_infer (:obj:`bool`): Initialize inference. + - valid_context_lengths (:obj:`torch.Tensor`): Valid context lengths. + Returns: + - torch.Tensor: Embeddings with position information added. + """ + if kvcache_independent: + steps_indices = prev_steps + torch.arange(num_steps, device=embeddings.device) + position_embeddings = self.pos_emb(steps_indices).view(-1, num_steps, embeddings.shape[-1]) + return embeddings + position_embeddings + else: + if is_init_infer: + return embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)) + else: + valid_context_lengths = torch.tensor(self.keys_values_wm_size_list_current, device=self.device) + position_embeddings = self.pos_emb( + valid_context_lengths + torch.arange(num_steps, device=self.device)).unsqueeze(1) + return embeddings + position_embeddings + + #@profile + def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps): + """ + Process combined observation embeddings and action tokens. + + Arguments: + - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary containing combined observation embeddings and action tokens. + - prev_steps (:obj:`torch.Tensor`): Previous steps. + Returns: + - torch.Tensor: Combined observation and action embeddings with position information added. + """ + obs_embeddings, act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] + if len(obs_embeddings.shape) == 3: + obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, + -1) + + num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1)) + act_embeddings = self.act_embedding_table(act_tokens) + + B, L, K, E = obs_embeddings.size() + obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=self.device) + + for i in range(L): + # obs = obs_embeddings[:, i, :, :] + obs = obs_embeddings[:, i, :, :] + self.task_embeddings # Shape: (B, K, E) TODO: task_embeddings + act = act_embeddings[:, i, 0, :].unsqueeze(1) + obs_act = torch.cat([obs, act], dim=1) + obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act + + return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps + + #@profile + def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, valid_context_lengths): + """ + Pass sequences through the transformer. + + Arguments: + - sequences (:obj:`torch.Tensor`): Input sequences. + - past_keys_values (:obj:`Optional[torch.Tensor]`): Previous keys and values for transformer. + - kvcache_independent (:obj:`bool`): Whether to use independent key-value caching. + - valid_context_lengths (:obj:`torch.Tensor`): Valid context lengths. + Returns: + - torch.Tensor: Transformer output. + """ + if kvcache_independent: + x = [self.transformer(sequences[k].unsqueeze(0), past_kv, + valid_context_lengths=valid_context_lengths[k].unsqueeze(0)) for k, past_kv in + enumerate(past_keys_values)] + return torch.cat(x, dim=0) + else: + return self.transformer(sequences, past_keys_values, valid_context_lengths=valid_context_lengths) + + #@profile + @torch.no_grad() + def reset_for_initial_inference(self, obs_act_dict: torch.FloatTensor, task_id = 0) -> torch.FloatTensor: + """ + Reset the model state based on initial observations and actions. + + Arguments: + - obs_act_dict (:obj:`torch.FloatTensor`): A dictionary containing 'obs', 'action', and 'current_obs'. + Returns: + - torch.FloatTensor: The outputs from the world model and the latent state. + """ + # Extract observations, actions, and current observations from the dictionary. + if isinstance(obs_act_dict, dict): + batch_obs = obs_act_dict['obs'] + batch_action = obs_act_dict['action'] + batch_current_obs = obs_act_dict['current_obs'] + + # Encode observations to latent embeddings. + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_obs, task_id=task_id) + + if batch_current_obs is not None: + # ================ Collect and Evaluation Phase ================ + # Encode current observations to latent embeddings + current_obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_current_obs, task_id=task_id) + # print(f"current_obs_embeddings.device: {current_obs_embeddings.device}") + self.latent_state = current_obs_embeddings + outputs_wm = self.wm_forward_for_initial_inference(obs_embeddings, batch_action, current_obs_embeddings, task_id=task_id) + else: + # ================ calculate the target value in Train phase ================ + self.latent_state = obs_embeddings + outputs_wm = self.wm_forward_for_initial_inference(obs_embeddings, batch_action, None, task_id=task_id) + + return outputs_wm, self.latent_state + + + #@profile + @torch.no_grad() + def wm_forward_for_initial_inference(self, last_obs_embeddings: torch.LongTensor, + batch_action=None, + current_obs_embeddings=None, task_id = 0) -> torch.FloatTensor: + """ + Refresh key-value pairs with the initial latent state for inference. + + Arguments: + - latent_state (:obj:`torch.LongTensor`): The latent state embeddings. + - batch_action (optional): Actions taken. + - current_obs_embeddings (optional): Current observation embeddings. + Returns: + - torch.FloatTensor: The outputs from the world model. + """ + n, num_observations_tokens, _ = last_obs_embeddings.shape + if n <= self.env_num and current_obs_embeddings is not None: + # ================ Collect and Evaluation Phase ================ + if current_obs_embeddings is not None: + if self.continuous_action_space: + first_step_flag = not isinstance(batch_action[0], np.ndarray) + else: + first_step_flag = max(batch_action) == -1 + if first_step_flag: + # First step in an episode + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=current_obs_embeddings.shape[0], + max_tokens=self.context_length) + # print(f"current_obs_embeddings.device: {current_obs_embeddings.device}") + outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, + past_keys_values=self.keys_values_wm, is_init_infer=True, task_id=task_id) + + # Copy and store keys_values_wm for a single environment + self.update_cache_context(current_obs_embeddings, is_init_infer=True) + else: + # Assume latest_state is the new latent_state, containing information from ready_env_num environments + ready_env_num = current_obs_embeddings.shape[0] + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + for i in range(ready_env_num): + # Retrieve latent state for a single environment + state_single_env = last_obs_embeddings[i] + # Compute hash value using latent state for a single environment + cache_key = hash_state( + state_single_env.view(-1).cpu().numpy()) # last_obs_embeddings[i] is torch.Tensor + + # Retrieve cached value + cache_index = self.past_kv_cache_init_infer_envs[i].get(cache_key) + if cache_index is not None: + matched_value = self.shared_pool_init_infer[i][cache_index] + else: + matched_value = None + + self.root_total_query_cnt += 1 + if matched_value is not None: + # If a matching value is found, add it to the list + self.root_hit_cnt += 1 + # deepcopy is needed because forward modifies matched_value in place + self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) + self.keys_values_wm_size_list.append(matched_value.size) + else: + # Reset using zero values + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.context_length) + outputs_wm = self.forward({'obs_embeddings': state_single_env.unsqueeze(0)}, + past_keys_values=self.keys_values_wm_single_env, + is_init_infer=True, task_id=task_id) + self.keys_values_wm_list.append(self.keys_values_wm_single_env) + self.keys_values_wm_size_list.append(1) + + # Input self.keys_values_wm_list, output self.keys_values_wm + self.keys_values_wm_size_list_current = self.trim_and_pad_kv_cache(is_init_infer=True) + + batch_action = batch_action[:ready_env_num] + # if ready_env_num < self.env_num: + # print(f'init inference ready_env_num: {ready_env_num} < env_num: {self.env_num}') + act_tokens = torch.from_numpy(np.array(batch_action)).to(last_obs_embeddings.device).unsqueeze(-1) + outputs_wm = self.forward({'act_tokens': act_tokens}, past_keys_values=self.keys_values_wm, + is_init_infer=True, task_id=task_id) + + outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, + past_keys_values=self.keys_values_wm, is_init_infer=True, task_id=task_id) + + # Copy and store keys_values_wm for a single environment + self.update_cache_context(current_obs_embeddings, is_init_infer=True) + + # elif n > self.env_num and batch_action is not None and current_obs_embeddings is None: + elif batch_action is not None and current_obs_embeddings is None: + # ================ calculate the target value in Train phase ================ + # [192, 16, 64] -> [32, 6, 16, 64] + last_obs_embeddings = last_obs_embeddings.contiguous().view(batch_action.shape[0], -1, num_observations_tokens, + self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 + + last_obs_embeddings = last_obs_embeddings[:, :-1, :] + batch_action = torch.from_numpy(batch_action).to(last_obs_embeddings.device) + act_tokens = rearrange(batch_action, 'b l -> b l 1') + + # select the last timestep for each sample + # This will select the last column while keeping the dimensions unchanged, and the target policy/value in the final step itself is not used. + last_steps_act = act_tokens[:, -1:, :] + act_tokens = torch.cat((act_tokens, last_steps_act), dim=1) + + outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (last_obs_embeddings, act_tokens)}, task_id=task_id) + + # select the last timestep for each sample + last_steps_value = outputs_wm.logits_value[:, -1:, :] + outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) + + last_steps_policy = outputs_wm.logits_policy[:, -1:, :] + outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) + + # Reshape your tensors + # outputs_wm.logits_value.shape (B, H, 101) = (B*H, 101) + outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') + outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') + + return outputs_wm + + + #@profile + @torch.no_grad() + def forward_initial_inference(self, obs_act_dict, task_id = 0): + """ + Perform initial inference based on the given observation-action dictionary. + + Arguments: + - obs_act_dict (:obj:`dict`): Dictionary containing observations and actions. + Returns: + - tuple: A tuple containing output sequence, latent state, logits rewards, logits policy, and logits value. + """ + # UniZero has context in the root node + outputs_wm, latent_state = self.reset_for_initial_inference(obs_act_dict, task_id=task_id) + self.past_kv_cache_recurrent_infer.clear() + + return (outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, + outputs_wm.logits_policy, outputs_wm.logits_value) + + #@profile + @torch.no_grad() + def forward_recurrent_inference(self, state_action_history, simulation_index=0, + latent_state_index_in_search_path=[], task_id = 0): + """ + Perform recurrent inference based on the state-action history. + + Arguments: + - state_action_history (:obj:`list`): List containing tuples of state and action history. + - simulation_index (:obj:`int`, optional): Index of the current simulation. Defaults to 0. + - latent_state_index_in_search_path (:obj:`list`, optional): List containing indices of latent states in the search path. Defaults to []. + Returns: + - tuple: A tuple containing output sequence, updated latent state, reward, logits policy, and logits value. + """ + latest_state, action = state_action_history[-1] + ready_env_num = latest_state.shape[0] + + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + self.keys_values_wm_size_list = self.retrieve_or_generate_kvcache(latest_state, ready_env_num, simulation_index, task_id=task_id) + + latent_state_list = [] + token = action.reshape(-1, 1) + + # ======= Print statistics for debugging ============= + # min_size = min(self.keys_values_wm_size_list) + # if min_size >= self.config.max_tokens - 5: + # self.length_largethan_maxminus5_context_cnt += len(self.keys_values_wm_size_list) + # if min_size >= self.config.max_tokens - 7: + # self.length_largethan_maxminus7_context_cnt += len(self.keys_values_wm_size_list) + # if self.total_query_count > 0 and self.total_query_count % 10000 == 0: + # self.hit_freq = self.hit_count / self.total_query_count + # print('total_query_count:', self.total_query_count) + # length_largethan_maxminus5_context_cnt_ratio = self.length_largethan_maxminus5_context_cnt / self.total_query_count + # print('recurrent largethan_maxminus5_context:', self.length_largethan_maxminus5_context_cnt) + # print('recurrent largethan_maxminus5_context_ratio:', length_largethan_maxminus5_context_cnt_ratio) + # length_largethan_maxminus7_context_cnt_ratio = self.length_largethan_maxminus7_context_cnt / self.total_query_count + # print('recurrent largethan_maxminus7_context_ratio:', length_largethan_maxminus7_context_cnt_ratio) + # print('recurrent largethan_maxminus7_context:', self.length_largethan_maxminus7_context_cnt) + + # Trim and pad kv_cache + self.keys_values_wm_size_list = self.trim_and_pad_kv_cache(is_init_infer=False) + self.keys_values_wm_size_list_current = self.keys_values_wm_size_list + + for k in range(2): + # action_token obs_token + if k == 0: + obs_embeddings_or_act_tokens = {'act_tokens': token} + else: + obs_embeddings_or_act_tokens = {'obs_embeddings': token} + + # Perform forward pass + outputs_wm = self.forward( + obs_embeddings_or_act_tokens, + past_keys_values=self.keys_values_wm, + kvcache_independent=False, + is_init_infer=False, + task_id = task_id + ) + + self.keys_values_wm_size_list_current = [i + 1 for i in self.keys_values_wm_size_list_current] + + if k == 0: + reward = outputs_wm.logits_rewards # (B,) + + if k < self.num_observations_tokens: + token = outputs_wm.logits_observations + if len(token.shape) != 3: + token = token.unsqueeze(1) # (8,1024) -> (8,1,1024) + latent_state_list.append(token) + + del self.latent_state # Very important to minimize cuda memory usage + self.latent_state = torch.cat(latent_state_list, dim=1) # (B, K) + + self.update_cache_context( + self.latent_state, + is_init_infer=False, + simulation_index=simulation_index, + latent_state_index_in_search_path=latent_state_index_in_search_path + ) + + return (outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value) + + def trim_and_pad_kv_cache(self, is_init_infer=True) -> list: + """ + Adjusts the key-value cache for each environment to ensure they all have the same size. + + In a multi-environment setting, the key-value cache (kv_cache) for each environment is stored separately. + During recurrent inference, the kv_cache sizes may vary across environments. This method pads each kv_cache + to match the largest size found among them, facilitating batch processing in the transformer forward pass. + + Arguments: + - is_init_infer (:obj:`bool`): Indicates if this is an initial inference. Default is True. + Returns: + - list: Updated sizes of the key-value caches. + """ + # Find the maximum size among all key-value caches + max_size = max(self.keys_values_wm_size_list) + + # Iterate over each layer of the transformer + for layer in range(self.num_layers): + kv_cache_k_list = [] + kv_cache_v_list = [] + + # Enumerate through each environment's key-value pairs + for idx, keys_values in enumerate(self.keys_values_wm_list): + k_cache = keys_values[layer]._k_cache._cache + v_cache = keys_values[layer]._v_cache._cache + + effective_size = self.keys_values_wm_size_list[idx] + pad_size = max_size - effective_size + + # If padding is required, trim the end and pad the beginning of the cache + if pad_size > 0: + k_cache_trimmed = k_cache[:, :, :-pad_size, :] + v_cache_trimmed = v_cache[:, :, :-pad_size, :] + k_cache_padded = F.pad(k_cache_trimmed, (0, 0, pad_size, 0), "constant", 0) + v_cache_padded = F.pad(v_cache_trimmed, (0, 0, pad_size, 0), "constant", 0) + else: + k_cache_padded = k_cache + v_cache_padded = v_cache + + kv_cache_k_list.append(k_cache_padded) + kv_cache_v_list.append(v_cache_padded) + + # Stack the caches along a new dimension and remove any extra dimensions + self.keys_values_wm._keys_values[layer]._k_cache._cache = torch.stack(kv_cache_k_list, dim=0).squeeze(1) + self.keys_values_wm._keys_values[layer]._v_cache._cache = torch.stack(kv_cache_v_list, dim=0).squeeze(1) + + # Update the cache size to the maximum size + self.keys_values_wm._keys_values[layer]._k_cache._size = max_size + self.keys_values_wm._keys_values[layer]._v_cache._size = max_size + + return self.keys_values_wm_size_list + + #@profile + def update_cache_context(self, latent_state, is_init_infer=True, simulation_index=0, + latent_state_index_in_search_path=[], valid_context_lengths=None): + """ + Update the cache context with the given latent state. + + Arguments: + - latent_state (:obj:`torch.Tensor`): The latent state tensor. + - is_init_infer (:obj:`bool`): Flag to indicate if this is the initial inference. + - simulation_index (:obj:`int`): Index of the simulation. + - latent_state_index_in_search_path (:obj:`list`): List of indices in the search path. + - valid_context_lengths (:obj:`list`): List of valid context lengths. + """ + if self.context_length <= 2: + # No context to update if the context length is less than or equal to 2. + return + for i in range(latent_state.size(0)): + # ============ Iterate over each environment ============ + cache_key = hash_state(latent_state[i].view(-1).cpu().numpy()) # latent_state[i] is torch.Tensor + context_length = self.context_length + + if not is_init_infer: + # ============ Internal Node ============ + # Retrieve KV from global KV cache self.keys_values_wm to single environment KV cache self.keys_values_wm_single_env, ensuring correct positional encoding + current_max_context_length = max(self.keys_values_wm_size_list_current) + trim_size = current_max_context_length - self.keys_values_wm_size_list_current[i] + for layer in range(self.num_layers): + # ============ Apply trimming and padding to each layer of kv_cache ============ + # cache shape [batch_size, num_heads, sequence_length, features] + k_cache_current = self.keys_values_wm._keys_values[layer]._k_cache._cache[i] + v_cache_current = self.keys_values_wm._keys_values[layer]._v_cache._cache[i] + + if trim_size > 0: + # Trim invalid leading zeros as per effective length + # Remove the first trim_size zero kv items + k_cache_trimmed = k_cache_current[:, trim_size:, :] + v_cache_trimmed = v_cache_current[:, trim_size:, :] + # If effective length < current_max_context_length, pad the end of cache with 'trim_size' zeros + k_cache_padded = F.pad(k_cache_trimmed, (0, 0, 0, trim_size), "constant", + 0) # Pad with 'trim_size' zeros at end of cache + v_cache_padded = F.pad(v_cache_trimmed, (0, 0, 0, trim_size), "constant", 0) + else: + k_cache_padded = k_cache_current + v_cache_padded = v_cache_current + + # Update cache of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + # Update size of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = \ + self.keys_values_wm_size_list_current[i] + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = \ + self.keys_values_wm_size_list_current[i] + + # ============ NOTE: Very Important ============ + if self.keys_values_wm_single_env._keys_values[layer]._k_cache._size >= context_length - 1: + # Keep only the last self.context_length-3 timesteps of context + # For memory environments, training is for H steps, recurrent_inference might exceed H steps + # Assuming cache dimension is [batch_size, num_heads, sequence_length, features] + k_cache_current = self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache + v_cache_current = self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache + + # Remove the first 2 steps, keep the last self.context_length-3 steps + k_cache_trimmed = k_cache_current[:, :, 2:context_length - 1, :].squeeze(0) + v_cache_trimmed = v_cache_current[:, :, 2:context_length - 1, :].squeeze(0) + + # Index pre-computed positional encoding differences + pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)] + pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)] + # ============ NOTE: Very Important ============ + # Apply positional encoding correction to k and v + k_cache_trimmed += pos_emb_diff_k.squeeze(0) + v_cache_trimmed += pos_emb_diff_v.squeeze(0) + + # Pad the last 3 steps along the third dimension with zeros + # F.pad parameters (0, 0, 0, 3) specify padding amounts for each dimension: (left, right, top, bottom). For 3D tensor, they correspond to (dim2 left, dim2 right, dim1 left, dim1 right). + padding_size = (0, 0, 0, 3) + k_cache_padded = F.pad(k_cache_trimmed, padding_size, 'constant', 0) + v_cache_padded = F.pad(v_cache_trimmed, padding_size, 'constant', 0) + # Update single environment cache + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = context_length - 3 + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = context_length - 3 + + else: + # ============ Root Node ============ + # Retrieve KV from global KV cache self.keys_values_wm to single environment KV cache self.keys_values_wm_single_env, ensuring correct positional encoding + + for layer in range(self.num_layers): + # ============ Apply trimming and padding to each layer of kv_cache ============ + + if self.keys_values_wm._keys_values[layer]._k_cache._size < context_length - 1: # Keep only the last self.context_length-1 timesteps of context + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # Shape torch.Size([2, 100, 512]) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size + else: + # Assuming cache dimension is [batch_size, num_heads, sequence_length, features] + k_cache_current = self.keys_values_wm._keys_values[layer]._k_cache._cache[i] + v_cache_current = self.keys_values_wm._keys_values[layer]._v_cache._cache[i] + + # Remove the first 2 steps, keep the last self.context_length-3 steps + k_cache_trimmed = k_cache_current[:, 2:context_length - 1, :] + v_cache_trimmed = v_cache_current[:, 2:context_length - 1, :] + + # Index pre-computed positional encoding differences + pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)] + pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)] + # ============ NOTE: Very Important ============ + # Apply positional encoding correction to k and v + k_cache_trimmed += pos_emb_diff_k.squeeze(0) + v_cache_trimmed += pos_emb_diff_v.squeeze(0) + + # Pad the last 3 steps along the third dimension with zeros + # F.pad parameters (0, 0, 0, 3) specify padding amounts for each dimension: (left, right, top, bottom). For 3D tensor, they correspond to (dim2 left, dim2 right, dim1 left, dim1 right). + padding_size = (0, 0, 0, 3) + k_cache_padded = F.pad(k_cache_trimmed, padding_size, 'constant', 0) + v_cache_padded = F.pad(v_cache_trimmed, padding_size, 'constant', 0) + # Update cache of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + # Update size of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = context_length - 3 + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = context_length - 3 + + if is_init_infer: + # Store the latest key-value cache for initial inference + cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) + self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index + else: + # Store the latest key-value cache for recurrent inference + cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) + self.past_kv_cache_recurrent_infer[cache_key] = cache_index + + #@profile + def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, + simulation_index: int = 0, task_id = 0) -> list: + """ + Retrieves or generates key-value caches for each environment based on the latent state. + + For each environment, this method either retrieves a matching cache from the predefined + caches if available, or generates a new cache if no match is found. The method updates + the internal lists with these caches and their sizes. + + Arguments: + - latent_state (:obj:`list`): List of latent states for each environment. + - ready_env_num (:obj:`int`): Number of environments ready for processing. + - simulation_index (:obj:`int`, optional): Index for simulation tracking. Default is 0. + Returns: + - list: Sizes of the key-value caches for each environment. + """ + for i in range(ready_env_num): + self.total_query_count += 1 + state_single_env = latent_state[i] # latent_state[i] is np.array + cache_key = hash_state(state_single_env) + + if self.reanalyze_phase: + # TODO: check if this is correct + matched_value = None + else: + # Try to retrieve the cached value from past_kv_cache_init_infer_envs + cache_index = self.past_kv_cache_init_infer_envs[i].get(cache_key) + if cache_index is not None: + matched_value = self.shared_pool_init_infer[i][cache_index] + else: + matched_value = None + + # If not found, try to retrieve from past_kv_cache_recurrent_infer + if matched_value is None: + matched_value = self.shared_pool_recur_infer[self.past_kv_cache_recurrent_infer.get(cache_key)] + + if matched_value is not None: + # If a matching cache is found, add it to the lists + self.hit_count += 1 + # Perform a deep copy because the transformer's forward pass might modify matched_value in-place + self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) + self.keys_values_wm_size_list.append(matched_value.size) + else: + # If no matching cache is found, generate a new one using zero reset + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values( + n=1, max_tokens=self.context_length + ) + self.forward( + {'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, + past_keys_values=self.keys_values_wm_single_env, is_init_infer=True, task_id=task_id + ) + self.keys_values_wm_list.append(self.keys_values_wm_single_env) + self.keys_values_wm_size_list.append(1) + + return self.keys_values_wm_size_list + + # TODO: legend在下边 + # def plot_embeddings(self, tsne_results, task_ids, observations, samples_per_task=5, save_dir='tsne_plots_26games'): + # """ + # 生成 t-SNE 可视化图,并在图中为每个任务随机标注指定数量的观测样本图像。 + + # 参数: + # - tsne_results: t-SNE 降维结果 (N x 2 的数组) + # - task_ids: 环境任务 ID,用于着色 (N 的数组) + # - observations: 对应的观测样本 (N x C x H x W 的张量或数组) + # - samples_per_task: 每个任务选择的样本数量,默认 5 + # - save_dir: 保存路径,默认 'tsne_plots_26games' + # """ + # import matplotlib.colors as mcolors + + # # 创建保存目录 + # os.makedirs(save_dir, exist_ok=True) + # print(f"[INFO] 保存目录已创建或已存在: {save_dir}") + + # # 创建 t-SNE 图 + # print("[INFO] 开始绘制 t-SNE 散点图...") + # plt.figure(figsize=(16, 10)) + + # # 散点图 + # scatter = plt.scatter( + # tsne_results[:, 0], + # tsne_results[:, 1], + # c=[self.colors[tid] for tid in task_ids], + # alpha=0.6, + # edgecolor='w', + # linewidth=0.5 + # ) + + # # 创建自定义图例 + # legend_elements = [] + # for idx, env_id in enumerate(self.env_id_list): + # short_name = self.env_short_names.get(env_id, env_id) + # color = self.colors[idx] + # legend_elements.append( + # Patch(facecolor=color, edgecolor='w', label=f"{idx}: {short_name}") + # ) + + # # 动态调整图例的列数,根据任务数量 + # num_cols = min(4, len(legend_elements)) # 最多4列 + + # # 将图例放在图像下方 + # plt.legend( + # handles=legend_elements, + # title="Environment IDs", + # loc='upper center', + # bbox_to_anchor=(0.5, -0.15), # 调整 y 值将图例移至下方 + # fontsize=10, + # title_fontsize=12, + # ncol=num_cols, + # frameon=False # 去除图例边框,增强美观 + # ) + + # # 设置标题和轴标签 + # plt.title("t-SNE of Latent States across Environments", fontsize=16) + # plt.xlabel("t-SNE Dimension 1", fontsize=14) + # plt.ylabel("t-SNE Dimension 2", fontsize=14) + # plt.xticks(fontsize=12) + # plt.yticks(fontsize=12) + # plt.grid(True, linestyle='--', alpha=0.5) + # print(f"[INFO] t-SNE 散点图绘制完成,共有 {len(tsne_results)} 个点。") + + # # 为每个任务选择指定数量的样本进行图像标注 + # print(f"[INFO] 开始为每个任务选择 {samples_per_task} 个样本进行图像标注...") + # for task_id in range(len(self.env_id_list)): + # # 找到当前任务的所有索引 + # task_indices = np.where(task_ids == task_id)[0] + # if len(task_indices) == 0: + # print(f"[WARNING] 任务 ID {task_id} 没有对应的样本。") + # continue + # # 如果样本数量少于所需,全部选取 + # if len(task_indices) < samples_per_task: + # selected_indices = task_indices + # print(f"[INFO] 任务 ID {task_id} 的样本数量 ({len(task_indices)}) 少于 {samples_per_task},选取全部。") + # else: + # selected_indices = np.random.choice(task_indices, size=samples_per_task, replace=False) + # print(f"[INFO] 任务 ID {task_id} 随机选取 {samples_per_task} 个样本进行标注。") + + # for idx in selected_indices: + # img = observations[idx] + # if isinstance(img, torch.Tensor): + # img = img.cpu().numpy() + # if img.shape[0] == 1 or img.shape[0] == 3: # 处理灰度图或 RGB 图 + # img = np.transpose(img, (1, 2, 0)) + # else: + # raise ValueError(f"Unsupported image shape: {img.shape}") + + # # 标准化图像到 [0,1] 范围 + # img_min, img_max = img.min(), img.max() + # if img_max - img_min > 1e-5: + # img = (img - img_min) / (img_max - img_min) + # else: + # img = np.zeros_like(img) + + # imagebox = OffsetImage(img, zoom=0.5) + # ab = AnnotationBbox( + # imagebox, + # (tsne_results[idx, 0], tsne_results[idx, 1]), + # frameon=False, + # pad=0.3 + # ) + # plt.gca().add_artist(ab) + # print(f"[INFO] 已添加图像标注: 任务 ID {task_id}, 点索引 {idx}, t-SNE 坐标 ({tsne_results[idx, 0]:.2f}, {tsne_results[idx, 1]:.2f})") + + # # 调整布局以适应图例 + # plt.tight_layout(rect=[0, 0.05, 1, 1]) # 为下方的图例预留空间 + + # # 保存图像,使用高分辨率 + # save_path_png = os.path.join(save_dir, 'tsne_plot.png') + # save_path_pdf = os.path.join(save_dir, 'tsne_plot.pdf') + # plt.savefig(save_path_png, dpi=300, bbox_inches='tight') + # plt.savefig(save_path_pdf, dpi=300, bbox_inches='tight') + # print(f"[INFO] t-SNE 可视化图已保存至: {save_path_png} 和 {save_path_pdf}") + # plt.close() + + # TODO: legend在右边 + def plot_embeddings(self, tsne_results, task_ids, observations, samples_per_task=5, save_dir='tsne_plots_26games'): + """ + 生成 t-SNE 可视化图,并在图中为每个任务随机标注指定数量的观测样本图像。 + + 参数: + - tsne_results: t-SNE 降维结果 (N x 2 的数组) + - task_ids: 环境任务 ID,用于着色 (N 的数组) + - observations: 对应的观测样本 (N x C x H x W 的张量或数组) + - samples_per_task: 每个任务选择的样本数量,默认 5 + - save_dir: 保存路径,默认 'tsne_plots_26games' + """ + + # 创建保存目录 + os.makedirs(save_dir, exist_ok=True) + print(f"[INFO] 保存目录已创建或已存在: {save_dir}") + + # 创建 t-SNE 图 + print("[INFO] 开始绘制 t-SNE 散点图...") + plt.figure(figsize=(18, 10)) # 增大图像宽度以适应右侧图例 + + # 散点图 + scatter = plt.scatter( + tsne_results[:, 0], + tsne_results[:, 1], + c=[self.colors[tid] for tid in task_ids], + alpha=0.6, + edgecolor='w', + linewidth=0.5 + ) + + # 创建自定义图例 + legend_elements = [] + for idx, env_id in enumerate(self.env_id_list): + short_name = self.env_short_names.get(env_id, env_id) + color = self.colors[idx] + legend_elements.append( + Patch(facecolor=color, edgecolor='w', label=f"{idx}: {short_name}") + ) + + # 将图例放在图像右侧,并且每个图例项占一行 + plt.legend( + handles=legend_elements, + title="Environment IDs", + loc='center left', + bbox_to_anchor=(1, 0.5), # 图例在图像右侧中央 + fontsize=10, + title_fontsize=12, + ncol=1, + frameon=False # 去除图例边框,增强美观 + ) + + # 设置标题和轴标签 + plt.title("t-SNE of Latent States across Environments", fontsize=16) + plt.xlabel("t-SNE Dimension 1", fontsize=14) + plt.ylabel("t-SNE Dimension 2", fontsize=14) + plt.xticks(fontsize=12) + plt.yticks(fontsize=12) + plt.grid(True, linestyle='--', alpha=0.5) + print(f"[INFO] t-SNE 散点图绘制完成,共有 {len(tsne_results)} 个点。") + + # 为每个任务选择指定数量的样本进行图像标注 + print(f"[INFO] 开始为每个任务选择 {samples_per_task} 个样本进行图像标注...") + for task_id in range(len(self.env_id_list)): + # 找到当前任务的所有索引 + task_indices = np.where(task_ids == task_id)[0] + if len(task_indices) == 0: + print(f"[WARNING] 任务 ID {task_id} 没有对应的样本。") + continue + # 如果样本数量少于所需,全部选取 + if len(task_indices) < samples_per_task: + selected_indices = task_indices + print(f"[INFO] 任务 ID {task_id} 的样本数量 ({len(task_indices)}) 少于 {samples_per_task},选取全部。") + else: + selected_indices = np.random.choice(task_indices, size=samples_per_task, replace=False) + print(f"[INFO] 任务 ID {task_id} 随机选取 {samples_per_task} 个样本进行标注。") + + for idx in selected_indices: + img = observations[idx] + if isinstance(img, torch.Tensor): + img = img.cpu().numpy() + if img.shape[0] == 1 or img.shape[0] == 3: # 处理灰度图或 RGB 图 + img = np.transpose(img, (1, 2, 0)) + else: + raise ValueError(f"Unsupported image shape: {img.shape}") + + # 标准化图像到 [0,1] 范围 + img_min, img_max = img.min(), img.max() + if img_max - img_min > 1e-5: + img = (img - img_min) / (img_max - img_min) + else: + img = np.zeros_like(img) + + imagebox = OffsetImage(img, zoom=0.5) + ab = AnnotationBbox( + imagebox, + (tsne_results[idx, 0], tsne_results[idx, 1]), + frameon=False, + pad=0.3 + ) + plt.gca().add_artist(ab) + print(f"[INFO] 已添加图像标注: 任务 ID {task_id}, 点索引 {idx}, t-SNE 坐标 ({tsne_results[idx, 0]:.2f}, {tsne_results[idx, 1]:.2f})") + + # 调整布局以适应图例 + plt.tight_layout(rect=[0, 0, 0.9, 1]) # 为右侧的图例预留空间 + + # 保存图像,使用高分辨率 + save_path_png = os.path.join(save_dir, 'tsne_plot.png') + save_path_pdf = os.path.join(save_dir, 'tsne_plot.pdf') + plt.savefig(save_path_png, dpi=300, bbox_inches='tight') + plt.savefig(save_path_pdf, dpi=300, bbox_inches='tight') + print(f"[INFO] t-SNE 可视化图已保存至: {save_path_png} 和 {save_path_pdf}") + plt.close() + + @torch.no_grad() + def gather_and_plot(self, local_embeddings, local_task_ids, local_observations): + world_size = dist.get_world_size() + rank = dist.get_rank() + + # 准备接收来自所有进程的CUDA张量 + embeddings_list = [torch.zeros_like(local_embeddings) for _ in range(world_size)] + task_ids_list = [torch.zeros_like(local_task_ids) for _ in range(world_size)] + + # 准备接收来自所有进程的CPU对象 + observations_list = [None for _ in range(world_size)] + + try: + # 收集CUDA张量:embeddings和task_ids + dist.all_gather(embeddings_list, local_embeddings) + dist.all_gather(task_ids_list, local_task_ids) + + # 收集CPU对象:observations + local_observations_cpu = local_observations.cpu().numpy().tolist() + dist.all_gather_object(observations_list, local_observations_cpu) + except RuntimeError as e: + print(f"Rank {rank}: all_gather failed with error: {e}") + return + + if rank == 0: + # 拼接所有embeddings和task_ids + all_embeddings = torch.cat(embeddings_list, dim=0).cpu().numpy() + all_task_ids = torch.cat(task_ids_list, dim=0).cpu().numpy() + + # 拼接所有observations + all_observations = [] + for obs in observations_list: + all_observations.extend(obs) + all_observations = np.array(all_observations) + + print(f"Shape of all_embeddings: {all_embeddings.shape}") + all_embeddings = all_embeddings.reshape(-1, all_embeddings.shape[-1]) + print(f"Shape of all_observations: {all_observations.shape}") + all_observations = all_observations.reshape(-1, *all_observations.shape[-3:]) + + # 执行t-SNE降维 + tsne = TSNE(n_components=2, random_state=42) + tsne_results = tsne.fit_transform(all_embeddings) + + # 绘制并保存图像 + self.plot_embeddings(tsne_results, all_task_ids, all_observations, save_dir=f'tsne_plots_{self.num_tasks}games') + + #@profile + def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar_transform_handle=None, task_id = 0, **kwargs: Any) -> LossWithIntermediateLosses: + # Encode observations into latent state representations + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], task_id=task_id) + + if self.analysis_tsne: + # =========== tsne analysis =========== + # 确保embeddings在CUDA设备上且为稠密张量 + if not obs_embeddings.is_cuda: + obs_embeddings = obs_embeddings.cuda() + obs_embeddings = obs_embeddings.contiguous() + + # 保存当前进程的 embeddings 和 task_id + local_embeddings = obs_embeddings.detach() + local_task_ids = torch.full((local_embeddings.size(0),), task_id, dtype=torch.long, device=local_embeddings.device) + + # 将observations移到CPU并转换为numpy + local_observations = batch['observations'].detach().cpu() + + # 进行数据收集和可视化 + self.gather_and_plot(local_embeddings, local_task_ids, local_observations) + + # ========= logging for analysis ========= + if self.analysis_dormant_ratio: + # Calculate dormant ratio of the encoder + shape = batch['observations'].shape # (..., C, H, W) + inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) # (32,5,3,64,64) -> (160,3,64,64) + dormant_ratio_encoder = cal_dormant_ratio(self.tokenizer.representation_network, inputs.detach(), + percentage=self.dormant_threshold) + self.past_kv_cache_init_infer.clear() + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + torch.cuda.empty_cache() + else: + dormant_ratio_encoder = torch.tensor(0.) + + # Calculate the L2 norm of the latent state roots + latent_state_l2_norms = torch.norm(obs_embeddings, p=2, dim=2).mean() + + if self.obs_type == 'image': + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) + + # ========== for visualization ========== + # Uncomment the lines below for visual analysis + # original_images, reconstructed_images = batch['observations'], reconstructed_images + # target_policy = batch['target_policy'] + # ==== for value priority ==== + # target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # true_rewards = inverse_scalar_transform_handle(batch['rewards'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # ========== for visualization ========== + + # Calculate reconstruction loss and perceptual loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 + # perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 + latent_recon_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + elif self.obs_type == 'vector': + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings.reshape(-1, self.embed_dim)) + # # Calculate reconstruction loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 25), + # reconstructed_images) + latent_recon_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + elif self.obs_type == 'image_memory': + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) + # original_images, reconstructed_images = batch['observations'], reconstructed_images + + # ========== for visualization ========== + # Uncomment the lines below for visual analysis + # target_policy = batch['target_policy'] + # target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # true_rewards = inverse_scalar_transform_handle(batch['rewards'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # ========== for visualization ========== + + # Calculate reconstruction loss and perceptual loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 5, 5), + # reconstructed_images) + + latent_recon_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + # Action tokens + act_tokens = rearrange(batch['actions'], 'b l -> b l 1') + + # Forward pass to obtain predictions for observations, rewards, and policies + outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, task_id=task_id) + + # ========= logging for analysis ========= + if self.analysis_dormant_ratio: + # Calculate dormant ratio of the world model + dormant_ratio_world_model = cal_dormant_ratio(self, { + 'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens.detach())}, + percentage=self.dormant_threshold) + self.past_kv_cache_init_infer.clear() + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + torch.cuda.empty_cache() + else: + dormant_ratio_world_model = torch.tensor(0.) + + # ========== for visualization ========== + # Uncomment the lines below for visualization + # predict_policy = outputs.logits_policy + # predict_policy = F.softmax(outputs.logits_policy, dim=-1) + # predict_value = inverse_scalar_transform_handle(outputs.logits_value.reshape(-1, 101)).reshape(batch['observations'].shape[0], batch['observations'].shape[1], 1) + # predict_rewards = inverse_scalar_transform_handle(outputs.logits_rewards.reshape(-1, 101)).reshape(batch['observations'].shape[0], batch['observations'].shape[1], 1) + # import pdb; pdb.set_trace() + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=[], suffix='pong_H10_H4_0613') + + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_success_episode') + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_fail_episode') + # ========== for visualization ========== + + # For training stability, use target_tokenizer to compute the true next latent state representations + with torch.no_grad(): + target_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations'], task_id=task_id) + + # Compute labels for observations, rewards, and ends + labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(target_obs_embeddings, + batch['rewards'], + batch['ends'], + batch['mask_padding']) + + # Reshape the logits and labels for observations + logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o') + labels_observations = labels_observations.reshape(-1, self.projection_input_dim) + + # Compute prediction loss for observations. Options: MSE and Group KL + if self.predict_latent_loss_type == 'mse': + # MSE loss, directly compare logits and labels + loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations, reduction='none').mean( + -1) + elif self.predict_latent_loss_type == 'group_kl': + # Group KL loss, group features and calculate KL divergence within each group + batch_size, num_features = logits_observations.shape + epsilon = 1e-6 + logits_reshaped = logits_observations.reshape(batch_size, self.num_groups, self.group_size) + epsilon + labels_reshaped = labels_observations.reshape(batch_size, self.num_groups, self.group_size) + epsilon + + loss_obs = F.kl_div(logits_reshaped.log(), labels_reshaped, reduction='none').sum(dim=-1).mean(dim=-1) + + # ========== for debugging ========== + # print('loss_obs:', loss_obs.mean()) + # assert not torch.isnan(loss_obs).any(), "loss_obs contains NaN values" + # assert not torch.isinf(loss_obs).any(), "loss_obs contains Inf values" + # for name, param in self.tokenizer.representation_network.named_parameters(): + # print('name, param.mean(), param.std():', name, param.mean(), param.std()) + + # Apply mask to loss_obs + mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) + loss_obs = (loss_obs * mask_padding_expanded) + + # Compute labels for policy and value + labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'], + batch['target_policy'], + batch['mask_padding']) + + # Compute losses for rewards, policy, and value + loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') + loss_policy, orig_policy_loss, policy_entropy = self.compute_cross_entropy_loss(outputs, labels_policy, batch, + element='policy') + loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') + + # Compute timesteps + timesteps = torch.arange(batch['actions'].shape[1], device=batch['actions'].device) + # Compute discount coefficients for each timestep + discounts = self.gamma ** timesteps + + if batch['mask_padding'].sum() == 0: + assert False, "mask_padding is all zeros" + + # Group losses into first step, middle step, and last step + first_step_losses = {} + middle_step_losses = {} + last_step_losses = {} + # batch['mask_padding'] indicates mask status for future H steps, exclude masked losses to maintain accurate mean statistics + # Group losses for each loss item + for loss_name, loss_tmp in zip( + ['loss_obs', 'loss_rewards', 'loss_value', 'loss_policy', 'orig_policy_loss', 'policy_entropy'], + [loss_obs, loss_rewards, loss_value, loss_policy, orig_policy_loss, policy_entropy] + ): + if loss_name == 'loss_obs': + seq_len = batch['actions'].shape[1] - 1 + # Get the corresponding mask_padding + mask_padding = batch['mask_padding'][:, 1:seq_len] + else: + seq_len = batch['actions'].shape[1] + # Get the corresponding mask_padding + mask_padding = batch['mask_padding'][:, :seq_len] + + # Adjust loss shape to (batch_size, seq_len) + loss_tmp = loss_tmp.view(-1, seq_len) + + # First step loss + first_step_mask = mask_padding[:, 0] + first_step_losses[loss_name] = loss_tmp[:, 0][first_step_mask].mean() + + # Middle step loss + middle_step_index = seq_len // 2 + middle_step_mask = mask_padding[:, middle_step_index] + middle_step_losses[loss_name] = loss_tmp[:, middle_step_index][middle_step_mask].mean() + + # Last step loss + last_step_mask = mask_padding[:, -1] + last_step_losses[loss_name] = loss_tmp[:, -1][last_step_mask].mean() + + # Discount reconstruction loss and perceptual loss + discounted_latent_recon_loss = latent_recon_loss + discounted_perceptual_loss = perceptual_loss + + # Calculate overall discounted loss + discounted_loss_obs = (loss_obs.view(-1, batch['actions'].shape[1] - 1) * discounts[1:]).sum()/ batch['mask_padding'][:,1:].sum() + discounted_loss_rewards = (loss_rewards.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_loss_value = (loss_value.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_loss_policy = (loss_policy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_orig_policy_loss = (orig_policy_loss.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_policy_entropy = (policy_entropy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + + if self.continuous_action_space: + return LossWithIntermediateLosses( + latent_recon_loss_weight=self.latent_recon_loss_weight, + perceptual_loss_weight=self.perceptual_loss_weight, + continuous_action_space=True, + loss_obs=discounted_loss_obs, + loss_rewards=discounted_loss_rewards, + loss_value=discounted_loss_value, + loss_policy=discounted_loss_policy, + latent_recon_loss=discounted_latent_recon_loss, + perceptual_loss=discounted_perceptual_loss, + orig_policy_loss=discounted_orig_policy_loss, + policy_entropy=discounted_policy_entropy, + first_step_losses=first_step_losses, + middle_step_losses=middle_step_losses, + last_step_losses=last_step_losses, + dormant_ratio_encoder=dormant_ratio_encoder, + dormant_ratio_world_model=dormant_ratio_world_model, + latent_state_l2_norms=latent_state_l2_norms, + policy_mu=mu, + policy_sigma=sigma, + target_sampled_actions=target_sampled_actions, + ) + else: + return LossWithIntermediateLosses( + latent_recon_loss_weight=self.latent_recon_loss_weight, + perceptual_loss_weight=self.perceptual_loss_weight, + continuous_action_space=False, + loss_obs=discounted_loss_obs, + loss_rewards=discounted_loss_rewards, + loss_value=discounted_loss_value, + loss_policy=discounted_loss_policy, + latent_recon_loss=discounted_latent_recon_loss, + perceptual_loss=discounted_perceptual_loss, + orig_policy_loss=discounted_orig_policy_loss, + policy_entropy=discounted_policy_entropy, + first_step_losses=first_step_losses, + middle_step_losses=middle_step_losses, + last_step_losses=last_step_losses, + dormant_ratio_encoder=dormant_ratio_encoder, + dormant_ratio_world_model=dormant_ratio_world_model, + latent_state_l2_norms=latent_state_l2_norms, + ) + + #@profile + def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): + # Assume outputs is an object with logits attributes like 'rewards', 'policy', and 'value'. + # labels is a target tensor for comparison. batch is a dictionary with a mask indicating valid timesteps. + + logits = getattr(outputs, f'logits_{element}') + + # Reshape your tensors + logits = rearrange(logits, 'b t e -> (b t) e') + labels = labels.reshape(-1, labels.shape[-1]) # Assume labels initially have shape [batch, time, dim] + + # Reshape your mask. True indicates valid data. + mask_padding = rearrange(batch['mask_padding'], 'b t -> (b t)') + + # Compute cross-entropy loss + loss = -(torch.log_softmax(logits, dim=1) * labels).sum(1) + loss = (loss * mask_padding) + + # if torch.isnan(loss).any(): + # raise ValueError(f"NaN detected in outputs for batch {batch} and element '{element}'") + + if element == 'policy': + # Compute policy entropy loss + policy_entropy = self.compute_policy_entropy_loss(logits, mask_padding) + # Combine losses with specified weight + combined_loss = loss - self.policy_entropy_weight * policy_entropy + return combined_loss, loss, policy_entropy + + return loss + + #@profile + def compute_policy_entropy_loss(self, logits, mask): + # Compute entropy of the policy + probs = torch.softmax(logits, dim=1) + log_probs = torch.log_softmax(logits, dim=1) + entropy = -(probs * log_probs).sum(1) + # Apply mask and return average entropy loss + entropy_loss = (entropy * mask) + return entropy_loss + + #@profile + def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, + mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # assert torch.all(ends.sum(dim=1) <= 1) # Each sequence sample should have at most one 'done' flag + mask_fill = torch.logical_not(mask_padding) + + # Prepare observation labels + labels_observations = obs_embeddings.contiguous().view(rewards.shape[0], -1, self.projection_input_dim)[:, 1:] + + # Fill the masked areas of rewards + mask_fill_rewards = mask_fill.unsqueeze(-1).expand_as(rewards) + labels_rewards = rewards.masked_fill(mask_fill_rewards, -100) + + # Fill the masked areas of ends + # labels_ends = ends.masked_fill(mask_fill, -100) + + # return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1) + return labels_observations, labels_rewards.view(-1, self.support_size), None + + #@profile + def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, + mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ Compute labels for value and policy predictions. """ + mask_fill = torch.logical_not(mask_padding) + + # Fill the masked areas of policy + mask_fill_policy = mask_fill.unsqueeze(-1).expand_as(target_policy) + labels_policy = target_policy.masked_fill(mask_fill_policy, -100) + + # Fill the masked areas of value + mask_fill_value = mask_fill.unsqueeze(-1).expand_as(target_value) + labels_value = target_value.masked_fill(mask_fill_value, -100) + + if self.continuous_action_space: + return None, labels_value.reshape(-1, self.support_size) + else: + return labels_policy.reshape(-1, self.action_space_size), labels_value.reshape(-1, self.support_size) + + #@profile + def clear_caches(self): + """ + Clears the caches of the world model. + """ + for kv_cache_dict_env in self.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + + print(f'rank {self._rank} Cleared {self.__class__.__name__} past_kv_cache.') + + def __repr__(self) -> str: + return "transformer-based latent world_model of UniZero" diff --git a/lzero/policy/muzero_multitask.py b/lzero/policy/muzero_multitask.py new file mode 100644 index 000000000..933a61f34 --- /dev/null +++ b/lzero/policy/muzero_multitask.py @@ -0,0 +1,859 @@ +import copy +from typing import List, Dict, Tuple, Union + +import numpy as np +import torch +import torch.optim as optim +from ding.model import model_wrap +from ding.torch_utils import to_tensor +from ding.utils import POLICY_REGISTRY + +from lzero.mcts import MuZeroMCTSCtree as MCTSCtree +from lzero.model import ImageTransforms +from lzero.model.utils import cal_dormant_ratio +from lzero.policy import ( + scalar_transform, + InverseScalarTransform, + cross_entropy_loss, + phi_transform, + DiscreteSupport, + to_torch_float_tensor, + mz_network_output_unpack, + select_action, + negative_cosine_similarity, + prepare_obs, +) +from lzero.policy.muzero import MuZeroPolicy + + +def generate_task_loss_dict(multi_task_losses, task_name_template, task_id): + """ + 生成每个任务的损失字典 + :param multi_task_losses: 包含每个任务损失的列表 + :param task_name_template: 任务名称模板,例如 'loss_task{}' + :param task_id: 任务起始ID + :return: 一个字典,包含每个任务的损失 + """ + task_loss_dict = {} + for task_idx, task_loss in enumerate(multi_task_losses): + task_name = task_name_template.format(task_idx + task_id) + try: + task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else task_loss + except Exception: + task_loss_dict[task_name] = task_loss + return task_loss_dict + +class WrappedModelV2: + def __init__(self, tokenizer, transformer, pos_emb, task_emb, act_embedding_table): + self.tokenizer = tokenizer + self.transformer = transformer + self.pos_emb = pos_emb + self.task_emb = task_emb + self.act_embedding_table = act_embedding_table + + def parameters(self): + # 返回 tokenizer, transformer 以及所有嵌入层的参数 + return ( + list(self.tokenizer.parameters()) + + list(self.transformer.parameters()) + + list(self.pos_emb.parameters()) + + list(self.task_emb.parameters()) + + list(self.act_embedding_table.parameters()) + ) + + def zero_grad(self, set_to_none=False): + # 将 tokenizer, transformer 和所有嵌入层的梯度设为零 + self.tokenizer.zero_grad(set_to_none=set_to_none) + self.transformer.zero_grad(set_to_none=set_to_none) + self.pos_emb.zero_grad(set_to_none=set_to_none) + self.task_emb.zero_grad(set_to_none=set_to_none) + self.act_embedding_table.zero_grad(set_to_none=set_to_none) + +@POLICY_REGISTRY.register('muzero_multitask') +class MuZeroMTPolicy(MuZeroPolicy): + """ + 概述: + MuZero 的多任务策略类,扩展自 MuZeroPolicy。支持同时训练多个任务,通过分离每个任务的损失并进行优化。 + """ + + # MuZeroMTPolicy 的默认配置 + config = dict( + type='muzero_multitask', + model=dict( + model_type='conv', # options={'mlp', 'conv'} + continuous_action_space=False, + observation_shape=(4, 96, 96), # example shape + self_supervised_learning_loss=False, + categorical_distribution=True, + image_channel=1, + frame_stack_num=1, + num_res_blocks=1, + num_channels=64, + support_scale=300, + bias=True, + discrete_action_encoding_type='one_hot', + res_connection_in_dynamics=True, + norm_type='BN', + analysis_sim_norm=False, + analysis_dormant_ratio=False, + harmony_balance=False, + ), + # ****** common ****** + use_rnd_model=False, + multi_gpu=False, + sampled_algo=False, + gumbel_algo=False, + mcts_ctree=True, + cuda=True, + collector_env_num=8, + evaluator_env_num=3, + env_type='not_board_games', + action_type='fixed_action_space', + battle_mode='play_with_bot_mode', + monitor_extra_statistics=True, + game_segment_length=200, + eval_offline=False, + cal_dormant_ratio=False, + analysis_sim_norm=False, + analysis_dormant_ratio=False, + + # ****** observation ****** + transform2string=False, + gray_scale=False, + use_augmentation=False, + augmentation=['shift', 'intensity'], + + # ******* learn ****** + ignore_done=False, + update_per_collect=None, + replay_ratio=0.25, + batch_size=256, + optim_type='SGD', + learning_rate=0.2, + target_update_freq=100, + target_update_freq_for_intrinsic_reward=1000, + weight_decay=1e-4, + momentum=0.9, + grad_clip_value=10, + n_episode=8, + num_segments=8, + num_simulations=50, + discount_factor=0.997, + td_steps=5, + num_unroll_steps=5, + reward_loss_weight=1, + value_loss_weight=0.25, + policy_loss_weight=1, + policy_entropy_weight=0, + ssl_loss_weight=0, + lr_piecewise_constant_decay=True, + threshold_training_steps_for_final_lr=int(5e4), + manual_temperature_decay=False, + threshold_training_steps_for_final_temperature=int(1e5), + fixed_temperature_value=0.25, + use_ture_chance_label_in_chance_encoder=False, + + # ****** Priority ****** + use_priority=False, + priority_prob_alpha=0.6, + priority_prob_beta=0.4, + + # ****** UCB ****** + root_dirichlet_alpha=0.3, + root_noise_weight=0.25, + + # ****** Explore by random collect ****** + random_collect_episode_num=0, + + # ****** Explore by eps greedy ****** + eps=dict( + eps_greedy_exploration_in_collect=False, + type='linear', + start=1., + end=0.05, + decay=int(1e5), + ), + + # ****** 多任务相关 ****** + task_num=2, # 任务数量,根据实际需求调整 + task_id=0, # 当前任务的起始ID + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + 概述: + 返回该算法的默认模型设置。 + 返回: + - model_info (:obj:`Tuple[str, List[str]]`): 模型名称和模型导入路径列表。 + """ + return 'MuZeroMTModel', ['lzero.model.muzero_model_multitask'] + + def _init_learn(self) -> None: + """ + 概述: + 学习模式初始化方法。初始化学习模型、优化器和MCTS工具。 + """ + super()._init_learn() + + assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type + # NOTE: in board_games, for fixed lr 0.003, 'Adam' is better than 'SGD'. + if self._cfg.optim_type == 'SGD': + self._optimizer = optim.SGD( + self._model.parameters(), + lr=self._cfg.learning_rate, + momentum=self._cfg.momentum, + weight_decay=self._cfg.weight_decay, + ) + elif self._cfg.optim_type == 'Adam': + self._optimizer = optim.Adam( + self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay + ) + elif self._cfg.optim_type == 'AdamW': + self._optimizer = configure_optimizers(model=self._model, weight_decay=self._cfg.weight_decay, + learning_rate=self._cfg.learning_rate, device_type=self._cfg.device) + + if self._cfg.lr_piecewise_constant_decay: + from torch.optim.lr_scheduler import LambdaLR + max_step = self._cfg.threshold_training_steps_for_final_lr + # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. + lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa + self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda) + + # use model_wrapper for specialized demands of different modes + self._target_model = copy.deepcopy(self._model) + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.target_update_freq} + ) + self._learn_model = self._model + + if self._cfg.use_augmentation: + self.image_transforms = ImageTransforms( + self._cfg.augmentation, + image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) + ) + self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.inverse_scalar_transform_handle = InverseScalarTransform( + self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution + ) + + # ============================================================== + # harmonydream (learnable weights for different losses) + # ============================================================== + if self._cfg.model.harmony_balance: + # List of parameter names + harmony_names = ["harmony_dynamics", "harmony_policy", "harmony_value", "harmony_reward", "harmony_entropy"] + # Initialize and name each parameter + for name in harmony_names: + param = torch.nn.Parameter(-torch.log(torch.tensor(1.0))) + setattr(self, name, param) + + if self._cfg.use_rnd_model: + if self._cfg.target_model_for_intrinsic_reward_update_type == 'assign': + self._target_model_for_intrinsic_reward = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.target_update_freq_for_intrinsic_reward} + ) + elif self._cfg.target_model_for_intrinsic_reward_update_type == 'momentum': + self._target_model_for_intrinsic_reward = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.target_update_theta_for_intrinsic_reward} + ) + + # ========= logging for analysis ========= + self.l2_norm_before = 0. + self.l2_norm_after = 0. + self.grad_norm_before = 0. + self.grad_norm_after = 0. + self.dormant_ratio_encoder = 0. + self.dormant_ratio_dynamics = 0. + # 初始化多任务相关参数 + self.task_num_for_current_rank = self._cfg.task_num + self.task_id = self._cfg.task_id + + def _forward_learn(self, data: List[Tuple[torch.Tensor, torch.Tensor, int]]) -> Dict[str, Union[float, int]]: + """ + 概述: + 学习模式的前向函数,是学习过程的核心。数据从重放缓冲区采样,计算损失并反向传播更新模型。 + 参数: + - data (:obj:`List[Tuple[torch.Tensor, torch.Tensor, int]]`): 每个任务的数据元组列表, + 每个元组包含 (current_batch, target_batch, task_id)。 + 返回: + - info_dict (:obj:`Dict[str, Union[float, int]]`): 用于记录的信息字典,包含当前学习损失和学习统计信息。 + """ + self._learn_model.train() + self._target_model.train() + + # 初始化多任务损失列表 + reward_loss_multi_task = [] + policy_loss_multi_task = [] + value_loss_multi_task = [] + consistency_loss_multi_task = [] + policy_entropy_multi_task = [] + lambd_multi_task = [] + value_priority_multi_task = [] + value_priority_mean_multi_task = [] + + weighted_total_loss = 0.0 # 初始化为0 + losses_list = [] # 用于存储每个任务的损失 + + for task_idx, (current_batch, target_batch, task_id) in enumerate(data): + obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch + target_reward, target_value, target_policy = target_batch + + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) + + # 数据增强 + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # 准备动作批次并转换为张量 + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() + data_list = [mask_batch, target_reward, target_value, target_policy, weights] + mask_batch, target_reward, target_value, target_policy, weights = to_torch_float_tensor( + data_list, self._cfg.device + ) + + target_reward = target_reward.view(self._cfg.batch_size[task_idx], -1) + target_value = target_value.view(self._cfg.batch_size[task_idx], -1) + + assert obs_batch.size(0) == self._cfg.batch_size[task_idx] == target_reward.size(0) + + # 变换奖励和价值到缩放形式 + transformed_target_reward = scalar_transform(target_reward) + transformed_target_value = scalar_transform(target_value) + + # 转换为类别分布 + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + # 初始推理 + network_output = self._learn_model.initial_inference(obs_batch, task_id=task_id) + + latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) + + # 记录 Dormant Ratio 和 L2 Norm + if self._cfg.cal_dormant_ratio: + self.dormant_ratio_encoder = cal_dormant_ratio( + self._learn_model.representation_network, obs_batch.detach(), + percentage=self._cfg.dormant_threshold + ) + latent_state_l2_norms = torch.norm(latent_state.view(latent_state.shape[0], -1), p=2, dim=1).mean() + + # 逆变换价值 + original_value = self.inverse_scalar_transform_handle(value) + + # 初始化预测值和策略 + predicted_rewards = [] + if self._cfg.monitor_extra_statistics: + predicted_values, predicted_policies = original_value.detach().cpu(), torch.softmax( + policy_logits, dim=1 + ).detach().cpu() + + # 计算优先级 + value_priority = torch.nn.L1Loss(reduction='none')(original_value.squeeze(-1), target_value[:, 0]) + value_priority = value_priority.data.cpu().numpy() + 1e-6 + + # 计算第一个步骤的策略和价值损失 + policy_loss = cross_entropy_loss(policy_logits, target_policy[:, 0]) + value_loss = cross_entropy_loss(value, target_value_categorical[:, 0]) + + prob = torch.softmax(policy_logits, dim=-1) + entropy = -(prob * torch.log(prob + 1e-9)).sum(-1) + policy_entropy_loss = -entropy + + reward_loss = torch.zeros(self._cfg.batch_size[task_idx], device=self._cfg.device) + consistency_loss = torch.zeros(self._cfg.batch_size[task_idx], device=self._cfg.device) + target_policy_entropy = 0 + + # 循环进行多个unroll步骤 + for step_k in range(self._cfg.num_unroll_steps): + # 使用动态函数进行递归推理 + network_output = self._learn_model.recurrent_inference(latent_state, action_batch[:, step_k]) + latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) + + # 记录 Dormant Ratio + if step_k == self._cfg.num_unroll_steps - 1 and self._cfg.cal_dormant_ratio: + action_tmp = action_batch[:, step_k] + if len(action_tmp.shape) == 1: + action_tmp = action_tmp.unsqueeze(-1) + # 转换动作为独热编码 + action_one_hot = torch.zeros(action_tmp.shape[0], policy_logits.shape[-1], device=action_tmp.device) + action_tmp = action_tmp.long() + action_one_hot.scatter_(1, action_tmp, 1) + action_encoding_tmp = action_one_hot.unsqueeze(-1).unsqueeze(-1) + action_encoding = action_encoding_tmp.expand( + latent_state.shape[0], policy_logits.shape[-1], latent_state.shape[2], latent_state.shape[3] + ) + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + self.dormant_ratio_dynamics = cal_dormant_ratio( + self._learn_model.dynamics_network, + state_action_encoding.detach(), + percentage=self._cfg.dormant_threshold + ) + + # 逆变换价值 + original_value = self.inverse_scalar_transform_handle(value) + + # 计算一致性损失 + if self._cfg.model.self_supervised_learning_loss and self._cfg.ssl_loss_weight > 0: + beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) + network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index], task_id=task_id) + + latent_state = to_tensor(latent_state) + representation_state = to_tensor(network_output.latent_state) + + dynamic_proj = self._learn_model.project(latent_state, with_grad=True) + observation_proj = self._learn_model.project(representation_state, with_grad=False) + temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_k] + consistency_loss += temp_loss + + # 计算策略和价值损失 + policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_k + 1]) + value_loss += cross_entropy_loss(value, target_value_categorical[:, step_k + 1]) + reward_loss += cross_entropy_loss(reward, target_reward_categorical[:, step_k]) + + # 计算策略熵损失 + prob = torch.softmax(policy_logits, dim=-1) + entropy = -(prob * torch.log(prob + 1e-9)).sum(-1) + policy_entropy_loss += -entropy + + # 计算目标策略熵(仅用于调试) + target_normalized_visit_count = target_policy[:, step_k + 1] + non_masked_indices = torch.nonzero(mask_batch[:, step_k + 1]).squeeze(-1) + if len(non_masked_indices) > 0: + target_normalized_visit_count_masked = torch.index_select( + target_normalized_visit_count, 0, non_masked_indices + ) + target_policy_entropy += -( + (target_normalized_visit_count_masked + 1e-6) * + torch.log(target_normalized_visit_count_masked + 1e-6) + ).sum(-1).mean() + else: + target_policy_entropy += torch.log( + torch.tensor(target_normalized_visit_count.shape[-1], device=self._cfg.device) + ) + + + # 记录预测值和奖励(如果监控额外统计) + if self._cfg.monitor_extra_statistics: + original_rewards = self.inverse_scalar_transform_handle(reward) + original_rewards_cpu = original_rewards.detach().cpu() + + predicted_values = torch.cat( + (predicted_values, self.inverse_scalar_transform_handle(value).detach().cpu()) + ) + predicted_rewards.append(original_rewards_cpu) + predicted_policies = torch.cat( + (predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu()) + ) + + # 核心学习模型更新步骤 + weighted_loss = self._cfg.policy_loss_weight * policy_loss + \ + self._cfg.value_loss_weight * value_loss + \ + self._cfg.reward_loss_weight * reward_loss + \ + self._cfg.ssl_loss_weight * consistency_loss + \ + self._cfg.policy_entropy_weight * policy_entropy_loss + + # 将多个任务的损失累加 + weighted_total_loss += weighted_loss.mean() + + # 保留每个任务的损失用于日志记录 + reward_loss_multi_task.append(reward_loss.mean().item()) + policy_loss_multi_task.append(policy_loss.mean().item()) + value_loss_multi_task.append(value_loss.mean().item()) + consistency_loss_multi_task.append(consistency_loss.mean().item()) + policy_entropy_multi_task.append(policy_entropy_loss.mean().item()) + lambd_multi_task.append(torch.tensor(0., device=self._cfg.device).item()) # TODO: 如果使用梯度校正,可以在这里调整 + value_priority_multi_task.append(value_priority.mean().item()) + value_priority_mean_multi_task.append(value_priority.mean().item()) + losses_list.append(weighted_loss.mean().item()) + + # 清零优化器的梯度 + self._optimizer.zero_grad() + + # 反向传播 + weighted_total_loss.backward() + + # 梯度裁剪 + total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_( + self._learn_model.parameters(), + self._cfg.grad_clip_value + ) + + # 多GPU训练时同步梯度 + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) + + # 更新优化器 + self._optimizer.step() + if self._cfg.lr_piecewise_constant_decay: + self.lr_scheduler.step() + + # 更新目标模型 + self._target_model.update(self._learn_model.state_dict()) + + # 获取GPU内存使用情况 + if torch.cuda.is_available(): + torch.cuda.synchronize() + current_memory_allocated = torch.cuda.memory_allocated() + max_memory_allocated = torch.cuda.max_memory_allocated() + current_memory_allocated_gb = current_memory_allocated / (1024 ** 3) + max_memory_allocated_gb = max_memory_allocated / (1024 ** 3) + else: + current_memory_allocated_gb = 0.0 + max_memory_allocated_gb = 0.0 + + # 构建返回的损失字典 + return_loss_dict = { + 'Current_GPU': current_memory_allocated_gb, + 'Max_GPU': max_memory_allocated_gb, + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'collect_epsilon': self.collect_epsilon, + 'cur_lr_world_model': self._optimizer.param_groups[0]['lr'], + 'weighted_total_loss': weighted_total_loss.item(), + 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), + } + + # print(f'self.task_id:{self.task_id}') + # 生成任务相关的损失字典,并为每个任务相关的 loss 添加前缀 "noreduce_" + multi_task_loss_dicts = { + **generate_task_loss_dict(consistency_loss_multi_task, 'noreduce_consistency_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(reward_loss_multi_task, 'noreduce_reward_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(policy_loss_multi_task, 'noreduce_policy_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_loss_multi_task, 'noreduce_value_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(policy_entropy_multi_task, 'noreduce_policy_entropy_task{}', task_id=self.task_id), + **generate_task_loss_dict(lambd_multi_task, 'noreduce_lambd_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_multi_task, 'noreduce_value_priority_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_mean_multi_task, 'noreduce_value_priority_mean_task{}', task_id=self.task_id), + } + + # 合并两个字典 + return_loss_dict.update(multi_task_loss_dicts) + + # 返回最终的损失字典 + return return_loss_dict + + + def _monitor_vars_learn(self, num_tasks: int = None) -> List[str]: + """ + 概述: + 注册学习模式中需要监控的变量。注册的变量将根据 `_forward_learn` 的返回值记录到tensorboard。 + 如果提供了 `num_tasks`,则为每个任务生成监控变量。 + 参数: + - num_tasks (:obj:`int`, 可选): 任务数量。 + 返回: + - monitored_vars (:obj:`List[str]`): 需要监控的变量列表。 + """ + # 基本监控变量 + monitored_vars = [ + 'Current_GPU', + 'Max_GPU', + 'collect_epsilon', + 'collect_mcts_temperature', + 'cur_lr_world_model', + 'weighted_total_loss', + 'total_grad_norm_before_clip_wm', + ] + + # 任务特定的监控变量 + task_specific_vars = [ + 'noreduce_consistency_loss', + 'noreduce_reward_loss', + 'noreduce_policy_loss', + 'noreduce_value_loss', + 'noreduce_policy_entropy', + 'noreduce_lambd', + 'noreduce_value_priority', + 'noreduce_value_priority_mean', + ] + # self.task_num_for_current_rank 作为当前rank的base_index + num_tasks = self.task_num_for_current_rank + print(f'self.task_num_for_current_rank: {self.task_num_for_current_rank}') + if num_tasks is not None: + for var in task_specific_vars: + for task_idx in range(num_tasks): + monitored_vars.append(f'{var}_task{self.task_id + task_idx}') + else: + monitored_vars.extend(task_specific_vars) + + return monitored_vars + + def _init_collect(self) -> None: + """ + Overview: + Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. + """ + self._collect_model = self._model + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(self._cfg) + else: + self._mcts_collect = MCTSPtree(self._cfg) + self._collect_mcts_temperature = 1. + self.collect_epsilon = 0.0 + if self._cfg.model.model_type == 'conv_context': + self.last_batch_obs = torch.zeros([8, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(8)] + + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id: np.array = None, + task_id: int = None, + ) -> Dict: + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - epsilon (:obj:`float`): The epsilon of the eps greedy exploration. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - temperature: :math:`(1, )`. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - epsilon: :math:`(1, )`. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + self._collect_mcts_temperature = temperature + self.collect_epsilon = epsilon + active_collect_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): + if self._cfg.model.model_type in ["conv", "mlp"]: + network_output = self._collect_model.initial_inference(data, task_id=task_id) + elif self._cfg.model.model_type == "conv_context": + network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, + data, task_id=task_id) + + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + if not self._cfg.collect_with_pure_policy: + # the only difference between collect and eval is the dirichlet noise + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, task_id=task_id) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + if self._cfg.eps.eps_greedy_exploration_in_collect: + # eps greedy collect + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self.collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + # normal collect + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=False + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + if self._cfg.model.model_type in ["conv_context"]: + batch_action.append(action) + + if self._cfg.model.model_type in ["conv_context"]: + self.last_batch_obs = data + self.last_batch_action = batch_action + else: + for i, env_id in enumerate(ready_env_id): + policy_values = torch.softmax(torch.tensor([policy_logits[i][a] for a in legal_actions[i]]), + dim=0).tolist() + policy_values = policy_values / np.sum(policy_values) + action_index_in_legal_action_set = np.random.choice(len(legal_actions[i]), p=policy_values) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[env_id] = { + 'action': action, + 'searched_value': pred_values[i], + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + + return output + + def _get_target_obs_index_in_step_k(self, step): + """ + Overview: + Get the begin index and end index of the target obs in step k. + Arguments: + - step (:obj:`int`): The current step k. + Returns: + - beg_index (:obj:`int`): The begin index of the target obs in step k. + - end_index (:obj:`int`): The end index of the target obs in step k. + Examples: + >>> self._cfg.model.model_type = 'conv' + >>> self._cfg.model.image_channel = 3 + >>> self._cfg.model.frame_stack_num = 4 + >>> self._get_target_obs_index_in_step_k(0) + >>> (0, 12) + """ + if self._cfg.model.model_type in ['conv', 'conv_context']: + beg_index = self._cfg.model.image_channel * step + end_index = self._cfg.model.image_channel * (step + self._cfg.model.frame_stack_num) + elif self._cfg.model.model_type in ['mlp', 'mlp_context']: + beg_index = self._cfg.model.observation_shape * step + end_index = self._cfg.model.observation_shape * (step + self._cfg.model.frame_stack_num) + return beg_index, end_index + + def _init_eval(self) -> None: + """ + Overview: + Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. + """ + self._eval_model = self._model + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(self._cfg) + else: + self._mcts_eval = MCTSPtree(self._cfg) + if self._cfg.model.model_type == 'conv_context': + self.last_batch_obs = torch.zeros([3, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(3)] + # elif self._cfg.model.model_type == 'mlp_context': + # self.last_batch_obs = torch.zeros([3, self._cfg.model.observation_shape]).to(self._cfg.device) + # self.last_batch_action = [-1 for _ in range(3)] + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, + ready_env_id: np.array = None, task_id: int = None) -> Dict: + """ + Overview: + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._eval_model.eval() + active_eval_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): + if self._cfg.model.model_type in ["conv", "mlp"]: + network_output = self._eval_model.initial_inference(data, task_id=task_id) + elif self._cfg.model.model_type == "conv_context": + network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, task_id=task_id) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + if not self._eval_model.training: + # if not in training, obtain the scalars of the value/reward + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_eval_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, task_id=task_id) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than + # sampling during the evaluation phase. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the + # entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + if self._cfg.model.model_type in ["conv_context"]: + batch_action.append(action) + + if self._cfg.model.model_type in ["conv_context"]: + self.last_batch_obs = data + self.last_batch_action = batch_action + + return output + diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index cf95c46d9..2976388ea 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -155,7 +155,7 @@ class UniZeroPolicy(MuZeroPolicy): # (bool) Whether to use the pure policy to collect data. collect_with_pure_policy=False, # (int) The evaluation frequency. - eval_freq=int(2e3), + eval_freq=int(5e3), # (str) The sample type. Options are ['episode', 'transition']. sample_type='transition', # ****** observation ****** @@ -563,7 +563,8 @@ def _forward_collect( temperature: float = 1, to_play: List = [-1], epsilon: float = 0.25, - ready_env_id: np.array = None + ready_env_id: np.array = None, + task_id: int = None, ) -> Dict: """ Overview: @@ -575,6 +576,7 @@ def _forward_collect( - temperature (:obj:`float`): The temperature of the policy. - to_play (:obj:`int`): The player to play. - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + - task_id (:obj:`int`): The task id. Default is None, which means UniZero is in the single-task mode. Shape: - data (:obj:`torch.Tensor`): - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ @@ -697,7 +699,7 @@ def _init_eval(self) -> None: self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, - ready_env_id: np.array = None) -> Dict: + ready_env_id: np.array = None, task_id: int = None,) -> Dict: """ Overview: The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. @@ -707,6 +709,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. - to_play (:obj:`int`): The player to play. - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + - task_id (:obj:`int`): The task id. Default is None, which means UniZero is in the single-task mode. Shape: - data (:obj:`torch.Tensor`): - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ diff --git a/lzero/policy/unizero_multitask.py b/lzero/policy/unizero_multitask.py new file mode 100644 index 000000000..de16eb379 --- /dev/null +++ b/lzero/policy/unizero_multitask.py @@ -0,0 +1,1231 @@ +import copy +from collections import defaultdict +from typing import List, Dict, Any, Tuple, Union + +import numpy as np +import torch +from ding.model import model_wrap +from ding.utils import POLICY_REGISTRY + +from lzero.entry.utils import initialize_zeros_batch +from lzero.mcts import UniZeroMCTSCtree as MCTSCtree +from lzero.model import ImageTransforms +from lzero.policy import prepare_obs_stack4_for_unizero +from lzero.policy import scalar_transform, InverseScalarTransform, phi_transform, \ + DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, prepare_obs +from lzero.policy.unizero import UniZeroPolicy +from .utils import configure_optimizers_nanogpt + + +# sys.path.append('/Users/puyuan/code/LibMTL/') +# from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect +# from LibMTL.weighting.CAGrad_unizero import CAGrad as GradCorrect + +# from LibMTL.weighting.abstract_weighting import AbsWeighting + + +def generate_task_loss_dict(multi_task_losses, task_name_template, task_id): + """ + 生成每个任务的损失字典 + :param multi_task_losses: 包含每个任务损失的列表 + :param task_name_template: 任务名称模板,例如 'obs_loss_task{}' + :return: 一个字典,包含每个任务的损失 + """ + task_loss_dict = {} + for task_idx, task_loss in enumerate(multi_task_losses): + task_name = task_name_template.format(task_idx + task_id) + try: + task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else task_loss + except Exception as e: + task_loss_dict[task_name] = task_loss + return task_loss_dict + + + +class WrappedModel: + def __init__(self, world_model): + self.world_model = world_model + + def parameters(self): + # 返回 tokenizer, transformer 以及所有嵌入层的参数 + return self.world_model.parameters() + + def zero_grad(self, set_to_none=False): + # 将 tokenizer, transformer 和所有嵌入层的梯度设为零 + self.world_model.zero_grad(set_to_none=set_to_none) + + +class WrappedModelV2: + def __init__(self, tokenizer, transformer, pos_emb, task_emb, act_embedding_table): + self.tokenizer = tokenizer + self.transformer = transformer + self.pos_emb = pos_emb + self.task_emb = task_emb + self.act_embedding_table = act_embedding_table + + def parameters(self): + # 返回 tokenizer, transformer 以及所有嵌入层的参数 + return (list(self.tokenizer.parameters()) + + list(self.transformer.parameters()) + + list(self.pos_emb.parameters()) + + list(self.task_emb.parameters()) + + list(self.act_embedding_table.parameters())) + + def zero_grad(self, set_to_none=False): + # 将 tokenizer, transformer 和所有嵌入层的梯度设为零 + self.tokenizer.zero_grad(set_to_none=set_to_none) + self.transformer.zero_grad(set_to_none=set_to_none) + self.pos_emb.zero_grad(set_to_none=set_to_none) + self.task_emb.zero_grad(set_to_none=set_to_none) + self.act_embedding_table.zero_grad(set_to_none=set_to_none) + + +class WrappedModelV3: + def __init__(self, transformer, pos_emb, task_emb, act_embedding_table): + self.transformer = transformer + self.pos_emb = pos_emb + self.task_emb = task_emb + self.act_embedding_table = act_embedding_table + + def parameters(self): + # 返回 tokenizer, transformer 以及所有嵌入层的参数 + return (list(self.transformer.parameters()) + + list(self.pos_emb.parameters()) + + list(self.task_emb.parameters()) + + list(self.act_embedding_table.parameters())) + + def zero_grad(self, set_to_none=False): + # 将 tokenizer, transformer 和所有嵌入层的梯度设为零 + # self.tokenizer.zero_grad(set_to_none=set_to_none) + self.transformer.zero_grad(set_to_none=set_to_none) + self.pos_emb.zero_grad(set_to_none=set_to_none) + self.task_emb.zero_grad(set_to_none=set_to_none) + self.act_embedding_table.zero_grad(set_to_none=set_to_none) + + + +@POLICY_REGISTRY.register('unizero_multitask') +class UniZeroMTPolicy(UniZeroPolicy): + """ + Overview: + The policy class for UniZero, official implementation for paper UniZero: Generalized and Efficient Planning + with Scalable LatentWorld Models. UniZero aims to enhance the planning capabilities of reinforcement learning agents + by addressing the limitations found in MuZero-style algorithms, particularly in environments requiring the + capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667. + """ + + # The default_config for UniZero policy. + config = dict( + type='unizero_multitask', + model=dict( + # (str) The model type. For 1-dimensional vector obs, we use mlp model. For the image obs, we use conv model. + model_type='conv', # options={'mlp', 'conv'} + # (bool) If True, the action space of the environment is continuous, otherwise discrete. + continuous_action_space=False, + # (tuple) The obs shape. + observation_shape=(3, 64, 64), + # (bool) Whether to use the self-supervised learning loss. + self_supervised_learning_loss=True, + # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. + categorical_distribution=True, + # (int) The image channel in image observation. + image_channel=3, + # (int) The number of frames to stack together. + frame_stack_num=1, + # (int) The number of res blocks in MuZero model. + num_res_blocks=1, + # (int) The number of channels of hidden states in MuZero model. + num_channels=64, + # (int) The scale of supports used in categorical distribution. + # This variable is only effective when ``categorical_distribution=True``. + support_scale=50, + # (bool) whether to learn bias in the last linear layer in value and policy head. + bias=True, + # (bool) whether to use res connection in dynamics. + res_connection_in_dynamics=True, + # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'BN'. + norm_type='LN', # NOTE: TODO + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (int) The save interval of the model. + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=10000, ), ), ), + world_model_cfg=dict( + # (int) The number of tokens per block. + tokens_per_block=2, + # (int) The maximum number of blocks. + max_blocks=10, + # (int) The maximum number of tokens, calculated as tokens per block multiplied by max blocks. + max_tokens=2 * 10, + # (int) The context length, usually calculated as twice the number of some base unit. + context_length=2 * 4, + # (bool) Whether to use GRU gating mechanism. + gru_gating=False, + # (str) The device to be used for computation, e.g., 'cpu' or 'cuda'. + device='cpu', + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (bool) Whether to analyze dormant ratio. + analysis_dormant_ratio=False, + # (int) The shape of the action space. + action_space_size=6, + # (int) The size of the group, related to simulation normalization. + group_size=8, # NOTE: sim_norm + # (str) The type of attention mechanism used. Options could be ['causal']. + attention='causal', + # (int) The number of layers in the model. + num_layers=2, + # (int) The number of attention heads. + num_heads=8, + # (int) The dimension of the embedding. + embed_dim=768, + # (float) The dropout probability for the embedding layer. + embed_pdrop=0.1, + # (float) The dropout probability for the residual connections. + resid_pdrop=0.1, + # (float) The dropout probability for the attention mechanism. + attn_pdrop=0.1, + # (int) The size of the support set for value and reward heads. + support_size=101, + # (int) The maximum size of the cache. + max_cache_size=5000, + # (int) The number of environments. + env_num=8, + # (float) The weight of the latent reconstruction loss. + latent_recon_loss_weight=0., + # (float) The weight of the perceptual loss. + perceptual_loss_weight=0., + # (float) The weight of the policy entropy. + policy_entropy_weight=1e-4, + # (str) The type of loss for predicting latent variables. Options could be ['group_kl', 'mse']. + predict_latent_loss_type='group_kl', + # (str) The type of observation. Options are ['image', 'vector']. + obs_type='image', + # (float) The discount factor for future rewards. + gamma=1, + # (float) The threshold for a dormant neuron. + dormant_threshold=0.025, + ), + ), + # ****** common ****** + # (bool) whether to use rnd model. + use_rnd_model=False, + # (bool) Whether to use multi-gpu training. + multi_gpu=True, + # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) + # this variable is used in ``collector``. + sampled_algo=False, + # (bool) Whether to enable the gumbel-based algorithm (e.g. Gumbel Muzero) + gumbel_algo=False, + # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. + mcts_ctree=True, + # (bool) Whether to use cuda for network. + cuda=True, + # (int) The number of environments used in collecting data. + collector_env_num=8, + # (int) The number of environments used in evaluating policy. + evaluator_env_num=3, + # (str) The type of environment. Options are ['not_board_games', 'board_games']. + env_type='not_board_games', + # (str) The type of action space. Options are ['fixed_action_space', 'varied_action_space']. + action_type='fixed_action_space', + # (str) The type of battle mode. Options are ['play_with_bot_mode', 'self_play_mode']. + battle_mode='play_with_bot_mode', + # (bool) Whether to monitor extra statistics in tensorboard. + monitor_extra_statistics=True, + # (int) The transition number of one ``GameSegment``. + game_segment_length=400, + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (bool) Whether to use the pure policy to collect data. + collect_with_pure_policy=False, + # (int) The evaluation frequency. + eval_freq=int(5e3), + # (str) The sample type. Options are ['episode', 'transition']. + sample_type='transition', + + # ****** observation ****** + # (bool) Whether to transform image to string to save memory. + transform2string=False, + # (bool) Whether to use gray scale image. + gray_scale=False, + # (bool) Whether to use data augmentation. + use_augmentation=False, + # (list) The style of augmentation. + augmentation=['shift', 'intensity'], + + # ******* learn ****** + # (bool) Whether to ignore the done flag in the training data. Typically, this value is set to False. + # However, for some environments with a fixed episode length, to ensure the accuracy of Q-value calculations, + # we should set it to True to avoid the influence of the done flag. + ignore_done=False, + # (int) How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... + # For different env, we have different episode_length, + # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. + # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically. + update_per_collect=None, + # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. + replay_ratio=0.25, + # (int) Minibatch size for one gradient descent. + batch_size=256, + # (str) Optimizer for training policy network. + optim_type='AdamW', + # (float) Learning rate for training policy network. Initial lr for manually decay schedule. + learning_rate=0.0001, + # (int) Frequency of hard target network update. + target_update_freq=100, + # (int) Frequency of soft target network update. + target_update_theta=0.05, + # (int) Frequency of target network update. + target_update_freq_for_intrinsic_reward=1000, + # (float) Weight decay for training policy network. + weight_decay=1e-4, + # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction). + momentum=0.9, + # (float) The maximum constraint value of gradient norm clipping. + grad_clip_value=5, + # (int) The number of episodes in each collecting stage when use muzero_collector. + n_episode=8, + # (int) The number of num_segments in each collecting stage when use muzero_segment_collector. + num_segments=8, + # (int) the number of simulations in MCTS. + num_simulations=50, + # (float) Discount factor (gamma) for returns. + discount_factor=0.997, + # (int) The number of steps for calculating target q_value. + td_steps=5, + # (int) The number of unroll steps in dynamics network. + num_unroll_steps=10, + # (float) The weight of reward loss. + reward_loss_weight=1, + # (float) The weight of value loss. + value_loss_weight=0.25, + # (float) The weight of policy loss. + policy_loss_weight=1, + # (float) The weight of ssl (self-supervised learning) loss. + ssl_loss_weight=0, + # (bool) Whether to use piecewise constant learning rate decay. + # i.e. lr: 0.2 -> 0.02 -> 0.002 + lr_piecewise_constant_decay=False, + # (int) The number of final training iterations to control lr decay, which is only used for manually decay. + threshold_training_steps_for_final_lr=int(5e4), + # (bool) Whether to use manually decayed temperature. + manual_temperature_decay=False, + # (int) The number of final training iterations to control temperature, which is only used for manually decay. + threshold_training_steps_for_final_temperature=int(1e5), + # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. + # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. + fixed_temperature_value=0.25, + # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. + use_ture_chance_label_in_chance_encoder=False, + + # ****** Priority ****** + # (bool) Whether to use priority when sampling training data from the buffer. + use_priority=False, + # (float) The degree of prioritization to use. A value of 0 means no prioritization, + # while a value of 1 means full prioritization. + priority_prob_alpha=0.6, + # (float) The degree of correction to use. A value of 0 means no correction, + # while a value of 1 means full correction. + priority_prob_beta=0.4, + # (int) The initial Env Steps for training. + train_start_after_envsteps=int(0), + + # ****** UCB ****** + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + + # ****** Explore by random collect ****** + # (int) The number of episodes to collect data randomly before training. + random_collect_episode_num=0, + + # ****** Explore by eps greedy ****** + eps=dict( + # (bool) Whether to use eps greedy exploration in collecting data. + eps_greedy_exploration_in_collect=False, + # (str) The type of decaying epsilon. Options are 'linear', 'exp'. + type='linear', + # (float) The start value of eps. + start=1., + # (float) The end value of eps. + end=0.05, + # (int) The decay steps from start to end eps. + decay=int(1e5), + ), + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default model setting for demonstration. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. + - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. + - import_names (:obj:`List[str]`): The model class path list used in this algorithm. + .. note:: + The user can define and use customized network model but must obey the same interface definition indicated \ + by import_names path. For MuZero, ``lzero.model.unizero_model.MuZeroModel`` + """ + # NOTE: multi-task model + return 'UniZeroMTModel', ['lzero.model.unizero_model_multitask'] + + def _init_learn(self) -> None: + """ + Overview: + Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. + """ + # NOTE: nanoGPT optimizer + self._optimizer_world_model = configure_optimizers_nanogpt( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + + # use model_wrapper for specialized demands of different modes + self._target_model = copy.deepcopy(self._model) + # Ensure that the installed torch version is greater than or equal to 2.0 + assert int(''.join(filter(str.isdigit, torch.__version__))) >= 200, "We need torch version >= 2.0" + self._model = torch.compile(self._model) + self._target_model = torch.compile(self._target_model) + # NOTE: soft target + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.target_update_theta} + ) + self._learn_model = self._model + + if self._cfg.use_augmentation: + self.image_transforms = ImageTransforms( + self._cfg.augmentation, + image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) + ) + self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.inverse_scalar_transform_handle = InverseScalarTransform( + self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution + ) + self.intermediate_losses = defaultdict(float) + self.l2_norm_before = 0. + self.l2_norm_after = 0. + self.grad_norm_before = 0. + self.grad_norm_after = 0. + + # 创建 WrappedModel 实例 + # 所有参数都共享,即所有参数都需要进行矫正 + # wrapped_model = WrappedModel( + # self._learn_model.world_model, + # ) + + # head 没有矫正梯度 + wrapped_model = WrappedModelV2( + # self._learn_model.world_model.tokenizer, # TODO: + self._learn_model.world_model.tokenizer.encoder[0], # TODO: one encoder + self._learn_model.world_model.transformer, + self._learn_model.world_model.pos_emb, + self._learn_model.world_model.task_emb, + self._learn_model.world_model.act_embedding_table, + ) + + # head 和 tokenizer.encoder 没有矫正梯度 + # wrapped_model = WrappedModelV3( + # self._learn_model.world_model.transformer, + # self._learn_model.world_model.pos_emb, + # self._learn_model.world_model.task_emb, + # self._learn_model.world_model.act_embedding_table, + # ) + + # 将 wrapped_model 作为 share_model 传递给 GradCorrect + # ========= 初始化 MoCo CAGrad 参数 ========= + # self.grad_correct = GradCorrect(wrapped_model, self.task_num, self._cfg.device) + # self.grad_correct.init_param() + # self.grad_correct.rep_grad = False + + self.task_id = self._cfg.task_id + self.task_num_for_current_rank = self._cfg.task_num + + + #@profile + def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward function for learning policy in learn mode, which is the core of the learning process. + The data is sampled from replay buffer. + The loss is calculated by the loss function and the loss is backpropagated to update the model. + Arguments: + - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. + The first tensor is the current_batch, the second tensor is the target_batch. + Returns: + - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ + current learning loss and learning statistics. + """ + self._learn_model.train() + self._target_model.train() + + obs_loss_multi_task = [] + reward_loss_multi_task = [] + policy_loss_multi_task = [] + value_loss_multi_task = [] + latent_recon_loss_multi_task = [] + perceptual_loss_multi_task = [] + orig_policy_loss_multi_task = [] + policy_entropy_multi_task = [] + weighted_total_loss = 0.0 # 初始化为0,避免使用in-place操作 + + latent_state_l2_norms_multi_task = [] + average_target_policy_entropy_multi_task = [] + value_priority_multi_task = [] + value_priority_mean_multi_task = [] + + + losses_list = [] # 用于存储每个任务的损失 + for task_id, data_one_task in enumerate(data): + current_batch, target_batch, task_id = data_one_task + # current_batch, target_batch, _ = data + obs_batch_ori, action_batch, target_action_batch, mask_batch, indices, weights, make_time = current_batch + target_reward, target_value, target_policy = target_batch + + # Prepare observations based on frame stack number + if self._cfg.model.frame_stack_num == 4: + obs_batch, obs_target_batch = prepare_obs_stack4_for_unizero(obs_batch_ori, self._cfg) + else: + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) + + # Apply augmentations if needed + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # Prepare action batch and convert to torch tensor + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze( + -1).long() # For discrete action space + data_list = [mask_batch, target_reward.astype('float32'), target_value.astype('float32'), target_policy, + weights] + mask_batch, target_reward, target_value, target_policy, weights = to_torch_float_tensor(data_list, + self._cfg.device) + + + # rank = get_rank() + # print(f'Rank {rank}: cfg.policy.task_id : {self._cfg.task_id}, self._cfg.batch_size {self._cfg.batch_size}') + + target_reward = target_reward.view(self._cfg.batch_size[task_id], -1) + target_value = target_value.view(self._cfg.batch_size[task_id], -1) + + target_reward = target_reward.view(self._cfg.batch_size[task_id], -1) + target_value = target_value.view(self._cfg.batch_size[task_id], -1) + + # assert obs_batch.size(0) == self._cfg.batch_size == target_reward.size(0) + + # Transform rewards and values to their scaled forms + transformed_target_reward = scalar_transform(target_reward) + transformed_target_value = scalar_transform(target_value) + + # Convert to categorical distributions + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + # Prepare batch for a transformer-based world model + batch_for_gpt = {} + if isinstance(self._cfg.model.observation_shape, int) or len(self._cfg.model.observation_shape) == 1: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + self._cfg.batch_size[task_id], -1, self._cfg.model.observation_shape) + elif len(self._cfg.model.observation_shape) == 3: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + self._cfg.batch_size[task_id], -1, *self._cfg.model.observation_shape) + + batch_for_gpt['actions'] = action_batch.squeeze(-1) + batch_for_gpt['rewards'] = target_reward_categorical[:, :-1] + batch_for_gpt['mask_padding'] = mask_batch == 1.0 # 0 means invalid padding data + batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] + batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] + batch_for_gpt['ends'] = torch.zeros(batch_for_gpt['mask_padding'].shape, dtype=torch.long, + device=self._cfg.device) + batch_for_gpt['target_value'] = target_value_categorical[:, :-1] + batch_for_gpt['target_policy'] = target_policy[:, :-1] + + # Extract valid target policy data and compute entropy + valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] + target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) + average_target_policy_entropy = target_policy_entropy.mean().item() + + # Update world model + intermediate_losses = defaultdict(float) + losses = self._learn_model.world_model.compute_loss( + batch_for_gpt, self._target_model.world_model.tokenizer, self.inverse_scalar_transform_handle, task_id=task_id + ) + + weighted_total_loss += losses.loss_total # TODO + + assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values" + assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values" + + losses_list.append(losses.loss_total) # TODO: for moco + + for loss_name, loss_value in losses.intermediate_losses.items(): + intermediate_losses[f"{loss_name}"] = loss_value + + obs_loss = intermediate_losses['loss_obs'] + reward_loss = intermediate_losses['loss_rewards'] + policy_loss = intermediate_losses['loss_policy'] + orig_policy_loss = intermediate_losses['orig_policy_loss'] + policy_entropy = intermediate_losses['policy_entropy'] + value_loss = intermediate_losses['loss_value'] + latent_recon_loss = intermediate_losses['latent_recon_loss'] + perceptual_loss = intermediate_losses['perceptual_loss'] + latent_state_l2_norms = intermediate_losses['latent_state_l2_norms'] + # value_priority = intermediate_losses['value_priority'] + # logits_value = intermediate_losses['logits_value'] + + # print(f'logits_value:" {logits_value}') + # print(f'logits_value.shape:" {logits_value.shape}') + # print(f"batch_for_gpt['observations'].shape: {batch_for_gpt['observations'].shape}") + + # ============ for value priority ============ + # transform the categorical representation of the scaled value to its original value + # original_value = self.inverse_scalar_transform_handle(logits_value.reshape(-1, 101)).reshape( + # batch_for_gpt['observations'].shape[0], batch_for_gpt['observations'].shape[1], 1) + # calculate the new priorities for each transition. + # value_priority = torch.nn.L1Loss(reduction='none')(original_value.squeeze(-1)[:,0], target_value[:, 0]) # TODO: mix of mean and sum + # value_priority = value_priority.data.cpu().numpy() + 1e-6 # TODO: log-reduce not support array now + value_priority = torch.tensor(0., device=self._cfg.device) + # ============ for value priority ============ + + obs_loss_multi_task.append(obs_loss) + reward_loss_multi_task.append(reward_loss) + policy_loss_multi_task.append(policy_loss) + orig_policy_loss_multi_task.append(orig_policy_loss) + policy_entropy_multi_task.append(policy_entropy) + reward_loss_multi_task.append(reward_loss) + value_loss_multi_task.append(value_loss) + latent_recon_loss_multi_task.append(latent_recon_loss) + perceptual_loss_multi_task.append(perceptual_loss) + latent_state_l2_norms_multi_task.append(latent_state_l2_norms) + value_priority_multi_task.append(value_priority) + value_priority_mean_multi_task.append(value_priority.mean().item()) + + + # Core learn model update step + self._optimizer_world_model.zero_grad() + + # TODO: 使用 MoCo 或 CAGrad 来计算梯度和权重 + # ============= for CAGrad and MoCo ============= + # lambd = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) + + # ============= TODO: 不使用梯度矫正的情况 ============= + lambd = torch.tensor([0. for i in range(self.task_num_for_current_rank)], device=self._cfg.device) + weighted_total_loss.backward() + + # ========== for debugging ========== + # for name, param in self._learn_model.world_model.tokenizer.encoder.named_parameters(): + # print('name, param.mean(), param.std():', name, param.mean(), param.std()) + # if param.requires_grad: + # print(name, param.grad.norm()) + + if self._cfg.analysis_sim_norm: + del self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after + self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze() + self._target_model.encoder_hook.clear_data() + + + total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(self._learn_model.world_model.parameters(), + self._cfg.grad_clip_value) + + if self._cfg.multi_gpu: + # Very important to sync gradients before updating the model + # rank = get_rank() + # print(f'Rank {rank} train task_id: {self._cfg.task_id} sync grad begin...') + self.sync_gradients(self._learn_model) + # print(f'Rank {rank} train task_id: {self._cfg.task_id} sync grad end...') + + self._optimizer_world_model.step() + if self._cfg.lr_piecewise_constant_decay: + self.lr_scheduler.step() + + # Core target model update step + self._target_model.update(self._learn_model.state_dict()) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + current_memory_allocated = torch.cuda.memory_allocated() + max_memory_allocated = torch.cuda.max_memory_allocated() + current_memory_allocated_gb = current_memory_allocated / (1024 ** 3) + max_memory_allocated_gb = max_memory_allocated / (1024 ** 3) + else: + current_memory_allocated_gb = 0. + max_memory_allocated_gb = 0. + + # 然后,在您的代码中,使用这个函数来构建损失字典: + return_loss_dict = { + 'Current_GPU': current_memory_allocated_gb, + 'Max_GPU': max_memory_allocated_gb, + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'collect_epsilon': self._collect_epsilon, + 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], + 'weighted_total_loss': weighted_total_loss.item(), + # 'policy_entropy': policy_entropy, + # 'target_policy_entropy': average_target_policy_entropy, + 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), + } + + # 生成任务相关的损失字典,并为每个任务相关的 loss 添加前缀 "noreduce_" + multi_task_loss_dicts = { + **generate_task_loss_dict(obs_loss_multi_task, 'noreduce_obs_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(latent_recon_loss_multi_task, 'noreduce_latent_recon_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(perceptual_loss_multi_task, 'noreduce_perceptual_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(latent_state_l2_norms_multi_task, 'noreduce_latent_state_l2_norms_task{}', task_id=self.task_id), + **generate_task_loss_dict(policy_loss_multi_task, 'noreduce_policy_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(orig_policy_loss_multi_task, 'noreduce_orig_policy_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(policy_entropy_multi_task, 'noreduce_policy_entropy_task{}', task_id=self.task_id), + **generate_task_loss_dict(reward_loss_multi_task, 'noreduce_reward_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_loss_multi_task, 'noreduce_value_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(average_target_policy_entropy_multi_task, 'noreduce_target_policy_entropy_task{}', task_id=self.task_id), + **generate_task_loss_dict(lambd, 'noreduce_lambd_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_multi_task, 'noreduce_value_priority_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_mean_multi_task, 'noreduce_value_priority_mean_task{}', task_id=self.task_id), + } + # 合并两个字典 + return_loss_dict.update(multi_task_loss_dicts) + # print(f'return_loss_dict:{return_loss_dict}') + + # 返回最终的损失字典 + return return_loss_dict + + def monitor_weights_and_grads(self, model): + for name, param in model.named_parameters(): + if param.requires_grad: + print(f"Layer: {name} | " + f"Weight mean: {param.data.mean():.4f} | " + f"Weight std: {param.data.std():.4f} | " + f"Grad mean: {param.grad.mean():.4f} | " + f"Grad std: {param.grad.std():.4f}") + + def _init_collect(self) -> None: + """ + Overview: + Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. + """ + self._collect_model = self._model + + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(self._cfg) + else: + self._mcts_collect = MCTSPtree(self._cfg) + self._collect_mcts_temperature = 1. + self._collect_epsilon = 0.0 + self.collector_env_num = self._cfg.collector_env_num + if self._cfg.model.model_type == 'conv': + self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] + elif self._cfg.model.model_type == 'mlp': + self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] + + # TODO: num_tasks + def _monitor_vars_learn(self, num_tasks=2) -> List[str]: + """ + Overview: + Register the variables to be monitored in learn mode. The registered variables will be logged in + tensorboard according to the return value ``_forward_learn``. + If num_tasks is provided, generate monitored variables for each task. + """ + # Basic monitored variables that do not depend on the number of tasks + monitored_vars = [ + 'Current_GPU', + 'Max_GPU', + 'collect_epsilon', + 'collect_mcts_temperature', + 'cur_lr_world_model', + 'weighted_total_loss', + 'total_grad_norm_before_clip_wm', + ] + + # rank = get_rank() + task_specific_vars = [ + 'noreduce_obs_loss', + 'noreduce_orig_policy_loss', + 'noreduce_policy_loss', + 'noreduce_latent_recon_loss', + 'noreduce_policy_entropy', + 'noreduce_target_policy_entropy', + 'noreduce_reward_loss', + 'noreduce_value_loss', + 'noreduce_perceptual_loss', + 'noreduce_latent_state_l2_norms', + 'noreduce_lambd', + 'noreduce_value_priority_mean', + ] + # self.task_num_for_current_rank 作为当前rank的base_index + num_tasks = self.task_num_for_current_rank + # If the number of tasks is provided, extend the monitored variables list with task-specific variables + if num_tasks is not None: + for var in task_specific_vars: + for task_idx in range(num_tasks): + # print(f"learner policy Rank {rank}, self.task_id: {self.task_id}") + monitored_vars.append(f'{var}_task{self.task_id+task_idx}') + else: + # If num_tasks is not provided, we assume there's only one task and keep the original variable names + monitored_vars.extend(task_specific_vars) + + return monitored_vars + + #@profile + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id: np.array = None, + task_id: int = None, + ) -> Dict: + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - temperature: :math:`(1, )`. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + + self._collect_mcts_temperature = temperature + self._collect_epsilon = epsilon + active_collect_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + output = {i: None for i in ready_env_id} + + with torch.no_grad(): + network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, task_id=task_id) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + # the only difference between collect and eval is the dirichlet noise + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, task_id=task_id) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + + if self._cfg.eps.eps_greedy_exploration_in_collect: + # eps greedy collect + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self._collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + # normal collect + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=False + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + # ============== TODO: only for visualize ============== + # action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + # distributions, temperature=self._collect_mcts_temperature, deterministic=True + # ) + # action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + # ============== TODO: only for visualize ============== + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + batch_action.append(action) + + self.last_batch_obs = data + self.last_batch_action = batch_action + + # ========= TODO: for muzero_segment_collector now ========= + if active_collect_env_num < self.collector_env_num: + print('==========collect_forward============') + print(f'len(self.last_batch_obs) < self.collector_env_num, {active_collect_env_num}<{self.collector_env_num}') + self._reset_collect(reset_init_data=True) + + return output + + def _init_eval(self) -> None: + """ + Overview: + Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. + """ + self._eval_model = self._model + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(self._cfg) + else: + self._mcts_eval = MCTSPtree(self._cfg) + self.evaluator_env_num = self._cfg.evaluator_env_num + + if self._cfg.model.model_type == 'conv': + self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + elif self._cfg.model.model_type == 'mlp': + self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + + #@profile + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, + ready_env_id: np.array = None, task_id: int = None) -> Dict: + """ + Overview: + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._eval_model.eval() + active_eval_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): + network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, task_id=task_id) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + if not self._eval_model.training: + # if not in training, obtain the scalars of the value/reward + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_eval_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, task_id=task_id) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + batch_action = [] + + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + # print("roots_visit_count_distributions:", distributions, "root_value:", value) + + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than + # sampling during the evaluation phase. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the + # entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + batch_action.append(action) + + self.last_batch_obs = data + self.last_batch_action = batch_action + + return output + + #@profile + def _reset_collect(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True) -> None: + """ + Overview: + This method resets the collection process for a specific environment. It clears caches and memory + when certain conditions are met, ensuring optimal performance. If reset_init_data is True, the initial data + will be reset. + Arguments: + - env_id (:obj:`int`, optional): The ID of the environment to reset. If None or list, the function returns immediately. + - current_steps (:obj:`int`, optional): The current step count in the environment. Used to determine + whether to clear caches. + - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. + """ + if reset_init_data: + self.last_batch_obs = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.collector_env_num, + self._cfg.device + ) + self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] + # print('collector: last_batch_obs, last_batch_action reset()', self.last_batch_obs.shape) + + # Return immediately if env_id is None or a list + if env_id is None or isinstance(env_id, list): + return + + # Determine the clear interval based on the environment's sample type + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + + # Clear caches if the current steps are a multiple of the clear interval + if current_steps % clear_interval == 0: + print(f'clear_interval: {clear_interval}') + + # Clear various caches in the collect model's world model + world_model = self._collect_model.world_model + world_model.past_kv_cache_init_infer.clear() + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up GPU memory + torch.cuda.empty_cache() + + print('collector: collect_model clear()') + print(f'eps_steps_lst[{env_id}]: {current_steps}') + + # TODO: check its correctness ========= + self._reset_target_model() + + #@profile + def _reset_target_model(self) -> None: + """ + Overview: + This method resets the target model. It clears caches and memory, ensuring optimal performance. + Arguments: + - None + """ + + # Clear various caches in the target_model + world_model = self._target_model.world_model + world_model.past_kv_cache_init_infer.clear() + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up GPU memory + torch.cuda.empty_cache() + print('collector: target_model past_kv_cache.clear()') + + #@profile + def _reset_eval(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True) -> None: + """ + Overview: + This method resets the evaluation process for a specific environment. It clears caches and memory + when certain conditions are met, ensuring optimal performance. If reset_init_data is True, + the initial data will be reset. + Arguments: + - env_id (:obj:`int`, optional): The ID of the environment to reset. If None or list, the function returns immediately. + - current_steps (:obj:`int`, optional): The current step count in the environment. Used to determine + whether to clear caches. + - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. + """ + if reset_init_data: + self.last_batch_obs = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.evaluator_env_num, + self._cfg.device + ) + self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] + # print('evaluator: last_batch_obs, last_batch_action reset()', self.last_batch_obs.shape) + + # Return immediately if env_id is None or a list + if env_id is None or isinstance(env_id, list): + return + + # Determine the clear interval based on the environment's sample type + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + + # Clear caches if the current steps are a multiple of the clear interval + if current_steps % clear_interval == 0: + print(f'clear_interval: {clear_interval}') + + # Clear various caches in the eval model's world model + world_model = self._eval_model.world_model + # world_model.past_kv_cache_init_infer.clear() + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up GPU memory + torch.cuda.empty_cache() + + print('evaluator: eval_model clear()') + print(f'eps_steps_lst[{env_id}]: {current_steps}') + + + def recompute_pos_emb_diff_and_clear_cache(self) -> None: + """ + Overview: + Clear the caches and precompute positional embedding matrices in the model. + """ + # NOTE: Clear caches and precompute positional embedding matrices both for the collect and target models + for model in [self._collect_model, self._target_model]: + model.world_model.precompute_pos_emb_diff_kv() + model.world_model.clear_caches() + torch.cuda.empty_cache() + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model, target_model and optimizer. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. + """ + return { + 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), + 'optimizer_world_model': self._optimizer_world_model.state_dict(), + } + + # ========== TODO: original version: load all parameters ========== + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. + """ + self._learn_model.load_state_dict(state_dict['model']) + self._target_model.load_state_dict(state_dict['target_model']) + self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) + + # # ========== TODO: pretrain-finetue version: only load encoder and transformer-backbone parameters, head use re init weight ========== + # def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + # """ + # Overview: + # Load the state_dict variable into policy learn mode, excluding multi-task related parameters. + # Arguments: + # - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved previously. + # """ + # # 定义需要排除的参数前缀 + # exclude_prefixes = [ + # '_orig_mod.world_model.head_policy_multi_task.', + # '_orig_mod.world_model.head_value_multi_task.', + # '_orig_mod.world_model.head_rewards_multi_task.', + # '_orig_mod.world_model.head_observations_multi_task.', + # '_orig_mod.world_model.task_emb.' + # ] + + # # 定义需要排除的具体参数(如果有特殊情况) + # exclude_keys = [ + # '_orig_mod.world_model.task_emb.weight', + # '_orig_mod.world_model.task_emb.bias', # 如果存在则添加 + # # 添加其他需要排除的具体参数名 + # ] + + # def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]: + # """ + # 过滤掉需要排除的参数。 + # """ + # filtered = {} + # for k, v in state_dict_loader.items(): + # if any(k.startswith(prefix) for prefix in exclude_prefixes): + # print(f"Excluding parameter: {k}") # 调试用,查看哪些参数被排除 + # continue + # if k in exclude_keys: + # print(f"Excluding specific parameter: {k}") # 调试用 + # continue + # filtered[k] = v + # return filtered + + # # 过滤并加载 'model' 部分 + # if 'model' in state_dict: + # model_state_dict = state_dict['model'] + # filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys) + # missing_keys, unexpected_keys = self._learn_model.load_state_dict(filtered_model_state_dict, strict=False) + # if missing_keys: + # print(f"Missing keys when loading _learn_model: {missing_keys}") + # if unexpected_keys: + # print(f"Unexpected keys when loading _learn_model: {unexpected_keys}") + # else: + # print("No 'model' key found in the state_dict.") + + # # 过滤并加载 'target_model' 部分 + # if 'target_model' in state_dict: + # target_model_state_dict = state_dict['target_model'] + # filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys) + # missing_keys, unexpected_keys = self._target_model.load_state_dict(filtered_target_model_state_dict, strict=False) + # if missing_keys: + # print(f"Missing keys when loading _target_model: {missing_keys}") + # if unexpected_keys: + # print(f"Unexpected keys when loading _target_model: {unexpected_keys}") + # else: + # print("No 'target_model' key found in the state_dict.") + + # # 加载优化器的 state_dict,不需要过滤,因为优化器通常不包含模型参数 + # if 'optimizer_world_model' in state_dict: + # optimizer_state_dict = state_dict['optimizer_world_model'] + # try: + # self._optimizer_world_model.load_state_dict(optimizer_state_dict) + # except Exception as e: + # print(f"Error loading optimizer state_dict: {e}") + # else: + # print("No 'optimizer_world_model' key found in the state_dict.") + + # # 如果需要,还可以加载其他部分,例如 scheduler 等 \ No newline at end of file diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index eff413df6..3f47c2eb4 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -41,6 +41,7 @@ def __init__( exp_name: Optional[str] = 'default_experiment', instance_name: Optional[str] = 'collector', policy_config: 'policy_config' = None, # noqa + task_id: int = None, ) -> None: """ Overview: @@ -53,7 +54,9 @@ def __init__( - exp_name (:obj:`str`): Name of the experiment, used for logging and saving purposes. - instance_name (:obj:`str`): Unique identifier for this collector instance. - policy_config (:obj:`Optional[policy_config]`): Configuration object for the policy. + - task_id (:obj:`int`): Unique identifier for the task. If None, that means we are in the single task mode. """ + self.task_id = task_id self._exp_name = exp_name self._instance_name = instance_name self._collect_print_freq = collect_print_freq @@ -267,6 +270,7 @@ def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegm end_index = beg_index + self.unroll_plus_td_steps - 1 pad_reward_lst = game_segments[i].reward_segment[beg_index:end_index] + if self.policy_config.use_ture_chance_label_in_chance_encoder: chance_lst = game_segments[i].chance_segment[beg_index:end_index] @@ -293,7 +297,7 @@ def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegm game_segment element shape: obs: game_segment_length + stack + num_unroll_steps, 20+4 +5 rew: game_segment_length + stack + num_unroll_steps + td_steps -1 20 +5+3-1 - action: game_segment_length -> 20 + action: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 root_values: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 child_visits: game_segment_length + num_unroll_steps -> 20 +5 to_play: game_segment_length -> 20 @@ -434,8 +438,13 @@ def collect(self, # Key policy forward step # ============================================================== # print(f'ready_env_id:{ready_env_id}') - policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id) - + # policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id) + if self.task_id is None: + # single task setting + policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id) + else: + # multi-task setting + policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, task_id=self.task_id) # Extract relevant policy outputs actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} value_dict_with_env_id = {k: v['searched_value'] for k, v in policy_output.items()} @@ -554,9 +563,9 @@ def collect(self, completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id])) eps_steps_lst[env_id] += 1 - if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']: - # only for UniZero now - self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) + if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero', 'unizero_multitask']: + # TODO: only for UniZero now + self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) # NOTE: reset_init_data=False total_transitions += 1 @@ -774,10 +783,16 @@ def _output_log(self, train_iter: int) -> None: for k, v in info.items(): if k in ['each_reward']: continue - self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + if self.task_id is None: + self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + else: + self._tb_logger.add_scalar('{}_iter_task{}/'.format(self._instance_name, self.task_id) + k, v, train_iter) if k in ['total_envstep_count']: continue - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) + if self.task_id is None: + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) + else: + self._tb_logger.add_scalar('{}_step_task{}/'.format(self._instance_name, self.task_id) + k, v, self._total_envstep_count) if self.policy_config.use_wandb: wandb.log({'{}_step/'.format(self._instance_name) + k: v for k, v in info.items()}, step=self._total_envstep_count) diff --git a/lzero/worker/muzero_evaluator.py b/lzero/worker/muzero_evaluator.py index f7cc39047..d88b9a221 100644 --- a/lzero/worker/muzero_evaluator.py +++ b/lzero/worker/muzero_evaluator.py @@ -15,6 +15,7 @@ from lzero.mcts.buffer.game_segment import GameSegment from lzero.mcts.utils import prepare_observation +import threading class MuZeroEvaluator(ISerialEvaluator): @@ -56,6 +57,7 @@ def __init__( exp_name: Optional[str] = 'default_experiment', instance_name: Optional[str] = 'evaluator', policy_config: 'policy_config' = None, # noqa + task_id: int = None, ) -> None: """ Overview: @@ -70,7 +72,10 @@ def __init__( - exp_name (:obj:`str`): Name of the experiment, used to determine output directory. - instance_name (:obj:`str`): Name of this evaluator instance. - policy_config (:obj:`Optional[dict]`): Optional configuration for the game policy. + - task_id (:obj:`int`): Unique identifier for the task. If None, that means we are in the single task mode. """ + self.stop_event = threading.Event() # Add stop event to handle timeouts + self.task_id = task_id self._eval_freq = eval_freq self._exp_name = exp_name self._instance_name = instance_name @@ -88,7 +93,19 @@ def __init__( './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name ) else: - self._logger, self._tb_logger = None, None # for close elegantly + # self._logger, self._tb_logger = None, None # for close elegantly + # ========== TODO: unizero_multitask ddp_v2 ======== + if tb_logger is not None: + self._logger, _ = build_logger( + './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False + ) + self._tb_logger = tb_logger + + + self._rank = get_rank() + + print(f'rank {self._rank}, self.task_id: {self.task_id}') + self.reset(policy, env) @@ -101,6 +118,9 @@ def __init__( # ============================================================== self.policy_config = policy_config + # def stop(self): + # self.stop_event.set() + def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: @@ -210,10 +230,19 @@ def eval( - stop_flag (:obj:`bool`): Indicates whether the training can be stopped based on the stop value. - episode_info (:obj:`Dict[str, Any]`): A dictionary containing information about the evaluation episodes. """ + print(f"=========in eval() Rank {get_rank()} ===========") + device = torch.cuda.current_device() + print(f"当前默认的 GPU 设备编号: {device}") + torch.cuda.set_device(get_rank()) + print(f"set device后的 GPU 设备编号: {get_rank()}") + # the evaluator only works on rank0 episode_info = None stop_flag = False - if get_rank() == 0: + # ======== TODO: unizero_multitask ddp_v2 ======== + # if get_rank() == 0: + if get_rank() >= 0: + if n_episode is None: n_episode = self._default_n_episode assert n_episode is not None, "please indicate eval n_episode" @@ -263,6 +292,12 @@ def eval( eps_steps_lst = np.zeros(env_nums) with self._timer: while not eval_monitor.is_finished(): + + # Check if stop_event is set (timeout occurred) + if self.stop_event.is_set(): + self._logger.info("[EVALUATOR]: Evaluation aborted due to timeout.") + break + # Get current ready env obs. obs = self._env.ready_obs new_available_env_id = set(obs.keys()).difference(ready_env_id) @@ -284,7 +319,13 @@ def eval( # ============================================================== # policy forward # ============================================================== - policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id) + # policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id) + if self.task_id is None: + # single task setting + policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id) + else: + # multi task setting + policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id, task_id=self.task_id) actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} distributions_dict_with_env_id = {k: v['visit_count_distributions'] for k, v in policy_output.items()} @@ -426,14 +467,23 @@ def eval( episode_info = eval_monitor.get_episode_info() if episode_info is not None: info.update(episode_info) + + print(f'rank {self._rank}, self.task_id: {self.task_id}') + self._logger.info(self._logger.get_tabulate_vars_hor(info)) for k, v in info.items(): if k in ['train_iter', 'ckpt_name', 'each_reward']: continue if not np.isscalar(v): continue - self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) + if self.task_id is None: + self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) + else: + self._tb_logger.add_scalar('{}_iter_task{}/'.format(self._instance_name, self.task_id) + k, v, + train_iter) + self._tb_logger.add_scalar('{}_step_task{}/'.format(self._instance_name, self.task_id) + k, v, + envstep) if self.policy_config.use_wandb: wandb.log({'{}_step/'.format(self._instance_name) + k: v}, step=envstep) @@ -451,12 +501,16 @@ def eval( ", so your MCTS/RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details." ) - if get_world_size() > 1: - objects = [stop_flag, episode_info] - broadcast_object_list(objects, src=0) - stop_flag, episode_info = objects + # ========== TODO: unizero_multitask ddp_v2 ======== + # if get_world_size() > 1: + # objects = [stop_flag, episode_info] + # print(f'rank {self._rank}, self.task_id: {self.task_id}') + # print('before broadcast_object_list') + # broadcast_object_list(objects, src=0) + # print('evaluator after broadcast_object_list') + # stop_flag, episode_info = objects episode_info = to_item(episode_info) if return_trajectory: episode_info['trajectory'] = game_segments - return stop_flag, episode_info + return stop_flag, episode_info \ No newline at end of file diff --git a/lzero/worker/muzero_segment_collector.py b/lzero/worker/muzero_segment_collector.py index 137d8ec89..325ac4904 100644 --- a/lzero/worker/muzero_segment_collector.py +++ b/lzero/worker/muzero_segment_collector.py @@ -46,19 +46,22 @@ def __init__( exp_name: Optional[str] = 'default_experiment', instance_name: Optional[str] = 'collector', policy_config: 'policy_config' = None, # noqa + task_id: int = None, ) -> None: """ Overview: - Initialize the MuZeroSegmentCollector with the given parameters. + Initialize the MuZeroCollector with the given parameters. Arguments: - collect_print_freq (:obj:`int`): Frequency (in training steps) at which to print collection information. - env (:obj:`Optional[BaseEnvManager]`): Instance of the subclass of vectorized environment manager. - - policy (:obj:`Optional[namedtuple]`): Namedtuple of the collection mode policy API. + - policy (:obj:`Optional[namedtuple]`): namedtuple of the collection mode policy API. - tb_logger (:obj:`Optional[SummaryWriter]`): TensorBoard logger instance. - exp_name (:obj:`str`): Name of the experiment, used for logging and saving purposes. - instance_name (:obj:`str`): Unique identifier for this collector instance. - policy_config (:obj:`Optional[policy_config]`): Configuration object for the policy. """ + self.task_id = task_id + self._exp_name = exp_name self._instance_name = instance_name self._collect_print_freq = collect_print_freq @@ -66,6 +69,10 @@ def __init__( self._end_flag = False self._rank = get_rank() + + print(f'rank {self._rank}, self.task_id: {self.task_id}') + + self._world_size = get_world_size() if self._rank == 0: if tb_logger is not None: @@ -83,7 +90,9 @@ def __init__( self._logger, _ = build_logger( path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False ) - self._tb_logger = None + # =========== TODO: for unizero_multitask ddp_v2 ======== + self._tb_logger = tb_logger + self.policy_config = policy_config self.collect_with_pure_policy = self.policy_config.collect_with_pure_policy @@ -460,8 +469,14 @@ def collect(self, # ============================================================== # Key policy forward step # ============================================================== - # logging.info(f'ready_env_id:{ready_env_id}') - policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id) + # print(f'ready_env_id:{ready_env_id}') + if self.task_id is None: + # single task setting + policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id) + else: + # multi task setting + policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, task_id=self.task_id) + # Extract relevant policy outputs actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} @@ -716,11 +731,13 @@ def collect(self, break collected_duration = sum([d['time'] for d in self._episode_info]) + # TODO: for atari multitask new ddp pipeline # reduce data when enables DDP - if self._world_size > 1: - collected_step = allreduce_data(collected_step, 'sum') - collected_episode = allreduce_data(collected_episode, 'sum') - collected_duration = allreduce_data(collected_duration, 'sum') + # if self._world_size > 1: + # collected_step = allreduce_data(collected_step, 'sum') + # collected_episode = allreduce_data(collected_episode, 'sum') + # collected_duration = allreduce_data(collected_duration, 'sum') + self._total_envstep_count += collected_step self._total_episode_count += collected_episode self._total_duration += collected_duration @@ -736,8 +753,9 @@ def _output_log(self, train_iter: int) -> None: Arguments: - train_iter (:obj:`int`): Current training iteration number for logging context. """ - if self._rank != 0: - return + # TODO: for atari multitask new ddp pipeline + # if self._rank != 0: + # return if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: self._last_train_iter = train_iter episode_count = len(self._episode_info) @@ -770,11 +788,20 @@ def _output_log(self, train_iter: int) -> None: if self.policy_config.gumbel_algo: info['completed_value'] = np.mean(completed_value) self._episode_info.clear() + print(f'collector output_log: rank {self._rank}, self.task_id: {self.task_id}') self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) for k, v in info.items(): if k in ['each_reward']: continue - self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + if self.task_id is None: + self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + else: + self._tb_logger.add_scalar('{}_iter_task{}/'.format(self._instance_name, self.task_id) + k, v, + train_iter) if k in ['total_envstep_count']: continue - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) \ No newline at end of file + if self.task_id is None: + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) + else: + self._tb_logger.add_scalar('{}_step_task{}/'.format(self._instance_name, self.task_id) + k, v, + self._total_envstep_count) diff --git a/requirements.txt b/requirements.txt index 831ae67c5..beec0c5ba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ moviepy pytest line_profiler xxhash +simple_parsing einops diff --git a/zoo/atari/config/atari_muzero_multitask_segment_8games_config.py b/zoo/atari/config/atari_muzero_multitask_segment_8games_config.py new file mode 100644 index 000000000..ce486a050 --- /dev/null +++ b/zoo/atari/config/atari_muzero_multitask_segment_8games_config.py @@ -0,0 +1,260 @@ +from easydict import EasyDict + +def create_config( + env_id, + action_space_size, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments +): + + return EasyDict(dict( + env=dict( + stop_value=int(5e5), # Adjusted max_env_step based on user TODO + env_id=env_id, + observation_shape=(4, 96, 96), + frame_stack_num=4, + gray_scale=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + full_action_space=True, + # ===== TODO: only for debug ===== + # collect_max_episode_steps=int(50), + # eval_max_episode_steps=int(50), + ), + policy=dict( + learn=dict( + learner=dict( + hook=dict(save_ckpt_after_iter=200000,), # Adjusted checkpoint frequency + ), + ), + grad_correct_params=dict( + # Placeholder for gradient correction parameters if needed + ), + task_num=len(env_id_list), + model=dict( + device='cuda', + num_res_blocks=2, # NOTE: encoder for 4 game + num_channels=256, + reward_head_channels= 16, + value_head_channels= 16, + policy_head_channels= 16, + fc_reward_layers= [32], + fc_value_layers= [32], + fc_policy_layers= [32], + observation_shape=(4, 96, 96), + frame_stack_num=4, + gray_scale=True, + action_space_size=action_space_size, + norm_type=norm_type, + model_type='conv', + image_channel=1, + downsample=True, + self_supervised_learning_loss=True, + discrete_action_encoding_type='one_hot', + use_sim_norm=True, + use_sim_norm_kl_loss=False, + task_num=len(env_id_list), + ), + cuda=True, + env_type='not_board_games', + # train_start_after_envsteps=2000, + train_start_after_envsteps=0, + game_segment_length=20, # Fixed segment length as per user config + random_collect_episode_num=0, + use_augmentation=True, + use_priority=False, + replay_ratio=0.25, + num_unroll_steps=num_unroll_steps, + # =========== TODO: debug =========== + # update_per_collect=2, # TODO: debug + update_per_collect=80, # Consistent with UniZero config + batch_size=batch_size, + optim_type='SGD', + td_steps=5, + lr_piecewise_constant_decay=True, + manual_temperature_decay=False, + learning_rate=0.2, + target_update_freq=100, + num_segments=num_segments, + num_simulations=num_simulations, + policy_entropy_weight=5e-3, #TODO + ssl_loss_weight=2, + eval_freq=int(5e3), + replay_buffer_size=int(5e5), # Adjusted as per UniZero config + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # ============= The key different params for reanalyze ============= + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs( + env_id_list, + action_space_size, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + seed, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments +): + configs = [] + exp_name_prefix = ( + f'data_muzero_mt_8games/{len(env_id_list)}games_brf{buffer_reanalyze_freq}/' + f'{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_' + f'{len(env_id_list)}-pred-head_mbs-512_upc80_H{num_unroll_steps}_seed{seed}/' + ) + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, + action_space_size, + # collector_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, # TODO: different collector_env_num for Pong and Boxing + # evaluator_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, + # n_episode if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments + ) + config.policy.task_id = task_id + config.exp_name = f"{exp_name_prefix}{env_id.split('NoFrameskip')[0]}_muzero-mt_seed{seed}" + + configs.append([task_id, [config, create_env_manager()]]) + + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + # env_manager=dict(type='base'), + policy=dict( + type='muzero_multitask', + import_names=['lzero.policy.muzero_multitask'], + ), + )) + +if __name__ == "__main__": + import sys + sys.path.insert(0, "/mnt/afs/niuyazhe/code/LightZero") + import lzero + print("lzero path:", lzero.__file__) + # import sys + # import os + # # 添加项目根目录到 PYTHONPATH + # sys.path.append(os.path.dirname(os.path.abspath(__file__))) + + from lzero.entry import train_muzero_multitask_segment_noddp + import argparse + + parser = argparse.ArgumentParser(description='Train MuZero Multitask on Atari') + parser.add_argument('--seed', type=int, default=0, help='Random seed') + args = parser.parse_args() + + # Define your list of environment IDs + env_id_list = [ + 'PongNoFrameskip-v4', + 'MsPacmanNoFrameskip-v4', + 'SeaquestNoFrameskip-v4', + 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', + 'ChopperCommandNoFrameskip-v4', + 'HeroNoFrameskip-v4', + 'RoadRunnerNoFrameskip-v4', + ] + # env_id_list = [ + # 'PongNoFrameskip-v4', + # 'MsPacmanNoFrameskip-v4', + # ] + + action_space_size = 18 # Full action space, adjust if different per env + seed = args.seed + collector_env_num = 8 + evaluator_env_num = 3 + num_segments = 8 + n_episode = 8 + num_simulations = 50 + reanalyze_ratio = 0.0 + + max_batch_size = 512 + batch_size = [int(min(64, max_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + print(f'=========== batch_size: {batch_size} ===========') + + num_unroll_steps = 5 + infer_context_length = 4 + # norm_type = 'LN' + norm_type = 'BN' + + buffer_reanalyze_freq = 1 / 50 # Adjusted as per UniZero config + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + num_segments = 8 + + # =========== TODO: debug =========== + # collector_env_num = 2 + # evaluator_env_num = 2 + # num_segments = 2 + # n_episode = 2 + # num_simulations = 5 + # batch_size = [int(min(2, max_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + + + # Generate configurations + configs = generate_configs( + env_id_list=env_id_list, + action_space_size=action_space_size, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_episode=n_episode, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + seed=seed, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments + ) + + # Start training + train_muzero_multitask_segment_noddp(configs, seed=seed, max_env_step=int(5e5)) \ No newline at end of file diff --git a/zoo/atari/config/atari_muzero_multitask_segment_8games_ddp_config.py b/zoo/atari/config/atari_muzero_multitask_segment_8games_ddp_config.py new file mode 100644 index 000000000..c11790337 --- /dev/null +++ b/zoo/atari/config/atari_muzero_multitask_segment_8games_ddp_config.py @@ -0,0 +1,275 @@ +# zoo/atari/config/atari_muzero_multitask_segment_8games_config.py + +from easydict import EasyDict +from copy import deepcopy +from atari_env_action_space_map import atari_env_action_space_map + +def create_config( + env_id, + action_space_size, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments +): + + return EasyDict(dict( + env=dict( + stop_value=int(5e5), # Adjusted max_env_step based on user TODO + env_id=env_id, + observation_shape=(4, 96, 96), + frame_stack_num=4, + gray_scale=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + full_action_space=True, + # ===== only for debug ===== + # collect_max_episode_steps=int(50), + # eval_max_episode_steps=int(50), + ), + policy=dict( + multi_gpu=True, # ======== Very important for ddp ============= + learn=dict( + learner=dict( + hook=dict(save_ckpt_after_iter=200000,), # Adjusted checkpoint frequency + ), + ), + grad_correct_params=dict( + # Placeholder for gradient correction parameters if needed + ), + task_num=len(env_id_list), + model=dict( + device='cuda', + num_res_blocks=2, # NOTE: encoder for 4 game + num_channels=256, + reward_head_channels= 16, + value_head_channels= 16, + policy_head_channels= 16, + fc_reward_layers= [32], + fc_value_layers= [32], + fc_policy_layers= [32], + observation_shape=(4, 96, 96), + frame_stack_num=4, + gray_scale=True, + action_space_size=action_space_size, + norm_type=norm_type, + model_type='conv', + image_channel=1, + downsample=True, + self_supervised_learning_loss=True, + discrete_action_encoding_type='one_hot', + use_sim_norm=True, + use_sim_norm_kl_loss=False, + task_num=len(env_id_list), + ), + allocated_batch_sizes=False, + # max_batch_size=max_batch_size, + max_batch_size=512,# TODO + cuda=True, + env_type='not_board_games', + # train_start_after_envsteps=2000, + train_start_after_envsteps=0, + game_segment_length=20, # Fixed segment length as per user config + random_collect_episode_num=0, + use_augmentation=True, + use_priority=False, + replay_ratio=0.25, + num_unroll_steps=num_unroll_steps, + # update_per_collect=2, # TODO: debug + update_per_collect=80, # Consistent with UniZero config + batch_size=batch_size, + optim_type='SGD', + td_steps=5, + lr_piecewise_constant_decay=True, + manual_temperature_decay=False, + learning_rate=0.2, + target_update_freq=100, + num_segments=num_segments, + num_simulations=num_simulations, + policy_entropy_weight=5e-3, #TODO + ssl_loss_weight=2, + eval_freq=int(5e3), + replay_buffer_size=int(5e5), # Adjusted as per UniZero config + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # ============= The key different params for reanalyze ============= + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs( + env_id_list, + action_space_size, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + seed, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments +): + configs = [] + exp_name_prefix = ( + f'data_muzero_mt_8games_ddp_8gpu_1129/{len(env_id_list)}games_brf{buffer_reanalyze_freq}/' + f'{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_' + f'{len(env_id_list)}-pred-head_mbs-512_upc80_H{num_unroll_steps}_seed{seed}/' + ) + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, + action_space_size, + # collector_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, # TODO: different collector_env_num for Pong and Boxing + # evaluator_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, + # n_episode if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments + ) + config.policy.task_id = task_id + config.exp_name = f"{exp_name_prefix}{env_id.split('NoFrameskip')[0]}_muzero-mt_seed{seed}" + + configs.append([task_id, [config, create_env_manager()]]) + + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + # env_manager=dict(type='base'), + policy=dict( + type='muzero_multitask', + import_names=['lzero.policy.muzero_multitask'], + ), + )) + +if __name__ == "__main__": + import sys + sys.path.insert(0, "/mnt/afs/niuyazhe/code/LightZero") + import lzero + print("lzero path:", lzero.__file__) + + # parser = argparse.ArgumentParser(description='Train MuZero Multitask on Atari') + # parser.add_argument('--seed', type=int, default=0, help='Random seed') + # args = parser.parse_args() + + # Define your list of environment IDs + env_id_list = [ + 'PongNoFrameskip-v4', + 'MsPacmanNoFrameskip-v4', + 'SeaquestNoFrameskip-v4', + 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', + 'ChopperCommandNoFrameskip-v4', + 'HeroNoFrameskip-v4', + 'RoadRunnerNoFrameskip-v4', + ] + # env_id_list = [ + # 'PongNoFrameskip-v4', + # 'MsPacmanNoFrameskip-v4', + # ] + + action_space_size = 18 # Full action space, adjust if different per env + + # seed = args.seed + seed = 0 + + collector_env_num = 8 + evaluator_env_num = 3 + num_segments = 8 + n_episode = 8 + num_simulations = 50 + reanalyze_ratio = 0.0 + max_env_step = 5e5 + + max_batch_size = 512 + batch_size = [int(min(64, max_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + print(f'=========== batch_size: {batch_size} ===========') + + num_unroll_steps = 5 + infer_context_length = 4 + # norm_type = 'LN' + norm_type = 'BN' + + buffer_reanalyze_freq = 1 / 50 # Adjusted as per UniZero config + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + num_segments = 8 + + # =========== TODO: debug =========== + # collector_env_num = 2 + # evaluator_env_num = 2 + # num_segments = 2 + # n_episode = 2 + # num_simulations = 5 + # batch_size = [int(min(2, max_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + + + # Generate configurations + configs = generate_configs( + env_id_list=env_id_list, + action_space_size=action_space_size, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_episode=n_episode, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + seed=seed, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments + ) + + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + export NCCL_TIMEOUT=3600000 + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 ./zoo/atari/config/atari_muzero_multitask_segment_8games_ddp_config.py + 或者使用 torchrun: + torchrun --nproc_per_node=8 ./zoo/atari/config/atari_muzero_multitask_segment_8games_ddp_config.py + """ + from lzero.entry import train_muzero_multitask_segment_ddp + from ding.utils import DDPContext + with DDPContext(): + train_muzero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/atari/config/atari_muzero_segment_stack1_config.py b/zoo/atari/config/atari_muzero_segment_stack1_config.py new file mode 100644 index 000000000..785e10c20 --- /dev/null +++ b/zoo/atari/config/atari_muzero_segment_stack1_config.py @@ -0,0 +1,136 @@ +from easydict import EasyDict +from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map + +def main(env_id, seed): + action_space_size = atari_env_action_space_map[env_id] + + # ============================================================== + # begin of the most frequently changed config specified by the user + # ============================================================== + collector_env_num = 8 + num_segments = 8 + game_segment_length = 20 + + evaluator_env_num = 3 + num_simulations = 50 + update_per_collect = None + replay_ratio = 0.25 + + num_unroll_steps = 5 + batch_size = 256 + max_env_step = int(4e5) + + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + buffer_reanalyze_freq = 1/50 + # buffer_reanalyze_freq = 1/10000 + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size = 160 + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition=0.75 + + # =========== for debug =========== + # collector_env_num = 2 + # num_segments = 2 + # evaluator_env_num = 2 + # num_simulations = 2 + # update_per_collect = 2 + # batch_size = 5 + # ============================================================== + # end of the most frequently changed config specified by the user + # ============================================================== + + atari_muzero_config = dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + frame_stack_num=1, + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + # TODO: debug + # collect_max_episode_steps=int(50), + # eval_max_episode_steps=int(50), + ), + policy=dict( + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000, ), ), ), # default is 10000 + analysis_sim_norm=False, + cal_dormant_ratio=False, + model=dict( + observation_shape=(3, 64, 64), + image_channel=3, + frame_stack_num=1, + gray_scale=False, + action_space_size=action_space_size, + downsample=True, + self_supervised_learning_loss=True, # default is False + discrete_action_encoding_type='one_hot', + norm_type='BN', + use_sim_norm=True, # NOTE + use_sim_norm_kl_loss=False, + model_type='conv' + ), + cuda=True, + env_type='not_board_games', + num_segments=num_segments, + train_start_after_envsteps=2000, + game_segment_length=game_segment_length, + random_collect_episode_num=0, + use_augmentation=True, + use_priority=False, + replay_ratio=replay_ratio, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='SGD', + td_steps=5, + lr_piecewise_constant_decay=True, + manual_temperature_decay=False, + learning_rate=0.2, + target_update_freq=100, + num_simulations=num_simulations, + ssl_loss_weight=2, + eval_freq=int(5e3), + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # ============= The key different params for reanalyze ============= + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + buffer_reanalyze_freq=buffer_reanalyze_freq, + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size=reanalyze_batch_size, + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition=reanalyze_partition, + ), + ) + atari_muzero_config = EasyDict(atari_muzero_config) + main_config = atari_muzero_config + + atari_muzero_create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), + ) + atari_muzero_create_config = EasyDict(atari_muzero_create_config) + create_config = atari_muzero_create_config + + # ============ use muzero_segment_collector instead of muzero_collector ============= + from lzero.entry import train_muzero_segment + main_config.exp_name = f'data_muzero_stack1_1205/{env_id[:-14]}/{env_id[:-14]}_mz_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}_bs{batch_size}_seed{seed}' + train_muzero_segment([main_config, create_config], seed=seed, max_env_step=max_env_step) + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description='Process different environments and seeds.') + parser.add_argument('--env', type=str, help='The environment to use', default='PongNoFrameskip-v4') + parser.add_argument('--seed', type=int, help='The seed to use', default=0) + args = parser.parse_args() + + main(args.env, args.seed) \ No newline at end of file diff --git a/zoo/atari/config/atari_muzero_stack1_multitask_segment_8games_ddp_config.py b/zoo/atari/config/atari_muzero_stack1_multitask_segment_8games_ddp_config.py new file mode 100644 index 000000000..1cecef1fa --- /dev/null +++ b/zoo/atari/config/atari_muzero_stack1_multitask_segment_8games_ddp_config.py @@ -0,0 +1,288 @@ +# zoo/atari/config/atari_muzero_multitask_segment_8games_config.py + +from easydict import EasyDict +from copy import deepcopy +from atari_env_action_space_map import atari_env_action_space_map + +def create_config( + env_id, + action_space_size, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments +): + + return EasyDict(dict( + env=dict( + stop_value=int(5e5), # Adjusted max_env_step based on user TODO + env_id=env_id, + # observation_shape=(4, 96, 96), + # frame_stack_num=4, + # gray_scale=True, + observation_shape=(3, 96, 96), + frame_stack_num=1, + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + full_action_space=True, + # ===== only for debug ===== + # collect_max_episode_steps=int(50), + # eval_max_episode_steps=int(50), + ), + policy=dict( + multi_gpu=True, # ======== Very important for ddp ============= + learn=dict( + learner=dict( + hook=dict(save_ckpt_after_iter=200000,), # Adjusted checkpoint frequency + ), + ), + grad_correct_params=dict( + # Placeholder for gradient correction parameters if needed + ), + task_num=len(env_id_list), + model=dict( + device='cuda', + + num_res_blocks=2, # NOTE: encoder for 4 game + # num_channels=256, + num_channels=64, + + reward_head_channels= 16, + value_head_channels= 16, + policy_head_channels= 16, + fc_reward_layers= [32], + fc_value_layers= [32], + fc_policy_layers= [32], + + # observation_shape=(4, 96, 96), + # frame_stack_num=4, + # gray_scale=True, + # image_channel=1, + + observation_shape=(3, 96, 96), + frame_stack_num=1, + gray_scale=False, + image_channel=3, + + action_space_size=action_space_size, + norm_type=norm_type, + model_type='conv', + downsample=True, + self_supervised_learning_loss=True, + discrete_action_encoding_type='one_hot', + use_sim_norm=True, + use_sim_norm_kl_loss=False, + task_num=len(env_id_list), + ), + allocated_batch_sizes=False, + # max_batch_size=max_batch_size, + max_batch_size=512,# TODO + cuda=True, + env_type='not_board_games', + # train_start_after_envsteps=2000, + train_start_after_envsteps=0, + game_segment_length=20, # Fixed segment length as per user config + random_collect_episode_num=0, + use_augmentation=True, + use_priority=False, + replay_ratio=0.25, + num_unroll_steps=num_unroll_steps, + # update_per_collect=2, # TODO: debug + update_per_collect=80, # Consistent with UniZero config + batch_size=batch_size, + optim_type='SGD', + td_steps=5, + lr_piecewise_constant_decay=True, + manual_temperature_decay=False, + learning_rate=0.2, + target_update_freq=100, + num_segments=num_segments, + num_simulations=num_simulations, + policy_entropy_weight=5e-3, #TODO + ssl_loss_weight=2, + eval_freq=int(5e3), + replay_buffer_size=int(5e5), # Adjusted as per UniZero config + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # ============= The key different params for reanalyze ============= + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs( + env_id_list, + action_space_size, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + seed, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments +): + configs = [] + exp_name_prefix = ( + f'data_muzero_stack1_mt_8games_ddp_8gpu_1201/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_channel64/' + f'{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel64_gsl20_' + f'{len(env_id_list)}-pred-head_mbs-512_upc80_H{num_unroll_steps}_seed{seed}/' + ) + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, + action_space_size, + # collector_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, # TODO: different collector_env_num for Pong and Boxing + # evaluator_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, + # n_episode if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments + ) + config.policy.task_id = task_id + config.exp_name = f"{exp_name_prefix}{env_id.split('NoFrameskip')[0]}_muzero-mt_seed{seed}" + + configs.append([task_id, [config, create_env_manager()]]) + + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + # env_manager=dict(type='base'), + policy=dict( + type='muzero_multitask', + import_names=['lzero.policy.muzero_multitask'], + ), + )) + +if __name__ == "__main__": + import sys + sys.path.insert(0, "/mnt/afs/niuyazhe/code/LightZero") + import lzero + print("lzero path:", lzero.__file__) + + # parser = argparse.ArgumentParser(description='Train MuZero Multitask on Atari') + # parser.add_argument('--seed', type=int, default=0, help='Random seed') + # args = parser.parse_args() + + # Define your list of environment IDs + env_id_list = [ + 'PongNoFrameskip-v4', + 'MsPacmanNoFrameskip-v4', + 'SeaquestNoFrameskip-v4', + 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', + 'ChopperCommandNoFrameskip-v4', + 'HeroNoFrameskip-v4', + 'RoadRunnerNoFrameskip-v4', + ] + # env_id_list = [ + # 'PongNoFrameskip-v4', + # 'MsPacmanNoFrameskip-v4', + # ] + + action_space_size = 18 # Full action space, adjust if different per env + + # seed = args.seed + seed = 0 + + collector_env_num = 8 + evaluator_env_num = 3 + num_segments = 8 + n_episode = 8 + num_simulations = 50 + reanalyze_ratio = 0.0 + max_env_step = 5e5 + + max_batch_size = 512 + batch_size = [int(min(64, max_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + print(f'=========== batch_size: {batch_size} ===========') + + num_unroll_steps = 5 + infer_context_length = 4 + # norm_type = 'LN' + norm_type = 'BN' + + buffer_reanalyze_freq = 1 / 50 # Adjusted as per UniZero config + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + num_segments = 8 + + # =========== TODO: debug =========== + # collector_env_num = 2 + # evaluator_env_num = 2 + # num_segments = 2 + # n_episode = 2 + # num_simulations = 5 + # batch_size = [int(min(2, max_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + + + # Generate configurations + configs = generate_configs( + env_id_list=env_id_list, + action_space_size=action_space_size, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_episode=n_episode, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + seed=seed, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments + ) + + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + export NCCL_TIMEOUT=3600000 + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 ./zoo/atari/config/atari_muzero_multitask_segment_8games_ddp_config.py + 或者使用 torchrun: + torchrun --nproc_per_node=8 ./zoo/atari/config/atari_muzero_multitask_segment_8games_ddp_config.py + """ + from lzero.entry import train_muzero_multitask_segment_ddp + from ding.utils import DDPContext + with DDPContext(): + train_muzero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/atari/config/atari_rezero_mz_config.py b/zoo/atari/config/atari_rezero_mz_config.py index c7787831b..91517afd5 100644 --- a/zoo/atari/config/atari_rezero_mz_config.py +++ b/zoo/atari/config/atari_rezero_mz_config.py @@ -18,6 +18,17 @@ reuse_search = True collect_with_pure_policy = True buffer_reanalyze_freq = 1 + +# ====== only for debug ===== +# collector_env_num = 8 +# num_segments = 8 +# evaluator_env_num = 2 +# num_simulations = 5 +# max_env_step = int(2e5) +# reanalyze_ratio = 0.1 +# batch_size = 64 +# num_unroll_steps = 10 +# replay_ratio = 0.01 # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -32,6 +43,9 @@ evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, manager=dict(shared_memory=False, ), + # # TODO: only for debug + # collect_max_episode_steps=int(20), + # eval_max_episode_steps=int(20), ), policy=dict( model=dict( diff --git a/zoo/atari/config/atari_unizero_multigpu_ddp_config.py b/zoo/atari/config/atari_unizero_multigpu_ddp_config.py index 82f64f141..26ecff41c 100644 --- a/zoo/atari/config/atari_unizero_multigpu_ddp_config.py +++ b/zoo/atari/config/atari_unizero_multigpu_ddp_config.py @@ -55,13 +55,20 @@ max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action context_length=2 * infer_context_length, device='cuda', - # device='cpu', action_space_size=action_space_size, num_layers=2, num_heads=8, embed_dim=768, obs_type='image', env_num=max(collector_env_num, evaluator_env_num), + task_num=1, + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, ), ), # (str) The path of the pretrained model. If None, the model will be initialized by the default model. diff --git a/zoo/atari/config/atari_unizero_multitask_26games_serial_config.py b/zoo/atari/config/atari_unizero_multitask_26games_serial_config.py new file mode 100644 index 000000000..f78e9e33f --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_26games_serial_config.py @@ -0,0 +1,158 @@ +from easydict import EasyDict +from copy import deepcopy + +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + ), + policy=dict( + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))), + grad_correct_params=dict( + MoCo_beta=0.5, + MoCo_beta_sigma=0.5, + MoCo_gamma=0.1, + MoCo_gamma_sigma=0.5, + MoCo_rho=0, + calpha=0.5, + rescale=1, + ), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + world_model_cfg=dict( + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=2, + num_heads=8, + embed_dim=768, + obs_type='image', + env_num=8, + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + ), + ), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + update_per_collect=1000, + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + )) + +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed): + configs = [] + exp_name_prefix = f'data_unizero_mt/{len(env_id_list)}games_1-encoder-{norm_type}_26-head_lsd768-nlayer2-nh8_max-bs2000_upc1000_seed{seed}/' + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, + action_space_size, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_unizero-mt_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + from lzero.entry import train_unizero_multitask_serial + + env_id_list = [ + 'PongNoFrameskip-v4', + 'MsPacmanNoFrameskip-v4', + 'SeaquestNoFrameskip-v4', + 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', + 'AmidarNoFrameskip-v4', + 'AssaultNoFrameskip-v4', + 'AsterixNoFrameskip-v4', + 'BankHeistNoFrameskip-v4', + 'BattleZoneNoFrameskip-v4', + 'ChopperCommandNoFrameskip-v4', + 'CrazyClimberNoFrameskip-v4', + 'DemonAttackNoFrameskip-v4', + 'FreewayNoFrameskip-v4', + 'FrostbiteNoFrameskip-v4', + 'GopherNoFrameskip-v4', + 'HeroNoFrameskip-v4', + 'JamesbondNoFrameskip-v4', + 'KangarooNoFrameskip-v4', + 'KrullNoFrameskip-v4', + 'KungFuMasterNoFrameskip-v4', + 'PrivateEyeNoFrameskip-v4', + 'RoadRunnerNoFrameskip-v4', + 'UpNDownNoFrameskip-v4', + 'QbertNoFrameskip-v4', + 'BreakoutNoFrameskip-v4', + ] + + action_space_size = 18 + seed = 0 + collector_env_num = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(1e6) + reanalyze_ratio = 0.0 + max_batch_size = 1000 + batch_size = [int(max_batch_size / len(env_id_list)) for _ in range(len(env_id_list))] + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + + configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed) + + # train_unizero_multitask_serial(configs[:4], seed=seed, max_env_step=max_env_step) # multitask learning on first four tasks + train_unizero_multitask_serial(configs, seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_multitask_segment_26games_ddp_config.py b/zoo/atari/config/atari_unizero_multitask_segment_26games_ddp_config.py new file mode 100644 index 000000000..7f38fb67a --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_26games_ddp_config.py @@ -0,0 +1,172 @@ +from easydict import EasyDict + +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + ), + policy=dict( + multi_gpu=True, + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=50000))), + grad_correct_params=dict( + MoCo_beta=0.5, + MoCo_beta_sigma=0.5, + MoCo_gamma=0.1, + MoCo_gamma_sigma=0.5, + MoCo_rho=0, + calpha=0.5, + rescale=1, + ), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + world_model_cfg=dict( + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=8, + num_heads=24, + embed_dim=768, + obs_type='image', + env_num=8, + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + num_experts_in_moe_head=4, + ), + ), + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + train_start_after_envsteps=int(0), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + game_segment_length=20, + update_per_collect=80, + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(5e5), + eval_freq=int(1e4), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments): + configs = [] + exp_name_prefix = f'data_unizero_mt_ddp-8gpu-26game_1201/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_nlayer8-nhead24_seed{seed}/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_{len(env_id_list)}-pred-head_lsd768-nlayer8-nh24_mbs-512-bs64_upc80_seed{seed}/' + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, + action_space_size, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_unizero-mt_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 ./zoo/atari/config/atari_unizero_multitask_segment_26games_ddp_config.py + torchrun --nproc_per_node=8 ./zoo/atari/config/atari_unizero_multitask_segment_26games_ddp_config.py + """ + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + import os + + # List of Atari games used for multi-task learning + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + 'AmidarNoFrameskip-v4', 'AssaultNoFrameskip-v4', 'AsterixNoFrameskip-v4', 'BankHeistNoFrameskip-v4', + 'BattleZoneNoFrameskip-v4', 'CrazyClimberNoFrameskip-v4', 'DemonAttackNoFrameskip-v4', 'FreewayNoFrameskip-v4', + 'FrostbiteNoFrameskip-v4', 'GopherNoFrameskip-v4', 'JamesbondNoFrameskip-v4', 'KangarooNoFrameskip-v4', + 'KrullNoFrameskip-v4', 'KungFuMasterNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'UpNDownNoFrameskip-v4', + 'QbertNoFrameskip-v4', 'BreakoutNoFrameskip-v4', + ] + + # Hyperparameters + action_space_size = 18 + collector_env_num = 8 + evaluator_env_num = 3 + n_episode = 8 + num_segments = 8 + num_simulations = 50 + reanalyze_ratio = 0.0 + max_env_step = int(5e5) + total_batch_size = 512 + batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in env_id_list] + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + buffer_reanalyze_freq = 1 / 50 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + for seed in [0]: # Seed for reproducibility + configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments) + + # Training with distributed data parallel (DDP) + with DDPContext(): + train_unizero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_multitask_segment_26games_serial_config.py b/zoo/atari/config/atari_unizero_multitask_segment_26games_serial_config.py new file mode 100644 index 000000000..0bcca387b --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_26games_serial_config.py @@ -0,0 +1,194 @@ +from easydict import EasyDict +from copy import deepcopy + +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments): + """ + Create the configuration for a specific environment. + """ + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), # Input observation dimensions + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + # ===== TODO: only for debug ===== + # collect_max_episode_steps=int(50), + # eval_max_episode_steps=int(50), + ), + policy=dict( + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))), + grad_correct_params=dict( + MoCo_beta=0.5, + MoCo_beta_sigma=0.5, + MoCo_gamma=0.1, + MoCo_gamma_sigma=0.5, + MoCo_rho=0, + calpha=0.5, + rescale=1, + ), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, # Encoder configuration + num_channels=256, + world_model_cfg=dict( + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=4, # Transformer layers + num_heads=8, + embed_dim=768, + obs_type='image', + env_num=8, + task_num=len(env_id_list), + use_normal_head=True, + use_moe_head=False, + moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + ), + ), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + game_segment_length=20, + update_per_collect=80, # Update steps per collection + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(5e5), + eval_freq=int(1e4), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments): + """ + Generate configurations for all environments in `env_id_list`. + """ + configs = [] + exp_name_prefix = f'data_unizero_mt_segcollect_1107/{len(env_id_list)}games_brf{buffer_reanalyze_freq}/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_{len(env_id_list)}-pred-head_lsd768-nlayer4-nh8_maxbs-640_upc80_seed{seed}/' + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, + action_space_size, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_unizero-mt_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + """ + Create the environment manager configuration. + """ + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + from lzero.entry import train_unizero_multitask_segment_serial + + # Define environments + env_id_list = [ + 'PongNoFrameskip-v4', + 'MsPacmanNoFrameskip-v4', + 'SeaquestNoFrameskip-v4', + 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', + 'ChopperCommandNoFrameskip-v4', + 'HeroNoFrameskip-v4', + 'RoadRunnerNoFrameskip-v4', + 'AmidarNoFrameskip-v4', + 'AssaultNoFrameskip-v4', + 'AsterixNoFrameskip-v4', + 'BankHeistNoFrameskip-v4', + 'BattleZoneNoFrameskip-v4', + 'CrazyClimberNoFrameskip-v4', + 'DemonAttackNoFrameskip-v4', + 'FreewayNoFrameskip-v4', + 'FrostbiteNoFrameskip-v4', + 'GopherNoFrameskip-v4', + 'JamesbondNoFrameskip-v4', + 'KangarooNoFrameskip-v4', + 'KrullNoFrameskip-v4', + 'KungFuMasterNoFrameskip-v4', + 'PrivateEyeNoFrameskip-v4', + 'UpNDownNoFrameskip-v4', + 'QbertNoFrameskip-v4', + 'BreakoutNoFrameskip-v4', + ] + + # Define hyperparameters + action_space_size = 18 + seed = 0 + collector_env_num = 8 + num_segments = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(1e6) + reanalyze_ratio = 0. + max_batch_size = 640 + batch_size = [int(min(64, max_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + buffer_reanalyze_freq = 1 / 50 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # ======== TODO: only for debug ======== + # collector_env_num = 2 + # num_segments = 2 + # n_episode = 2 + # evaluator_env_num = 2 + # num_simulations = 2 + # batch_size = [4, 4, 4, 4] + + # Generate configurations + configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments) + + # Train using the generated configurations + train_unizero_multitask_segment_serial(configs, seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py b/zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py new file mode 100644 index 000000000..cb8e6811e --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py @@ -0,0 +1,174 @@ +from easydict import EasyDict + +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, + total_batch_size): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + # ===== only for debug ===== + # collect_max_episode_steps=int(30), + # eval_max_episode_steps=int(30), + ), + policy=dict( + multi_gpu=True, # Very important for ddp + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))), + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + world_model_cfg=dict( + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + # batch_size=64 8games训练时,每张卡大约占 12*3=36G cuda显存 + num_layers=12, + num_heads=24, + embed_dim=768, + obs_type='image', + env_num=8, + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + multiplication_moe_in_transformer=False, + num_experts_in_moe_head=4, + use_moe_head=False, + ), + ), + eval_offline=False, + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + train_start_after_envsteps=int(0), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + game_segment_length=20, + update_per_collect=80, + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(5e5), + eval_freq=int(2e4), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size): + configs = [] + exp_name_prefix = f'data_unizero_mt_ddp-8gpu_20241226/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_seed{seed}/' + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, + reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 ./zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py + torchrun --nproc_per_node=8 ./zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py + """ + + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + import os + + os.environ["NCCL_TIMEOUT"] = "3600000000" + + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + ] + + action_space_size = 18 + collector_env_num = 8 + num_segments = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(5e5) + reanalyze_ratio = 0.0 + total_batch_size = 512 + batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + buffer_reanalyze_freq = 1 / 50 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # ======== TODO: only for debug ======== + collector_env_num = 2 + num_segments = 2 + n_episode = 2 + evaluator_env_num = 2 + num_simulations = 2 + batch_size = [4, 4, 4, 4, 4, 4, 4, 4] + + + for seed in [0]: + configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size) + + with DDPContext(): + train_unizero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step) + # train_unizero_multitask_segment_ddp(configs[:4], seed=seed, max_env_step=max_env_step) # train on the first four tasks diff --git a/zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config_evaloffline.py b/zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config_evaloffline.py new file mode 100644 index 000000000..d20e90824 --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config_evaloffline.py @@ -0,0 +1,180 @@ +from easydict import EasyDict + +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, + total_batch_size): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + # collect_max_episode_steps=int(5e3), + # eval_max_episode_steps=int(5e3), + # ===== only for debug ===== + collect_max_episode_steps=int(20), + eval_max_episode_steps=int(20), + ), + policy=dict( + multi_gpu=True, # Very important for ddp + # learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))), + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=10))), # TODO + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + world_model_cfg=dict( + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + # batch_size=64 8games训练时,每张卡大约占 12*3=36G cuda显存 + num_layers=2, + # num_layers=12, + num_heads=24, + embed_dim=768, + obs_type='image', + env_num=8, + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + multiplication_moe_in_transformer=False, + num_experts_in_moe_head=4, + use_moe_head=False, + ), + ), + eval_offline=True, # TODO: 目前的版本需要load_ + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + train_start_after_envsteps=int(0), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + game_segment_length=20, + # update_per_collect=80, + update_per_collect=20, # TODO + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(5e5), + # eval_freq=int(2e4), + eval_freq=int(10), # TODO + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size): + configs = [] + exp_name_prefix = f'data_unizero_mt_ddp-8gpu_20241226/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_seed{seed}/' + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, + reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 ./zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config_debug.py + torchrun --nproc_per_node=8 ./zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py + """ + + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + import os + + os.environ["NCCL_TIMEOUT"] = "3600000000" + + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + ] + + action_space_size = 18 + collector_env_num = 8 + num_segments = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + # max_env_step = int(5e5) + max_env_step = int(200) # TODO + + reanalyze_ratio = 0.0 + total_batch_size = 512 + batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + buffer_reanalyze_freq = 1 / 50 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # ======== TODO: only for debug ======== + collector_env_num = 2 + num_segments = 2 + n_episode = 2 + evaluator_env_num = 2 + num_simulations = 2 + batch_size = [4, 4, 4, 4, 4, 4, 4, 4] + + + for seed in [0]: + configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size) + + with DDPContext(): + train_unizero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step) + # train_unizero_multitask_segment_ddp(configs[:4], seed=seed, max_env_step=max_env_step) # train on the first four tasks diff --git a/zoo/atari/config/atari_unizero_multitask_segment_8games_serial_config.py b/zoo/atari/config/atari_unizero_multitask_segment_8games_serial_config.py new file mode 100644 index 000000000..790e06a37 --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_8games_serial_config.py @@ -0,0 +1,165 @@ +from easydict import EasyDict +from copy import deepcopy + +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments): + """ + Create a configuration for a specific environment. + """ + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + ), + policy=dict( + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))), + grad_correct_params=dict( + MoCo_beta=0.5, + MoCo_beta_sigma=0.5, + MoCo_gamma=0.1, + MoCo_gamma_sigma=0.5, + MoCo_rho=0, + calpha=0.5, + rescale=1, + ), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + world_model_cfg=dict( + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=4, + num_heads=8, + embed_dim=768, + obs_type='image', + env_num=8, + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + ), + ), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + game_segment_length=20, + update_per_collect=160, + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(5e5), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments): + """ + Generate configurations for all tasks in the environment list. + """ + configs = [] + exp_name_prefix = f'data_unizero_mt_serial/{len(env_id_list)}games_brf{buffer_reanalyze_freq}/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_{len(env_id_list)}-pred-head_nlayer4-nh8-lsd768_mbs-320_upc160_seed{seed}/' + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, + action_space_size, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_unizero-mt_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + """ + Create the environment manager configuration. + """ + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + from lzero.entry import train_unizero_multitask_segment_serial + + # Define the environment list + env_id_list = [ + 'PongNoFrameskip-v4', + 'MsPacmanNoFrameskip-v4', + 'SeaquestNoFrameskip-v4', + 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', + 'ChopperCommandNoFrameskip-v4', + 'HeroNoFrameskip-v4', + 'RoadRunnerNoFrameskip-v4', + ] + + # Define key parameters + action_space_size = 18 + seed = 0 + collector_env_num = 8 + num_segments = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(5e5) + reanalyze_ratio = 0. + max_batch_size = 320 + batch_size = [int(min(64, max_batch_size / len(env_id_list))) for _ in env_id_list] + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + buffer_reanalyze_freq = 1 / 50 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # Generate configurations and start training + configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments) + train_unizero_multitask_segment_serial(configs, seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_multitask_segment_eval_config.py b/zoo/atari/config/atari_unizero_multitask_segment_eval_config.py new file mode 100644 index 000000000..6bb9972c9 --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_eval_config.py @@ -0,0 +1,166 @@ +from easydict import EasyDict + +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + ), + policy=dict( + multi_gpu=True, # Enable multi-GPU for DDP + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=50000))), + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, + MoCo_rho=0, calpha=0.5, rescale=1, + ), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + world_model_cfg=dict( + env_id_list=env_id_list, + analysis_tsne=True, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=8, # Transformer layers + num_heads=24, + embed_dim=768, + obs_type='image', + env_num=8, + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + ), + ), + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + train_start_after_envsteps=int(0), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + game_segment_length=20, + update_per_collect=80, + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(5e5), + eval_freq=int(2e4), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size): + configs = [] + exp_name_prefix = f'data_unizero_mt_ddp-8gpu_eval-latent_state_tsne/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_{len(env_id_list)}-pred-head_nlayer8-nh24-lsd768_seed{seed}/' + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, action_space_size, collector_env_num, evaluator_env_num, + n_episode, num_simulations, reanalyze_ratio, batch_size, + num_unroll_steps, infer_context_length, norm_type, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_unizero-mt_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + """ + Overview: + This program is designed to obtain the t-SNE of the latent states in 8games multi-task learning. + + This script should be executed with GPUs. + Run the following command to launch the script: + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 ./zoo/atari/config/atari_unizero_multitask_segment_eval_config.py + torchrun --nproc_per_node=8 ./zoo/atari/config/atari_unizero_multitask_segment_eval_config.py + """ + + from lzero.entry import train_unizero_multitask_segment_eval + from ding.utils import DDPContext + + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', + 'BoxingNoFrameskip-v4', 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', + 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + ] + + action_space_size = 18 + + for seed in [0]: + collector_env_num = 2 + num_segments = 2 + n_episode = 2 + evaluator_env_num = 2 + num_simulations = 50 + max_env_step = int(4e5) + reanalyze_ratio = 0.0 + total_batch_size = int(4*len(env_id_list)) + batch_size = [4 for _ in range(len(env_id_list))] + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + buffer_reanalyze_freq = 1/50 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + + configs = generate_configs( + env_id_list, action_space_size, collector_env_num, n_episode, + evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, + num_unroll_steps, infer_context_length, norm_type, seed, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size + ) + + # Pretrained model paths + # 8games + pretrained_model_path = '/mnt/afs/niuyazhe/code/LightZero/data_unizero_mt_ddp-8gpu_1127/8games_brf0.02_nlayer8-nhead24_seed1/8games_brf0.02_1-encoder-LN-res2-channel256_gsl20_8-pred-head_lsd768-nlayer8-nh24_mbs-512-bs64_upc80_seed1/Pong_unizero-mt_seed1/ckpt/iteration_200000.pth.tar' + # 26games + # pretrained_model_path = '/mnt/afs/niuyazhe/code/LightZero/data_unizero_mt_ddp-8gpu-26game_1127/26games_brf0.02_nlayer8-nhead24_seed0/26games_brf0.02_1-encoder-LN-res2-channel256_gsl20_26-pred-head_lsd768-nlayer8-nh24_mbs-512-bs64_upc80_seed0/Pong_unizero-mt_seed0/ckpt/iteration_150000.pth.tar' + + with DDPContext(): + train_unizero_multitask_segment_eval(configs, seed=seed, model_path=pretrained_model_path, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py b/zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py new file mode 100644 index 000000000..aa11a8120 --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py @@ -0,0 +1,169 @@ +from easydict import EasyDict + +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + ), + policy=dict( + multi_gpu=True, + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=50000))), + grad_correct_params=dict( # Gradient correction parameters + MoCo_beta=0.5, + MoCo_beta_sigma=0.5, + MoCo_gamma=0.1, + MoCo_gamma_sigma=0.5, + MoCo_rho=0, + calpha=0.5, + rescale=1, + ), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + world_model_cfg=dict( + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=8, + num_heads=24, + embed_dim=768, + obs_type='image', + env_num=8, + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + ), + ), + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + train_start_after_envsteps=int(0), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + game_segment_length=20, + update_per_collect=80, + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(5e5), + eval_freq=int(2e4), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size): + configs = [] + exp_name_prefix = f'data_unizero_mt_ddp-2gpu_1201/finetune_pong/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_{len(env_id_list)}-pred-head_lsd768-nlayer8-nh24_mbs-512-bs64_upc80_seed{seed}/' + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, + action_space_size, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments, + total_batch_size + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_unizero-mt_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 ./zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py + torchrun --nproc_per_node=8 ./zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py + """ + + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + from easydict import EasyDict + + env_id_list = ['PongNoFrameskip-v4'] # Debug setup + action_space_size = 18 + + # NCCL environment setup + import os + os.environ["NCCL_TIMEOUT"] = "3600000000" + + for seed in [0, 1, 2]: + collector_env_num = 8 + num_segments = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(4e5) + + reanalyze_ratio = 0.0 + total_batch_size = 512 + batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + buffer_reanalyze_freq = 1 / 50 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size) + + pretrained_model_path = '/mnt/afs/niuyazhe/code/LightZero/data_unizero_mt_ddp-8gpu_1127/8games_brf0.02_nlayer8-nhead24_seed1/8games_brf0.02_1-encoder-LN-res2-channel256_gsl20_8-pred-head_lsd768-nlayer8-nh24_mbs-512-bs64_upc80_seed1/Pong_unizero-mt_seed1/ckpt/iteration_200000.pth.tar' + + with DDPContext(): + train_unizero_multitask_segment_ddp(configs, seed=seed, model_path=pretrained_model_path, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/atari/envs/atari_lightzero_env.py b/zoo/atari/envs/atari_lightzero_env.py index 67a27c5f0..d7d0174c6 100644 --- a/zoo/atari/envs/atari_lightzero_env.py +++ b/zoo/atari/envs/atari_lightzero_env.py @@ -24,6 +24,8 @@ class AtariEnvLightZero(BaseEnv): _reward_space, obs, _eval_episode_return, has_reset, _seed, _dynamic_seed """ config = dict( + # (bool) Whether to use the full action space of the environment. Default is False. If set to True, the action space size is 18 for Atari. + full_action_space=False, # (int) The number of environment instances used for data collection. collector_env_num=8, # (int) The number of environment instances used for evaluator. @@ -156,6 +158,7 @@ def step(self, action: int) -> BaseEnvTimestep: observation = self.observe() if done: info['eval_episode_return'] = self._eval_episode_return + print(f'one episode of {self.cfg.env_id} done') return BaseEnvTimestep(observation, self.reward, done, info) diff --git a/zoo/atari/envs/atari_wrappers.py b/zoo/atari/envs/atari_wrappers.py index f38aa24d6..265ef31ac 100644 --- a/zoo/atari/envs/atari_wrappers.py +++ b/zoo/atari/envs/atari_wrappers.py @@ -93,9 +93,9 @@ def wrap_lightzero(config: EasyDict, episode_life: bool, clip_rewards: bool) -> - env (:obj:`gym.Env`): The wrapped Atari environment with the given configurations. """ if config.render_mode_human: - env = gym.make(config.env_id, render_mode='human') + env = gym.make(config.env_id, render_mode='human', full_action_space=config.full_action_space) else: - env = gym.make(config.env_id, render_mode='rgb_array') + env = gym.make(config.env_id, render_mode='rgb_array', full_action_space=config.full_action_space) assert 'NoFrameskip' in env.spec.id if hasattr(config, 'save_replay') and config.save_replay \ and hasattr(config, 'replay_path') and config.replay_path is not None: diff --git a/zoo/dmc2gym/config/dmc2gym_pixels_sampled_unizero_config.py b/zoo/dmc2gym/config/dmc2gym_pixels_sampled_unizero_config.py new file mode 100644 index 000000000..4f5ca5bda --- /dev/null +++ b/zoo/dmc2gym/config/dmc2gym_pixels_sampled_unizero_config.py @@ -0,0 +1,132 @@ +from easydict import EasyDict +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== + +from zoo.dmc2gym.config.dmc_state_env_space_map import dmc_state_env_action_space_map, dmc_state_env_obs_space_map + +env_id = 'cartpole-swingup' # You can specify any DMC task here +action_space_size = dmc_state_env_action_space_map[env_id] +obs_space_size = dmc_state_env_obs_space_map[env_id] +print(f'env_id: {env_id}, action_space_size: {action_space_size}, obs_space_size: {obs_space_size}') + +domain_name = env_id.split('-')[0] +task_name = env_id.split('-')[1] + +continuous_action_space = True +K = 20 # num_of_sampled_actions +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = None +replay_ratio = 0.25 +max_env_step = int(1e6) +reanalyze_ratio = 0 +batch_size = 64 +num_unroll_steps = 10 +infer_context_length = 4 +norm_type = 'LN' +seed = 0 + +# for debug +# collector_env_num = 2 +# n_episode = 2 +# evaluator_env_num = 1 +# num_simulations = 2 +# batch_size = 2 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +dmc2gym_pixels_cont_sampled_unizero_config = dict( + exp_name=f'data_sampled_unizero_0901/dmc2gym_{env_id}_image_cont_sampled_unizero_ns{num_simulations}_upc{update_per_collect}-rr{replay_ratio}_rer{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_{norm_type}_seed{seed}', + env=dict( + env_id='dmc2gym-v0', + continuous=True, + domain_name=domain_name, + task_name=task_name, + from_pixels=True, # pixel/image obs + frame_skip=2, + warp_frame=True, + scale=True, + frame_stack_num=1, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(3, 84, 84), + action_space_size=action_space_size, + continuous_action_space=continuous_action_space, + num_of_sampled_actions=K, + world_model_cfg=dict( + obs_type='image', + num_unroll_steps=num_unroll_steps, + policy_entropy_loss_weight=5e-3, + continuous_action_space=continuous_action_space, + num_of_sampled_actions=K, + sigma_type='conditioned', + fixed_sigma_value=0.3, + bound_type=None, + model_type='conv', + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + # device='cpu', + device='cuda', + action_space_size=action_space_size, + num_layers=2, + num_heads=8, + embed_dim=768, + env_num=max(collector_env_num, evaluator_env_num), + ), + ), + # (str) The path of the pretrained model. If None, the model will be initialized by the default model. + model_path=None, + num_unroll_steps=num_unroll_steps, + cuda=True, + use_augmentation=False, + env_type='not_board_games', + game_segment_length=100, + replay_ratio=replay_ratio, + batch_size=batch_size, + optim_type='AdamW', + lr_piecewise_constant_decay=False, + learning_rate=0.0001, + target_update_freq=100, + grad_clip_value=5, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) + +dmc2gym_pixels_cont_sampled_unizero_config = EasyDict(dmc2gym_pixels_cont_sampled_unizero_config) +main_config = dmc2gym_pixels_cont_sampled_unizero_config + +dmc2gym_pixels_cont_sampled_unizero_create_config = dict( + env=dict( + type='dmc2gym_lightzero', + import_names=['zoo.dmc2gym.envs.dmc2gym_lightzero_env'], + ), + # env_manager=dict(type='subprocess'), + env_manager=dict(type='base'), + policy=dict( + type='sampled_unizero', + import_names=['lzero.policy.sampled_unizero'], + ), +) +dmc2gym_pixels_cont_sampled_unizero_create_config = EasyDict(dmc2gym_pixels_cont_sampled_unizero_create_config) +create_config = dmc2gym_pixels_cont_sampled_unizero_create_config + +if __name__ == "__main__": + from lzero.entry import train_unizero + + train_unizero([main_config, create_config], seed=seed, max_env_step=max_env_step) diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_segment_config.py b/zoo/dmc2gym/config/dmc2gym_state_suz_segment_config.py index 56e89d1a3..757044f16 100644 --- a/zoo/dmc2gym/config/dmc2gym_state_suz_segment_config.py +++ b/zoo/dmc2gym/config/dmc2gym_state_suz_segment_config.py @@ -122,7 +122,8 @@ def main(env_id, seed): reanalyze_ratio=0, n_episode=n_episode, eval_freq=int(5e3), - replay_buffer_size=int(1e6), + # replay_buffer_size=int(1e6), + replay_buffer_size=int(5e5), # TODO collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, # ============= The key different params for ReZero =============