Skip to content

Commit

Permalink
Merge pull request #2532 from fesily/automatic-infer-function-param-type
Browse files Browse the repository at this point in the history
add infer function param type
  • Loading branch information
sumneko authored Feb 26, 2024
2 parents f388b95 + 87c83c3 commit 73be83c
Show file tree
Hide file tree
Showing 27 changed files with 179 additions and 94 deletions.
6 changes: 6 additions & 0 deletions locale/en-us/setting.lua
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,12 @@ When checking the type of union type, ignore the `nil` in it.
When this setting is `false`, the `number|nil` type cannot be assigned to the `number` type. It can be with `true`.
]]
config.type.inferParamType =
[[
When a parameter type is not annotated, it is inferred from the function's call sites.
When this setting is `false`, the type of the parameter is `any` when it is not annotated.
]]
config.doc.privateName =
'Treat specific field names as private, e.g. `m_*` means `XXX.m_id` and `XXX.m_type` are private, witch can only be accessed in the class where the definition is located.'
config.doc.protectedName =
Expand Down
6 changes: 6 additions & 0 deletions locale/pt-br/setting.lua
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,12 @@ When checking the type of union type, ignore the `nil` in it.
When this setting is `false`, the `number|nil` type cannot be assigned to the `number` type. It can be with `true`.
]]
config.type.inferParamType = -- TODO: need translate!
[[
When the parameter type is not annotated, the parameter type is inferred from the function's incoming parameters.
When this setting is `false`, the type of the parameter is `any` when it is not annotated.
]]
config.doc.privateName = -- TODO: need translate!
'Treat specific field names as private, e.g. `m_*` means `XXX.m_id` and `XXX.m_type` are private, witch can only be accessed in the class where the definition is located.'
config.doc.protectedName = -- TODO: need translate!
Expand Down
6 changes: 6 additions & 0 deletions locale/zh-cn/setting.lua
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,12 @@ config.type.weakNilCheck =
此设置为 `false` 时,`numer|nil` 类型无法赋给 `number` 类型;为 `true` 是则可以。
]]
config.type.inferParamType =
[[
未注释参数类型时,参数类型由函数传入参数推断。
如果设置为 "false",则在未注释时,参数类型为 "any"。
]]
config.doc.privateName =
'将特定名称的字段视为私有,例如 `m_*` 意味着 `XXX.m_id` 与 `XXX.m_type` 是私有字段,只能在定义所在的类中访问。'
config.doc.protectedName =
Expand Down
6 changes: 6 additions & 0 deletions locale/zh-tw/setting.lua
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,12 @@ When checking the type of union type, ignore the `nil` in it.
When this setting is `false`, the `number|nil` type cannot be assigned to the `number` type. It can be with `true`.
]]
config.type.inferParamType = -- TODO: need translate!
[[
未注释参数类型时,参数类型由函数传入参数推断。
如果设置为 "false",则在未注释时,参数类型为 "any"。
]]
config.doc.privateName = -- TODO: need translate!
'Treat specific field names as private, e.g. `m_*` means `XXX.m_id` and `XXX.m_type` are private, witch can only be accessed in the class where the definition is located.'
config.doc.protectedName = -- TODO: need translate!
Expand Down
4 changes: 2 additions & 2 deletions script/client.lua
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ local function searchPatchInfo(cfg, rawKey)
}
end

---@param uri uri
---@param uri? uri
---@param cfg table
---@param change config.change
---@return json.patch?
Expand Down Expand Up @@ -330,7 +330,7 @@ local function makeConfigPatch(uri, cfg, change)
return nil
end

