Skip to content

Commit 8f27fda

Browse files
committed
Implement a new AI algorithm and fix atomic bug
Implement another AI algorithm which is negamax, also provide it's helper package zobrist. However, we've shrink the size of HASH_TABLE_SIZE or it'll be too large to be kmalloc(). The atomic variable "turn" was a wrong implementation which will cause numerous error message in kernel space, so we change "turn" to the type of "char" and use memory barrier and the macros "READ_ONCE()", "WRITE_ONCE()" to provide the data consistency between different processor.
1 parent 21a4c9d commit 8f27fda

File tree

9 files changed

+232
-32
lines changed

9 files changed

+232
-32
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
TARGET = kmldrv
2-
kmldrv-objs = simrupt.o game.o mcts.o
2+
kmldrv-objs = simrupt.o game.o wyhash.o mcts.o negamax.o zobrist.o
33
obj-m := $(TARGET).o
44

55
KDIR ?= /lib/modules/$(shell uname -r)/build

game.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ typedef unsigned fixed_point_t;
3030
#define DRAW_SIZE (N_GRIDS + BOARD_SIZE)
3131
#define DRAWBUFFER_SIZE \
3232
((BOARD_SIZE * (BOARD_SIZE + 1) << 1) + (BOARD_SIZE * BOARD_SIZE) + \
33-
((BOARD_SIZE << 1) + 1))
33+
((BOARD_SIZE << 1) + 1) + 1)
3434

3535
extern const line_t lines[4];
3636

negamax.c

+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#include <linux/slab.h>
2+
#include <linux/sort.h>
3+
#include <linux/string.h>
4+
5+
#include "game.h"
6+
#include "negamax.h"
7+
#include "util.h"
8+
#include "zobrist.h"
9+
10+
#define MAX_SEARCH_DEPTH 6
11+
12+
static int history_score_sum[N_GRIDS];
13+
static int history_count[N_GRIDS];
14+
15+
static u64 hash_value;
16+
17+
static int cmp_moves(const void *a, const void *b)
18+
{
19+
int *_a = (int *) a, *_b = (int *) b;
20+
int score_a = 0, score_b = 0;
21+
22+
if (history_count[*_a])
23+
score_a = history_score_sum[*_a] / history_count[*_a];
24+
if (history_count[*_b])
25+
score_b = history_score_sum[*_b] / history_count[*_b];
26+
return score_b - score_a;
27+
}
28+
29+
static move_t negamax(char *table, int depth, char player, int alpha, int beta)
30+
{
31+
if (check_win(table) != ' ' || depth == 0) {
32+
move_t result = {get_score(table, player), -1};
33+
return result;
34+
}
35+
zobrist_entry_t *entry = zobrist_get(hash_value);
36+
if (entry)
37+
return (move_t){.score = entry->score, .move = entry->move};
38+
39+
int score;
40+
move_t best_move = {-10000, -1};
41+
int *moves = available_moves(table);
42+
int n_moves = 0;
43+
while (n_moves < N_GRIDS && moves[n_moves] != -1)
44+
++n_moves;
45+
46+
sort(moves, n_moves, sizeof(int), cmp_moves, NULL);
47+
48+
for (int i = 0; i < n_moves; i++) {
49+
table[moves[i]] = player;
50+
hash_value ^= zobrist_table[moves[i]][player == 'X'];
51+
if (!i)
52+
score = -negamax(table, depth - 1, player == 'X' ? 'O' : 'X', -beta,
53+
-alpha)
54+
.score;
55+
else {
56+
score = -negamax(table, depth - 1, player == 'X' ? 'O' : 'X',
57+
-alpha - 1, -alpha)
58+
.score;
59+
if (alpha < score && score < beta)
60+
score = -negamax(table, depth - 1, player == 'X' ? 'O' : 'X',
61+
-beta, -score)
62+
.score;
63+
}
64+
history_count[moves[i]]++;
65+
history_score_sum[moves[i]] += score;
66+
if (score > best_move.score) {
67+
best_move.score = score;
68+
best_move.move = moves[i];
69+
}
70+
table[moves[i]] = ' ';
71+
hash_value ^= zobrist_table[moves[i]][player == 'X'];
72+
if (score > alpha)
73+
alpha = score;
74+
if (alpha >= beta)
75+
break;
76+
}
77+
78+
kfree((char *) moves);
79+
zobrist_put(hash_value, best_move.score, best_move.move);
80+
return best_move;
81+
}
82+
83+
void negamax_init(void)
84+
{
85+
zobrist_init();
86+
hash_value = 0;
87+
}
88+
89+
move_t negamax_predict(char *table, char player)
90+
{
91+
memset(history_score_sum, 0, sizeof(history_score_sum));
92+
memset(history_count, 0, sizeof(history_count));
93+
move_t result;
94+
for (int depth = 2; depth <= MAX_SEARCH_DEPTH; depth += 2) {
95+
result = negamax(table, depth, player, -100000, 100000);
96+
zobrist_clear();
97+
}
98+
return result;
99+
}

