Skip to content

Commit

Permalink
Added check for classes that have mutually-incompatible base classes …
Browse files Browse the repository at this point in the history
…due to generic type argument mismatches. This addresses #5748. (#5762)

Co-authored-by: Eric Traut <[email protected]>
  • Loading branch information
erictraut and msfterictraut authored Aug 18, 2023
1 parent b9f4615 commit bd54869
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 2 deletions.
93 changes: 93 additions & 0 deletions packages/pyright-internal/src/analyzer/checker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ import {
ClassMember,
ClassMemberLookupFlags,
applySolvedTypeVars,
buildTypeVarContextFromSpecializedClass,
convertToInstance,
derivesFromAnyOrUnknown,
derivesFromClassRecursive,
Expand Down Expand Up @@ -382,6 +383,8 @@ export class Checker extends ParseTreeWalker {
this._validateSlotsClassVarConflict(classTypeResult.classType);
}

this._validateMultipleInheritanceBaseClasses(classTypeResult.classType, node.name);

this._validateMultipleInheritanceCompatibility(classTypeResult.classType, node.name);

this._validateConstructorConsistency(classTypeResult.classType);
Expand Down Expand Up @@ -5145,6 +5148,96 @@ export class Checker extends ParseTreeWalker {
}
}

// Verifies that classes that have more than one base class do not have
// have conflicting type arguments.
private _validateMultipleInheritanceBaseClasses(classType: ClassType, errorNode: ParseNode) {
// Skip this check if the class has only one base class or one or more
// of the base classes are Any.
const filteredBaseClasses: ClassType[] = [];
for (const baseClass of classType.details.baseClasses) {
if (!isClass(baseClass)) {
return;
}

if (!ClassType.isBuiltIn(baseClass, ['Generic', 'Protocol', 'object'])) {
filteredBaseClasses.push(baseClass);
}
}

if (filteredBaseClasses.length < 2) {
return;
}

const diagAddendum = new DiagnosticAddendum();

for (const baseClass of filteredBaseClasses) {
const typeVarContext = buildTypeVarContextFromSpecializedClass(baseClass);

for (const baseClassMroClass of baseClass.details.mro) {
// There's no need to check for conflicts if this class isn't generic.
if (isClass(baseClassMroClass) && baseClassMroClass.details.typeParameters.length > 0) {
const specializedBaseClassMroClass = applySolvedTypeVars(
baseClassMroClass,
typeVarContext
) as ClassType;

// Find the corresponding class in the derived class's MRO list.
const matchingMroClass = classType.details.mro.find(
(mroClass) =>
isClass(mroClass) && ClassType.isSameGenericClass(mroClass, specializedBaseClassMroClass)
);

if (matchingMroClass && isInstantiableClass(matchingMroClass)) {
const matchingMroObject = ClassType.cloneAsInstance(matchingMroClass);
const baseClassMroObject = ClassType.cloneAsInstance(specializedBaseClassMroClass);

// If the types match exactly, we can shortcut the remainder of the MRO chain.
// if (isTypeSame(matchingMroObject, baseClassMroObject)) {
// break;
// }

if (!this._evaluator.assignType(matchingMroObject, baseClassMroObject)) {
const diag = new DiagnosticAddendum();
const baseClassObject = convertToInstance(baseClass);

if (isTypeSame(baseClassObject, baseClassMroObject)) {
diag.addMessage(
Localizer.DiagnosticAddendum.baseClassIncompatible().format({
baseClass: this._evaluator.printType(baseClassObject),
type: this._evaluator.printType(matchingMroObject),
})
);
} else {
diag.addMessage(
Localizer.DiagnosticAddendum.baseClassIncompatibleSubclass().format({
baseClass: this._evaluator.printType(baseClassObject),
subclass: this._evaluator.printType(baseClassMroObject),
type: this._evaluator.printType(matchingMroObject),
})
);
}

diagAddendum.addAddendum(diag);

// Break out of the inner loop so we don't report any redundant errors for this base class.
break;
}
}
}
}
}

if (!diagAddendum.isEmpty()) {
this._evaluator.addDiagnostic(
this._fileInfo.diagnosticRuleSet.reportGeneralTypeIssues,
DiagnosticRule.reportGeneralTypeIssues,
Localizer.Diagnostic.baseClassIncompatible().format({ type: classType.details.name }) +
diagAddendum.getString(),
errorNode
);
}
}

