Building a game AI for Animal Shogi using Monte Carlo Tree Search (MCTS)
algorithmmachinelearninggolangMonte Carlo Tree Search (MCTS) is an algorithm used for searching optimal moves in game AI such as AlphaGo. Similar algorithms include the minimax, which exhaustively searches for the best moves assuming the opponent always makes the worst possible moves, and the alpha-beta pruning, which improves efficiency through pruning. While these algorithms require assigning evaluation values to intermediate board status, Monte Carlo Tree Search can evaluate the move based on the win/loss results of randomly playing out the game.
Game Implementation
First, I implemented the game for simulation. The complete code is available on GitHub. In Animal Shogi, you can place captured pieces like in regular shogi in addition to moving your own pieces. You win by either capturing the opponent’s lion or moving your lion to the last row of the opponent’s territory. Only chicks can promote to chickens when they reach the opponent’s last row.
func (b *Board) IsGameOver() (bool, player) {
var p1Lion, p2Lion bool
for row := 0; row < 4; row++ {
for col := 0; col < 3; col++ {
square := b.squares[row][col]
if square.piece == lion {
if square.player == Player1 {
p1Lion = true
// 自分のライオンを相手陣の1段目に移動させる「トライ」
if row == 0 {
return true, Player1
}
} else if square.player == Player2 {
p2Lion = true
if row == 3 {
return true, Player2
}
}
}
}
}
// 相手のライオンを取る「キャッチ」
if !p1Lion {
return true, Player2
}
if !p2Lion {
return true, Player1
}
return false, none
}
func (b *Board) GetAllValidMoves() []Move {
var allMoves []Move
// 持ち駒を出す場合
for _, piece := range b.captured[b.turnPlayer] {
drops := b.getValidDrops(piece)
for _, drop := range drops {
nextBoard := b.ApplyMove(drop)
// 千日手(手番が全く同じ状態が3回現れる)は除く
if nextBoard.occurrences[nextBoard.String()] < 4 {
allMoves = append(allMoves, drop)
}
}
}
// 盤面の駒を移動する場合
for row := 0; row < 4; row++ {
for col := 0; col < 3; col++ {
square := b.squares[row][col]
if square.player == b.turnPlayer {
moves := b.getValidMoves(row, col)
for _, move := range moves {
nextBoard := b.ApplyMove(move)
if nextBoard.occurrences[nextBoard.String()] < 4 {
allMoves = append(allMoves, move)
}
}
}
}
}
return allMoves
}
func (b *Board) getValidDrops(piece piece) []Move {
var drops []Move
for row := 0; row < 4; row++ {
for col := 0; col < 3; col++ {
if b.squares[row][col].piece == empty {
drops = append(drops, Move{-1, -1, row, col, piece})
}
}
}
return drops
}
func (b *Board) getValidMoves(row, col int) []Move {
var moves []Move
if row < 0 || row >= 4 || col < 0 || col >= 3 {
return moves
}
square := b.squares[row][col]
if square.piece == empty {
return moves
}
addIfValid := func(toRow, toCol int) {
if toRow >= 0 && toRow < 4 && toCol >= 0 && toCol < 3 {
if b.squares[toRow][toCol].player != square.player {
moves = append(moves, Move{row, col, toRow, toCol, empty})
}
}
}
switch square.piece {
// 隣接する8マスのいずれかに進むことができる。
case lion:
for drow := -1; drow <= 1; drow++ {
for dcol := -1; dcol <= 1; dcol++ {
if drow == 0 && dcol == 0 {
continue
}
addIfValid(row+drow, col+dcol)
}
}
// 斜めの4マスのいずれかに進むことができる
case elephant:
directions := [][2]int{{-1, -1}, {-1, 1}, {1, -1}, {1, 1}}
for _, d := range directions {
addIfValid(row+d[0], col+d[1])
}
// 縦・横の4マスのいずれかに進むことができる
case giraffe:
directions := [][2]int{{-1, 0}, {1, 0}, {0, -1}, {0, 1}}
for _, d := range directions {
addIfValid(row+d[0], col+d[1])
}
// 前の1マスにのみ進むことができる。
case chick:
if square.player == Player1 {
addIfValid(row-1, col)
} else {
addIfValid(row+1, col)
}
// 斜め後ろ以外の6マスのいずれかに進むことができる。
case chicken:
for drow := -1; drow <= 1; drow++ {
for dcol := -1; dcol <= 1; dcol++ {
if drow == 0 && dcol == 0 {
continue
}
if (square.player == Player1 && drow == 1) ||
(square.player == Player2 && drow == -1) {
continue
}
addIfValid(row+drow, col+dcol)
}
}
}
return moves
}
func (b *Board) Play(chooseMove func(Board) Move) *Board {
currentBoard := b
for {
if gameover, _ := currentBoard.IsGameOver(); gameover {
return currentBoard
}
currentBoard = currentBoard.ApplyMove(chooseMove(*currentBoard))
}
}
When both players move randomly, the win rate is approximately 50%.
$ cat main.go
package main
import (
"fmt"
"math/rand"
ds "github.com/sambaiz/doubutsu_shogi/doubutsushogi"
)
func main() {
board := ds.NewBoard()
randomMove := func(b ds.Board) ds.Move {
fmt.Println(b)
moves := b.GetAllValidMoves()
return moves[rand.Intn(len(moves))]
}
endBoard := board.Play(randomMove)
fmt.Println(endBoard)
_, winPlayer := endBoard.IsGameOver()
fmt.Printf("Player%d win!\n", winPlayer)
}
$ go run main
2G 2L 2E
-- 2C --
-- 1C --
1E 1L 1G
Player1 has:
Player2 has:
- Player1's turn -
2G 2L 2E
-- 2C --
-- 1C 1L
1E -- 1G
Player1 has:
Player2 has:
- Player2's turn -
-- 2L 2E
2G 2C --
-- 1C 1L
1E -- 1G
Player1 has:
Player2 has:
- Player1's turn -
-- 2L 2E
2G 2C 1L
-- 1C --
1E -- 1G
Player1 has:
Player2 has:
- Player2's turn -
-- 2L 2E
-- 2C 1L
2G 1C --
1E -- 1G
Player1 has:
Player2 has:
- Player1's turn -
-- 2L 2E
-- 1L --
2G 1C --
1E -- 1G
Player1 has: C
Player2 has:
- Player2's turn -
-- 2L 2E
-- 1L --
-- 2G --
1E -- 1G
Player1 has: C
Player2 has: C
- Player1's turn -
-- 1L 2E
-- -- --
-- 2G --
1E -- 1G
Player1 has: L C
Player2 has: C
- Player2's turn -
Player1 win!
Implementation of Monte Carlo Tree Search
Monte Carlo Tree Search repeatedly selects the leaf node with the highest evaluation value, plays random moves from there, and updates the evaluation value based on the results. The evaluation function uses UCT (Upper Confidence Bound for Trees), which is based on the UCB1 bandit algorithm. \(Q_i\) represents the total reward for that node, \(n_i\) is the number of times that node has been explored, \(N_i\) is the number of times the parent node has been explored.\(C\) is a tuning parameter, commonly set to \(\sqrt{2}\).
$$ UCT = \frac{Q_i}{n_i} + C\sqrt{\frac{\ln{N}}{n_i}} $$
package mcts
import (
"fmt"
"math"
"strings"
"math/rand"
ds "github.com/sambaiz/doubutsu_shogi/doubutsushogi"
)
type Node struct {
board *ds.Board
move *ds.Move
totalScore float64
visitCount int64
parent *Node
children []*Node
}
func NewNode(board *ds.Board, move *ds.Move, parent *Node) *Node {
root := &Node{
board: board,
move: move,
parent: parent,
}
childMoves := board.GetAllValidMoves()
root.children = make([]*Node, 0, len(childMoves))
for _, childMove := range childMoves {
root.children = append(root.children, &Node{
board: board.ApplyMove(childMove),
move: &childMove,
parent: root,
})
}
return root
}
func (n *Node) String() string {
var sb strings.Builder
if n.board != nil {
sb.WriteString(fmt.Sprintf("Board:\n%s\n", n.board.String()))
}
if n.move != nil {
sb.WriteString(fmt.Sprintf("Move:\n%s\n", n.move.String()))
}
return sb.String()
}
func (n *Node) PrettyString(prefix string) string {
var sb strings.Builder
if n.move != nil {
sb.WriteString(fmt.Sprintf("%s├──%s [%f/%d/%d]\n", prefix, n.move.String(), n.calculateScore(), int(n.totalScore), n.visitCount))
}
for _, child := range n.children {
newPrefix := prefix
if n.move != nil {
newPrefix += "│ "
}
sb.WriteString(child.PrettyString(newPrefix))
}
return sb.String()
}
func (n *Node) SelectBestChild() *Node {
if len(n.children) == 0 {
return nil
}
maxScore := math.Inf(-1)
var bestChild *Node
for _, child := range n.children {
score := child.calculateScore()
if score > maxScore {
maxScore = score
bestChild = child
}
}
return bestChild
}
func (n *Node) SelectRandomChild() *Node {
if len(n.children) == 0 {
return nil
}
return n.children[rand.Intn(len(n.children))]
}
func (n *Node) calculateScore() float64 {
if n.visitCount == 0 {
return math.Inf(1) // デフォルト値
}
exploitation := n.totalScore / float64(n.visitCount)
exploration := math.Sqrt(2) * math.Sqrt(
math.Log(float64(n.parent.visitCount))/float64(n.visitCount),
)
return exploitation + exploration
}
// 相手 + 自分の手番で最大2レベル増えます
func (n *Node) Expand() {
if gameover, _ := n.board.IsGameOver(); gameover {
return
}
moves := n.board.GetAllValidMoves()
n.children = make([]*Node, 0, len(moves))
for _, move := range moves {
childNode := NewNode(n.board.ApplyMove(move), &move, n)
if gameover, _ := childNode.board.IsGameOver(); gameover {
continue
}
n.children = append(n.children, childNode)
}
}
const EXPANSION_THRESHOLD = 10
func (n *Node) SimulateAndExpand() bool {
board := n.board
randomMove := func(b ds.Board) ds.Move {
moves := b.GetAllValidMoves()
return moves[rand.Intn(len(moves))]
}
endBoard := board.Play(randomMove)
_, winPlayer := endBoard.IsGameOver()
// Backpropagation
for {
if n == nil {
break
}
n.visitCount += 1
if winPlayer != n.board.TurnPlayer() {
n.totalScore += 1
}
if len(n.children) == 0 && n.visitCount >= EXPANSION_THRESHOLD {
n.Expand()
}
n = n.parent
}
return winPlayer == 1
}
After around 200,000 iterations, it was able to win 99% of games against a random opponent, and perfect wins were also observed.
$ cat main.go
package main
import (
"fmt"
ds "github.com/sambaiz/doubutsu_shogi/doubutsushogi"
"github.com/sambaiz/doubutsu_shogi/mcts"
)
func main() {
board := ds.NewBoard()
rootNode := mcts.NewNode(board, nil, nil)
for i := 0; i <= 1000; i++ {
winCount := 0
for j := 0; j < 1000; j++ {
selectedNode := rootNode
for {
// Player 1 は最もスコアが高いノードを選ぶ
nextNode := selectedNode.SelectBestChild()
if nextNode == nil {
break
}
selectedNode = nextNode
// Player 2 はランダムに選ぶ
nextNode = selectedNode.SelectRandomChild()
if nextNode == nil {
break
}
selectedNode = nextNode
}
if selectedNode.SimulateAndExpand() {
winCount++
}
}
if i%100 == 0 {
fmt.Printf("%d: win rate: %.1f%%\n", i*1000, 100.0*float64(winCount)/1000.0)
}
}
// fmt.Println(rootNode.PrettyString(""))
}
$ go run main.go
0: win rate: 61.1%
100000: win rate: 98.7%
200000: win rate: 99.1%
300000: win rate: 99.5%
400000: win rate: 99.9%
500000: win rate: 99.4%
600000: win rate: 99.9%
700000: win rate: 99.5%
800000: win rate: 99.6%
900000: win rate: 99.8%
1000000: win rate: 100.0%
The tree has also been updated.
├──(2, 1) -> (1, 1) [0.839369/116/234]
│ ├──(0, 0) -> (1, 0) [0.831078/20/53]
│ │ ├──C -> (0, 0) [1.960251/1/3]
│ │ ├──C -> (1, 2) [1.860205/3/5]
│ │ ├──C -> (2, 0) [1.922211/6/7]
│ │ ├──C -> (2, 1) [1.983738/5/6]
│ │ ├──C -> (2, 2) [1.626918/0/3]
│ │ ├──(1, 1) -> (0, 1) [1.939301/9/9]
│ │ ├──(3, 0) -> (2, 1) [1.908952/2/4]
│ │ ├──(3, 1) -> (2, 0) [1.817071/4/6]
│ │ ├──(3, 1) -> (2, 1) [1.626918/0/3]
│ │ ├──(3, 1) -> (2, 2) [1.908952/2/4]
│ │ ├──(3, 2) -> (2, 2) [1.960251/1/3]
│ ├──(0, 1) -> (1, 0) [0.954886/18/41]
│ │ ├──C -> (0, 1) [1.862639/2/4]
...