Skip to content

Commit

Permalink
Fixed a bug that resulted in incorrect type narrowing for sequence pa…
Browse files Browse the repository at this point in the history
…tterns when the subject expression contains a tuple with an unbounded component. This addresses #7117. (#7156)
  • Loading branch information
erictraut authored Jan 29, 2024
1 parent c8c8cea commit 6457531
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 78 deletions.
185 changes: 109 additions & 76 deletions packages/pyright-internal/src/analyzer/patternMatching.ts
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ function narrowTypeBasedOnSequencePattern(
}

// If this is a supertype of Sequence, we can narrow it to a Sequence type.
if (entry.isPotentialNoMatch) {
if (entry.isPotentialNoMatch && !entry.isTuple) {
const sequenceType = evaluator.getTypingType(pattern, 'Sequence');
if (sequenceType && isInstantiableClass(sequenceType)) {
let typeArgType = evaluator.stripLiteralValue(combineTypes(narrowedEntryTypes));
Expand Down Expand Up @@ -1230,7 +1230,6 @@ function getSequencePatternInfo(
const patternEntryCount = pattern.entries.length;
const patternStarEntryIndex = pattern.starEntryIndex;
const sequenceInfo: SequencePatternInfo[] = [];
const minPatternEntryCount = patternStarEntryIndex === undefined ? patternEntryCount : patternEntryCount - 1;

doForEachSubtype(type, (subtype) => {
const concreteSubtype = evaluator.makeTopLevelTypeVarsConcrete(subtype);
Expand Down Expand Up @@ -1274,91 +1273,125 @@ function getSequencePatternInfo(
const specializedSequence = partiallySpecializeType(mroClassToSpecialize, concreteSubtype) as ClassType;

if (isTupleClass(specializedSequence)) {
if (specializedSequence.tupleTypeArguments) {
if (isUnboundedTupleClass(specializedSequence)) {
sequenceInfo.push({
subtype,
entryTypes: [combineTypes(specializedSequence.tupleTypeArguments.map((t) => t.type))],
isIndeterminateLength: true,
isTuple: true,
isDefiniteNoMatch: false,
});
pushedEntry = true;
} else {
const tupleIndeterminateIndex = specializedSequence.tupleTypeArguments.findIndex(
const typeArgs = specializedSequence.tupleTypeArguments ?? [
{ type: UnknownType.create(), isUnbounded: true },
];

const tupleIndeterminateIndex = typeArgs.findIndex(
(t) => t.isUnbounded || isUnpackedVariadicTypeVar(t.type)
);

// If the tuple contains an indeterminate entry, expand or remove that
// entry to match the length of the pattern if possible.
if (tupleIndeterminateIndex >= 0) {
while (typeArgs.length < patternEntryCount) {
typeArgs.splice(tupleIndeterminateIndex, 0, typeArgs[tupleIndeterminateIndex]);
}

if (typeArgs.length > patternEntryCount) {
typeArgs.splice(tupleIndeterminateIndex, 1);
}
}

// If the pattern contains a star entry and there are too many entries
// in the tuple, we can collapse some of them into the star entry.
if (
patternStarEntryIndex !== undefined &&
typeArgs.length >= 2 &&
typeArgs.length > patternEntryCount
) {
const entriesToCombine = typeArgs.length - patternEntryCount + 1;
const removedEntries = typeArgs.splice(patternStarEntryIndex, entriesToCombine);
typeArgs.splice(patternStarEntryIndex, 0, {
type: combineTypes(removedEntries.map((t) => t.type)),
isUnbounded: removedEntries.every(
(t) => t.isUnbounded || isUnpackedVariadicTypeVar(t.type)
);
let minTupleLength = specializedSequence.tupleTypeArguments.length;
if (tupleIndeterminateIndex >= 0) {
minTupleLength -= 1;
}
),
});
}

if (minTupleLength >= minPatternEntryCount) {
let isDefiniteNoMatch = false;
const leftLength = Math.min(
patternStarEntryIndex !== undefined ? patternStarEntryIndex : patternEntryCount,
tupleIndeterminateIndex >= 0
? tupleIndeterminateIndex
: specializedSequence.tupleTypeArguments.length
);
if (typeArgs.length === patternEntryCount) {
let isDefiniteNoMatch = false;
let isPotentialNoMatch = tupleIndeterminateIndex >= 0;
if (patternStarEntryIndex !== undefined && patternEntryCount === 1) {
isPotentialNoMatch = false;
}

for (let i = 0; i < leftLength; i++) {
const leftPattern = pattern.entries[i];
const leftType = specializedSequence.tupleTypeArguments[i].type;
const narrowedType = narrowTypeBasedOnPattern(
evaluator,
leftType,
leftPattern,
/* isPositiveTest */ true
);
for (let i = 0; i < patternEntryCount; i++) {
const subPattern = pattern.entries[i];
const typeArg = typeArgs[i].type;
const narrowedType = narrowTypeBasedOnPattern(
evaluator,
typeArg,
subPattern,
/* isPositiveTest */ true
);

if (isNever(narrowedType)) {
isDefiniteNoMatch = true;
}
}
if (isNever(narrowedType)) {
isDefiniteNoMatch = true;
}
}

if (patternStarEntryIndex !== undefined || tupleIndeterminateIndex >= 0) {
const rightLength = Math.min(
patternStarEntryIndex !== undefined
? patternEntryCount - patternStarEntryIndex - 1
: patternEntryCount,
tupleIndeterminateIndex >= 0
? specializedSequence.tupleTypeArguments.length - tupleIndeterminateIndex
: specializedSequence.tupleTypeArguments.length
);
sequenceInfo.push({
subtype,
entryTypes: isDefiniteNoMatch ? [] : typeArgs.map((t) => t.type),
isIndeterminateLength: false,
isTuple: true,
isDefiniteNoMatch,
isPotentialNoMatch,
});
pushedEntry = true;
}

for (let i = 0; i < rightLength; i++) {
const rightPattern = pattern.entries[patternEntryCount - i - 1];
const rightType =
specializedSequence.tupleTypeArguments[
specializedSequence.tupleTypeArguments.length - i - 1
].type;
const narrowedType = narrowTypeBasedOnPattern(
evaluator,
rightType,
rightPattern,
/* isPositiveTest */ true
);
// If the pattern contains a star entry and the pattern associated with
// the star entry is unbounded, we can remove it completely under the
// assumption that the star pattern will capture nothing.
if (patternStarEntryIndex !== undefined) {
let tryMatchStarSequence = false;

if (typeArgs.length === patternEntryCount - 1) {
tryMatchStarSequence = true;
typeArgs.splice(patternStarEntryIndex, 0, {
type: AnyType.create(),
isUnbounded: true,
});
} else if (
typeArgs.length === patternEntryCount &&
typeArgs[patternStarEntryIndex].isUnbounded
) {
tryMatchStarSequence = true;
}

if (isNever(narrowedType)) {
isDefiniteNoMatch = true;
}
}
if (tryMatchStarSequence) {
let isDefiniteNoMatch = false;

for (let i = 0; i < patternEntryCount; i++) {
if (i === patternStarEntryIndex) {
continue;
}

if (patternStarEntryIndex !== undefined || minTupleLength === minPatternEntryCount) {
sequenceInfo.push({
subtype,
entryTypes: isDefiniteNoMatch
? []
: specializedSequence.tupleTypeArguments.map((t) => t.type),
isIndeterminateLength: false,
isTuple: true,
isDefiniteNoMatch,
});
pushedEntry = true;
const subPattern = pattern.entries[i];
const typeArg = typeArgs[i].type;
const narrowedType = narrowTypeBasedOnPattern(
evaluator,
typeArg,
subPattern,
/* isPositiveTest */ true
);

if (isNever(narrowedType)) {
isDefiniteNoMatch = true;
}
}

sequenceInfo.push({
subtype,
entryTypes: isDefiniteNoMatch ? [] : typeArgs.map((t) => t.type),
isIndeterminateLength: false,
isTuple: true,
isDefiniteNoMatch,
});
pushedEntry = true;
}
}
} else {
Expand Down
25 changes: 23 additions & 2 deletions packages/pyright-internal/src/tests/samples/matchSequence1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from enum import Enum
from typing import Any, Generic, Iterator, List, Literal, Protocol, Reversible, Sequence, Tuple, TypeVar, Union
from typing_extensions import Unpack

def test_unknown(value_to_match):
match value_to_match:
Expand Down Expand Up @@ -200,7 +201,7 @@ def test_union(value_to_match: Union[Tuple[complex, complex], Tuple[int, str, fl

case d1, *d2, d3 if value_to_match[0] == 0:
reveal_type(d1, expected_text="complex | int | str | float | Any")
reveal_type(d2, expected_text="list[str | float] | list[str] | list[float] | list[Any]")
reveal_type(d2, expected_text="list[Any] | list[str | float] | list[str] | list[float]")
reveal_type(d3, expected_text="complex | str | float | Any")
reveal_type(value_to_match, expected_text="Tuple[complex, complex] | Tuple[int, str, float, complex] | List[str] | Tuple[float, ...] | Sequence[Any]")

Expand Down Expand Up @@ -366,7 +367,7 @@ def test_negative_narrowing1(subj: tuple[Literal[0]] | tuple[Literal[1]]):
match subj:
case (1,*a) | (*a):
reveal_type(subj, expected_text="tuple[Literal[1]] | tuple[Literal[0]]")
reveal_type(a, expected_text="list[int]")
reveal_type(a, expected_text="list[Any] | list[int]")

case b:
reveal_type(subj, expected_text="Never")
Expand Down Expand Up @@ -458,3 +459,23 @@ def test_tuple_with_subpattern(
case (MyEnum.C, b):
reveal_type(subj, expected_text="tuple[Literal[MyEnum.C], str]")
reveal_type(b, expected_text="str")


def test_unbounded_tuple(
subj: tuple[int] | tuple[str, str] | tuple[int, Unpack[tuple[str, ...]], complex]
):
match subj:
case (x,):
reveal_type(subj, expected_text="tuple[int]")
reveal_type(x, expected_text="int")

case (x, y):
reveal_type(subj, expected_text="tuple[str, str] | tuple[int, complex]")
reveal_type(x, expected_text="str | int")
reveal_type(y, expected_text="str | complex")

case (x, y, z):
reveal_type(subj, expected_text="tuple[int, str, complex]")
reveal_type(x, expected_text="int")
reveal_type(y, expected_text="str")
reveal_type(z, expected_text="complex")

0 comments on commit 6457531

Please sign in to comment.