-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
gopls/internal/analysis/hostport: report net.Dial("%s:%d") addresses
This change defines an analyzer that reports calls to net.Dial, net.DialTimeout, or net.Dialer.Dial with an address produced by a direct call to fmt.Sprintf, or via an intermediate local variable declared using the form: addr := fmt.Sprintf("%s:%d", host, port) ... net.Dial("tcp", addr) In other words, it uses the more precise approach suggested in dominikh/go-tools#358, not the blunter instrument of golang/go#28308. Formatting addresses this way doesn't work with IPv6. The diagnostic carries a fix to use net.JoinHostPort instead. The analyzer turns up a fairly small number of diagnostics across the corpus; however it is precise and cheap to run (since it requires a direct import of net). + test, relnote, doc We plan to add this to cmd/vet after go1.24 is released. Updates golang/go#28308 Updates dominikh/go-tools#358 Change-Id: I72e27253b75ed4702762a65c1b069e7920103bb7 Reviewed-on: https://go-review.googlesource.com/c/tools/+/554495 Auto-Submit: Alan Donovan <[email protected]> LUCI-TryBot-Result: Go LUCI <[email protected]> Reviewed-by: Robert Findley <[email protected]>
- Loading branch information
Showing
9 changed files
with
350 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
// Copyright 2024 The Go Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
// Package hostport defines an analyzer for calls to net.Dial with | ||
// addresses of the form "%s:%d" or "%s:%s", which work only with IPv4. | ||
package hostport | ||
|
||
import ( | ||
"fmt" | ||
"go/ast" | ||
"go/constant" | ||
"go/types" | ||
|
||
"golang.org/x/tools/go/analysis" | ||
"golang.org/x/tools/go/analysis/passes/inspect" | ||
"golang.org/x/tools/go/ast/inspector" | ||
"golang.org/x/tools/go/types/typeutil" | ||
"golang.org/x/tools/gopls/internal/util/safetoken" | ||
"golang.org/x/tools/internal/analysisinternal" | ||
"golang.org/x/tools/internal/astutil/cursor" | ||
) | ||
|
||
const Doc = `check format of addresses passed to net.Dial | ||
This analyzer flags code that produce network address strings using | ||
fmt.Sprintf, as in this example: | ||
addr := fmt.Sprintf("%s:%d", host, 12345) // "will not work with IPv6" | ||
... | ||
conn, err := net.Dial("tcp", addr) // "when passed to dial here" | ||
The analyzer suggests a fix to use the correct approach, a call to | ||
net.JoinHostPort: | ||
addr := net.JoinHostPort(host, "12345") | ||
... | ||
conn, err := net.Dial("tcp", addr) | ||
A similar diagnostic and fix are produced for a format string of "%s:%s". | ||
` | ||
|
||
var Analyzer = &analysis.Analyzer{ | ||
Name: "hostport", | ||
Doc: Doc, | ||
URL: "https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/hostport", | ||
Requires: []*analysis.Analyzer{inspect.Analyzer}, | ||
Run: run, | ||
} | ||
|
||
func run(pass *analysis.Pass) (any, error) { | ||
// Fast path: if the package doesn't import net and fmt, skip | ||
// the traversal. | ||
if !analysisinternal.Imports(pass.Pkg, "net") || | ||
!analysisinternal.Imports(pass.Pkg, "fmt") { | ||
return nil, nil | ||
} | ||
|
||
info := pass.TypesInfo | ||
|
||
// checkAddr reports a diagnostic (and returns true) if e | ||
// is a call of the form fmt.Sprintf("%d:%d", ...). | ||
// The diagnostic includes a fix. | ||
// | ||
// dialCall is non-nil if the Dial call is non-local | ||
// but within the same file. | ||
checkAddr := func(e ast.Expr, dialCall *ast.CallExpr) { | ||
if call, ok := e.(*ast.CallExpr); ok { | ||
obj := typeutil.Callee(info, call) | ||
if analysisinternal.IsFunctionNamed(obj, "fmt", "Sprintf") { | ||
// Examine format string. | ||
formatArg := call.Args[0] | ||
if tv := info.Types[formatArg]; tv.Value != nil { | ||
numericPort := false | ||
format := constant.StringVal(tv.Value) | ||
switch format { | ||
case "%s:%d": | ||
// Have: fmt.Sprintf("%s:%d", host, port) | ||
numericPort = true | ||
|
||
case "%s:%s": | ||
// Have: fmt.Sprintf("%s:%s", host, portStr) | ||
// Keep port string as is. | ||
|
||
default: | ||
return | ||
} | ||
|
||
// Use granular edits to preserve original formatting. | ||
edits := []analysis.TextEdit{ | ||
{ | ||
// Replace fmt.Sprintf with net.JoinHostPort. | ||
Pos: call.Fun.Pos(), | ||
End: call.Fun.End(), | ||
NewText: []byte("net.JoinHostPort"), | ||
}, | ||
{ | ||
// Delete format string. | ||
Pos: formatArg.Pos(), | ||
End: call.Args[1].Pos(), | ||
}, | ||
} | ||
|
||
// Turn numeric port into a string. | ||
if numericPort { | ||
// port => fmt.Sprintf("%d", port) | ||
// 123 => "123" | ||
port := call.Args[2] | ||
newPort := fmt.Sprintf(`fmt.Sprintf("%%d", %s)`, port) | ||
if port := info.Types[port].Value; port != nil { | ||
if i, ok := constant.Int64Val(port); ok { | ||
newPort = fmt.Sprintf(`"%d"`, i) // numeric constant | ||
} | ||
} | ||
|
||
edits = append(edits, analysis.TextEdit{ | ||
Pos: port.Pos(), | ||
End: port.End(), | ||
NewText: []byte(newPort), | ||
}) | ||
} | ||
|
||
// Refer to Dial call, if not adjacent. | ||
suffix := "" | ||
if dialCall != nil { | ||
suffix = fmt.Sprintf(" (passed to net.Dial at L%d)", | ||
safetoken.StartPosition(pass.Fset, dialCall.Pos()).Line) | ||
} | ||
|
||
pass.Report(analysis.Diagnostic{ | ||
// Highlight the format string. | ||
Pos: formatArg.Pos(), | ||
End: formatArg.End(), | ||
Message: fmt.Sprintf("address format %q does not work with IPv6%s", format, suffix), | ||
SuggestedFixes: []analysis.SuggestedFix{{ | ||
Message: "Replace fmt.Sprintf with net.JoinHostPort", | ||
TextEdits: edits, | ||
}}, | ||
}) | ||
} | ||
} | ||
} | ||
} | ||
|
||
// Check address argument of each call to net.Dial et al. | ||
inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) | ||
for curCall := range cursor.Root(inspect).Preorder((*ast.CallExpr)(nil)) { | ||
call := curCall.Node().(*ast.CallExpr) | ||
|
||
obj := typeutil.Callee(info, call) | ||
if analysisinternal.IsFunctionNamed(obj, "net", "Dial", "DialTimeout") || | ||
analysisinternal.IsMethodNamed(obj, "net", "Dialer", "Dial") { | ||
|
||
switch address := call.Args[1].(type) { | ||
case *ast.CallExpr: | ||
// net.Dial("tcp", fmt.Sprintf("%s:%d", ...)) | ||
checkAddr(address, nil) | ||
|
||
case *ast.Ident: | ||
// addr := fmt.Sprintf("%s:%d", ...) | ||
// ... | ||
// net.Dial("tcp", addr) | ||
|
||
// Search for decl of addrVar within common ancestor of addrVar and Dial call. | ||
if addrVar, ok := info.Uses[address].(*types.Var); ok { | ||
pos := addrVar.Pos() | ||
// TODO(adonovan): use Cursor.Ancestors iterator when available. | ||
for _, curAncestor := range curCall.Stack(nil) { | ||
if curIdent, ok := curAncestor.FindPos(pos, pos); ok { | ||
// curIdent is the declaring ast.Ident of addr. | ||
switch parent := curIdent.Parent().Node().(type) { | ||
case *ast.AssignStmt: | ||
if len(parent.Rhs) == 1 { | ||
// Have: addr := fmt.Sprintf("%s:%d", ...) | ||
checkAddr(parent.Rhs[0], call) | ||
} | ||
|
||
case *ast.ValueSpec: | ||
if len(parent.Values) == 1 { | ||
// Have: var addr = fmt.Sprintf("%s:%d", ...) | ||
checkAddr(parent.Values[0], call) | ||
} | ||
} | ||
break | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
return nil, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
// Copyright 2024 The Go Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
package hostport_test | ||
|
||
import ( | ||
"testing" | ||
|
||
"golang.org/x/tools/go/analysis/analysistest" | ||
"golang.org/x/tools/gopls/internal/analysis/hostport" | ||
) | ||
|
||
func Test(t *testing.T) { | ||
testdata := analysistest.TestData() | ||
analysistest.RunWithSuggestedFixes(t, testdata, hostport.Analyzer, "a") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
// Copyright 2024 The Go Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
//go:build ignore | ||
|
||
package main | ||
|
||
import ( | ||
"golang.org/x/tools/go/analysis/singlechecker" | ||
"golang.org/x/tools/gopls/internal/analysis/hostport" | ||
) | ||
|
||
func main() { singlechecker.Main(hostport.Analyzer) } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
package a | ||
|
||
import ( | ||
"fmt" | ||
"net" | ||
) | ||
|
||
func direct(host string, port int, portStr string) { | ||
// Dial, directly called with result of Sprintf. | ||
net.Dial("tcp", fmt.Sprintf("%s:%d", host, port)) // want `address format "%s:%d" does not work with IPv6` | ||
|
||
net.Dial("tcp", fmt.Sprintf("%s:%s", host, portStr)) // want `address format "%s:%s" does not work with IPv6` | ||
} | ||
|
||
// port is a constant: | ||
var addr4 = fmt.Sprintf("%s:%d", "localhost", 123) // want `address format "%s:%d" does not work with IPv6 \(passed to net.Dial at L39\)` | ||
|
||
func indirect(host string, port int) { | ||
// Dial, addr is immediately preceding. | ||
{ | ||
addr1 := fmt.Sprintf("%s:%d", host, port) // want `address format "%s:%d" does not work with IPv6.*at L22` | ||
net.Dial("tcp", addr1) | ||
} | ||
|
||
// DialTimeout, addr is in ancestor block. | ||
addr2 := fmt.Sprintf("%s:%d", host, port) // want `address format "%s:%d" does not work with IPv6.*at L28` | ||
{ | ||
net.DialTimeout("tcp", addr2, 0) | ||
} | ||
|
||
// Dialer.Dial, addr is declared with var. | ||
var dialer net.Dialer | ||
{ | ||
var addr3 = fmt.Sprintf("%s:%d", host, port) // want `address format "%s:%d" does not work with IPv6.*at L35` | ||
dialer.Dial("tcp", addr3) | ||
} | ||
|
||
// Dialer.Dial again, addr is declared at package level. | ||
dialer.Dial("tcp", addr4) | ||
} |
40 changes: 40 additions & 0 deletions
40
gopls/internal/analysis/hostport/testdata/src/a/a.go.golden
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
package a | ||
|
||
import ( | ||
"fmt" | ||
"net" | ||
) | ||
|
||
func direct(host string, port int, portStr string) { | ||
// Dial, directly called with result of Sprintf. | ||
net.Dial("tcp", net.JoinHostPort(host, fmt.Sprintf("%d", port))) // want `address format "%s:%d" does not work with IPv6` | ||
|
||
net.Dial("tcp", net.JoinHostPort(host, portStr)) // want `address format "%s:%s" does not work with IPv6` | ||
} | ||
|
||
// port is a constant: | ||
var addr4 = net.JoinHostPort("localhost", "123") // want `address format "%s:%d" does not work with IPv6 \(passed to net.Dial at L39\)` | ||
|
||
func indirect(host string, port int) { | ||
// Dial, addr is immediately preceding. | ||
{ | ||
addr1 := net.JoinHostPort(host, fmt.Sprintf("%d", port)) // want `address format "%s:%d" does not work with IPv6.*at L22` | ||
net.Dial("tcp", addr1) | ||
} | ||
|
||
// DialTimeout, addr is in ancestor block. | ||
addr2 := net.JoinHostPort(host, fmt.Sprintf("%d", port)) // want `address format "%s:%d" does not work with IPv6.*at L28` | ||
{ | ||
net.DialTimeout("tcp", addr2, 0) | ||
} | ||
|
||
// Dialer.Dial, addr is declared with var. | ||
var dialer net.Dialer | ||
{ | ||
var addr3 = net.JoinHostPort(host, fmt.Sprintf("%d", port)) // want `address format "%s:%d" does not work with IPv6.*at L35` | ||
dialer.Dial("tcp", addr3) | ||
} | ||
|
||
// Dialer.Dial again, addr is declared at package level. | ||
dialer.Dial("tcp", addr4) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,6 +49,7 @@ import ( | |
"golang.org/x/tools/gopls/internal/analysis/deprecated" | ||
"golang.org/x/tools/gopls/internal/analysis/embeddirective" | ||
"golang.org/x/tools/gopls/internal/analysis/fillreturns" | ||
"golang.org/x/tools/gopls/internal/analysis/hostport" | ||
"golang.org/x/tools/gopls/internal/analysis/infertypeargs" | ||
"golang.org/x/tools/gopls/internal/analysis/modernize" | ||
"golang.org/x/tools/gopls/internal/analysis/nonewvars" | ||
|
@@ -158,6 +159,7 @@ func init() { | |
{analyzer: sortslice.Analyzer, enabled: true}, | ||
{analyzer: embeddirective.Analyzer, enabled: true}, | ||
{analyzer: waitgroup.Analyzer, enabled: true}, // to appear in cmd/[email protected] | ||
{analyzer: hostport.Analyzer, enabled: true}, // to appear in cmd/[email protected] | ||
{analyzer: modernize.Analyzer, enabled: true, severity: protocol.SeverityInformation}, | ||
|
||
// disabled due to high false positives | ||
|