Skip to content

Commit

Permalink
Extended handling of terminal states for POMDPs (#617)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexBork authored Nov 28, 2024
1 parent 627fa29 commit 6432d9b
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 2 deletions.
157 changes: 155 additions & 2 deletions src/storm/generator/PrismNextStateGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,26 @@ StateBehavior<ValueType, StateType> PrismNextStateGenerator<ValueType, StateType
result.addStateReward(stateRewardValue);
}

// If a terminal expression was set and we must not expand this state, return now.
// If a terminal expression was set, we must not expand this state
if (!this->terminalStates.empty()) {
for (auto const& expressionBool : this->terminalStates) {
if (this->evaluator->asBool(expressionBool.first) == expressionBool.second) {
return result;
if (!isPartiallyObservable()) {
// If the model is not partially observable, return.
return result;
} else {
// for partially observable models, we need to add self-loops for all enabled actions.
result.setExpanded();
std::vector<Choice<ValueType>> allChoices = getSelfLoopsForAllActions(*this->state, stateToIdCallback);
if (allChoices.size() != 0) {
for (auto& choice : allChoices) {
result.addChoice(std::move(choice));
}

this->postprocess(result);
}
return result;
}
}
}
}
Expand Down Expand Up @@ -685,6 +700,144 @@ std::vector<Choice<ValueType>> PrismNextStateGenerator<ValueType, StateType>::ge
return result;
}

