From ce5c8c4118073e5820f8cb650b1ef4bf2314f8c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Tr=C3=B6ster?= Date: Sun, 8 Dec 2024 21:13:01 +0100 Subject: [PATCH] abstract ppo training with generic state and action --- Schafkopf.Lib/Card.cs | 6 +- .../FeatureVectorTests.cs | 2 +- Schafkopf.Training/Algos/PPOAgent.cs | 263 ++---------------- Schafkopf.Training/CardPicker/Experience.cs | 5 + .../{Common => CardPicker}/GameState.cs | 2 +- Schafkopf.Training/CardPicker/PPOAgent.cs | 231 +++++++++++++++ Schafkopf.Training/Common/Experience.cs | 30 +- Schafkopf.Training/Program.cs | 2 +- 8 files changed, 288 insertions(+), 253 deletions(-) create mode 100644 Schafkopf.Training/CardPicker/Experience.cs rename Schafkopf.Training/{Common => CardPicker}/GameState.cs (99%) create mode 100644 Schafkopf.Training/CardPicker/PPOAgent.cs diff --git a/Schafkopf.Lib/Card.cs b/Schafkopf.Lib/Card.cs index c27d84d..5755fc4 100644 --- a/Schafkopf.Lib/Card.cs +++ b/Schafkopf.Lib/Card.cs @@ -20,7 +20,7 @@ public enum CardColor Eichel } -public readonly struct Card +public readonly struct Card : IEquatable { public const byte EXISTING_FLAG = 0x20; public const byte TRUMPF_FLAG = 0x40; @@ -60,6 +60,9 @@ public Card(CardType type, CardColor color, bool exists, bool isTrumpf) public override bool Equals([NotNullWhen(true)] object? obj) => obj is Card c && (c.Id & ORIG_CARD_MASK) == (this.Id & ORIG_CARD_MASK); + public bool Equals(Card other) + => Equals((object?)other); + public override int GetHashCode() => Id & ORIG_CARD_MASK; public static bool operator ==(Card a, Card b) @@ -71,5 +74,6 @@ public override bool Equals([NotNullWhen(true)] object? obj) public override string ToString() => $"{Color} {Type}{(IsTrumpf ? " (trumpf)" : "")}"; + // TODO: add an emoji format } diff --git a/Schafkopf.Training.Tests/FeatureVectorTests.cs b/Schafkopf.Training.Tests/FeatureVectorTests.cs index 6701eb7..7da5cf1 100644 --- a/Schafkopf.Training.Tests/FeatureVectorTests.cs +++ b/Schafkopf.Training.Tests/FeatureVectorTests.cs @@ -11,7 +11,7 @@ public void Test_CanSerializeCompleteGame() var call = GameCall.Sauspiel(0, 1, CardColor.Schell); var history = generateHistoryWithCall(call); - var newExp = () => new SarsExp() { StateBefore = new GameState() }; + var newExp = () => new SchafkopfSarsExp() { StateBefore = new GameState() }; var states = Enumerable.Range(0, 32).Select(i => newExp()).ToArray(); serializer.SerializeSarsExps(history, states); diff --git a/Schafkopf.Training/Algos/PPOAgent.cs b/Schafkopf.Training/Algos/PPOAgent.cs index 496bea7..e137a97 100644 --- a/Schafkopf.Training/Algos/PPOAgent.cs +++ b/Schafkopf.Training/Algos/PPOAgent.cs @@ -25,70 +25,6 @@ public class PPOTrainingSettings public int NumTrainings => TrainSteps / StepsPerUpdate; } -public class PPOTrainingSession -{ - public PPOModel Train(PPOTrainingSettings config) - { - var model = new PPOModel(config); - var rollout = new PPORolloutBuffer(config); - var exps = new CardPickerExpCollector(); - var benchmark = new RandomPlayBenchmark(); - var agent = new PPOAgent(model); - - for (int ep = 0; ep < config.NumTrainings; ep++) - { - Console.WriteLine($"epoch {ep+1}"); - exps.Collect(rollout, model); - model.Train(rollout); - - model.RecompileCache(batchSize: 1); - double winRate = benchmark.Benchmark(agent); - model.RecompileCache(batchSize: config.BatchSize); - - Console.WriteLine($"win rate vs. random agents: {winRate}"); - Console.WriteLine("--------------------------------------"); - } - - return model; - } -} - -public class PPOAgent : ISchafkopfAIAgent -{ - public PPOAgent(PPOModel model) - { - this.model = model; - } - - private PPOModel model; - private HeuristicAgent heuristicAgent = new HeuristicAgent(); - private GameStateSerializer stateSerializer = new GameStateSerializer(); - private PossibleCardPicker sampler = new PossibleCardPicker(); - - private Matrix2D s0 = Matrix2D.Zeros(1, 90); - private Matrix2D piOh = Matrix2D.Zeros(1, 32); - private Matrix2D V = Matrix2D.Zeros(1, 1); - - public Card ChooseCard(GameLog log, ReadOnlySpan possibleCards) - { - var state = stateSerializer.SerializeState(log); - state.ExportFeatures(s0.SliceRowsRaw(0, 1)); - model.Predict(s0, piOh, V); - var predDist = piOh.SliceRowsRaw(0, 1); - return sampler.PickCard(possibleCards, predDist); - } - - public bool CallKontra(GameLog log) => heuristicAgent.CallKontra(log); - public bool CallRe(GameLog log) => heuristicAgent.CallRe(log); - public bool IsKlopfer(int position, ReadOnlySpan firstFourCards) - => heuristicAgent.IsKlopfer(position, firstFourCards); - public GameCall MakeCall( - ReadOnlySpan possibleCalls, - int position, Hand hand, int klopfer) - => heuristicAgent.MakeCall(possibleCalls, position, hand, klopfer); - public void OnGameFinished(GameLog final) => heuristicAgent.OnGameFinished(final); -} - public class PPOModel { public PPOModel(PPOTrainingSettings config) @@ -139,7 +75,9 @@ public void Predict(Matrix2D s0, Matrix2D outPiOnehot, Matrix2D outV) Matrix2D.CopyData(predV, outV); } - public void Train(PPORolloutBuffer memory) + public void Train(PPORolloutBuffer memory) + where TState : IEquatable, new() + where TAction : IEquatable, new() { int numBatches = memory.NumBatches( config.BatchSize, config.UpdateEpochs); @@ -228,175 +166,13 @@ public void RecompileCache(int batchSize) } } -public class CardPickerExpCollector -{ - public void Collect(PPORolloutBuffer buffer, PPOModel strategy) - { - int numGames = buffer.Steps / 8; - int numSessions = buffer.NumEnvs / 4; - var envs = Enumerable.Range(0, numSessions) - .Select(i => new MultiAgentCardPickerEnv()).ToArray(); - - var vecAgent = new VectorizedCardPickerAgent(strategy, numSessions); - var agents = Enumerable.Range(0, buffer.NumEnvs) - .Select(i => new AsyncCardPickerAgent(vecAgent)).ToArray(); - - var expCache = new PPOExp[buffer.NumEnvs]; - int t = 0; - var barr = new Barrier(buffer.NumEnvs, (b) => { - buffer.AppendStep(expCache, t++); - Console.Write($"\rcollecting ppo data {t} / {buffer.Steps} "); - }); - - var collectTasks = Enumerable.Range(0, buffer.NumEnvs) - .Select(i => Task.Run(() => { - var agent = agents[i]; - var env = envs[i / 4]; - foreach (var exp in agent.PlaySteps(i % 4, env, buffer.Steps)) - { - barr.SignalAndWait(); - expCache[i] = exp; - } - })) - .ToArray(); - - Task.WaitAll(collectTasks); - Console.WriteLine(); - } -} - -public class VectorizedCardPickerAgent -{ - public VectorizedCardPickerAgent(PPOModel strategy, int numSessions) - { - states = Matrix2D.Zeros(numSessions, GameState.NUM_FEATURES); - predPi = Matrix2D.Zeros(numSessions, 32); - predV = Matrix2D.Zeros(numSessions, 1); - - samplers = Enumerable.Range(0, numSessions) - .Select(i => new PossibleCardPicker()).ToArray(); - - threadIds = new int[numSessions]; - barr = new Barrier(numSessions, (b) => strategy.Predict(states, predPi, predV)); - } - - private int[] threadIds; - private Barrier barr; - - private Matrix2D states; - private Matrix2D predPi; - private Matrix2D predV; - - private PossibleCardPicker[] samplers; - - private int sessionIdByThread() - { - int threadId = Environment.CurrentManagedThreadId; - for (int i = 0; i < threadIds.Length; i++) - if (threadIds[i] == threadId) - return i; - throw new InvalidOperationException("Unregistered thread!"); - } - - public void Register(int sessionId) - { - threadIds[sessionId] = Environment.CurrentManagedThreadId; - } - - public (Card, double, double) Predict( - GameState state, ReadOnlySpan possCards) - { - int sessionId = sessionIdByThread(); - var s0Slice = states.SliceRowsRaw(sessionId, 1); - state.ExportFeatures(s0Slice); - - barr.SignalAndWait(); - - var predPiDistr = predPi.SliceRowsRaw(sessionId, 1); - var card = samplers[sessionId].PickCard(possCards, predPiDistr); - double pi = predPiDistr[card.Id % 32]; - double V = predV.At(sessionId, 0); - - return (card, pi, V); - } -} - -public class AsyncCardPickerAgent -{ - public AsyncCardPickerAgent(VectorizedCardPickerAgent vecAgent) - { - this.vecAgent = vecAgent; - } - - private VectorizedCardPickerAgent vecAgent; - private Card[] cardCache = new Card[8]; - private GameRules rules = new GameRules(); - private GameStateSerializer stateSerializer = new GameStateSerializer(); - - public IEnumerable PlaySteps( - int playerId, MultiAgentCardPickerEnv env, int steps) - { - var exp = new PPOExp(); - env.Register(playerId); - var state = env.Reset(); - - for (int i = 0; i < steps; i++) - { - (GameState s0, Card a0, double pi, double V) = predict(state); - (state, double r1, bool t1) = env.Step(a0); - if (t1) - state = env.Reset(); - - exp.StateBefore = s0; - exp.Action = a0; - exp.Reward = r1; - exp.IsTerminal = t1; - exp.OldProb = pi; - exp.OldBaseline = V; - yield return exp; - } - } - - private (GameState, Card, double, double) predict(GameLog state) - { - var possCards = rules.PossibleCards(state, cardCache); - var encState = stateSerializer.SerializeState(state); - (var a0, var pi, var V) = vecAgent.Predict(encState, possCards); - return (encState, a0, pi, V); - } -} - -public class PossibleCardPicker -{ - private UniformDistribution uniform = new UniformDistribution(); - - public Card PickCard(ReadOnlySpan possibleCards, ReadOnlySpan predPi) - => possibleCards[uniform.Sample(normProbDist(predPi, possibleCards))]; - - private double[] probDistCache = new double[8]; - private ReadOnlySpan normProbDist( - ReadOnlySpan probDistAll, ReadOnlySpan possibleCards) - { - double probSum = 0; - for (int i = 0; i < possibleCards.Length; i++) - probDistCache[i] = probDistAll[possibleCards[i].Id & Card.ORIG_CARD_MASK]; - for (int i = 0; i < possibleCards.Length; i++) - probSum += probDistCache[i]; - double scale = 1 / probSum; - for (int i = 0; i < possibleCards.Length; i++) - probDistCache[i] *= scale; - - return probDistCache.AsSpan().Slice(0, possibleCards.Length); - } -} - public struct PPOTrainBatch { - public PPOTrainBatch(int size, int numStateDims) + public PPOTrainBatch(int size, int numStateDims, int numActionDims) { Size = size; StatesBefore = Matrix2D.Zeros(size, numStateDims); - Actions = Matrix2D.Zeros(size, 1); + Actions = Matrix2D.Zeros(size, numActionDims); Rewards = Matrix2D.Zeros(size, 1); Terminals = Matrix2D.Zeros(size, 1); Returns = Matrix2D.Zeros(size, 1); @@ -464,10 +240,18 @@ public PPOTrainBatch SliceRows(int rowid, int length) }; } -public class PPORolloutBuffer +public class PPORolloutBuffer + where TState : IEquatable, new() + where TAction : IEquatable, new() { - public PPORolloutBuffer(PPOTrainingSettings config) + public PPORolloutBuffer( + PPOTrainingSettings config, + Action encodeState, + Action encodeAction) { + this.encodeState = encodeState; + this.encodeAction = encodeAction; + NumEnvs = config.NumEnvs; Steps = config.StepsPerUpdate; gamma = config.RewardDiscount; @@ -478,13 +262,19 @@ public PPORolloutBuffer(PPOTrainingSettings config) int size = Steps * NumEnvs; int sizeWithExtraStep = (Steps + 1) * NumEnvs; - cache = new PPOTrainBatch(sizeWithExtraStep, config.NumStateDims); + cache = new PPOTrainBatch( + sizeWithExtraStep, + config.NumStateDims, + config.NumActionDims + ); cacheWithoutLastStep = cache.SliceRows(0, size); cacheOnlyFirstStep = cache.SliceRows(0, NumEnvs); cacheOnlyLastStep = cache.SliceRows(size, NumEnvs); permCache = Perm.Identity(size); } + private Action encodeState; + private Action encodeAction; public int NumEnvs; public int Steps; private double gamma; @@ -500,7 +290,7 @@ public PPORolloutBuffer(PPOTrainingSettings config) public int NumBatches(int batchSize, int epochs = 1) => cacheWithoutLastStep.Size / batchSize * epochs; - public void AppendStep(PPOExp[] exps, int t) + public void AppendStep(PPOExp[] exps, int t) { if (exps.Length != NumEnvs) throw new ArgumentException("Invalid amount of experiences!"); @@ -514,9 +304,10 @@ public void AppendStep(PPOExp[] exps, int t) var exp = exps[i]; unsafe { - var s0Dest = buffer.StatesBefore.SliceRowsRaw(i, 1); - exp.StateBefore.ExportFeatures(s0Dest); - buffer.Actions.Data[i] = exp.Action.Id % 32; + var s0Dest = buffer.StatesBefore.SliceRows(i, 1); + encodeState(exp.StateBefore, s0Dest); + var a0Dest = buffer.Actions.SliceRows(i, 1); + encodeAction(exp.Action, a0Dest); buffer.Rewards.Data[i] = exp.Reward; buffer.Terminals.Data[i] = exp.IsTerminal ? 1 : 0; buffer.OldProbs.Data[i] = exp.OldProb; diff --git a/Schafkopf.Training/CardPicker/Experience.cs b/Schafkopf.Training/CardPicker/Experience.cs new file mode 100644 index 0000000..5178d7b --- /dev/null +++ b/Schafkopf.Training/CardPicker/Experience.cs @@ -0,0 +1,5 @@ +namespace Schafkopf.Training; + +public class SchafkopfSarsExp : SarsExp { } + +public class SchafkopfPPOExp : PPOExp { } diff --git a/Schafkopf.Training/Common/GameState.cs b/Schafkopf.Training/CardPicker/GameState.cs similarity index 99% rename from Schafkopf.Training/Common/GameState.cs rename to Schafkopf.Training/CardPicker/GameState.cs index 510f8b7..368b6cd 100644 --- a/Schafkopf.Training/Common/GameState.cs +++ b/Schafkopf.Training/CardPicker/GameState.cs @@ -36,7 +36,7 @@ public static GameState[] NewBuffer() => Enumerable.Range(0, 36).Select(x => new GameState()).ToArray(); private GameState[] stateBuffer = NewBuffer(); - public void SerializeSarsExps(GameLog completedGame, SarsExp[] exps) + public void SerializeSarsExps(GameLog completedGame, SchafkopfSarsExp[] exps) { if (completedGame.CardCount != 32) throw new ArgumentException("Can only process finished games!"); diff --git a/Schafkopf.Training/CardPicker/PPOAgent.cs b/Schafkopf.Training/CardPicker/PPOAgent.cs new file mode 100644 index 0000000..b406765 --- /dev/null +++ b/Schafkopf.Training/CardPicker/PPOAgent.cs @@ -0,0 +1,231 @@ +namespace Schafkopf.Training; + +public class SchafkopfPPOTrainingSession +{ + public PPOModel Train(PPOTrainingSettings config) + { + var model = new PPOModel(config); + var rollout = new PPORolloutBuffer( + config, + (s0, buf) => s0.ExportFeatures(buf.SliceRowsRaw(0, 1)), + (a0, buf) => buf.SliceRowsRaw(0, 1)[0] = a0.Id % 32 + ); + var exps = new CardPickerExpCollector(); + var benchmark = new RandomPlayBenchmark(); + var agent = new SchafkopfPPOAgent(model); + + for (int ep = 0; ep < config.NumTrainings; ep++) + { + Console.WriteLine($"epoch {ep+1}"); + exps.Collect(rollout, model); + model.Train(rollout); + + model.RecompileCache(batchSize: 1); + double winRate = benchmark.Benchmark(agent); + model.RecompileCache(batchSize: config.BatchSize); + + Console.WriteLine($"win rate vs. random agents: {winRate}"); + Console.WriteLine("--------------------------------------"); + } + + return model; + } +} + +public class SchafkopfPPOAgent : ISchafkopfAIAgent +{ + public SchafkopfPPOAgent(PPOModel model) + { + this.model = model; + } + + private PPOModel model; + private HeuristicAgent heuristicAgent = new HeuristicAgent(); + private GameStateSerializer stateSerializer = new GameStateSerializer(); + private PossibleCardPicker sampler = new PossibleCardPicker(); + + private Matrix2D s0 = Matrix2D.Zeros(1, 90); + private Matrix2D piOh = Matrix2D.Zeros(1, 32); + private Matrix2D V = Matrix2D.Zeros(1, 1); + + public Card ChooseCard(GameLog log, ReadOnlySpan possibleCards) + { + var state = stateSerializer.SerializeState(log); + state.ExportFeatures(s0.SliceRowsRaw(0, 1)); + model.Predict(s0, piOh, V); + var predDist = piOh.SliceRowsRaw(0, 1); + return sampler.PickCard(possibleCards, predDist); + } + + public bool CallKontra(GameLog log) => heuristicAgent.CallKontra(log); + public bool CallRe(GameLog log) => heuristicAgent.CallRe(log); + public bool IsKlopfer(int position, ReadOnlySpan firstFourCards) + => heuristicAgent.IsKlopfer(position, firstFourCards); + public GameCall MakeCall( + ReadOnlySpan possibleCalls, + int position, Hand hand, int klopfer) + => heuristicAgent.MakeCall(possibleCalls, position, hand, klopfer); + public void OnGameFinished(GameLog final) => heuristicAgent.OnGameFinished(final); +} + +public class CardPickerExpCollector +{ + public void Collect(PPORolloutBuffer buffer, PPOModel strategy) + { + int numGames = buffer.Steps / 8; + int numSessions = buffer.NumEnvs / 4; + var envs = Enumerable.Range(0, numSessions) + .Select(i => new MultiAgentCardPickerEnv()).ToArray(); + + var vecAgent = new VectorizedCardPickerAgent(strategy, numSessions); + var agents = Enumerable.Range(0, buffer.NumEnvs) + .Select(i => new AsyncCardPickerAgent(vecAgent)).ToArray(); + + var expCache = new SchafkopfPPOExp[buffer.NumEnvs]; + int t = 0; + var barr = new Barrier(buffer.NumEnvs, (b) => { + buffer.AppendStep(expCache, t++); + Console.Write($"\rcollecting ppo data {t} / {buffer.Steps} "); + }); + + var collectTasks = Enumerable.Range(0, buffer.NumEnvs) + .Select(i => Task.Run(() => { + var agent = agents[i]; + var env = envs[i / 4]; + foreach (var exp in agent.PlaySteps(i % 4, env, buffer.Steps)) + { + barr.SignalAndWait(); + expCache[i] = exp; + } + })) + .ToArray(); + + Task.WaitAll(collectTasks); + Console.WriteLine(); + } +} + +public class VectorizedCardPickerAgent +{ + public VectorizedCardPickerAgent(PPOModel strategy, int numSessions) + { + states = Matrix2D.Zeros(numSessions, GameState.NUM_FEATURES); + predPi = Matrix2D.Zeros(numSessions, 32); + predV = Matrix2D.Zeros(numSessions, 1); + + samplers = Enumerable.Range(0, numSessions) + .Select(i => new PossibleCardPicker()).ToArray(); + + threadIds = new int[numSessions]; + barr = new Barrier(numSessions, (b) => strategy.Predict(states, predPi, predV)); + } + + private int[] threadIds; + private Barrier barr; + + private Matrix2D states; + private Matrix2D predPi; + private Matrix2D predV; + + private PossibleCardPicker[] samplers; + + private int sessionIdByThread() + { + int threadId = Environment.CurrentManagedThreadId; + for (int i = 0; i < threadIds.Length; i++) + if (threadIds[i] == threadId) + return i; + throw new InvalidOperationException("Unregistered thread!"); + } + + public void Register(int sessionId) + { + threadIds[sessionId] = Environment.CurrentManagedThreadId; + } + + public (Card, double, double) Predict( + GameState state, ReadOnlySpan possCards) + { + int sessionId = sessionIdByThread(); + var s0Slice = states.SliceRowsRaw(sessionId, 1); + state.ExportFeatures(s0Slice); + + barr.SignalAndWait(); + + var predPiDistr = predPi.SliceRowsRaw(sessionId, 1); + var card = samplers[sessionId].PickCard(possCards, predPiDistr); + double pi = predPiDistr[card.Id % 32]; + double V = predV.At(sessionId, 0); + + return (card, pi, V); + } +} + +public class AsyncCardPickerAgent +{ + public AsyncCardPickerAgent(VectorizedCardPickerAgent vecAgent) + { + this.vecAgent = vecAgent; + } + + private VectorizedCardPickerAgent vecAgent; + private Card[] cardCache = new Card[8]; + private GameRules rules = new GameRules(); + private GameStateSerializer stateSerializer = new GameStateSerializer(); + + public IEnumerable PlaySteps( + int playerId, MultiAgentCardPickerEnv env, int steps) + { + var exp = new SchafkopfPPOExp(); + env.Register(playerId); + var state = env.Reset(); + + for (int i = 0; i < steps; i++) + { + (GameState s0, Card a0, double pi, double V) = predict(state); + (state, double r1, bool t1) = env.Step(a0); + if (t1) + state = env.Reset(); + + exp.StateBefore = s0; + exp.Action = a0; + exp.Reward = r1; + exp.IsTerminal = t1; + exp.OldProb = pi; + exp.OldBaseline = V; + yield return exp; + } + } + + private (GameState, Card, double, double) predict(GameLog state) + { + var possCards = rules.PossibleCards(state, cardCache); + var encState = stateSerializer.SerializeState(state); + (var a0, var pi, var V) = vecAgent.Predict(encState, possCards); + return (encState, a0, pi, V); + } +} + +public class PossibleCardPicker +{ + private UniformDistribution uniform = new UniformDistribution(); + + public Card PickCard(ReadOnlySpan possibleCards, ReadOnlySpan predPi) + => possibleCards[uniform.Sample(normProbDist(predPi, possibleCards))]; + + private double[] probDistCache = new double[8]; + private ReadOnlySpan normProbDist( + ReadOnlySpan probDistAll, ReadOnlySpan possibleCards) + { + double probSum = 0; + for (int i = 0; i < possibleCards.Length; i++) + probDistCache[i] = probDistAll[possibleCards[i].Id & Card.ORIG_CARD_MASK]; + for (int i = 0; i < possibleCards.Length; i++) + probSum += probDistCache[i]; + double scale = 1 / probSum; + for (int i = 0; i < possibleCards.Length; i++) + probDistCache[i] *= scale; + + return probDistCache.AsSpan().Slice(0, possibleCards.Length); + } +} diff --git a/Schafkopf.Training/Common/Experience.cs b/Schafkopf.Training/Common/Experience.cs index 0aaa176..d412e7a 100644 --- a/Schafkopf.Training/Common/Experience.cs +++ b/Schafkopf.Training/Common/Experience.cs @@ -1,39 +1,43 @@ namespace Schafkopf.Training; -public struct SarsExp : IEquatable +public class SarsExp : IEquatable> + where TState : IEquatable, new() where TAction : IEquatable, new() { public SarsExp() { } - public GameState StateBefore = new GameState(); - public GameState StateAfter = new GameState(); - public Card Action = new Card(); + public TState StateBefore = new TState(); + public TState StateAfter = new TState(); + public TAction Action = new TAction(); public double Reward = 0.0; public bool IsTerminal = false; - public bool Equals(SarsExp other) - => StateBefore.Equals(other.StateBefore) + public bool Equals(SarsExp? other) + => other != null + && StateBefore.Equals(other.StateBefore) && StateAfter.Equals(other.StateAfter) - && Action == other.Action + && Action.Equals(other.Action) && Reward == other.Reward && IsTerminal == other.IsTerminal; public override int GetHashCode() => 0; } -public struct PPOExp : IEquatable +public class PPOExp : IEquatable> + where TState : IEquatable, new() where TAction : IEquatable, new() { public PPOExp() { } - public GameState StateBefore = new GameState(); - public Card Action = new Card(); + public TState StateBefore = new TState(); + public TAction Action = new TAction(); public double Reward = 0.0; public bool IsTerminal = false; public double OldProb = 0.0; public double OldBaseline = 0.0; - public bool Equals(PPOExp other) - => StateBefore.Equals(other.StateBefore) - && Action == other.Action + public bool Equals(PPOExp? other) + => other != null + && StateBefore.Equals(other.StateBefore) + && Action.Equals(other.Action) && Reward == other.Reward && IsTerminal == other.IsTerminal && OldProb == other.OldProb diff --git a/Schafkopf.Training/Program.cs b/Schafkopf.Training/Program.cs index 94bf5a8..d91c2a6 100644 --- a/Schafkopf.Training/Program.cs +++ b/Schafkopf.Training/Program.cs @@ -5,7 +5,7 @@ public class Program public static void Main(string[] args) { var config = new PPOTrainingSettings(); - var session = new PPOTrainingSession(); + var session = new SchafkopfPPOTrainingSession(); session.Train(config); } }