---@param uri uri
---@param uri? uri
---@param path string
---@param changes config.change[]
---@return string?
Expand Down
1 change: 1 addition & 0 deletions script/config/template.lua
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ local template = {
['Lua.type.castNumberToInteger'] = Type.Boolean >> true,
['Lua.type.weakUnionCheck'] = Type.Boolean >> false,
['Lua.type.weakNilCheck'] = Type.Boolean >> false,
['Lua.type.inferParamType'] = Type.Boolean >> false,
['Lua.doc.privateName'] = Type.Array(Type.String),
['Lua.doc.protectedName'] = Type.Array(Type.String),
['Lua.doc.packageName'] = Type.Array(Type.String),
Expand Down
5 changes: 4 additions & 1 deletion script/core/command/autoRequire.lua
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ end

---@async
return function (data)
---@type uri
local uri = data.uri
local target = data.target
local name = data.name
Expand All @@ -158,5 +159,7 @@ return function (data)
end

local offset, fmt = findInsertRow(uri)
applyAutoRequire(uri, offset, name, requireName, fmt)
if offset and fmt then
applyAutoRequire(uri, offset, name, requireName, fmt)
end
end
52 changes: 37 additions & 15 deletions script/core/completion/completion.lua
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ end

local function findParent(state, position)
local text = state.lua
if not text then
return
end
local offset = guide.positionToOffset(state, position)
for i = offset, 1, -1 do
local char = text:sub(i, i)
Expand Down Expand Up @@ -675,6 +678,7 @@ local function checkGlobal(state, word, startPos, position, parent, oop, results
end

---@async
---@param parent parser.object
local function checkField(state, word, start, position, parent, oop, results)
if parent.tag == '_ENV' or parent.special == '_G' then
local globals = vm.getGlobalSets(state.uri, 'variable')
Expand Down Expand Up @@ -955,8 +959,7 @@ local function checkFunctionArgByDocParam(state, word, startPos, results)
end
end

local function isAfterLocal(state, startPos)
local text = state.lua
local function isAfterLocal(state, text, startPos)
local offset = guide.positionToOffset(state, startPos)
local pos = lookBackward.skipSpace(text, offset)
local word = lookBackward.findWord(text, pos)
Expand All @@ -965,6 +968,8 @@ end

local function collectRequireNames(mode, myUri, literal, source, smark, position, results)
local collect = {}
local source_start = source and smark and (source.start + #smark) or position
local source_finish = source and smark and (source.finish - #smark) or position
if mode == 'require' then
for uri in files.eachFile(myUri) do
if myUri == uri then
Expand All @@ -978,8 +983,8 @@ local function collectRequireNames(mode, myUri, literal, source, smark, position
if not collect[info.name] then
collect[info.name] = {
textEdit = {
start = smark and (source.start + #smark) or position,
finish = smark and (source.finish - #smark) or position,
start = source_start,
finish = source_finish,
newText = smark and info.name or util.viewString(info.name),
},
path = relative,
Expand All @@ -1006,8 +1011,8 @@ local function collectRequireNames(mode, myUri, literal, source, smark, position
if not collect[open] then
collect[open] = {
textEdit = {
start = smark and (source.start + #smark) or position,
finish = smark and (source.finish - #smark) or position,
start = source_start,
finish = source_finish,
newText = smark and open or util.viewString(open),
},
path = path,
Expand All @@ -1034,8 +1039,8 @@ local function collectRequireNames(mode, myUri, literal, source, smark, position
if not collect[path] then
collect[path] = {
textEdit = {
start = smark and (source.start + #smark) or position,
finish = smark and (source.finish - #smark) or position,
start = source_start,
finish = source_finish,
newText = smark and path or util.viewString(path),
}
}
Expand Down Expand Up @@ -1097,6 +1102,9 @@ end

local function checkLenPlusOne(state, position, results)
local text = state.lua
if not text then
return
end
guide.eachSourceContain(state.ast, position, function (source)
if source.type == 'getindex'
or source.type == 'setindex' then
Expand Down Expand Up @@ -1392,6 +1400,9 @@ end

local function checkEqualEnum(state, position, results)
local text = state.lua
if not text then
return
end
local start = lookBackward.findTargetSymbol(text, guide.positionToOffset(state, position), '=')
if not start then
return
Expand Down Expand Up @@ -1493,6 +1504,9 @@ local function tryWord(state, position, triggerCharacter, results)
return
end
local text = state.lua
if not text then
return
end
local offset = guide.positionToOffset(state, position)
local finish = lookBackward.skipSpace(text, offset)
local word, start = lookBackward.findWord(text, offset)
Expand All @@ -1518,7 +1532,7 @@ local function tryWord(state, position, triggerCharacter, results)
checkProvideLocal(state, word, startPos, results)
checkFunctionArgByDocParam(state, word, startPos, results)
else
local afterLocal = isAfterLocal(state, startPos)
local afterLocal = isAfterLocal(state, text, startPos)
local stop = checkKeyWord(state, startPos, position, word, hasSpace, afterLocal, results)
if stop then
return
Expand All @@ -1530,8 +1544,10 @@ local function tryWord(state, position, triggerCharacter, results)
checkLocal(state, word, startPos, results)
checkTableField(state, word, startPos, results)
local env = guide.getENV(state.ast, startPos)
checkGlobal(state, word, startPos, position, env, false, results)
checkModule(state, word, startPos, results)
if env then
checkGlobal(state, word, startPos, position, env, false, results)
checkModule(state, word, startPos, results)
end
end
end
end
Expand Down Expand Up @@ -1592,6 +1608,9 @@ end

local function checkTableLiteralField(state, position, tbl, fields, results)
local text = state.lua
if not text then
return
end
local mark = {}
for _, field in ipairs(tbl) do
if field.type == 'tablefield'
Expand All @@ -1610,9 +1629,11 @@ local function checkTableLiteralField(state, position, tbl, fields, results)
local left = lookBackward.findWord(text, guide.positionToOffset(state, position))
if not left then
local pos = lookBackward.findAnyOffset(text, guide.positionToOffset(state, position))
local char = text:sub(pos, pos)
if char == '{' or char == ',' or char == ';' then
left = ''
if pos then
local char = text:sub(pos, pos)
if char == '{' or char == ',' or char == ';' then
left = ''
end
end
end
if left then
Expand Down Expand Up @@ -1801,6 +1822,7 @@ local function getluaDocByContain(state, position)
return result
end

---@return parser.state.err?, parser.object?
local function getluaDocByErr(state, start, position)
local targetError
for _, err in ipairs(state.errs) do
Expand Down Expand Up @@ -2008,7 +2030,7 @@ local function tryluaDocByErr(state, position, err, docState, results)
for _, doc in ipairs(vm.getDocSets(state.uri)) do
if doc.type == 'doc.class'
and not used[doc.class[1]]
and doc.class[1] ~= docState.class[1] then
and docState and doc.class[1] ~= docState.class[1] then
used[doc.class[1]] = true
results[#results+1] = {
label = doc.class[1],
Expand Down
13 changes: 1 addition & 12 deletions script/core/diagnostics/undefined-doc-name.lua
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,6 @@ return function (uri, callback)
return
end

local function hasNameOfGeneric(name, source)
if not source.typeGeneric then
return false
end
if not source.typeGeneric[name] then
return false
end
return true
end

guide.eachSource(state.ast.docs, function (source)
if source.type ~= 'doc.extends.name'
and source.type ~= 'doc.type.name' then
Expand All @@ -35,8 +25,7 @@ return function (uri, callback)
if name == '...' or name == '_' or name == 'self' then
return
end
if #vm.getDocSets(uri, name) > 0
or hasNameOfGeneric(name, source) then
if #vm.getDocSets(uri, name) > 0 then
return
end
callback {
Expand Down
4 changes: 2 additions & 2 deletions script/core/highlight.lua
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ local function checkInIf(state, source, text, position)
local endA = endB - #'end' + 1
if position >= source.finish - #'end'
and position <= source.finish
and text:sub(endA, endB) == 'end' then
and text and text:sub(endA, endB) == 'end' then
return true
end
-- 检查每个子模块
Expand All @@ -83,7 +83,7 @@ local function makeIf(state, source, text, callback)
-- end
local endB = guide.positionToOffset(state, source.finish)
local endA = endB - #'end' + 1
if text:sub(endA, endB) == 'end' then
if text and text:sub(endA, endB) == 'end' then
callback(source.finish - #'end', source.finish)
end
-- 每个子模块
Expand Down
6 changes: 6 additions & 0 deletions script/fs-utility.lua
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ function dfs:__div(filename)
return new
end

---@package
function dfs:_open(index)
local paths = split(self.path, '[/\\]')
local current = self.files
Expand All @@ -147,6 +148,7 @@ function dfs:_open(index)
return current
end

---@package
function dfs:_filename()
return self.path:match '[^/\\]+$'
end
Expand Down Expand Up @@ -291,6 +293,7 @@ local function fsIsDirectory(path, option)
if path.type == 'dummy' then
return path:isDirectory()
end
---@cast path -dummyfs
local status = fs.symlink_status(path):type()
return status == 'directory'
end
Expand Down Expand Up @@ -347,6 +350,7 @@ local function fsSave(path, text, option)
return false
end
if path.type == 'dummy' then
---@cast path -fs.path
local dir = path:_open(-2)
if not dir then
option.err[#option.err+1] = '无法打开:' .. path:string()
Expand Down Expand Up @@ -385,6 +389,7 @@ local function fsLoad(path, option)
return nil
end
else
---@cast path -dummyfs
local text, err = m.loadFile(path)
if text then
return text
Expand All @@ -407,6 +412,7 @@ local function fsCopy(source, target, option)
end
return fsSave(target, sourceText, option)
else
---@cast source -dummyfs
if target.type == 'dummy' then
local sourceText, err = m.loadFile(source)
if not sourceText then
Expand Down
3 changes: 2 additions & 1 deletion script/gc.lua
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
local util = require 'utility'

---@class gc
---@field _list table
---@field package _list table
local mt = {}
mt.__index = mt
mt.type = 'gc'
mt._removed = false

---@package
mt._max = 10

local function destroyGCObject(obj)
Expand Down
Loading

0 comments on commit 73be83c

Please sign in to comment.