モンテカルロ木探索 (Monte Carlo Tree Search; MCTS) は AlphaGo などのゲームAIで最適な行動を探索するのに用いられているアルゴリズム。同種のアルゴリズムとしては、相手が(自分にとって)最悪の手を常に打つと仮定したときの最良の手を網羅的に探索する Mini-Max 法や、それを枝刈りすることで効率化した αβ 法がある。これらの手法が途中盤面の評価値を与えてやる必要があるのに対して、モンテカルロ木探索はランダムにゲームを進めた結果の勝ち負けをもって評価することができる。
ゲームの実装
まずシミュレーションのためのゲームを実装した。全体のコードは GitHub にある。 どうぶつしょうぎは自分の駒を動かすほか、通常の将棋と同じく取った駒を置くこともでき、最終的に相手のライオンを取るか、自分のライオンを相手陣の最奥まで動かすと勝ちとなる。ひよこのみ最奥まで動かすとにわとりに成ることができる。
func (b *Board) IsGameOver() (bool, player) {
var p1Lion, p2Lion bool
for row := 0; row < 4; row++ {
for col := 0; col < 3; col++ {
square := b.squares[row][col]
if square.piece == lion {
if square.player == Player1 {
p1Lion = true
// 自分のライオンを相手陣の1段目に移動させる「トライ」
if row == 0 {
return true, Player1
}
} else if square.player == Player2 {
p2Lion = true
if row == 3 {
return true, Player2
}
}
}
}
}
// 相手のライオンを取る「キャッチ」
if !p1Lion {
return true, Player2
}
if !p2Lion {
return true, Player1
}
return false, none
}
func (b *Board) GetAllValidMoves() []Move {
var allMoves []Move
// 持ち駒を出す場合
for _, piece := range b.captured[b.turnPlayer] {
drops := b.getValidDrops(piece)
for _, drop := range drops {
nextBoard := b.ApplyMove(drop)
// 千日手(手番が全く同じ状態が3回現れる)は除く
if nextBoard.occurrences[nextBoard.String()] < 4 {
allMoves = append(allMoves, drop)
}
}
}
// 盤面の駒を移動する場合
for row := 0; row < 4; row++ {
for col := 0; col < 3; col++ {
square := b.squares[row][col]
if square.player == b.turnPlayer {
moves := b.getValidMoves(row, col)
for _, move := range moves {
nextBoard := b.ApplyMove(move)
if nextBoard.occurrences[nextBoard.String()] < 4 {
allMoves = append(allMoves, move)
}
}
}
}
}
return allMoves
}
func (b *Board) getValidDrops(piece piece) []Move {
var drops []Move
for row := 0; row < 4; row++ {
for col := 0; col < 3; col++ {
if b.squares[row][col].piece == empty {
drops = append(drops, Move{-1, -1, row, col, piece})
}
}
}
return drops
}
func (b *Board) getValidMoves(row, col int) []Move {
var moves []Move
if row < 0 || row >= 4 || col < 0 || col >= 3 {
return moves
}
square := b.squares[row][col]
if square.piece == empty {
return moves
}
addIfValid := func(toRow, toCol int) {
if toRow >= 0 && toRow < 4 && toCol >= 0 && toCol < 3 {
if b.squares[toRow][toCol].player != square.player {
moves = append(moves, Move{row, col, toRow, toCol, empty})
}
}
}
switch square.piece {
// 隣接する8マスのいずれかに進むことができる。
case lion:
for drow := -1; drow <= 1; drow++ {
for dcol := -1; dcol <= 1; dcol++ {
if drow == 0 && dcol == 0 {
continue
}
addIfValid(row+drow, col+dcol)
}
}
// 斜めの4マスのいずれかに進むことができる
case elephant:
directions := [][2]int{{-1, -1}, {-1, 1}, {1, -1}, {1, 1}}
for _, d := range directions {
addIfValid(row+d[0], col+d[1])
}
// 縦・横の4マスのいずれかに進むことができる
case giraffe:
directions := [][2]int{{-1, 0}, {1, 0}, {0, -1}, {0, 1}}
for _, d := range directions {
addIfValid(row+d[0], col+d[1])
}
// 前の1マスにのみ進むことができる。
case chick:
if square.player == Player1 {
addIfValid(row-1, col)
} else {
addIfValid(row+1, col)
}
// 斜め後ろ以外の6マスのいずれかに進むことができる。
case chicken:
for drow := -1; drow <= 1; drow++ {
for dcol := -1; dcol <= 1; dcol++ {
if drow == 0 && dcol == 0 {
continue
}
if (square.player == Player1 && drow == 1) ||
(square.player == Player2 && drow == -1) {
continue
}
addIfValid(row+drow, col+dcol)
}
}
}
return moves
}
func (b *Board) Play(chooseMove func(Board) Move) *Board {
currentBoard := b
for {
if gameover, _ := currentBoard.IsGameOver(); gameover {
return currentBoard
}
currentBoard = currentBoard.ApplyMove(chooseMove(*currentBoard))
}
}
お互いランダムに動かす場合、勝敗はほぼ50%となった。
$ cat main.go
package main
import (
"fmt"
"math/rand"
ds "github.com/sambaiz/doubutsu_shogi/doubutsushogi"
)
func main() {
board := ds.NewBoard()
randomMove := func(b ds.Board) ds.Move {
fmt.Println(b)
moves := b.GetAllValidMoves()
return moves[rand.Intn(len(moves))]
}
endBoard := board.Play(randomMove)
fmt.Println(endBoard)
_, winPlayer := endBoard.IsGameOver()
fmt.Printf("Player%d win!\n", winPlayer)
}
$ go run main
2G 2L 2E
-- 2C --
-- 1C --
1E 1L 1G
Player1 has:
Player2 has:
- Player1's turn -
2G 2L 2E
-- 2C --
-- 1C 1L
1E -- 1G
Player1 has:
Player2 has:
- Player2's turn -
-- 2L 2E
2G 2C --
-- 1C 1L
1E -- 1G
Player1 has:
Player2 has:
- Player1's turn -
-- 2L 2E
2G 2C 1L
-- 1C --
1E -- 1G
Player1 has:
Player2 has:
- Player2's turn -
-- 2L 2E
-- 2C 1L
2G 1C --
1E -- 1G
Player1 has:
Player2 has:
- Player1's turn -
-- 2L 2E
-- 1L --
2G 1C --
1E -- 1G
Player1 has: C
Player2 has:
- Player2's turn -
-- 2L 2E
-- 1L --
-- 2G --
1E -- 1G
Player1 has: C
Player2 has: C
- Player1's turn -
-- 1L 2E
-- -- --
-- 2G --
1E -- 1G
Player1 has: L C
Player2 has: C
- Player2's turn -
Player1 win!
モンテカルロ木探索の実装
モンテカルロ木探索は最も評価値の高い葉を選び、そこからランダムにゲームを進め、その結果によって評価値を更新し一定回数探索した葉から木を拡張する、というのを繰り返す。評価関数としてはバンディットアルゴリズムの UCB1 ベースの UCT (Upper Confidence Bound for Trees) が用いられる。\(Q_i\) がそのノードの合計報酬、\(n_i\) がそのノードの探索回数、\(N_i\) が親ノードの探索回数。\(C\) は調整用のパラメータで、\(\sqrt{2}\) とすることが多い。
$$ UCT = \frac{Q_i}{n_i} + C\sqrt{\frac{\ln{N}}{n_i}} $$
package mcts
import (
"fmt"
"math"
"strings"
"math/rand"
ds "github.com/sambaiz/doubutsu_shogi/doubutsushogi"
)
type Node struct {
board *ds.Board
move *ds.Move
totalScore float64
visitCount int64
parent *Node
children []*Node
}
func NewNode(board *ds.Board, move *ds.Move, parent *Node) *Node {
root := &Node{
board: board,
move: move,
parent: parent,
}
childMoves := board.GetAllValidMoves()
root.children = make([]*Node, 0, len(childMoves))
for _, childMove := range childMoves {
root.children = append(root.children, &Node{
board: board.ApplyMove(childMove),
move: &childMove,
parent: root,
})
}
return root
}
func (n *Node) String() string {
var sb strings.Builder
if n.board != nil {
sb.WriteString(fmt.Sprintf("Board:\n%s\n", n.board.String()))
}
if n.move != nil {
sb.WriteString(fmt.Sprintf("Move:\n%s\n", n.move.String()))
}
return sb.String()
}
func (n *Node) PrettyString(prefix string) string {
var sb strings.Builder
if n.move != nil {
sb.WriteString(fmt.Sprintf("%s├──%s [%f/%d/%d]\n", prefix, n.move.String(), n.calculateScore(), int(n.totalScore), n.visitCount))
}
for _, child := range n.children {
newPrefix := prefix
if n.move != nil {
newPrefix += "│ "
}
sb.WriteString(child.PrettyString(newPrefix))
}
return sb.String()
}
func (n *Node) SelectBestChild() *Node {
if len(n.children) == 0 {
return nil
}
maxScore := math.Inf(-1)
var bestChild *Node
for _, child := range n.children {
score := child.calculateScore()
if score > maxScore {
maxScore = score
bestChild = child
}
}
return bestChild
}
func (n *Node) SelectRandomChild() *Node {
if len(n.children) == 0 {
return nil
}
return n.children[rand.Intn(len(n.children))]
}
func (n *Node) calculateScore() float64 {
if n.visitCount == 0 {
return math.Inf(1) // デフォルト値
}
exploitation := n.totalScore / float64(n.visitCount)
exploration := math.Sqrt(2) * math.Sqrt(
math.Log(float64(n.parent.visitCount))/float64(n.visitCount),
)
return exploitation + exploration
}
// 相手 + 自分の手番で最大2レベル増えます
func (n *Node) Expand() {
if gameover, _ := n.board.IsGameOver(); gameover {
return
}
moves := n.board.GetAllValidMoves()
n.children = make([]*Node, 0, len(moves))
for _, move := range moves {
childNode := NewNode(n.board.ApplyMove(move), &move, n)
if gameover, _ := childNode.board.IsGameOver(); gameover {
continue
}
n.children = append(n.children, childNode)
}
}
const EXPANSION_THRESHOLD = 10
func (n *Node) SimulateAndExpand() bool {
board := n.board
randomMove := func(b ds.Board) ds.Move {
moves := b.GetAllValidMoves()
return moves[rand.Intn(len(moves))]
}
endBoard := board.Play(randomMove)
_, winPlayer := endBoard.IsGameOver()
// Backpropagation
for {
if n == nil {
break
}
n.visitCount += 1
if winPlayer != n.board.TurnPlayer() {
n.totalScore += 1
}
if len(n.children) == 0 && n.visitCount >= EXPANSION_THRESHOLD {
n.Expand()
}
n = n.parent
}
return winPlayer == 1
}
200000 ループほどでランダム相手に 99 % 勝てるようになり、全勝も観測された。
$ cat main.go
package main
import (
"fmt"
ds "github.com/sambaiz/doubutsu_shogi/doubutsushogi"
"github.com/sambaiz/doubutsu_shogi/mcts"
)
func main() {
board := ds.NewBoard()
rootNode := mcts.NewNode(board, nil, nil)
for i := 0; i <= 1000; i++ {
winCount := 0
for j := 0; j < 1000; j++ {
selectedNode := rootNode
for {
// Player 1 は最もスコアが高いノードを選ぶ
nextNode := selectedNode.SelectBestChild()
if nextNode == nil {
break
}
selectedNode = nextNode
// Player 2 はランダムに選ぶ
nextNode = selectedNode.SelectRandomChild()
if nextNode == nil {
break
}
selectedNode = nextNode
}
if selectedNode.SimulateAndExpand() {
winCount++
}
}
if i%100 == 0 {
fmt.Printf("%d: win rate: %.1f%%\n", i*1000, 100.0*float64(winCount)/1000.0)
}
}
// fmt.Println(rootNode.PrettyString(""))
}
$ go run main.go
0: win rate: 61.1%
100000: win rate: 98.7%
200000: win rate: 99.1%
300000: win rate: 99.5%
400000: win rate: 99.9%
500000: win rate: 99.4%
600000: win rate: 99.9%
700000: win rate: 99.5%
800000: win rate: 99.6%
900000: win rate: 99.8%
1000000: win rate: 100.0%
木も更新されている。
├──(2, 1) -> (1, 1) [0.839369/116/234]
│ ├──(0, 0) -> (1, 0) [0.831078/20/53]
│ │ ├──C -> (0, 0) [1.960251/1/3]
│ │ ├──C -> (1, 2) [1.860205/3/5]
│ │ ├──C -> (2, 0) [1.922211/6/7]
│ │ ├──C -> (2, 1) [1.983738/5/6]
│ │ ├──C -> (2, 2) [1.626918/0/3]
│ │ ├──(1, 1) -> (0, 1) [1.939301/9/9]
│ │ ├──(3, 0) -> (2, 1) [1.908952/2/4]
│ │ ├──(3, 1) -> (2, 0) [1.817071/4/6]
│ │ ├──(3, 1) -> (2, 1) [1.626918/0/3]
│ │ ├──(3, 1) -> (2, 2) [1.908952/2/4]
│ │ ├──(3, 2) -> (2, 2) [1.960251/1/3]
│ ├──(0, 1) -> (1, 0) [0.954886/18/41]
│ │ ├──C -> (0, 1) [1.862639/2/4]
...