モンテカルロ木探索 (MCTS) でどうぶつしょうぎのゲームAIを作る

algorithmmachinelearninggolang

モンテカルロ木探索 (Monte Carlo Tree Search; MCTS) は AlphaGo などのゲームAIで最適な行動を探索するのに用いられているアルゴリズム。同種のアルゴリズムとしては、相手が(自分にとって)最悪の手を常に打つと仮定したときの最良の手を網羅的に探索する Mini-Max 法や、それを枝刈りすることで効率化した αβ 法がある。これらの手法が途中盤面の評価値を与えてやる必要があるのに対して、モンテカルロ木探索はランダムにゲームを進めた結果の勝ち負けをもって評価することができる。

ゲームの実装

まずシミュレーションのためのゲームを実装した。全体のコードは GitHub にある。 どうぶつしょうぎは自分の駒を動かすほか、通常の将棋と同じく取った駒を置くこともでき、最終的に相手のライオンを取るか、自分のライオンを相手陣の最奥まで動かすと勝ちとなる。ひよこのみ最奥まで動かすとにわとりに成ることができる。

func (b *Board) IsGameOver() (bool, player) {
	var p1Lion, p2Lion bool
	for row := 0; row < 4; row++ {
		for col := 0; col < 3; col++ {
			square := b.squares[row][col]
			if square.piece == lion {
				if square.player == Player1 {
					p1Lion = true
					// 自分のライオンを相手陣の1段目に移動させる「トライ」
					if row == 0 {
						return true, Player1
					}
				} else if square.player == Player2 {
					p2Lion = true
					if row == 3 {
						return true, Player2
					}
				}
			}
		}
	}

	// 相手のライオンを取る「キャッチ」
	if !p1Lion {
		return true, Player2
	}
	if !p2Lion {
		return true, Player1
	}

	return false, none
}

func (b *Board) GetAllValidMoves() []Move {
	var allMoves []Move

	// 持ち駒を出す場合
	for _, piece := range b.captured[b.turnPlayer] {
		drops := b.getValidDrops(piece)
		for _, drop := range drops {
			nextBoard := b.ApplyMove(drop)
			// 千日手(手番が全く同じ状態が3回現れる)は除く
			if nextBoard.occurrences[nextBoard.String()] < 4 {
				allMoves = append(allMoves, drop)
			}
		}
	}

	// 盤面の駒を移動する場合
	for row := 0; row < 4; row++ {
		for col := 0; col < 3; col++ {
			square := b.squares[row][col]
			if square.player == b.turnPlayer {
				moves := b.getValidMoves(row, col)
				for _, move := range moves {
					nextBoard := b.ApplyMove(move)
					if nextBoard.occurrences[nextBoard.String()] < 4 {
						allMoves = append(allMoves, move)
					}
				}
			}
		}
	}

	return allMoves
}

func (b *Board) getValidDrops(piece piece) []Move {
	var drops []Move

	for row := 0; row < 4; row++ {
		for col := 0; col < 3; col++ {
			if b.squares[row][col].piece == empty {
				drops = append(drops, Move{-1, -1, row, col, piece})
			}
		}
	}

	return drops
}

func (b *Board) getValidMoves(row, col int) []Move {
	var moves []Move
	if row < 0 || row >= 4 || col < 0 || col >= 3 {
		return moves
	}

	square := b.squares[row][col]
	if square.piece == empty {
		return moves
	}

	addIfValid := func(toRow, toCol int) {
		if toRow >= 0 && toRow < 4 && toCol >= 0 && toCol < 3 {
			if b.squares[toRow][toCol].player != square.player {
				moves = append(moves, Move{row, col, toRow, toCol, empty})
			}
		}
	}

	switch square.piece {
	// 隣接する8マスのいずれかに進むことができる。
	case lion:
		for drow := -1; drow <= 1; drow++ {
			for dcol := -1; dcol <= 1; dcol++ {
				if drow == 0 && dcol == 0 {
					continue
				}
				addIfValid(row+drow, col+dcol)
			}
		}

	// 斜めの4マスのいずれかに進むことができる
	case elephant:
		directions := [][2]int{{-1, -1}, {-1, 1}, {1, -1}, {1, 1}}
		for _, d := range directions {
			addIfValid(row+d[0], col+d[1])
		}

	// 縦・横の4マスのいずれかに進むことができる
	case giraffe:
		directions := [][2]int{{-1, 0}, {1, 0}, {0, -1}, {0, 1}}
		for _, d := range directions {
			addIfValid(row+d[0], col+d[1])
		}

	// 前の1マスにのみ進むことができる。
	case chick:
		if square.player == Player1 {
			addIfValid(row-1, col)
		} else {
			addIfValid(row+1, col)
		}

	// 斜め後ろ以外の6マスのいずれかに進むことができる。
	case chicken:
		for drow := -1; drow <= 1; drow++ {
			for dcol := -1; dcol <= 1; dcol++ {
				if drow == 0 && dcol == 0 {
					continue
				}
				if (square.player == Player1 && drow == 1) ||
					(square.player == Player2 && drow == -1) {
					continue
				}
				addIfValid(row+drow, col+dcol)
			}
		}
	}

	return moves
}

