Skip to content

Commit

Permalink
BUG: Fixed DataFrameGroupBy.transform with numba returning the wrong …
Browse files Browse the repository at this point in the history
…order with non increasing indexes #57069 (#58030)

* Fix #57069: DataFrameGroupBy.transform with numba returning the wrong order with non monotonically increasing indexes

Fixed a bug that was returning the wrong order unless the index was monotonically increasing while utilizing DataFrameGroupBy.transform with engine='numba'
Fixed the test "pandas/tests/groupby/transform/test_numba.py::test_index_data_correctly_passed" to expect a result in the correct order
Added a test "pandas/tests/groupby/transform/test_numba.py::test_index_order_consistency_preserved" to test DataFrameGroupBy.transform with engine='numba' with a decreasing index
Updated whatsnew to reflect changes

* Apply suggestions from code review

Co-authored-by: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com>

* Fixed pre-commit requirements

---------

Co-authored-by: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com>
  • Loading branch information
andremcorreia and mroeschke committed Mar 28, 2024
1 parent b86eb99 commit c468028
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ Bug fixes
- Fixed bug in :meth:`DataFrame.cumsum` which was raising ``IndexError`` if dtype is ``timedelta64[ns]`` (:issue:`57956`)
- Fixed bug in :meth:`DataFrame.join` inconsistently setting result index name (:issue:`55815`)
- Fixed bug in :meth:`DataFrame.to_string` that raised ``StopIteration`` with nested DataFrames. (:issue:`16098`)
- Fixed bug in :meth:`DataFrame.transform` that was returning the wrong order unless the index was monotonically increasing. (:issue:`57069`)
- Fixed bug in :meth:`DataFrame.update` bool dtype being converted to object (:issue:`55509`)
- Fixed bug in :meth:`DataFrameGroupBy.apply` that was returning a completely empty DataFrame when all return values of ``func`` were ``None`` instead of returning an empty DataFrame with the original columns and dtypes. (:issue:`57775`)
- Fixed bug in :meth:`Series.diff` allowing non-integer values for the ``periods`` argument. (:issue:`56607`)
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1439,6 +1439,7 @@ def _transform_with_numba(self, func, *args, engine_kwargs=None, **kwargs):
data and indices into a Numba jitted function.
"""
data = self._obj_with_exclusions
index_sorting = self._grouper.result_ilocs
df = data if data.ndim == 2 else data.to_frame()

starts, ends, sorted_index, sorted_data = self._numba_prep(df)
Expand All @@ -1456,7 +1457,7 @@ def _transform_with_numba(self, func, *args, engine_kwargs=None, **kwargs):
)
# result values needs to be resorted to their original positions since we
# evaluated the data sorted by group
result = result.take(np.argsort(sorted_index), axis=0)
result = result.take(np.argsort(index_sorting), axis=0)
index = data.index
if data.ndim == 1:
result_kwargs = {"name": data.name}
Expand Down
17 changes: 16 additions & 1 deletion pandas/tests/groupby/transform/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,25 @@ def f(values, index):

df = DataFrame({"group": ["A", "A", "B"], "v": [4, 5, 6]}, index=[-1, -2, -3])
result = df.groupby("group").transform(f, engine="numba")
expected = DataFrame([-4.0, -3.0, -2.0], columns=["v"], index=[-1, -2, -3])
expected = DataFrame([-2.0, -3.0, -4.0], columns=["v"], index=[-1, -2, -3])
tm.assert_frame_equal(result, expected)


def test_index_order_consistency_preserved():
# GH 57069
pytest.importorskip("numba")

def f(values, index):
return values

df = DataFrame(
{"vals": [0.0, 1.0, 2.0, 3.0], "group": [0, 1, 0, 1]}, index=range(3, -1, -1)
)
result = df.groupby("group")["vals"].transform(f, engine="numba")
expected = Series([0.0, 1.0, 2.0, 3.0], index=range(3, -1, -1), name="vals")
tm.assert_series_equal(result, expected)


def test_engine_kwargs_not_cached():
# If the user passes a different set of engine_kwargs don't return the same
# jitted function
Expand Down

0 comments on commit c468028

Please sign in to comment.