-
-
Notifications
You must be signed in to change notification settings - Fork 190
/
Copy pathadapters.lua
116 lines (103 loc) · 2.78 KB
/
adapters.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
local Path = require("plenary.path")
local M = {}
---Refresh when we should next check the model cache
---@param file string
---@param cache_for number
---@return number
function M.refresh_cache(file, cache_for)
cache_for = cache_for or 1800
local time = os.time() + cache_for
Path.new(file):write(time, "w")
return time
end
---Return when the model cache expires
---@param file string
---@return number
function M.cache_expires(file, cache_for)
cache_for = cache_for or 1800
local ok, expires = pcall(function()
return Path.new(file):read()
end)
if not ok then
expires = M.refresh_cache(file, cache_for)
end
expires = tonumber(expires)
assert(expires, "Could not get the cache expiry time")
return expires
end
---Check if the cache has expired
---@param file string
---@return boolean
function M.cache_expired(file)
return os.time() > M.cache_expires(file)
end
---Extend a default adapter
---@param base_tbl table
---@param new_tbl table
---@return nil
function M.extend(base_tbl, new_tbl)
for name, adapter in pairs(new_tbl) do
if base_tbl[name] then
if type(adapter) == "table" then
base_tbl[name] = adapter
if adapter.schema then
base_tbl[name].schema = vim.tbl_deep_extend("force", base_tbl[name].schema, adapter.schema)
end
end
end
end
end
---Get the indexes for messages with a specific role
---@param role string
---@param messages table
---@return table|nil
function M.get_msg_index(role, messages)
local prompts = {}
for i = 1, #messages do
if messages[i].role == role then
table.insert(prompts, i)
end
end
if #prompts > 0 then
return prompts
end
end
---Pluck messages from a table with a specific role
---@param messages table
---@param role string
---@return table
function M.pluck_messages(messages, role)
local output = {}
for _, message in ipairs(messages) do
if message.role == role then
table.insert(output, message)
end
end
return output
end
---Merge consecutive messages with the same role, together
---@param messages table
---@return table
function M.merge_messages(messages)
return vim.iter(messages):fold({}, function(acc, msg)
local last = acc[#acc]
if last and last.role == msg.role then
last.content = last.content .. "\n\n" .. msg.content
else
table.insert(acc, { role = msg.role, content = msg.content })
end
return acc
end)
end
---Clean streaming data to be parsed as JSON. Typically streaming endpoints
---return invalid JSON such as `data: { "id": 12345}`
---@param data string | { body: string }
---@return string
function M.clean_streamed_data(data)
if type(data) == "table" then
return data.body
end
local find_json_start = string.find(data, "{") or 1
return string.sub(data, find_json_start)
end
return M