Skip to content

Commit a066bda

Browse files
author
Mikel
committed
Add PettingZoo support for Craftium multi-agent environments
1 parent 2c5c085 commit a066bda

File tree

1 file changed

+112
-0
lines changed

1 file changed

+112
-0
lines changed

craftium/pettingzoo_env.py

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from pettingzoo import AECEnv
2+
from pettingzoo.utils import agent_selector, wrappers
3+
from typing import Optional
4+
import os
5+
import functools
6+
from . import root_path
7+
from .multiagent_env import MarlCraftiumEnv
8+
9+
AVAIL_ENVS = {
10+
"Craftium/MultiAgentCombat-v0": dict(
11+
env_dir="craftium-envs/multi-agent-combat",
12+
conf=dict(
13+
num_agents=2,
14+
obs_width=64,
15+
obs_height=64,
16+
max_timesteps=1000,
17+
rgb_observations=True,
18+
init_frames=200,
19+
sync_mode=False,
20+
),
21+
),
22+
}
23+
24+
def env(env_name: str, render_mode: Optional[str] = None, **kwargs):
25+
env_dir = os.path.join(root_path, AVAIL_ENVS[env_name]["env_dir"])
26+
final_args = kwargs | AVAIL_ENVS[env_name]["conf"]
27+
28+
return raw_env(
29+
env_dir,
30+
render_mode,
31+
**final_args
32+
)
33+
34+
class raw_env(AECEnv):
35+
metadata = {"render_modes": ["human", "rgb_array"], "name": "craftium_env"}
36+
37+
def __init__(self, env_dir: str, render_mode: Optional[str] = None, **kwargs):
38+
self.na = kwargs["num_agents"]
39+
self.possible_agents = ["player_" + str(r) for r in range(self.na)]
40+
41+
# mapping between agent ID and name
42+
self.agent_id_map = dict(
43+
zip(self.possible_agents, list(range(len(self.possible_agents))))
44+
)
45+
46+
self.env = MarlCraftiumEnv(
47+
env_dir=env_dir,
48+
**kwargs
49+
)
50+
51+
@functools.lru_cache(maxsize=None)
52+
def observation_space(self, agent):
53+
return self.env.observation_space[0]
54+
55+
@functools.lru_cache(maxsize=None)
56+
def action_space(self, agent):
57+
return self.env.action_space
58+
59+
def render(self):
60+
pass
61+
62+
def observe(self, agent):
63+
return self.observations[agent]
64+
65+
def close(self):
66+
self.env.close()
67+
68+
def reset(self, seed=None, options=None):
69+
observations, infos = self.env.reset()
70+
71+
self.agents = self.possible_agents[:]
72+
self.rewards = {agent: 0 for agent in self.agents}
73+
self._cumulative_rewards = {agent: 0 for agent in self.agents}
74+
self.terminations = {agent: False for agent in self.agents}
75+
self.truncations = {agent: False for agent in self.agents}
76+
self.infos = {agent: infos for i, agent in enumerate(self.agents)}
77+
# self.state = {agent: observations[i] for i, agent in enumerate(self.agents)}
78+
self.observations = {agent: observations[i] for i, agent in enumerate(self.agents)}
79+
80+
# the agent_selector utility allows easy cyclic stepping through the agents list
81+
self._agent_selector = agent_selector(self.agents)
82+
self.agent_selection = self._agent_selector.next()
83+
84+
def step(self, action):
85+
if (
86+
self.terminations[self.agent_selection]
87+
or self.truncations[self.agent_selection]
88+
):
89+
# handles stepping an agent which is already dead
90+
# accepts a None action for the one agent, and moves the agent_selection to
91+
# the next dead agent, or if there are no more dead agents, to the next live agent
92+
self._was_dead_step(action)
93+
return
94+
95+
agent = self.agent_selection
96+
97+
agent_id = self.agent_id_map[agent]
98+
99+
self.env.current_agent_id = agent_id
100+
observation, reward, termination, truncated, info = self.env.step_agent(action)
101+
102+
self.observations[agent] = observation
103+
self.rewards[agent] = reward
104+
self.truncations[agent] = truncated
105+
self.terminations[agent] = termination
106+
107+
# collect reward if it is the last agent to act
108+
if self._agent_selector.is_last():
109+
# Adds .rewards to ._cumulative_rewards
110+
self._accumulate_rewards()
111+
# selects the next agent.
112+
self.agent_selection = self._agent_selector.next()

0 commit comments

Comments
 (0)