diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index 76b4c2365365c4..abd33f114832fd 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -965,12 +965,21 @@ absl::StatusOr FormulateAndSolveMIPFromSolverRequest( namespace { +// Computes the edge resharding index from ths terminal node sharding indices. +EdgeStrategyIdx GetEdgeStrategy( + const AutoShardingSolverRequest& request, + const std::vector& node_strategies, const EdgeIdx edge) { + int u = request.edges(edge).first(); + int v = request.edges(edge).second(); + int64_t num_v_strategies = request.computation_costs(v).costs_size(); + return node_strategies[u] * num_v_strategies + node_strategies[v]; +} + // Checks if the node-sharding strategy has a finite cost and satisfies the // peak-memory constraint. std::optional ShardingStrategyHasViolation( const AutoShardingSolverRequest& request, - const std::vector& node_strategies, - const std::vector& edge_strategies) { + const std::vector& node_strategies) { const int num_nodes = request.num_nodes(); const int num_edges = request.edges_size(); // Check for infinite coefficients in the objective function. @@ -987,7 +996,7 @@ std::optional ShardingStrategyHasViolation( } } for (EdgeIdx e = 0; e < num_edges; ++e) { - EdgeStrategyIdx strategy = edge_strategies[e]; + EdgeStrategyIdx strategy = GetEdgeStrategy(request, node_strategies, e); if (request.resharding_costs(e).costs(strategy) >= kInfinityCost) { return AutoShardingViolationCode::kInfiniteCostViolationCode; } @@ -1005,37 +1014,39 @@ std::optional ShardingStrategyHasViolation( return std::nullopt; } -// Assigns all nodes to their first sharding configuration. If the assignment is -// infeasible, the output cost is negative and encodes the violation code. -AutoShardingSolverOutput SolveTrivial( - const AutoShardingSolverRequest& request) { - const int num_nodes = request.num_nodes(); - const int num_edges = request.edges_size(); - std::vector node_strategies(num_nodes, -1); - std::vector edge_strategies(num_edges, -1); +// Computes the objective value of the sharding strategy. If the objective value +// is infinite or the sharding is infeasible (e.g., violates the peak-memory +// constraint), then a negated `AutoShardingViolationCode` value is returned. +double ComputeShardingStrategyCost( + const AutoShardingSolverRequest& request, + const std::vector& node_strategies) { double cost = 0.0; - - for (NodeIdx v = 0; v < num_nodes; ++v) { - NodeStrategyIdx strategy = 0; - node_strategies[v] = strategy; + for (NodeIdx v = 0; v < request.num_nodes(); ++v) { + NodeStrategyIdx strategy = node_strategies[v]; cost += request.computation_costs(v).costs(strategy) + request.communication_costs(v).costs(strategy); } - for (EdgeIdx e = 0; e < num_edges; ++e) { - // If e = (i, j), this is the resharding cost of (i, 0) --> (j, 0). - EdgeStrategyIdx strategy = 0; - edge_strategies[e] = strategy; + for (EdgeIdx e = 0; e < request.edges_size(); ++e) { + EdgeStrategyIdx strategy = GetEdgeStrategy(request, node_strategies, e); cost += request.resharding_costs(e).costs(strategy); } std::optional violation_code = - ShardingStrategyHasViolation(request, node_strategies, edge_strategies); + ShardingStrategyHasViolation(request, node_strategies); if (violation_code.has_value()) { cost = -1 * (*violation_code); } + return cost; +} + +// Assigns all nodes to their first sharding configuration. If the assignment is +// infeasible, the output cost is negative and encodes the violation code. +AutoShardingSolverOutput SolveTrivial( + const AutoShardingSolverRequest& request) { + std::vector node_strategies(request.num_nodes(), 0); AutoShardingSolverOutput output; output.s_val = node_strategies; - output.cost = cost; + output.cost = ComputeShardingStrategyCost(request, node_strategies); return output; } @@ -1043,59 +1054,34 @@ AutoShardingSolverOutput SolveRandom(const AutoShardingSolverRequest& request, const int num_trials) { std::mt19937_64 rng(0); const int num_nodes = request.num_nodes(); - const int num_edges = request.edges_size(); - - std::vector node_strategies(num_nodes, -1); - std::vector edge_strategies(num_edges, -1); - double cost = 0.0; - std::vector best_node_strategies(num_nodes, -1); - std::vector best_edge_strategies(num_edges, -1); double best_cost = -std::numeric_limits::infinity(); for (int trial = 0; trial < num_trials; ++trial) { - cost = 0.0; + std::vector node_strategies(num_nodes, -1); for (NodeIdx v = 0; v < num_nodes; ++v) { - int num_configurations = request.computation_costs(v).costs_size(); - std::uniform_int_distribution<> dist(0, num_configurations - 1); + int num_strategies = request.computation_costs(v).costs_size(); + std::uniform_int_distribution<> dist(0, num_strategies - 1); NodeStrategyIdx strategy = dist(rng); node_strategies[v] = strategy; - cost += request.computation_costs(v).costs(strategy) + - request.communication_costs(v).costs(strategy); - } - for (EdgeIdx e = 0; e < num_edges; ++e) { - int u = request.edges(e).first(); - int v = request.edges(e).second(); - int64_t num_v_strategies = request.computation_costs(v).costs_size(); - EdgeStrategyIdx strategy = - node_strategies[u] * num_v_strategies + node_strategies[v]; - edge_strategies[e] = strategy; - cost += request.resharding_costs(e).costs(strategy); - } - std::optional violation_code = - ShardingStrategyHasViolation(request, node_strategies, edge_strategies); - if (violation_code.has_value()) { - cost = -1 * (*violation_code); } + double cost = ComputeShardingStrategyCost(request, node_strategies); bool have_feasible_solution = (best_cost >= 0.0); - bool candidate_is_feasible = !violation_code.has_value(); + bool candidate_is_feasible = (cost >= 0.0); if (have_feasible_solution && !candidate_is_feasible) { continue; } else if (have_feasible_solution && candidate_is_feasible) { if (cost < best_cost) { best_node_strategies = node_strategies; - best_edge_strategies = edge_strategies; best_cost = cost; } } else if (!have_feasible_solution && candidate_is_feasible) { best_node_strategies = node_strategies; - best_edge_strategies = edge_strategies; best_cost = cost; } else { // Don't have feasible solution and candidate is also infeasible. if (cost > best_cost) { best_node_strategies = node_strategies; - best_edge_strategies = edge_strategies; best_cost = cost; // Track encoded reason for infeasibility. } } @@ -1107,6 +1093,44 @@ AutoShardingSolverOutput SolveRandom(const AutoShardingSolverRequest& request, return output; } +// Greedily selects the node sharding strategies. Valid modes: +// - "node_cost" +// - "node_memory" +AutoShardingSolverOutput SolveGreedy(const AutoShardingSolverRequest& request, + const std::string& mode) { + const int num_nodes = request.num_nodes(); + std::vector node_strategies(num_nodes, -1); + + for (NodeIdx v = 0; v < num_nodes; ++v) { + int num_strategies = request.computation_costs(v).costs_size(); + NodeStrategyIdx best_strategy = -1; + double best_cost = -std::numeric_limits::infinity(); + for (NodeStrategyIdx strategy = 0; strategy < num_strategies; ++strategy) { + double cost = 0.0; + if (mode == "node-cost") { + cost = request.computation_costs(v).costs(strategy) + + request.communication_costs(v).costs(strategy); + } else if (mode == "node-memory") { + cost = request.memory_costs(v).costs(strategy); + } else { + CHECK(false) << absl::Substitute( + "SolveGreedy mode $0 is not implemented.", mode); + } + if (best_strategy == -1 || cost < best_cost) { + best_strategy = strategy; + best_cost = cost; + } + } + CHECK_NE(best_strategy, -1); + node_strategies[v] = best_strategy; + } + + AutoShardingSolverOutput output; + output.s_val = node_strategies; + output.cost = ComputeShardingStrategyCost(request, node_strategies); + return output; +} + } // namespace absl::StatusOr RunHeuristicSolver( @@ -1121,6 +1145,10 @@ absl::StatusOr RunHeuristicSolver( output = SolveTrivial(request); } else if (algorithm == "random") { output = SolveRandom(request, 10000); + } else if (algorithm == "greedy-node-cost") { + output = SolveGreedy(request, "node-cost"); + } else if (algorithm == "greedy-node-memory") { + output = SolveGreedy(request, "node-memory"); } else { CHECK(false) << absl::Substitute("Algorithm $0 is not implemented.", algorithm); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h index a20256c82fcf19..7852e1abfb91f7 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h @@ -49,6 +49,8 @@ absl::StatusOr FormulateAndSolveMIPFromSolverRequest( // Runs a heuristic specified by one of the following values of `algorithm`: // - "trivial" // - "random" +// - "greedy-node-cost" +// - "greedy-node-memory" absl::StatusOr RunHeuristicSolver( const AutoShardingSolverRequest& request, const std::string& algorithm);