diff --git a/src/ai/OpenLoopMctsPlayer.js b/src/ai/OpenLoopMctsPlayer.js index 2d5be0d..468e458 100644 --- a/src/ai/OpenLoopMctsPlayer.js +++ b/src/ai/OpenLoopMctsPlayer.js @@ -22,7 +22,7 @@ export class OpenLoopMctsPlayer { } /** - * @param {InteractiveGame} game + * @param {Game} game * @returns {Command[]} */ decideMove(game) { @@ -79,9 +79,9 @@ export class OpenLoopMctsPlayer { _select(game, node) { while (!game.isTerminal()) { if (node.childrenSize === 0) { - return node.bestUctChild(this.args.expansionFactor, this.args.samplingExplorationChance); + return node.bestUctChild(game, this.args.expansionFactor); } else { - [game, node] = node.bestUctChild(this.args.expansionFactor, this.args.samplingExplorationChance); + [game, node] = node.bestUctChild(game, this.args.expansionFactor); } } // game is terminal diff --git a/src/ai/OpenLoopMctsPlayer.test.js b/src/ai/OpenLoopMctsPlayer.test.js index e4f18f9..098f2f4 100644 --- a/src/ai/OpenLoopMctsPlayer.test.js +++ b/src/ai/OpenLoopMctsPlayer.test.js @@ -1,9 +1,25 @@ import { fixedRandom, resetFixedRandom } from "../lib/random.js"; +import { Command } from "../model/commands/commands.js"; import makeGame from "../model/game.js"; import { TwoOnTwoMeleeScenario } from "../model/scenarios.js"; import { OpenLoopMctsPlayer } from "./OpenLoopMctsPlayer.js"; describe('OpenLoopMctsPlayer', () => { + test('decideMove returns a valid move', () => { + const game = makeGame(new TwoOnTwoMeleeScenario()); + const player = new OpenLoopMctsPlayer({ + expansionFactor: 2.1415, + playoutIterations: 20, + iterations: 1, + logfunction: () => {}, + }); + + const move = player.decideMove(game); + + expect(move).toBeInstanceOf(Array); + move.forEach(command => expect(command).toBeInstanceOf(Command)); + }); + const originalRandom = Math.random; beforeEach(() => { resetFixedRandom(); @@ -23,7 +39,7 @@ describe('OpenLoopMctsPlayer', () => { test('2 on 2', () => { const game = makeGame(new TwoOnTwoMeleeScenario()); - const root = player.search(game, 1); + const root = player.search(game, 2); const rootAsString = root.toString(); console.log(rootAsString); diff --git a/src/ai/OpenLoopNode.js b/src/ai/OpenLoopNode.js index 6d18c0b..7591ece 100644 --- a/src/ai/OpenLoopNode.js +++ b/src/ai/OpenLoopNode.js @@ -28,12 +28,11 @@ export default class OpenLoopNode { */ bestUctChild(game, expansionFactor= 2.4142) { ensure(this.side === game.currentSide, `The node's side ${this.side} must be the same as the game's current side ${game.currentSide}}`); - ensure(this.visits > 0, "The node must have been visited at least once"); let bestChild = undefined; let bestScore = -Infinity; let bestClone = undefined; - const logOfThisVisits = Math.log(this.visits); + const logOfThisVisits = this.visits === 0 ? 0 : Math.log(this.visits); for (const command of randomShuffleArray(game.validCommands())) { const clone = game.clone(); clone.executeCommand(command); @@ -126,6 +125,10 @@ export default class OpenLoopNode { return this.#children.size; } + get children() { + return this.#children.values(); + } + /** * @param {Command} command * @returns {OpenLoopNode} @@ -147,12 +150,16 @@ export default class OpenLoopNode { shape() { const result = []; + /** + * @param {OpenLoopNode} node + * @param {number} level + */ function traverse(node, level) { if (!result[level]) { result[level] = 0; } result[level]++; - for (const child of node.children) { + for (const [__, child] of node.children) { traverse(child, level + 1); } } diff --git a/src/ai/__snapshots__/OpenLoopMctsPlayer.test.js.snap b/src/ai/__snapshots__/OpenLoopMctsPlayer.test.js.snap index dd63738..94b2d83 100644 --- a/src/ai/__snapshots__/OpenLoopMctsPlayer.test.js.snap +++ b/src/ai/__snapshots__/OpenLoopMctsPlayer.test.js.snap @@ -1,6 +1,9 @@ // Jest Snapshot v1, https://goo.gl/fbAQLP exports[`OpenLoopMctsPlayer 2 on 2 1`] = ` -"undefined -> 0/0: Side: Roman -> 0 +"undefined -> -1/2: Side: Roman -> 2 + PlayCard(Order Three Units Left) -> -1/1: Side: Roman -> 0 + PlayCard(Order Heavy Troops) -> 0/1: Side: Roman -> 1 + End phase -> 0/1: Side: Roman -> 0 " `;