diff --git a/patches/patched_pytorch_v2.2.1_rc3.patch b/patches/patched_pytorch_v2.2.1_rc3.patch index 30648b1..747f3b9 100644 --- a/patches/patched_pytorch_v2.2.1_rc3.patch +++ b/patches/patched_pytorch_v2.2.1_rc3.patch @@ -1,11 +1,11 @@ diff --git a/aten/src/ATen/FunctionalInverses.cpp b/aten/src/ATen/FunctionalInverses.cpp -index af0e5af3be8..9896f16a84e 100644 +index af0e5af..9896f16 100644 --- a/aten/src/ATen/FunctionalInverses.cpp +++ b/aten/src/ATen/FunctionalInverses.cpp @@ -151,6 +151,12 @@ Tensor FunctionalInverses::expand_copy_inverse(const Tensor& base, const Tensor& return at::sum_to(mutated_view, base.sym_sizes(),/*always_return_non_view=*/!reapply_views); } - + +Tensor FunctionalInverses::expand_as_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views,const Tensor& other) { + return at::sum_to(mutated_view, base.sym_sizes(),/*always_return_non_view=*/!reapply_views); +} @@ -15,8 +15,175 @@ index af0e5af3be8..9896f16a84e 100644 Tensor FunctionalInverses::permute_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef dims) { return at::functionalization::permute_copy_inverse(mutated_view, dims, reapply_views); } +diff --git a/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp b/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp +index b8004ec..45869fe 100644 +--- a/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp ++++ b/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp +@@ -137,6 +137,32 @@ uint64_t CUDAGeneratorImpl::get_offset() const { + return philox_offset_per_thread_; + } + ++uint64_t CUDAGeneratorImpl::get_sharding_spec(uint64_t local_shape[MAX_DIMS], ++ uint64_t global_offset[MAX_DIMS], ++ uint64_t global_shape[MAX_DIMS], ++ uint64_t global_strides[MAX_DIMS]) const { ++ at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::get_sharding_spec"); ++ memcpy(local_shape, this->local_shape_, this->tensor_dim_ * sizeof(uint64_t)); ++ memcpy(global_offset, this->global_offset_, this->tensor_dim_ * sizeof(uint64_t)); ++ memcpy(global_shape, this->global_shape_, this->tensor_dim_ * sizeof(uint64_t)); ++ memcpy(global_strides, this->global_strides_, this->tensor_dim_ * sizeof(uint64_t)); ++ return this->tensor_dim_; ++} ++ ++void CUDAGeneratorImpl::set_sharding_spec(uint64_t tensor_dim, ++ const uint64_t local_shape[MAX_DIMS], ++ const uint64_t global_offset[MAX_DIMS], ++ const uint64_t global_shape[MAX_DIMS], ++ const uint64_t global_strides[MAX_DIMS]) { ++ at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::set_sharding_spec"); ++ this->tensor_dim_ = tensor_dim; ++ memcpy(this->local_shape_, local_shape, tensor_dim * sizeof(uint64_t)); ++ memcpy(this->global_offset_, global_offset, tensor_dim * sizeof(uint64_t)); ++ memcpy(this->global_shape_, global_shape, tensor_dim * sizeof(uint64_t)); ++ memcpy(this->global_strides_, global_strides, tensor_dim * sizeof(uint64_t)); ++ no_reset_rnn_state_.clear(); ++} ++ + #define CAPTURE_DEFAULT_GENS_MSG \ + "In regions captured by CUDA graphs, you may only use the default CUDA RNG " \ + "generator on the device that's current when capture begins. " \ +@@ -175,14 +201,23 @@ c10::intrusive_ptr CUDAGeneratorImpl::get_state() const { + // The RNG state comprises the seed, and an offset used for Philox. + static const size_t seed_size = sizeof(uint64_t); + static const size_t offset_size = sizeof(int64_t); +- static const size_t total_size = seed_size + offset_size; ++ const size_t local_shape_size = sizeof(uint64_t) * this->tensor_dim_; ++ size_t total_size = seed_size + offset_size + local_shape_size * 4; + + auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt); + auto rng_state = state_tensor.data_ptr(); + auto current_seed = this->current_seed(); + auto offset = static_cast(this->philox_offset_per_thread()); // Note that old THCGeneratorState had offset as std::atomic ++ auto local_shape = this->local_shape_; ++ auto global_offset = this->global_offset_; ++ auto global_shape = this->global_shape_; ++ auto global_strides = this->global_strides_; + memcpy(rng_state, ¤t_seed, seed_size); + memcpy(rng_state + seed_size, &offset, offset_size); ++ memcpy(rng_state + seed_size + offset_size, local_shape, local_shape_size); ++ memcpy(rng_state + seed_size + offset_size + local_shape_size, global_offset, local_shape_size); ++ memcpy(rng_state + seed_size + offset_size + 2 * local_shape_size, global_shape, local_shape_size); ++ memcpy(rng_state + seed_size + offset_size + 3 * local_shape_size, global_strides, local_shape_size); + + return state_tensor.getIntrusivePtr(); + } +@@ -196,27 +231,47 @@ c10::intrusive_ptr CUDAGeneratorImpl::get_state() const { + void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) { + static const size_t seed_size = sizeof(uint64_t); + static const size_t offset_size = sizeof(int64_t); +- static const size_t total_size = seed_size + offset_size; + + detail::check_rng_state(new_state); + + bool no_philox_seed = false; + auto new_state_size = new_state.numel(); +- if (new_state_size == total_size - offset_size) { ++ if (new_state_size % (4 * seed_size) == seed_size) { + no_philox_seed = true; + } else { +- TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size"); ++ TORCH_CHECK(new_state_size % (4 * seed_size) == 2 * seed_size, "RNG state is wrong size"); + } + + uint64_t input_seed; + auto new_rng_state = new_state.data_dtype_initialized(); + memcpy(&input_seed, new_rng_state, seed_size); + this->set_current_seed(input_seed); ++ + int64_t philox_offset = 0; + if (!no_philox_seed) { + memcpy(&philox_offset, new_rng_state + seed_size, offset_size); + } + this->set_philox_offset_per_thread(static_cast(philox_offset)); ++ ++ size_t ptr_offset = offset_size; ++ if (!no_philox_seed) { ++ ptr_offset += seed_size; ++ } ++ ++ uint64_t tensor_dim = (new_state_size - ptr_offset) / (4 * seed_size); ++ ++ TORCH_CHECK(tensor_dim <= MAX_DIMS, "tensor has too many (", tensor_dim, " > ", MAX_DIMS, ") dims"); ++ ++ uint64_t local_shape[MAX_DIMS]; ++ uint64_t global_offset[MAX_DIMS]; ++ uint64_t global_shape[MAX_DIMS]; ++ uint64_t global_strides[MAX_DIMS]; ++ ++ memcpy(local_shape, new_rng_state + ptr_offset, tensor_dim * seed_size); ++ memcpy(global_offset, new_rng_state + ptr_offset + tensor_dim * seed_size, tensor_dim * seed_size); ++ memcpy(global_shape, new_rng_state + ptr_offset + 2 * tensor_dim * seed_size, tensor_dim * seed_size); ++ memcpy(global_strides, new_rng_state + ptr_offset + 3 * tensor_dim * seed_size, tensor_dim * seed_size); ++ this->set_sharding_spec(tensor_dim, local_shape, global_offset, global_shape, global_strides); + } + + /** +@@ -351,6 +406,7 @@ CUDAGeneratorImpl* CUDAGeneratorImpl::clone_impl() const { + auto gen = new CUDAGeneratorImpl(this->device().index()); + gen->set_current_seed(this->seed_); + gen->set_philox_offset_per_thread(this->philox_offset_per_thread_); ++ gen->set_sharding_spec(this->tensor_dim_, this->local_shape_, this->global_offset_, this->global_shape_, this->global_strides_); + return gen; + } + +diff --git a/aten/src/ATen/cuda/CUDAGeneratorImpl.h b/aten/src/ATen/cuda/CUDAGeneratorImpl.h +index 2fe8a6f..874ef15 100644 +--- a/aten/src/ATen/cuda/CUDAGeneratorImpl.h ++++ b/aten/src/ATen/cuda/CUDAGeneratorImpl.h +@@ -87,6 +87,13 @@ namespace at { + * + */ + ++// aten/src/ATen/cuda/detail/OffsetCalculator.cuh ++#if defined(USE_ROCM) ++constexpr int MAX_DIMS = 16; ++#else ++constexpr int MAX_DIMS = 25; ++#endif ++ + struct TORCH_CUDA_CPP_API CUDAGeneratorImpl : public c10::GeneratorImpl { + // Constructors + CUDAGeneratorImpl(DeviceIndex device_index = -1); +@@ -106,6 +113,15 @@ struct TORCH_CUDA_CPP_API CUDAGeneratorImpl : public c10::GeneratorImpl { + void capture_prologue(int64_t* seed_extragraph, int64_t* offset_extragraph); + uint64_t capture_epilogue(); + PhiloxCudaState philox_cuda_state(uint64_t increment); ++ uint64_t get_sharding_spec(uint64_t local_shape[MAX_DIMS], ++ uint64_t global_offset[MAX_DIMS], ++ uint64_t global_shape[MAX_DIMS], ++ uint64_t global_strides[MAX_DIMS]) const; ++ void set_sharding_spec(uint64_t tensor_dim, ++ const uint64_t local_shape[MAX_DIMS], ++ const uint64_t global_offset[MAX_DIMS], ++ const uint64_t global_shape[MAX_DIMS], ++ const uint64_t global_strides[MAX_DIMS]); + + bool reset_rnn_state() { + return !no_reset_rnn_state_.test_and_set(); +@@ -124,6 +140,11 @@ private: + int64_t* seed_extragraph_{}; + int64_t* offset_extragraph_{}; + uint32_t offset_intragraph_ = 0; ++ uint64_t tensor_dim_ = 0; ++ uint64_t local_shape_[MAX_DIMS]; ++ uint64_t global_offset_[MAX_DIMS]; ++ uint64_t global_shape_[MAX_DIMS]; ++ uint64_t global_strides_[MAX_DIMS]; + bool graph_expects_this_gen_ = false; + std::atomic_flag no_reset_rnn_state_; + }; diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp -index 1b179a505e9..b1beaa67ae7 100644 +index 1b179a5..b1beaa6 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -296,7 +296,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { @@ -29,7 +196,7 @@ index 1b179a505e9..b1beaa67ae7 100644 OP_DECOMPOSE2(var, dim); OP_DECOMPOSE(var_mean); diff --git a/aten/src/ATen/native/Onehot.cpp b/aten/src/ATen/native/Onehot.cpp -index 41b7a696186..26fd0979c39 100644 +index 41b7a69..26fd097 100644 --- a/aten/src/ATen/native/Onehot.cpp +++ b/aten/src/ATen/native/Onehot.cpp @@ -5,7 +5,9 @@ @@ -43,7 +210,7 @@ index 41b7a696186..26fd0979c39 100644 #include #endif @@ -14,6 +16,17 @@ namespace at { namespace native { - + Tensor one_hot(const Tensor &self, int64_t num_classes) { TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor."); + // using meta bit test to catch Fake Tensor as well until __torch_function__ @@ -58,16 +225,16 @@ index 41b7a696186..26fd0979c39 100644 + } + auto shape = self.sizes().vec(); - + // empty tensor could be converted to one hot representation, diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp -index 7a47490c674..a2c54db9424 100644 +index 7a47490..a2c54db 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -2228,26 +2228,21 @@ bool cpu_equal(const Tensor& self, const Tensor& other) { return result.load(); } - + -static Tensor value_selecting_reduction_backward(const Tensor& grad, int64_t dim, const Tensor& indices, at::IntArrayRef sizes, bool keepdim) { - return at::native::value_selecting_reduction_backward_symint(grad, dim, indices, c10::fromIntArrayRefSlow(sizes), keepdim); -} @@ -88,14 +255,14 @@ index 7a47490c674..a2c54db9424 100644 } return grad_in.scatter_(dim, indices_, grad_out); }; - + - if (!keepdim && !sizes.empty()) { + if (!keepdim && !src.sizes().empty()) { auto grad_ = grad.unsqueeze(dim); auto indices_ = indices.unsqueeze(dim); return inplace_scatter_if_not_tensor_subclass(grad_, indices_); diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp -index 0a018fbc8db..a5e4643ae53 100644 +index 0a018fb..a5e4643 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -109,6 +109,7 @@ @@ -108,7 +275,7 @@ index 0a018fbc8db..a5e4643ae53 100644 #include @@ -1143,7 +1144,21 @@ Tensor expand(const Tensor& self, c10::IntArrayRef size, bool /*unused*/) { } - + Tensor expand_as(const Tensor& self, const Tensor& other) { - return self.expand_symint(other.sym_sizes()); + IntArrayRef size = other.sizes(); @@ -127,10 +294,432 @@ index 0a018fbc8db..a5e4643ae53 100644 + namedinference::propagate_names_for_expand(result, self); + return result; } - + Tensor sum_to_size_symint(const Tensor& self, SymIntArrayRef size) { +diff --git a/aten/src/ATen/native/cuda/DistributionTemplates.h b/aten/src/ATen/native/cuda/DistributionTemplates.h +index 5f38e36..aa95680 100644 +--- a/aten/src/ATen/native/cuda/DistributionTemplates.h ++++ b/aten/src/ATen/native/cuda/DistributionTemplates.h +@@ -62,32 +62,47 @@ std::tuple calc_execution_policy(int64_t total_elements) { + } + + // grid stride loop kernel for distributions +-template ++template + C10_LAUNCH_BOUNDS_2(block_size_bound, grid_size_bound) + __global__ void distribution_elementwise_grid_stride_kernel(int numel, + PhiloxCudaState philox_args, + const dist_t dist_func, +- const transform_t transform_func) { +- auto seeds = at::cuda::philox::unpack(philox_args); +- int idx = blockIdx.x * blockDim.x + threadIdx.x; ++ const transform_t transform_func, ++ const virtual_idx_t virtual_idx_func, ++ bool is_sharded=false) { ++ auto [seed, global_offset] = at::cuda::philox::unpack(philox_args); ++ uint64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + curandStatePhilox4_32_10_t state; +- curand_init(std::get<0>(seeds), +- idx, +- std::get<1>(seeds), +- &state); +- + int rounded_size = ((numel - 1)/(blockDim.x * gridDim.x * unroll_factor)+1) * + blockDim.x * gridDim.x * unroll_factor; +- for(int linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) { +- auto rand = dist_func(&state); +- #pragma unroll +- for (int ii = 0; ii < unroll_factor; ii++) { +- int li = linear_index + blockDim.x * gridDim.x * ii; +- if (li < numel) { +- transform_func(li, static_cast((&rand.x)[ii])); ++ if (is_sharded) { ++ for(int linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) { ++ #pragma unroll ++ for (int ii = 0; ii < unroll_factor; ii++) { ++ int li = linear_index + blockDim.x * gridDim.x * ii; ++ if (li < numel) { ++ auto [virtual_idx, virtual_offset, single_thread_n] = virtual_idx_func(li); ++ virtual_offset += global_offset; ++ curand_init(seed, virtual_idx, 4 * (virtual_offset / 4), &state); ++ auto rand = dist_func(&state); ++ transform_func(li, static_cast((&rand.x)[virtual_offset % unroll_factor])); ++ } ++ } ++ __syncthreads(); ++ } ++ } else { ++ curand_init(seed, idx, global_offset, &state); ++ for(int linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) { ++ auto rand = dist_func(&state); ++ #pragma unroll ++ for (int ii = 0; ii < unroll_factor; ii++) { ++ int li = linear_index + blockDim.x * gridDim.x * ii; ++ if (li < numel) { ++ transform_func(li, static_cast((&rand.x)[ii])); ++ } + } ++ __syncthreads(); + } +- __syncthreads(); + } + } + +@@ -127,11 +142,17 @@ void distribution_nullary_kernel(at::TensorIteratorBase& iter, + auto counter_offset = std::get<0>(execution_policy); + auto grid = std::get<1>(execution_policy); + auto block = std::get<2>(execution_policy); ++ uint64_t tensor_dim = 0; ++ uint64_t local_shape[MAX_DIMS]; ++ uint64_t global_offset[MAX_DIMS]; ++ uint64_t global_shape[MAX_DIMS]; ++ uint64_t global_strides[MAX_DIMS]; + PhiloxCudaState rng_engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_cuda_state(counter_offset); ++ tensor_dim = gen->get_sharding_spec(local_shape, global_offset, global_shape, global_strides); + } + + if (!iter.can_use_32bit_indexing()) { +@@ -144,6 +165,40 @@ void distribution_nullary_kernel(at::TensorIteratorBase& iter, + + char* out_data = (char*)iter.data_ptr(0); + ++ uint64_t global_numel = numel; ++ uint64_t single_thread_n = grid.x * block.x; ++ bool is_sharded = false; ++ if (tensor_dim > 0) { ++ global_numel = 1; ++ is_sharded = true; ++ for (int i = 0; i < (int)tensor_dim; ++i) { ++ global_numel *= global_shape[i]; ++ if (local_shape[i] == 0) ++ is_sharded = false; ++ } ++ auto single_exec_policy = calc_execution_policy(global_numel); ++ single_thread_n = std::get<1>(single_exec_policy).x * std::get<2>(single_exec_policy).x; ++ } ++ TORCH_CHECK(single_thread_n > 0, "single_thread_n is 0!!!"); ++ ++ auto virtual_idx_func = [=]__device__(uint64_t local_entry_linear_idx) { ++ if (tensor_dim == 0) // not a dtensor ++ return std::make_tuple(local_entry_linear_idx % single_thread_n, ++ local_entry_linear_idx / single_thread_n, ++ single_thread_n); ++ uint64_t tmp_idx = local_entry_linear_idx; ++ uint64_t global_entry_linear_idx = 0; ++ for (int i = tensor_dim - 1; i >= 0; --i) { ++ uint64_t global_idx_at_i = global_offset[i] + tmp_idx % local_shape[i]; ++ tmp_idx /= local_shape[i]; ++ global_entry_linear_idx += global_idx_at_i * global_strides[i]; ++ } ++ uint64_t virtual_thread_idx = global_entry_linear_idx % single_thread_n; ++ uint64_t virtual_offset = global_entry_linear_idx / single_thread_n; ++ virtual_offset *= curand4_engine_calls / unroll_factor; ++ return std::make_tuple(virtual_thread_idx, virtual_offset, single_thread_n); ++ }; ++ + auto stream = at::cuda::getCurrentCUDAStream(); + if (iter.is_trivial_1d()) { + auto strides = iter.get_inner_strides(); +@@ -155,7 +210,9 @@ void distribution_nullary_kernel(at::TensorIteratorBase& iter, + [=]__device__(int idx, accscalar_t rand) { + scalar_t* out = (scalar_t*)&out_data[stride0 * idx]; + *out = transform_func(rand); +- } ++ }, ++ virtual_idx_func, ++ is_sharded + ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { +@@ -168,7 +225,9 @@ void distribution_nullary_kernel(at::TensorIteratorBase& iter, + auto offsets = offset_calc.get(idx); + scalar_t* out = (scalar_t*)&out_data[offsets[0]]; + *out = transform_func(rand); +- } ++ }, ++ virtual_idx_func, ++ is_sharded + ); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +diff --git a/aten/src/ATen/native/cuda/Dropout.cu b/aten/src/ATen/native/cuda/Dropout.cu +index 67ea3e4..938a90a 100644 +--- a/aten/src/ATen/native/cuda/Dropout.cu ++++ b/aten/src/ATen/native/cuda/Dropout.cu +@@ -56,13 +56,10 @@ fused_dropout_kernel_vec(at::cuda::detail::TensorInfo a, + using LoadT = memory::aligned_vector; + using MaskLoadT = memory::aligned_vector; + +- auto seeds = at::cuda::philox::unpack(philox_args); ++ auto [seed, global_offset] = at::cuda::philox::unpack(philox_args); + IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; + curandStatePhilox4_32_10_t state; +- curand_init(std::get<0>(seeds), +- idx, +- std::get<1>(seeds), +- &state); ++ curand_init(seed, idx, global_offset, &state); + + // Helps align the total number of times curand_uniform4 is called by each thread for the same totalElements + // in the vec=2 and vec=4 cases. +@@ -128,7 +125,8 @@ template < + typename IndexType, + int ADims, + int BDims = ADims, +- typename mask_t> ++ typename mask_t, ++ typename virtual_idx_t> + #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) + C10_LAUNCH_BOUNDS_2(256, 4) + #endif +@@ -137,48 +135,75 @@ fused_dropout_kernel(cuda::detail::TensorInfo a, + cuda::detail::TensorInfo b, + cuda::detail::TensorInfo c, + IndexType totalElements, accscalar_t p, +- PhiloxCudaState philox_args) { +- auto seeds = at::cuda::philox::unpack(philox_args); ++ PhiloxCudaState philox_args, ++ const virtual_idx_t virtual_idx_func, ++ bool is_sharded=false, ++ int global_vec_size=1) { ++ auto [seed, global_offset] = at::cuda::philox::unpack(philox_args); + IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; +- curandStatePhilox4_32_10_t state; +- curand_init(std::get<0>(seeds), +- idx, +- std::get<1>(seeds), +- &state); +- accscalar_t scale = 1.0 / p; +- + IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * + blockDim.x * gridDim.x * UNROLL; +- for (IndexType linearIndex = idx; +- linearIndex < rounded_size; +- linearIndex += gridDim.x * blockDim.x*UNROLL) { +-//curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything +- float4 rand = curand_uniform4(&state); +- scalar_t src[UNROLL]; +- rand.x = rand.x < p; +- rand.y = rand.y < p; +- rand.z = rand.z < p; +- rand.w = rand.w < p; +- for (int ii = 0; ii < UNROLL; ii++) { +- IndexType li = linearIndex + blockDim.x * gridDim.x * ii; +- if (li < totalElements) { +- // Convert `linearIndex` into an offset of `a` +- const IndexType aOffset = +- cuda::detail::IndexToOffset::get(li, a); +- src[ii] = a.data[aOffset]; +- } +- } +- for (int ii = 0; ii < UNROLL; ii++) { +- IndexType li = linearIndex + blockDim.x * gridDim.x * ii; +- if (li < totalElements) { +- // Convert `linearIndex` into an offset of `b` +- const IndexType bOffset = +- cuda::detail::IndexToOffset::get(li, b); +- b.data[bOffset] = src[ii]*(&rand.x)[ii]*scale; +- c.data[bOffset] = (mask_t)(&rand.x)[ii]; +- } +- } +- __syncthreads(); ++ accscalar_t scale = 1.0 / p; ++ curandStatePhilox4_32_10_t state; ++ if (is_sharded) { ++ for (IndexType linearIndex = idx; ++ linearIndex < totalElements; ++ linearIndex += gridDim.x * blockDim.x) { ++ //curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything ++ auto [global_idx, single_thread_n] = virtual_idx_func(linearIndex); ++ IndexType virtual_idx = (global_idx / global_vec_size) % single_thread_n; ++ IndexType virtual_offset = global_vec_size * ((global_idx / global_vec_size) / single_thread_n) + global_idx % global_vec_size; ++ virtual_offset += global_offset; ++ curand_init(seed, virtual_idx, 4 * (virtual_offset / 4), &state); ++ float4 rand = curand_uniform4(&state); ++ rand.x = rand.x < p; ++ rand.y = rand.y < p; ++ rand.z = rand.z < p; ++ rand.w = rand.w < p; ++ // Convert `linearIndex` into an offset of `a` ++ const IndexType aOffset = ++ cuda::detail::IndexToOffset::get(linearIndex, a); ++ // Convert `linearIndex` into an offset of `b` ++ const IndexType bOffset = ++ cuda::detail::IndexToOffset::get(linearIndex, b); ++ scalar_t src = a.data[aOffset]; ++ b.data[bOffset] = src*(&rand.x)[virtual_offset % 4]*scale; ++ c.data[bOffset] = (mask_t)(&rand.x)[virtual_offset % 4]; ++ __syncthreads(); ++ } ++ } else { ++ curand_init(seed, idx, global_offset, &state); ++ for (IndexType linearIndex = idx; ++ linearIndex < rounded_size; ++ linearIndex += gridDim.x * blockDim.x*UNROLL) { ++ //curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything ++ float4 rand = curand_uniform4(&state); ++ scalar_t src[UNROLL]; ++ rand.x = rand.x < p; ++ rand.y = rand.y < p; ++ rand.z = rand.z < p; ++ rand.w = rand.w < p; ++ for (int ii = 0; ii < UNROLL; ii++) { ++ IndexType li = linearIndex + blockDim.x * gridDim.x * ii; ++ if (li < totalElements) { ++ // Convert `linearIndex` into an offset of `a` ++ const IndexType aOffset = ++ cuda::detail::IndexToOffset::get(li, a); ++ src[ii] = a.data[aOffset]; ++ } ++ } ++ for (int ii = 0; ii < UNROLL; ii++) { ++ IndexType li = linearIndex + blockDim.x * gridDim.x * ii; ++ if (li < totalElements) { ++ // Convert `linearIndex` into an offset of `b` ++ const IndexType bOffset = ++ cuda::detail::IndexToOffset::get(li, b); ++ b.data[bOffset] = src[ii]*(&rand.x)[ii]*scale; ++ c.data[bOffset] = (mask_t)(&rand.x)[ii]; ++ } ++ } ++ __syncthreads(); ++ } + } + } + +@@ -217,7 +242,7 @@ int get_vector_size(at::Tensor self, at::Tensor ret, at::Tensor mask) { + return can_vectorize ? vec_size : 1; + } + +-template ++template + inline void launcher( + const Tensor& self, + Tensor& ret, +@@ -226,7 +251,10 @@ inline void launcher( + const int64_t nelem, + const PhiloxCudaState rng_engine_inputs, + dim3 grid, +- dim3 dim_block) { ++ dim3 dim_block, ++ const virtual_idx_t virtual_idx_func, ++ bool is_sharded=false, ++ int global_vec_size=1) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, +@@ -248,7 +276,7 @@ inline void launcher( + + int vec_size = get_vector_size(self, ret, mask); + +- if (vec_size > 1) { ++ if (vec_size > 1 && !is_sharded) { + switch (vec_size) { + case 4: + fused_dropout_kernel_vec< +@@ -293,7 +321,10 @@ inline void launcher( + mask_info, + nelem, + pa, +- rng_engine_inputs); ++ rng_engine_inputs, ++ virtual_idx_func, ++ is_sharded, ++ global_vec_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + break; + default: +@@ -309,7 +340,10 @@ inline void launcher( + mask_info, + nelem, + pa, +- rng_engine_inputs); ++ rng_engine_inputs, ++ virtual_idx_func, ++ is_sharded, ++ global_vec_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + fused_dropout_kernel +@@ -322,7 +356,10 @@ inline void launcher( + mask_info, + nelem, + pa, +- rng_engine_inputs); ++ rng_engine_inputs, ++ virtual_idx_func, ++ is_sharded, ++ global_vec_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + } +@@ -350,17 +387,58 @@ dropout_cuda(CUDAGeneratorImpl* gen, const Tensor& self, double p){ + //number of times random will be generated per thread, to offset philox counter in thc random state + int64_t counter_offset = ((nelem - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL; + PhiloxCudaState rng_engine_inputs; ++ uint64_t tensor_dim = 0; ++ uint64_t local_shape[MAX_DIMS]; ++ uint64_t global_offset[MAX_DIMS]; ++ uint64_t global_shape[MAX_DIMS]; ++ uint64_t global_strides[MAX_DIMS]; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_cuda_state(counter_offset); ++ tensor_dim = gen->get_sharding_spec(local_shape, global_offset, global_shape, global_strides); ++ } ++ uint64_t global_nelem = nelem; ++ uint64_t single_thread_n = grid.x * dim_block.x; ++ bool is_sharded = false; ++ int global_vec_size = -1; ++ if (tensor_dim > 0) { ++ global_nelem = 1; ++ is_sharded = true; ++ for (int i = 0; i < (int)tensor_dim; ++i) { ++ global_nelem *= global_shape[i]; ++ if (local_shape[i] == 0) ++ is_sharded = false; ++ } ++ dim3 single_grid((global_nelem + block_size - 1) / block_size); ++ single_grid.x = std::min( ++ (unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, ++ single_grid.x); ++ single_thread_n = single_grid.x * dim_block.x; ++ global_vec_size = 4; ++ while (global_vec_size > 1 && global_nelem % global_vec_size != 0) ++ global_vec_size /= 2; + } ++ TORCH_CHECK(single_thread_n > 0, "single_thread_n is 0!!!"); ++ ++ auto virtual_idx_func = [=]__device__(uint64_t local_entry_linear_idx) { ++ if (!is_sharded) // not a dtensor ++ return std::make_tuple(local_entry_linear_idx, single_thread_n); ++ uint64_t tmp_idx = local_entry_linear_idx; ++ uint64_t global_entry_linear_idx = 0; ++ for (int i = tensor_dim - 1; i >= 0; --i) { ++ uint64_t global_idx_at_i = global_offset[i] + tmp_idx % local_shape[i]; ++ tmp_idx /= local_shape[i]; ++ global_entry_linear_idx += global_idx_at_i * global_strides[i]; ++ } ++ return std::make_tuple(global_entry_linear_idx, single_thread_n); ++ }; + if (cuda::detail::canUse32BitIndexMath(self)){ + launcher( +- self, ret, mask, p, nelem, rng_engine_inputs, grid, dim_block); ++ self, ret, mask, p, nelem, rng_engine_inputs, grid, dim_block, virtual_idx_func, is_sharded, global_vec_size); + } else { + launcher( +- self, ret, mask, p, nelem, rng_engine_inputs, grid, dim_block); ++ self, ret, mask, p, nelem, rng_engine_inputs, grid, dim_block, virtual_idx_func, is_sharded, global_vec_size); + } + return std::tuple(ret, mask); + } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml -index 35a1049e209..604f53ac734 100644 +index 35a1049..604f53a 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2595,6 +2595,8 @@ @@ -139,13 +728,13 @@ index 35a1049e209..604f53ac734 100644 device_guard: False + dispatch: + CompositeExplicitAutograd: expand_as - + # decomposes to eye.m - func: eye(SymInt n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -3759,12 +3761,10 @@ - func: max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) device_check: NoCheck # TensorIterator - + -- func: value_selecting_reduction_backward(Tensor grad, int dim, Tensor indices, SymInt[] sizes, bool keepdim) -> Tensor +- func: value_selecting_reduction_backward(Tensor grad, int dim, Tensor indices, Tensor sizes, bool keepdim) -> Tensor variants: function @@ -153,13 +742,13 @@ index 35a1049e209..604f53ac734 100644 device_guard: False - dispatch: - CompositeImplicitAutograd: value_selecting_reduction_backward_symint - + - func: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor variants: function, method @@ -14225,6 +14225,13 @@ tags: view_copy autogen: expand_copy.out - + +- func: expand_as_copy(Tensor self, Tensor other) -> Tensor + variants: function + dispatch: @@ -171,13 +760,13 @@ index 35a1049e209..604f53ac734 100644 variants: function dispatch: diff --git a/test/distributed/_tensor/test_dtensor.py b/test/distributed/_tensor/test_dtensor.py -index a83efe539e4..e190c5b97d5 100644 +index a83efe5..e190c5b 100644 --- a/test/distributed/_tensor/test_dtensor.py +++ b/test/distributed/_tensor/test_dtensor.py @@ -109,6 +109,16 @@ class DTensorTest(DTensorTestBase): value_tensor = torch.empty_like(meta_dtensor.to_local()).fill_(1.5) self.assertEqual(meta_dtensor.to_local(), value_tensor) - + + @with_comms + def test_dtensor_local_tensor_storage(self): + device_mesh = self.build_device_mesh() @@ -192,26 +781,26 @@ index a83efe539e4..e190c5b97d5 100644 def test_modules_w_meta_dtensor(self): model = DummyMLP("meta") diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml -index 2c6886a36cc..6d651249354 100644 +index 2c6886a..6d65124 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -892,7 +892,7 @@ self: non_differentiable - + - name: kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) - self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) + self: value_selecting_reduction_backward(grad, dim, indices, self, keepdim) values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) - + - name: le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) @@ -1084,7 +1084,7 @@ result: linalg_matrix_exp_differential(self_p, self_t, /*adjoint*/ false) - + - name: max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) - self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) + self: value_selecting_reduction_backward(grad, dim, indices, self, keepdim) values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) - + - name: max(Tensor self) -> Tensor @@ -1132,15 +1132,15 @@ # The backward implementation is correct in the sense that it returns the @@ -220,53 +809,53 @@ index 2c6886a36cc..6d651249354 100644 - self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) + self: value_selecting_reduction_backward(grad, dim, indices, self, keepdim) values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) - + - name: nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) - self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) + self: value_selecting_reduction_backward(grad, dim, indices, self, keepdim) values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) - + - name: min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) - self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) + self: value_selecting_reduction_backward(grad, dim, indices, self, keepdim) values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) - + - name: min(Tensor self) -> Tensor @@ -1171,7 +1171,7 @@ result: at::mm(self_t, mat2_p) + at::mm(self_p, mat2_t) - + - name: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) - self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) + self: value_selecting_reduction_backward(grad, dim, indices, self, keepdim) values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) - + - name: mul.Tensor(Tensor self, Tensor other) -> Tensor @@ -1526,12 +1526,12 @@ output_differentiability: [True, False, False, False] # LU is an auxiliary tensor not exposed to the user - + - name: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) - self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), true) + self: value_selecting_reduction_backward(grad, dim, indices, self, true) output_differentiability: [True, False] values: gather_with_keepdimed_indices(self_t, dim, indices, true) - + - name: sort.stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) - self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), true) + self: value_selecting_reduction_backward(grad, dim, indices, self, true) output_differentiability: [True, False] values: gather_with_keepdimed_indices(self_t, dim, indices, true) - + @@ -1692,7 +1692,7 @@ result: auto_element_wise - + - name: topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) - self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), true) + self: value_selecting_reduction_backward(grad, dim, indices, self, true) output_differentiability: [True, False] values: gather(self_t, dim, indices) - + diff --git a/tools/autograd/gen_inplace_or_view_type.py b/tools/autograd/gen_inplace_or_view_type.py -index ee1075cbed9..fea1c399012 100644 +index ee1075c..fea1c39 100644 --- a/tools/autograd/gen_inplace_or_view_type.py +++ b/tools/autograd/gen_inplace_or_view_type.py @@ -315,6 +315,7 @@ def get_view_info(f: NativeFunction) -> Optional[str]: @@ -286,7 +875,7 @@ index ee1075cbed9..fea1c399012 100644 # [NOTE] [Nested Arg Types] # This is temporary. Nested tensors will be migrating to use SymInts and diff --git a/torch/_dynamo/variables/distributed.py b/torch/_dynamo/variables/distributed.py -index 54ad1cdf9b4..47605d96aed 100644 +index 54ad1cd..47605d9 100644 --- a/torch/_dynamo/variables/distributed.py +++ b/torch/_dynamo/variables/distributed.py @@ -24,9 +24,7 @@ class DistributedVariable(VariableTracker): @@ -297,22 +886,22 @@ index 54ad1cdf9b4..47605d96aed 100644 - - return inspect.isfunction(value) and value is DTensor.from_local + return inspect.isfunction(value) and value.__name__ == "from_local" - - + + def is_constant_pg_functions(value): @@ -57,17 +55,17 @@ class PlacementClassVariable(DistributedVariable): if not DistributedVariable.is_available(): return False - + - from torch.distributed._tensor.placement_types import Placement + if not isinstance(value, type): + return False + return value.__name__ in ("Placement", "Replicate", "Shard", "_Partial" "Partial", "InterleavedShard") - + - return type(value) is type and issubclass(value, Placement) + def as_python_constant(self): + return self.value - + def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": @@ -327,18 +916,18 @@ index 54ad1cdf9b4..47605d96aed 100644 @@ -90,9 +88,7 @@ class PlacementVariable(DistributedVariable): if not DistributedVariable.is_available(): return False - + - from torch.distributed._tensor.placement_types import Placement - - return isinstance(value, Placement) + return type(value).__name__ in ("Placement", "Replicate", "Shard", "_Partial" "Partial", "InterleavedShard") - + def as_python_constant(self): return self.value @@ -106,15 +102,30 @@ class PlacementVariable(DistributedVariable): ) -> "VariableTracker": from . import ConstantVariable - + - allowed_methods = ["__init__", "__setattr__"] - # placement types dynamo tracking allows only __init__ - # and __setattr__ methods, the latter is for case like `Shard(dim)` @@ -374,7 +963,7 @@ index 54ad1cdf9b4..47605d96aed 100644 except AttributeError: method = None @@ -123,7 +134,9 @@ class PlacementVariable(DistributedVariable): - + args = [x.as_python_constant() for x in args] kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} - method(self.value, *args, **kwargs) @@ -382,17 +971,17 @@ index 54ad1cdf9b4..47605d96aed 100644 + if name in return_constant_functions: + return ConstantVariable(out) return self - + return super().call_method(tx, name, args, kwargs) @@ -140,9 +153,7 @@ class DeviceMeshVariable(DistributedVariable): if not DistributedVariable.is_available(): return False - + - from torch.distributed.device_mesh import DeviceMesh - - return istype(value, DeviceMesh) + return type(value).__name__ == "DeviceMesh" - + def as_python_constant(self): return self.value @@ -150,6 +161,9 @@ class DeviceMeshVariable(DistributedVariable): @@ -403,10 +992,10 @@ index 54ad1cdf9b4..47605d96aed 100644 + return ConstantVariable.create(self.value.device_type) + return super().var_getattr(tx, name) - - + + @@ -198,9 +212,7 @@ class ProcessGroupVariable(DistributedVariable): - + def var_getattr(self, tx, name): if name in ["rank", "size"]: - return variables.LambdaVariable( @@ -415,15 +1004,15 @@ index 54ad1cdf9b4..47605d96aed 100644 + return variables.LambdaVariable(lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)) # TODO should this just raise unimplemented? return super().var_getattr(tx, name) - + diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py -index e5cf6f66730..755e28f331b 100644 +index e5cf6f6..755e28f 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -266,6 +266,64 @@ class NewGlobalVariable(VariableTracker): def __init__(self, **kwargs): super().__init__(**kwargs) - + +class BoundArgumentsVariable(VariableTracker): + """ + This class is used to hack python code about `inspect` package, and not well-designed. @@ -482,7 +1071,7 @@ index e5cf6f66730..755e28f331b 100644 + return variables.ConstantVariable.create(None) + return super().call_method(tx, name, args, kwargs) + - + class InspectSignatureVariable(VariableTracker): """represents inspect.signature(...)""" @@ -279,23 +337,52 @@ class InspectSignatureVariable(VariableTracker): @@ -496,7 +1085,7 @@ index e5cf6f66730..755e28f331b 100644 + self.python_signature = inspect.signature(self.inspected.fn) + else: + unimplemented("unsupported callable") - + def var_getattr(self, tx, name: str) -> "VariableTracker": if name == "parameters": + paramters = self.python_signature.parameters @@ -510,7 +1099,7 @@ index e5cf6f66730..755e28f331b 100644 user_cls=dict, ) return super().var_getattr(tx, name) - + + def call_method(self, tx, name, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]) -> VariableTracker: + if name == "bind": + # NOTE: InspectSignatureVariable only record the inspected user_method or function @@ -518,13 +1107,13 @@ index e5cf6f66730..755e28f331b 100644 + return BoundArgumentsVariable.create(self.python_signature.bind(*args, **kwargs)) + return super().call_method(tx, name, args, kwargs) + - + class InspectParameterVariable(VariableTracker): """This is not implemented, if used will graph break.""" + def __init__(self, value, **kwargs): + super().__init__(**kwargs) + self.value = value - + - pass + @staticmethod + def create(value, **kwargs): @@ -538,15 +1127,15 @@ index e5cf6f66730..755e28f331b 100644 + if name in ["kind", "name", "default"]: + return variables.ConstantVariable.create(getattr(self.value, name)) + return super().var_getattr(tx, name) - - + + def produce_trampoline_autograd_fwd(fn_cls): diff --git a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py -index 16eef07af02..ce82a2675d4 100644 +index 16eef07..ce82a26 100644 --- a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py +++ b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py @@ -102,9 +102,10 @@ def aot_dispatch_base_graph( - + # TODO: should factor this into a separate function for export that always only returns just the graph. if aot_config.is_export: - assert ( @@ -558,15 +1147,15 @@ index 16eef07af02..ce82a2675d4 100644 + # ), "aot_export_module does not support tensor subclass inputs for now." return fw_module return fw_module, list(updated_flat_args_subclasses_desugared), maybe_subclass_meta - + diff --git a/torch/_functorch/_aot_autograd/subclass_utils.py b/torch/_functorch/_aot_autograd/subclass_utils.py -index 0514c1c4d56..4d813fe64b5 100644 +index 0514c1c..4d813fe 100644 --- a/torch/_functorch/_aot_autograd/subclass_utils.py +++ b/torch/_functorch/_aot_autograd/subclass_utils.py @@ -16,6 +16,27 @@ from .utils import strict_zip - + zip = strict_zip - + +def is_dtensor_subclass_dispatch(args, fw_metadata: ViewAndMutationMeta) -> bool: + args_flattened = pytree.arg_tree_leaves(*args) + # NOTE: hack: separately check DTensor dispatch @@ -588,25 +1177,25 @@ index 0514c1c4d56..4d813fe64b5 100644 + ) + return any_subclass_args or any_subclass_outputs + - + def requires_subclass_dispatch(args, fw_metadata: ViewAndMutationMeta) -> bool: args_flattened = pytree.arg_tree_leaves(*args) diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py -index 837fe2ab4b6..b38b2c2bedc 100644 +index 837fe2a..b38b2c2 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -511,6 +511,8 @@ def create_aot_dispatcher_function( )(*fake_flat_args) - + req_subclass_dispatch = requires_subclass_dispatch(fake_flat_args, fw_metadata) + from ._aot_autograd.subclass_utils import is_dtensor_subclass_dispatch + dtensor_dispatch = is_dtensor_subclass_dispatch(fake_flat_args, fw_metadata) - + if needs_autograd and not any(x.requires_grad for x in fw_metadata.output_info): # We realized that none of the outputs require grad, @@ -568,7 +570,8 @@ Found a graph input that requires gradients, and received a mutation. This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue. - + fw_metadata={str(fw_metadata)}""") - if req_subclass_dispatch: + # NOTE: hack: make DTensor dispatch succeed! @@ -615,11 +1204,11 @@ index 837fe2ab4b6..b38b2c2bedc 100644 aot_export is not currently supported with traceable tensor subclass. If you need this feature, please comment on """) diff --git a/torch/_guards.py b/torch/_guards.py -index 69912b15313..4f00d53b88e 100644 +index 69912b1..4f00d53 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -817,8 +817,16 @@ def detect_fake_mode(inputs: Any = None): - + flat_inputs = pytree.tree_leaves(inputs) for i, flat_input in enumerate(flat_inputs): + from torch.utils._python_dispatch import is_traceable_wrapper_subclass @@ -632,11 +1221,11 @@ index 69912b15313..4f00d53b88e 100644 + inner_tensor = getattr(flat_input, attr) + if isinstance(inner_tensor, FakeTensor): + fake_modes.append((inner_tensor.fake_mode, "fake inner tensor input", i)) - + if fake_modes: fake_mode, desc1, i1 = fake_modes[0] diff --git a/torch/_tensor.py b/torch/_tensor.py -index 3aa0cee639d..dd76e76e841 100644 +index 3aa0cee..dd76e76 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -107,6 +107,7 @@ class Tensor(torch._C.TensorBase): @@ -648,7 +1237,7 @@ index 3aa0cee639d..dd76e76e841 100644 new_tensor = self.clone() if type(new_tensor) is not type(self): diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp -index ba0e913896d..0335434fbe5 100644 +index ba0e913..0335434 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -656,9 +656,9 @@ static PyObject* THPVariable_make_wrapper_subclass( @@ -662,7 +1251,7 @@ index ba0e913896d..0335434fbe5 100644 + ParsedArgs<15> parsed_args{}; auto r = parser.parse(args, kwargs, parsed_args); PyObject* cls = r.pyobject(0); - + @@ -726,8 +726,15 @@ static PyObject* THPVariable_make_wrapper_subclass( size_bytes, /*allocator=*/c10::GetAllocator(c10::kMeta), @@ -678,7 +1267,7 @@ index ba0e913896d..0335434fbe5 100644 + } else { + storage.set_data_ptr_noswap(at::DataPtr{nullptr, r.device(7)}); + } - + auto keys = c10::DispatchKeySet({options.computeDispatchKey()}); if (auto mb_extra_keys = r.toDispatchKeySetOptional(13)) { @@ -2210,4 +2217,4 @@ bool THPVariable_initModule(PyObject* module) { @@ -689,13 +1278,13 @@ index ba0e913896d..0335434fbe5 100644 +} \ No newline at end of file diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py -index a0e02292cfe..f76fded484e 100644 +index a0e0229..f76fded 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -128,6 +128,62 @@ def wait_tensor(tensor): return torch.ops.c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] - - + + +def send(self: torch.Tensor, dst: int, group: RANK_TYPES, tag: str = ""): + """ + Sends the tensor to the destination process. @@ -756,9 +1345,9 @@ index a0e02292cfe..f76fded484e 100644 """ Broadcasts the tensor to all processes in the given process group. @@ -542,6 +598,23 @@ def _all_gather_into_tensor_coalesced_meta(self, tag, rankset, group_size): - + return [mk_out_tensor(t) for t in self] - + + +def _send_meta(self, *args): + return torch.empty_like(self) @@ -780,7 +1369,7 @@ index a0e02292cfe..f76fded484e 100644 def _broadcast_meta(self, *args): return torch.empty_like(self) @@ -619,6 +692,10 @@ def _reduce_scatter_tensor_coalesced_native_meta(inputs, reduce_op, group_size, - + def _register_ops(): ops_defs = [ + "send(Tensor self, int dst, str tag, int[] ranks, int group_size) -> Tensor", @@ -791,13 +1380,13 @@ index a0e02292cfe..f76fded484e 100644 "all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", "all_reduce_coalesced(Tensor[] self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", diff --git a/torch/distributed/_functional_collectives_impl.py b/torch/distributed/_functional_collectives_impl.py -index f14ad5b067e..04445656e75 100644 +index f14ad5b..0444565 100644 --- a/torch/distributed/_functional_collectives_impl.py +++ b/torch/distributed/_functional_collectives_impl.py @@ -138,6 +138,37 @@ def _str_to_reduce_op(reduceOp: str) -> dist.ReduceOp: raise ValueError(f"Invalid reduce operation {reduceOp}") return cast(dist.ReduceOp, op) - + +def _send(self, dst, tag, ranks, group_size): + group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size) + assert group is not None @@ -829,11 +1418,11 @@ index f14ad5b067e..04445656e75 100644 + _register_tensor_work(self, work) + return self + - + """ Kernel implementations (for eager runtime only) - should never be traced by torch.compile diff --git a/torch/distributed/_tensor/api.py b/torch/distributed/_tensor/api.py -index 068bc8b9af8..5a577046244 100644 +index 068bc8b..5a57704 100644 --- a/torch/distributed/_tensor/api.py +++ b/torch/distributed/_tensor/api.py @@ -233,6 +233,7 @@ class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__ @@ -842,5 +1431,5 @@ index 068bc8b9af8..5a577046244 100644 requires_grad=requires_grad, + data_ptr=local_tensor.data_ptr(), ) - + tensor_meta = TensorMeta(shape, stride, dtype)