|
| 1 | +import torch |
| 2 | +import gymnasium as gym |
| 3 | +import numpy as np |
| 4 | +from munchausen import QNetwork |
| 5 | + |
| 6 | +# import plotext as plt |
| 7 | +import cv2 |
| 8 | + |
| 9 | + |
| 10 | +def evaluate_with_rendering(model_path, env_id, num_episodes=10, seed=0, tau_soft=1.0): |
| 11 | + """ |
| 12 | + Évalue le modèle sauvegardé avec rendu visuel, en utilisant une politique softmax. |
| 13 | +
|
| 14 | + Args: |
| 15 | + model_path (str): Le chemin vers le fichier du modèle sauvegardé. |
| 16 | + env_id (str): L'identifiant de l'environnement. |
| 17 | + num_episodes (int): Le nombre d'épisodes à évaluer. |
| 18 | + seed (int): La graine initiale pour reproduire les résultats. |
| 19 | + tau_soft (float): Température pour le softmax. |
| 20 | + """ |
| 21 | + # Charger le modèle |
| 22 | + env = gym.make(env_id) |
| 23 | + |
| 24 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 25 | + model = QNetwork(env.action_space, env.observation_space).to(device) |
| 26 | + model.load_state_dict(torch.load(model_path)) |
| 27 | + model.eval() |
| 28 | + avg_counter_checkpts =0 |
| 29 | + avg_steps=0 |
| 30 | + avg_return =0 |
| 31 | + # Évaluation |
| 32 | + for episode in range(num_episodes): |
| 33 | + obs, _ = env.reset(seed=seed + episode) |
| 34 | + # obs, _ = env.reset(seed=58) |
| 35 | + |
| 36 | + done = False |
| 37 | + total_reward = 0 |
| 38 | + |
| 39 | + print(f"\nEpisode {episode + 1}/{num_episodes}") |
| 40 | + |
| 41 | + while not done: |
| 42 | + obs_tensor = torch.Tensor(obs).unsqueeze(0).to(device) |
| 43 | + |
| 44 | + with torch.no_grad(): |
| 45 | + q_values = model(obs_tensor) |
| 46 | + policy = torch.softmax(q_values / tau_soft, dim=-1) |
| 47 | + # print('policy', policy) |
| 48 | + # argmax |
| 49 | + action = torch.argmax(policy, dim=-1).item() |
| 50 | + # action = torch.multinomial(policy, num_samples=1).item() |
| 51 | + # print('action', action) |
| 52 | + img = env.render() |
| 53 | + cv2.imshow("Duckietown", img) |
| 54 | + cv2.waitKey(1) |
| 55 | + obs, reward, done, _, _ = env.step(action) |
| 56 | + # print('obs', obs.shape) |
| 57 | + total_reward += reward |
| 58 | + print(f"step {env.unwrapped.episodic_length} reward {reward}") |
| 59 | + |
| 60 | + avg_counter_checkpts += env.env.env.env.race.counter_checkpt |
| 61 | + avg_steps += env.env.env.env.episodic_length |
| 62 | + avg_return += total_reward |
| 63 | + |
| 64 | + print(f"Récompense totale de l'épisode : {total_reward}") |
| 65 | + |
| 66 | + print(f"the Avg. N. Checkpoints:{avg_counter_checkpts/num_episodes}") |
| 67 | + print(f"the Avg. N. Steps:{avg_steps/num_episodes}") |
| 68 | + print(f"the Avg. Return:{avg_return/num_episodes}") |
| 69 | + |
| 70 | + |
| 71 | + env.close() |
| 72 | + |
| 73 | + |
| 74 | +if __name__ == "__main__": |
| 75 | + # model_path /home/dcas/g.ferraro/Desktop/models= "/home/p.le-tolguenec/Documents/duckietown_rl_course/model/exp_1/munchausen_430404_1746.0135750578204.pt" # Adapter au chemin réel |
| 76 | + model_path = ( |
| 77 | + "/home/dcas/g.ferraro/Desktop/models/munchausen_1000000_1648.556805028793.pt" |
| 78 | + ) |
| 79 | + |
| 80 | + env_id = "DuckietownDiscrete-v0" |
| 81 | + |
| 82 | + evaluate_with_rendering(model_path, env_id, num_episodes=20, seed=42, tau_soft=0.05) |
0 commit comments