func (b *Board) Play(chooseMove func(Board) Move) *Board {
	currentBoard := b

	for {
		if gameover, _ := currentBoard.IsGameOver(); gameover {
			return currentBoard
		}

		currentBoard = currentBoard.ApplyMove(chooseMove(*currentBoard))
	}
}

お互いランダムに動かす場合、勝敗はほぼ50%となった。

$ cat main.go
package main

import (
	"fmt"

	"math/rand"

	ds "github.com/sambaiz/doubutsu_shogi/doubutsushogi"
)

func main() {
	board := ds.NewBoard()

	randomMove := func(b ds.Board) ds.Move {
		fmt.Println(b)
		moves := b.GetAllValidMoves()
		return moves[rand.Intn(len(moves))]
	}

	endBoard := board.Play(randomMove)
	fmt.Println(endBoard)

	_, winPlayer := endBoard.IsGameOver()
	fmt.Printf("Player%d win!\n", winPlayer)
}

$ go run main

2G 2L 2E 
-- 2C -- 
-- 1C -- 
1E 1L 1G 

Player1 has: 
Player2 has: 

- Player1's turn -

2G 2L 2E 
-- 2C -- 
-- 1C 1L 
1E -- 1G 

Player1 has: 
Player2 has: 

- Player2's turn -

-- 2L 2E 
2G 2C -- 
-- 1C 1L 
1E -- 1G 

Player1 has: 
Player2 has: 

- Player1's turn -

-- 2L 2E 
2G 2C 1L 
-- 1C -- 
1E -- 1G 

Player1 has: 
Player2 has: 

- Player2's turn -

-- 2L 2E 
-- 2C 1L 
2G 1C -- 
1E -- 1G 

Player1 has: 
Player2 has: 

- Player1's turn -

-- 2L 2E 
-- 1L -- 
2G 1C -- 
1E -- 1G 

Player1 has: C 
Player2 has: 

- Player2's turn -

-- 2L 2E 
-- 1L -- 
-- 2G -- 
1E -- 1G 

Player1 has: C 
Player2 has: C 

- Player1's turn -

-- 1L 2E 
-- -- -- 
-- 2G -- 
1E -- 1G 

Player1 has: L C 
Player2 has: C 

- Player2's turn -

Player1 win!

モンテカルロ木探索の実装

モンテカルロ木探索は最も評価値の高い葉を選び、そこからランダムにゲームを進め、その結果によって評価値を更新し一定回数探索した葉から木を拡張する、というのを繰り返す。評価関数としてはバンディットアルゴリズムの UCB1 ベースの UCT (Upper Confidence Bound for Trees) が用いられる。\(Q_i\) がそのノードの合計報酬、\(n_i\) がそのノードの探索回数、\(N_i\) が親ノードの探索回数。\(C\) は調整用のパラメータで、\(\sqrt{2}\) とすることが多い。

$$ UCT = \frac{Q_i}{n_i} + C\sqrt{\frac{\ln{N}}{n_i}} $$

package mcts

import (
	"fmt"
	"math"
	"strings"

	"math/rand"

	ds "github.com/sambaiz/doubutsu_shogi/doubutsushogi"
)

type Node struct {
	board      *ds.Board
	move       *ds.Move
	totalScore float64
	visitCount int64
	parent     *Node
	children   []*Node
}

func NewNode(board *ds.Board, move *ds.Move, parent *Node) *Node {
	root := &Node{
		board:  board,
		move:   move,
		parent: parent,
	}

	childMoves := board.GetAllValidMoves()
	root.children = make([]*Node, 0, len(childMoves))
	for _, childMove := range childMoves {
		root.children = append(root.children, &Node{
			board:  board.ApplyMove(childMove),
			move:   &childMove,
			parent: root,
		})
	}

	return root
}

func (n *Node) String() string {
	var sb strings.Builder

	if n.board != nil {
		sb.WriteString(fmt.Sprintf("Board:\n%s\n", n.board.String()))
	}
	if n.move != nil {
		sb.WriteString(fmt.Sprintf("Move:\n%s\n", n.move.String()))
	}

	return sb.String()
}

func (n *Node) PrettyString(prefix string) string {
	var sb strings.Builder

	if n.move != nil {
		sb.WriteString(fmt.Sprintf("%s├──%s [%f/%d/%d]\n", prefix, n.move.String(), n.calculateScore(), int(n.totalScore), n.visitCount))
	}

	for _, child := range n.children {
		newPrefix := prefix
		if n.move != nil {
			newPrefix += "│   "
		}
		sb.WriteString(child.PrettyString(newPrefix))
	}

	return sb.String()
}

