Skip to content

Commit 24adbb7

Browse files
committed
complete project init
1 parent 2e10cff commit 24adbb7

7 files changed

+1195
-0
lines changed

cpu.h

+243
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
#ifndef CPU_H
2+
#define CPU_H
3+
4+
#include <unordered_map>
5+
#include <algorithm>
6+
#include <cmath>
7+
#include "game.h"
8+
#include "ntuple.h"
9+
10+
#define EXACT 0
11+
#define LOWER 1
12+
#define UPPER 2
13+
14+
#define INF 100000
15+
16+
using namespace std;
17+
18+
class Move {
19+
public:
20+
int move;
21+
float score;
22+
23+
explicit Move(int move) : move(move) {
24+
score = 0;
25+
}
26+
27+
bool operator < (const Move& move) const {
28+
return score < move.score;
29+
}
30+
31+
bool operator > (const Move& move) const {
32+
return score > move.score;
33+
}
34+
35+
friend ostream& operator<< (ostream &out, Move &move) {
36+
out << move.score << ": " << move.move;
37+
return out;
38+
}
39+
};
40+
41+
class TTEntry {
42+
public:
43+
float value = 0;
44+
int flag = 0;
45+
int age = 0;
46+
int move;
47+
};
48+
49+
class Cpu {
50+
public:
51+
explicit Cpu() {
52+
53+
}
54+
55+
void setPlayer(int player) {
56+
this->player = player;
57+
}
58+
59+
int getPlayer() {
60+
return player;
61+
}
62+
63+
void setNetwork(NTupleNetwork *network) {
64+
this->network = network;
65+
}
66+
67+
void setGame(Game *game) {
68+
this->game = game;
69+
}
70+
71+
Move getBestMove(int levels) {
72+
vector<int> movesNum = game->getMoves();
73+
vector<Move> moves; moves.reserve(movesNum.size());
74+
for (auto & moveNum : movesNum) {
75+
Move move = Move(moveNum);
76+
game->makeMove(move.move);
77+
if (game->isOver()) {
78+
int winner = game->getWinner();
79+
if (winner == player) {
80+
move.score = INF - 10 * game->rounds;
81+
} else if (winner == (player^1) ){
82+
move.score = -INF + 10 * game->rounds;
83+
}
84+
} else {
85+
move.score = getScore();
86+
}
87+
game->undoMove();
88+
moves.push_back(move);
89+
}
90+
91+
Move bestMoveSoFar = getBestMoveSoFar(moves);
92+
if (bestMoveSoFar.score > INF/2) {
93+
return bestMoveSoFar;
94+
}
95+
96+
tt.clear();
97+
for (int i=1; i < levels; i++) {
98+
age = i;
99+
for (auto & move : moves) {
100+
if (move.score < -INF/2) {
101+
continue;
102+
}
103+
game->makeMove(move.move);
104+
move.score = -getScore(-1, i-1, -INF, INF);
105+
game->undoMove();
106+
if (move.score > INF/2) {
107+
return move;
108+
}
109+
if (move.score > bestMoveSoFar.score) {
110+
bestMoveSoFar = move;
111+
}
112+
}
113+
bestMoveSoFar = getBestMoveSoFar(moves);
114+
if (bestMoveSoFar.score < -INF/2) {
115+
break;
116+
}
117+
}
118+
119+
return bestMoveSoFar;
120+
}
121+
122+
private:
123+
int player;
124+
Game *game;
125+
unordered_map<int,TTEntry> tt;
126+
int age;
127+
NTupleNetwork *network = nullptr;
128+
129+
Move getBestMoveSoFar(vector<Move> & moves) {
130+
sort(moves.begin(),moves.end(), [](const Move & a, const Move & b) -> bool
131+
{
132+
return a.score > b.score;
133+
});
134+
135+
float score = moves[0].score;
136+
int n = 1;
137+
138+
for (size_t i=1; i < moves.size(); i++) {
139+
if (moves[i].score < score) break;
140+
n++;
141+
}
142+
143+
return moves[rand() % n];
144+
}
145+
146+
float getScore() {
147+
if (network != nullptr) {
148+
float s = network->predict(game->getTuples());
149+
return player == PLAYER_X ? s : -s;
150+
}
151+
return 0;
152+
}
153+
154+
float getScore(int color, float level, float alpha, float beta) {
155+
float output = -INF;
156+
157+
int state;
158+
int bestMove = -1;
159+
float alphaOrig = alpha;
160+
161+
if (level > 0) {
162+
state = game->getHash();
163+
if (tt.find(state) != tt.end()) {
164+
auto & entry = tt[state];
165+
if (entry.age != age) {
166+
bestMove = entry.move;
167+
} else {
168+
if (entry.flag == EXACT) {
169+
return entry.value;
170+
} else if (entry.flag == LOWER) {
171+
alpha = max(alpha, entry.value);
172+
} else if (entry.flag == UPPER) {
173+
beta = min(beta, entry.value);
174+
}
175+
if (alpha >= beta) {
176+
return entry.value;
177+
}
178+
}
179+
180+
}
181+
}
182+
183+
vector<int> moves = game->getMoves();
184+
185+
if (level > 0) {
186+
if (bestMove != -1) {
187+
auto it = find(moves.begin(),moves.end(),bestMove);
188+
if (it != moves.end()) {
189+
iter_swap(moves.begin(), it);
190+
}
191+
}
192+
}
193+
194+
for (auto & move : moves) {
195+
float score = 0;
196+
game->makeMove(move);
197+
if (game->isOver()) {
198+
int winner = game->getWinner();
199+
if (winner == player) {
200+
score = INF - 10 * game->rounds;
201+
} else if (winner == (player^1)) {
202+
score = -INF + 10 * game->rounds;
203+
}
204+
score *= color;
205+
} else if (level > 0) {
206+
score = -getScore(-color, level-1, -beta, -alpha);
207+
} else {
208+
score = color * getScore();
209+
}
210+
game->undoMove();
211+
if (score > output) {
212+
bestMove = move;
213+
}
214+
output = max(score, output);
215+
alpha = max(alpha, output);
216+
if (alpha >= beta) {
217+
break;
218+
}
219+
}
220+
221+
if (level > 0) {
222+
TTEntry entry;
223+
entry.value = output;
224+
entry.age = age;
225+
226+
if (output <= alphaOrig) {
227+
entry.flag = UPPER;
228+
} else if (output >= beta) {
229+
entry.move = bestMove;
230+
entry.flag = LOWER;
231+
} else {
232+
entry.move = bestMove;
233+
entry.flag = EXACT;
234+
}
235+
236+
tt[state] = entry;
237+
}
238+
239+
return output;
240+
}
241+
};
242+
243+
#endif // CPU_H

0 commit comments

Comments
 (0)