Skip to content

Commit

Permalink
control: Refactor interface finder
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Nov 12, 2024
1 parent 0998999 commit cc7e630
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 70 deletions.
41 changes: 20 additions & 21 deletions common/cond.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,18 @@ func IndexIndexed[T any](arr []T, block func(index int, it T) bool) int {
return -1
}

func Equal[S ~[]E, E comparable](s1, s2 S) bool {
if len(s1) != len(s2) {
return false
}
for i := range s1 {
if s1[i] != s2[i] {
return false
}
}
return true
}

//go:norace
func Dup[T any](obj T) T {
pointer := uintptr(unsafe.Pointer(&obj))
Expand Down Expand Up @@ -268,6 +280,14 @@ func Reverse[T any](arr []T) []T {
return arr
}

func ReverseMap[K comparable, V comparable](m map[K]V) map[V]K {
ret := make(map[V]K, len(m))
for k, v := range m {
ret[v] = k
}
return ret
}

func Done(ctx context.Context) bool {
select {
case <-ctx.Done():
Expand Down Expand Up @@ -362,24 +382,3 @@ func Close(closers ...any) error {
}
return retErr
}

// Deprecated: wtf is this?
type Starter interface {
Start() error
}

// Deprecated: wtf is this?
func Start(starters ...any) error {
for _, rawStarter := range starters {
if rawStarter == nil {
continue
}
if starter, isStarter := rawStarter.(Starter); isStarter {
err := starter.Start()
if err != nil {
return err
}
}
}
return nil
}
4 changes: 2 additions & 2 deletions common/control/bind_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ import (

func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
return Raw(conn, func(fd uintptr) error {
var err error
if interfaceIndex == -1 {
if finder == nil {
return os.ErrInvalid
}
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
iif, err := finder.ByName(interfaceName)
if err != nil {
return err
}
interfaceIndex = iif.Index
}
switch network {
case "tcp6", "udp6":
Expand Down
44 changes: 40 additions & 4 deletions common/control/bind_finder.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,57 @@ package control
import (
"net"
"net/netip"
"unsafe"

"github.com/sagernet/sing/common"
M "github.com/sagernet/sing/common/metadata"
)

type InterfaceFinder interface {
Update() error
Interfaces() []Interface
InterfaceIndexByName(name string) (int, error)
InterfaceNameByIndex(index int) (string, error)
InterfaceByAddr(addr netip.Addr) (*Interface, error)
ByName(name string) (*Interface, error)
ByIndex(index int) (*Interface, error)
ByAddr(addr netip.Addr) (*Interface, error)
}

type Interface struct {
Index int
MTU int
Name string
Addresses []netip.Prefix
HardwareAddr net.HardwareAddr
Flags net.Flags
Addresses []netip.Prefix
}

func (i Interface) Equals(other Interface) bool {
return i.Index == other.Index &&
i.MTU == other.MTU &&
i.Name == other.Name &&
common.Equal(i.HardwareAddr, other.HardwareAddr) &&
i.Flags == other.Flags &&
common.Equal(i.Addresses, other.Addresses)
}

func (i Interface) NetInterface() net.Interface {
return *(*net.Interface)(unsafe.Pointer(&i))
}

func InterfaceFromNet(iif net.Interface) (Interface, error) {
ifAddrs, err := iif.Addrs()
if err != nil {
return Interface{}, err
}
return InterfaceFromNetAddrs(iif, common.Map(ifAddrs, M.PrefixFromNet)), nil
}

func InterfaceFromNetAddrs(iif net.Interface, addresses []netip.Prefix) Interface {
return Interface{
Index: iif.Index,
MTU: iif.MTU,
Name: iif.Name,
HardwareAddr: iif.HardwareAddr,
Flags: iif.Flags,
Addresses: addresses,
}
}
62 changes: 24 additions & 38 deletions common/control/bind_finder_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@ package control
import (
"net"
"net/netip"
_ "unsafe"

"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
)

var _ InterfaceFinder = (*DefaultInterfaceFinder)(nil)
Expand All @@ -27,18 +24,12 @@ func (f *DefaultInterfaceFinder) Update() error {
}
interfaces := make([]Interface, 0, len(netIfs))
for _, netIf := range netIfs {
ifAddrs, err := netIf.Addrs()
var iif Interface
iif, err = InterfaceFromNet(netIf)
if err != nil {
return err
}
interfaces = append(interfaces, Interface{
Index: netIf.Index,
MTU: netIf.MTU,
Name: netIf.Name,
Addresses: common.Map(ifAddrs, M.PrefixFromNet),
HardwareAddr: netIf.HardwareAddr,
Flags: netIf.Flags,
})
interfaces = append(interfaces, iif)
}
f.interfaces = interfaces
return nil
Expand All @@ -52,46 +43,41 @@ func (f *DefaultInterfaceFinder) Interfaces() []Interface {
return f.interfaces
}

func (f *DefaultInterfaceFinder) InterfaceIndexByName(name string) (int, error) {
func (f *DefaultInterfaceFinder) ByName(name string) (*Interface, error) {
for _, netInterface := range f.interfaces {
if netInterface.Name == name {
return netInterface.Index, nil
return &netInterface, nil
}
}
netInterface, err := net.InterfaceByName(name)
if err != nil {
return 0, err
_, err := net.InterfaceByName(name)
if err == nil {
err = f.Update()
if err != nil {
return nil, err
}
return f.ByName(name)
}
f.Update()
return netInterface.Index, nil
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: nil}, Err: E.New("no such network interface")}
}

func (f *DefaultInterfaceFinder) InterfaceNameByIndex(index int) (string, error) {
func (f *DefaultInterfaceFinder) ByIndex(index int) (*Interface, error) {
for _, netInterface := range f.interfaces {
if netInterface.Index == index {
return netInterface.Name, nil
return &netInterface, nil
}
}
netInterface, err := net.InterfaceByIndex(index)
if err != nil {
return "", err
_, err := net.InterfaceByIndex(index)
if err == nil {
err = f.Update()
if err != nil {
return nil, err
}
return f.ByIndex(index)
}
f.Update()
return netInterface.Name, nil
return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: nil}, Err: E.New("no such network interface")}
}

func (f *DefaultInterfaceFinder) InterfaceByAddr(addr netip.Addr) (*Interface, error) {
for _, netInterface := range f.interfaces {
for _, prefix := range netInterface.Addresses {
if prefix.Contains(addr) {
return &netInterface, nil
}
}
}
err := f.Update()
if err != nil {
return nil, err
}
func (f *DefaultInterfaceFinder) ByAddr(addr netip.Addr) (*Interface, error) {
for _, netInterface := range f.interfaces {
for _, prefix := range netInterface.Addresses {
if prefix.Contains(addr) {
Expand Down
4 changes: 2 additions & 2 deletions common/control/bind_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ func bindToInterface(conn syscall.RawConn, network string, address string, finde
if interfaceName == "" {
return os.ErrInvalid
}
var err error
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
iif, err := finder.ByName(interfaceName)
if err != nil {
return err
}
interfaceIndex = iif.Index
}
err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_BINDTOIFINDEX, interfaceIndex)
if err == nil {
Expand Down
6 changes: 3 additions & 3 deletions common/control/bind_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,19 @@ import (

func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error {
return Raw(conn, func(fd uintptr) error {
var err error
if interfaceIndex == -1 {
if finder == nil {
return os.ErrInvalid
}
interfaceIndex, err = finder.InterfaceIndexByName(interfaceName)
iif, err := finder.ByName(interfaceName)
if err != nil {
return err
}
interfaceIndex = iif.Index
}
handle := syscall.Handle(fd)
if M.ParseSocksaddr(address).AddrString() == "" {
err = bind4(handle, interfaceIndex)
err := bind4(handle, interfaceIndex)
if err != nil {
return err
}
Expand Down

0 comments on commit cc7e630

Please sign in to comment.