Building a game AI for Animal Shogi using Monte Carlo Tree Search (MCTS)

algorithmmachinelearninggolang

Monte Carlo Tree Search (MCTS) is an algorithm used for searching optimal moves in game AI such as AlphaGo. Similar algorithms include the minimax, which exhaustively searches for the best moves assuming the opponent always makes the worst possible moves, and the alpha-beta pruning, which improves efficiency through pruning. While these algorithms require assigning evaluation values to intermediate board status, Monte Carlo Tree Search can evaluate the move based on the win/loss results of randomly playing out the game.

Game Implementation

First, I implemented the game for simulation. The complete code is available on GitHub. In Animal Shogi, you can place captured pieces like in regular shogi in addition to moving your own pieces. You win by either capturing the opponent’s lion or moving your lion to the last row of the opponent’s territory. Only chicks can promote to chickens when they reach the opponent’s last row.

func (b *Board) IsGameOver() (bool, player) {
	var p1Lion, p2Lion bool
	for row := 0; row < 4; row++ {
		for col := 0; col < 3; col++ {
			square := b.squares[row][col]
			if square.piece == lion {
				if square.player == Player1 {
					p1Lion = true
					// 自分のライオンを相手陣の1段目に移動させる「トライ」
					if row == 0 {
						return true, Player1
					}
				} else if square.player == Player2 {
					p2Lion = true
					if row == 3 {
						return true, Player2
					}
				}
			}
		}
	}

	// 相手のライオンを取る「キャッチ」
	if !p1Lion {
		return true, Player2
	}
	if !p2Lion {
		return true, Player1
	}

	return false, none
}

func (b *Board) GetAllValidMoves() []Move {
	var allMoves []Move

	// 持ち駒を出す場合
	for _, piece := range b.captured[b.turnPlayer] {
		drops := b.getValidDrops(piece)
		for _, drop := range drops {
			nextBoard := b.ApplyMove(drop)
			// 千日手(手番が全く同じ状態が3回現れる)は除く
			if nextBoard.occurrences[nextBoard.String()] < 4 {
				allMoves = append(allMoves, drop)
			}
		}
	}

	// 盤面の駒を移動する場合
	for row := 0; row < 4; row++ {
		for col := 0; col < 3; col++ {
			square := b.squares[row][col]
			if square.player == b.turnPlayer {
				moves := b.getValidMoves(row, col)
				for _, move := range moves {
					nextBoard := b.ApplyMove(move)
					if nextBoard.occurrences[nextBoard.String()] < 4 {
						allMoves = append(allMoves, move)
					}
				}
			}
		}
	}

	return allMoves
}

func (b *Board) getValidDrops(piece piece) []Move {
	var drops []Move

	for row := 0; row < 4; row++ {
		for col := 0; col < 3; col++ {
			if b.squares[row][col].piece == empty {
				drops = append(drops, Move{-1, -1, row, col, piece})
			}
		}
	}

	return drops
}

func (b *Board) getValidMoves(row, col int) []Move {
	var moves []Move
	if row < 0 || row >= 4 || col < 0 || col >= 3 {
		return moves
	}

	square := b.squares[row][col]
	if square.piece == empty {
		return moves
	}

	addIfValid := func(toRow, toCol int) {
		if toRow >= 0 && toRow < 4 && toCol >= 0 && toCol < 3 {
			if b.squares[toRow][toCol].player != square.player {
				moves = append(moves, Move{row, col, toRow, toCol, empty})
			}
		}
	}

	switch square.piece {
	// 隣接する8マスのいずれかに進むことができる。
	case lion:
		for drow := -1; drow <= 1; drow++ {
			for dcol := -1; dcol <= 1; dcol++ {
				if drow == 0 && dcol == 0 {
					continue
				}
				addIfValid(row+drow, col+dcol)
			}
		}

	// 斜めの4マスのいずれかに進むことができる
	case elephant:
		directions := [][2]int{{-1, -1}, {-1, 1}, {1, -1}, {1, 1}}
		for _, d := range directions {
			addIfValid(row+d[0], col+d[1])
		}

	// 縦・横の4マスのいずれかに進むことができる
	case giraffe:
		directions := [][2]int{{-1, 0}, {1, 0}, {0, -1}, {0, 1}}
		for _, d := range directions {
			addIfValid(row+d[0], col+d[1])
		}

	// 前の1マスにのみ進むことができる。
	case chick:
		if square.player == Player1 {
			addIfValid(row-1, col)
		} else {
			addIfValid(row+1, col)
		}

	// 斜め後ろ以外の6マスのいずれかに進むことができる。
	case chicken:
		for drow := -1; drow <= 1; drow++ {
			for dcol := -1; dcol <= 1; dcol++ {
				if drow == 0 && dcol == 0 {
					continue
				}
				if (square.player == Player1 && drow == 1) ||
					(square.player == Player2 && drow == -1) {
					continue
				}
				addIfValid(row+drow, col+dcol)
			}
		}
	}

	return moves
}

