Skip to content

Commit

Permalink
add sampling layer to ppo model
Browse files Browse the repository at this point in the history
  • Loading branch information
Bonifatius94 committed Dec 13, 2024
1 parent 5d44385 commit 069571c
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 47 deletions.
47 changes: 21 additions & 26 deletions Schafkopf.Training.Tests/PPOTrainingSessionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,12 @@ public void Training_WithPPOAgent_CanLearnCartPole()
buf.SliceRowsRaw(0, 1)[0] = (int)a0.Direction;
};

// TODO: model one-hot sampling as a new class
var uniform = new UniformDistribution();
var actionsCache = Enumerable.Range(0, config.NumEnvs)
.Select(x => new CartPoleAction()).ToArray();
var probsCache = new double[config.NumEnvs];
var sampleActions = (Matrix2D piOH) => {
for (int i = 0; i < piOH.NumRows; i++)
{
var probDist = piOH.SliceRowsRaw(i, 1);
var idx = uniform.Sample(probDist);
actionsCache[i].Direction = (CartPoleDirection)idx;
probsCache[i] = probDist[idx];
}
return ((IList<CartPoleAction>)actionsCache, (IList<double>)probsCache);
var sampleActions = (Matrix2D pi) => {
for (int i = 0; i < pi.NumRows; i++)
actionsCache[i].Direction = (CartPoleDirection)(int)pi.At(i, 0);
return (IList<CartPoleAction>)actionsCache;
};

var rollout = new PPORolloutBuffer<CartPoleState, CartPoleAction>(
Expand All @@ -75,7 +67,7 @@ public class SingleAgentExpCollector<TState, TAction>
public SingleAgentExpCollector(
PPOTrainingSettings config,
Action<TState, Matrix2D> encodeState,
Func<Matrix2D, (IList<TAction>, IList<double>)> sampleActions,
Func<Matrix2D, IList<TAction>> sampleActions,
Func<MDPEnv<TState, TAction>> envFactory)
{
this.config = config;
Expand All @@ -91,19 +83,21 @@ public SingleAgentExpCollector(

s0_enc = Matrix2D.Zeros(config.NumEnvs, config.NumStateDims);
v = Matrix2D.Zeros(config.NumEnvs, 1);
pi = Matrix2D.Zeros(config.NumEnvs, config.NumActionDims);
pi = Matrix2D.Zeros(config.NumEnvs, 1);
piProbs = Matrix2D.Zeros(config.NumEnvs, 1);
}

private readonly PPOTrainingSettings config;
private readonly VectorizedEnv<TState, TAction> vecEnv;
private readonly Action<TState, Matrix2D> encodeState;
private readonly Func<Matrix2D, (IList<TAction>, IList<double>)> sampleActions;
private readonly Func<Matrix2D, IList<TAction>> sampleActions;

private TState[] s0;
private PPOExp<TState, TAction>[] exps;
private Matrix2D s0_enc;
private Matrix2D v;
private Matrix2D pi;
private Matrix2D piProbs;

public void Collect(PPORolloutBuffer<TState, TAction> buffer, PPOModel model)
{
Expand All @@ -112,8 +106,8 @@ public void Collect(PPORolloutBuffer<TState, TAction> buffer, PPOModel model)
for (int i = 0; i < config.NumEnvs; i++)
encodeState(s0[i], s0_enc.SliceRows(i, 1));

model.Predict(s0_enc, v, pi);
(var a0, var p_a0) = sampleActions(pi);
model.Predict(s0_enc, pi, piProbs, v);
var a0 = sampleActions(pi);

(var s1, var r1, var t1) = vecEnv.Step(a0);

Expand All @@ -123,8 +117,8 @@ public void Collect(PPORolloutBuffer<TState, TAction> buffer, PPOModel model)
exps[i].Action = a0[i];
exps[i].Reward = r1[i];
exps[i].IsTerminal = t1[i];
exps[i].OldProb = p_a0[i];
exps[i].OldBaseline = v.SliceRowsRaw(i, 1)[0];
exps[i].OldProb = piProbs.At(i, 0);
exps[i].OldBaseline = v.At(i, 0);
}

for (int i = 0; i < config.NumEnvs; i++)
Expand Down Expand Up @@ -166,7 +160,7 @@ public VectorizedEnv(IList<MDPEnv<TState, TAction>> envs)
public IList<TState> Reset()
{
for (int i = 0; i < envs.Count; i++)
states[i] = envs[0].Reset();
states[i] = envs[i].Reset();
return states;
}

Expand Down Expand Up @@ -226,11 +220,12 @@ public class CartPoleEnv : MDPEnv<CartPoleState, CartPoleAction>
theta_dot + tau * thetaacc
);

// TODO: check if this condition is correct
var terminated =
x < -x_threshold
|| x > x_threshold
|| theta < -theta_threshold_radians
|| theta > theta_threshold_radians;
x > -x_threshold
&& x < x_threshold
&& theta > -theta_threshold_radians
&& theta < theta_threshold_radians;

var reward = 1.0;
return (state.Value, reward, terminated);
Expand All @@ -240,9 +235,9 @@ public CartPoleState Reset()
{
state = new CartPoleState(
sample(x_threshold * -2, x_threshold * 2),
sample(-0.05, 0.05),
sample(-10.0, 10.0),
sample(theta_threshold_radians * -2, theta_threshold_radians * 2),
sample(-0.05, 0.05)
sample(-Math.PI, Math.PI)
);
return state.Value;
}
Expand Down
55 changes: 49 additions & 6 deletions Schafkopf.Training/Algos/Distributions.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
namespace Schafkopf.Training;

public class UniformDistribution
public static class UniformDistribution
{
public UniformDistribution(int? seed = null)
=> rng = seed != null ? new Random(seed.Value) : new Random();
private static Random rng = new Random();

private Random rng;
public static int Sample(this Span<double> probs, Random? rng = null)
{
rng = rng ?? UniformDistribution.rng;
double p = rng.NextDouble();
double sum = 0;
for (int i = 0; i < probs.Length - 1; i++)
{
sum += probs[i];
if (p < sum)
return i;
}
return probs.Length - 1;
}

public int Sample(ReadOnlySpan<double> probs)
public static int Sample(this ReadOnlySpan<double> probs, Random? rng = null)
{
rng = rng ?? UniformDistribution.rng;
double p = rng.NextDouble();
double sum = 0;
for (int i = 0; i < probs.Length - 1; i++)
Expand All @@ -20,5 +32,36 @@ public int Sample(ReadOnlySpan<double> probs)
return probs.Length - 1;
}

public int Sample(int numClasses) => rng.Next(0, numClasses);
public static int Sample(int numClasses, Random? rng)
=> (rng ?? UniformDistribution.rng).Next(0, numClasses);
}

public static class NormalDistribution
{
private static Random rng = new Random();

public static double Sample(
(double, double) mu_sigma, Random? rng = null, double eps = 1.19e-07)
{
rng = rng ?? NormalDistribution.rng;
(var mu, var sigma) = mu_sigma;
return Sample(mu, sigma, rng, eps);
}

public static double Sample(
double mu, double sigma, Random? rng = null, double eps = 1.19e-07)
{
const double TWO_PI = 2 * Math.PI;
rng = rng ?? NormalDistribution.rng;

double u1, u2;
do { u1 = rng.NextDouble(); } while (u1 <= eps);
u2 = rng.NextDouble();

double mag = sigma * Math.Sqrt(-2 * Math.Min(Math.Log(u1 + 1e-8), 0));
if (rng.NextDouble() > 0.5)
return mag * Math.Cos(TWO_PI * u2) + mu;
else
return mag * Math.Sin(TWO_PI * u2) + mu;
}
}
11 changes: 8 additions & 3 deletions Schafkopf.Training/Algos/PPOAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ public PPOModel(PPOTrainingSettings config)
new DenseLayer(1)
});

