-
-
Notifications
You must be signed in to change notification settings - Fork 191
/
Copy pathtokens.lua
85 lines (67 loc) · 2.05 KB
/
tokens.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
--Taken from https://github.com/jackMort/ChatGPT.nvim/blob/main/lua/chatgpt/flows/chat/tokens.lua
local api = vim.api
local M = {}
---Calculate the number of tokens in a message
---@param message string The text to calculate the number of tokens for
---@return number The number of tokens in the message
function M.calculate(message)
local tokens = 0
local current_token = ""
if message == "" or string.sub(message, 1, 2) == "# " then
return tokens
end
for char in message:gmatch(".") do
if char == " " or char == "\n" then
if current_token ~= "" then
tokens = tokens + 1
current_token = ""
end
else
current_token = current_token .. char
end
end
if current_token ~= "" then
tokens = tokens + 1
end
return tokens
end
---@param messages table The messages to calculate the number of tokens for.
---@return number The number of tokens in the messages.
function M.get_tokens(messages)
local tokens = 0
for _, message in ipairs(messages) do
tokens = tokens + M.calculate(message.content)
end
return tokens
end
---Display the number of tokens in the current buffer
---@param token_str string
---@param ns_id number
---@param parser table
---@param start_row number
---@param bufnr? number
---@return nil
function M.display(token_str, ns_id, parser, start_row, bufnr)
bufnr = bufnr or api.nvim_get_current_buf()
api.nvim_buf_clear_namespace(bufnr, ns_id, 0, -1)
local query = vim.treesitter.query.get("markdown", "tokens")
local tree = parser:parse({ start_row - 1, -1 })[1]
local root = tree:root()
local header
for id, node in query:iter_captures(root, bufnr, start_row - 1, -1) do
if query.captures[id] == "role" then
header = node
end
end
if header then
local _, _, end_row, _ = header:range()
local virtual_text = { { token_str, "CodeCompanionChatTokens" } }
api.nvim_buf_set_extmark(bufnr, ns_id, end_row - 1, 0, {
virt_text = virtual_text,
virt_text_pos = "eol",
priority = 110,
hl_mode = "combine",
})
end
end
return M