Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto Reorder #6

Open
wants to merge 3,335 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
3335 commits
Select commit Hold shift + click to select a range
38cc864
[XLA:Python] Factors the ":logging" library out from ":xla_extension".
wrengr Mar 28, 2024
c2fd51c
[XLA:Python] Add python function to convert `xla::LiteralProto` into …
wrengr Mar 28, 2024
4f8384e
[XLA:Python] Adding `xla::PrimitiveType` <-> `numpy.dtype` conversion…
wrengr Mar 28, 2024
bcdb690
Integrate StableHLO at openxla/stablehlo@271e8634
ghpvnist Mar 28, 2024
e396c88
Delete populateRankSpecialization*Patterns functions
Mar 28, 2024
078a31b
Correctly handle output streaming case where the MoveToHost annotatio…
tensorflower-gardener Mar 29, 2024
d835533
Set release_base for all release platforms
jakeharmon8 Mar 29, 2024
de04cff
[XLA] Respect min_rank for reduce scatter version of MatchReduceScatter.
Mar 29, 2024
e806de5
Integrate LLVM at llvm/llvm-project@aa2c14de1adc
gribozavr Mar 29, 2024
3a12f75
Automated Code Change
tensorflower-gardener Mar 29, 2024
e686499
[xla:gpu][NFC] Use absl::Span more consistenly
tyb0807 Mar 29, 2024
35a5635
[xla][gpu] Extracting triton codegen requirements for hlo instructions
tensorflower-gardener Mar 29, 2024
593d762
[xla:gpu] Create fake buffer allocations for embedded thunk
tyb0807 Mar 29, 2024
c3d52e9
[xla:gpu][NFC] Add AddressComputationThunk test with GEMM operands sh…
tyb0807 Mar 29, 2024
b1c051c
[xla:gpu][NFC] Use meaningful constexpr
tyb0807 Mar 29, 2024
24c0b39
Deduplicate inferred mesh shapes when try_multiple_mesh_shapes=true.
tensorflower-gardener Mar 29, 2024
947a640
Delete the redundant compilation_cache_test
tensorflower-gardener Mar 29, 2024
c4b031b
Restore GOOGLE_CUDA guard in scoped_annotation.h
tensorflower-gardener Mar 29, 2024
0b476d5
Enhanced zeta readability based on the article
Mar 29, 2024
b4a4647
Rollback https://github.com/openxla/xla/commit/0ab2be0b5a575da3206d2c…
reedwm Mar 29, 2024
8bea463
Modify the matrix class to keep track of both memory and communicatio…
tensorflower-gardener Mar 29, 2024
33a8900
[xla:gpu][NFC] Simplify `collect_slice_info`
tyb0807 Mar 29, 2024
24fc9c0
Change include order in `ml_dtypes.cc` to prevent errors.
ddunl Mar 29, 2024
d55afd1
Integrate LLVM at llvm/llvm-project@80aa52d8c5a8
gribozavr Mar 29, 2024
83f3fd6
PR #7849: [XLA:CPU] Add support for cross-process collectives using mpi.
inailuig Mar 29, 2024
7b079ca
Use `bytes` proto field type for string values of `xla::PjRtValueType…
hyeontaek Mar 29, 2024
db8a37d
Move `tsl/python` to `xla/tsl/python`
ddunl Mar 29, 2024
9b34ebe
Reverts c36f237324de53045aa27d87021d2f9db13cf11a
Mar 29, 2024
087bceb
[xla:gpu] Add a version of `HloPredicateIsOp` for `HloInstructionAdap…
tyb0807 Mar 29, 2024
3c8260f
[PJRT C API] Add a PJRT extension to register custom partitioner.
Mar 29, 2024
287cf2e
In general, avoid the suffix on StatusOr.
ghpvnist Mar 29, 2024
3db823f
[xla:gpu] AddressComputationFusionRewriter should run before other fu…
tyb0807 Mar 30, 2024
2c62120
[xla:gpu] Unify GEMM emission for (Dynamic)AddressComputationFusion e…
tyb0807 Mar 30, 2024
e7788f3
[xla:gpu][NFC] Remove unused constexprs
tyb0807 Mar 30, 2024
35c9bb1
[xla:gpu] Unify static and dynamic slice cases for AddressComputation…
tyb0807 Mar 30, 2024
3ff51cd
[xla:gpu][NFC] Make lambdas static functions for better reusability
tyb0807 Mar 30, 2024
f155104
[xla:gpu][NFC] Explicitly rewrite AddressComputationFusion in custom …
tyb0807 Mar 30, 2024
5745c4f
[xla:gpu][NFC] Use the same helpers to get slices for GEMM and generi…
tyb0807 Mar 31, 2024
9572376
Enable SparseCore threads in TpuLayoutAssignment
tensorflower-gardener Mar 31, 2024
5d0e144
[xla:gpu] Generic custom call emission for DynamicAddressComputationF…
tyb0807 Mar 31, 2024
adde9f0
[xla:gpu] DUS support for generic custom call emission in DynamicAddr…
tyb0807 Mar 31, 2024
91281d5
[PJRT:CPU] Fix thread-pool stack sizes to 2MiB.
hawkinsp Apr 1, 2024
b3f212a
[PJRT:CPU] Replace references to pjrt/tfrt_cpu_pjrt_client with pjrt/…
hawkinsp Apr 1, 2024
68b27ae
[xla:gpu] No need for dynamic/static mode in AddressComputationFusion…
ezhulenev Apr 1, 2024
50de394
[xla:gpu][NFC] No need for custom HloModuleConfigs in address_computa…
ezhulenev Apr 1, 2024
f7c1695
Automated Code Change
tensorflower-gardener Apr 1, 2024
06d0760
Add support for parameter streaming with while loop by:
tensorflower-gardener Apr 1, 2024
1d21af3
Hoist async copies when start_after is -1
tensorflower-gardener Apr 1, 2024
7476c6d
Integrate LLVM at llvm/llvm-project@0f6ed4c394fd
alinas Apr 1, 2024
1fa455d
Integrate LLVM at llvm/llvm-project@c09b6fac12b0
slackito Apr 1, 2024
decdd8c
Remove unnecessary Copybara transforms
ddunl Apr 1, 2024
a1240f3
Speed up DetermineHloInstructionIsReplicated.
tensorflower-gardener Apr 2, 2024
99e7acf
Temporary fix for TensorFlow VS2019 breakage.
wecing Apr 2, 2024
3c85e49
Support RngBitGenerator HloInstruction with single output in the HLO …
ZixuanJiang Apr 2, 2024
7ff97ee
Automated Code Change
tensorflower-gardener Apr 2, 2024
ab0ce49
Automated Code Change
tensorflower-gardener Apr 2, 2024
6f894bf
Automated Code Change
tensorflower-gardener Apr 2, 2024
ca2e872
Automated Code Change
tensorflower-gardener Apr 2, 2024
b381cca
Automated Code Change
tensorflower-gardener Apr 2, 2024
3fdfffb
Switch `xla_cpu` dialect to MLIR properties.
chsigg Apr 2, 2024
c602266
Internal cleanup of BUILD/.bzl files
tensorflower-gardener Apr 2, 2024
13bde94
Internal cleanup of BUILD/.bzl files
tensorflower-gardener Apr 2, 2024
f899622
PR #10965: [GPU] Make xla_gpu_enable_nccl_per_stream_comms false by d…
trevor-m Apr 2, 2024
3c9405f
Automated Code Change
tensorflower-gardener Apr 2, 2024
7b0124b
PR #11094: [ROCm] enable rocm for se_gpu_pjrt_compiler_aot_test
i-chaochen Apr 2, 2024
92874b9
Add missing parameter to xla_computation_to_mlir_module interface
Adam-Banas Apr 2, 2024
baf3ad4
Re-enable HoistLayoutConversion pattern and mixed-precision MMA for A…
Moerafaat Apr 2, 2024
a134d67
Add machine attributes to topology
tensorflower-gardener Apr 2, 2024
f3f0122
Eliminate unused Compiler::AssignBuffers method and overloads.
klucke Apr 2, 2024
c28739c
Add absl::SourceLocations to xla::ResourceExhaustedError.
klucke Apr 2, 2024
f4274db
Reverts baf3ad40513f12841cdef913bfb4f4db35469caa
tensorflower-gardener Apr 2, 2024
ee3c488
Migrate the SerDes logic for basic IFRT API object types from IFRT Pr…
hyeontaek Apr 2, 2024
7395a37
[xla:ffi] Compute correct byte size of a DeviceMemoryBase for FFI buf…
ezhulenev Apr 2, 2024
5eeadb3
[XLA] Ignore channel id for all-reduce in spmd programs
vsytch Apr 2, 2024
68a57ec
Pass the correct operands for branch computations when invoking Einsu…
tensorflower-gardener Apr 2, 2024
0742379
1. More comprehensive after-all handling. Specifically, the pass now …
tensorflower-gardener Apr 2, 2024
f9ae540
Roll back LiteralBase::Hash change due to performance regression.
tensorflower-gardener Apr 2, 2024
7df6101
This CL introduces 'PluginProgram' in IFRT and exposes this in python…
tensorflower-gardener Apr 2, 2024
08b6f26
PR #11141: Add warmup iterations introduced in #9757
olupton Apr 2, 2024
ce93261
PR #11122: [ROCm] Fix RCCL hang on rocm5.7
draganmladjenovic Apr 2, 2024
0eddb55
Ensure that the module's buffer donor config and the input output ali…
tensorflower-gardener Apr 2, 2024
094ad43
[HloValueSemanticsAnalysis] Deduplicate some functions.
jinliangwei Apr 2, 2024
d3e9dde
[xla:gpu] NFC: Remove AddressComputationFusion emitter
ezhulenev Apr 2, 2024
914025c
[XLA:Runtime] Moved the nccl_api target to runtime folder.
tensorflower-gardener Apr 2, 2024
492c2d7
[xla:gpu][NFC] Rename DynamicAddressComputationFusion back to Address…
ezhulenev Apr 2, 2024
b62b77c
Add absl::SourceLocation to xla::Cancelled, xla::NotFound, xla::Unava…
klucke Apr 2, 2024
a2b4e43
Reverts 9b34ebe93fb8aa435e20f91c4af7de17bbc806cc
Apr 2, 2024
04233b9
A couple minor bug fixes:
tensorflower-gardener Apr 3, 2024
207d793
xla_compile_lib: Also support modules in binary proto format.
pizzud Apr 3, 2024
3f67da4
[JAX] Update JAX CI dockerfiles to use NumPy 2.0.0rc1, SciPy 1.13.0rc…
hawkinsp Apr 3, 2024
4e8e23f
Make RemoveCollective virtual and return the new custom call instruction
tensorflower-gardener Apr 3, 2024
f8e10a9
Remove unused includes and dependencies
Adam-Banas Apr 3, 2024
7c09621
[XLA:GPU] Forbid fusing concatenations in dots when the non-contracti…
bchetioui Apr 3, 2024
aee555d
Fixed internal issue.
Apr 3, 2024
4c8a74b
PR #11172: GpuTimer: use delay kernel to improve accuracy
olupton Apr 3, 2024
4de516e
Automated Code Change
tensorflower-gardener Apr 3, 2024
2ff8727
[XLA:GPU] Extract SymbolicTiledHloInstruction into a separate file.
olegshyshkov Apr 3, 2024
828b65d
Use absl::Status errors rather than the tsl equivalents.
klucke Apr 3, 2024
601eaf6
Add source locations to FAILED_PRECONDITION errors.
klucke Apr 3, 2024
c2b4f49
[ifrt_proxy] Added a separate thread pool for host callbacks
superbobry Apr 3, 2024
956f0ed
Only use `config-cuda-only` tag under `if_google` wrapper
ddunl Apr 3, 2024
5fe2779
[XLA:LatencyHidingScheduler] Schedule while and its tuple back to back.
tensorflower-gardener Apr 3, 2024
7d9d486
Add attribute interface for IFRT IR sharding.
ICGog Apr 3, 2024
940093e
Add SourceLocation information to xla::Internal errors.
klucke Apr 3, 2024
eb453e4
[XLA:Python] Making `absl::StatusOr`-casting explicit.
wrengr Apr 3, 2024
0a6049d
[JAX] Rebuild CUDA 12.1 image with newer ml_dtypes and numpy.
hawkinsp Apr 3, 2024
4386e92
[XLA] Make shape util fuzzer happy
vsytch Apr 3, 2024
34509f1
Add IFRT pass that verifies if all !ifrt.arrays have sharding specified.
ICGog Apr 3, 2024
aecdde0
Remove the constraint that tokens cannot be passed as entry parameter…
yueshengys Apr 3, 2024
7b4b275
[PJRT C API] Plumb plugin attributes from plugin to JAX python.
Apr 3, 2024
baf7544
Integrate LLVM at llvm/llvm-project@9df19ce40281
tensorflower-gardener Apr 3, 2024
55cdde9
Minimize number of Copybara transforms that operate on `tensorflow/th…
ddunl Apr 4, 2024
ff468cb
Implement convolution via indexing maps.
sergeykozub Apr 4, 2024
31da0a2
Internal CI configuration to run tests on H100
akuegel Apr 4, 2024
fbbb6c8
Integrate Triton up to [e902d3b6](https://github.com/openai/triton/co…
tensorflower-gardener Apr 4, 2024
27c0d0f
Automated Code Change
tensorflower-gardener Apr 4, 2024
bc2e787
PR #10763: [XLA:GPU] Fix cuDNN FMHA fwd scale not passed into cuDNN
Cjkkkk Apr 4, 2024
b72a639
Automated Code Change
tensorflower-gardener Apr 4, 2024
4821859
PR #11139: [GPU] Enable cuDNN integer math mode only with v9.1+.
sergachev Apr 4, 2024
8cee2ac
[xla:ffi] Unit tests for CPU type-safe custom call API
Adam-Banas Apr 4, 2024
3325ca7
Internal cleanup of BUILD/.bzl files
tensorflower-gardener Apr 4, 2024
fef33a9
Integrate LLVM at llvm/llvm-project@c511c90680ee
tensorflower-gardener Apr 4, 2024
b879cfa
Reverts 99e7acf3535eed1982be5cec9ec0537b1c2bb84f
akuegel Apr 4, 2024
22257bf
Bump the operands+outputs threshold to allow larger fusions.
thomasjoerg Apr 4, 2024
1345784
Better disable mechanism for tensor cores for 8-bit-or-less dot with …
Moerafaat Apr 4, 2024
89fd078
Disable a test that is failing on H100
akuegel Apr 4, 2024
8ea3a17
Reverts f4274db0dda92116ce96d84f8b42fee4e9377307
Moerafaat Apr 4, 2024
3be7909
Fix msan error introduced in https://github.com/openxla/xla/commit/7b…
Apr 4, 2024
570a4d8
[pjrt] NFC: Rename HostBufferSemantics::kZeroCopy to kImmutableZeroCopy
ezhulenev Apr 4, 2024
19422af
[XLA:GPU] Add TiledHloInstruction.
olegshyshkov Apr 4, 2024
b1de4e9
[xla] Add a method to express that we want to schedule a node as earl…
bixia1 Apr 4, 2024
2421769
Added a virtual function (`CanPropagateShardingToOperands`) to `Custo…
tensorflower-gardener Apr 4, 2024
2a29934
PR #10503: Fix log1p inaccuracies on complex inputs with large absolu…
pearu Apr 4, 2024
c2cc020
Add SourceLocation information to xla::Unimplemented errors.
klucke Apr 4, 2024
3d6326c
presubmits: Add a presubmit for CHECK and related macros.
pizzud Apr 4, 2024
a0cf7ec
[XLA:Runtime] Moved the nccl_clique target to runtime folder.
sgerrard Apr 4, 2024
774befd
[xla][gpu] Change the point-to-point pipeliner to produce an intermed…
bixia1 Apr 4, 2024
136750b
Fix for a bug where the while loop fusible sinking crashes when all o…
tensorflower-gardener Apr 4, 2024
ee8f8b8
Import nanobind caster for std::string to avoid casting error.
pschuh Apr 4, 2024
df104c0
[xla:ffi] Add support for annotating FFI results with type tags to di…
ezhulenev Apr 4, 2024
4d135db
[xla:ffi] Add auto-binding for FFI results
ezhulenev Apr 4, 2024
0960128
Crash on HLOs with nested tuples in conditionals.
tensorflower-gardener Apr 4, 2024
b0c6c26
Integrate StableHLO at openxla/stablehlo@1bdf7c26
abhigunj Apr 4, 2024
c3dec1d
[xla:gpu] Pass custom-call results as xla:ffi results to handlers
ezhulenev Apr 5, 2024
be5c637
Add GetDefaultLayout to PjRtTopologyDescription. This is needed to su…
pschuh Apr 5, 2024
fd923c9
Integrate LLVM at llvm/llvm-project@e0e615efac52
tensorflower-gardener Apr 5, 2024
75d64d4
PR #11164: [XLA:GPU] bump up minimum PTX ISA to be 8.1 for CUDA >= 12.1
Cjkkkk Apr 5, 2024
b2886a5
Add support for sparse dot (wgmma.sp) to NVGPU triton dialect.
sergeykozub Apr 5, 2024
4a5ccff
Introduce a new dot operation with a sparse operand (2:4) and its low…
sergeykozub Apr 5, 2024
39ba25c
Reverts 8ea3a17f41213ebff0a1e2b11fab31ef3cd96d92
Moerafaat Apr 5, 2024
e0c0495
Update triton compiler passes to support sparse dot operation.
sergeykozub Apr 5, 2024
8826494
[XLA:GPU] Move code to compute block id to tile offset indexing map t…
olegshyshkov Apr 5, 2024
aa08925
PR #10649: [ROCm] Triton in XLA for ROCm - ir_emitter_triton related …
zoranjovanovic-ns Apr 5, 2024
6d79348
Adds utilities for extracting chosen node & edge strategies.
tensorflower-gardener Apr 5, 2024
03ace28
Remove unused proto imports
hyeontaek Apr 5, 2024
3f57bde
Rolling back for now.
majnemer Apr 5, 2024
5925407
Fix Copybara reversibility issues
ddunl Apr 5, 2024
91c9c6a
Lazily instantiates memory constraints.
tensorflower-gardener Apr 5, 2024
58650f0
Automated Code Change
klucke Apr 5, 2024
2449ba5
Fix test to load autotuning results from cache instead of actually co…
Moerafaat Apr 5, 2024
c2354f9
Add a fallback when GetDefaultLayout is unimplemented for that backend.
Apr 5, 2024
cfdc914
Revert TrivialDce pass
Apr 5, 2024
8b2a01a
Make common host memory spaces sharable across backends.
Apr 5, 2024
3c320e0
Use absl::Status rather than tsl::Status
klucke Apr 5, 2024
7edf554
Defines a solver 'output' class, to replace the earlier (unnamed) tuple.
tensorflower-gardener Apr 5, 2024
2ab1bae
Pass MLIR bytecode across XLA Extension boundary for JAX when convert…
GleasonK Apr 5, 2024
d8fe022
Integrate LLVM at llvm/llvm-project@8487e05967aa
tensorflower-gardener Apr 5, 2024
5769a5c
Add GetDefaultLayoutForDevice to IFRT.
Apr 5, 2024
7e0eb94
Add SourceLocation information to xla::InvalidArgument.
klucke Apr 5, 2024
de4c964
[IFRT] Cache the hash of `DeviceList`
hyeontaek Apr 5, 2024
af88e08
Allow both nb::tuple and nb::list for fastpath_data.
pschuh Apr 5, 2024
78a4bcb
Add a method to get default layout in PyClient.
Apr 5, 2024
8e3cf89
Add unbounded dynamism test for NotOp.
ghpvnist Apr 6, 2024
2fd63e3
Add unbounded dynamism test for MinOp.
ghpvnist Apr 6, 2024
139d4fc
Add unbounded dynamism test for RemOp.
ghpvnist Apr 6, 2024
c1a90a3
Add unbounded dynamism test for ReducePrecisionOp.
ghpvnist Apr 6, 2024
28a8a4b
Add unbounded dynamism test for MapOp.
ghpvnist Apr 6, 2024
e1f9272
Add unbounded dynamism test for XorOp.
ghpvnist Apr 6, 2024
38b39b3
Add unbounded dynamism test for ComplexOp.
ghpvnist Apr 6, 2024
91394b4
[IFRT] Add fast pointer equality test for `DeviceList` internal state
hyeontaek Apr 6, 2024
dda9726
IfrtServingExecutable support host callback execution
deqiangc Apr 6, 2024
17a4028
Add unbounded dynamism test for ShiftLeftOp.
ghpvnist Apr 6, 2024
04e2731
Automated Code Change
tensorflower-gardener Apr 6, 2024
ed9b5bb
Add unbounded dynamism test for ShiftRightArithmeticOp.
ghpvnist Apr 6, 2024
2df33c5
Support shape transpose in `hlo_sharding_util::ReshapeSharding`.
ZixuanJiang Apr 6, 2024
a5788c4
Ensure that the module we consume has no unused computations. This ca…
tensorflower-gardener Apr 6, 2024
42fa291
[NFC] Switch `mhlo` dialect to MLIR properties.
chsigg Apr 6, 2024
ea157cf
Add support for recv and recv-done HLO ops in auto-sharding
tensorflower-gardener Apr 6, 2024
119ed17
Automated Code Change
tensorflower-gardener Apr 6, 2024
26635be
Use StreamExecutor to create stream rather than manually constructing.
klucke Apr 6, 2024
f76be00
Remove deprecated code from JAX lowering and compilation
yashk2810 Apr 7, 2024
dbbf0dd
In the Auto Sharding solver output, populates the times where peak me…
tensorflower-gardener Apr 7, 2024
59d9644
Ignores the previous run's peak times if we're in deterministic mode.
tensorflower-gardener Apr 7, 2024
b8e6958
Unpack per-channel hybrid quantized MHLO ops to float ops
doyeonkim0 Apr 8, 2024
b5bcce7
[xla] hlo_computation: drop instruction_indices_
cota Apr 8, 2024
6a19221
[tsl:concurrency] Add LLVM-style type casting to AsyncValuePtr<T>
ezhulenev Apr 8, 2024
3f881e9
[tsl:concurrency] Add LLVM-style type casting to AsyncValueRef<T>
ezhulenev Apr 8, 2024
990d54a
Add unbounded dynamism test for ShiftRightLogicalOp.
ghpvnist Apr 8, 2024
ed6603c
add constrain:communicate op order ,try solve hung
zjjott Apr 8, 2024
2b4a3d8
Add unbounded dynamism test for SelectAndScatterOp.
ghpvnist Apr 8, 2024
ae6218c
Reverts 2df33c54de828f0144c8a43333daa31e5bd8b265
laurentes Apr 8, 2024
d8fe29f
Removes useless friend declaration.
Apr 8, 2024
370a681
Fix a bug in algebraic simplifier that incorrectly rewrites a broadca…
tensorflower-gardener Apr 8, 2024
32a4c7b
fix sf
zjjott Apr 8, 2024
b8c6a2a
Add unbounded dynamism test for OptimizationBarrierOp.
ghpvnist Apr 8, 2024
55bb55a
Add unbounded dynamism test for WhileOp.
ghpvnist Apr 8, 2024
ba124f3
Add unbounded dynamism test for TupleOp.
ghpvnist Apr 8, 2024
208496a
Fix clang-14 build issue (reference to local binding declared in encl…
sergeykozub Apr 8, 2024
41fbb25
[XLA:GPU][Coalescing] Add coalescing for ops with runtime vars.
pifon2a Apr 8, 2024
298ad29
[XLA:GPU][IndexAnalysis] Check if the scatter is canonicalized when c…
pifon2a Apr 8, 2024
8e24878
Integrate LLVM at llvm/llvm-project@8ee6ab7f69ca
krasimirgg Apr 8, 2024
64df3fc
[PJRT C API] Add a build rule for building the PJRT CPU plugin.
Apr 8, 2024
4c4fbf5
Prepare move of CUDA compilation functionality
beckerhe Apr 8, 2024
b09587c
Add unbounded dynamism test for GetTupleElementOp.
ghpvnist Apr 8, 2024
a8d2772
[tsl:concurrency] Add helper functioms to block/await on AsyncValueRe…
ezhulenev Apr 8, 2024
ecf5f62
Fix HLO cost analysis for simple nested fusions.
tensorflower-gardener Apr 8, 2024
2303754
[tsl:concurrency] NFC: Use port::Aligned(Malloc|Free) instead of cust…
ezhulenev Apr 8, 2024
1ca099e
[tsl:concurrency] NFC: Add tests for various types of AndThen callbac…
ezhulenev Apr 8, 2024
335d390
[tsl:concurrency] Specify casting rules for ErrorAsyncValue
ezhulenev Apr 8, 2024
0bde421
[tsl:concurrency] NFC: Fix warnings in tsl/concurrency folder
ezhulenev Apr 8, 2024
6caa54b
Add metadata matcher for unit testing.
tensorflower-gardener Apr 8, 2024
f97959d
Add unbounded dynamism test for CholeskyOp.
ghpvnist Apr 8, 2024
b3a7256
Customize Reserved GPU HBM Size by flag in Pathways.
tensorflower-gardener Apr 8, 2024
1448c02
HloTopKInstruction doesn't inherit from HloDimensionsInstruction,
tensorflower-gardener Apr 8, 2024
8f1fced
Add unbounded dynamism tests for RngNormalOp and RngUniformOp.
ghpvnist Apr 8, 2024
283059e
Remove Stream::SetPriority functions from Stream.
klucke Apr 8, 2024
e16c318
Add unbounded dynamism test for RngBitGeneratorOp.
ghpvnist Apr 8, 2024
94fdc81
Move StreamImplementation creation into Stream::Initialize, and elimi…
klucke Apr 8, 2024
1ff4f0d
[tsl:concurrency] Specify Isa/DynCast/Cast semantics for indirect asy…
ezhulenev Apr 8, 2024
0a534f1
Enforce that CLs satisfy openxla/xla's buildifier checks
ddunl Apr 8, 2024
9949693
Reverts 0a534f166b4edf6a41578263e7784ef892ba41b3
ddunl Apr 9, 2024
f253a74
Add unbounded dynamism test for TriangularSolveOp.
ghpvnist Apr 9, 2024
b53d9c8
Add unbounded dynamism test for ReverseOp.
ghpvnist Apr 9, 2024
3e89666
Add unbounded dynamism test for SortOp.
ghpvnist Apr 9, 2024
79eccb4
Add unbounded dynamism test for DynamicSliceOp.
ghpvnist Apr 9, 2024
1acf05e
Automated Code Change
tensorflower-gardener Apr 9, 2024
556df5a
try to fix communicate order issue; work on some scene
zjjott Apr 11, 2024
ecfe74a
AutoReorder: add hint, so that solve will be faster
zjjott Apr 11, 2024
0fe4d0e
add constrain to limit two communicate fuse; add all-gather cost esti…
zjjott Apr 16, 2024
7866099
add reduce-scatter cost estimate.
zjjott Apr 18, 2024
947a478
add cuda error debug info.add all2all test
zjjott Apr 22, 2024
0c4f31d
Merge branch 'github_1acf05e' into feature/auto_reorder
zjjott Apr 22, 2024
db4a5d5
after merge. some fix; communication op cost is uncorrect[WIP]
zjjott Apr 25, 2024
c3b71f7
fix uncorrect comm op cost
zjjott Apr 25, 2024
78996c6
fix allgather cost
zjjott Apr 26, 2024
b43d42d
add allgather/reducescatter scaleradio
zjjott May 7, 2024
5628d13
Support flash-attention custom call (#8)
ApsarasX May 8, 2024
a4fc087
fix log
zjjott May 9, 2024
0356daa
migrate convert xplane; PGLE using analytical as fallback estimator
zjjott May 15, 2024
8ada939
support export to mps and json; [WIP] convert xplant to offline sqlite
zjjott May 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[xla:gpu] DUS support for generic custom call emission in DynamicAddr…
…essComputationFusion emitter

PiperOrigin-RevId: 620691743
  • Loading branch information
tyb0807 authored and copybara-github committed Mar 31, 2024
commit adde9f0f1ca58bbf84310298c0bdc4bb51b04428
137 changes: 113 additions & 24 deletions xla/service/gpu/fusions/address_computation_fusion_test.cc
Original file line number Diff line number Diff line change
@@ -1089,7 +1089,7 @@ void Callback_Memcpy(se::gpu::GpuStreamHandle stream, void** buffers,
const char* /*opaque*/, size_t /*opaque_len*/) {
void* src = buffers[0];
void* dst = buffers[1];
auto err = gpuMemcpyAsync(dst, src, /*count=*/sizeof(float) * 128,
auto err = gpuMemcpyAsync(dst, src, /*count=*/sizeof(float) * 3 * 128,
gpuMemcpyDeviceToDevice, stream);
ASSERT_EQ(err, gpuSuccess);
}
@@ -1100,9 +1100,9 @@ TEST_F(AddressComputationFusionTest, CustomCallLegacyAPI) {
XlaBuilder b(TestName());
CustomCall(&b, "Callback_Memcpy",
/*operands=*/
{Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {256}), {0},
{128}, {1})},
ShapeUtil::MakeShape(F32, {128}), /*opaque=*/"");
{Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {512}), {128},
{4 * 128}, {1})},
ShapeUtil::MakeShape(F32, {3 * 128}), /*opaque=*/"");
ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3};

TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
@@ -2570,20 +2570,25 @@ static absl::Status SubBuffers2(se::Stream* stream, ffi::BufferBase src0,
ffi::BufferBase src1, ffi::BufferBase src2,
ffi::BufferBase src3, ffi::BufferBase src4,
ffi::BufferBase src5, ffi::BufferBase src6,
ffi::BufferBase src7, ffi::BufferBase dst0,
ffi::BufferBase dst1, ffi::BufferBase dst2,
ffi::BufferBase dst3, ffi::BufferBase dst4) {
ffi::BufferBase dst0, ffi::BufferBase dst1,
ffi::BufferBase dst2, ffi::BufferBase dst3,
ffi::BufferBase dst4, ffi::BufferBase dst5,
ffi::BufferBase dst6) {
// src0: param 0 at tuple index {0}, shape f32[128]
// src1: param 0 at tuple index {1}, shape f32[256]
// src2: param 1 at tuple index {0}, shape f32[1024]
// src3: param 1 at tuple index {1}, shape f32[8]
// src4: param 2, shape f32[4,8]
// src5: param 3 at tuple index {0, 0}, shape f32[3,128]
// src6: param 3 at tuple index {0, 1}, shape f32[5,128]
//
// dst0: result at tuple index {0}, shape f32[8]
// dst1: result at tuple index {1, 0}, shape f32[128]
// dst2: result at tuple index {1, 1}, shape f32[256]
// dst3: result at tuple index {2}, shape f32[1024]
// dst4: result at tuple index {3}, shape f32[4,8]
// dst5: result at tuple index {4, 0}, shape f32[5,128]
// dst6: result at tuple index {4, 1}, shape f32[3,128]

TF_RETURN_IF_ERROR(
stream->MemcpyD2D(&dst0.data, src3.data, 8 * sizeof(float)));
@@ -2595,6 +2600,10 @@ static absl::Status SubBuffers2(se::Stream* stream, ffi::BufferBase src0,
stream->MemcpyD2D(&dst3.data, src2.data, 1024 * sizeof(float)));
TF_RETURN_IF_ERROR(
stream->MemcpyD2D(&dst4.data, src4.data, 4 * 8 * sizeof(float)));
TF_RETURN_IF_ERROR(
stream->MemcpyD2D(&dst5.data, src6.data, 5 * 128 * sizeof(float)));
TF_RETURN_IF_ERROR(
stream->MemcpyD2D(&dst6.data, src5.data, 3 * 128 * sizeof(float)));
return absl::OkStatus();
}

