diff --git a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc b/third_party/xla/xla/service/hlo_value_semantics_analysis.cc index cc407ea2bea75b..7e1bd0e87e5af4 100644 --- a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc +++ b/third_party/xla/xla/service/hlo_value_semantics_analysis.cc @@ -920,7 +920,8 @@ HloValueSemanticsPropagation::ComputeSemanticsFromStaticAndOther( instruction->opcode() == HloOpcode::kConvolution; if (is_dot_or_convolution && other_semantics.label() == HloValueSemanticLabel::kActivationGradient) { - return CreateGradientSemantics(instruction); + return MaybeCreateGradientSemantics( + instruction, HloValueSemanticLabel::kActivationGradient); } return CopySemantics(other_semantics); } @@ -939,8 +940,9 @@ HloValueSemanticsPropagation::ComputeSemanticsFromRandomAndOther( } StatusOr -HloValueSemanticsPropagation::CreateGradientSemantics( - HloInstruction* gradient_candidate) const { +HloValueSemanticsPropagation::MaybeCreateGradientSemantics( + HloInstruction* gradient_candidate, + HloValueSemanticLabel fallback_label) const { const EinsumDepthMap& einsum_depth_map = analysis_->GetEinsumDepthMap(); auto depth_iter = einsum_depth_map.find(gradient_candidate); CHECK(depth_iter != einsum_depth_map.end()); @@ -956,8 +958,7 @@ HloValueSemanticsPropagation::CreateGradientSemantics( return HloValueSemantics(HloValueSemanticLabel::kWeightGradient, {gradient_candidate, {}}); } - return HloValueSemantics(HloValueSemanticLabel::kActivationGradient, - {gradient_candidate, {}}); + return HloValueSemantics(fallback_label, {gradient_candidate, {}}); } StatusOr @@ -972,6 +973,9 @@ HloValueSemanticsPropagation::ComputeSemanticsFromWeightAndOther( instruction->opcode() == HloOpcode::kConvolution; if (other_semantics.label() == HloValueSemanticLabel::kWeight) { if (!is_dot_or_convolution) { + if (weight_semantics.origin() == other_semantics.origin()) { + return CopySemantics(other_semantics); + } return CopySemanticsWithNewOrigin(other_semantics, instruction); } return HloValueSemantics(HloValueSemanticLabel::kActivation, @@ -988,7 +992,8 @@ HloValueSemanticsPropagation::ComputeSemanticsFromWeightAndOther( // operand. if (OriginDependsOn(other_semantics, weight_semantics.origin(), /*recursive=*/true)) { - return CreateGradientSemantics(instruction); + return MaybeCreateGradientSemantics( + instruction, HloValueSemanticLabel::kActivationGradient); } return CopySemanticsWithNewOrigin(other_semantics, instruction); } @@ -997,7 +1002,8 @@ HloValueSemanticsPropagation::ComputeSemanticsFromWeightAndOther( // which produce an Activation. The ActivationGradient to this Activation // could be used in an einsum with one of the Weights to compute // the WeightGradient for the other Weight. - return CreateGradientSemantics(instruction); + return MaybeCreateGradientSemantics( + instruction, HloValueSemanticLabel::kActivationGradient); } CHECK(other_semantics.label() == HloValueSemanticLabel::kWeightGradient); return CopySemantics(other_semantics); @@ -1015,14 +1021,16 @@ HloValueSemanticsPropagation::ComputeSemanticsFromActivationAndOther( bool is_dot_or_convolution = instruction->opcode() == HloOpcode::kDot || instruction->opcode() == HloOpcode::kConvolution; if (!is_dot_or_convolution) { + if (activation_semantics.origin() == other_semantics.origin()) { + return CopySemantics(other_semantics); + } return CopySemanticsWithNewOrigin(other_semantics, instruction); } if (other_semantics.label() == HloValueSemanticLabel::kActivation) { // Like said above, since loss is classified as Activation, an einsum - // between an Activation X and an Activation Y could be WeightGradient or - // even ActivationGradient when either X or Y is the loss. This case is - // different from other Activation einsums because there must a dependency - // between X and Y. + // between an Activation X and an Activation Y could be WeightGradient if + // either X or Y is the loss. This case is different from other Activation + // einsums because there must a dependency between X and Y. bool other_depends_on_activation = OriginDependsOn( other_semantics, activation_semantics.origin(), /*recursive=*/true); bool activation_depends_on_other = @@ -1032,14 +1040,19 @@ HloValueSemanticsPropagation::ComputeSemanticsFromActivationAndOther( // If there is no dependency between the two Activations, the output must // be an Activation. if (other_depends_on_activation || activation_depends_on_other) { - return CreateGradientSemantics(instruction); + // We check if the einsum is actually weight gradient. If it is not, fall + // back to activation, since we expect the loss to be computed from an + // activation-weight einsum. + return MaybeCreateGradientSemantics(instruction, + HloValueSemanticLabel::kActivation); } return CopySemanticsWithNewOrigin(other_semantics, instruction); } if (other_semantics.label() == HloValueSemanticLabel::kActivationGradient) { // An Activation-ActivationGradient einsum could be computing // WeightGradient or ActivationGradient. - return CreateGradientSemantics(instruction); + return MaybeCreateGradientSemantics( + instruction, HloValueSemanticLabel::kActivationGradient); } CHECK(other_semantics.label() == HloValueSemanticLabel::kWeightGradient) << "instruction: " << instruction->ToString() @@ -1407,16 +1420,7 @@ Status HloValueSemanticsPropagation::HandleDynamicSlice( const HloInstruction* dynamic_slice_operand = dynamic_slice->operand(0); const HloValueSemantics* operand_semantics = analysis_->GetSemantics(dynamic_slice_operand); - const HloValueSemantics* semantics = nullptr; - if (operand_semantics->label() == HloValueSemanticLabel::kStatic || - operand_semantics->label() == HloValueSemanticLabel::kRandom || - operand_semantics->label() == HloValueSemanticLabel::kWeight) { - semantics = analysis_->NewHloValueSemantics(operand_semantics->label(), - {dynamic_slice, {}}); - } else { - HloValueSemantics semantics_value = CopySemantics(*operand_semantics); - semantics = AddSemantics(semantics_value); - } + const HloValueSemantics* semantics = AddSemantics(*operand_semantics); ShapeTree semantics_shape_tree( dynamic_slice->shape(), semantics); analysis_->SetHloValueSemantics(dynamic_slice, semantics_shape_tree); diff --git a/third_party/xla/xla/service/hlo_value_semantics_analysis.h b/third_party/xla/xla/service/hlo_value_semantics_analysis.h index fa4d14ad829898..634b13f21ed65c 100644 --- a/third_party/xla/xla/service/hlo_value_semantics_analysis.h +++ b/third_party/xla/xla/service/hlo_value_semantics_analysis.h @@ -346,8 +346,9 @@ class HloValueSemanticsPropagation : public DfsHloVisitorWithDefault { bool OriginDependsOn(const HloValueSemantics& semantics, const HloPosition& origin_dependence, bool recursive = false) const; - StatusOr CreateGradientSemantics( - HloInstruction* gradient_candidate) const; + StatusOr MaybeCreateGradientSemantics( + HloInstruction* gradient_candidate, + HloValueSemanticLabel fallback_label) const; StatusOr ComputeSemanticsFromStaticAndOther( const HloValueSemantics& static_semantics, const HloValueSemantics& other_semantics,