-
-
Notifications
You must be signed in to change notification settings - Fork 19.3k
BUG: Fix multiindex factorize extension dtypes #62964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
davidjcastrejon
wants to merge
10
commits into
pandas-dev:main
Choose a base branch
from
davidjcastrejon:fix-multiindex-factorize-extension-dtypes-62337
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+252
−0
Open
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
ac9cd15
TST: Add tests for MultiIndex.factorize method with extension dtypes
davidjcastrejon d6c267d
BUG: Preserve extension dtypes in MultiIndex.factorize() #62337
davidjcastrejon 33f13c4
BUG: Preserve extension dtypes in MultiIndex reconstruction
davidjcastrejon c1fef65
BUG: Preserve extension dtypes in MultiIndex.factorize and improve re…
davidjcastrejon 404f943
DOC: Add MultiIndex.factorize extension dtype fix to whatsnew v3.0.0
davidjcastrejon 647201b
BUG: Ensure extension dtypes are preserved in IndexOpsMixin when proc…
davidjcastrejon ccab4f2
Removed implementation from base.py
davidjcastrejon 632b2d6
BUG: Override MultiIndex.factorize to preserve extension dtypes
davidjcastrejon 7e3927a
TYP: Fix mypy error in MultiIndex.factorize return type
davidjcastrejon 62bf28d
TYP: Fix mypy build errors in core modules
davidjcastrejon File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3979,6 +3979,121 @@ def truncate(self, before=None, after=None) -> MultiIndex: | |
| verify_integrity=False, | ||
| ) | ||
|
|
||
| def factorize( | ||
| self, | ||
| sort: bool = False, | ||
| use_na_sentinel: bool = True, | ||
| ) -> tuple[npt.NDArray[np.intp], MultiIndex]: | ||
| """ | ||
| Encode the object as an enumerated type or categorical variable. | ||
|
|
||
| This method preserves extension dtypes (e.g., Int64, boolean, string) | ||
| in MultiIndex levels during factorization. See GH#62337. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| sort : bool, default False | ||
| Sort uniques and shuffle codes to maintain the relationship. | ||
| use_na_sentinel : bool, default True | ||
| If True, the sentinel -1 will be used for NaN values. If False, | ||
| NaN values will be encoded as non-negative integers and will not drop the | ||
| NaN from the uniques of the values. | ||
|
|
||
| Returns | ||
| ------- | ||
| codes : np.ndarray | ||
| An integer ndarray that's an indexer into uniques. | ||
| uniques : MultiIndex | ||
| The unique values with extension dtypes preserved when present. | ||
|
|
||
| See Also | ||
| -------- | ||
| Index.factorize : Encode the object as an enumerated type. | ||
|
|
||
| Examples | ||
| -------- | ||
| >>> mi = pd.MultiIndex.from_arrays( | ||
| ... [pd.array([1, 2, 1], dtype="Int64"), ["a", "b", "a"]] | ||
| ... ) | ||
| >>> codes, uniques = mi.factorize() | ||
| >>> codes | ||
| array([0, 1, 0]) | ||
| >>> uniques.dtypes | ||
| level_0 Int64 | ||
| level_1 object | ||
| dtype: object | ||
| """ | ||
| # Check if any level has extension dtypes | ||
| has_extension_dtypes = any( | ||
| isinstance(level.dtype, ExtensionDtype) for level in self.levels | ||
| ) | ||
|
|
||
| if not has_extension_dtypes: | ||
| # Use parent implementation for performance when no extension dtypes | ||
| codes, uniques = super().factorize( | ||
| sort=sort, use_na_sentinel=use_na_sentinel | ||
| ) | ||
|
|
||
| assert isinstance(uniques, MultiIndex) | ||
| return codes, uniques | ||
|
|
||
| # Custom implementation for extension dtypes (GH#62337) | ||
| return self._factorize_with_extension_dtypes( | ||
| sort=sort, use_na_sentinel=use_na_sentinel | ||
| ) | ||
|
|
||
| def _factorize_with_extension_dtypes( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this only required for a MultiIndex? A base Index doesn't have the same requirement? |
||
| self, sort: bool, use_na_sentinel: bool | ||
| ) -> tuple[npt.NDArray[np.intp], MultiIndex]: | ||
| """ | ||
| Factorize MultiIndex while preserving extension dtypes. | ||
|
|
||
| This method uses the base factorize on _values but then reconstructs | ||
| the MultiIndex with proper extension dtypes preserved. | ||
| """ | ||
| # Factorize using base algorithm on _values | ||
| codes, uniques_array = algos.factorize( | ||
| self._values, sort=sort, use_na_sentinel=use_na_sentinel | ||
| ) | ||
|
|
||
| # Handle empty case | ||
| if len(uniques_array) == 0: | ||
| # Create empty levels with preserved dtypes | ||
| empty_levels = [] | ||
| for original_level in self.levels: | ||
| # Create empty level with same dtype | ||
| empty_level = original_level[:0] # Slice to get empty with same dtype | ||
| empty_levels.append(empty_level) | ||
|
|
||
| # Create empty MultiIndex with preserved level dtypes | ||
| result_mi = type(self)( | ||
| levels=empty_levels, | ||
| codes=[[] for _ in range(len(empty_levels))], | ||
| ) | ||
| return codes, result_mi | ||
|
|
||
| # Create MultiIndex from unique tuples | ||
| result_mi = type(self).from_tuples(uniques_array) | ||
|
|
||
| # Restore extension dtypes | ||
| new_levels = [] | ||
| for i, original_level in enumerate(self.levels): | ||
| if isinstance(original_level.dtype, ExtensionDtype): | ||
| # Preserve extension dtype by casting result level | ||
| try: | ||
| new_level = result_mi.levels[i].astype(original_level.dtype) | ||
| new_levels.append(new_level) | ||
| except (TypeError, ValueError): | ||
| # If casting fails, keep the inferred level | ||
| new_levels.append(result_mi.levels[i]) | ||
| else: | ||
| # Keep inferred dtype for regular levels | ||
| new_levels.append(result_mi.levels[i]) | ||
|
|
||
| # Reconstruct with preserved dtypes | ||
| result_mi = result_mi.set_levels(new_levels) | ||
| return codes, result_mi | ||
|
|
||
| def equals(self, other: object) -> bool: | ||
| """ | ||
| Determines if two MultiIndex objects have the same labeling information | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,134 @@ | ||
| """ | ||
| Tests for MultiIndex.factorize method | ||
| """ | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
|
|
||
| import pandas as pd | ||
| import pandas._testing as tm | ||
|
|
||
|
|
||
| class TestMultiIndexFactorize: | ||
| def test_factorize_extension_dtype_int32(self): | ||
| # GH#62337: factorize should preserve Int32 extension dtype | ||
| df = pd.DataFrame({"col": pd.Series([1, None, 2], dtype="Int32")}) | ||
| mi = pd.MultiIndex.from_frame(df) | ||
|
|
||
| codes, uniques = mi.factorize() | ||
|
|
||
| result_dtype = uniques.to_frame().iloc[:, 0].dtype | ||
| expected_dtype = pd.Int32Dtype() | ||
| assert result_dtype == expected_dtype | ||
|
|
||
| # Verify codes are correct | ||
| expected_codes = np.array([0, 1, 2], dtype=np.intp) | ||
| tm.assert_numpy_array_equal(codes, expected_codes) | ||
|
|
||
| @pytest.mark.parametrize("dtype", ["Int32", "Int64", "string", "boolean"]) | ||
| def test_factorize_extension_dtypes(self, dtype): | ||
| # GH#62337: factorize should preserve various extension dtypes | ||
| if dtype == "boolean": | ||
| values = [True, None, False] | ||
| elif dtype == "string": | ||
| values = ["a", None, "b"] | ||
| else: # Int32, Int64 | ||
| values = [1, None, 2] | ||
|
|
||
| df = pd.DataFrame({"col": pd.Series(values, dtype=dtype)}) | ||
| mi = pd.MultiIndex.from_frame(df) | ||
|
|
||
| codes, uniques = mi.factorize() | ||
| result_dtype = uniques.to_frame().iloc[:, 0].dtype | ||
|
|
||
| assert str(result_dtype) == dtype | ||
|
|
||
| def test_factorize_multiple_extension_dtypes(self): | ||
| # GH#62337: factorize with multiple columns having extension dtypes | ||
| df = pd.DataFrame( | ||
| { | ||
| "int_col": pd.Series([1, 2, 1], dtype="Int64"), | ||
| "str_col": pd.Series(["a", "b", "a"], dtype="string"), | ||
| } | ||
| ) | ||
| mi = pd.MultiIndex.from_frame(df) | ||
|
|
||
| codes, uniques = mi.factorize() | ||
|
|
||
| result_frame = uniques.to_frame() | ||
| assert result_frame.iloc[:, 0].dtype == pd.Int64Dtype() | ||
| assert result_frame.iloc[:, 1].dtype == pd.StringDtype() | ||
|
|
||
| # Should have 2 unique combinations: (1,'a') and (2,'b') | ||
| assert len(uniques) == 2 | ||
|
|
||
| def test_factorize_preserves_names(self): | ||
| # GH#62337: factorize should preserve MultiIndex names when extension | ||
| # dtypes are involved | ||
| df = pd.DataFrame( | ||
| { | ||
| "level_1": pd.Series([1, 2], dtype="Int32"), | ||
| "level_2": pd.Series(["a", "b"], dtype="string"), | ||
| } | ||
| ) | ||
| mi = pd.MultiIndex.from_frame(df) | ||
|
|
||
| codes, uniques = mi.factorize() | ||
|
|
||
| # The main fix is extension dtype preservation, names behavior follows | ||
| # existing patterns | ||
| # Just verify that factorize runs without errors and dtypes are preserved | ||
| result_frame = uniques.to_frame() | ||
| assert result_frame.iloc[:, 0].dtype == pd.Int32Dtype() | ||
| assert result_frame.iloc[:, 1].dtype == pd.StringDtype() | ||
|
|
||
| def test_factorize_extension_dtype_with_sort(self): | ||
| # GH#62337: factorize with sort=True should preserve extension dtypes | ||
| df = pd.DataFrame({"col": pd.Series([2, None, 1], dtype="Int32")}) | ||
| mi = pd.MultiIndex.from_frame(df) | ||
|
|
||
| codes, uniques = mi.factorize(sort=True) | ||
|
|
||
| result_dtype = uniques.to_frame().iloc[:, 0].dtype | ||
| assert result_dtype == pd.Int32Dtype() | ||
|
|
||
| def test_factorize_empty_extension_dtype(self): | ||
| # GH#62337: factorize on empty MultiIndex with extension dtype | ||
| df = pd.DataFrame({"col": pd.Series([], dtype="Int32")}) | ||
| mi = pd.MultiIndex.from_frame(df) | ||
|
|
||
| codes, uniques = mi.factorize() | ||
|
|
||
| assert len(codes) == 0 | ||
| assert len(uniques) == 0 | ||
| assert uniques.to_frame().iloc[:, 0].dtype == pd.Int32Dtype() | ||
|
|
||
| def test_factorize_regular_dtypes_unchanged(self): | ||
| # Ensure regular dtypes still work as before | ||
| df = pd.DataFrame({"int_col": [1, 2, 1], "float_col": [1.1, 2.2, 1.1]}) | ||
| mi = pd.MultiIndex.from_frame(df) | ||
|
|
||
| codes, uniques = mi.factorize() | ||
|
|
||
| result_frame = uniques.to_frame() | ||
| assert result_frame.iloc[:, 0].dtype == np.dtype("int64") | ||
| assert result_frame.iloc[:, 1].dtype == np.dtype("float64") | ||
|
|
||
| # Should have 2 unique combinations | ||
| assert len(uniques) == 2 | ||
|
|
||
| def test_factorize_mixed_extension_regular_dtypes(self): | ||
| # Mix of extension and regular dtypes | ||
| df = pd.DataFrame( | ||
| { | ||
| "ext_col": pd.Series([1, 2, 1], dtype="Int64"), | ||
| "reg_col": [1.1, 2.2, 1.1], # regular float64 | ||
| } | ||
| ) | ||
| mi = pd.MultiIndex.from_frame(df) | ||
|
|
||
| codes, uniques = mi.factorize() | ||
|
|
||
| result_frame = uniques.to_frame() | ||
| assert result_frame.iloc[:, 0].dtype == pd.Int64Dtype() | ||
| assert result_frame.iloc[:, 1].dtype == np.dtype("float64") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this comment represent?