negamax.h

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#pragma once
2+
3+
typedef struct {
4+
int score, move;
5+
} move_t;
6+
7+
void negamax_init(void);
8+
move_t negamax_predict(char *table, char player);

simrupt.c

+14-12
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "game.h"
1414
#include "mcts.h"
15+
#include "negamax.h"
1516

1617
MODULE_LICENSE("Dual MIT/GPL");
1718
MODULE_AUTHOR("National Cheng Kung University, Taiwan");
@@ -147,7 +148,7 @@ static void simrupt_work_func(struct work_struct *w)
147148
wake_up_interruptible(&rx_wait);
148149
}
149150

150-
static atomic_t turn;
151+
static char turn;
151152

152153
static void ai_one_work_func(struct work_struct *w)
153154
{
@@ -159,23 +160,23 @@ static void ai_one_work_func(struct work_struct *w)
159160
WARN_ON_ONCE(in_softirq());
160161
WARN_ON_ONCE(in_interrupt());
161162

162-
int expect = 0;
163-
atomic_read_acquire(&turn);
164-
if (!atomic_try_cmpxchg(&turn, &expect, 1))
163+
READ_ONCE(turn);
164+
if (turn == 'X')
165165
return;
166166

167167
cpu = get_cpu();
168168
pr_info("simrupt: [CPU#%d] start doing %s\n", cpu, __func__);
169169
tv_start = ktime_get();
170170
mutex_lock(&producer_lock);
171-
int move = mcts(table, 'O');
171+
int move;
172+
WRITE_ONCE(move, mcts(table, 'O'));
172173

173174
smp_mb();
174175

175176
if (move != -1)
176177
WRITE_ONCE(table[move], 'O');
177178

178-
atomic_set_release(&turn, 1);
179+
WRITE_ONCE(turn, 'X');
179180
mutex_unlock(&producer_lock);
180181
tv_end = ktime_get();
181182

@@ -195,23 +196,23 @@ static void ai_two_work_func(struct work_struct *w)
195196
WARN_ON_ONCE(in_softirq());
196197
WARN_ON_ONCE(in_interrupt());
197198

198-
int expect = 1;
199-
atomic_read_acquire(&turn);
200-
if (!atomic_try_cmpxchg(&turn, &expect, 0))
199+
READ_ONCE(turn);
200+
if (turn == 'O')
201201
return;
202202

203203
cpu = get_cpu();
204204
pr_info("simrupt: [CPU#%d] start doing %s\n", cpu, __func__);
205205
tv_start = ktime_get();
206206
mutex_lock(&producer_lock);
207-
int move = mcts(table, 'X');
207+
int move;
208+
WRITE_ONCE(move, negamax_predict(table, 'X').move);
208209

209210
smp_mb();
210211

211212
if (move != -1)
212213
WRITE_ONCE(table[move], 'X');
213214

214-
atomic_set_release(&turn, 0);
215+
WRITE_ONCE(turn, 'O');
215216
mutex_unlock(&producer_lock);
216217
tv_end = ktime_get();
217218

@@ -443,8 +444,9 @@ static int __init simrupt_init(void)
443444
goto error_cdev;
444445
}
445446

447+
negamax_init();
446448
memset(table, ' ', N_GRIDS);
447-
atomic_set_release(&turn, 0);
449+
turn = 'O';
448450
/* Setup the timer */
449451
timer_setup(&timer, timer_handler, 0);
450452
atomic_set(&open_cnt, 0);