@@ -2608,20 +2617,64 @@ XLA_FFI_DEFINE_HANDLER(kSubBuffers2, SubBuffers2,
.Arg<ffi::BufferBase>() // src4
.Arg<ffi::BufferBase>() // src5
.Arg<ffi::BufferBase>() // src6
.Arg<ffi::BufferBase>() // src7
.Arg<ffi::BufferBase>() // dst0
.Arg<ffi::BufferBase>() // dst1
.Arg<ffi::BufferBase>() // dst2
.Arg<ffi::BufferBase>() // dst3
.Arg<ffi::BufferBase>() // dst4
.Arg<ffi::BufferBase>() // dst5
.Arg<ffi::BufferBase>() // dst6
);
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$subbuffers2",
PLATFORM, kSubBuffers2);

TEST_F(AddressComputationFusionTest, Test) {
TEST_F(AddressComputationFusionTest, CustomCallDUS) {
XlaBuilder b(TestName());
CustomCall(
&b, "Callback_Void", /*operands=*/
auto custom_call =
CustomCall(&b, "Callback_Memcpy",
/*operands=*/
{Slice(Broadcast(ConstantR0WithType(&b, F32, 42.0), {10, 128}),
{2, 0}, {5, 128}, {1, 1})},
ShapeUtil::MakeShape(F32, {3, 128}), /*opaque=*/"");

DynamicUpdateSlice(
Broadcast(ConstantR0WithType(&b, F32, 92.0), {10, 128}), custom_call,
{ConstantR0WithType(&b, S32, 4), ConstantR0WithType(&b, S32, 0)});

ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3};

TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
xla::HloModuleConfig hlo_config(
xla::ProgramShape(computation.proto().host_program_shape()),
/*ignore_layouts=*/false);
DebugOptions debug_options = GetDebugOptionsForTest();
debug_options.set_xla_gpu_enable_address_computation_fusion(false);
hlo_config.set_debug_options(debug_options);
TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto(
computation.proto(), hlo_config));

debug_options.set_xla_gpu_enable_address_computation_fusion(true);
hlo_config.set_debug_options(debug_options);
TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto(
computation.proto(), hlo_config));