template<typename ValueType, typename StateType>
std::vector<Choice<ValueType>> PrismNextStateGenerator<ValueType, StateType>::getSelfLoopsForAllActions(CompressedState const& state,
StateToIdCallback stateToIdCallback,
CommandFilter const& commandFilter) {
std::vector<Choice<ValueType>> result;

// Asynchronous actions
// Iterate over all modules.
for (uint_fast64_t i = 0; i < program.getNumberOfModules(); ++i) {
storm::prism::Module const& module = program.getModule(i);

// Iterate over all commands.
for (uint_fast64_t j = 0; j < module.getNumberOfCommands(); ++j) {
storm::prism::Command const& command = module.getCommand(j);

// Only consider commands that are not possibly synchronizing.
if (isCommandPotentiallySynchronizing(command))
continue;

if (this->actionMask != nullptr) {
if (!this->actionMask->query(*this, command.getActionIndex())) {
continue;
}
}

// Skip the command, if it is not enabled.
if (!this->evaluator->asBool(command.getGuardExpression())) {
continue;
}

result.push_back(Choice<ValueType>(command.getActionIndex(), command.isMarkovian()));
Choice<ValueType>& choice = result.back();

// Remember the choice origin only if we were asked to.
if (this->options.isBuildChoiceOriginsSet()) {
CommandSet commandIndex{command.getGlobalIndex()};
choice.addOriginData(boost::any(std::move(commandIndex)));
}
choice.addProbability(stateToIdCallback(*this->state), storm::utility::one<ValueType>());

// Create the state-action reward for the newly created choice.
for (auto const& rewardModel : rewardModels) {
ValueType stateActionRewardValue = storm::utility::zero<ValueType>();
if (rewardModel.get().hasStateActionRewards()) {
for (auto const& stateActionReward : rewardModel.get().getStateActionRewards()) {
if (stateActionReward.getActionIndex() == choice.getActionIndex() &&
this->evaluator->asBool(stateActionReward.getStatePredicateExpression())) {
stateActionRewardValue += ValueType(this->evaluator->asRational(stateActionReward.getRewardValueExpression()));
}
}
}
choice.addReward(stateActionRewardValue);
}

if (this->options.isBuildChoiceLabelsSet() && command.isLabeled()) {
choice.addLabel(program.getActionName(command.getActionIndex()));
}
}
}

// Synchronizing actions
for (uint_fast64_t actionIndex : program.getSynchronizingActionIndices()) {
if (this->actionMask != nullptr) {
if (!this->actionMask->query(*this, actionIndex)) {
continue;
}
}
boost::optional<std::vector<std::vector<std::reference_wrapper<storm::prism::Command const>>>> optionalActiveCommandLists =
getActiveCommandsByActionIndex(actionIndex, commandFilter);

// Only process this action label, if there is at least one feasible solution.
if (optionalActiveCommandLists) {
std::vector<std::vector<std::reference_wrapper<storm::prism::Command const>>> const& activeCommandList = optionalActiveCommandLists.get();
std::vector<std::vector<std::reference_wrapper<storm::prism::Command const>>::const_iterator> iteratorList(activeCommandList.size());

// Initialize the list of iterators.
for (size_t i = 0; i < activeCommandList.size(); ++i) {
iteratorList[i] = activeCommandList[i].cbegin();
}

// As long as there is one feasible combination of commands, keep on expanding it.
bool done = false;
while (!done) {
// At this point, we applied all commands of the current command combination and newTargetStates
// contains all target states and their respective probabilities. That means we are now ready to
// add the choice to the list of transitions.
result.push_back(Choice<ValueType>(actionIndex));

// Now create the actual distribution.
Choice<ValueType>& choice = result.back();

// Remember the choice label and origins only if we were asked to.
if (this->options.isBuildChoiceLabelsSet()) {
choice.addLabel(program.getActionName(actionIndex));
}
if (this->options.isBuildChoiceOriginsSet()) {
CommandSet commandIndices;
for (uint_fast64_t i = 0; i < iteratorList.size(); ++i) {
commandIndices.insert(iteratorList[i]->get().getGlobalIndex());
}
choice.addOriginData(boost::any(std::move(commandIndices)));
}
choice.addProbability(stateToIdCallback(*this->state), storm::utility::one<ValueType>());

// Create the state-action reward for the newly created choice.
for (auto const& rewardModel : rewardModels) {
ValueType stateActionRewardValue = storm::utility::zero<ValueType>();
if (rewardModel.get().hasStateActionRewards()) {
for (auto const& stateActionReward : rewardModel.get().getStateActionRewards()) {
if (stateActionReward.getActionIndex() == choice.getActionIndex() &&
this->evaluator->asBool(stateActionReward.getStatePredicateExpression())) {
stateActionRewardValue += ValueType(this->evaluator->asRational(stateActionReward.getRewardValueExpression()));
}
}
}
choice.addReward(stateActionRewardValue);
}

// Now, check whether there is one more command combination to consider.
bool movedIterator = false;
for (int_fast64_t j = iteratorList.size() - 1; !movedIterator && j >= 0; --j) {
++iteratorList[j];
if (iteratorList[j] != activeCommandList[j].end()) {
movedIterator = true;
} else {
// Reset the iterator to the beginning of the list.
iteratorList[j] = activeCommandList[j].begin();
}
}

done = !movedIterator;
}
}
}

return result;
}

template<typename ValueType, typename StateType>
void PrismNextStateGenerator<ValueType, StateType>::generateSynchronizedDistribution(
storm::storage::BitVector const& state, ValueType const& probability, uint64_t position,
Expand Down
9 changes: 9 additions & 0 deletions src/storm/generator/PrismNextStateGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,15 @@ class PrismNextStateGenerator : public NextStateGenerator<ValueType, StateType>
void addSynchronousChoices(std::vector<Choice<ValueType>>& choices, CompressedState const& state, StateToIdCallback stateToIdCallback,
CommandFilter const& commandFilter = CommandFilter::All);

/*!
* Generates self-loops for all actions of the given state. Necessary for POMDPs.
*
* @param state The state for which to retrieve the unlabeled choices.
* @return The choices representing self-loops for all actions of the state.
*/
std::vector<Choice<ValueType>> getSelfLoopsForAllActions(CompressedState const& state, StateToIdCallback stateToIdCallback,
CommandFilter const& commandFilter = CommandFilter::All);

/*!
* Extend the Json struct with additional information about the state.
*/
Expand Down

0 comments on commit 6432d9b

Please sign in to comment.