6
6
from matplotlib .backend_bases import MouseEvent
7
7
from matplotlib .cbook import CallbackRegistry
8
8
from matplotlib .lines import Line2D
9
+ from matplotlib .transforms import IdentityTransform , blended_transform_factory
9
10
from matplotlib .widgets import AxesWidget
10
11
11
12
__all__ = [
16
17
17
18
18
19
class DraggableLine (AxesWidget ):
19
- def __init__ (self , ax , x , y , grab_range = 10 , useblit = False , ** kwargs ) -> None :
20
+ def __init__ (
21
+ self ,
22
+ ax ,
23
+ x ,
24
+ y ,
25
+ grab_range = 10 ,
26
+ useblit = False ,
27
+ grab_range_transform = None ,
28
+ ** kwargs ,
29
+ ) -> None :
20
30
"""
21
31
Parameters
22
32
----------
@@ -29,6 +39,9 @@ def __init__(self, ax, x, y, grab_range=10, useblit=False, **kwargs) -> None:
29
39
Whether to use blitting for faster drawing (if supported by the
30
40
backend). See the tutorial :doc:`/tutorials/advanced/blitting`
31
41
for details.
42
+ grab_range_transform : matplotlib.transform.Transform, optional
43
+ The transform to use for the handle positions when calculating
44
+ if a handle has been grabbed.
32
45
**kwargs :
33
46
Passed on to Line2D for styling
34
47
"""
@@ -43,6 +56,8 @@ def __init__(self, ax, x, y, grab_range=10, useblit=False, **kwargs) -> None:
43
56
marker = kwargs .pop ("marker" , "o" )
44
57
color = kwargs .pop ("color" , "k" )
45
58
transform = kwargs .pop ("transform" , self .ax .transData )
59
+ self ._grab_range_transform = grab_range_transform or self .ax .transLimits
60
+
46
61
self ._handles = Line2D (
47
62
[x [0 ], center_x , x [1 ]],
48
63
[y [0 ], center_y , y [1 ]],
@@ -108,12 +123,11 @@ def _on_press(self, event: MouseEvent):
108
123
if not self .canvas .widgetlock .available (self ):
109
124
return
110
125
# figure out if any handles are being grabbed
111
- # maybe possible to do this with a pick event?
112
126
113
127
x , y = self ._handles .get_data ()
114
- # this is taken pretty much directly from the implementation
115
- # in matplotlib.widget.ToolHandles.closest
116
- pts = self .ax . transLimits .transform (np .column_stack ([x , y ]))
128
+ # this is a modified version of
129
+ # matplotlib.widget.ToolHandles.closest
130
+ pts = self ._grab_range_transform .transform (np .column_stack ([x , y ]))
117
131
diff = pts - self .ax .transLimits .transform ((event .xdata , event .ydata ))
118
132
dist = np .hypot (* diff .T )
119
133
idx = np .argmin (dist )
@@ -227,6 +241,9 @@ def __init__(self, ax, x, grab_range=0.1, useblit=False, **kwargs) -> None:
227
241
grab_range = grab_range ,
228
242
useblit = useblit ,
229
243
transform = ax .get_xaxis_transform (),
244
+ grab_range_transform = blended_transform_factory (
245
+ ax .transLimits , IdentityTransform ()
246
+ ),
230
247
** kwargs ,
231
248
)
232
249
self ._y_lock = True
@@ -292,6 +309,9 @@ def __init__(self, ax, y, grab_range=0.1, useblit=False, **kwargs) -> None:
292
309
grab_range = grab_range ,
293
310
useblit = useblit ,
294
311
transform = ax .get_yaxis_transform (),
312
+ grab_range_transform = blended_transform_factory (
313
+ IdentityTransform (), ax .transLimits
314
+ ),
295
315
** kwargs ,
296
316
)
297
317
self ._x_lock = True
0 commit comments