forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
TensorIterator.cpp
1475 lines (1344 loc) · 53.3 KB
/
TensorIterator.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include <ATen/native/TensorIterator.h>
#include <array>
#include <ATen/ExpandUtils.h>
#include <ATen/Parallel.h>
#include <ATen/native/TypeProperties.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/native/Resize.h>
#include <ATen/TensorOperators.h>
#include <c10/util/irange.h>
namespace at {
using DimMask = TensorIteratorBase::DimMask;
using PtrVector = TensorIteratorBase::PtrVector;
using loop2d_t = TensorIteratorBase::loop2d_t;
using StrideVector = TensorIteratorBase::StrideVector;
/// Construction
TensorIteratorConfig& TensorIteratorConfig::add_output(const Tensor& output) {
TORCH_INTERNAL_ASSERT(num_inputs_ == 0);
tensors_.push_back(c10::MaybeOwned<Tensor>::owned(c10::in_place, output));
num_outputs_++;
return *this;
}
TensorIteratorConfig& TensorIteratorConfig::add_input(const Tensor& input) {
tensors_.push_back(c10::MaybeOwned<Tensor>::owned(c10::in_place, input));
num_inputs_++;
return *this;
}
TensorIteratorConfig& TensorIteratorConfig::add_borrowed_output(const Tensor& output) {
TORCH_INTERNAL_ASSERT(num_inputs_ == 0);
tensors_.push_back(c10::MaybeOwned<Tensor>::borrowed(output));
num_outputs_++;
return *this;
}
TensorIteratorConfig& TensorIteratorConfig::add_borrowed_input(const Tensor& input) {
tensors_.push_back(c10::MaybeOwned<Tensor>::borrowed(input));
num_inputs_++;
return *this;
}
TensorIteratorConfig& TensorIteratorConfig::declare_static_dtype_and_device(ScalarType dtype, Device device) {
TORCH_CHECK(!check_all_same_dtype_, "check_all_same_dtype(false) must be called before declare_static_dtype(...)");
static_dtype_and_device_ = c10::make_optional(std::make_pair(dtype, device));
return *this;
}
TensorIteratorConfig& TensorIteratorConfig::declare_static_shape(IntArrayRef shape) {
// WARNING:
// This will bypass all shape checking in the TensorIterator. Kernels which call this method
// are expected to check shapes before calling `add_input` or `add_output`.
TORCH_CHECK(!resize_outputs_, "resize_outputs() must be called before declare_static_shape(...)")
static_shape_ = c10::make_optional(DimVector(shape));
return *this;
}
TensorIteratorConfig& TensorIteratorConfig::declare_static_shape(IntArrayRef shape, IntArrayRef squash_dims) {
declare_static_shape(shape);
if (!static_shape_->size()) return *this;
for (const auto& squash_dim : squash_dims) {
TORCH_CHECK(squash_dim >= 0 && squash_dim < static_cast<int64_t>(static_shape_->size()),
"squash_dim ", squash_dim, " must be in [0, ", static_shape_->size(), ").");
(*static_shape_)[squash_dim] = 1;
}
return *this;
}
// NOTE: [Computing output strides]
// We use the following algorithm to compute output strides
// If correctly sized output is provided, we respect its stides and don't change them
// Otherwise, if provided output is of incorrect size or no output is provided,
// we try to recover permutation that was applied to the inputs
// by sorting the strides of the inputs. Precedence is given to the inputs in the order they were added,
// and to permutations involving non-broadcasted dimensions
// 1. we loop over inputs starting from the first
// 2. for all inputs strides of broadcasted dimensions are set to 0, and 0 compares equal to anything. If one
// of the dimensions being compared has a stride of 0, we move on to the next tensor to determine if
// these dimensions need to be swapped.
// 3. strides of dimensions equal to 1 participate in sorting
// 4. if 2 strides are equal and neither is 0, we try to break the tie by looking at the corresponding dimensions
// of the tensor. Dimensions were permuted if, when iterating from the end, dimensions corresponding to the
// same strides are increasing. If dimensions are non-increasing, we move on to the next input to break the tie.
//
// Instead of applying rule 4 for tie breaking, we could move on to the next tensor directly. This would result in possibly
// losing the correct permuation of the first tensor if there are permuted trivial dimensions, but could potentially
// improve traversal order of the second tensor. We chose the former option to better propagate channels last layout
// for example for a tensor with the sizes N1H1
// These rules result in the intuitive behavior that in most cases recovers permutation of either the first argument (if all
// arguments are of the same size) or the argument that is not broadcasted, regardless of its position.
// As a bonus, it also result in reasonably well-behaved traversal order of the inputs and outputs - in the kernels
// output is traversed linearly, and since it closely follows input layouts, inputs are traversed linearly as well
//
// Examples:
// full size tensor + broadcasted tensor with 0 or 1 non-trivial dimensions => strides of output are same
// as strides of full size input regardless of the order
// 2 tensors of same size but different strides => output strides are the same as first argument
//
// We also have fast path for memory-dense inputs with the same strides (or, trivially, single memory-dense input)
// that outputs a tensor with the same strides as inputs. The only difference in result with the algorithm described
// above is for strides for trivial (1) dimensions, where in ambiguous cases for performance reasons we default to
// contiguous strides.
// Example: tensor with sizes NC11 and strides C1CC will produce output with strides C111 (note differences are only
// in the strides of trivial dimensions, so physical layout is unaffected but permutation information is lost)
// We might change this behavior in future once performance considerations are resolved
void TensorIteratorBase::reorder_dimensions() {
// Sort the dimensions based on strides in ascending order with reduced dims
// at the front. NOTE: that this inverts the order of C-contiguous tensors.
// strides[0] is the fastest moving dimension instead of strides[ndim - 1].
// See NOTE: [Computing output strides] and inline comments for more detailed description
perm_.resize(ndim());
if (ndim() == 1) {
perm_[0] = 0;
return;
}
// initialize perm with n-1, n-2, ..., 1, 0
std::iota(perm_.rbegin(), perm_.rend(), 0);
// returns 1 if the dim0 should come after dim1, -1 if dim0 should come
// before dim1, and 0 if the comparison is ambiguous.
auto should_swap = [&](size_t dim0, size_t dim1) {
for (int arg = 0; arg < ntensors(); arg++) {
// ignore undefined or incorrectly sized tensors
if (operands_[arg].stride_bytes.empty() || operands_[arg].will_resize) {
continue;
}
int64_t stride0 = operands_[arg].stride_bytes[dim0];
int64_t stride1 = operands_[arg].stride_bytes[dim1];
if (is_reduction_ && operands_[arg].is_output) {
// move reduced dimensions to the front
// strides of reduced dimensions are always set to 0 by review_reduce_result
if ((stride0 == 0) != (stride1 == 0)) {
return stride1 == 0 ? 1 : -1;
}
}
//move on to the next input if one of the dimensions is broadcasted
if (stride0 == 0 || stride1 == 0) {
continue;
// it is important to return here only with strict comparisons, for equal strides we try to break the tie later
// by comparing corresponding dimensions or if that does not work, moving on to the next tensor
} else if (stride0 < stride1) {
return -1;
} else if (stride0 > stride1) {
return 1;
} else { //equal strides, use dimensions themselves as the tie-breaker.
//at this point, with zero strides out of the way, we are guaranteed that operand dimensions are equal to shape_
auto t_dim0 = shape_[dim0];
auto t_dim1 = shape_[dim1];
//return only if dimensions should be swapped, otherwise move on to the next tensor
if (t_dim0 > t_dim1) {
return 1;
}
}
}
return 0;
};
// insertion sort with support for ambiguous comparisons
for (int i = 1; i < ndim(); i++) {
int dim1 = i;
for (int dim0 = i - 1; dim0 >= 0; dim0--) {
int comparison = should_swap(perm_[dim0], perm_[dim1]);
if (comparison > 0) {
std::swap(perm_[dim0], perm_[dim1]);
dim1 = dim0;
} else if (comparison < 0) {
break;
}
}
}
// perform re-ordering of shape and strides
permute_dimensions(perm_);
}
// Computes a common dtype using type promotion
// See the [Common Dtype Computation] note
ScalarType TensorIteratorBase::compute_common_dtype() {
at::native::ResultTypeState state = {};
for (const auto& op : operands_) {
if (op.is_output) {
continue;
}
state = at::native::update_result_type_state(*op.tensor, state);
}
common_dtype_ = at::native::result_type(state);
TORCH_INTERNAL_ASSERT(common_dtype_ != ScalarType::Undefined);
return common_dtype_;
}
TensorOptions original_options(const OperandInfo& op) {
if (op.original_tensor->defined()) {
return op.original_tensor->options();
} else {
return op.options();
}
}
// Implements the the behavior of the following flags:
// - check_all_same_dtype_
// - check_all_same_device_
// - enforce_safe_casting_to_output_
// - promote_inputs_to_common_dtype_
// - cast_common_dtype_to_outputs_
//
// See their descriptions in TensorIterator.h for details.
// NOTE: Checks for more specific behaviors (e.g. the first and second
// inputs must share a dtype, but the third must have the long dtype)
// should be implemented directly and outside of TensorIterator.
void TensorIteratorBase::compute_types(const TensorIteratorConfig& config) {
// Reviews operands (1/2)
// - validates that all input tensors are defined
// - computes common device
// - determines if there are undefined outputs
// - determines if there are different dtypes and attempts
// to quickly acquire a common dtype
Device common_device = kCPU;
common_dtype_ = ScalarType::Undefined;
// NB: despite output_dtype's generic sounding name, it only is
// used in a nontrivial way if check_all_same_dtype is true
ScalarType output_dtype = ScalarType::Undefined;
bool has_different_input_dtypes = false;
bool has_different_output_dtypes = false;
bool has_undefined_outputs = false;
for (auto& op : operands_) {
// Validates that all inputs have type information, and that
// if an output is missing type information that we can infer
// the device it should be allocated on.
if (!op.is_type_defined()) {
TORCH_INTERNAL_ASSERT(op.is_output, "Found type undefined input tensor!");
if (config.static_dtype_and_device_.has_value()) {
op.target_dtype = config.static_dtype_and_device_->first;
op.device = config.static_dtype_and_device_->second;
} else {
TORCH_INTERNAL_ASSERT(config.check_all_same_device_);
has_undefined_outputs = true;
continue;
}
}
// Validates input tensors are defined
if (!op.tensor->defined()) {
TORCH_INTERNAL_ASSERT(op.is_output, "Found undefined input tensor!");
continue;
}
TORCH_INTERNAL_ASSERT(op.target_dtype == op.current_dtype)
// Acquires the first non-CPU device (if any) as the common device
if (common_device == kCPU && !op.tensor->is_cpu()) {
common_device = op.tensor->device();
}
if (!op.is_output) {
// Determines if there are varying input dtypes
// NOTE: the common dtype is set to the first defined input dtype observed
if (op.target_dtype != common_dtype_) {
if (common_dtype_ == ScalarType::Undefined) {
common_dtype_ = op.target_dtype;
} else {
has_different_input_dtypes = true;
}
}
} else { // op.is_output
// Determines if there are varying output dtypes
// NOTE: the output dtype is set to the first defined output dtype observed
if (op.target_dtype != output_dtype) {
if (output_dtype == ScalarType::Undefined) {
output_dtype = op.target_dtype;
} else {
has_different_output_dtypes = true;
}
}
}
}
// Checks that either the computation type is computable or unneeded
TORCH_INTERNAL_ASSERT(!(has_different_input_dtypes && !config.promote_inputs_to_common_dtype_ &&
(has_undefined_outputs || config.enforce_safe_casting_to_output_ ||
config.cast_common_dtype_to_outputs_)));
// Checks that all inputs and defined outputs are the same dtype, if requested
if (config.check_all_same_dtype_ &&
(has_different_input_dtypes || has_different_output_dtypes ||
(common_dtype_ != output_dtype && output_dtype != ScalarType::Undefined))) {
// Throws an informative error message
for (auto& op : operands_) {
if (!op.tensor->defined()) {
continue;
}
TORCH_CHECK(op.target_dtype == common_dtype_,
"Found dtype ", op.target_dtype, " but expected ", common_dtype_);
}
}
// Short-circuits if no additional work required
if (!has_undefined_outputs && !config.check_all_same_device_ &&
!config.promote_inputs_to_common_dtype_ && !config.cast_common_dtype_to_outputs_ &&
!config.enforce_safe_casting_to_output_) {
// Invalidates common_dtype_ if it could not be inferred
common_dtype_ = has_different_input_dtypes ? ScalarType::Undefined : common_dtype_;
return;
}
// Computes a common dtype, if needed
if (has_different_input_dtypes && config.promote_inputs_to_common_dtype_) {
common_dtype_ = compute_common_dtype();
}
// Promotes common dtype to the default float scalar type, if needed
if (config.promote_integer_inputs_to_float_ &&
c10::isIntegralType(common_dtype_, /*includeBool=*/true)) {
common_dtype_ = c10::typeMetaToScalarType(c10::get_default_dtype());
}
// Reviews operands (2/2)
// - sets metadata for undefined outputs
// - checks that all tensors are on the same device, if requested
// - checks that the common dtype can safely cast to each output, if requested
// - creates temporaries for CPU operations, if needed and requested
int max_cpu_scalars_on_non_cpu = config.allow_cpu_scalars_ ? 1 : 0;
int current_cpu_scalars_on_non_cpu = 0;
for (auto& op : operands_) {
if (!op.is_type_defined()) {
op.target_dtype = common_dtype_;
op.device = common_device;
continue;
}
// Skips undefined tensors
if (!op.tensor->defined()) {
continue;
}
// Checks all tensors are on the same device, if requested
if (config.check_all_same_device_) {
// Handles CPU scalars on CUDA kernels that support them
if (!common_device.is_cpu() &&
config.allow_cpu_scalars_ && !op.is_output && op.tensor->dim() == 0 &&
op.tensor->is_cpu()) {
TORCH_CHECK(current_cpu_scalars_on_non_cpu < max_cpu_scalars_on_non_cpu,
"Trying to pass too many CPU scalars to non-CPU kernel!");
++current_cpu_scalars_on_non_cpu;
} else if (op.device != common_device) {
TORCH_CHECK(false,
"Expected all tensors to be on the same device, but "
"found at least two devices, ", common_device, " and ", op.device, "!");
}
}
// Checks safe casting, if requested
if (config.enforce_safe_casting_to_output_ && op.is_output && op.current_dtype != common_dtype_) {
TORCH_CHECK(canCast(common_dtype_, op.current_dtype),
"result type ", common_dtype_, " can't be cast to the "
"desired output type ", op.current_dtype);
}
// Creates temporaries for CPU operations, if needed and requested
// TODO: reuse temporaries when possible (e.g. for inplace operations)
if (common_device == kCPU) {
// Casts to outputs by creating temporaries of the correct dtype (if needed)
// NB: we skip this on is_meta_, because the temporary allocation here is
// unnecessary if we aren't going to actually do the compute
if (config.cast_common_dtype_to_outputs_ && op.is_output && op.current_dtype != common_dtype_ && !is_meta_) {
TORCH_INTERNAL_ASSERT(op.tensor->defined());
// Marker [Output original_tensor is set]
op.original_tensor = op.tensor;
// NB: do NOT use set_output here, as the temporary is NOT a true output;
// op.tensor is the true output and it was pre-provided for us.
// TODO: The logic for cast_outputs will need to be handled by the
// structured kernels implementation. What probably should happen
// is that we pass in the inferred dtype into the out kernel, and
// then after calling the out kernel, do the conversion (which
// is cast_outputs here), but integrating this with existing
// TensorIterator will take a little doing
op.tensor = c10::MaybeOwned<Tensor>::owned(
at::empty_like(*op.tensor,
op.tensor->options().dtype(common_dtype_),
LEGACY_CONTIGUOUS_MEMORY_FORMAT));
if (!names_.empty()) {
namedinference::propagate_names(*op.tensor, names_);
}
op.current_dtype = common_dtype_;
op.target_dtype = common_dtype_;
}
// Promotes inputs by creating temporaries of the correct dtype
if (config.promote_inputs_to_common_dtype_ && !op.is_output && op.current_dtype != common_dtype_) {
op.original_tensor = op.tensor;
op.tensor = c10::MaybeOwned<Tensor>::owned(op.tensor->to(common_dtype_));
op.current_dtype = common_dtype_;
op.target_dtype = common_dtype_;
}
}
common_device_ = common_device;
}
}
StrideVector TensorIteratorBase::compatible_stride(int element_size) const {
auto stride = StrideVector();
int64_t next_stride = element_size;
for (int dim = 0; dim < ndim(); dim++) {
stride.push_back(next_stride);
next_stride *= shape_[dim];
}
return stride;
}
DimVector TensorIteratorBase::invert_perm(IntArrayRef input) const {
// Invert the permutation caused by reorder_dimensions. This is not valid
// after coalesce_dimensions is called.
TORCH_INTERNAL_ASSERT(!has_coalesced_dimensions_);
TORCH_INTERNAL_ASSERT(input.size()==perm_.size());
auto res = DimVector(input.size()); //no initialization needed, every value in res should be written to.
for (int dim = 0; dim < ndim(); dim++) {
res[perm_[dim]] = input[dim];
}
return res;
}
void TensorIteratorBase::allocate_or_resize_outputs() {
for (int i = 0; i < num_outputs_; i++) {
auto& op = operands_[i];
if (!op.tensor->defined() || op.will_resize) {
TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i);
int element_size = elementSize(op.target_dtype);
op.stride_bytes = compatible_stride(element_size);
// check if permutation is just an inverted order
bool inverted = true;
for (int i = 0; i < ndim(); i++) {
if (perm_[i] != ndim() - i - 1) {
inverted = false;
break;
}
}
auto tensor_shape = invert_perm(shape_);
if (inverted) {
// can just return contiguous output
// it is faster because it avoids allocating 0 size tensor and
// resizing and restriding it
set_output(i, tensor_shape, {}, original_options(op), names_);
} else {
auto tensor_stride = invert_perm(op.stride_bytes);
for (int dim = 0; dim < ndim(); dim++) {
tensor_stride[dim] /= element_size;
}
set_output(i, tensor_shape, tensor_stride, original_options(op), names_);
}
op.current_dtype = op.target_dtype;
} else if (op.tensor->defined()) {
// Even if we don't resize, we still need to tell set_output about
// the output, so that we properly set guard and propagate names
set_output(i, op.tensor->sizes(), {}, original_options(op), names_);
}
}
}
void TensorIteratorBase::compute_names(const TensorIteratorConfig& config) {
bool should_infer_names = std::any_of(
operands_.begin(),
operands_.end(),
[](const OperandInfo& op) {
return op.tensor->defined() && op.tensor->has_names();
});
if (!should_infer_names) {
return;
}
for (auto& op : operands_) {
if (!op.tensor->defined()) continue;
// Don't include output tensors if we are resizing, since we will
// clobber their names in any case. (If the output tensor was
// also an input tensor, we'll pick it up when it shows up again
// in operands).
if (config.resize_outputs_ && op.is_output) continue;
// perform name inference
if (names_.empty()) {
names_ = op.tensor->names();
} else {
names_ = NameVector(unify_from_right(names_, op.tensor->names()));
}
}
}
void TensorIteratorBase::coalesce_dimensions() {
if (ndim() <= 1) {
return;
}
// We can coalesce two adjacent dimensions if either dim has size 1 or if:
// shape[n] * stride[n] == shape[n + 1].
auto can_coalesce = [&](int dim0, int dim1) {
auto shape0 = shape_[dim0];
auto shape1 = shape_[dim1];
if (shape0 == 1 || shape1 == 1) {
return true;
}
for (int i = 0; i < ntensors(); i++) {
auto& stride = operands_[i].stride_bytes;
if (shape0 * stride[dim0] != stride[dim1]) {
return false;
}
}
return true;
};
// replace each operands stride at dim0 with its stride at dim1
auto replace_stride = [&](int dim0, int dim1) {
for (int i = 0; i < ntensors(); i++) {
auto& stride = operands_[i].stride_bytes;
stride[dim0] = stride[dim1];
}
};
int prev_dim = 0;
for (int dim = 1; dim < ndim(); dim++) {
if (can_coalesce(prev_dim, dim)) {
if (shape_[prev_dim] == 1) {
replace_stride(prev_dim, dim);
}
shape_[prev_dim] *= shape_[dim];
} else {
prev_dim++;
if (prev_dim != dim) {
replace_stride(prev_dim, dim);
shape_[prev_dim] = shape_[dim];
}
}
}
shape_.resize(prev_dim + 1);
for (int i = 0; i < ntensors(); i++) {
operands_[i].stride_bytes.resize(ndim());
}
has_coalesced_dimensions_ = true;
}
int64_t TensorIteratorBase::numel() const {
int64_t numel = 1;
for (int64_t size : shape_) {
numel *= size;
}
return numel;
}
StrideVector TensorIteratorBase::get_dim_strides(int dim) const {
auto dims = ndim();
auto inner_strides = StrideVector();
for (auto& op : operands_) {
inner_strides.push_back(dims == 0 ? 0 : op.stride_bytes[dim]);
}
return inner_strides;
}
SmallVector<char*, 4> TensorIteratorBase::get_data_ptrs(ArrayRef<char*> base, IntArrayRef counter) const {
auto ptrs = SmallVector<char*, 4>(base);
for (int dim = 0; dim < ndim(); dim++) {
int64_t value = counter[dim];
for (int arg = 0; arg < ntensors(); arg++) {
ptrs[arg] += value * operands_[arg].stride_bytes[dim];
}
}
return ptrs;
}
SmallVector<char*, 4> TensorIteratorBase::get_base_ptrs() const {
auto ptrs = SmallVector<char*, 4>();
for (int i = 0; i < ntensors(); i++) {
ptrs.push_back((char*)data_ptr(i));
}
return ptrs;
}
bool TensorIteratorBase::is_dim_reduced(int dim) const {
for (auto& op : operands_) {
if (op.is_output && op.stride_bytes[dim] == 0 && shape_[dim] > 1) {
return true;
}
}
return false;
}
void TensorIteratorBase::permute_dimensions(IntArrayRef perm) {
TORCH_INTERNAL_ASSERT(perm.size() == static_cast<unsigned>(ndim()));
auto reorder = [perm](IntArrayRef data) {
auto res = DimVector(data.size(), 0);
for (size_t i = 0; i < perm.size(); i++) {
res[i] = data[perm[i]];
}
return res;
};
// Update shape and strides
shape_ = reorder(shape_);
for (auto& op : operands_) {
if (op.stride_bytes.size() > 0) {
op.stride_bytes = reorder(op.stride_bytes);
}
}
}
int64_t TensorIteratorBase::num_output_elements() const {
int64_t elem = 1;
for (int dim = 0; dim < ndim(); dim++) {
if (operands_[0].stride_bytes[dim] != 0 || shape_[dim] == 0) {
elem *= shape_[dim];
}
}
return elem;
}
int TensorIteratorBase::num_reduce_dims() const {
int count = 0;
for (int dim = 0; dim < ndim(); dim++) {
if (operands_[0].stride_bytes[dim] == 0) {
count++;
}
}
return count;
}
void TensorIteratorBase::for_each(loop2d_t loop, int64_t grain_size) {
int64_t numel = this->numel();
if (numel == 0) {
return;
} else if (numel < grain_size || at::get_num_threads() == 1) {
return serial_for_each(loop, {0, numel});
} else {
at::parallel_for(0, numel, grain_size, [&](int64_t begin, int64_t end) {
serial_for_each(loop, {begin, end});
});
}
}
StrideVector TensorIteratorBase::get_strides() const {
StrideVector strides;
for (int dim = 0; dim < ndim(); dim++) {
for (int arg = 0; arg < ntensors(); arg++) {
strides.push_back(operands_[arg].stride_bytes[dim]);
}
}
return strides;
}
void TensorIteratorBase::serial_for_each(loop2d_t loop, Range range) const {
if (range.size() == 0) {
return;
}
auto strides = get_strides();
while (strides.size() < 2U * ntensors()) {
strides.push_back(0);
}
auto base_ptrs = get_base_ptrs();
if (ndim() <= 1) {
auto ptrs = get_data_ptrs(base_ptrs, { range.begin });
loop(ptrs.data(), strides.data(), range.size(), 1);
} else {
auto counter = DimCounter(shape_, range);
while (!counter.is_done()) {
auto ptrs = get_data_ptrs(base_ptrs, counter.values);
auto step = counter.max_2d_step();
loop(ptrs.data(), strides.data(), step[0], step[1]);
counter.increment(step);
}
}
}
bool TensorIteratorBase::is_trivial_1d() const {
// TODO: check for casting once it's supported
return ndim() == 1;
}
bool TensorIteratorBase::is_contiguous() const {
if (numel() == 1) {
return true;
}
if (ndim() != 1) {
return false;
}
return has_contiguous_first_dim();
}
bool TensorIteratorBase::is_scalar(int arg) const {
const auto& stride = operands_[arg].stride_bytes;
for (int i = 0; i < ndim(); i++) {
if (stride[i] != 0 && shape_[i] != 1) {
return false;
}
}
return true;
}
bool TensorIteratorBase::is_cpu_scalar(int arg) const {
return is_scalar(arg) && device(arg).is_cpu();
}
void TensorIteratorBase::cast_outputs() {
for (auto& op : operands_) {
if (op.is_output && op.original_tensor->defined() &&
op.original_tensor->scalar_type() != op.current_dtype) {
// TODO: Now that set_output resizes both the original_tensor
// and tensor, this condition should no longer ever be true
if (op.original_tensor->sizes() != op.tensor->sizes()){
op.original_tensor->resize_as_(*op.tensor).as_strided_(op.tensor->sizes(), op.tensor->strides());
}
op.original_tensor->copy_(*op.tensor);
op.tensor = op.original_tensor;
}
}
}
void* TensorIteratorBase::data_ptr(int arg) const {
return operands_[arg].data;
}
void TensorIteratorBase::remove_operand(int arg) {
operands_.erase(operands_.begin() + arg);
}
void TensorIteratorBase::unsafe_replace_operand(int arg, void* data) {
operands_[arg].data = data;
}
void TensorIteratorBase::narrow(int dim, int64_t start, int64_t size) {
TORCH_INTERNAL_ASSERT(dim < ndim() && size >= 1);
shape_[dim] = size;
view_offsets_[dim] += start;
for (auto& op : operands_) {
op.data = ((char*)op.data) + op.stride_bytes[dim] * start;
}
if (size == 1 && !is_reduction_) {
coalesce_dimensions();
}
}
void TensorIteratorBase::select_all_keeping_dim(int start_dim, IntArrayRef indices) {
TORCH_INTERNAL_ASSERT(start_dim <= ndim());
for (int i = start_dim; i < ndim(); ++i) {
for (auto& op : operands_) {
op.data = ((char*)op.data) + op.stride_bytes[i] * indices[i - start_dim];
}
shape_[i] = 1;
}
}
// Helper to construct a binary op that promotes integer inputs to float.
void TensorIteratorBase::build_binary_float_op(const Tensor& out, const Tensor& a, const Tensor& b) {
build(TensorIteratorConfig()
.set_check_mem_overlap(true)
.add_output(out)
.add_input(a)
.add_input(b)
.allow_cpu_scalars(true)
.promote_inputs_to_common_dtype(true)
.cast_common_dtype_to_outputs(true)
.enforce_safe_casting_to_output(true)
.promote_integer_inputs_to_float(true));
}
// This cannot be a function because TensorIteratorConfig is not
// copyable or movable, so it can't be returned from the function.
#define BINARY_OP_CONFIG() \
TensorIteratorConfig() \
.set_check_mem_overlap(true) \
.allow_cpu_scalars(true) \
.promote_inputs_to_common_dtype(true) \
.cast_common_dtype_to_outputs(true) \
.enforce_safe_casting_to_output(true) \
void TensorIteratorBase::build_binary_op(const Tensor& out, const Tensor& a, const Tensor& b) {
build(BINARY_OP_CONFIG()
.add_output(out)
.add_input(a)
.add_input(b));
}
void TensorIteratorBase::build_borrowing_binary_op(const Tensor& out, const Tensor& a, const Tensor& b) {
build(BINARY_OP_CONFIG()
.add_borrowed_output(out)
.add_borrowed_input(a)
.add_borrowed_input(b));
}
void TensorIteratorBase::build_unary_float_op(const Tensor& out, const Tensor& a) {
build(TensorIteratorConfig()
.set_check_mem_overlap(true)
.add_output(out)
.add_input(a)
.promote_inputs_to_common_dtype(true)
.cast_common_dtype_to_outputs(true)
.enforce_safe_casting_to_output(true)
.promote_integer_inputs_to_float(true));
}
void TensorIteratorBase::build_unary_op(const Tensor& out, const Tensor& a) {
build(TensorIteratorConfig()
.set_check_mem_overlap(true)
.add_output(out)
.add_input(a)
.cast_common_dtype_to_outputs(false)
.enforce_safe_casting_to_output(false)
.check_all_same_dtype(true));
}
TensorIterator TensorIterator::binary_op(Tensor& out, const Tensor& a, const Tensor& b) {
TensorIterator iter;
iter.build_binary_op(out, a, b);
return iter;
}
TensorIterator TensorIterator::binary_float_op(Tensor& out, const Tensor& a, const Tensor& b) {
TensorIterator iter;
iter.build_binary_float_op(out, a, b);
return iter;
}
TensorIterator TensorIterator::comparison_op(Tensor& out, const Tensor& a,
const Tensor& b) {
// Note [special-case bool outputs]
// We explicitly don't call `cast_common_dtype_to_outputs` when the output tensor
// has `bool` dtype. This is a performance optimization: the functional
// version of all comparison/logical ops uses a bool output tensor, and we'd like to
// avoid creating a temporary copy of the output.
// However, note that all kernels using this TensorIterator will need to special-case when
// the output tensor has bool dtype, and provide a lambda of type (scalar_t, scalar_t -> bool).
if (out.scalar_type() == kBool) {
return TensorIteratorConfig()
.set_check_mem_overlap(true)
.add_output(out)
.add_input(a)
.add_input(b)
.allow_cpu_scalars(true)
.promote_inputs_to_common_dtype(true)
.build();
} else {
return TensorIteratorConfig()
.set_check_mem_overlap(true)
.add_output(out)
.add_input(a)
.add_input(b)
.allow_cpu_scalars(true)
.promote_inputs_to_common_dtype(true)
.cast_common_dtype_to_outputs(true)
.build();
}
}
TensorIterator TensorIterator::unary_op(Tensor& out, const Tensor& a) {
TensorIterator iter;
iter.build_unary_op(out, a);
return iter;
}
TensorIterator TensorIterator::unary_float_op(Tensor& out, const Tensor& a) {
TensorIterator iter;
iter.build_unary_float_op(out, a);
return iter;
}
TensorIterator TensorIterator::nullary_op(Tensor& out) {
return TensorIteratorConfig()
.set_check_mem_overlap(true)
.check_all_same_dtype(false)
.add_output(out)
// FIXME: workaround for bug: https://github.com/pytorch/pytorch/issues/20342
.resize_outputs(false)
.build();
}
TensorIterator TensorIterator::reduce_op(Tensor& out, const Tensor& a) {
TORCH_INTERNAL_ASSERT(out.defined());
return TensorIteratorConfig()
.set_check_mem_overlap(false)
.add_output(out)
.add_input(a)
.resize_outputs(false)
.is_reduction(true)
// TODO: not supporting casting to outputs is only really necessary for arg{min,max}
.promote_inputs_to_common_dtype(true)
.build();
}
TensorIterator TensorIterator::reduce_op(Tensor& out1, Tensor& out2, const Tensor& a) {
TORCH_INTERNAL_ASSERT(out1.defined());
TORCH_INTERNAL_ASSERT(out2.defined());
TORCH_CHECK(a.device() == out1.device() && out1.device() == out2.device(),
"reduce_op(): expected input and both outputs to be on same device, but input is on ", a.device(),
", output1 is on ", out1.device(), " and output2 is on", out2.device());
TORCH_CHECK(out1.dim() == out2.dim(), "reduce_op(): expected both outputs to have same number of dims, but output1 has ", out1.dim(),
" and output2 has ", out2.dim());
TORCH_CHECK(out1.sizes() == out2.sizes(), "reduce_op(): expected both outputs to have same sizes, but output1 has ", out1.sizes(),
" and output2 has ", out2.sizes());
TORCH_CHECK(out1.strides() == out2.strides(), "reduce_op(): expected both outputs to have same strides, but output1 has ", out1.strides(),
" and output2 has ", out2.strides());
return TensorIteratorConfig()
.set_check_mem_overlap(false)
.add_output(out1)
.add_output(out2)
.add_input(a)
.resize_outputs(false)
.is_reduction(true)
.check_all_same_dtype(false)
.build();
}
void TensorIteratorBase::populate_operands(TensorIteratorConfig& config) {
for (auto& tensor: config.tensors_) {
// If *any* of the arguments is a meta tensor, the overall
// computation is a meta computation (don't do any work,
// just compute output information). This aligns with
// our multiple dispatch semantics.
if (tensor->is_meta()) {
is_meta_ = true;
}
operands_.emplace_back(std::move(tensor));
}
num_outputs_ = config.num_outputs_;
}
void TensorIteratorBase::mark_outputs() {
// TODO: merge this into populate_operands
for (int i = 0; i < num_outputs_; i++) {
operands_[i].is_output = true;
const auto& output = operands_[i].tensor;
if (!output->defined()) continue;
// check if output is also an input
for (int arg = num_outputs_; arg < ntensors(); arg++) {
const auto& input = operands_[arg].tensor;
if (output->is_same(*input)) {
operands_[i].is_read_write = true;
}
}
}
}
void TensorIteratorBase::mark_resize_outputs(const TensorIteratorConfig& config) {
// Outputs cannot be broadcasted. Check that the shape of the outputs matches
// the inferred shape. There's an exception for write-only tensors to support
// our legacy behavior that functions with `out=` arguments resize their
// outputs.
if (config.static_shape_.has_value()) {
return;
}
for (int i = 0; i < num_outputs_; i++) {
const auto& output = operands_[i].tensor;
if (output->defined() && !output->sizes().equals(shape_)) {
if (config.resize_outputs_ && !operands_[i].is_read_write) {
operands_[i].will_resize = true;
continue;
}
// for reduction, output size does not match shape_, as output is reduced size, and shape_ is size of the input
TORCH_CHECK(is_reduction_, "output with shape ", output->sizes(), " doesn't match the broadcast shape ",
shape_);
}
}
}
void TensorIteratorBase::compute_mem_overlaps(const TensorIteratorConfig& config) {
if (!config.check_mem_overlap_) {
return;
}
for (int i = 0; i < num_outputs_; i++) {
const auto& output = operands_[i].tensor;
if (!output->defined()) continue;
assert_no_internal_overlap(*output);
for (int j = num_outputs_; j < ntensors(); j++) {
const auto& input = operands_[j].tensor;
assert_no_partial_overlap(*output, *input);
}
}
}
void TensorIteratorBase::compute_shape(const TensorIteratorConfig& config) {
if (config.static_shape_.has_value()) {
shape_ = *config.static_shape_;
return;
}
all_ops_same_shape_ = true;
bool has_scalars = false;
bool has_tensors = false;
for (auto& op : operands_) {
if (!op.tensor->defined()) continue;