Skip to content

Commit e1ba409

Browse files
committed
use strconv.UnquoteChar
1 parent a6c7d79 commit e1ba409

File tree

1 file changed

+55
-75
lines changed

1 file changed

+55
-75
lines changed

gopls/internal/golang/highlight.go

+55-75
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,9 @@ func highlightPath(info *types.Info, path []ast.Node, pos token.Pos) (map[posRan
8181
// Treat each corresponding ("%v", arg) pair as a highlight class.
8282
for _, node := range path {
8383
if call, ok := node.(*ast.CallExpr); ok {
84-
idx := formatStringIndex(info, call)
85-
if idx >= 0 && idx < len(call.Args) {
86-
// We only care about literal format strings, so fmt.Sprint("a"+"b%s", "bar") won't be highlighted.
87-
if lit, ok := call.Args[idx].(*ast.BasicLit); ok && lit.Kind == token.STRING {
88-
highlightPrintf(call, idx, pos, lit, result)
89-
}
84+
idx, lit := formatLitAndIndex(info, call)
85+
if idx != -1 {
86+
highlightPrintf(call, idx, pos, lit, result)
9087
}
9188
}
9289
}
@@ -152,29 +149,33 @@ func highlightPath(info *types.Info, path []ast.Node, pos token.Pos) (map[posRan
152149
return result, nil
153150
}
154151

155-
// formatStringIndex returns the index of the format string (the last
152+
// formatLitAndIndex returns the BasicLit and index of the BasicLit (the last
156153
// non-variadic parameter) within the given printf-like call
157-
// expression, or -1 if unknown.
158-
func formatStringIndex(info *types.Info, call *ast.CallExpr) int {
154+
// expression, returns -1 as index if unknown.
155+
func formatLitAndIndex(info *types.Info, call *ast.CallExpr) (int, *ast.BasicLit) {
159156
typ := info.Types[call.Fun].Type
160157
if typ == nil {
161-
return -1 // missing type
158+
return -1, nil // missing type
162159
}
163160
sig, ok := typ.(*types.Signature)
164161
if !ok {
165-
return -1 // ill-typed
162+
return -1, nil // ill-typed
166163
}
167164
if !sig.Variadic() {
168165
// Skip checking non-variadic functions.
169-
return -1
166+
return -1, nil
170167
}
171168
idx := sig.Params().Len() - 2
172169
if idx < 0 {
173170
// Skip checking variadic functions without
174171
// fixed arguments.
175-
return -1
172+
return -1, nil
173+
}
174+
// We only care about literal format strings, so fmt.Sprint("a"+"b%s", "bar") won't be highlighted.
175+
if lit, ok := call.Args[idx].(*ast.BasicLit); ok && lit.Kind == token.STRING {
176+
return idx, lit
176177
}
177-
return idx
178+
return -1, nil
178179
}
179180

180181
// highlightPrintf highlights operations in a format string and their corresponding
@@ -207,19 +208,27 @@ func highlightPrintf(call *ast.CallExpr, idx int, cursorPos token.Pos, lit *ast.
207208
succeededArg := 0
208209
visited := make(map[posRange]int, 0)
209210

210-
formatPos := call.Args[idx].Pos()
211211
// highlightPair highlights the operation and its potential argument pair if the cursor is within either range.
212212
highlightPair := func(rang fmtstr.Range, argIndex int) {
213-
var (
214-
rangeStart = formatPos + token.Pos(offsetInStringLiteral(lit.Value, format, rang.Start))
215-
rangeEnd = formatPos + token.Pos(offsetInStringLiteral(lit.Value, format, rang.End-1)+1)
216-
arg ast.Expr // may not exist
217-
)
213+
rangeStart, err := posInStringLiteral(lit, rang.Start)
214+
if err != nil {
215+
return
216+
}
217+
rangeEnd, err := posInStringLiteral(lit, rang.End)
218+
if err != nil {
219+
return
220+
}
218221
visited[posRange{rangeStart, rangeEnd}] = argIndex
222+
223+
var arg ast.Expr
219224
if argIndex < len(call.Args) {
220225
arg = call.Args[argIndex]
221226
}
222-
if rangeStart <= cursorPos && cursorPos < rangeEnd || arg != nil && arg.Pos() <= cursorPos && cursorPos < arg.End() {
227+
228+
// cursorPos can't equal to end position, otherwise the two
229+
// neighborhood such as (%[2]*d) are both highlighted if cursor in "*" (ending of [2]*).
230+
if rangeStart <= cursorPos && cursorPos < rangeEnd ||
231+
arg != nil && arg.Pos() <= cursorPos && cursorPos < arg.End() {
223232
highlightRange(result, rangeStart, rangeEnd, protocol.Write)
224233
if arg != nil {
225234
succeededArg = argIndex
@@ -268,65 +277,36 @@ func highlightPrintf(call *ast.CallExpr, idx int, cursorPos token.Pos, lit *ast.
268277
}
269278
}
270279

271-
// offsetInStringLiteral maps an offset in the unquoted string to
272-
// relative to the literal string.
273-
func offsetInStringLiteral(literal string, unquoted string, logicalOffset int) int {
274-
literalIdx := 1 // Skip the initial quote char.
275-
logIdx := 0
276-
277-
// Advance by one unquoted rune and the corresponding literal string.
278-
advanceRune := func() {
279-
r, size := utf8.DecodeRuneInString(unquoted[logIdx:])
280-
if r == utf8.RuneError && size <= 1 {
281-
// Malformed UTF-8 or end of string,
282-
// move one byte in both strings to avoid infinite loops.
283-
logIdx++
284-
literalIdx++
285-
return
286-
}
287-
logIdx += size
280+
// posInStringLiteral returns the position within a string literal
281+
// corresponding to the specified byte offset within the logical
282+
// string that it denotes.
283+
func posInStringLiteral(lit *ast.BasicLit, offset int) (token.Pos, error) {
284+
raw := lit.Value
288285

289-
if literalIdx >= len(literal)-1 {
290-
return
291-
}
292-
293-
if literal[literalIdx] == '\\' {
294-
remain := literal[literalIdx:]
295-
escLen := 0
296-
if len(remain) < 2 {
297-
escLen = 1 // just the '\'
298-
}
299-
switch remain[1] {
300-
case 'x':
301-
escLen = 4
302-
case 'u':
303-
escLen = 6
304-
case 'U':
305-
escLen = 10
306-
case 'a', 'b', 'f', 'n', 'r', 't', 'v', '\\', '"':
307-
escLen = 2
308-
case '0', '1', '2', '3', '4', '5', '6', '7':
309-
escLen = 4
310-
default:
311-
return
312-
}
313-
literalIdx += escLen
314-
} else {
315-
// non-escaped character
316-
literalIdx++
317-
}
286+
value, err := strconv.Unquote(raw)
287+
if err != nil {
288+
return 0, err
318289
}
319-
320-
for logIdx < len(unquoted) && (logIdx < logicalOffset) && literalIdx < len(literal)-1 {
321-
advanceRune()
290+
if !(0 <= offset && offset <= len(value)) {
291+
return 0, fmt.Errorf("invalid offset")
322292
}
323293

324-
// Clamp it to ensure we don't exceed array bounds.
325-
if literalIdx >= len(literal)-1 {
326-
literalIdx = len(literal) - 1
327-
}
294+
// remove quotes
295+
quote := raw[0] // '"' or '`'
296+
raw = raw[1 : len(raw)-1]
328297

329-
return literalIdx
298+
var (
299+
i = 0 // byte index within logical value
300+
pos = lit.ValuePos + 1 // position within literal
301+
)
302+
for raw != "" && i < offset {
303+
r, _, rest, _ := strconv.UnquoteChar(raw, quote) // can't fail
304+
sz := len(raw) - len(rest) // length of literal char in raw bytes
305+
pos += token.Pos(sz)
306+
raw = raw[sz:]
307+
i += utf8.RuneLen(r)
308+
}
309+
return pos, nil
330310
}
331311

332312
type posRange struct {

0 commit comments

Comments
 (0)