diff --git a/packages/pyright-internal/src/analyzer/patternMatching.ts b/packages/pyright-internal/src/analyzer/patternMatching.ts index 507d502bf52a..4664ff77b9db 100644 --- a/packages/pyright-internal/src/analyzer/patternMatching.ts +++ b/packages/pyright-internal/src/analyzer/patternMatching.ts @@ -340,7 +340,7 @@ function narrowTypeBasedOnSequencePattern( } // If this is a supertype of Sequence, we can narrow it to a Sequence type. - if (entry.isPotentialNoMatch) { + if (entry.isPotentialNoMatch && !entry.isTuple) { const sequenceType = evaluator.getTypingType(pattern, 'Sequence'); if (sequenceType && isInstantiableClass(sequenceType)) { let typeArgType = evaluator.stripLiteralValue(combineTypes(narrowedEntryTypes)); @@ -1230,7 +1230,6 @@ function getSequencePatternInfo( const patternEntryCount = pattern.entries.length; const patternStarEntryIndex = pattern.starEntryIndex; const sequenceInfo: SequencePatternInfo[] = []; - const minPatternEntryCount = patternStarEntryIndex === undefined ? patternEntryCount : patternEntryCount - 1; doForEachSubtype(type, (subtype) => { const concreteSubtype = evaluator.makeTopLevelTypeVarsConcrete(subtype); @@ -1274,91 +1273,125 @@ function getSequencePatternInfo( const specializedSequence = partiallySpecializeType(mroClassToSpecialize, concreteSubtype) as ClassType; if (isTupleClass(specializedSequence)) { - if (specializedSequence.tupleTypeArguments) { - if (isUnboundedTupleClass(specializedSequence)) { - sequenceInfo.push({ - subtype, - entryTypes: [combineTypes(specializedSequence.tupleTypeArguments.map((t) => t.type))], - isIndeterminateLength: true, - isTuple: true, - isDefiniteNoMatch: false, - }); - pushedEntry = true; - } else { - const tupleIndeterminateIndex = specializedSequence.tupleTypeArguments.findIndex( + const typeArgs = specializedSequence.tupleTypeArguments ?? [ + { type: UnknownType.create(), isUnbounded: true }, + ]; + + const tupleIndeterminateIndex = typeArgs.findIndex( + (t) => t.isUnbounded || isUnpackedVariadicTypeVar(t.type) + ); + + // If the tuple contains an indeterminate entry, expand or remove that + // entry to match the length of the pattern if possible. + if (tupleIndeterminateIndex >= 0) { + while (typeArgs.length < patternEntryCount) { + typeArgs.splice(tupleIndeterminateIndex, 0, typeArgs[tupleIndeterminateIndex]); + } + + if (typeArgs.length > patternEntryCount) { + typeArgs.splice(tupleIndeterminateIndex, 1); + } + } + + // If the pattern contains a star entry and there are too many entries + // in the tuple, we can collapse some of them into the star entry. + if ( + patternStarEntryIndex !== undefined && + typeArgs.length >= 2 && + typeArgs.length > patternEntryCount + ) { + const entriesToCombine = typeArgs.length - patternEntryCount + 1; + const removedEntries = typeArgs.splice(patternStarEntryIndex, entriesToCombine); + typeArgs.splice(patternStarEntryIndex, 0, { + type: combineTypes(removedEntries.map((t) => t.type)), + isUnbounded: removedEntries.every( (t) => t.isUnbounded || isUnpackedVariadicTypeVar(t.type) - ); - let minTupleLength = specializedSequence.tupleTypeArguments.length; - if (tupleIndeterminateIndex >= 0) { - minTupleLength -= 1; - } + ), + }); + } - if (minTupleLength >= minPatternEntryCount) { - let isDefiniteNoMatch = false; - const leftLength = Math.min( - patternStarEntryIndex !== undefined ? patternStarEntryIndex : patternEntryCount, - tupleIndeterminateIndex >= 0 - ? tupleIndeterminateIndex - : specializedSequence.tupleTypeArguments.length - ); + if (typeArgs.length === patternEntryCount) { + let isDefiniteNoMatch = false; + let isPotentialNoMatch = tupleIndeterminateIndex >= 0; + if (patternStarEntryIndex !== undefined && patternEntryCount === 1) { + isPotentialNoMatch = false; + } - for (let i = 0; i < leftLength; i++) { - const leftPattern = pattern.entries[i]; - const leftType = specializedSequence.tupleTypeArguments[i].type; - const narrowedType = narrowTypeBasedOnPattern( - evaluator, - leftType, - leftPattern, - /* isPositiveTest */ true - ); + for (let i = 0; i < patternEntryCount; i++) { + const subPattern = pattern.entries[i]; + const typeArg = typeArgs[i].type; + const narrowedType = narrowTypeBasedOnPattern( + evaluator, + typeArg, + subPattern, + /* isPositiveTest */ true + ); - if (isNever(narrowedType)) { - isDefiniteNoMatch = true; - } - } + if (isNever(narrowedType)) { + isDefiniteNoMatch = true; + } + } - if (patternStarEntryIndex !== undefined || tupleIndeterminateIndex >= 0) { - const rightLength = Math.min( - patternStarEntryIndex !== undefined - ? patternEntryCount - patternStarEntryIndex - 1 - : patternEntryCount, - tupleIndeterminateIndex >= 0 - ? specializedSequence.tupleTypeArguments.length - tupleIndeterminateIndex - : specializedSequence.tupleTypeArguments.length - ); + sequenceInfo.push({ + subtype, + entryTypes: isDefiniteNoMatch ? [] : typeArgs.map((t) => t.type), + isIndeterminateLength: false, + isTuple: true, + isDefiniteNoMatch, + isPotentialNoMatch, + }); + pushedEntry = true; + } - for (let i = 0; i < rightLength; i++) { - const rightPattern = pattern.entries[patternEntryCount - i - 1]; - const rightType = - specializedSequence.tupleTypeArguments[ - specializedSequence.tupleTypeArguments.length - i - 1 - ].type; - const narrowedType = narrowTypeBasedOnPattern( - evaluator, - rightType, - rightPattern, - /* isPositiveTest */ true - ); + // If the pattern contains a star entry and the pattern associated with + // the star entry is unbounded, we can remove it completely under the + // assumption that the star pattern will capture nothing. + if (patternStarEntryIndex !== undefined) { + let tryMatchStarSequence = false; + + if (typeArgs.length === patternEntryCount - 1) { + tryMatchStarSequence = true; + typeArgs.splice(patternStarEntryIndex, 0, { + type: AnyType.create(), + isUnbounded: true, + }); + } else if ( + typeArgs.length === patternEntryCount && + typeArgs[patternStarEntryIndex].isUnbounded + ) { + tryMatchStarSequence = true; + } - if (isNever(narrowedType)) { - isDefiniteNoMatch = true; - } - } + if (tryMatchStarSequence) { + let isDefiniteNoMatch = false; + + for (let i = 0; i < patternEntryCount; i++) { + if (i === patternStarEntryIndex) { + continue; } - if (patternStarEntryIndex !== undefined || minTupleLength === minPatternEntryCount) { - sequenceInfo.push({ - subtype, - entryTypes: isDefiniteNoMatch - ? [] - : specializedSequence.tupleTypeArguments.map((t) => t.type), - isIndeterminateLength: false, - isTuple: true, - isDefiniteNoMatch, - }); - pushedEntry = true; + const subPattern = pattern.entries[i]; + const typeArg = typeArgs[i].type; + const narrowedType = narrowTypeBasedOnPattern( + evaluator, + typeArg, + subPattern, + /* isPositiveTest */ true + ); + + if (isNever(narrowedType)) { + isDefiniteNoMatch = true; } } + + sequenceInfo.push({ + subtype, + entryTypes: isDefiniteNoMatch ? [] : typeArgs.map((t) => t.type), + isIndeterminateLength: false, + isTuple: true, + isDefiniteNoMatch, + }); + pushedEntry = true; } } } else { diff --git a/packages/pyright-internal/src/tests/samples/matchSequence1.py b/packages/pyright-internal/src/tests/samples/matchSequence1.py index 6f3221402666..a9efb3be8b11 100644 --- a/packages/pyright-internal/src/tests/samples/matchSequence1.py +++ b/packages/pyright-internal/src/tests/samples/matchSequence1.py @@ -3,6 +3,7 @@ from enum import Enum from typing import Any, Generic, Iterator, List, Literal, Protocol, Reversible, Sequence, Tuple, TypeVar, Union +from typing_extensions import Unpack def test_unknown(value_to_match): match value_to_match: @@ -200,7 +201,7 @@ def test_union(value_to_match: Union[Tuple[complex, complex], Tuple[int, str, fl case d1, *d2, d3 if value_to_match[0] == 0: reveal_type(d1, expected_text="complex | int | str | float | Any") - reveal_type(d2, expected_text="list[str | float] | list[str] | list[float] | list[Any]") + reveal_type(d2, expected_text="list[Any] | list[str | float] | list[str] | list[float]") reveal_type(d3, expected_text="complex | str | float | Any") reveal_type(value_to_match, expected_text="Tuple[complex, complex] | Tuple[int, str, float, complex] | List[str] | Tuple[float, ...] | Sequence[Any]") @@ -366,7 +367,7 @@ def test_negative_narrowing1(subj: tuple[Literal[0]] | tuple[Literal[1]]): match subj: case (1,*a) | (*a): reveal_type(subj, expected_text="tuple[Literal[1]] | tuple[Literal[0]]") - reveal_type(a, expected_text="list[int]") + reveal_type(a, expected_text="list[Any] | list[int]") case b: reveal_type(subj, expected_text="Never") @@ -458,3 +459,23 @@ def test_tuple_with_subpattern( case (MyEnum.C, b): reveal_type(subj, expected_text="tuple[Literal[MyEnum.C], str]") reveal_type(b, expected_text="str") + + +def test_unbounded_tuple( + subj: tuple[int] | tuple[str, str] | tuple[int, Unpack[tuple[str, ...]], complex] +): + match subj: + case (x,): + reveal_type(subj, expected_text="tuple[int]") + reveal_type(x, expected_text="int") + + case (x, y): + reveal_type(subj, expected_text="tuple[str, str] | tuple[int, complex]") + reveal_type(x, expected_text="str | int") + reveal_type(y, expected_text="str | complex") + + case (x, y, z): + reveal_type(subj, expected_text="tuple[int, str, complex]") + reveal_type(x, expected_text="int") + reveal_type(y, expected_text="str") + reveal_type(z, expected_text="complex")