func (b *Board) Play(chooseMove func(Board) Move) *Board {
	currentBoard := b

	for {
		if gameover, _ := currentBoard.IsGameOver(); gameover {
			return currentBoard
		}

		currentBoard = currentBoard.ApplyMove(chooseMove(*currentBoard))
	}
}

When both players move randomly, the win rate is approximately 50%.

$ cat main.go
package main

import (
	"fmt"

	"math/rand"

	ds "github.com/sambaiz/doubutsu_shogi/doubutsushogi"
)

func main() {
	board := ds.NewBoard()

	randomMove := func(b ds.Board) ds.Move {
		fmt.Println(b)
		moves := b.GetAllValidMoves()
		return moves[rand.Intn(len(moves))]
	}

	endBoard := board.Play(randomMove)
	fmt.Println(endBoard)

	_, winPlayer := endBoard.IsGameOver()
	fmt.Printf("Player%d win!\n", winPlayer)
}

$ go run main

2G 2L 2E 
-- 2C -- 
-- 1C -- 
1E 1L 1G 

Player1 has: 
Player2 has: 

- Player1's turn -

2G 2L 2E 
-- 2C -- 
-- 1C 1L 
1E -- 1G 

Player1 has: 
Player2 has: 

- Player2's turn -

-- 2L 2E 
2G 2C -- 
-- 1C 1L 
1E -- 1G 

Player1 has: 
Player2 has: 

- Player1's turn -

-- 2L 2E 
2G 2C 1L 
-- 1C -- 
1E -- 1G 

Player1 has: 
Player2 has: 

- Player2's turn -

-- 2L 2E 
-- 2C 1L 
2G 1C -- 
1E -- 1G 

Player1 has: 
Player2 has: 

- Player1's turn -

-- 2L 2E 
-- 1L -- 
2G 1C -- 
1E -- 1G 

Player1 has: C 
Player2 has: 

- Player2's turn -

-- 2L 2E 
-- 1L -- 
-- 2G -- 
1E -- 1G 

Player1 has: C 
Player2 has: C 

- Player1's turn -

-- 1L 2E 
-- -- -- 
-- 2G -- 
1E -- 1G 

Player1 has: L C 
Player2 has: C 

- Player2's turn -

Player1 win!

Monte Carlo Tree Search repeatedly selects the leaf node with the highest evaluation value, plays random moves from there, and updates the evaluation value based on the results. The evaluation function uses UCT (Upper Confidence Bound for Trees), which is based on the UCB1 bandit algorithm. \(Q_i\) represents the total reward for that node, \(n_i\) is the number of times that node has been explored, \(N_i\) is the number of times the parent node has been explored.\(C\) is a tuning parameter, commonly set to \(\sqrt{2}\).

$$ UCT = \frac{Q_i}{n_i} + C\sqrt{\frac{\ln{N}}{n_i}} $$

package mcts

import (
	"fmt"
	"math"
	"strings"

	"math/rand"

	ds "github.com/sambaiz/doubutsu_shogi/doubutsushogi"
)

type Node struct {
	board      *ds.Board
	move       *ds.Move
	totalScore float64
	visitCount int64
	parent     *Node
	children   []*Node
}

func NewNode(board *ds.Board, move *ds.Move, parent *Node) *Node {
	root := &Node{
		board:  board,
		move:   move,
		parent: parent,
	}

	childMoves := board.GetAllValidMoves()
	root.children = make([]*Node, 0, len(childMoves))
	for _, childMove := range childMoves {
		root.children = append(root.children, &Node{
			board:  board.ApplyMove(childMove),
			move:   &childMove,
			parent: root,
		})
	}

	return root
}

func (n *Node) String() string {
	var sb strings.Builder

	if n.board != nil {
		sb.WriteString(fmt.Sprintf("Board:\n%s\n", n.board.String()))
	}
	if n.move != nil {
		sb.WriteString(fmt.Sprintf("Move:\n%s\n", n.move.String()))
	}

	return sb.String()
}

func (n *Node) PrettyString(prefix string) string {
	var sb strings.Builder

	if n.move != nil {
		sb.WriteString(fmt.Sprintf("%s├──%s [%f/%d/%d]\n", prefix, n.move.String(), n.calculateScore(), int(n.totalScore), n.visitCount))
	}

	for _, child := range n.children {
		newPrefix := prefix
		if n.move != nil {
			newPrefix += "│   "
		}
		sb.WriteString(child.PrettyString(newPrefix))
	}

	return sb.String()
}

