Skip to content

Commit 7d3b736

Browse files
committed
use strconv.UnquoteChar
1 parent a6c7d79 commit 7d3b736

File tree

1 file changed

+56
-75
lines changed

1 file changed

+56
-75
lines changed

gopls/internal/golang/highlight.go

+56-75
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,9 @@ func highlightPath(info *types.Info, path []ast.Node, pos token.Pos) (map[posRan
8181
// Treat each corresponding ("%v", arg) pair as a highlight class.
8282
for _, node := range path {
8383
if call, ok := node.(*ast.CallExpr); ok {
84-
idx := formatStringIndex(info, call)
85-
if idx >= 0 && idx < len(call.Args) {
86-
// We only care about literal format strings, so fmt.Sprint("a"+"b%s", "bar") won't be highlighted.
87-
if lit, ok := call.Args[idx].(*ast.BasicLit); ok && lit.Kind == token.STRING {
88-
highlightPrintf(call, idx, pos, lit, result)
89-
}
84+
idx, lit := formatLitAndIndex(info, call)
85+
if idx != -1 {
86+
highlightPrintf(call, idx, pos, lit, result)
9087
}
9188
}
9289
}
@@ -152,29 +149,34 @@ func highlightPath(info *types.Info, path []ast.Node, pos token.Pos) (map[posRan
152149
return result, nil
153150
}
154151

155-
// formatStringIndex returns the index of the format string (the last
152+
// formatLitAndIndex returns the BasicLit and index of the BasicLit (the last
156153
// non-variadic parameter) within the given printf-like call
157-
// expression, or -1 if unknown.
158-
func formatStringIndex(info *types.Info, call *ast.CallExpr) int {
154+
// expression, returns -1 as index if unknown.
155+
func formatLitAndIndex(info *types.Info, call *ast.CallExpr) (int, *ast.BasicLit) {
159156
typ := info.Types[call.Fun].Type
160157
if typ == nil {
161-
return -1 // missing type
158+
return -1, nil // missing type
162159
}
163160
sig, ok := typ.(*types.Signature)
164161
if !ok {
165-
return -1 // ill-typed
162+
return -1, nil // ill-typed
166163
}
167164
if !sig.Variadic() {
168165
// Skip checking non-variadic functions.
169-
return -1
166+
return -1, nil
170167
}
171168
idx := sig.Params().Len() - 2
172169
if idx < 0 {
173170
// Skip checking variadic functions without
174171
// fixed arguments.
175-
return -1
172+
return -1, nil
173+
}
174+
175+
// We only care about literal format strings, so fmt.Sprint("a"+"b%s", "bar") won't be highlighted.
176+
if lit, ok := call.Args[idx].(*ast.BasicLit); ok && lit.Kind == token.STRING {
177+
return idx, lit
176178
}
177-
return idx
179+
return -1, nil
178180
}
179181

180182
// highlightPrintf highlights operations in a format string and their corresponding
@@ -207,19 +209,27 @@ func highlightPrintf(call *ast.CallExpr, idx int, cursorPos token.Pos, lit *ast.
207209
succeededArg := 0
208210
visited := make(map[posRange]int, 0)
209211

210-
formatPos := call.Args[idx].Pos()
211212
// highlightPair highlights the operation and its potential argument pair if the cursor is within either range.
212213
highlightPair := func(rang fmtstr.Range, argIndex int) {
213-
var (
214-
rangeStart = formatPos + token.Pos(offsetInStringLiteral(lit.Value, format, rang.Start))
215-
rangeEnd = formatPos + token.Pos(offsetInStringLiteral(lit.Value, format, rang.End-1)+1)
216-
arg ast.Expr // may not exist
217-
)
214+
rangeStart, err := posInStringLiteral(lit, rang.Start)
215+
if err != nil {
216+
return
217+
}
218+
rangeEnd, err := posInStringLiteral(lit, rang.End)
219+
if err != nil {
220+
return
221+
}
218222
visited[posRange{rangeStart, rangeEnd}] = argIndex
223+
224+
var arg ast.Expr
219225
if argIndex < len(call.Args) {
220226
arg = call.Args[argIndex]
221227
}
222-
if rangeStart <= cursorPos && cursorPos < rangeEnd || arg != nil && arg.Pos() <= cursorPos && cursorPos < arg.End() {
228+
229+
// cursorPos can't equal to end position, otherwise the two
230+
// neighborhood such as (%[2]*d) are both highlighted if cursor in "*" (ending of [2]*).
231+
if rangeStart <= cursorPos && cursorPos < rangeEnd ||
232+
arg != nil && arg.Pos() <= cursorPos && cursorPos < arg.End() {
223233
highlightRange(result, rangeStart, rangeEnd, protocol.Write)
224234
if arg != nil {
225235
succeededArg = argIndex
@@ -268,65 +278,36 @@ func highlightPrintf(call *ast.CallExpr, idx int, cursorPos token.Pos, lit *ast.
268278
}
269279
}
270280

271-
// offsetInStringLiteral maps an offset in the unquoted string to
272-
// relative to the literal string.
273-
func offsetInStringLiteral(literal string, unquoted string, logicalOffset int) int {
274-
literalIdx := 1 // Skip the initial quote char.
275-
logIdx := 0
276-
277-
// Advance by one unquoted rune and the corresponding literal string.
278-
advanceRune := func() {
279-
r, size := utf8.DecodeRuneInString(unquoted[logIdx:])
280-
if r == utf8.RuneError && size <= 1 {
281-
// Malformed UTF-8 or end of string,
282-
// move one byte in both strings to avoid infinite loops.
283-
logIdx++
284-
literalIdx++
285-
return
286-
}
287-
logIdx += size
288-
289-
if literalIdx >= len(literal)-1 {
290-
return
291-
}
281+
// posInStringLiteral returns the position within a string literal
282+
// corresponding to the specified byte offset within the logical
283+
// string that it denotes.
284+
func posInStringLiteral(lit *ast.BasicLit, offset int) (token.Pos, error) {
285+
raw := lit.Value
292286

293-
if literal[literalIdx] == '\\' {
294-
remain := literal[literalIdx:]
295-
escLen := 0
296-
if len(remain) < 2 {
297-
escLen = 1 // just the '\'
298-
}
299-
switch remain[1] {
300-
case 'x':
301-
escLen = 4
302-
case 'u':
303-
escLen = 6
304-
case 'U':
305-
escLen = 10
306-
case 'a', 'b', 'f', 'n', 'r', 't', 'v', '\\', '"':
307-
escLen = 2
308-
case '0', '1', '2', '3', '4', '5', '6', '7':
309-
escLen = 4
310-
default:
311-
return
312-
}
313-
literalIdx += escLen
314-
} else {
315-
// non-escaped character
316-
literalIdx++
317-
}
287+
value, err := strconv.Unquote(raw)
288+
if err != nil {
289+
return 0, err
318290
}
319-
320-
for logIdx < len(unquoted) && (logIdx < logicalOffset) && literalIdx < len(literal)-1 {
321-
advanceRune()
291+
if !(0 <= offset && offset <= len(value)) {
292+
return 0, fmt.Errorf("invalid offset")
322293
}
323294

324-
// Clamp it to ensure we don't exceed array bounds.
325-
if literalIdx >= len(literal)-1 {
326-
literalIdx = len(literal) - 1
327-
}
295+
// remove quotes
296+
quote := raw[0] // '"' or '`'
297+
raw = raw[1 : len(raw)-1]
328298

329-
return literalIdx
299+
var (
300+
i = 0 // byte index within logical value
301+
pos = lit.ValuePos + 1 // position within literal
302+
)
303+
for raw != "" && i < offset {
304+
r, _, rest, _ := strconv.UnquoteChar(raw, quote) // can't fail
305+
sz := len(raw) - len(rest) // length of literal char in raw bytes
306+
pos += token.Pos(sz)
307+
raw = raw[sz:]
308+
i += utf8.RuneLen(r)
309+
}
310+
return pos, nil
330311
}
331312

332313
type posRange struct {

0 commit comments

Comments
 (0)