wyhash.c

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#include <linux/string.h>
2+
#include <linux/timer.h>
3+
4+
#include "wyhash.h"
5+
6+
static inline u64 wyhash64_stateless(u64 *seed)
7+
{
8+
*seed += 0x60bee2bee120fc15;
9+
u128 tmp;
10+
tmp = (u128) *seed * 0xa3b195354a39b70d;
11+
u64 m1 = (tmp >> 64) ^ tmp;
12+
tmp = (u128) m1 * 0x1b03738712fad5c9;
13+
u64 m2 = (tmp >> 64) ^ tmp;
14+
return m2;
15+
}
16+
17+
u64 wyhash64(void)
18+
{
19+
u64 seed = (u64) ktime_to_ns(ktime_get());
20+
return wyhash64_stateless(&seed);
21+
}

wyhash.h

+1-18
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,2 @@
1-
#include <linux/string.h>
2-
#include <linux/timer.h>
31

4-
static inline u64 wyhash64_stateless(u64 *seed)
5-
{
6-
*seed += 0x60bee2bee120fc15;
7-
u128 tmp;
8-
tmp = (u128) *seed * 0xa3b195354a39b70d;
9-
u64 m1 = (tmp >> 64) ^ tmp;
10-
tmp = (u128) m1 * 0x1b03738712fad5c9;
11-
u64 m2 = (tmp >> 64) ^ tmp;
12-
return m2;
13-
}
14-
15-
u64 wyhash64(void)
16-
{
17-
u64 seed = (u64) ktime_to_ns(ktime_get());
18-
return wyhash64_stateless(&seed);
19-
}
2+
u64 wyhash64(void);

zobrist.c

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#include <linux/slab.h>
2+
3+
#include "wyhash.h"
4+
#include "zobrist.h"
5+
6+
u64 zobrist_table[N_GRIDS][2];
7+
8+
#define HASH(key) ((key) % HASH_TABLE_SIZE)
9+
10+
static struct hlist_head *hash_table;
11+
12+
void zobrist_init(void)
13+
{
14+
int i;
15+
for (i = 0; i < N_GRIDS; i++) {
16+
zobrist_table[i][0] = wyhash64();
17+
zobrist_table[i][1] = wyhash64();
18+
}
19+
hash_table =
20+
kmalloc(sizeof(struct hlist_head) * HASH_TABLE_SIZE, GFP_KERNEL);
21+
if (!hash_table) {
22+
pr_info("simrupt: Failed to allocate space for hash_table\n");
23+
return;
24+
}
25+
for (i = 0; i < HASH_TABLE_SIZE; i++)
26+
INIT_HLIST_HEAD(&hash_table[i]);
27+
}
28+
29+
zobrist_entry_t *zobrist_get(u64 key)
30+
{
31+
unsigned long long hash_key = HASH(key);
32+
33+
if (hlist_empty(&hash_table[hash_key]))
34+
return NULL;
35+
36+
zobrist_entry_t *entry = NULL;
37+
38+
hlist_for_each_entry (entry, &hash_table[hash_key], ht_list) {
39+
if (entry->key == key)
40+
return entry;
41+
}
42+
return NULL;
43+
}
44+
45+
void zobrist_put(u64 key, int score, int move)
46+
{
47+
unsigned long long hash_key = HASH(key);
48+
zobrist_entry_t *new_entry = kmalloc(sizeof(zobrist_entry_t), GFP_KERNEL);
49+
new_entry->key = key;
50+
new_entry->move = move;
51+
new_entry->score = score;
52+
hlist_add_head(&new_entry->ht_list, &hash_table[hash_key]);
53+
}
54+
55+
void zobrist_clear(void)
56+
{
57+
for (int i = 0; i < HASH_TABLE_SIZE; i++) {
58+
while (!hlist_empty(&hash_table[i])) {
59+
zobrist_entry_t *entry =
60+
hlist_entry(hash_table[i].first, zobrist_entry_t, ht_list);
61+
hlist_del(&entry->ht_list);
62+
kfree(entry);
63+
}
64+
INIT_HLIST_HEAD(&hash_table[i]);
65+
}
66+
}

zobrist.h

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#pragma once
2+
3+
#include <linux/list.h>
4+
5+
#include "game.h"
6+
7+
#define HASH_TABLE_SIZE (100003)
8+
9+
extern u64 zobrist_table[N_GRIDS][2];
10+
11+
typedef struct {
12+
u64 key;
13+
int score;
14+
int move;
15+
struct hlist_node ht_list;
16+
} zobrist_entry_t;
17+
18+
void zobrist_init(void);
19+
zobrist_entry_t *zobrist_get(u64 key);
20+
void zobrist_put(u64 key, int score, int move);
21+
void zobrist_clear(void);

0 commit comments

Comments
 (0)