func (n *Node) SelectBestChild() *Node {
	if len(n.children) == 0 {
		return nil
	}

	maxScore := math.Inf(-1)
	var bestChild *Node

	for _, child := range n.children {
		score := child.calculateScore()
		if score > maxScore {
			maxScore = score
			bestChild = child
		}
	}

	return bestChild
}

func (n *Node) SelectRandomChild() *Node {
	if len(n.children) == 0 {
		return nil
	}

	return n.children[rand.Intn(len(n.children))]
}

func (n *Node) calculateScore() float64 {
	if n.visitCount == 0 {
		return math.Inf(1) // デフォルト値
	}

	exploitation := n.totalScore / float64(n.visitCount)

	exploration := math.Sqrt(2) * math.Sqrt(
		math.Log(float64(n.parent.visitCount))/float64(n.visitCount),
	)

	return exploitation + exploration
}

// 相手 + 自分の手番で最大2レベル増えます
func (n *Node) Expand() {
	if gameover, _ := n.board.IsGameOver(); gameover {
		return
	}

	moves := n.board.GetAllValidMoves()
	n.children = make([]*Node, 0, len(moves))

	for _, move := range moves {
		childNode := NewNode(n.board.ApplyMove(move), &move, n)
		if gameover, _ := childNode.board.IsGameOver(); gameover {
			continue
		}
		n.children = append(n.children, childNode)
	}
}

const EXPANSION_THRESHOLD = 10

func (n *Node) SimulateAndExpand() bool {
	board := n.board

	randomMove := func(b ds.Board) ds.Move {
		moves := b.GetAllValidMoves()
		return moves[rand.Intn(len(moves))]
	}

	endBoard := board.Play(randomMove)

	_, winPlayer := endBoard.IsGameOver()

	// Backpropagation
	for {
		if n == nil {
			break
		}
		n.visitCount += 1
		if winPlayer != n.board.TurnPlayer() {
			n.totalScore += 1
		}
		if len(n.children) == 0 && n.visitCount >= EXPANSION_THRESHOLD {
			n.Expand()
		}
		n = n.parent
	}

	return winPlayer == 1
}

After around 200,000 iterations, it was able to win 99% of games against a random opponent, and perfect wins were also observed.

$ cat main.go
package main

import (
	"fmt"

	ds "github.com/sambaiz/doubutsu_shogi/doubutsushogi"
	"github.com/sambaiz/doubutsu_shogi/mcts"
)

func main() {
	board := ds.NewBoard()
	rootNode := mcts.NewNode(board, nil, nil)

	for i := 0; i <= 1000; i++ {
		winCount := 0

		for j := 0; j < 1000; j++ {
			selectedNode := rootNode
			for {
				// Player 1 は最もスコアが高いノードを選ぶ
				nextNode := selectedNode.SelectBestChild()
				if nextNode == nil {
					break
				}
				selectedNode = nextNode

				// Player 2 はランダムに選ぶ
				nextNode = selectedNode.SelectRandomChild()
				if nextNode == nil {
					break
				}
				selectedNode = nextNode
			}

			if selectedNode.SimulateAndExpand() {
				winCount++
			}
		}
		if i%100 == 0 {
			fmt.Printf("%d: win rate: %.1f%%\n", i*1000, 100.0*float64(winCount)/1000.0)
		}
	}

	// fmt.Println(rootNode.PrettyString(""))
}

$ go run main.go

0: win rate: 61.1%
100000: win rate: 98.7%
200000: win rate: 99.1%
300000: win rate: 99.5%
400000: win rate: 99.9%
500000: win rate: 99.4%
600000: win rate: 99.9%
700000: win rate: 99.5%
800000: win rate: 99.6%
900000: win rate: 99.8%
1000000: win rate: 100.0%

The tree has also been updated.

├──(2, 1) -> (1, 1) [0.839369/116/234]
│   ├──(0, 0) -> (1, 0) [0.831078/20/53]
│   │   ├──C -> (0, 0) [1.960251/1/3]
│   │   ├──C -> (1, 2) [1.860205/3/5]
│   │   ├──C -> (2, 0) [1.922211/6/7]
│   │   ├──C -> (2, 1) [1.983738/5/6]
│   │   ├──C -> (2, 2) [1.626918/0/3]
│   │   ├──(1, 1) -> (0, 1) [1.939301/9/9]
│   │   ├──(3, 0) -> (2, 1) [1.908952/2/4]
│   │   ├──(3, 1) -> (2, 0) [1.817071/4/6]
│   │   ├──(3, 1) -> (2, 1) [1.626918/0/3]
│   │   ├──(3, 1) -> (2, 2) [1.908952/2/4]
│   │   ├──(3, 2) -> (2, 2) [1.960251/1/3]
│   ├──(0, 1) -> (1, 0) [0.954886/18/41]
│   │   ├──C -> (0, 1) [1.862639/2/4]
...

References

モンテカルロ木探索 | 備忘録

Monte Carlo Tree Search - About