// Validates that any methods and variables in multiple base classes are
// compatible with each other.
private _validateMultipleInheritanceCompatibility(classType: ClassType, errorNode: ParseNode) {
Expand Down
12 changes: 11 additions & 1 deletion packages/pyright-internal/src/localization/localize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,11 @@ export namespace Localizer {
export const awaitNotInAsync = () => getRawString('Diagnostic.awaitNotInAsync');
export const backticksIllegal = () => getRawString('Diagnostic.backticksIllegal');
export const baseClassCircular = () => getRawString('Diagnostic.baseClassCircular');
export const baseClassInvalid = () => getRawString('Diagnostic.baseClassInvalid');
export const baseClassFinal = () =>
new ParameterizedString<{ type: string }>(getRawString('Diagnostic.baseClassFinal'));
export const baseClassIncompatible = () =>
new ParameterizedString<{ type: string }>(getRawString('Diagnostic.baseClassIncompatible'));
export const baseClassInvalid = () => getRawString('Diagnostic.baseClassInvalid');
export const baseClassMethodTypeIncompatible = () =>
new ParameterizedString<{ classType: string; name: string }>(
getRawString('Diagnostic.baseClassMethodTypeIncompatible')
Expand Down Expand Up @@ -1116,6 +1118,14 @@ export namespace Localizer {
new ParameterizedString<{ types: string }>(getRawString('DiagnosticAddendum.argumentTypes'));
export const assignToNone = () => getRawString('DiagnosticAddendum.assignToNone');
export const asyncHelp = () => getRawString('DiagnosticAddendum.asyncHelp');
export const baseClassIncompatible = () =>
new ParameterizedString<{ baseClass: string; type: string }>(
getRawString('DiagnosticAddendum.baseClassIncompatible')
);
export const baseClassIncompatibleSubclass = () =>
new ParameterizedString<{ baseClass: string; subclass: string; type: string }>(
getRawString('DiagnosticAddendum.baseClassIncompatibleSubclass')
);
export const baseClassOverriddenType = () =>
new ParameterizedString<{ baseClass: string; type: string }>(
getRawString('DiagnosticAddendum.baseClassOverriddenType')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"awaitNotInAsync": "\"await\" allowed only within async function",
"backticksIllegal": "Expressions surrounded by backticks are not supported in Python 3.x; use repr instead",
"baseClassCircular": "Class cannot derive from itself",
"baseClassIncompatible": "Base classes of {type} are mutually incompatible",
"baseClassFinal": "Base class \"{type}\" is marked final and cannot be subclassed",
"baseClassInvalid": "Argument to class must be a base class",
"baseClassMethodTypeIncompatible": "Base classes for class \"{classType}\" define method \"{name}\" in incompatible way",
Expand Down Expand Up @@ -578,6 +579,8 @@
"argumentTypes": "Argument types: ({types})",
"assignToNone": "Type cannot be assigned to type \"None\"",
"asyncHelp": "Did you mean \"async with\"?",
"baseClassIncompatible": "Base class \"{baseClass}\" is incompatible with type \"{type}\"",
"baseClassIncompatibleSubclass": "Base class \"{baseClass}\" derives from \"{subclass}\" which is incompatible with type \"{type}\"",
"baseClassOverriddenType": "Base class \"{baseClass}\" provides type \"{type}\", which is overridden",
"baseClassOverridesType": "Base class \"{baseClass}\" overrides with type \"{type}\"",
"conditionalRequiresBool": "Method __bool__ for type \"{operandType}\" returns type \"{boolReturnType}\" rather than \"bool\"",
Expand Down
45 changes: 45 additions & 0 deletions packages/pyright-internal/src/tests/samples/classes11.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# This sample tests the detection of mutually-incompatible base classes
# in classes that use multiple inheritance.

from typing import Collection, Mapping, Sequence, TypeVar


# This should generate an error.
class A(Mapping[str, int], Collection[int]):
...


# This should generate an error.
class B(Mapping[str, int], Sequence[int]):
...


# This should generate an error.
class C(Sequence[int], Mapping[str, int]):
...


class D(Sequence[float], Mapping[float, int]):
...


class E(Sequence[float], Mapping[int, int]):
...


# This should generate an error.
class F(Mapping[int, int], Sequence[float]):
...


T = TypeVar("T")
S = TypeVar("S")


class G(Mapping[T, S], Collection[T]):
...


# This should generate an error.
class H(Mapping[T, S], Collection[S]):
...
2 changes: 1 addition & 1 deletion packages/pyright-internal/src/tests/samples/classes7.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class BaseClass(Generic[T]):
pass


IntBaseClass = BaseClass[int]
IntBaseClass = BaseClass[float]


# This should generate an error because the same
Expand Down
6 changes: 6 additions & 0 deletions packages/pyright-internal/src/tests/typeEvaluator3.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,12 @@ test('Classes10', () => {
TestUtils.validateResults(analysisResults, 0);
});

test('Classes11', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['classes11.py']);

TestUtils.validateResults(analysisResults, 5);
});

test('Methods1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['methods1.py']);

Expand Down

0 comments on commit bd54869

Please sign in to comment.