@@ -1589,6 +1589,7 @@ def _clean_interp_method(method, **kwargs):
1589
1589
1590
1590
1591
1591
def interpolate_1d (xvalues , yvalues , method = 'linear' , limit = None ,
1592
+ limit_direction = 'forward' ,
1592
1593
fill_value = None , bounds_error = False , order = None , ** kwargs ):
1593
1594
"""
1594
1595
Logic for the 1-d interpolation. The result should be 1-d, inputs
@@ -1602,9 +1603,15 @@ def interpolate_1d(xvalues, yvalues, method='linear', limit=None,
1602
1603
invalid = isnull (yvalues )
1603
1604
valid = ~ invalid
1604
1605
1605
- valid_y = yvalues [valid ]
1606
- valid_x = xvalues [valid ]
1607
- new_x = xvalues [invalid ]
1606
+ if not valid .any ():
1607
+ # have to call np.asarray(xvalues) since xvalues could be an Index
1608
+ # which cant be mutated
1609
+ result = np .empty_like (np .asarray (xvalues ), dtype = np .float64 )
1610
+ result .fill (np .nan )
1611
+ return result
1612
+
1613
+ if valid .all ():
1614
+ return yvalues
1608
1615
1609
1616
if method == 'time' :
1610
1617
if not getattr (xvalues , 'is_all_dates' , None ):
@@ -1614,66 +1621,82 @@ def interpolate_1d(xvalues, yvalues, method='linear', limit=None,
1614
1621
'DatetimeIndex' )
1615
1622
method = 'values'
1616
1623
1617
- def _interp_limit (invalid , limit ):
1618
- """mask off values that won't be filled since they exceed the limit"" "
1624
+ def _interp_limit (invalid , fw_limit , bw_limit ):
1625
+ "Get idx of values that won't be forward- filled b/c they exceed the limit. "
1619
1626
all_nans = np .where (invalid )[0 ]
1620
1627
if all_nans .size == 0 : # no nans anyway
1621
1628
return []
1622
- violate = [invalid [x :x + limit + 1 ] for x in all_nans ]
1623
- violate = np .array ([x .all () & (x .size > limit ) for x in violate ])
1624
- return all_nans [violate ] + limit
1629
+ violate = [invalid [max (0 , x - bw_limit ):x + fw_limit + 1 ] for x in all_nans ]
1630
+ violate = np .array ([x .all () & (x .size > bw_limit + fw_limit ) for x in violate ])
1631
+ return all_nans [violate ] + fw_limit - bw_limit
1632
+
1633
+ valid_limit_directions = ['forward' , 'backward' , 'both' ]
1634
+ limit_direction = limit_direction .lower ()
1635
+ if limit_direction not in valid_limit_directions :
1636
+ msg = 'Invalid limit_direction: expecting one of %r, got %r.' % (
1637
+ valid_limit_directions , limit_direction )
1638
+ raise ValueError (msg )
1625
1639
1626
- xvalues = getattr (xvalues , 'values' , xvalues )
1627
- yvalues = getattr (yvalues , 'values' , yvalues )
1640
+ from pandas import Series
1641
+ ys = Series (yvalues )
1642
+ start_nans = set (range (ys .first_valid_index ()))
1643
+ end_nans = set (range (1 + ys .last_valid_index (), len (valid )))
1644
+
1645
+ # This is a list of the indexes in the series whose yvalue is currently NaN,
1646
+ # but whose interpolated yvalue will be overwritten with NaN after computing
1647
+ # the interpolation. For each index in this list, one of these conditions is
1648
+ # true of the corresponding NaN in the yvalues:
1649
+ #
1650
+ # a) It is one of a chain of NaNs at the beginning of the series, and either
1651
+ # limit is not specified or limit_direction is 'forward'.
1652
+ # b) It is one of a chain of NaNs at the end of the series, and limit is
1653
+ # specified and limit_direction is 'backward' or 'both'.
1654
+ # c) Limit is nonzero and it is further than limit from the nearest non-NaN
1655
+ # value (with respect to the limit_direction setting).
1656
+ #
1657
+ # The default behavior is to fill forward with no limit, ignoring NaNs at
1658
+ # the beginning (see issues #9218 and #10420)
1659
+ violate_limit = sorted (start_nans )
1628
1660
1629
1661
if limit :
1630
- violate_limit = _interp_limit (invalid , limit )
1631
- if valid .any ():
1632
- firstIndex = valid .argmax ()
1633
- valid = valid [firstIndex :]
1634
- invalid = invalid [firstIndex :]
1635
- result = yvalues .copy ()
1636
- if valid .all ():
1637
- return yvalues
1638
- else :
1639
- # have to call np.array(xvalues) since xvalues could be an Index
1640
- # which cant be mutated
1641
- result = np .empty_like (np .array (xvalues ), dtype = np .float64 )
1642
- result .fill (np .nan )
1643
- return result
1662
+ if limit_direction == 'forward' :
1663
+ violate_limit = sorted (start_nans | set (_interp_limit (invalid , limit , 0 )))
1664
+ if limit_direction == 'backward' :
1665
+ violate_limit = sorted (end_nans | set (_interp_limit (invalid , 0 , limit )))
1666
+ if limit_direction == 'both' :
1667
+ violate_limit = _interp_limit (invalid , limit , limit )
1668
+
1669
+ xvalues = getattr (xvalues , 'values' , xvalues )
1670
+ yvalues = getattr (yvalues , 'values' , yvalues )
1671
+ result = yvalues .copy ()
1644
1672
1645
1673
if method in ['linear' , 'time' , 'index' , 'values' ]:
1646
1674
if method in ('values' , 'index' ):
1647
1675
inds = np .asarray (xvalues )
1648
1676
# hack for DatetimeIndex, #1646
1649
1677
if issubclass (inds .dtype .type , np .datetime64 ):
1650
1678
inds = inds .view (np .int64 )
1651
-
1652
1679
if inds .dtype == np .object_ :
1653
1680
inds = lib .maybe_convert_objects (inds )
1654
1681
else :
1655
1682
inds = xvalues
1656
-
1657
- inds = inds [firstIndex :]
1658
-
1659
- result [firstIndex :][invalid ] = np .interp (inds [invalid ], inds [valid ],
1660
- yvalues [firstIndex :][valid ])
1661
-
1662
- if limit :
1663
- result [violate_limit ] = np .nan
1683
+ result [invalid ] = np .interp (inds [invalid ], inds [valid ], yvalues [valid ])
1684
+ result [violate_limit ] = np .nan
1664
1685
return result
1665
1686
1666
1687
sp_methods = ['nearest' , 'zero' , 'slinear' , 'quadratic' , 'cubic' ,
1667
1688
'barycentric' , 'krogh' , 'spline' , 'polynomial' ,
1668
1689
'piecewise_polynomial' , 'pchip' ]
1669
1690
if method in sp_methods :
1670
- new_x = new_x [firstIndex :]
1671
-
1672
- result [firstIndex :][invalid ] = _interpolate_scipy_wrapper (
1673
- valid_x , valid_y , new_x , method = method , fill_value = fill_value ,
1691
+ inds = np .asarray (xvalues )
1692
+ # hack for DatetimeIndex, #1646
1693
+ if issubclass (inds .dtype .type , np .datetime64 ):
1694
+ inds = inds .view (np .int64 )
1695
+ result [invalid ] = _interpolate_scipy_wrapper (
1696
+ inds [valid ], yvalues [valid ], inds [invalid ], method = method ,
1697
+ fill_value = fill_value ,
1674
1698
bounds_error = bounds_error , order = order , ** kwargs )
1675
- if limit :
1676
- result [violate_limit ] = np .nan
1699
+ result [violate_limit ] = np .nan
1677
1700
return result
1678
1701
1679
1702
0 commit comments