Skip to content

Commit

Permalink
remove declarations of input/output handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Oct 31, 2023
1 parent ee158b4 commit da90a7e
Showing 1 changed file with 0 additions and 26 deletions.
26 changes: 0 additions & 26 deletions torch_xla/csrc/xla_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,32 +58,6 @@ class ShardingUtil {
bool unroll_windowed_einsum = false,
bool bidirectional_windowed_einsum = false);

// Reshuffles arguments (sharded or replicated) on the devices. The
// size of the arguments vector must match that of the sharding_specs.
// The the returned arguments will be in 1:1 correspondence with the `devices`
// vector, so the `i`th result will belong on the `i`th device.
// TODO(yeounoh) avoiding pre-loading of the unpartitioned input arguments
// might improve the performance and save the bandwidth.
static std::vector<std::vector<runtime::ComputationClient::DataPtr>>
InputHandler(std::vector<runtime::ComputationClient::DataPtr> arguments,
std::vector<std::string> devices);

// Processes replicated execution results, where `sharded_results` contains
// `PjRtData` handles and spans the number of devices (outer) and the number
// of arguments (innner). This requires `sharding_specs` of the same size as
// the number of arguments. `sharding_specs` can contain `nullptr` if the
// corresponding result argument is not sharded. The replicated execution
// `replicated_output=true` leaves the results in replicated states, which is
// aligned with the default exepctation of XLA compiler. However, we override
// the compiler's default behavior and allow the execution to return sharded
// results and wrap sharded arguments into `PjRtShardedData`. This returns a
// vector of size that is equal to the number of arguments.
static std::vector<runtime::ComputationClient::DataPtr> OutputHandler(
std::vector<std::vector<runtime::ComputationClient::DataPtr>>
sharded_results,
std::vector<XLATensor::ShardingSpecPtr> sharding_specs,
bool replicated_output = false);

// Returns the shape of the resulting shards of `tensor` after applying
// `sharding`. This assumes the shards will be padded to ensure they all
// have the same shape.
Expand Down

0 comments on commit da90a7e

Please sign in to comment.