piSampler = new UniformSamplingLayer();
strategy = new FFModel(new ILayer[] {
new DenseLayer(64),
new ReLULayer(),
new DenseLayer(64),
new ReLULayer(),
new DenseLayer(config.NumActionDims),
new SoftmaxLayer()
new SoftmaxLayer(),
piSampler
});

valueFunc.Compile(config.BatchSize, config.NumStateDims);
Expand All @@ -60,18 +62,21 @@ public PPOModel(PPOTrainingSettings config)
private PPOTrainingSettings config;
private FFModel valueFunc;
private FFModel strategy;
private ISampler piSampler;
private IOptimizer strategyOpt;
private IOptimizer valueFuncOpt;
private Matrix2D featureCache;
private ILoss mse = new MeanSquaredError();

public int BatchSize => config.BatchSize;

public void Predict(Matrix2D s0, Matrix2D outPiOnehot, Matrix2D outV)
public void Predict(Matrix2D s0, Matrix2D outPi, Matrix2D outProbs, Matrix2D outV)
{
var predPi = strategy.PredictBatch(s0);
var predV = valueFunc.PredictBatch(s0);
Matrix2D.CopyData(predPi, outPiOnehot);
var predProbs = piSampler.FetchSelectionProbs();
Matrix2D.CopyData(predPi, outPi);
Matrix2D.CopyData(predProbs, outProbs);
Matrix2D.CopyData(predV, outV);
}

