Skip to content

Commit c9ace60

Browse files
committed
fix: don't move cursor to nil node
1 parent 6876fc4 commit c9ace60

File tree

1 file changed

+150
-146
lines changed

1 file changed

+150
-146
lines changed

lua/treewalker/swap.lua

Lines changed: 150 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -11,189 +11,193 @@ local M = {}
1111

1212
---@return boolean
1313
local function is_on_target_node()
14-
local node = vim.treesitter.get_node()
15-
if not node then return false end
16-
17-
-- Special case for markdown - use heading utility
18-
if util.is_markdown_file() then
19-
return markdown_heading.is_heading(vim.fn.line("."))
20-
end
21-
22-
-- For other languages, use the standard Treesitter-based approach
23-
if not nodes.is_jump_target(node) then return false end
24-
if vim.fn.line('.') - 1 ~= node:range() then return false end
25-
return true
14+
local node = vim.treesitter.get_node()
15+
if not node then return false end
16+
17+
-- Special case for markdown - use heading utility
18+
if util.is_markdown_file() then
19+
return markdown_heading.is_heading(vim.fn.line("."))
20+
end
21+
22+
-- For other languages, use the standard Treesitter-based approach
23+
if not nodes.is_jump_target(node) then return false end
24+
if vim.fn.line('.') - 1 ~= node:range() then return false end
25+
return true
2626
end
2727

2828
---@return boolean
2929
local function is_supported_ft()
30-
local unsupported_filetypes = {
31-
["text"] = true,
32-
["txt"] = true,
33-
}
30+
local unsupported_filetypes = {
31+
["text"] = true,
32+
["txt"] = true,
33+
}
3434

35-
local bufnr = vim.api.nvim_get_current_buf()
36-
local ft = vim.bo[bufnr].filetype
35+
local bufnr = vim.api.nvim_get_current_buf()
36+
local ft = vim.bo[bufnr].filetype
3737

38-
return not unsupported_filetypes[ft]
38+
return not unsupported_filetypes[ft]
3939
end
4040

4141
function M.swap_down()
42-
vim.cmd("normal! ^")
43-
if not is_supported_ft() then return end
44-
if not is_on_target_node() then return end
45-
if util.is_markdown_file() then
46-
return markdown_swap.swap_down_markdown()
47-
end
48-
local current = nodes.get_current()
49-
50-
local target = targets.down()
51-
if not target then return end
52-
53-
current = nodes.get_highest_coincident(current)
54-
55-
local current_augments = augment.get_node_augments(current)
56-
local current_all = { current, unpack(current_augments) }
57-
local current_srow = nodes.get_srow(current)
58-
local current_erow = nodes.get_erow(current)
59-
local current_all_rows = nodes.whole_range(current_all)
60-
61-
local target_augments = augment.get_node_augments(target)
62-
local target_all = { target, unpack(target_augments) }
63-
local target_srow = nodes.get_srow(target)
64-
local target_erow = nodes.get_erow(target)
65-
local target_scol = nodes.get_scol(target)
66-
local target_all_rows = nodes.whole_range(target_all)
67-
operations.swap_rows(current_all_rows, target_all_rows)
68-
69-
-- Place cursor
70-
local node_length_diff = (current_erow - current_srow) - (target_erow - target_srow)
71-
local x = target_srow - node_length_diff
72-
local y = target_scol
73-
vim.fn.cursor(x, y)
42+
vim.cmd("normal! ^")
43+
if not is_supported_ft() then return end
44+
if not is_on_target_node() then return end
45+
if util.is_markdown_file() then
46+
return markdown_swap.swap_down_markdown()
47+
end
48+
local current = nodes.get_current()
49+
50+
local target = targets.down()
51+
if not target then return end
52+
53+
current = nodes.get_highest_coincident(current)
54+
55+
local current_augments = augment.get_node_augments(current)
56+
local current_all = { current, unpack(current_augments) }
57+
local current_srow = nodes.get_srow(current)
58+
local current_erow = nodes.get_erow(current)
59+
local current_all_rows = nodes.whole_range(current_all)
60+
61+
local target_augments = augment.get_node_augments(target)
62+
local target_all = { target, unpack(target_augments) }
63+
local target_srow = nodes.get_srow(target)
64+
local target_erow = nodes.get_erow(target)
65+
local target_scol = nodes.get_scol(target)
66+
local target_all_rows = nodes.whole_range(target_all)
67+
operations.swap_rows(current_all_rows, target_all_rows)
68+
69+
-- Place cursor
70+
local node_length_diff = (current_erow - current_srow) - (target_erow - target_srow)
71+
local x = target_srow - node_length_diff
72+
local y = target_scol
73+
vim.fn.cursor(x, y)
7474
end
7575

