Skip to content

Commit

Permalink
[xla-auto-sharding] Add SolveGreedy() heuristic to make local greedy …
Browse files Browse the repository at this point in the history
…choices.

Notes:
- Only represent a sharding strategy by the node sharding configs (not the induced edge sharding indices, since they're redundant)
- Add `GetEdgeStrategy()` method for recovering the induced edge sharding strategies
- Factor out logic for computing objective value of a given sharding strategy as `ComputeShardingStrategyCost()`
- Add `SolveGreedy()` heuristic that outputs a sharding based on local greedy decisions for each node
PiperOrigin-RevId: 703566314
  • Loading branch information
Matthew Fahrbach authored and tensorflower-gardener committed Dec 6, 2024
1 parent 2da8686 commit bcd2b16
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -965,12 +965,21 @@ absl::StatusOr<AutoShardingSolverOutput> FormulateAndSolveMIPFromSolverRequest(

namespace {

// Computes the edge resharding index from ths terminal node sharding indices.
EdgeStrategyIdx GetEdgeStrategy(
const AutoShardingSolverRequest& request,
const std::vector<NodeStrategyIdx>& node_strategies, const EdgeIdx edge) {
int u = request.edges(edge).first();
int v = request.edges(edge).second();
int64_t num_v_strategies = request.computation_costs(v).costs_size();
return node_strategies[u] * num_v_strategies + node_strategies[v];
}

// Checks if the node-sharding strategy has a finite cost and satisfies the
// peak-memory constraint.
std::optional<AutoShardingViolationCode> ShardingStrategyHasViolation(
const AutoShardingSolverRequest& request,
const std::vector<NodeStrategyIdx>& node_strategies,
const std::vector<EdgeStrategyIdx>& edge_strategies) {
const std::vector<NodeStrategyIdx>& node_strategies) {
const int num_nodes = request.num_nodes();
const int num_edges = request.edges_size();
// Check for infinite coefficients in the objective function.
Expand All @@ -987,7 +996,7 @@ std::optional<AutoShardingViolationCode> ShardingStrategyHasViolation(
}
}
for (EdgeIdx e = 0; e < num_edges; ++e) {
EdgeStrategyIdx strategy = edge_strategies[e];
EdgeStrategyIdx strategy = GetEdgeStrategy(request, node_strategies, e);
if (request.resharding_costs(e).costs(strategy) >= kInfinityCost) {
return AutoShardingViolationCode::kInfiniteCostViolationCode;
}
Expand All @@ -1005,97 +1014,74 @@ std::optional<AutoShardingViolationCode> ShardingStrategyHasViolation(
return std::nullopt;
}

// Assigns all nodes to their first sharding configuration. If the assignment is
// infeasible, the output cost is negative and encodes the violation code.
AutoShardingSolverOutput SolveTrivial(
const AutoShardingSolverRequest& request) {
const int num_nodes = request.num_nodes();
const int num_edges = request.edges_size();
std::vector<NodeStrategyIdx> node_strategies(num_nodes, -1);
std::vector<EdgeStrategyIdx> edge_strategies(num_edges, -1);
// Computes the objective value of the sharding strategy. If the objective value
// is infinite or the sharding is infeasible (e.g., violates the peak-memory
// constraint), then a negated `AutoShardingViolationCode` value is returned.
double ComputeShardingStrategyCost(
const AutoShardingSolverRequest& request,
const std::vector<NodeStrategyIdx>& node_strategies) {
double cost = 0.0;

for (NodeIdx v = 0; v < num_nodes; ++v) {
NodeStrategyIdx strategy = 0;
node_strategies[v] = strategy;
for (NodeIdx v = 0; v < request.num_nodes(); ++v) {
NodeStrategyIdx strategy = node_strategies[v];
cost += request.computation_costs(v).costs(strategy) +
request.communication_costs(v).costs(strategy);
}
for (EdgeIdx e = 0; e < num_edges; ++e) {
// If e = (i, j), this is the resharding cost of (i, 0) --> (j, 0).
EdgeStrategyIdx strategy = 0;
edge_strategies[e] = strategy;
for (EdgeIdx e = 0; e < request.edges_size(); ++e) {
EdgeStrategyIdx strategy = GetEdgeStrategy(request, node_strategies, e);
cost += request.resharding_costs(e).costs(strategy);
}
std::optional<AutoShardingViolationCode> violation_code =
ShardingStrategyHasViolation(request, node_strategies, edge_strategies);
ShardingStrategyHasViolation(request, node_strategies);
if (violation_code.has_value()) {
cost = -1 * (*violation_code);
}
return cost;
}

// Assigns all nodes to their first sharding configuration. If the assignment is
// infeasible, the output cost is negative and encodes the violation code.
AutoShardingSolverOutput SolveTrivial(
const AutoShardingSolverRequest& request) {
std::vector<NodeStrategyIdx> node_strategies(request.num_nodes(), 0);

AutoShardingSolverOutput output;
output.s_val = node_strategies;
output.cost = cost;
output.cost = ComputeShardingStrategyCost(request, node_strategies);
return output;
}

AutoShardingSolverOutput SolveRandom(const AutoShardingSolverRequest& request,
const int num_trials) {
std::mt19937_64 rng(0);
const int num_nodes = request.num_nodes();
const int num_edges = request.edges_size();

std::vector<NodeStrategyIdx> node_strategies(num_nodes, -1);
std::vector<EdgeStrategyIdx> edge_strategies(num_edges, -1);
double cost = 0.0;

std::vector<NodeStrategyIdx> best_node_strategies(num_nodes, -1);
std::vector<EdgeStrategyIdx> best_edge_strategies(num_edges, -1);
double best_cost = -std::numeric_limits<double>::infinity();

for (int trial = 0; trial < num_trials; ++trial) {
cost = 0.0;
std::vector<NodeStrategyIdx> node_strategies(num_nodes, -1);
for (NodeIdx v = 0; v < num_nodes; ++v) {
int num_configurations = request.computation_costs(v).costs_size();
std::uniform_int_distribution<> dist(0, num_configurations - 1);
int num_strategies = request.computation_costs(v).costs_size();
std::uniform_int_distribution<> dist(0, num_strategies - 1);
NodeStrategyIdx strategy = dist(rng);
node_strategies[v] = strategy;
cost += request.computation_costs(v).costs(strategy) +
request.communication_costs(v).costs(strategy);
}
for (EdgeIdx e = 0; e < num_edges; ++e) {
int u = request.edges(e).first();
int v = request.edges(e).second();
int64_t num_v_strategies = request.computation_costs(v).costs_size();
EdgeStrategyIdx strategy =
node_strategies[u] * num_v_strategies + node_strategies[v];
edge_strategies[e] = strategy;
cost += request.resharding_costs(e).costs(strategy);
}
std::optional<AutoShardingViolationCode> violation_code =
ShardingStrategyHasViolation(request, node_strategies, edge_strategies);
if (violation_code.has_value()) {
cost = -1 * (*violation_code);
}
double cost = ComputeShardingStrategyCost(request, node_strategies);

bool have_feasible_solution = (best_cost >= 0.0);
bool candidate_is_feasible = !violation_code.has_value();
bool candidate_is_feasible = (cost >= 0.0);
if (have_feasible_solution && !candidate_is_feasible) {
continue;
} else if (have_feasible_solution && candidate_is_feasible) {
if (cost < best_cost) {
best_node_strategies = node_strategies;
best_edge_strategies = edge_strategies;
best_cost = cost;
}
} else if (!have_feasible_solution && candidate_is_feasible) {
best_node_strategies = node_strategies;
best_edge_strategies = edge_strategies;
best_cost = cost;
} else { // Don't have feasible solution and candidate is also infeasible.
if (cost > best_cost) {
best_node_strategies = node_strategies;
best_edge_strategies = edge_strategies;
best_cost = cost; // Track encoded reason for infeasibility.
}
}
Expand All @@ -1107,6 +1093,44 @@ AutoShardingSolverOutput SolveRandom(const AutoShardingSolverRequest& request,
return output;
}

// Greedily selects the node sharding strategies. Valid modes:
// - "node_cost"
// - "node_memory"
AutoShardingSolverOutput SolveGreedy(const AutoShardingSolverRequest& request,
const std::string& mode) {
const int num_nodes = request.num_nodes();
std::vector<NodeStrategyIdx> node_strategies(num_nodes, -1);

for (NodeIdx v = 0; v < num_nodes; ++v) {
int num_strategies = request.computation_costs(v).costs_size();
NodeStrategyIdx best_strategy = -1;
double best_cost = -std::numeric_limits<double>::infinity();
for (NodeStrategyIdx strategy = 0; strategy < num_strategies; ++strategy) {
double cost = 0.0;
if (mode == "node-cost") {
cost = request.computation_costs(v).costs(strategy) +
request.communication_costs(v).costs(strategy);
} else if (mode == "node-memory") {
cost = request.memory_costs(v).costs(strategy);
} else {
CHECK(false) << absl::Substitute(
"SolveGreedy mode $0 is not implemented.", mode);
}
if (best_strategy == -1 || cost < best_cost) {
best_strategy = strategy;
best_cost = cost;
}
}
CHECK_NE(best_strategy, -1);
node_strategies[v] = best_strategy;
}

AutoShardingSolverOutput output;
output.s_val = node_strategies;
output.cost = ComputeShardingStrategyCost(request, node_strategies);
return output;
}

} // namespace

absl::StatusOr<AutoShardingSolverOutput> RunHeuristicSolver(
Expand All @@ -1121,6 +1145,10 @@ absl::StatusOr<AutoShardingSolverOutput> RunHeuristicSolver(
output = SolveTrivial(request);
} else if (algorithm == "random") {
output = SolveRandom(request, 10000);
} else if (algorithm == "greedy-node-cost") {
output = SolveGreedy(request, "node-cost");
} else if (algorithm == "greedy-node-memory") {
output = SolveGreedy(request, "node-memory");
} else {
CHECK(false) << absl::Substitute("Algorithm $0 is not implemented.",
algorithm);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ absl::StatusOr<AutoShardingSolverOutput> FormulateAndSolveMIPFromSolverRequest(
// Runs a heuristic specified by one of the following values of `algorithm`:
// - "trivial"
// - "random"
// - "greedy-node-cost"
// - "greedy-node-memory"
absl::StatusOr<AutoShardingSolverOutput> RunHeuristicSolver(
const AutoShardingSolverRequest& request, const std::string& algorithm);

Expand Down

0 comments on commit bcd2b16

Please sign in to comment.