forked from junhyukoh/value-prediction-network
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmult_agent_plot.py
47 lines (35 loc) · 1.1 KB
/
mult_agent_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
from os.path import join
import pandas as pd
path = './results'
files = [f for f in os.listdir(path)]
df = pd.DataFrame()
for file in files:
path_to_file = join(path, file)
tmp = pd.read_csv(path_to_file, index_col=0)
df = df.append(tmp, ignore_index=True)
# to_plot = df.groupby(by=['num_goals', 'branch'], observed=True).mean().reset_index(level=[0, 1], inplace=False)
import seaborn as sns
import matplotlib.pyplot as plt
print(df['branch'].unique())
to_plot = df[df['branch'].isin(
[
'[4]',
# '[4, 4]',
'[4, 4, 4]',
# '[4, 4, 4, 4]',
# '[4, 4, 4, 4, 4]',
'[4, 4, 4, 1, 1, 1]',
# '[1, 1, 1, 1, 1, 1, 1, 4, 4, 4]',
# '[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]',
# '[1]',
# '[1, 1, 1]',
'[1, 1, 1, 4, 4, 4]',
'[4, 4, 4, 1, 1, 1, 1, 1, 1, 1]'
])]
sns.lineplot(x="num_goals", y="score", hue='branch',
err_style="bars", data=to_plot)
import numpy as np
sns.scatterplot(x=np.arange(1, 11), y=np.arange(1, 11), marker='_',
color='red', label="best possible")
plt.show()