Skip to content

Commit

Permalink
[HloValueSemanticsAnalysis] Fix activation gradient classification fo…
Browse files Browse the repository at this point in the history
…r MoE.

The idea is that activation-activation einsums are (almost) never activation gradient.

PiperOrigin-RevId: 587143854
  • Loading branch information
jinliangwei authored and tensorflower-gardener committed Dec 1, 2023
1 parent aef6b65 commit f01a271
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 25 deletions.
50 changes: 27 additions & 23 deletions third_party/xla/xla/service/hlo_value_semantics_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,8 @@ HloValueSemanticsPropagation::ComputeSemanticsFromStaticAndOther(
instruction->opcode() == HloOpcode::kConvolution;
if (is_dot_or_convolution &&
other_semantics.label() == HloValueSemanticLabel::kActivationGradient) {
return CreateGradientSemantics(instruction);
return MaybeCreateGradientSemantics(
instruction, HloValueSemanticLabel::kActivationGradient);
}
return CopySemantics(other_semantics);
}
Expand All @@ -939,8 +940,9 @@ HloValueSemanticsPropagation::ComputeSemanticsFromRandomAndOther(
}

StatusOr<HloValueSemantics>
HloValueSemanticsPropagation::CreateGradientSemantics(
HloInstruction* gradient_candidate) const {
HloValueSemanticsPropagation::MaybeCreateGradientSemantics(
HloInstruction* gradient_candidate,
HloValueSemanticLabel fallback_label) const {
const EinsumDepthMap& einsum_depth_map = analysis_->GetEinsumDepthMap();
auto depth_iter = einsum_depth_map.find(gradient_candidate);
CHECK(depth_iter != einsum_depth_map.end());
Expand All @@ -956,8 +958,7 @@ HloValueSemanticsPropagation::CreateGradientSemantics(
return HloValueSemantics(HloValueSemanticLabel::kWeightGradient,
{gradient_candidate, {}});
}
return HloValueSemantics(HloValueSemanticLabel::kActivationGradient,
{gradient_candidate, {}});
return HloValueSemantics(fallback_label, {gradient_candidate, {}});
}

StatusOr<HloValueSemantics>
Expand All @@ -972,6 +973,9 @@ HloValueSemanticsPropagation::ComputeSemanticsFromWeightAndOther(
instruction->opcode() == HloOpcode::kConvolution;
if (other_semantics.label() == HloValueSemanticLabel::kWeight) {
if (!is_dot_or_convolution) {
if (weight_semantics.origin() == other_semantics.origin()) {
return CopySemantics(other_semantics);
}
return CopySemanticsWithNewOrigin(other_semantics, instruction);
}
return HloValueSemantics(HloValueSemanticLabel::kActivation,
Expand All @@ -988,7 +992,8 @@ HloValueSemanticsPropagation::ComputeSemanticsFromWeightAndOther(
// operand.
if (OriginDependsOn(other_semantics, weight_semantics.origin(),
/*recursive=*/true)) {
return CreateGradientSemantics(instruction);
return MaybeCreateGradientSemantics(
instruction, HloValueSemanticLabel::kActivationGradient);
}
return CopySemanticsWithNewOrigin(other_semantics, instruction);
}
Expand All @@ -997,7 +1002,8 @@ HloValueSemanticsPropagation::ComputeSemanticsFromWeightAndOther(
// which produce an Activation. The ActivationGradient to this Activation
// could be used in an einsum with one of the Weights to compute
// the WeightGradient for the other Weight.
return CreateGradientSemantics(instruction);
return MaybeCreateGradientSemantics(
instruction, HloValueSemanticLabel::kActivationGradient);
}
CHECK(other_semantics.label() == HloValueSemanticLabel::kWeightGradient);
return CopySemantics(other_semantics);
Expand All @@ -1015,14 +1021,16 @@ HloValueSemanticsPropagation::ComputeSemanticsFromActivationAndOther(
bool is_dot_or_convolution = instruction->opcode() == HloOpcode::kDot ||
instruction->opcode() == HloOpcode::kConvolution;
if (!is_dot_or_convolution) {
if (activation_semantics.origin() == other_semantics.origin()) {
return CopySemantics(other_semantics);
}
return CopySemanticsWithNewOrigin(other_semantics, instruction);
}
if (other_semantics.label() == HloValueSemanticLabel::kActivation) {
// Like said above, since loss is classified as Activation, an einsum
// between an Activation X and an Activation Y could be WeightGradient or
// even ActivationGradient when either X or Y is the loss. This case is
// different from other Activation einsums because there must a dependency
// between X and Y.
// between an Activation X and an Activation Y could be WeightGradient if
// either X or Y is the loss. This case is different from other Activation
// einsums because there must a dependency between X and Y.
bool other_depends_on_activation = OriginDependsOn(
other_semantics, activation_semantics.origin(), /*recursive=*/true);
bool activation_depends_on_other =
Expand All @@ -1032,14 +1040,19 @@ HloValueSemanticsPropagation::ComputeSemanticsFromActivationAndOther(
// If there is no dependency between the two Activations, the output must
// be an Activation.
if (other_depends_on_activation || activation_depends_on_other) {
return CreateGradientSemantics(instruction);
// We check if the einsum is actually weight gradient. If it is not, fall
// back to activation, since we expect the loss to be computed from an
// activation-weight einsum.
return MaybeCreateGradientSemantics(instruction,
HloValueSemanticLabel::kActivation);
}
return CopySemanticsWithNewOrigin(other_semantics, instruction);
}
if (other_semantics.label() == HloValueSemanticLabel::kActivationGradient) {
// An Activation-ActivationGradient einsum could be computing
// WeightGradient or ActivationGradient.
return CreateGradientSemantics(instruction);
return MaybeCreateGradientSemantics(
instruction, HloValueSemanticLabel::kActivationGradient);
}
CHECK(other_semantics.label() == HloValueSemanticLabel::kWeightGradient)
<< "instruction: " << instruction->ToString()
Expand Down Expand Up @@ -1407,16 +1420,7 @@ Status HloValueSemanticsPropagation::HandleDynamicSlice(
const HloInstruction* dynamic_slice_operand = dynamic_slice->operand(0);
const HloValueSemantics* operand_semantics =
analysis_->GetSemantics(dynamic_slice_operand);
const HloValueSemantics* semantics = nullptr;
if (operand_semantics->label() == HloValueSemanticLabel::kStatic ||
operand_semantics->label() == HloValueSemanticLabel::kRandom ||
operand_semantics->label() == HloValueSemanticLabel::kWeight) {
semantics = analysis_->NewHloValueSemantics(operand_semantics->label(),
{dynamic_slice, {}});
} else {
HloValueSemantics semantics_value = CopySemantics(*operand_semantics);
semantics = AddSemantics(semantics_value);
}
const HloValueSemantics* semantics = AddSemantics(*operand_semantics);
ShapeTree<const HloValueSemantics*> semantics_shape_tree(
dynamic_slice->shape(), semantics);
analysis_->SetHloValueSemantics(dynamic_slice, semantics_shape_tree);
Expand Down
5 changes: 3 additions & 2 deletions third_party/xla/xla/service/hlo_value_semantics_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,9 @@ class HloValueSemanticsPropagation : public DfsHloVisitorWithDefault {
bool OriginDependsOn(const HloValueSemantics& semantics,
const HloPosition& origin_dependence,
bool recursive = false) const;
StatusOr<HloValueSemantics> CreateGradientSemantics(
HloInstruction* gradient_candidate) const;
StatusOr<HloValueSemantics> MaybeCreateGradientSemantics(
HloInstruction* gradient_candidate,
HloValueSemanticLabel fallback_label) const;
StatusOr<HloValueSemantics> ComputeSemanticsFromStaticAndOther(
const HloValueSemantics& static_semantics,
const HloValueSemantics& other_semantics,
Expand Down

0 comments on commit f01a271

Please sign in to comment.