func (n *Node) SelectBestChild() *Node {
	if len(n.children) == 0 {
		return nil
	}

	maxScore := math.Inf(-1)
	var bestChild *Node

	for _, child := range n.children {
		score := child.calculateScore()
		if score > maxScore {
			maxScore = score
			bestChild = child
		}
	}

	return bestChild
}

func (n *Node) SelectRandomChild() *Node {
	if len(n.children) == 0 {
		return nil
	}

	return n.children[rand.Intn(len(n.children))]
}

func (n *Node) calculateScore() float64 {
	if n.visitCount == 0 {
		return math.Inf(1) // デフォルト値
	}

	exploitation := n.totalScore / float64(n.visitCount)

	exploration := math.Sqrt(2) * math.Sqrt(
		math.Log(float64(n.parent.visitCount))/float64(n.visitCount),
	)

	return exploitation + exploration
}

// 相手 + 自分の手番で最大2レベル増えます
func (n *Node) Expand() {
	if gameover, _ := n.board.IsGameOver(); gameover {
		return
	}

	moves := n.board.GetAllValidMoves()
	n.children = make([]*Node, 0, len(moves))

	for _, move := range moves {
		childNode := NewNode(n.board.ApplyMove(move), &move, n)
		if gameover, _ := childNode.board.IsGameOver(); gameover {
			continue
		}
		n.children = append(n.children, childNode)
	}
}

const EXPANSION_THRESHOLD = 10

func (n *Node) SimulateAndExpand() bool {
	board := n.board

	randomMove := func(b ds.Board) ds.Move {
		moves := b.GetAllValidMoves()
		return moves[rand.Intn(len(moves))]
	}

	endBoard := board.Play(randomMove)

	_, winPlayer := endBoard.IsGameOver()

	// Backpropagation
	for {
		if n == nil {
			break
		}
		n.visitCount += 1
		if winPlayer != n.board.TurnPlayer() {
			n.totalScore += 1
		}
		if len(n.children) == 0 && n.visitCount >= EXPANSION_THRESHOLD {
			n.Expand()
		}
		n = n.parent
	}

	return winPlayer == 1
}

200000 ループほどでランダム相手に 99 % 勝てるようになり、全勝も観測された。

$ cat main.go
package main

import (
	"fmt"

	ds "github.com/sambaiz/doubutsu_shogi/doubutsushogi"
	"github.com/sambaiz/doubutsu_shogi/mcts"
)

func main() {
	board := ds.NewBoard()
	rootNode := mcts.NewNode(board, nil, nil)

	for i := 0; i <= 1000; i++ {
		winCount := 0

		for j := 0; j < 1000; j++ {
			selectedNode := rootNode
			for {
				// Player 1 は最もスコアが高いノードを選ぶ
				nextNode := selectedNode.SelectBestChild()
				if nextNode == nil {
					break
				}
				selectedNode = nextNode

				// Player 2 はランダムに選ぶ
				nextNode = selectedNode.SelectRandomChild()
				if nextNode == nil {
					break
				}
				selectedNode = nextNode
			}

			if selectedNode.SimulateAndExpand() {
				winCount++
			}
		}
		if i%100 == 0 {
			fmt.Printf("%d: win rate: %.1f%%\n", i*1000, 100.0*float64(winCount)/1000.0)
		}
	}

	// fmt.Println(rootNode.PrettyString(""))
}

$ go run main.go

0: win rate: 61.1%
100000: win rate: 98.7%
200000: win rate: 99.1%
300000: win rate: 99.5%
400000: win rate: 99.9%
500000: win rate: 99.4%
600000: win rate: 99.9%
700000: win rate: 99.5%
800000: win rate: 99.6%
900000: win rate: 99.8%
1000000: win rate: 100.0%

木も更新されている。

├──(2, 1) -> (1, 1) [0.839369/116/234]
│   ├──(0, 0) -> (1, 0) [0.831078/20/53]
│   │   ├──C -> (0, 0) [1.960251/1/3]
│   │   ├──C -> (1, 2) [1.860205/3/5]
│   │   ├──C -> (2, 0) [1.922211/6/7]
│   │   ├──C -> (2, 1) [1.983738/5/6]
│   │   ├──C -> (2, 2) [1.626918/0/3]
│   │   ├──(1, 1) -> (0, 1) [1.939301/9/9]
│   │   ├──(3, 0) -> (2, 1) [1.908952/2/4]
│   │   ├──(3, 1) -> (2, 0) [1.817071/4/6]
│   │   ├──(3, 1) -> (2, 1) [1.626918/0/3]
│   │   ├──(3, 1) -> (2, 2) [1.908952/2/4]
│   │   ├──(3, 2) -> (2, 2) [1.960251/1/3]
│   ├──(0, 1) -> (1, 0) [0.954886/18/41]
│   │   ├──C -> (0, 1) [1.862639/2/4]
...

参考

モンテカルロ木探索 | 備忘録

Monte Carlo Tree Search - About