7676
function M.swap_up()
77-
vim.cmd("normal! ^")
78-
if not is_supported_ft() then return end
79-
if not is_on_target_node() then return end
80-
if util.is_markdown_file() then
81-
return markdown_swap.swap_up_markdown()
82-
end
83-
local current = nodes.get_current()
84-
local target = targets.up()
85-
if not target then return end
86-
87-
current = nodes.get_highest_coincident(current)
88-
89-
local current_augments = augment.get_node_augments(current)
90-
local current_all = { current, unpack(current_augments) }
91-
local current_srow = nodes.get_srow(current)
92-
local current_all_rows = nodes.whole_range(current_all)
93-
94-
local target_srow = nodes.get_srow(target)
95-
local target_scol = nodes.get_scol(target)
96-
local target_augments = augment.get_node_augments(target)
97-
local target_all = { target, unpack(target_augments) }
98-
local target_all_rows = nodes.whole_range(target_all)
99-
100-
local target_augment_rows = nodes.whole_range(target_augments)
101-
local target_augment_srow = target_augment_rows[1]
102-
local target_augment_length = #target_augments == 0 and 0 or (target_srow - target_augment_srow - 1)
103-
104-
local current_augment_rows = nodes.whole_range(current_augments)
105-
local current_augment_srow = current_augment_rows[1]
106-
local current_augment_length = #current_augments == 0 and 0 or (current_srow - current_augment_srow - 1)
107-
108-
-- Do the swap
109-
operations.swap_rows(target_all_rows, current_all_rows)
110-
111-
-- Place cursor
112-
local x = target_srow + current_augment_length - target_augment_length
113-
local y = target_scol
114-
vim.fn.cursor(x, y)
77+
vim.cmd("normal! ^")
78+
if not is_supported_ft() then return end
79+
if not is_on_target_node() then return end
80+
if util.is_markdown_file() then
81+
return markdown_swap.swap_up_markdown()
82+
end
83+
local current = nodes.get_current()
84+
local target = targets.up()
85+
if not target then return end
86+
87+
current = nodes.get_highest_coincident(current)
88+
89+
local current_augments = augment.get_node_augments(current)
90+
local current_all = { current, unpack(current_augments) }
91+
local current_srow = nodes.get_srow(current)
92+
local current_all_rows = nodes.whole_range(current_all)
93+
94+
local target_srow = nodes.get_srow(target)
95+
local target_scol = nodes.get_scol(target)
96+
local target_augments = augment.get_node_augments(target)
97+
local target_all = { target, unpack(target_augments) }
98+
local target_all_rows = nodes.whole_range(target_all)
99+
100+
local target_augment_rows = nodes.whole_range(target_augments)
101+
local target_augment_srow = target_augment_rows[1]
102+
local target_augment_length = #target_augments == 0 and 0 or (target_srow - target_augment_srow - 1)
103+
104+
local current_augment_rows = nodes.whole_range(current_augments)
105+
local current_augment_srow = current_augment_rows[1]
106+
local current_augment_length = #current_augments == 0 and 0 or (current_srow - current_augment_srow - 1)
107+
108+
-- Do the swap
109+
operations.swap_rows(target_all_rows, current_all_rows)
110+
111+
-- Place cursor
112+
local x = target_srow + current_augment_length - target_augment_length
113+
local y = target_scol
114+
vim.fn.cursor(x, y)
115115
end
116116

117117
function M.swap_right()
118-
if not is_supported_ft() then return end
119-
if util.is_markdown_file() then return end
120-
local current = nodes.get_current()
121-
current = strategies.get_highest_string_node(current) or current
122-
current = nodes.get_highest_coincident(current)
118+
if not is_supported_ft() then return end
119+
if util.is_markdown_file() then return end
120+
local current = nodes.get_current()
121+
current = strategies.get_highest_string_node(current) or current
122+
current = nodes.get_highest_coincident(current)
123123

124-
local target = nodes.next_sib(current)
124+
local target = nodes.next_sib(current)
125125

126-
if not target then
127-
M.reorder(current, nodes.prev_sib)
128-
end
126+
if not target then
127+
M.reorder(current, nodes.prev_sib)
128+
end
129129

