Skip to content

Commit f8fc938

Browse files
committed
use strconv.UnquoteChar
1 parent a6c7d79 commit f8fc938

File tree

1 file changed

+54
-75
lines changed

1 file changed

+54
-75
lines changed

gopls/internal/golang/highlight.go

+54-75
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,9 @@ func highlightPath(info *types.Info, path []ast.Node, pos token.Pos) (map[posRan
8181
// Treat each corresponding ("%v", arg) pair as a highlight class.
8282
for _, node := range path {
8383
if call, ok := node.(*ast.CallExpr); ok {
84-
idx := formatStringIndex(info, call)
85-
if idx >= 0 && idx < len(call.Args) {
86-
// We only care about literal format strings, so fmt.Sprint("a"+"b%s", "bar") won't be highlighted.
87-
if lit, ok := call.Args[idx].(*ast.BasicLit); ok && lit.Kind == token.STRING {
88-
highlightPrintf(call, idx, pos, lit, result)
89-
}
84+
idx, lit := formatLitAndIndex(info, call)
85+
if idx != -1 {
86+
highlightPrintf(call, idx, pos, lit, result)
9087
}
9188
}
9289
}
@@ -152,29 +149,32 @@ func highlightPath(info *types.Info, path []ast.Node, pos token.Pos) (map[posRan
152149
return result, nil
153150
}
154151

155-
// formatStringIndex returns the index of the format string (the last
152+
// formatLitAndIndex returns the BasicLit and index of the BasicLit (the last
156153
// non-variadic parameter) within the given printf-like call
157-
// expression, or -1 if unknown.
158-
func formatStringIndex(info *types.Info, call *ast.CallExpr) int {
154+
// expression, returns -1 as index if unknown.
155+
func formatLitAndIndex(info *types.Info, call *ast.CallExpr) (int, *ast.BasicLit) {
159156
typ := info.Types[call.Fun].Type
160157
if typ == nil {
161-
return -1 // missing type
158+
return -1, nil // missing type
162159
}
163160
sig, ok := typ.(*types.Signature)
164161
if !ok {
165-
return -1 // ill-typed
162+
return -1, nil // ill-typed
166163
}
167164
if !sig.Variadic() {
168165
// Skip checking non-variadic functions.
169-
return -1
166+
return -1, nil
170167
}
171168
idx := sig.Params().Len() - 2
172169
if idx < 0 {
173170
// Skip checking variadic functions without
174171
// fixed arguments.
175-
return -1
172+
return -1, nil
173+
}
174+
if lit, ok := call.Args[idx].(*ast.BasicLit); ok && lit.Kind == token.STRING {
175+
return idx, lit
176176
}
177-
return idx
177+
return -1, nil
178178
}
179179

180180
// highlightPrintf highlights operations in a format string and their corresponding
@@ -207,19 +207,27 @@ func highlightPrintf(call *ast.CallExpr, idx int, cursorPos token.Pos, lit *ast.
207207
succeededArg := 0
208208
visited := make(map[posRange]int, 0)
209209

210-
formatPos := call.Args[idx].Pos()
211210
// highlightPair highlights the operation and its potential argument pair if the cursor is within either range.
212211
highlightPair := func(rang fmtstr.Range, argIndex int) {
213-
var (
214-
rangeStart = formatPos + token.Pos(offsetInStringLiteral(lit.Value, format, rang.Start))
215-
rangeEnd = formatPos + token.Pos(offsetInStringLiteral(lit.Value, format, rang.End-1)+1)
216-
arg ast.Expr // may not exist
217-
)
212+
rangeStart, err := posInStringLiteral(lit, rang.Start)
213+
if err != nil {
214+
return
215+
}
216+
rangeEnd, err := posInStringLiteral(lit, rang.End)
217+
if err != nil {
218+
return
219+
}
218220
visited[posRange{rangeStart, rangeEnd}] = argIndex
221+
222+
var arg ast.Expr
219223
if argIndex < len(call.Args) {
220224
arg = call.Args[argIndex]
221225
}
222-
if rangeStart <= cursorPos && cursorPos < rangeEnd || arg != nil && arg.Pos() <= cursorPos && cursorPos < arg.End() {
226+
227+
// cursorPos can't equal to end position, otherwise the two
228+
// neighborhood such as (%[2]*d) are both highlighted if cursor in "*" (ending of [2]*).
229+
if rangeStart <= cursorPos && cursorPos < rangeEnd ||
230+
arg != nil && arg.Pos() <= cursorPos && cursorPos < arg.End() {
223231
highlightRange(result, rangeStart, rangeEnd, protocol.Write)
224232
if arg != nil {
225233
succeededArg = argIndex
@@ -268,65 +276,36 @@ func highlightPrintf(call *ast.CallExpr, idx int, cursorPos token.Pos, lit *ast.
268276
}
269277
}
270278

271-
// offsetInStringLiteral maps an offset in the unquoted string to
272-
// relative to the literal string.
273-
func offsetInStringLiteral(literal string, unquoted string, logicalOffset int) int {
274-
literalIdx := 1 // Skip the initial quote char.
275-
logIdx := 0
276-
277-
// Advance by one unquoted rune and the corresponding literal string.
278-
advanceRune := func() {
279-
r, size := utf8.DecodeRuneInString(unquoted[logIdx:])
280-
if r == utf8.RuneError && size <= 1 {
281-
// Malformed UTF-8 or end of string,
282-
// move one byte in both strings to avoid infinite loops.
283-
logIdx++
284-
literalIdx++
285-
return
286-
}
287-
logIdx += size
279+
// posInStringLiteral returns the position within a string literal
280+
// corresponding to the specified byte offset within the logical
281+
// string that it denotes.
282+
func posInStringLiteral(lit *ast.BasicLit, offset int) (token.Pos, error) {
283+
raw := lit.Value
288284

289-
if literalIdx >= len(literal)-1 {
290-
return
291-
}
292-
293-
if literal[literalIdx] == '\\' {
294-
remain := literal[literalIdx:]
295-
escLen := 0
296-
if len(remain) < 2 {
297-
escLen = 1 // just the '\'
298-
}
299-
switch remain[1] {
300-
case 'x':
301-
escLen = 4
302-
case 'u':
303-
escLen = 6
304-
case 'U':
305-
escLen = 10
306-
case 'a', 'b', 'f', 'n', 'r', 't', 'v', '\\', '"':
307-
escLen = 2
308-
case '0', '1', '2', '3', '4', '5', '6', '7':
309-
escLen = 4
310-
default:
311-
return
312-
}
313-
literalIdx += escLen
314-
} else {
315-
// non-escaped character
316-
literalIdx++
317-
}
285+
value, err := strconv.Unquote(raw)
286+
if err != nil {
287+
return 0, err
318288
}
319-
320-
for logIdx < len(unquoted) && (logIdx < logicalOffset) && literalIdx < len(literal)-1 {
321-
advanceRune()
289+
if !(0 <= offset && offset <= len(value)) {
290+
return 0, fmt.Errorf("invalid offset")
322291
}
323292

324-
// Clamp it to ensure we don't exceed array bounds.
325-
if literalIdx >= len(literal)-1 {
326-
literalIdx = len(literal) - 1
327-
}
293+
// remove quotes
294+
quote := raw[0] // '"' or '`'
295+
raw = raw[1 : len(raw)-1]
328296

329-
return literalIdx
297+
var (
298+
i = 0 // byte index within logical value
299+
pos = lit.ValuePos + 1 // position within literal
300+
)
301+
for raw != "" && i < offset {
302+
r, _, rest, _ := strconv.UnquoteChar(raw, quote) // can't fail
303+
sz := len(raw) - len(rest) // length of literal char in raw bytes
304+
pos += token.Pos(sz)
305+
raw = raw[sz:]
306+
i += utf8.RuneLen(r)
307+
}
308+
return pos, nil
330309
}
331310

332311
type posRange struct {

0 commit comments

Comments
 (0)