Expand Down
102 changes: 102 additions & 0 deletions Schafkopf.Training/Algos/SamplingLayer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
using Schafkopf.Training;

namespace BackpropNet;

public interface ISampler : ILayer
{
void Seed(int seed);
Matrix2D FetchSelectionProbs();
}

public class UniformSamplingLayer : ISampler
{
public UniformSamplingLayer(bool sparse = true, int seed = 0)
{
this.sparse = sparse;
Seed(seed);
}

public LayerCache Cache { get; private set; }
public int InputDims { get; private set; }
public int OutputDims { get; private set; }

private bool sparse;
private Random Rng;
private Matrix2D SelectionProbs;

public void Compile(int inputDims)
{
InputDims = inputDims;
OutputDims = sparse ? 1 : inputDims;
}

public void CompileCache(Matrix2D inputs, Matrix2D deltasOut)
{
if (InputDims != inputs.NumCols)
throw new ArgumentException("Expected different amount of input dims!");

int batchSize = inputs.NumRows;
Cache = new LayerCache() {
Input = inputs,
Output = Matrix2D.Zeros(batchSize, OutputDims),
DeltasIn = Matrix2D.Zeros(batchSize, OutputDims),
DeltasOut = deltasOut,
Gradients = Matrix2D.Null(),
};

SelectionProbs = Matrix2D.Zeros(batchSize, 1);
}

public void Seed(int seed)
=> Rng = new Random(seed);

public Matrix2D FetchSelectionProbs()
=> SelectionProbs;

public void Forward()
{
int batchSize = Cache.Input.NumRows;
int numClasses = Cache.Input.NumCols;
bool sparse = Cache.Output.NumCols != numClasses;

var selProbs = SelectionProbs.SliceRowsRaw(0, batchSize);
var output = Cache.Output.SliceRowsRaw(0, batchSize);
int offset = 0;

for (int i = 0; i < batchSize; i++)
{
var probDist = Cache.Input.SliceRowsRaw(i, 1);
var idx = probDist.Sample(Rng);
selProbs[i] = probDist[idx];
if (sparse)
output[offset++] = idx;
else
for (int j = 0; j < numClasses; j++)
output[offset++] = j == idx ? 1 : 0;
}
}

public void Backward()
{
int batchSize = Cache.Input.NumRows;
int numClasses = Cache.Input.NumCols;
bool sparse = Cache.Output.NumCols != numClasses;

var output = Cache.Output.SliceRowsRaw(0, batchSize);
var deltasIn = Cache.DeltasIn.SliceRowsRaw(0, batchSize);
var deltasOut = Cache.DeltasOut.SliceRowsRaw(0, batchSize);
int offset = 0;

if (sparse)
for (int i = 0; i < batchSize; i++)
for (int j = 0; j < numClasses; j++)
deltasOut[offset++] = output[i] == j ? deltasIn[i] : 0;
else
Matrix2D.CopyData(Cache.DeltasIn, Cache.DeltasOut);
}

public void ApplyGrads()
{
// info: layer isn't trainable
}
}
Loading

0 comments on commit 069571c

Please sign in to comment.