130-
if not current or not target then return end
130+
if not current or not target then return end
131131

132-
-- set a mark to track where the target started, so we may later go there after the swap
133-
local ns_id = vim.api.nvim_create_namespace("treewalker#swap_right")
134-
local ext_id = vim.api.nvim_buf_set_extmark(
135-
0,
136-
ns_id,
137-
nodes.get_srow(target) - 1,
138-
nodes.get_scol(target) - 1,
139-
{}
140-
)
132+
-- set a mark to track where the target started, so we may later go there after the swap
133+
local ns_id = vim.api.nvim_create_namespace("treewalker#swap_right")
134+
local ext_id = vim.api.nvim_buf_set_extmark(
135+
0,
136+
ns_id,
137+
nodes.get_srow(target) - 1,
138+
nodes.get_scol(target) - 1,
139+
{}
140+
)
141141

142-
operations.swap_nodes(current, target)
142+
operations.swap_nodes(current, target)
143143

144-
local ext = vim.api.nvim_buf_get_extmark_by_id(0, ns_id, ext_id, {})
145-
local new_current = nodes.get_at_rowcol(ext[1] + 1, ext[2] - 1)
144+
local ext = vim.api.nvim_buf_get_extmark_by_id(0, ns_id, ext_id, {})
145+
local new_current = nodes.get_at_rowcol(ext[1] + 1, ext[2] - 1)
146146

147-
if not new_current then return end
147+
if not new_current then return end
148148

149-
vim.fn.cursor(
150-
nodes.get_srow(new_current),
151-
nodes.get_scol(new_current)
152-
)
149+
vim.fn.cursor(
150+
nodes.get_srow(new_current),
151+
nodes.get_scol(new_current)
152+
)
153153

154-
-- cleanup
155-
vim.api.nvim_buf_clear_namespace(0, ns_id, 0, -1)
154+
-- cleanup
155+
vim.api.nvim_buf_clear_namespace(0, ns_id, 0, -1)
156156
end
157157

158158
function M.swap_left()
159-
if not is_supported_ft() then return end
160-
if util.is_markdown_file() then return end
161-
local current = nodes.get_current()
162-
current = strategies.get_highest_string_node(current) or current
163-
current = nodes.get_highest_coincident(current)
159+
if not is_supported_ft() then return end
160+
if util.is_markdown_file() then return end
161+
local current = nodes.get_current()
162+
current = strategies.get_highest_string_node(current) or current
163+
current = nodes.get_highest_coincident(current)
164164

165-
local target = nodes.prev_sib(current)
165+
local target = nodes.prev_sib(current)
166166

167-
if not target then
168-
M.reorder(current, nodes.next_sib)
169-
end
167+
if not target then
168+
M.reorder(current, nodes.next_sib)
169+
end
170170

171-
if not current or not target then return end
171+
if not current or not target then return end
172172

173-
operations.swap_nodes(target, current)
173+
operations.swap_nodes(target, current)
174174

175-
-- Place cursor
176-
vim.fn.cursor(
177-
nodes.get_srow(target),
178-
nodes.get_scol(target)
179-
)
175+
-- Place cursor
176+
vim.fn.cursor(
177+
nodes.get_srow(target),
178+
nodes.get_scol(target)
179+
)
180180
end
181181

182182
---@param node TSNode
183183
---@param fn function
184184
function M.reorder(node, fn)
185-
if not node then return nil end
186-
---@param iter TSNode
187-
local iter = node
188-
while fn(iter) do
189-
operations.swap_nodes(fn(iter), iter)
190-
iter = fn(iter)
191-
end
192-
-- place cursor on iter
193-
vim.fn.cursor(
194-
nodes.get_srow(iter),
195-
nodes.get_scol(iter)
196-
)
185+
if not node then return end
186+
if not fn then return end
187+
188+
---@param iter TSNode
189+
local iter = fn(node)
190+
while iter do
191+
operations.swap_nodes(fn(iter), iter)
192+
iter = fn(iter)
193+
end
194+
195+
-- We don't know which way we're swapping
196+
local lastnode = nodes.prev_sib(iter) or nodes.next_sib(iter)
197+
vim.fn.cursor(
198+
nodes.get_srow(lastnode),
199+
nodes.get_scol(lastnode)
200+
)
197201
end
198202

199203
return M

0 commit comments

Comments
 (0)