AddressComputationFusionRewriter pass(PLATFORM);
TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get()));
EXPECT_TRUE(changed);

EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt),
error_spec,
/*run_hlo_passes=*/false));
}

TEST_F(AddressComputationFusionTest, CustomCallDUSTuple) {
XlaBuilder b(TestName());
auto big_buffer1 =
Parameter(&b, 0, ShapeUtil::MakeShape(F32, {10, 128}), "p0");
auto big_buffer2 =
Parameter(&b, 1, ShapeUtil::MakeShape(F32, {10, 256}), "p1");
auto custom_call = CustomCall(
&b, "__xla_test$$subbuffers2", /*operands=*/
{
Tuple(&b,
{
@@ -2630,24 +2683,60 @@ TEST_F(AddressComputationFusionTest, Test) {
}),
Tuple(&b,
{
Slice(Broadcast(ConstantR0WithType(&b, F32, 3), {1024}),
{512}, {512 + 256}, {1}),
Broadcast(ConstantR0WithType(&b, F32, 3), {1024}),
Broadcast(ConstantR0WithType(&b, F32, 4), {8}),
}),
Slice(Broadcast(ConstantR0WithType(&b, F32, 5), {8, 8}), {0, 0},
{4, 8}, {1, 1}),
Tuple(&b,
{
Tuple(&b,
{
Broadcast(ConstantR0WithType(&b, F32, 6), {32}),
Broadcast(ConstantR0WithType(&b, F32, 7), {64}),
}),
}),
Tuple(
&b,
{
Tuple(
&b,
{
Broadcast(ConstantR0WithType(&b, F32, 6), {3, 128}),
DynamicSlice(Broadcast(ConstantR0WithType(&b, F32, 7),
{8, 128}),
{ConstantR0WithType(&b, S32, 2),
ConstantR0WithType(&b, S32, 0)},
{5, 128}),
}),
}),
},
ShapeUtil::MakeNil(),
// ShapeUtil::MakeShape(F32, {128}),
/*opaque=*/"");
ShapeUtil::MakeTupleShape({
ShapeUtil::MakeShape(F32, {8}),
ShapeUtil::MakeTupleShape({
ShapeUtil::MakeShape(F32, {128}),
ShapeUtil::MakeShape(F32, {256}),
}),
ShapeUtil::MakeShape(F32, {1024}),
ShapeUtil::MakeShape(F32, {4, 8}),
ShapeUtil::MakeTupleShape({
ShapeUtil::MakeShape(F32, {5, 128}),
ShapeUtil::MakeShape(F32, {3, 128}),
}),
}),
/*opaque=*/"",
/*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
/*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
/*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
auto tuple_gte = GetTupleElement(custom_call, 4);
auto dus1 = DynamicUpdateSlice(
big_buffer1, GetTupleElement(tuple_gte, 0),
{ConstantR0WithType(&b, S32, 2), ConstantR0WithType(&b, S32, 0)});
auto dus2 = DynamicUpdateSlice(
big_buffer1, GetTupleElement(tuple_gte, 1),
{ConstantR0WithType(&b, S32, 7), ConstantR0WithType(&b, S32, 0)});
auto dus3 = DynamicUpdateSlice(
big_buffer2,
xla::internal::XlaBuilderFriend::BuildBitcast(
&b, GetTupleElement(custom_call, 2),
ShapeUtil::MakeShape(F32, {4, 256})),
{Parameter(&b, 2, ShapeUtil::MakeShape(S32, {}), "start0"),
Parameter(&b, 3, ShapeUtil::MakeShape(S32, {}), "start1")});
Tuple(&b, {dus1, dus2, dus3});

ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3};

TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
20 changes: 19 additions & 1 deletion xla/service/gpu/fusions/custom.cc
Original file line number Diff line number Diff line change
@@ -207,8 +207,26 @@ absl::StatusOr<BufferAllocation::Slice> GetResultSlice(
const HloInstruction& fusion_instr, const HloInstruction& start_instr,
std::vector<HloInstruction*>& slice_instrs, const ShapeIndex& shape_idx,
unsigned arg_idx) {
auto* start = const_cast<HloInstruction*>(&start_instr);
// Walk through ShapeIndex to find the real "user" (i.e. not get-tuple-element
// user). Otherwise one sliced element will mark all buffers of all other
// elements "sliced" too.
if (start->shape().IsTuple()) {
for (auto idx : shape_idx) {
std::vector<HloGetTupleElementInstruction*> gte_users(
start->shape().tuple_shapes_size(), nullptr);
for (auto* user : start->users())
if (auto* gte = DynCast<HloGetTupleElementInstruction>(user))
gte_users[gte->tuple_index()] = gte;

start = static_cast<HloInstruction*>(gte_users[idx]);
if (start == nullptr)
return GetAllocationSlice(buffer_assignment, &fusion_instr, shape_idx);
}
}

auto slice_adaptor = HloFindIf(
{HloInstructionAdaptor(start_instr)}, adaptor,
{HloInstructionAdaptor(*start)}, adaptor,
[](auto node) { return node.opcode() == HloOpcode::kDynamicUpdateSlice; },
/*visit_operands=*/false);
if (slice_adaptor.has_value()) {