Skip to content

Commit b0a0ed6

Browse files
authored
add eval_race.py for race competition evaluation
1 parent 72630e1 commit b0a0ed6

File tree

1 file changed

+82
-0
lines changed

1 file changed

+82
-0
lines changed

duckiesim/rl/eval_race.py

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import torch
2+
import gymnasium as gym
3+
import numpy as np
4+
from munchausen import QNetwork
5+
6+
# import plotext as plt
7+
import cv2
8+
9+
10+
def evaluate_with_rendering(model_path, env_id, num_episodes=10, seed=0, tau_soft=1.0):
11+
"""
12+
Évalue le modèle sauvegardé avec rendu visuel, en utilisant une politique softmax.
13+
14+
Args:
15+
model_path (str): Le chemin vers le fichier du modèle sauvegardé.
16+
env_id (str): L'identifiant de l'environnement.
17+
num_episodes (int): Le nombre d'épisodes à évaluer.
18+
seed (int): La graine initiale pour reproduire les résultats.
19+
tau_soft (float): Température pour le softmax.
20+
"""
21+
# Charger le modèle
22+
env = gym.make(env_id)
23+
24+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25+
model = QNetwork(env.action_space, env.observation_space).to(device)
26+
model.load_state_dict(torch.load(model_path))
27+
model.eval()
28+
avg_counter_checkpts =0
29+
avg_steps=0
30+
avg_return =0
31+
# Évaluation
32+
for episode in range(num_episodes):
33+
obs, _ = env.reset(seed=seed + episode)
34+
# obs, _ = env.reset(seed=58)
35+
36+
done = False
37+
total_reward = 0
38+
39+
print(f"\nEpisode {episode + 1}/{num_episodes}")
40+
41+
while not done:
42+
obs_tensor = torch.Tensor(obs).unsqueeze(0).to(device)
43+
44+
with torch.no_grad():
45+
q_values = model(obs_tensor)
46+
policy = torch.softmax(q_values / tau_soft, dim=-1)
47+
# print('policy', policy)
48+
# argmax
49+
action = torch.argmax(policy, dim=-1).item()
50+
# action = torch.multinomial(policy, num_samples=1).item()
51+
# print('action', action)
52+
img = env.render()
53+
cv2.imshow("Duckietown", img)
54+
cv2.waitKey(1)
55+
obs, reward, done, _, _ = env.step(action)
56+
# print('obs', obs.shape)
57+
total_reward += reward
58+
print(f"step {env.unwrapped.episodic_length} reward {reward}")
59+
60+
avg_counter_checkpts += env.env.env.env.race.counter_checkpt
61+
avg_steps += env.env.env.env.episodic_length
62+
avg_return += total_reward
63+
64+
print(f"Récompense totale de l'épisode : {total_reward}")
65+
66+
print(f"the Avg. N. Checkpoints:{avg_counter_checkpts/num_episodes}")
67+
print(f"the Avg. N. Steps:{avg_steps/num_episodes}")
68+
print(f"the Avg. Return:{avg_return/num_episodes}")
69+
70+
71+
env.close()
72+
73+
74+
if __name__ == "__main__":
75+
# model_path /home/dcas/g.ferraro/Desktop/models= "/home/p.le-tolguenec/Documents/duckietown_rl_course/model/exp_1/munchausen_430404_1746.0135750578204.pt" # Adapter au chemin réel
76+
model_path = (
77+
"/home/dcas/g.ferraro/Desktop/models/munchausen_1000000_1648.556805028793.pt"
78+
)
79+
80+
env_id = "DuckietownDiscrete-v0"
81+
82+
evaluate_with_rendering(model_path, env_id, num_episodes=20, seed=42, tau_soft=0.05)

0 commit comments

Comments
 (0)