From 9443c94bfa5aa707f20503318c2bebe58583fea6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 21 Nov 2024 18:10:41 +0800 Subject: [PATCH] refactor: WireGuard endpoint --- adapter/endpoint.go | 28 ++ adapter/endpoint/adapter.go | 43 +++ adapter/endpoint/manager.go | 147 +++++++++++ adapter/endpoint/registry.go | 72 +++++ adapter/inbound.go | 2 +- adapter/inbound/manager.go | 11 +- adapter/lifecycle_legacy.go | 3 + adapter/outbound/adapter.go | 18 +- adapter/outbound/manager.go | 34 ++- box.go | 51 +++- cmd/sing-box/cmd.go | 2 +- common/dialer/default.go | 2 +- experimental/deprecated/constants.go | 10 + experimental/libbox/config.go | 4 +- experimental/libbox/service.go | 2 +- go.mod | 4 +- go.sum | 8 +- include/registry.go | 9 + include/wireguard.go | 5 + include/wireguard_stub.go | 9 +- option/endpoint.go | 47 ++++ option/inbound.go | 2 +- option/options.go | 1 + option/outbound.go | 2 +- option/wireguard.go | 38 ++- protocol/block/outbound.go | 2 +- protocol/direct/inbound.go | 5 +- protocol/direct/outbound.go | 2 +- protocol/dns/outbound.go | 2 +- protocol/group/selector.go | 2 +- protocol/group/urltest.go | 2 +- protocol/http/inbound.go | 5 +- protocol/http/outbound.go | 2 +- protocol/hysteria/inbound.go | 5 +- protocol/hysteria/outbound.go | 2 +- protocol/hysteria2/inbound.go | 5 +- protocol/hysteria2/outbound.go | 2 +- protocol/mixed/inbound.go | 5 +- protocol/naive/inbound.go | 5 +- protocol/redirect/redirect.go | 5 +- protocol/redirect/tproxy.go | 5 +- protocol/shadowsocks/inbound.go | 5 +- protocol/shadowsocks/inbound_multi.go | 5 +- protocol/shadowsocks/inbound_relay.go | 5 +- protocol/shadowsocks/outbound.go | 2 +- protocol/shadowtls/inbound.go | 5 +- protocol/shadowtls/outbound.go | 2 +- protocol/socks/inbound.go | 5 +- protocol/socks/outbound.go | 2 +- protocol/ssh/outbound.go | 2 +- protocol/tor/outbound.go | 2 +- protocol/trojan/inbound.go | 5 +- protocol/trojan/outbound.go | 2 +- protocol/tuic/inbound.go | 5 +- protocol/tuic/outbound.go | 2 +- protocol/tun/inbound.go | 182 ++++++------- protocol/vless/inbound.go | 5 +- protocol/vless/outbound.go | 2 +- protocol/vmess/inbound.go | 5 +- protocol/vmess/outbound.go | 2 +- protocol/wireguard/endpoint.go | 211 +++++++++++++++ protocol/wireguard/outbound.go | 242 ++++++----------- route/network.go | 37 ++- test/box_test.go | 2 +- test/wireguard_test.go | 4 +- transport/wireguard/client_bind.go | 36 ++- transport/wireguard/device.go | 37 ++- transport/wireguard/device_stack.go | 123 ++++----- .../{gonet.go => device_stack_gonet.go} | 0 transport/wireguard/device_stack_stub.go | 10 +- transport/wireguard/device_system.go | 148 +++++------ transport/wireguard/device_system_stack.go | 182 +++++++++++++ transport/wireguard/endpoint.go | 248 +++++++++++++++++- transport/wireguard/endpoint_options.go | 40 +++ transport/wireguard/resolve.go | 148 ----------- 75 files changed, 1649 insertions(+), 674 deletions(-) create mode 100644 adapter/endpoint.go create mode 100644 adapter/endpoint/adapter.go create mode 100644 adapter/endpoint/manager.go create mode 100644 adapter/endpoint/registry.go create mode 100644 option/endpoint.go create mode 100644 protocol/wireguard/endpoint.go rename transport/wireguard/{gonet.go => device_stack_gonet.go} (100%) create mode 100644 transport/wireguard/device_system_stack.go create mode 100644 transport/wireguard/endpoint_options.go delete mode 100644 transport/wireguard/resolve.go diff --git a/adapter/endpoint.go b/adapter/endpoint.go new file mode 100644 index 0000000000..f09f08ce75 --- /dev/null +++ b/adapter/endpoint.go @@ -0,0 +1,28 @@ +package adapter + +import ( + "context" + + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" +) + +type Endpoint interface { + Lifecycle + Type() string + Tag() string + Outbound +} + +type EndpointRegistry interface { + option.EndpointOptionsRegistry + Create(ctx context.Context, router Router, logger log.ContextLogger, tag string, endpointType string, options any) (Endpoint, error) +} + +type EndpointManager interface { + Lifecycle + Endpoints() []Endpoint + Get(tag string) (Endpoint, bool) + Remove(tag string) error + Create(ctx context.Context, router Router, logger log.ContextLogger, tag string, endpointType string, options any) error +} diff --git a/adapter/endpoint/adapter.go b/adapter/endpoint/adapter.go new file mode 100644 index 0000000000..c75e4d839c --- /dev/null +++ b/adapter/endpoint/adapter.go @@ -0,0 +1,43 @@ +package endpoint + +import "github.com/sagernet/sing-box/option" + +type Adapter struct { + endpointType string + endpointTag string + network []string + dependencies []string +} + +func NewAdapter(endpointType string, endpointTag string, network []string, dependencies []string) Adapter { + return Adapter{ + endpointType: endpointType, + endpointTag: endpointTag, + network: network, + dependencies: dependencies, + } +} + +func NewAdapterWithDialerOptions(endpointType string, endpointTag string, network []string, dialOptions option.DialerOptions) Adapter { + var dependencies []string + if dialOptions.Detour != "" { + dependencies = []string{dialOptions.Detour} + } + return NewAdapter(endpointType, endpointTag, network, dependencies) +} + +func (a *Adapter) Type() string { + return a.endpointType +} + +func (a *Adapter) Tag() string { + return a.endpointTag +} + +func (a *Adapter) Network() []string { + return a.network +} + +func (a *Adapter) Dependencies() []string { + return a.dependencies +} diff --git a/adapter/endpoint/manager.go b/adapter/endpoint/manager.go new file mode 100644 index 0000000000..5a633beed3 --- /dev/null +++ b/adapter/endpoint/manager.go @@ -0,0 +1,147 @@ +package endpoint + +import ( + "context" + "os" + "sync" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/taskmonitor" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" +) + +var _ adapter.EndpointManager = (*Manager)(nil) + +type Manager struct { + logger log.ContextLogger + registry adapter.EndpointRegistry + access sync.Mutex + started bool + stage adapter.StartStage + endpoints []adapter.Endpoint + endpointByTag map[string]adapter.Endpoint +} + +func NewManager(logger log.ContextLogger, registry adapter.EndpointRegistry) *Manager { + return &Manager{ + logger: logger, + registry: registry, + endpointByTag: make(map[string]adapter.Endpoint), + } +} + +func (m *Manager) Start(stage adapter.StartStage) error { + m.access.Lock() + defer m.access.Unlock() + if m.started && m.stage >= stage { + panic("already started") + } + m.started = true + m.stage = stage + if stage == adapter.StartStateStart { + // started with outbound manager + return nil + } + for _, endpoint := range m.endpoints { + err := adapter.LegacyStart(endpoint, stage) + if err != nil { + return E.Cause(err, stage, " endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]") + } + } + return nil +} + +func (m *Manager) Close() error { + m.access.Lock() + defer m.access.Unlock() + if !m.started { + return nil + } + m.started = false + endpoints := m.endpoints + m.endpoints = nil + monitor := taskmonitor.New(m.logger, C.StopTimeout) + var err error + for _, endpoint := range endpoints { + monitor.Start("close endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]") + err = E.Append(err, endpoint.Close(), func(err error) error { + return E.Cause(err, "close endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]") + }) + monitor.Finish() + } + return nil +} + +func (m *Manager) Endpoints() []adapter.Endpoint { + m.access.Lock() + defer m.access.Unlock() + return m.endpoints +} + +func (m *Manager) Get(tag string) (adapter.Endpoint, bool) { + m.access.Lock() + defer m.access.Unlock() + endpoint, found := m.endpointByTag[tag] + return endpoint, found +} + +func (m *Manager) Remove(tag string) error { + m.access.Lock() + endpoint, found := m.endpointByTag[tag] + if !found { + m.access.Unlock() + return os.ErrInvalid + } + delete(m.endpointByTag, tag) + index := common.Index(m.endpoints, func(it adapter.Endpoint) bool { + return it == endpoint + }) + if index == -1 { + panic("invalid endpoint index") + } + m.endpoints = append(m.endpoints[:index], m.endpoints[index+1:]...) + started := m.started + m.access.Unlock() + if started { + return endpoint.Close() + } + return nil +} + +func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, outboundType string, options any) error { + endpoint, err := m.registry.Create(ctx, router, logger, tag, outboundType, options) + if err != nil { + return err + } + m.access.Lock() + defer m.access.Unlock() + if m.started { + for _, stage := range adapter.ListStartStages { + err = adapter.LegacyStart(endpoint, stage) + if err != nil { + return E.Cause(err, stage, " endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]") + } + } + } + if existsEndpoint, loaded := m.endpointByTag[tag]; loaded { + if m.started { + err = existsEndpoint.Close() + if err != nil { + return E.Cause(err, "close endpoint/", existsEndpoint.Type(), "[", existsEndpoint.Tag(), "]") + } + } + existsIndex := common.Index(m.endpoints, func(it adapter.Endpoint) bool { + return it == existsEndpoint + }) + if existsIndex == -1 { + panic("invalid endpoint index") + } + m.endpoints = append(m.endpoints[:existsIndex], m.endpoints[existsIndex+1:]...) + } + m.endpoints = append(m.endpoints, endpoint) + m.endpointByTag[tag] = endpoint + return nil +} diff --git a/adapter/endpoint/registry.go b/adapter/endpoint/registry.go new file mode 100644 index 0000000000..92cb9025de --- /dev/null +++ b/adapter/endpoint/registry.go @@ -0,0 +1,72 @@ +package endpoint + +import ( + "context" + "sync" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" +) + +type ConstructorFunc[T any] func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options T) (adapter.Endpoint, error) + +func Register[Options any](registry *Registry, outboundType string, constructor ConstructorFunc[Options]) { + registry.register(outboundType, func() any { + return new(Options) + }, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, rawOptions any) (adapter.Endpoint, error) { + var options *Options + if rawOptions != nil { + options = rawOptions.(*Options) + } + return constructor(ctx, router, logger, tag, common.PtrValueOrDefault(options)) + }) +} + +var _ adapter.EndpointRegistry = (*Registry)(nil) + +type ( + optionsConstructorFunc func() any + constructorFunc func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options any) (adapter.Endpoint, error) +) + +type Registry struct { + access sync.Mutex + optionsType map[string]optionsConstructorFunc + constructor map[string]constructorFunc +} + +func NewRegistry() *Registry { + return &Registry{ + optionsType: make(map[string]optionsConstructorFunc), + constructor: make(map[string]constructorFunc), + } +} + +func (m *Registry) CreateOptions(outboundType string) (any, bool) { + m.access.Lock() + defer m.access.Unlock() + optionsConstructor, loaded := m.optionsType[outboundType] + if !loaded { + return nil, false + } + return optionsConstructor(), true +} + +func (m *Registry) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, outboundType string, options any) (adapter.Endpoint, error) { + m.access.Lock() + defer m.access.Unlock() + constructor, loaded := m.constructor[outboundType] + if !loaded { + return nil, E.New("outbound type not found: " + outboundType) + } + return constructor(ctx, router, logger, tag, options) +} + +func (m *Registry) register(outboundType string, optionsConstructor optionsConstructorFunc, constructor constructorFunc) { + m.access.Lock() + defer m.access.Unlock() + m.optionsType[outboundType] = optionsConstructor + m.constructor[outboundType] = constructor +} diff --git a/adapter/inbound.go b/adapter/inbound.go index d3cc95b7d9..b4eaca3d4d 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -13,7 +13,7 @@ import ( ) type Inbound interface { - Service + Lifecycle Type() string Tag() string } diff --git a/adapter/inbound/manager.go b/adapter/inbound/manager.go index d2be0f365e..c690b2c913 100644 --- a/adapter/inbound/manager.go +++ b/adapter/inbound/manager.go @@ -18,6 +18,7 @@ var _ adapter.InboundManager = (*Manager)(nil) type Manager struct { logger log.ContextLogger registry adapter.InboundRegistry + endpoint adapter.EndpointManager access sync.Mutex started bool stage adapter.StartStage @@ -25,10 +26,11 @@ type Manager struct { inboundByTag map[string]adapter.Inbound } -func NewManager(logger log.ContextLogger, registry adapter.InboundRegistry) *Manager { +func NewManager(logger log.ContextLogger, registry adapter.InboundRegistry, endpoint adapter.EndpointManager) *Manager { return &Manager{ logger: logger, registry: registry, + endpoint: endpoint, inboundByTag: make(map[string]adapter.Inbound), } } @@ -79,9 +81,12 @@ func (m *Manager) Inbounds() []adapter.Inbound { func (m *Manager) Get(tag string) (adapter.Inbound, bool) { m.access.Lock() - defer m.access.Unlock() inbound, found := m.inboundByTag[tag] - return inbound, found + m.access.Unlock() + if found { + return inbound, true + } + return m.endpoint.Get(tag) } func (m *Manager) Remove(tag string) error { diff --git a/adapter/lifecycle_legacy.go b/adapter/lifecycle_legacy.go index 0c8c75daed..94a5cf8c8d 100644 --- a/adapter/lifecycle_legacy.go +++ b/adapter/lifecycle_legacy.go @@ -1,6 +1,9 @@ package adapter func LegacyStart(starter any, stage StartStage) error { + if lifecycle, isLifecycle := starter.(Lifecycle); isLifecycle { + return lifecycle.Start(stage) + } switch stage { case StartStateInitialize: if preStarter, isPreStarter := starter.(interface { diff --git a/adapter/outbound/adapter.go b/adapter/outbound/adapter.go index 481bb6197d..cd71527af5 100644 --- a/adapter/outbound/adapter.go +++ b/adapter/outbound/adapter.go @@ -5,35 +5,35 @@ import ( ) type Adapter struct { - protocol string + outboundType string + outboundTag string network []string - tag string dependencies []string } -func NewAdapter(protocol string, network []string, tag string, dependencies []string) Adapter { +func NewAdapter(outboundType string, outboundTag string, network []string, dependencies []string) Adapter { return Adapter{ - protocol: protocol, + outboundType: outboundType, + outboundTag: outboundTag, network: network, - tag: tag, dependencies: dependencies, } } -func NewAdapterWithDialerOptions(protocol string, network []string, tag string, dialOptions option.DialerOptions) Adapter { +func NewAdapterWithDialerOptions(outboundType string, outboundTag string, network []string, dialOptions option.DialerOptions) Adapter { var dependencies []string if dialOptions.Detour != "" { dependencies = []string{dialOptions.Detour} } - return NewAdapter(protocol, network, tag, dependencies) + return NewAdapter(outboundType, outboundTag, network, dependencies) } func (a *Adapter) Type() string { - return a.protocol + return a.outboundType } func (a *Adapter) Tag() string { - return a.tag + return a.outboundTag } func (a *Adapter) Network() []string { diff --git a/adapter/outbound/manager.go b/adapter/outbound/manager.go index 84a105c59e..f68b42b942 100644 --- a/adapter/outbound/manager.go +++ b/adapter/outbound/manager.go @@ -21,6 +21,7 @@ var _ adapter.OutboundManager = (*Manager)(nil) type Manager struct { logger log.ContextLogger registry adapter.OutboundRegistry + endpoint adapter.EndpointManager defaultTag string access sync.Mutex started bool @@ -32,10 +33,11 @@ type Manager struct { defaultOutboundFallback adapter.Outbound } -func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry, defaultTag string) *Manager { +func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry, endpoint adapter.EndpointManager, defaultTag string) *Manager { return &Manager{ logger: logger, registry: registry, + endpoint: endpoint, defaultTag: defaultTag, outboundByTag: make(map[string]adapter.Outbound), dependByTag: make(map[string][]string), @@ -56,7 +58,14 @@ func (m *Manager) Start(stage adapter.StartStage) error { outbounds := m.outbounds m.access.Unlock() if stage == adapter.StartStateStart { - return m.startOutbounds(outbounds) + if m.defaultOutbound == nil { + if len(outbounds) > 0 { + m.defaultOutbound = outbounds[0] + } else if len(m.endpoint.Endpoints()) > 0 { + m.defaultOutbound = m.endpoint.Endpoints()[0] + } + } + return m.startOutbounds(append(outbounds, common.Map(m.endpoint.Endpoints(), func(it adapter.Endpoint) adapter.Outbound { return it })...)) } else { for _, outbound := range outbounds { err := adapter.LegacyStart(outbound, stage) @@ -87,7 +96,14 @@ func (m *Manager) startOutbounds(outbounds []adapter.Outbound) error { } started[outboundTag] = true canContinue = true - if starter, isStarter := outboundToStart.(interface { + if starter, isStarter := outboundToStart.(adapter.Lifecycle); isStarter { + monitor.Start("start outbound/", outboundToStart.Type(), "[", outboundTag, "]") + err := starter.Start(adapter.StartStateStart) + monitor.Finish() + if err != nil { + return E.Cause(err, "start outbound/", outboundToStart.Type(), "[", outboundTag, "]") + } + } else if starter, isStarter := outboundToStart.(interface { Start() error }); isStarter { monitor.Start("start outbound/", outboundToStart.Type(), "[", outboundTag, "]") @@ -160,9 +176,12 @@ func (m *Manager) Outbounds() []adapter.Outbound { func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) { m.access.Lock() - defer m.access.Unlock() outbound, found := m.outboundByTag[tag] - return outbound, found + m.access.Unlock() + if found { + return outbound, true + } + return m.endpoint.Get(tag) } func (m *Manager) Default() adapter.Outbound { @@ -195,6 +214,9 @@ func (m *Manager) Remove(tag string) error { if len(m.outbounds) > 0 { m.defaultOutbound = m.outbounds[0] m.logger.Info("updated default outbound to ", m.defaultOutbound.Tag()) + } else if len(m.endpoint.Endpoints()) > 0 { + m.defaultOutbound = m.endpoint.Endpoints()[0] + m.logger.Info("updated default outbound to ", m.defaultOutbound.Tag()) } else { m.defaultOutbound = nil } @@ -259,7 +281,7 @@ func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log. for _, dependency := range dependencies { m.dependByTag[dependency] = append(m.dependByTag[dependency], tag) } - if tag == m.defaultTag || (m.defaultTag == "" && m.defaultOutbound == nil) { + if tag == m.defaultTag || (m.started && m.defaultTag == "" && m.defaultOutbound == nil) { m.defaultOutbound = outbound if m.started { m.logger.Info("updated default outbound to ", outbound.Tag()) diff --git a/box.go b/box.go index 02fdff33d8..b47122c714 100644 --- a/box.go +++ b/box.go @@ -9,6 +9,7 @@ import ( "time" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/endpoint" "github.com/sagernet/sing-box/adapter/inbound" "github.com/sagernet/sing-box/adapter/outbound" "github.com/sagernet/sing-box/common/dialer" @@ -36,6 +37,7 @@ type Box struct { logFactory log.Factory logger log.ContextLogger network *route.NetworkManager + endpoint *endpoint.Manager inbound *inbound.Manager outbound *outbound.Manager connection *route.ConnectionManager @@ -54,6 +56,7 @@ func Context( ctx context.Context, inboundRegistry adapter.InboundRegistry, outboundRegistry adapter.OutboundRegistry, + endpointRegistry adapter.EndpointRegistry, ) context.Context { if service.FromContext[option.InboundOptionsRegistry](ctx) == nil || service.FromContext[adapter.InboundRegistry](ctx) == nil { @@ -65,6 +68,11 @@ func Context( ctx = service.ContextWith[option.OutboundOptionsRegistry](ctx, outboundRegistry) ctx = service.ContextWith[adapter.OutboundRegistry](ctx, outboundRegistry) } + if service.FromContext[option.EndpointOptionsRegistry](ctx) == nil || + service.FromContext[adapter.EndpointRegistry](ctx) == nil { + ctx = service.ContextWith[option.EndpointOptionsRegistry](ctx, endpointRegistry) + ctx = service.ContextWith[adapter.EndpointRegistry](ctx, endpointRegistry) + } return ctx } @@ -76,12 +84,16 @@ func New(options Options) (*Box, error) { } ctx = service.ContextWithDefaultRegistry(ctx) + endpointRegistry := service.FromContext[adapter.EndpointRegistry](ctx) inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx) + outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx) + + if endpointRegistry == nil { + return nil, E.New("missing endpoint registry in context") + } if inboundRegistry == nil { return nil, E.New("missing inbound registry in context") } - - outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx) if outboundRegistry == nil { return nil, E.New("missing outbound registry in context") } @@ -119,8 +131,10 @@ func New(options Options) (*Box, error) { } routeOptions := common.PtrValueOrDefault(options.Route) - inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry) - outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, routeOptions.Final) + endpointManager := endpoint.NewManager(logFactory.NewLogger("endpoint"), endpointRegistry) + inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry, endpointManager) + outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, endpointManager, routeOptions.Final) + service.MustRegister[adapter.EndpointManager](ctx, endpointManager) service.MustRegister[adapter.InboundManager](ctx, inboundManager) service.MustRegister[adapter.OutboundManager](ctx, outboundManager) @@ -135,6 +149,24 @@ func New(options Options) (*Box, error) { if err != nil { return nil, E.Cause(err, "initialize router") } + for i, endpointOptions := range options.Endpoints { + var tag string + if endpointOptions.Tag != "" { + tag = endpointOptions.Tag + } else { + tag = F.ToString(i) + } + err = endpointManager.Create(ctx, + router, + logFactory.NewLogger(F.ToString("endpoint/", endpointOptions.Type, "[", tag, "]")), + tag, + endpointOptions.Type, + endpointOptions.Options, + ) + if err != nil { + return nil, E.Cause(err, "initialize inbound[", i, "]") + } + } for i, inboundOptions := range options.Inbounds { var tag string if inboundOptions.Tag != "" { @@ -241,6 +273,7 @@ func New(options Options) (*Box, error) { } return &Box{ network: networkManager, + endpoint: endpointManager, inbound: inboundManager, outbound: outboundManager, connection: connectionManager, @@ -303,7 +336,7 @@ func (s *Box) preStart() error { if err != nil { return err } - err = adapter.Start(adapter.StartStateInitialize, s.network, s.router, s.outbound, s.inbound) + err = adapter.Start(adapter.StartStateInitialize, s.network, s.router, s.outbound, s.inbound, s.endpoint) if err != nil { return err } @@ -327,7 +360,11 @@ func (s *Box) start() error { if err != nil { return err } - err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.router, s.inbound) + err = adapter.Start(adapter.StartStateStart, s.endpoint) + if err != nil { + return err + } + err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.router, s.inbound, s.endpoint) if err != nil { return err } @@ -335,7 +372,7 @@ func (s *Box) start() error { if err != nil { return err } - err = adapter.Start(adapter.StartStateStarted, s.network, s.router, s.outbound, s.inbound) + err = adapter.Start(adapter.StartStateStarted, s.network, s.router, s.outbound, s.inbound, s.endpoint) if err != nil { return err } diff --git a/cmd/sing-box/cmd.go b/cmd/sing-box/cmd.go index dc7a830965..d55235b855 100644 --- a/cmd/sing-box/cmd.go +++ b/cmd/sing-box/cmd.go @@ -69,5 +69,5 @@ func preRun(cmd *cobra.Command, args []string) { configPaths = append(configPaths, "config.json") } globalCtx = service.ContextWith(globalCtx, deprecated.NewStderrManager(log.StdLogger())) - globalCtx = box.Context(globalCtx, include.InboundRegistry(), include.OutboundRegistry()) + globalCtx = box.Context(globalCtx, include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry()) } diff --git a/common/dialer/default.go b/common/dialer/default.go index 9c4c865b25..c6b7f53823 100644 --- a/common/dialer/default.go +++ b/common/dialer/default.go @@ -279,7 +279,7 @@ func (d *DefaultDialer) ListenSerialInterfacePacket(ctx context.Context, destina } func (d *DefaultDialer) ListenPacketCompat(network, address string) (net.PacketConn, error) { - return trackPacketConn(d.listenSerialInterfacePacket(context.Background(), d.udpListener, network, address, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay)) + return trackPacketConn(d.udpListener.ListenPacket(context.Background(), network, address)) } func trackConn(conn net.Conn, err error) (net.Conn, error) { diff --git a/experimental/deprecated/constants.go b/experimental/deprecated/constants.go index a62b89f84b..f5b000a7ac 100644 --- a/experimental/deprecated/constants.go +++ b/experimental/deprecated/constants.go @@ -109,6 +109,15 @@ var OptionDestinationOverrideFields = Note{ MigrationLink: "https://sing-box.sagernet.org/migration/#migrate-destination-override-fields-to-route-options", } +var OptionWireGuardOutbound = Note{ + Name: "wireguard-outbound", + Description: "legacy wireguard outbound", + DeprecatedVersion: "1.11.0", + ScheduledVersion: "1.13.0", + EnvName: "WIREGUARD_OUTBOUND", + MigrationLink: "https://sing-box.sagernet.org/migration/#migrate-wireguard-outbound-to-endpoint", +} + var Options = []Note{ OptionBadMatchSource, OptionGEOIP, @@ -117,4 +126,5 @@ var Options = []Note{ OptionSpecialOutbounds, OptionInboundOptions, OptionDestinationOverrideFields, + OptionWireGuardOutbound, } diff --git a/experimental/libbox/config.go b/experimental/libbox/config.go index cce45c6040..f8c30b3b3f 100644 --- a/experimental/libbox/config.go +++ b/experimental/libbox/config.go @@ -30,7 +30,7 @@ func parseConfig(ctx context.Context, configContent string) (option.Options, err } func CheckConfig(configContent string) error { - ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry()) + ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry()) options, err := parseConfig(ctx, configContent) if err != nil { return err @@ -131,7 +131,7 @@ func (s *platformInterfaceStub) SendNotification(notification *platform.Notifica } func FormatConfig(configContent string) (string, error) { - options, err := parseConfig(box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry()), configContent) + options, err := parseConfig(box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry()), configContent) if err != nil { return "", err } diff --git a/experimental/libbox/service.go b/experimental/libbox/service.go index 0fdf721a37..30e5750711 100644 --- a/experimental/libbox/service.go +++ b/experimental/libbox/service.go @@ -44,7 +44,7 @@ type BoxService struct { } func NewService(configContent string, platformInterface PlatformInterface) (*BoxService, error) { - ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry()) + ctx := box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry()) ctx = filemanager.WithDefault(ctx, sWorkingPath, sTempPath, sUserID, sGroupID) service.MustRegister[deprecated.Manager](ctx, new(deprecatedManager)) options, err := parseConfig(ctx, configContent) diff --git a/go.mod b/go.mod index 6cfb070cbc..cc08a0e4d4 100644 --- a/go.mod +++ b/go.mod @@ -32,11 +32,11 @@ require ( github.com/sagernet/sing-shadowsocks v0.2.7 github.com/sagernet/sing-shadowsocks2 v0.2.0 github.com/sagernet/sing-shadowtls v0.2.0-alpha.2 - github.com/sagernet/sing-tun v0.6.0-alpha.9 + github.com/sagernet/sing-tun v0.6.0-alpha.10 github.com/sagernet/sing-vmess v0.1.12 github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7 github.com/sagernet/utls v1.6.7 - github.com/sagernet/wireguard-go v0.0.0-20231215174105-89dec3b2f3e8 + github.com/sagernet/wireguard-go v0.0.1-beta.2 github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 github.com/spf13/cobra v1.8.1 github.com/stretchr/testify v1.9.0 diff --git a/go.sum b/go.sum index a97d4c7cc7..71fdd0e01d 100644 --- a/go.sum +++ b/go.sum @@ -124,16 +124,16 @@ github.com/sagernet/sing-shadowsocks2 v0.2.0 h1:wpZNs6wKnR7mh1wV9OHwOyUr21VkS3wK github.com/sagernet/sing-shadowsocks2 v0.2.0/go.mod h1:RnXS0lExcDAovvDeniJ4IKa2IuChrdipolPYWBv9hWQ= github.com/sagernet/sing-shadowtls v0.2.0-alpha.2 h1:RPrpgAdkP5td0vLfS5ldvYosFjSsZtRPxiyLV6jyKg0= github.com/sagernet/sing-shadowtls v0.2.0-alpha.2/go.mod h1:0j5XlzKxaWRIEjc1uiSKmVoWb0k+L9QgZVb876+thZA= -github.com/sagernet/sing-tun v0.6.0-alpha.9 h1:Qf667035KnlydZ+ftj3U4HH+oddi3RdyKzBiCcnSgaI= -github.com/sagernet/sing-tun v0.6.0-alpha.9/go.mod h1:TgvxE2YD7O9c/unHju0nWAGBGsVppWIuju13vlmdllM= +github.com/sagernet/sing-tun v0.6.0-alpha.10 h1:kJOMUR6VKHkTrtJ+kPJVsCqrJYmW0nTRJLYv+Or7lNA= +github.com/sagernet/sing-tun v0.6.0-alpha.10/go.mod h1:UmZpZ06gItrbOFLhyeZsilHKQDa5h4NSQy8LalkTkXQ= github.com/sagernet/sing-vmess v0.1.12 h1:2gFD8JJb+eTFMoa8FIVMnknEi+vCSfaiTXTfEYAYAPg= github.com/sagernet/sing-vmess v0.1.12/go.mod h1:luTSsfyBGAc9VhtCqwjR+dt1QgqBhuYBCONB/POhF8I= github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7 h1:DImB4lELfQhplLTxeq2z31Fpv8CQqqrUwTbrIRumZqQ= github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7/go.mod h1:FP9X2xjT/Az1EsG/orYYoC+5MojWnuI7hrffz8fGwwo= github.com/sagernet/utls v1.6.7 h1:Ep3+aJ8FUGGta+II2IEVNUc3EDhaRCZINWkj/LloIA8= github.com/sagernet/utls v1.6.7/go.mod h1:Uua1TKO/FFuAhLr9rkaVnnrTmmiItzDjv1BUb2+ERwM= -github.com/sagernet/wireguard-go v0.0.0-20231215174105-89dec3b2f3e8 h1:R0OMYAScomNAVpTfbHFpxqJpvwuhxSRi+g6z7gZhABs= -github.com/sagernet/wireguard-go v0.0.0-20231215174105-89dec3b2f3e8/go.mod h1:K4J7/npM+VAMUeUmTa2JaA02JmyheP0GpRBOUvn3ecc= +github.com/sagernet/wireguard-go v0.0.1-beta.2 h1:afmDgfCL2Esc+2EYtdcJFepTWHX9+kZnosC0A84VJ9s= +github.com/sagernet/wireguard-go v0.0.1-beta.2/go.mod h1:8xfewtQJZ1g3HeMQbLpJxTjyTiE3FL+Joq5LQoKLFEw= github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 h1:6uUiZcDRnZSAegryaUGwPC/Fj13JSHwiTftrXhMmYOc= github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854/go.mod h1:LtfoSK3+NG57tvnVEHgcuBW9ujgE8enPSgzgwStwCAA= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= diff --git a/include/registry.go b/include/registry.go index 03fb33f245..e71ffb0c82 100644 --- a/include/registry.go +++ b/include/registry.go @@ -4,6 +4,7 @@ import ( "context" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/endpoint" "github.com/sagernet/sing-box/adapter/inbound" "github.com/sagernet/sing-box/adapter/outbound" C "github.com/sagernet/sing-box/constant" @@ -82,6 +83,14 @@ func OutboundRegistry() *outbound.Registry { return registry } +func EndpointRegistry() *endpoint.Registry { + registry := endpoint.NewRegistry() + + registerWireGuardEndpoint(registry) + + return registry +} + func registerStubForRemovedInbounds(registry *inbound.Registry) { inbound.Register[option.ShadowsocksInboundOptions](registry, C.TypeShadowsocksR, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowsocksInboundOptions) (adapter.Inbound, error) { return nil, E.New("ShadowsocksR is deprecated and removed in sing-box 1.6.0") diff --git a/include/wireguard.go b/include/wireguard.go index dfc3a242a5..f2ce9e2341 100644 --- a/include/wireguard.go +++ b/include/wireguard.go @@ -3,6 +3,7 @@ package include import ( + "github.com/sagernet/sing-box/adapter/endpoint" "github.com/sagernet/sing-box/adapter/outbound" "github.com/sagernet/sing-box/protocol/wireguard" ) @@ -10,3 +11,7 @@ import ( func registerWireGuardOutbound(registry *outbound.Registry) { wireguard.RegisterOutbound(registry) } + +func registerWireGuardEndpoint(registry *endpoint.Registry) { + wireguard.RegisterEndpoint(registry) +} diff --git a/include/wireguard_stub.go b/include/wireguard_stub.go index a9e84522bb..247546e26e 100644 --- a/include/wireguard_stub.go +++ b/include/wireguard_stub.go @@ -6,6 +6,7 @@ import ( "context" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/endpoint" "github.com/sagernet/sing-box/adapter/outbound" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" @@ -14,7 +15,13 @@ import ( ) func registerWireGuardOutbound(registry *outbound.Registry) { - outbound.Register[option.WireGuardOutboundOptions](registry, C.TypeWireGuard, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardOutboundOptions) (adapter.Outbound, error) { + outbound.Register[option.LegacyWireGuardOutboundOptions](registry, C.TypeWireGuard, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.LegacyWireGuardOutboundOptions) (adapter.Outbound, error) { + return nil, E.New(`WireGuard is not included in this build, rebuild with -tags with_wireguard`) + }) +} + +func registerWireGuardEndpoint(registry *endpoint.Registry) { + endpoint.Register[option.WireGuardEndpointOptions](registry, C.TypeWireGuard, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardEndpointOptions) (adapter.Endpoint, error) { return nil, E.New(`WireGuard is not included in this build, rebuild with -tags with_wireguard`) }) } diff --git a/option/endpoint.go b/option/endpoint.go new file mode 100644 index 0000000000..909fb89618 --- /dev/null +++ b/option/endpoint.go @@ -0,0 +1,47 @@ +package option + +import ( + "context" + + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json" + "github.com/sagernet/sing/common/json/badjson" + "github.com/sagernet/sing/service" +) + +type EndpointOptionsRegistry interface { + CreateOptions(endpointType string) (any, bool) +} + +type _Endpoint struct { + Type string `json:"type"` + Tag string `json:"tag,omitempty"` + Options any `json:"-"` +} + +type Endpoint _Endpoint + +func (h *Endpoint) MarshalJSONContext(ctx context.Context) ([]byte, error) { + return badjson.MarshallObjectsContext(ctx, (*_Endpoint)(h), h.Options) +} + +func (h *Endpoint) UnmarshalJSONContext(ctx context.Context, content []byte) error { + err := json.UnmarshalContext(ctx, content, (*_Endpoint)(h)) + if err != nil { + return err + } + registry := service.FromContext[EndpointOptionsRegistry](ctx) + if registry == nil { + return E.New("missing Endpoint fields registry in context") + } + options, loaded := registry.CreateOptions(h.Type) + if !loaded { + return E.New("unknown inbound type: ", h.Type) + } + err = badjson.UnmarshallExcludedContext(ctx, content, (*_Endpoint)(h), options) + if err != nil { + return err + } + h.Options = options + return nil +} diff --git a/option/inbound.go b/option/inbound.go index 2cc1598946..1cf16ff6ec 100644 --- a/option/inbound.go +++ b/option/inbound.go @@ -28,7 +28,7 @@ func (h *Inbound) MarshalJSONContext(ctx context.Context) ([]byte, error) { } func (h *Inbound) UnmarshalJSONContext(ctx context.Context, content []byte) error { - err := json.Unmarshal(content, (*_Inbound)(h)) + err := json.UnmarshalContext(ctx, content, (*_Inbound)(h)) if err != nil { return err } diff --git a/option/options.go b/option/options.go index 13a16c08b5..94c9771928 100644 --- a/option/options.go +++ b/option/options.go @@ -13,6 +13,7 @@ type _Options struct { Log *LogOptions `json:"log,omitempty"` DNS *DNSOptions `json:"dns,omitempty"` NTP *NTPOptions `json:"ntp,omitempty"` + Endpoints []Endpoint `json:"endpoints,omitempty"` Inbounds []Inbound `json:"inbounds,omitempty"` Outbounds []Outbound `json:"outbounds,omitempty"` Route *RouteOptions `json:"route,omitempty"` diff --git a/option/outbound.go b/option/outbound.go index 34ef904a80..833a2d2030 100644 --- a/option/outbound.go +++ b/option/outbound.go @@ -30,7 +30,7 @@ func (h *Outbound) MarshalJSONContext(ctx context.Context) ([]byte, error) { } func (h *Outbound) UnmarshalJSONContext(ctx context.Context, content []byte) error { - err := json.Unmarshal(content, (*_Outbound)(h)) + err := json.UnmarshalContext(ctx, content, (*_Outbound)(h)) if err != nil { return err } diff --git a/option/wireguard.go b/option/wireguard.go index ebdf159fde..b70fbd55d3 100644 --- a/option/wireguard.go +++ b/option/wireguard.go @@ -6,14 +6,38 @@ import ( "github.com/sagernet/sing/common/json/badoption" ) -type WireGuardOutboundOptions struct { +type WireGuardEndpointOptions struct { + System bool `json:"system,omitempty"` + Name string `json:"name,omitempty"` + MTU uint32 `json:"mtu,omitempty"` + GSO bool `json:"gso,omitempty"` + Address badoption.Listable[netip.Prefix] `json:"address"` + PrivateKey string `json:"private_key"` + ListenPort uint16 `json:"listen_port,omitempty"` + Peers []WireGuardPeer `json:"peers,omitempty"` + UDPTimeout UDPTimeoutCompat `json:"udp_timeout,omitempty"` + Workers int `json:"workers,omitempty"` + DialerOptions +} + +type WireGuardPeer struct { + Address string `json:"address,omitempty"` + Port uint16 `json:"port,omitempty"` + PublicKey string `json:"public_key,omitempty"` + PreSharedKey string `json:"pre_shared_key,omitempty"` + AllowedIPs badoption.Listable[netip.Prefix] `json:"allowed_ips,omitempty"` + PersistentKeepaliveInterval badoption.Duration `json:"persistent_keepalive_interval,omitempty"` + Reserved []uint8 `json:"reserved,omitempty"` +} + +type LegacyWireGuardOutboundOptions struct { DialerOptions SystemInterface bool `json:"system_interface,omitempty"` GSO bool `json:"gso,omitempty"` InterfaceName string `json:"interface_name,omitempty"` LocalAddress badoption.Listable[netip.Prefix] `json:"local_address"` PrivateKey string `json:"private_key"` - Peers []WireGuardPeer `json:"peers,omitempty"` + Peers []LegacyWireGuardPeer `json:"peers,omitempty"` ServerOptions PeerPublicKey string `json:"peer_public_key"` PreSharedKey string `json:"pre_shared_key,omitempty"` @@ -23,10 +47,10 @@ type WireGuardOutboundOptions struct { Network NetworkList `json:"network,omitempty"` } -type WireGuardPeer struct { +type LegacyWireGuardPeer struct { ServerOptions - PublicKey string `json:"public_key,omitempty"` - PreSharedKey string `json:"pre_shared_key,omitempty"` - AllowedIPs badoption.Listable[string] `json:"allowed_ips,omitempty"` - Reserved []uint8 `json:"reserved,omitempty"` + PublicKey string `json:"public_key,omitempty"` + PreSharedKey string `json:"pre_shared_key,omitempty"` + AllowedIPs badoption.Listable[netip.Prefix] `json:"allowed_ips,omitempty"` + Reserved []uint8 `json:"reserved,omitempty"` } diff --git a/protocol/block/outbound.go b/protocol/block/outbound.go index 75bc7797e0..fe1ccda760 100644 --- a/protocol/block/outbound.go +++ b/protocol/block/outbound.go @@ -26,7 +26,7 @@ type Outbound struct { func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, _ option.StubOptions) (adapter.Outbound, error) { return &Outbound{ - Adapter: outbound.NewAdapter(C.TypeBlock, []string{N.NetworkTCP, N.NetworkUDP}, tag, nil), + Adapter: outbound.NewAdapter(C.TypeBlock, tag, []string{N.NetworkTCP, N.NetworkUDP}, nil), logger: logger, }, nil } diff --git a/protocol/direct/inbound.go b/protocol/direct/inbound.go index 6db60d7890..c9594f9a7c 100644 --- a/protocol/direct/inbound.go +++ b/protocol/direct/inbound.go @@ -68,7 +68,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo return inbound, nil } -func (i *Inbound) Start() error { +func (i *Inbound) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } return i.listener.Start() } diff --git a/protocol/direct/outbound.go b/protocol/direct/outbound.go index 5ae0dac647..4962ea7025 100644 --- a/protocol/direct/outbound.go +++ b/protocol/direct/outbound.go @@ -52,7 +52,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL return nil, err } outbound := &Outbound{ - Adapter: outbound.NewAdapterWithDialerOptions(C.TypeDirect, []string{N.NetworkTCP, N.NetworkUDP}, tag, options.DialerOptions), + Adapter: outbound.NewAdapterWithDialerOptions(C.TypeDirect, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions), logger: logger, domainStrategy: dns.DomainStrategy(options.DomainStrategy), fallbackDelay: time.Duration(options.FallbackDelay), diff --git a/protocol/dns/outbound.go b/protocol/dns/outbound.go index 7ce9fde2f2..3c493f80e3 100644 --- a/protocol/dns/outbound.go +++ b/protocol/dns/outbound.go @@ -28,7 +28,7 @@ type Outbound struct { func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.StubOptions) (adapter.Outbound, error) { return &Outbound{ - Adapter: outbound.NewAdapter(C.TypeDNS, []string{N.NetworkTCP, N.NetworkUDP}, tag, nil), + Adapter: outbound.NewAdapter(C.TypeDNS, tag, []string{N.NetworkTCP, N.NetworkUDP}, nil), router: router, logger: logger, }, nil diff --git a/protocol/group/selector.go b/protocol/group/selector.go index 08db74c09e..0bb3cd6644 100644 --- a/protocol/group/selector.go +++ b/protocol/group/selector.go @@ -38,7 +38,7 @@ type Selector struct { func NewSelector(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SelectorOutboundOptions) (adapter.Outbound, error) { outbound := &Selector{ - Adapter: outbound.NewAdapter(C.TypeSelector, nil, tag, options.Outbounds), + Adapter: outbound.NewAdapter(C.TypeSelector, tag, nil, options.Outbounds), ctx: ctx, outboundManager: service.FromContext[adapter.OutboundManager](ctx), logger: logger, diff --git a/protocol/group/urltest.go b/protocol/group/urltest.go index f1a84b5044..fcada7dc3b 100644 --- a/protocol/group/urltest.go +++ b/protocol/group/urltest.go @@ -49,7 +49,7 @@ type URLTest struct { func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.URLTestOutboundOptions) (adapter.Outbound, error) { outbound := &URLTest{ - Adapter: outbound.NewAdapter(C.TypeURLTest, []string{N.NetworkTCP, N.NetworkUDP}, tag, options.Outbounds), + Adapter: outbound.NewAdapter(C.TypeURLTest, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.Outbounds), ctx: ctx, router: router, outboundManager: service.FromContext[adapter.OutboundManager](ctx), diff --git a/protocol/http/inbound.go b/protocol/http/inbound.go index bb8fe8c7dd..a7a463f76b 100644 --- a/protocol/http/inbound.go +++ b/protocol/http/inbound.go @@ -61,7 +61,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo return inbound, nil } -func (h *Inbound) Start() error { +func (h *Inbound) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } if h.tlsConfig != nil { err := h.tlsConfig.Start() if err != nil { diff --git a/protocol/http/outbound.go b/protocol/http/outbound.go index 81fd024669..c58f307138 100644 --- a/protocol/http/outbound.go +++ b/protocol/http/outbound.go @@ -39,7 +39,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL return nil, err } return &Outbound{ - Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHTTP, []string{N.NetworkTCP}, tag, options.DialerOptions), + Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHTTP, tag, []string{N.NetworkTCP}, options.DialerOptions), logger: logger, client: sHTTP.NewClient(sHTTP.Options{ Dialer: detour, diff --git a/protocol/hysteria/inbound.go b/protocol/hysteria/inbound.go index 2b6d8a4d78..447cb217c0 100644 --- a/protocol/hysteria/inbound.go +++ b/protocol/hysteria/inbound.go @@ -160,7 +160,10 @@ func (h *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose) } -func (h *Inbound) Start() error { +func (h *Inbound) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } if h.tlsConfig != nil { err := h.tlsConfig.Start() if err != nil { diff --git a/protocol/hysteria/outbound.go b/protocol/hysteria/outbound.go index e4c8775fa9..2da10f1ef4 100644 --- a/protocol/hysteria/outbound.go +++ b/protocol/hysteria/outbound.go @@ -95,7 +95,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL return nil, err } return &Outbound{ - Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria, networkList, tag, options.DialerOptions), + Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria, tag, networkList, options.DialerOptions), logger: logger, client: client, }, nil diff --git a/protocol/hysteria2/inbound.go b/protocol/hysteria2/inbound.go index 03cd8d2d52..a3260e55ee 100644 --- a/protocol/hysteria2/inbound.go +++ b/protocol/hysteria2/inbound.go @@ -171,7 +171,10 @@ func (h *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose) } -func (h *Inbound) Start() error { +func (h *Inbound) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } if h.tlsConfig != nil { err := h.tlsConfig.Start() if err != nil { diff --git a/protocol/hysteria2/outbound.go b/protocol/hysteria2/outbound.go index 4cabb4751c..068cc7f7fe 100644 --- a/protocol/hysteria2/outbound.go +++ b/protocol/hysteria2/outbound.go @@ -81,7 +81,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL return nil, err } return &Outbound{ - Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria2, networkList, tag, options.DialerOptions), + Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria2, tag, networkList, options.DialerOptions), logger: logger, client: client, }, nil diff --git a/protocol/mixed/inbound.go b/protocol/mixed/inbound.go index 4f48144086..ad7bed7de2 100644 --- a/protocol/mixed/inbound.go +++ b/protocol/mixed/inbound.go @@ -54,7 +54,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo return inbound, nil } -func (h *Inbound) Start() error { +func (h *Inbound) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } return h.listener.Start() } diff --git a/protocol/naive/inbound.go b/protocol/naive/inbound.go index 1a561aeaf2..18acd2accd 100644 --- a/protocol/naive/inbound.go +++ b/protocol/naive/inbound.go @@ -78,7 +78,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo return inbound, nil } -func (n *Inbound) Start() error { +func (n *Inbound) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } var tlsConfig *tls.STDConfig if n.tlsConfig != nil { err := n.tlsConfig.Start() diff --git a/protocol/redirect/redirect.go b/protocol/redirect/redirect.go index 23bfad3eb4..0950d102b9 100644 --- a/protocol/redirect/redirect.go +++ b/protocol/redirect/redirect.go @@ -42,7 +42,10 @@ func NewRedirect(ctx context.Context, router adapter.Router, logger log.ContextL return redirect, nil } -func (h *Redirect) Start() error { +func (h *Redirect) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } return h.listener.Start() } diff --git a/protocol/redirect/tproxy.go b/protocol/redirect/tproxy.go index bdff4dff6a..b651130d66 100644 --- a/protocol/redirect/tproxy.go +++ b/protocol/redirect/tproxy.go @@ -61,7 +61,10 @@ func NewTProxy(ctx context.Context, router adapter.Router, logger log.ContextLog return tproxy, nil } -func (t *TProxy) Start() error { +func (t *TProxy) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } err := t.listener.Start() if err != nil { return err diff --git a/protocol/shadowsocks/inbound.go b/protocol/shadowsocks/inbound.go index 84ad43fcb0..8332a93cc0 100644 --- a/protocol/shadowsocks/inbound.go +++ b/protocol/shadowsocks/inbound.go @@ -93,7 +93,10 @@ func newInbound(ctx context.Context, router adapter.Router, logger log.ContextLo return inbound, err } -func (h *Inbound) Start() error { +func (h *Inbound) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } return h.listener.Start() } diff --git a/protocol/shadowsocks/inbound_multi.go b/protocol/shadowsocks/inbound_multi.go index a76075efe0..ec55713430 100644 --- a/protocol/shadowsocks/inbound_multi.go +++ b/protocol/shadowsocks/inbound_multi.go @@ -101,7 +101,10 @@ func newMultiInbound(ctx context.Context, router adapter.Router, logger log.Cont return inbound, err } -func (h *MultiInbound) Start() error { +func (h *MultiInbound) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } return h.listener.Start() } diff --git a/protocol/shadowsocks/inbound_relay.go b/protocol/shadowsocks/inbound_relay.go index f7ec2b7703..bb20de3f58 100644 --- a/protocol/shadowsocks/inbound_relay.go +++ b/protocol/shadowsocks/inbound_relay.go @@ -86,7 +86,10 @@ func newRelayInbound(ctx context.Context, router adapter.Router, logger log.Cont return inbound, err } -func (h *RelayInbound) Start() error { +func (h *RelayInbound) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } return h.listener.Start() } diff --git a/protocol/shadowsocks/outbound.go b/protocol/shadowsocks/outbound.go index 8771fa8e99..7e7277ef95 100644 --- a/protocol/shadowsocks/outbound.go +++ b/protocol/shadowsocks/outbound.go @@ -49,7 +49,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL return nil, err } outbound := &Outbound{ - Adapter: outbound.NewAdapterWithDialerOptions(C.TypeShadowsocks, options.Network.Build(), tag, options.DialerOptions), + Adapter: outbound.NewAdapterWithDialerOptions(C.TypeShadowsocks, tag, options.Network.Build(), options.DialerOptions), logger: logger, dialer: outboundDialer, method: method, diff --git a/protocol/shadowtls/inbound.go b/protocol/shadowtls/inbound.go index 1be422decc..ce4238fd49 100644 --- a/protocol/shadowtls/inbound.go +++ b/protocol/shadowtls/inbound.go @@ -90,7 +90,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo return inbound, nil } -func (h *Inbound) Start() error { +func (h *Inbound) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } return h.listener.Start() } diff --git a/protocol/shadowtls/outbound.go b/protocol/shadowtls/outbound.go index e979dba270..2b480729e5 100644 --- a/protocol/shadowtls/outbound.go +++ b/protocol/shadowtls/outbound.go @@ -29,7 +29,7 @@ type Outbound struct { func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowTLSOutboundOptions) (adapter.Outbound, error) { outbound := &Outbound{ - Adapter: outbound.NewAdapterWithDialerOptions(C.TypeShadowTLS, []string{N.NetworkTCP}, tag, options.DialerOptions), + Adapter: outbound.NewAdapterWithDialerOptions(C.TypeShadowTLS, tag, []string{N.NetworkTCP}, options.DialerOptions), } if options.TLS == nil || !options.TLS.Enabled { return nil, C.ErrTLSRequired diff --git a/protocol/socks/inbound.go b/protocol/socks/inbound.go index fddacb21fb..115db30308 100644 --- a/protocol/socks/inbound.go +++ b/protocol/socks/inbound.go @@ -50,7 +50,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo return inbound, nil } -func (h *Inbound) Start() error { +func (h *Inbound) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } return h.listener.Start() } diff --git a/protocol/socks/outbound.go b/protocol/socks/outbound.go index 70a5a5eda0..0632f0825f 100644 --- a/protocol/socks/outbound.go +++ b/protocol/socks/outbound.go @@ -50,7 +50,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL return nil, err } outbound := &Outbound{ - Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSOCKS, options.Network.Build(), tag, options.DialerOptions), + Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSOCKS, tag, options.Network.Build(), options.DialerOptions), router: router, logger: logger, client: socks.NewClient(outboundDialer, options.ServerOptions.Build(), version, options.Username, options.Password), diff --git a/protocol/ssh/outbound.go b/protocol/ssh/outbound.go index 1dfc1f6d68..eb9970b5fb 100644 --- a/protocol/ssh/outbound.go +++ b/protocol/ssh/outbound.go @@ -54,7 +54,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL return nil, err } outbound := &Outbound{ - Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSSH, []string{N.NetworkTCP}, tag, options.DialerOptions), + Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSSH, tag, []string{N.NetworkTCP}, options.DialerOptions), ctx: ctx, logger: logger, dialer: outboundDialer, diff --git a/protocol/tor/outbound.go b/protocol/tor/outbound.go index 3d2170115d..58824b53fb 100644 --- a/protocol/tor/outbound.go +++ b/protocol/tor/outbound.go @@ -80,7 +80,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL return nil, err } return &Outbound{ - Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTor, []string{N.NetworkTCP}, tag, options.DialerOptions), + Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTor, tag, []string{N.NetworkTCP}, options.DialerOptions), ctx: ctx, logger: logger, proxy: NewProxyListener(ctx, logger, outboundDialer), diff --git a/protocol/trojan/inbound.go b/protocol/trojan/inbound.go index 2dc1bb9421..107667c580 100644 --- a/protocol/trojan/inbound.go +++ b/protocol/trojan/inbound.go @@ -110,7 +110,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo return inbound, nil } -func (h *Inbound) Start() error { +func (h *Inbound) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } if h.tlsConfig != nil { err := h.tlsConfig.Start() if err != nil { diff --git a/protocol/trojan/outbound.go b/protocol/trojan/outbound.go index 68b0069043..82889bc188 100644 --- a/protocol/trojan/outbound.go +++ b/protocol/trojan/outbound.go @@ -43,7 +43,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL return nil, err } outbound := &Outbound{ - Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTrojan, options.Network.Build(), tag, options.DialerOptions), + Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTrojan, tag, options.Network.Build(), options.DialerOptions), logger: logger, dialer: outboundDialer, serverAddr: options.ServerOptions.Build(), diff --git a/protocol/tuic/inbound.go b/protocol/tuic/inbound.go index 496079c194..a21b72ea5b 100644 --- a/protocol/tuic/inbound.go +++ b/protocol/tuic/inbound.go @@ -142,7 +142,10 @@ func (h *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose) } -func (h *Inbound) Start() error { +func (h *Inbound) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } if h.tlsConfig != nil { err := h.tlsConfig.Start() if err != nil { diff --git a/protocol/tuic/outbound.go b/protocol/tuic/outbound.go index 177f21fc95..49b01f96e9 100644 --- a/protocol/tuic/outbound.go +++ b/protocol/tuic/outbound.go @@ -80,7 +80,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL return nil, err } return &Outbound{ - Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTUIC, options.Network.Build(), tag, options.DialerOptions), + Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTUIC, tag, options.Network.Build(), options.DialerOptions), logger: logger, client: client, udpStream: options.UDPOverStream, diff --git a/protocol/tun/inbound.go b/protocol/tun/inbound.go index 302afb578d..b3ada90550 100644 --- a/protocol/tun/inbound.go +++ b/protocol/tun/inbound.go @@ -300,104 +300,104 @@ func (t *Inbound) Tag() string { return t.tag } -func (t *Inbound) Start() error { - if C.IsAndroid && t.platformInterface == nil { - t.tunOptions.BuildAndroidRules(t.networkManager.PackageManager()) - } - if t.tunOptions.Name == "" { - t.tunOptions.Name = tun.CalculateInterfaceName("") - } - var ( - tunInterface tun.Tun - err error - ) - monitor := taskmonitor.New(t.logger, C.StartTimeout) - monitor.Start("open tun interface") - if t.platformInterface != nil { - tunInterface, err = t.platformInterface.OpenTun(&t.tunOptions, t.platformOptions) - } else { - tunInterface, err = tun.New(t.tunOptions) - } - monitor.Finish() - if err != nil { - return E.Cause(err, "configure tun interface") - } - t.logger.Trace("creating stack") - t.tunIf = tunInterface - var ( - forwarderBindInterface bool - includeAllNetworks bool - ) - if t.platformInterface != nil { - forwarderBindInterface = true - includeAllNetworks = t.platformInterface.IncludeAllNetworks() - } - tunStack, err := tun.NewStack(t.stack, tun.StackOptions{ - Context: t.ctx, - Tun: tunInterface, - TunOptions: t.tunOptions, - UDPTimeout: t.udpTimeout, - Handler: t, - Logger: t.logger, - ForwarderBindInterface: forwarderBindInterface, - InterfaceFinder: t.networkManager.InterfaceFinder(), - IncludeAllNetworks: includeAllNetworks, - }) - if err != nil { - return err - } - t.tunStack = tunStack - t.logger.Info("started at ", t.tunOptions.Name) - return nil -} - -func (t *Inbound) PostStart() error { - monitor := taskmonitor.New(t.logger, C.StartTimeout) - monitor.Start("starting tun stack") - err := t.tunStack.Start() - monitor.Finish() - if err != nil { - return E.Cause(err, "starting tun stack") - } - monitor.Start("starting tun interface") - err = t.tunIf.Start() - monitor.Finish() - if err != nil { - return E.Cause(err, "starting TUN interface") - } - if t.autoRedirect != nil { - t.routeAddressSet = common.FlatMap(t.routeRuleSet, adapter.RuleSet.ExtractIPSet) - for _, routeRuleSet := range t.routeRuleSet { - ipSets := routeRuleSet.ExtractIPSet() - if len(ipSets) == 0 { - t.logger.Warn("route_address_set: no destination IP CIDR rules found in rule-set: ", routeRuleSet.Name()) - } - t.routeAddressSet = append(t.routeAddressSet, ipSets...) +func (t *Inbound) Start(stage adapter.StartStage) error { + switch stage { + case adapter.StartStateStart: + if C.IsAndroid && t.platformInterface == nil { + t.tunOptions.BuildAndroidRules(t.networkManager.PackageManager()) } - t.routeExcludeAddressSet = common.FlatMap(t.routeExcludeRuleSet, adapter.RuleSet.ExtractIPSet) - for _, routeExcludeRuleSet := range t.routeExcludeRuleSet { - ipSets := routeExcludeRuleSet.ExtractIPSet() - if len(ipSets) == 0 { - t.logger.Warn("route_address_set: no destination IP CIDR rules found in rule-set: ", routeExcludeRuleSet.Name()) - } - t.routeExcludeAddressSet = append(t.routeExcludeAddressSet, ipSets...) + if t.tunOptions.Name == "" { + t.tunOptions.Name = tun.CalculateInterfaceName("") + } + var ( + tunInterface tun.Tun + err error + ) + monitor := taskmonitor.New(t.logger, C.StartTimeout) + monitor.Start("open tun interface") + if t.platformInterface != nil { + tunInterface, err = t.platformInterface.OpenTun(&t.tunOptions, t.platformOptions) + } else { + tunInterface, err = tun.New(t.tunOptions) } - monitor.Start("initialize auto-redirect") - err := t.autoRedirect.Start() monitor.Finish() if err != nil { - return E.Cause(err, "auto-redirect") + return E.Cause(err, "configure tun interface") } - for _, routeRuleSet := range t.routeRuleSet { - t.routeRuleSetCallback = append(t.routeRuleSetCallback, routeRuleSet.RegisterCallback(t.updateRouteAddressSet)) - routeRuleSet.DecRef() + t.logger.Trace("creating stack") + t.tunIf = tunInterface + var ( + forwarderBindInterface bool + includeAllNetworks bool + ) + if t.platformInterface != nil { + forwarderBindInterface = true + includeAllNetworks = t.platformInterface.IncludeAllNetworks() } - for _, routeExcludeRuleSet := range t.routeExcludeRuleSet { - t.routeExcludeRuleSetCallback = append(t.routeExcludeRuleSetCallback, routeExcludeRuleSet.RegisterCallback(t.updateRouteAddressSet)) - routeExcludeRuleSet.DecRef() + tunStack, err := tun.NewStack(t.stack, tun.StackOptions{ + Context: t.ctx, + Tun: tunInterface, + TunOptions: t.tunOptions, + UDPTimeout: t.udpTimeout, + Handler: t, + Logger: t.logger, + ForwarderBindInterface: forwarderBindInterface, + InterfaceFinder: t.networkManager.InterfaceFinder(), + IncludeAllNetworks: includeAllNetworks, + }) + if err != nil { + return err + } + t.tunStack = tunStack + t.logger.Info("started at ", t.tunOptions.Name) + case adapter.StartStatePostStart: + monitor := taskmonitor.New(t.logger, C.StartTimeout) + monitor.Start("starting tun stack") + err := t.tunStack.Start() + monitor.Finish() + if err != nil { + return E.Cause(err, "starting tun stack") + } + monitor.Start("starting tun interface") + err = t.tunIf.Start() + monitor.Finish() + if err != nil { + return E.Cause(err, "starting TUN interface") + } + if t.autoRedirect != nil { + t.routeAddressSet = common.FlatMap(t.routeRuleSet, adapter.RuleSet.ExtractIPSet) + for _, routeRuleSet := range t.routeRuleSet { + ipSets := routeRuleSet.ExtractIPSet() + if len(ipSets) == 0 { + t.logger.Warn("route_address_set: no destination IP CIDR rules found in rule-set: ", routeRuleSet.Name()) + } + t.routeAddressSet = append(t.routeAddressSet, ipSets...) + } + t.routeExcludeAddressSet = common.FlatMap(t.routeExcludeRuleSet, adapter.RuleSet.ExtractIPSet) + for _, routeExcludeRuleSet := range t.routeExcludeRuleSet { + ipSets := routeExcludeRuleSet.ExtractIPSet() + if len(ipSets) == 0 { + t.logger.Warn("route_address_set: no destination IP CIDR rules found in rule-set: ", routeExcludeRuleSet.Name()) + } + t.routeExcludeAddressSet = append(t.routeExcludeAddressSet, ipSets...) + } + monitor.Start("initialize auto-redirect") + err := t.autoRedirect.Start() + monitor.Finish() + if err != nil { + return E.Cause(err, "auto-redirect") + } + for _, routeRuleSet := range t.routeRuleSet { + t.routeRuleSetCallback = append(t.routeRuleSetCallback, routeRuleSet.RegisterCallback(t.updateRouteAddressSet)) + routeRuleSet.DecRef() + } + for _, routeExcludeRuleSet := range t.routeExcludeRuleSet { + t.routeExcludeRuleSetCallback = append(t.routeExcludeRuleSetCallback, routeExcludeRuleSet.RegisterCallback(t.updateRouteAddressSet)) + routeExcludeRuleSet.DecRef() + } + t.routeAddressSet = nil + t.routeExcludeAddressSet = nil } - t.routeAddressSet = nil - t.routeExcludeAddressSet = nil } return nil } diff --git a/protocol/vless/inbound.go b/protocol/vless/inbound.go index f5dfeabb6b..b0aa959e06 100644 --- a/protocol/vless/inbound.go +++ b/protocol/vless/inbound.go @@ -89,7 +89,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo return inbound, nil } -func (h *Inbound) Start() error { +func (h *Inbound) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } if h.tlsConfig != nil { err := h.tlsConfig.Start() if err != nil { diff --git a/protocol/vless/outbound.go b/protocol/vless/outbound.go index de655230e0..1d832a654d 100644 --- a/protocol/vless/outbound.go +++ b/protocol/vless/outbound.go @@ -46,7 +46,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL return nil, err } outbound := &Outbound{ - Adapter: outbound.NewAdapterWithDialerOptions(C.TypeVLESS, options.Network.Build(), tag, options.DialerOptions), + Adapter: outbound.NewAdapterWithDialerOptions(C.TypeVLESS, tag, options.Network.Build(), options.DialerOptions), logger: logger, dialer: outboundDialer, serverAddr: options.ServerOptions.Build(), diff --git a/protocol/vmess/inbound.go b/protocol/vmess/inbound.go index 88ad222764..9f1009a301 100644 --- a/protocol/vmess/inbound.go +++ b/protocol/vmess/inbound.go @@ -99,7 +99,10 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo return inbound, nil } -func (h *Inbound) Start() error { +func (h *Inbound) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } err := h.service.Start() if err != nil { return err diff --git a/protocol/vmess/outbound.go b/protocol/vmess/outbound.go index 1e84639f47..d41b30d964 100644 --- a/protocol/vmess/outbound.go +++ b/protocol/vmess/outbound.go @@ -46,7 +46,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL return nil, err } outbound := &Outbound{ - Adapter: outbound.NewAdapterWithDialerOptions(C.TypeVMess, options.Network.Build(), tag, options.DialerOptions), + Adapter: outbound.NewAdapterWithDialerOptions(C.TypeVMess, tag, options.Network.Build(), options.DialerOptions), logger: logger, dialer: outboundDialer, serverAddr: options.ServerOptions.Build(), diff --git a/protocol/wireguard/endpoint.go b/protocol/wireguard/endpoint.go new file mode 100644 index 0000000000..74cf2dc5df --- /dev/null +++ b/protocol/wireguard/endpoint.go @@ -0,0 +1,211 @@ +package wireguard + +import ( + "context" + "net" + "net/netip" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/endpoint" + "github.com/sagernet/sing-box/common/dialer" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/transport/wireguard" + "github.com/sagernet/sing-dns" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/service" +) + +func RegisterEndpoint(registry *endpoint.Registry) { + endpoint.Register[option.WireGuardEndpointOptions](registry, C.TypeWireGuard, NewEndpoint) +} + +var ( + _ adapter.Endpoint = (*Endpoint)(nil) + _ adapter.InterfaceUpdateListener = (*Endpoint)(nil) +) + +type Endpoint struct { + endpoint.Adapter + ctx context.Context + router adapter.Router + logger logger.ContextLogger + localAddresses []netip.Prefix + endpoint *wireguard.Endpoint +} + +func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardEndpointOptions) (adapter.Endpoint, error) { + ep := &Endpoint{ + Adapter: endpoint.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions), + ctx: ctx, + router: router, + logger: logger, + localAddresses: options.Address, + } + if options.Detour == "" { + options.IsWireGuardListener = true + } else if options.GSO { + return nil, E.New("gso is conflict with detour") + } + outboundDialer, err := dialer.New(ctx, options.DialerOptions) + if err != nil { + return nil, err + } + wgEndpoint, err := wireguard.NewEndpoint(wireguard.EndpointOptions{ + Context: ctx, + Logger: logger, + System: options.System, + Handler: ep, + UDPTimeout: time.Duration(options.UDPTimeout), + Dialer: outboundDialer, + CreateDialer: func(interfaceName string) N.Dialer { + return common.Must1(dialer.NewDefault(service.FromContext[adapter.NetworkManager](ctx), option.DialerOptions{ + BindInterface: interfaceName, + })) + }, + Name: options.Name, + MTU: options.MTU, + GSO: options.GSO, + Address: options.Address, + PrivateKey: options.PrivateKey, + ListenPort: options.ListenPort, + ResolvePeer: func(domain string) (netip.Addr, error) { + endpointAddresses, lookupErr := router.Lookup(ctx, domain, dns.DomainStrategy(options.DomainStrategy)) + if lookupErr != nil { + return netip.Addr{}, lookupErr + } + return endpointAddresses[0], nil + }, + Peers: common.Map(options.Peers, func(it option.WireGuardPeer) wireguard.PeerOptions { + return wireguard.PeerOptions{ + Endpoint: M.ParseSocksaddrHostPort(it.Address, it.Port), + PublicKey: it.PublicKey, + PreSharedKey: it.PreSharedKey, + AllowedIPs: it.AllowedIPs, + PersistentKeepaliveInterval: time.Duration(it.PersistentKeepaliveInterval), + Reserved: it.Reserved, + } + }), + Workers: options.Workers, + }) + if err != nil { + return nil, err + } + ep.endpoint = wgEndpoint + return ep, nil +} + +func (w *Endpoint) Start(stage adapter.StartStage) error { + switch stage { + case adapter.StartStateStart: + return w.endpoint.Start(false) + case adapter.StartStatePostStart: + return w.endpoint.Start(true) + } + return nil +} + +func (w *Endpoint) Close() error { + return w.endpoint.Close() +} + +func (w *Endpoint) InterfaceUpdated() { + w.endpoint.BindUpdate() + return +} + +func (w *Endpoint) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr) error { + return w.router.PreMatch(adapter.InboundContext{ + Inbound: w.Tag(), + InboundType: w.Type(), + Network: network, + Source: source, + Destination: destination, + }) +} + +func (w *Endpoint) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) { + var metadata adapter.InboundContext + metadata.Inbound = w.Tag() + metadata.InboundType = w.Type() + metadata.Source = source + for _, localPrefix := range w.localAddresses { + if localPrefix.Contains(destination.Addr) { + metadata.OriginDestination = destination + if destination.Addr.Is4() { + destination.Addr = netip.AddrFrom4([4]uint8{127, 0, 0, 1}) + } else { + destination.Addr = netip.IPv6Loopback() + } + break + } + } + metadata.Destination = destination + w.logger.InfoContext(ctx, "inbound connection from ", source) + w.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) + w.router.RouteConnectionEx(ctx, conn, metadata, onClose) +} + +func (w *Endpoint) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) { + var metadata adapter.InboundContext + metadata.Inbound = w.Tag() + metadata.InboundType = w.Type() + metadata.Source = source + metadata.Destination = destination + for _, localPrefix := range w.localAddresses { + if localPrefix.Contains(destination.Addr) { + metadata.OriginDestination = destination + if destination.Addr.Is4() { + metadata.Destination.Addr = netip.AddrFrom4([4]uint8{127, 0, 0, 1}) + } else { + metadata.Destination.Addr = netip.IPv6Loopback() + } + conn = bufio.NewNATPacketConn(bufio.NewNetPacketConn(conn), metadata.OriginDestination, metadata.Destination) + } + } + w.logger.InfoContext(ctx, "inbound packet connection from ", source) + w.logger.InfoContext(ctx, "inbound packet connection to ", destination) + w.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose) +} + +func (w *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + switch network { + case N.NetworkTCP: + w.logger.InfoContext(ctx, "outbound connection to ", destination) + case N.NetworkUDP: + w.logger.InfoContext(ctx, "outbound packet connection to ", destination) + } + if destination.IsFqdn() { + destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn) + if err != nil { + return nil, err + } + return N.DialSerial(ctx, w.endpoint, network, destination, destinationAddresses) + } else if !destination.Addr.IsValid() { + return nil, E.New("invalid destination: ", destination) + } + return w.endpoint.DialContext(ctx, network, destination) +} + +func (w *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + w.logger.InfoContext(ctx, "outbound packet connection to ", destination) + if destination.IsFqdn() { + destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn) + if err != nil { + return nil, err + } + packetConn, _, err := N.ListenSerial(ctx, w.endpoint, destination, destinationAddresses) + if err != nil { + return nil, err + } + return packetConn, err + } + return w.endpoint.ListenPacket(ctx, destination) +} diff --git a/protocol/wireguard/outbound.go b/protocol/wireguard/outbound.go index 7b2f8a6cc7..2da8108183 100644 --- a/protocol/wireguard/outbound.go +++ b/protocol/wireguard/outbound.go @@ -2,231 +2,153 @@ package wireguard import ( "context" - "encoding/base64" - "encoding/hex" - "fmt" "net" "net/netip" - "strings" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter/outbound" "github.com/sagernet/sing-box/common/dialer" C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/experimental/deprecated" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/transport/wireguard" - "github.com/sagernet/sing-tun" + dns "github.com/sagernet/sing-dns" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/x/list" "github.com/sagernet/sing/service" - "github.com/sagernet/sing/service/pause" - "github.com/sagernet/wireguard-go/conn" - "github.com/sagernet/wireguard-go/device" ) func RegisterOutbound(registry *outbound.Registry) { - outbound.Register[option.WireGuardOutboundOptions](registry, C.TypeWireGuard, NewOutbound) + outbound.Register[option.LegacyWireGuardOutboundOptions](registry, C.TypeWireGuard, NewOutbound) } -var _ adapter.InterfaceUpdateListener = (*Outbound)(nil) +var ( + _ adapter.Endpoint = (*Endpoint)(nil) + _ adapter.InterfaceUpdateListener = (*Endpoint)(nil) +) type Outbound struct { outbound.Adapter - ctx context.Context - router adapter.Router - logger logger.ContextLogger - workers int - peers []wireguard.PeerConfig - useStdNetBind bool - listener N.Dialer - ipcConf string - - pauseManager pause.Manager - pauseCallback *list.Element[pause.Callback] - bind conn.Bind - device *device.Device - tunDevice wireguard.Device + ctx context.Context + router adapter.Router + logger logger.ContextLogger + localAddresses []netip.Prefix + endpoint *wireguard.Endpoint } -func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardOutboundOptions) (adapter.Outbound, error) { +func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.LegacyWireGuardOutboundOptions) (adapter.Outbound, error) { + deprecated.Report(ctx, deprecated.OptionWireGuardOutbound) outbound := &Outbound{ - Adapter: outbound.NewAdapterWithDialerOptions(C.TypeWireGuard, options.Network.Build(), tag, options.DialerOptions), - ctx: ctx, - router: router, - logger: logger, - workers: options.Workers, - pauseManager: service.FromContext[pause.Manager](ctx), - } - peers, err := wireguard.ParsePeers(options) - if err != nil { - return nil, err + Adapter: outbound.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions), + ctx: ctx, + router: router, + logger: logger, + localAddresses: options.LocalAddress, } - outbound.peers = peers - if len(options.LocalAddress) == 0 { - return nil, E.New("missing local address") - } - if options.GSO { - if options.GSO && options.Detour != "" { - return nil, E.New("gso is conflict with detour") - } + if options.Detour == "" { options.IsWireGuardListener = true - outbound.useStdNetBind = true + } else if options.GSO { + return nil, E.New("gso is conflict with detour") } - listener, err := dialer.New(ctx, options.DialerOptions) + outboundDialer, err := dialer.New(ctx, options.DialerOptions) if err != nil { return nil, err } - outbound.listener = listener - var privateKey string - { - bytes, err := base64.StdEncoding.DecodeString(options.PrivateKey) - if err != nil { - return nil, E.Cause(err, "decode private key") - } - privateKey = hex.EncodeToString(bytes) - } - outbound.ipcConf = "private_key=" + privateKey - mtu := options.MTU - if mtu == 0 { - mtu = 1408 - } - var wireTunDevice wireguard.Device - if !options.SystemInterface && tun.WithGVisor { - wireTunDevice, err = wireguard.NewStackDevice(options.LocalAddress, mtu) - } else { - wireTunDevice, err = wireguard.NewSystemDevice(service.FromContext[adapter.NetworkManager](ctx), options.InterfaceName, options.LocalAddress, mtu, options.GSO) - } - if err != nil { - return nil, E.Cause(err, "create WireGuard device") - } - outbound.tunDevice = wireTunDevice - return outbound, nil -} - -func (w *Outbound) Start() error { - if common.Any(w.peers, func(peer wireguard.PeerConfig) bool { - return !peer.Endpoint.IsValid() - }) { - // wait for all outbounds to be started and continue in PortStart - return nil - } - return w.start() -} - -func (w *Outbound) PostStart() error { - if common.All(w.peers, func(peer wireguard.PeerConfig) bool { - return peer.Endpoint.IsValid() - }) { - return nil - } - return w.start() -} - -func (w *Outbound) start() error { - err := wireguard.ResolvePeers(w.ctx, w.router, w.peers) - if err != nil { - return err - } - var bind conn.Bind - if w.useStdNetBind { - bind = conn.NewStdNetBind(w.listener.(dialer.WireGuardListener)) - } else { - var ( - isConnect bool - connectAddr netip.AddrPort - reserved [3]uint8 - ) - peerLen := len(w.peers) - if peerLen == 1 { - isConnect = true - connectAddr = w.peers[0].Endpoint - reserved = w.peers[0].Reserved - } - bind = wireguard.NewClientBind(w.ctx, w.logger, w.listener, isConnect, connectAddr, reserved) - } - err = w.tunDevice.Start() - if err != nil { - return err - } - wgDevice := device.NewDevice(w.tunDevice, bind, &device.Logger{ - Verbosef: func(format string, args ...interface{}) { - w.logger.Debug(fmt.Sprintf(strings.ToLower(format), args...)) + wgEndpoint, err := wireguard.NewEndpoint(wireguard.EndpointOptions{ + Context: ctx, + Logger: logger, + System: options.SystemInterface, + Dialer: outboundDialer, + CreateDialer: func(interfaceName string) N.Dialer { + return common.Must1(dialer.NewDefault(service.FromContext[adapter.NetworkManager](ctx), option.DialerOptions{ + BindInterface: interfaceName, + })) }, - Errorf: func(format string, args ...interface{}) { - w.logger.Error(fmt.Sprintf(strings.ToLower(format), args...)) + Name: options.InterfaceName, + MTU: options.MTU, + GSO: options.GSO, + Address: options.LocalAddress, + PrivateKey: options.PrivateKey, + ResolvePeer: func(domain string) (netip.Addr, error) { + endpointAddresses, lookupErr := router.Lookup(ctx, domain, dns.DomainStrategy(options.DomainStrategy)) + if lookupErr != nil { + return netip.Addr{}, lookupErr + } + return endpointAddresses[0], nil }, - }, w.workers) - ipcConf := w.ipcConf - for _, peer := range w.peers { - ipcConf += peer.GenerateIpcLines() - } - err = wgDevice.IpcSet(ipcConf) + Peers: common.Map(options.Peers, func(it option.LegacyWireGuardPeer) wireguard.PeerOptions { + return wireguard.PeerOptions{ + Endpoint: it.ServerOptions.Build(), + PublicKey: it.PublicKey, + PreSharedKey: it.PreSharedKey, + AllowedIPs: it.AllowedIPs, + // PersistentKeepaliveInterval: time.Duration(it.PersistentKeepaliveInterval), + Reserved: it.Reserved, + } + }), + Workers: options.Workers, + }) if err != nil { - return E.Cause(err, "setup wireguard: \n", ipcConf) + return nil, err } - w.device = wgDevice - w.pauseCallback = w.pauseManager.RegisterCallback(w.onPauseUpdated) - return nil + outbound.endpoint = wgEndpoint + return outbound, nil } -func (w *Outbound) Close() error { - if w.device != nil { - w.device.Close() - } - if w.pauseCallback != nil { - w.pauseManager.UnregisterCallback(w.pauseCallback) +func (o *Outbound) Start(stage adapter.StartStage) error { + switch stage { + case adapter.StartStateStart: + return o.endpoint.Start(false) + case adapter.StartStatePostStart: + return o.endpoint.Start(true) } return nil } -func (w *Outbound) InterfaceUpdated() { - w.device.BindUpdate() - return +func (o *Outbound) Close() error { + return o.endpoint.Close() } -func (w *Outbound) onPauseUpdated(event int) { - switch event { - case pause.EventDevicePaused: - w.device.Down() - case pause.EventDeviceWake: - w.device.Up() - } +func (o *Outbound) InterfaceUpdated() { + o.endpoint.BindUpdate() + return } -func (w *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { +func (o *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { switch network { case N.NetworkTCP: - w.logger.InfoContext(ctx, "outbound connection to ", destination) + o.logger.InfoContext(ctx, "outbound connection to ", destination) case N.NetworkUDP: - w.logger.InfoContext(ctx, "outbound packet connection to ", destination) + o.logger.InfoContext(ctx, "outbound packet connection to ", destination) } if destination.IsFqdn() { - destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn) + destinationAddresses, err := o.router.LookupDefault(ctx, destination.Fqdn) if err != nil { return nil, err } - return N.DialSerial(ctx, w.tunDevice, network, destination, destinationAddresses) + return N.DialSerial(ctx, o.endpoint, network, destination, destinationAddresses) + } else if !destination.Addr.IsValid() { + return nil, E.New("invalid destination: ", destination) } - return w.tunDevice.DialContext(ctx, network, destination) + return o.endpoint.DialContext(ctx, network, destination) } -func (w *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - w.logger.InfoContext(ctx, "outbound packet connection to ", destination) +func (o *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + o.logger.InfoContext(ctx, "outbound packet connection to ", destination) if destination.IsFqdn() { - destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn) + destinationAddresses, err := o.router.LookupDefault(ctx, destination.Fqdn) if err != nil { return nil, err } - packetConn, _, err := N.ListenSerial(ctx, w.tunDevice, destination, destinationAddresses) + packetConn, _, err := N.ListenSerial(ctx, o.endpoint, destination, destinationAddresses) if err != nil { return nil, err } return packetConn, err } - return w.tunDevice.ListenPacket(ctx, destination) + return o.endpoint.ListenPacket(ctx, destination) } diff --git a/route/network.go b/route/network.go index afa757a0a5..27cea08d62 100644 --- a/route/network.go +++ b/route/network.go @@ -41,17 +41,17 @@ type NetworkManager struct { autoDetectInterface bool defaultOptions adapter.NetworkOptions autoRedirectOutputMark uint32 - - networkMonitor tun.NetworkUpdateMonitor - interfaceMonitor tun.DefaultInterfaceMonitor - packageManager tun.PackageManager - powerListener winpowrprof.EventListener - pauseManager pause.Manager - platformInterface platform.Interface - inboundManager adapter.InboundManager - outboundManager adapter.OutboundManager - wifiState adapter.WIFIState - started bool + networkMonitor tun.NetworkUpdateMonitor + interfaceMonitor tun.DefaultInterfaceMonitor + packageManager tun.PackageManager + powerListener winpowrprof.EventListener + pauseManager pause.Manager + platformInterface platform.Interface + endpoint adapter.EndpointManager + inbound adapter.InboundManager + outbound adapter.OutboundManager + wifiState adapter.WIFIState + started bool } func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, routeOptions option.RouteOptions) (*NetworkManager, error) { @@ -69,7 +69,9 @@ func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, routeOp }, pauseManager: service.FromContext[pause.Manager](ctx), platformInterface: service.FromContext[platform.Interface](ctx), - outboundManager: service.FromContext[adapter.OutboundManager](ctx), + endpoint: service.FromContext[adapter.EndpointManager](ctx), + inbound: service.FromContext[adapter.InboundManager](ctx), + outbound: service.FromContext[adapter.OutboundManager](ctx), } if C.NetworkStrategy(routeOptions.DefaultNetworkStrategy) != C.NetworkStrategyDefault { if routeOptions.DefaultInterface != "" { @@ -358,14 +360,21 @@ func (r *NetworkManager) WIFIState() adapter.WIFIState { func (r *NetworkManager) ResetNetwork() { conntrack.Close() - for _, inbound := range r.inboundManager.Inbounds() { + for _, endpoint := range r.endpoint.Endpoints() { + listener, isListener := endpoint.(adapter.InterfaceUpdateListener) + if isListener { + listener.InterfaceUpdated() + } + } + + for _, inbound := range r.inbound.Inbounds() { listener, isListener := inbound.(adapter.InterfaceUpdateListener) if isListener { listener.InterfaceUpdated() } } - for _, outbound := range r.outboundManager.Outbounds() { + for _, outbound := range r.outbound.Outbounds() { listener, isListener := outbound.(adapter.InterfaceUpdateListener) if isListener { listener.InterfaceUpdated() diff --git a/test/box_test.go b/test/box_test.go index 0801b51816..08b50b64c2 100644 --- a/test/box_test.go +++ b/test/box_test.go @@ -32,7 +32,7 @@ func TestMain(m *testing.M) { var globalCtx context.Context func init() { - globalCtx = box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry()) + globalCtx = box.Context(context.Background(), include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry()) } func startInstance(t *testing.T, options option.Options) *box.Box { diff --git a/test/wireguard_test.go b/test/wireguard_test.go index 860c29df4c..d6c7bcc3f0 100644 --- a/test/wireguard_test.go +++ b/test/wireguard_test.go @@ -37,12 +37,12 @@ func _TestWireGuard(t *testing.T) { Outbounds: []option.Outbound{ { Type: C.TypeWireGuard, - Options: &option.WireGuardOutboundOptions{ + Options: &option.WireGuardEndpointOptions{ ServerOptions: option.ServerOptions{ Server: "127.0.0.1", ServerPort: serverPort, }, - LocalAddress: []netip.Prefix{netip.MustParsePrefix("10.0.0.2/32")}, + Address: []netip.Prefix{netip.MustParsePrefix("10.0.0.2/32")}, PrivateKey: "qGnwlkZljMxeECW8fbwAWdvgntnbK7B8UmMFl3zM0mk=", PeerPublicKey: "QsdcBm+oJw2oNv0cIFXLIq1E850lgTBonup4qnKEQBg=", }, diff --git a/transport/wireguard/client_bind.go b/transport/wireguard/client_bind.go index 20e7c0790c..e74e909d24 100644 --- a/transport/wireguard/client_bind.go +++ b/transport/wireguard/client_bind.go @@ -128,7 +128,7 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint) select { case <-c.done: default: - c.logger.Error(context.Background(), E.Cause(err, "read packet")) + c.logger.Error(E.Cause(err, "read packet")) err = nil } return @@ -138,7 +138,7 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint) b := packets[0] common.ClearArray(b[1:4]) } - eps[0] = Endpoint(M.AddrPortFromNet(addr)) + eps[0] = remoteEndpoint(M.AddrPortFromNet(addr)) count = 1 return } @@ -169,7 +169,7 @@ func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error { time.Sleep(time.Second) return err } - destination := netip.AddrPort(ep.(Endpoint)) + destination := netip.AddrPort(ep.(remoteEndpoint)) for _, b := range bufs { if len(b) > 3 { reserved, loaded := c.reservedForEndpoint[destination] @@ -192,7 +192,7 @@ func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) { if err != nil { return nil, err } - return Endpoint(ap), nil + return remoteEndpoint(ap), nil } func (c *ClientBind) BatchSize() int { @@ -229,3 +229,31 @@ func (w *wireConn) Close() error { close(w.done) return nil } + +var _ conn.Endpoint = (*remoteEndpoint)(nil) + +type remoteEndpoint netip.AddrPort + +func (e remoteEndpoint) ClearSrc() { +} + +func (e remoteEndpoint) SrcToString() string { + return "" +} + +func (e remoteEndpoint) DstToString() string { + return (netip.AddrPort)(e).String() +} + +func (e remoteEndpoint) DstToBytes() []byte { + b, _ := (netip.AddrPort)(e).MarshalBinary() + return b +} + +func (e remoteEndpoint) DstIP() netip.Addr { + return (netip.AddrPort)(e).Addr() +} + +func (e remoteEndpoint) SrcIP() netip.Addr { + return netip.Addr{} +} diff --git a/transport/wireguard/device.go b/transport/wireguard/device.go index 14e04bf56c..d5d3b78151 100644 --- a/transport/wireguard/device.go +++ b/transport/wireguard/device.go @@ -1,13 +1,44 @@ package wireguard import ( + "context" + "net/netip" + "time" + + "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common/logger" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/wireguard-go/tun" + "github.com/sagernet/wireguard-go/device" + wgTun "github.com/sagernet/wireguard-go/tun" ) type Device interface { - tun.Device + wgTun.Device N.Dialer Start() error - // NewEndpoint() (stack.LinkEndpoint, error) + SetDevice(device *device.Device) +} + +type DeviceOptions struct { + Context context.Context + Logger logger.ContextLogger + System bool + Handler tun.Handler + UDPTimeout time.Duration + CreateDialer func(interfaceName string) N.Dialer + Name string + MTU uint32 + GSO bool + Address []netip.Prefix + AllowedAddress []netip.Prefix +} + +func NewDevice(options DeviceOptions) (Device, error) { + if !options.System { + return newStackDevice(options) + } else if options.Handler == nil { + return newSystemDevice(options) + } else { + return newSystemStackDevice(options) + } } diff --git a/transport/wireguard/device_stack.go b/transport/wireguard/device_stack.go index 61286e6a9a..f9440f02fa 100644 --- a/transport/wireguard/device_stack.go +++ b/transport/wireguard/device_stack.go @@ -5,7 +5,6 @@ package wireguard import ( "context" "net" - "net/netip" "os" "github.com/sagernet/gvisor/pkg/buffer" @@ -15,52 +14,41 @@ import ( "github.com/sagernet/gvisor/pkg/tcpip/network/ipv4" "github.com/sagernet/gvisor/pkg/tcpip/network/ipv6" "github.com/sagernet/gvisor/pkg/tcpip/stack" - "github.com/sagernet/gvisor/pkg/tcpip/transport/icmp" "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" "github.com/sagernet/sing-tun" - "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + "github.com/sagernet/wireguard-go/device" wgTun "github.com/sagernet/wireguard-go/tun" ) -var _ Device = (*StackDevice)(nil) - -const defaultNIC tcpip.NICID = 1 - -type StackDevice struct { - stack *stack.Stack - mtu uint32 - events chan wgTun.Event - outbound chan *stack.PacketBuffer - packetOutbound chan *buf.Buffer - done chan struct{} - dispatcher stack.NetworkDispatcher - addr4 tcpip.Address - addr6 tcpip.Address -} - -func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, error) { - ipStack := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6}, - HandleLocal: true, - }) - tunDevice := &StackDevice{ - stack: ipStack, - mtu: mtu, - events: make(chan wgTun.Event, 1), - outbound: make(chan *stack.PacketBuffer, 256), - packetOutbound: make(chan *buf.Buffer, 256), - done: make(chan struct{}), +var _ Device = (*stackDevice)(nil) + +type stackDevice struct { + stack *stack.Stack + mtu uint32 + events chan wgTun.Event + outbound chan *stack.PacketBuffer + done chan struct{} + dispatcher stack.NetworkDispatcher + addr4 tcpip.Address + addr6 tcpip.Address +} + +func newStackDevice(options DeviceOptions) (*stackDevice, error) { + tunDevice := &stackDevice{ + mtu: options.MTU, + events: make(chan wgTun.Event, 1), + outbound: make(chan *stack.PacketBuffer, 256), + done: make(chan struct{}), } - err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice)) + ipStack, err := tun.NewGVisorStack((*wireEndpoint)(tunDevice)) if err != nil { - return nil, E.New(err.String()) + return nil, err } - for _, prefix := range localAddresses { + for _, prefix := range options.Address { addr := tun.AddressFromAddr(prefix.Addr()) protoAddr := tcpip.ProtocolAddress{ AddressWithPrefix: tcpip.AddressWithPrefix{ @@ -75,32 +63,27 @@ func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, er tunDevice.addr6 = addr protoAddr.Protocol = ipv6.ProtocolNumber } - err = ipStack.AddProtocolAddress(defaultNIC, protoAddr, stack.AddressProperties{}) - if err != nil { - return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", err.String()) + gErr := ipStack.AddProtocolAddress(tun.DefaultNIC, protoAddr, stack.AddressProperties{}) + if gErr != nil { + return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", gErr.String()) } } - sOpt := tcpip.TCPSACKEnabled(true) - ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt) - cOpt := tcpip.CongestionControlOption("cubic") - ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &cOpt) - ipStack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: defaultNIC}) - ipStack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: defaultNIC}) + tunDevice.stack = ipStack + if options.Handler != nil { + ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket) + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket) + } return tunDevice, nil } -func (w *StackDevice) NewEndpoint() (stack.LinkEndpoint, error) { - return (*wireEndpoint)(w), nil -} - -func (w *StackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { +func (w *stackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { addr := tcpip.FullAddress{ - NIC: defaultNIC, + NIC: tun.DefaultNIC, Port: destination.Port, Addr: tun.AddressFromAddr(destination.Addr), } bind := tcpip.FullAddress{ - NIC: defaultNIC, + NIC: tun.DefaultNIC, } var networkProtocol tcpip.NetworkProtocolNumber if destination.IsIPv4() { @@ -128,9 +111,9 @@ func (w *StackDevice) DialContext(ctx context.Context, network string, destinati } } -func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { +func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { bind := tcpip.FullAddress{ - NIC: defaultNIC, + NIC: tun.DefaultNIC, } var networkProtocol tcpip.NetworkProtocolNumber if destination.IsIPv4() { @@ -147,24 +130,19 @@ func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) return udpConn, nil } -func (w *StackDevice) Inet4Address() netip.Addr { - return tun.AddrFromAddress(w.addr4) -} - -func (w *StackDevice) Inet6Address() netip.Addr { - return tun.AddrFromAddress(w.addr6) +func (w *stackDevice) SetDevice(device *device.Device) { } -func (w *StackDevice) Start() error { +func (w *stackDevice) Start() error { w.events <- wgTun.EventUp return nil } -func (w *StackDevice) File() *os.File { +func (w *stackDevice) File() *os.File { return nil } -func (w *StackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) { +func (w *stackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) { select { case packetBuffer, ok := <-w.outbound: if !ok { @@ -180,17 +158,12 @@ func (w *StackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, e sizes[0] = n count = 1 return - case packet := <-w.packetOutbound: - defer packet.Release() - sizes[0] = copy(bufs[0][offset:], packet.Bytes()) - count = 1 - return case <-w.done: return 0, os.ErrClosed } } -func (w *StackDevice) Write(bufs [][]byte, offset int) (count int, err error) { +func (w *stackDevice) Write(bufs [][]byte, offset int) (count int, err error) { for _, b := range bufs { b = b[offset:] if len(b) == 0 { @@ -213,23 +186,23 @@ func (w *StackDevice) Write(bufs [][]byte, offset int) (count int, err error) { return } -func (w *StackDevice) Flush() error { +func (w *stackDevice) Flush() error { return nil } -func (w *StackDevice) MTU() (int, error) { +func (w *stackDevice) MTU() (int, error) { return int(w.mtu), nil } -func (w *StackDevice) Name() (string, error) { +func (w *stackDevice) Name() (string, error) { return "sing-box", nil } -func (w *StackDevice) Events() <-chan wgTun.Event { +func (w *stackDevice) Events() <-chan wgTun.Event { return w.events } -func (w *StackDevice) Close() error { +func (w *stackDevice) Close() error { close(w.done) close(w.events) w.stack.Close() @@ -240,13 +213,13 @@ func (w *StackDevice) Close() error { return nil } -func (w *StackDevice) BatchSize() int { +func (w *stackDevice) BatchSize() int { return 1 } var _ stack.LinkEndpoint = (*wireEndpoint)(nil) -type wireEndpoint StackDevice +type wireEndpoint stackDevice func (ep *wireEndpoint) MTU() uint32 { return ep.mtu diff --git a/transport/wireguard/gonet.go b/transport/wireguard/device_stack_gonet.go similarity index 100% rename from transport/wireguard/gonet.go rename to transport/wireguard/device_stack_gonet.go diff --git a/transport/wireguard/device_stack_stub.go b/transport/wireguard/device_stack_stub.go index b383ab3825..ea413559e6 100644 --- a/transport/wireguard/device_stack_stub.go +++ b/transport/wireguard/device_stack_stub.go @@ -2,12 +2,12 @@ package wireguard -import ( - "net/netip" +import "github.com/sagernet/sing-tun" - "github.com/sagernet/sing-tun" -) +func newStackDevice(options DeviceOptions) (Device, error) { + return nil, tun.ErrGVisorNotIncluded +} -func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (Device, error) { +func newSystemStackDevice(options DeviceOptions) (Device, error) { return nil, tun.ErrGVisorNotIncluded } diff --git a/transport/wireguard/device_system.go b/transport/wireguard/device_system.go index 8a54a75ef9..53fc6f53bc 100644 --- a/transport/wireguard/device_system.go +++ b/transport/wireguard/device_system.go @@ -6,96 +6,88 @@ import ( "net" "net/netip" "os" + "runtime" "sync" "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/common/dialer" - "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/service" + "github.com/sagernet/wireguard-go/device" wgTun "github.com/sagernet/wireguard-go/tun" ) -var _ Device = (*SystemDevice)(nil) - -type SystemDevice struct { - dialer N.Dialer - device tun.Tun - batchDevice tun.LinuxTUN - name string - mtu uint32 - inet4Addresses []netip.Prefix - inet6Addresses []netip.Prefix - gso bool - events chan wgTun.Event - closeOnce sync.Once +var _ Device = (*systemDevice)(nil) + +type systemDevice struct { + options DeviceOptions + dialer N.Dialer + device tun.Tun + batchDevice tun.LinuxTUN + events chan wgTun.Event + closeOnce sync.Once } -func NewSystemDevice(networkManager adapter.NetworkManager, interfaceName string, localPrefixes []netip.Prefix, mtu uint32, gso bool) (*SystemDevice, error) { - var inet4Addresses []netip.Prefix - var inet6Addresses []netip.Prefix - for _, prefixes := range localPrefixes { - if prefixes.Addr().Is4() { - inet4Addresses = append(inet4Addresses, prefixes) - } else { - inet6Addresses = append(inet6Addresses, prefixes) - } - } - if interfaceName == "" { - interfaceName = tun.CalculateInterfaceName("wg") +func newSystemDevice(options DeviceOptions) (*systemDevice, error) { + if options.Name == "" { + options.Name = tun.CalculateInterfaceName("wg") } - - return &SystemDevice{ - dialer: common.Must1(dialer.NewDefault(networkManager, option.DialerOptions{ - BindInterface: interfaceName, - })), - name: interfaceName, - mtu: mtu, - inet4Addresses: inet4Addresses, - inet6Addresses: inet6Addresses, - gso: gso, - events: make(chan wgTun.Event, 1), + return &systemDevice{ + options: options, + dialer: options.CreateDialer(options.Name), + events: make(chan wgTun.Event, 1), }, nil } -func (w *SystemDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { +func (w *systemDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { return w.dialer.DialContext(ctx, network, destination) } -func (w *SystemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { +func (w *systemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { return w.dialer.ListenPacket(ctx, destination) } -func (w *SystemDevice) Inet4Address() netip.Addr { - if len(w.inet4Addresses) == 0 { - return netip.Addr{} +func (w *systemDevice) SetDevice(device *device.Device) { +} + +func (w *systemDevice) Start() error { + networkManager := service.FromContext[adapter.NetworkManager](w.options.Context) + tunOptions := tun.Options{ + Name: w.options.Name, + Inet4Address: common.Filter(w.options.Address, func(it netip.Prefix) bool { + return it.Addr().Is4() + }), + Inet6Address: common.Filter(w.options.Address, func(it netip.Prefix) bool { + return it.Addr().Is6() + }), + MTU: w.options.MTU, + GSO: w.options.GSO, + InterfaceScope: true, + Inet4RouteAddress: common.Filter(w.options.AllowedAddress, func(it netip.Prefix) bool { + return it.Addr().Is4() + }), + Inet6RouteAddress: common.Filter(w.options.AllowedAddress, func(it netip.Prefix) bool { return it.Addr().Is6() }), + InterfaceMonitor: networkManager.InterfaceMonitor(), + InterfaceFinder: networkManager.InterfaceFinder(), } - return w.inet4Addresses[0].Addr() -} - -func (w *SystemDevice) Inet6Address() netip.Addr { - if len(w.inet6Addresses) == 0 { - return netip.Addr{} + // works with Linux, macOS with IFSCOPE routes, not tested on Windows + if runtime.GOOS == "darwin" { + tunOptions.AutoRoute = true } - return w.inet6Addresses[0].Addr() -} - -func (w *SystemDevice) Start() error { - tunInterface, err := tun.New(tun.Options{ - Name: w.name, - Inet4Address: w.inet4Addresses, - Inet6Address: w.inet6Addresses, - MTU: w.mtu, - GSO: w.gso, - }) + tunInterface, err := tun.New(tunOptions) + if err != nil { + return err + } + err = tunInterface.Start() if err != nil { return err } + w.options.Logger.Info("started at ", w.options.Name) w.device = tunInterface - if w.gso { + if w.options.GSO { batchTUN, isBatchTUN := tunInterface.(tun.LinuxTUN) if !isBatchTUN { tunInterface.Close() @@ -107,15 +99,15 @@ func (w *SystemDevice) Start() error { return nil } -func (w *SystemDevice) File() *os.File { +func (w *systemDevice) File() *os.File { return nil } -func (w *SystemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) { +func (w *systemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) { if w.batchDevice != nil { - count, err = w.batchDevice.BatchRead(bufs, offset, sizes) + count, err = w.batchDevice.BatchRead(bufs, offset-tun.PacketOffset, sizes) } else { - sizes[0], err = w.device.Read(bufs[0][offset:]) + sizes[0], err = w.device.Read(bufs[0][offset-tun.PacketOffset:]) if err == nil { count = 1 } else if errors.Is(err, tun.ErrTooManySegments) { @@ -125,12 +117,16 @@ func (w *SystemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, return } -func (w *SystemDevice) Write(bufs [][]byte, offset int) (count int, err error) { +func (w *systemDevice) Write(bufs [][]byte, offset int) (count int, err error) { if w.batchDevice != nil { return 0, w.batchDevice.BatchWrite(bufs, offset) } else { - for _, b := range bufs { - _, err = w.device.Write(b[offset:]) + for _, packet := range bufs { + if tun.PacketOffset > 0 { + common.ClearArray(packet[offset-tun.PacketOffset : offset]) + tun.PacketFillHeader(packet[offset-tun.PacketOffset:], tun.PacketIPVersion(packet[offset:])) + } + _, err = w.device.Write(packet[offset-tun.PacketOffset:]) if err != nil { return } @@ -140,28 +136,28 @@ func (w *SystemDevice) Write(bufs [][]byte, offset int) (count int, err error) { return } -func (w *SystemDevice) Flush() error { +func (w *systemDevice) Flush() error { return nil } -func (w *SystemDevice) MTU() (int, error) { - return int(w.mtu), nil +func (w *systemDevice) MTU() (int, error) { + return int(w.options.MTU), nil } -func (w *SystemDevice) Name() (string, error) { - return w.name, nil +func (w *systemDevice) Name() (string, error) { + return w.options.Name, nil } -func (w *SystemDevice) Events() <-chan wgTun.Event { +func (w *systemDevice) Events() <-chan wgTun.Event { return w.events } -func (w *SystemDevice) Close() error { +func (w *systemDevice) Close() error { close(w.events) return w.device.Close() } -func (w *SystemDevice) BatchSize() int { +func (w *systemDevice) BatchSize() int { if w.batchDevice != nil { return w.batchDevice.BatchSize() } diff --git a/transport/wireguard/device_system_stack.go b/transport/wireguard/device_system_stack.go new file mode 100644 index 0000000000..0396690952 --- /dev/null +++ b/transport/wireguard/device_system_stack.go @@ -0,0 +1,182 @@ +//go:build with_gvisor + +package wireguard + +import ( + "net/netip" + + "github.com/sagernet/gvisor/pkg/buffer" + "github.com/sagernet/gvisor/pkg/tcpip" + "github.com/sagernet/gvisor/pkg/tcpip/header" + "github.com/sagernet/gvisor/pkg/tcpip/stack" + "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" + "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" + "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common" + "github.com/sagernet/wireguard-go/device" +) + +var _ Device = (*systemStackDevice)(nil) + +type systemStackDevice struct { + *systemDevice + stack *stack.Stack + endpoint *deviceEndpoint + writeBufs [][]byte +} + +func newSystemStackDevice(options DeviceOptions) (*systemStackDevice, error) { + system, err := newSystemDevice(options) + if err != nil { + return nil, err + } + endpoint := &deviceEndpoint{ + mtu: options.MTU, + done: make(chan struct{}), + } + ipStack, err := tun.NewGVisorStack(endpoint) + if err != nil { + return nil, err + } + ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket) + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket) + return &systemStackDevice{ + systemDevice: system, + stack: ipStack, + endpoint: endpoint, + }, nil +} + +func (w *systemStackDevice) SetDevice(device *device.Device) { + w.endpoint.device = device +} + +func (w *systemStackDevice) Write(bufs [][]byte, offset int) (count int, err error) { + if w.batchDevice != nil { + w.writeBufs = w.writeBufs[:0] + for _, packet := range bufs { + if !w.writeStack(packet[offset:]) { + w.writeBufs = append(w.writeBufs, packet) + } + } + if len(w.writeBufs) > 0 { + return 0, w.batchDevice.BatchWrite(bufs, offset) + } + } else { + for _, packet := range bufs { + if !w.writeStack(packet[offset:]) { + if tun.PacketOffset > 0 { + common.ClearArray(packet[offset-tun.PacketOffset : offset]) + tun.PacketFillHeader(packet[offset-tun.PacketOffset:], tun.PacketIPVersion(packet[offset:])) + } + _, err = w.device.Write(packet[offset-tun.PacketOffset:]) + } + if err != nil { + return + } + } + } + // WireGuard will not read count + return +} + +func (w *systemStackDevice) Close() error { + close(w.endpoint.done) + w.stack.Close() + for _, endpoint := range w.stack.CleanupEndpoints() { + endpoint.Abort() + } + w.stack.Wait() + return w.systemDevice.Close() +} + +func (w *systemStackDevice) writeStack(packet []byte) bool { + var ( + networkProtocol tcpip.NetworkProtocolNumber + destination netip.Addr + ) + switch header.IPVersion(packet) { + case header.IPv4Version: + networkProtocol = header.IPv4ProtocolNumber + destination = netip.AddrFrom4(header.IPv4(packet).DestinationAddress().As4()) + case header.IPv6Version: + networkProtocol = header.IPv6ProtocolNumber + destination = netip.AddrFrom16(header.IPv6(packet).DestinationAddress().As16()) + } + for _, prefix := range w.options.Address { + if prefix.Contains(destination) { + return false + } + } + packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(packet), + }) + w.endpoint.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer) + packetBuffer.DecRef() + return true +} + +type deviceEndpoint struct { + mtu uint32 + done chan struct{} + device *device.Device + dispatcher stack.NetworkDispatcher +} + +func (ep *deviceEndpoint) MTU() uint32 { + return ep.mtu +} + +func (ep *deviceEndpoint) SetMTU(mtu uint32) { +} + +func (ep *deviceEndpoint) MaxHeaderLength() uint16 { + return 0 +} + +func (ep *deviceEndpoint) LinkAddress() tcpip.LinkAddress { + return "" +} + +func (ep *deviceEndpoint) SetLinkAddress(addr tcpip.LinkAddress) { +} + +func (ep *deviceEndpoint) Capabilities() stack.LinkEndpointCapabilities { + return stack.CapabilityRXChecksumOffload +} + +func (ep *deviceEndpoint) Attach(dispatcher stack.NetworkDispatcher) { + ep.dispatcher = dispatcher +} + +func (ep *deviceEndpoint) IsAttached() bool { + return ep.dispatcher != nil +} + +func (ep *deviceEndpoint) Wait() { +} + +func (ep *deviceEndpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareNone +} + +func (ep *deviceEndpoint) AddHeader(buffer *stack.PacketBuffer) { +} + +func (ep *deviceEndpoint) ParseHeader(ptr *stack.PacketBuffer) bool { + return true +} + +func (ep *deviceEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) { + for _, packetBuffer := range list.AsSlice() { + destination := packetBuffer.Network().DestinationAddress() + ep.device.InputPacket(destination.AsSlice(), packetBuffer.AsSlices()) + } + return list.Len(), nil +} + +func (ep *deviceEndpoint) Close() { +} + +func (ep *deviceEndpoint) SetOnCloseAction(f func()) { +} diff --git a/transport/wireguard/endpoint.go b/transport/wireguard/endpoint.go index 3c3ec7db5c..b2839c2a2d 100644 --- a/transport/wireguard/endpoint.go +++ b/transport/wireguard/endpoint.go @@ -1,35 +1,255 @@ package wireguard import ( + "context" + "encoding/base64" + "encoding/hex" + "fmt" + "net" "net/netip" + "os" + "strings" + "time" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/x/list" + "github.com/sagernet/sing/service" + "github.com/sagernet/sing/service/pause" "github.com/sagernet/wireguard-go/conn" + "github.com/sagernet/wireguard-go/device" + + "go4.org/netipx" ) -var _ conn.Endpoint = (*Endpoint)(nil) +type Endpoint struct { + options EndpointOptions + peers []peerConfig + ipcConf string + allowedAddress []netip.Prefix + tunDevice Device + device *device.Device + pauseManager pause.Manager + pauseCallback *list.Element[pause.Callback] +} + +func NewEndpoint(options EndpointOptions) (*Endpoint, error) { + if options.PrivateKey == "" { + return nil, E.New("missing private key") + } + privateKeyBytes, err := base64.StdEncoding.DecodeString(options.PrivateKey) + if err != nil { + return nil, E.Cause(err, "decode private key") + } + privateKey := hex.EncodeToString(privateKeyBytes) + ipcConf := "private_key=" + privateKey + if options.ListenPort != 0 { + ipcConf += "\nlisten_port=" + F.ToString(options.ListenPort) + } + var peers []peerConfig + for peerIndex, rawPeer := range options.Peers { + peer := peerConfig{ + allowedIPs: rawPeer.AllowedIPs, + keepalive: rawPeer.PersistentKeepaliveInterval, + } + if !rawPeer.Endpoint.IsValid() { + return nil, E.New("invalid endpoint for peer ", peerIndex, ": ", rawPeer.Endpoint) + } else if rawPeer.Endpoint.Addr.IsValid() { + peer.endpoint = rawPeer.Endpoint.AddrPort() + } else { + peer.destination = rawPeer.Endpoint + } + publicKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey) + if err != nil { + return nil, E.Cause(err, "decode public key for peer ", peerIndex) + } + peer.publicKeyHex = hex.EncodeToString(publicKeyBytes) + if rawPeer.PreSharedKey != "" { + preSharedKeyBytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey) + if err != nil { + return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex) + } + peer.preSharedKeyHex = hex.EncodeToString(preSharedKeyBytes) + } + if len(rawPeer.AllowedIPs) == 0 { + return nil, E.New("missing allowed ips for peer ", peerIndex) + } + if len(rawPeer.Reserved) > 0 { + if len(rawPeer.Reserved) != 3 { + return nil, E.New("invalid reserved value for peer ", peerIndex, ", required 3 bytes, got ", len(peer.reserved)) + } + copy(peer.reserved[:], rawPeer.Reserved[:]) + } + peers = append(peers, peer) + } + var allowedPrefixBuilder netipx.IPSetBuilder + for _, peer := range options.Peers { + for _, prefix := range peer.AllowedIPs { + allowedPrefixBuilder.AddPrefix(prefix) + } + } + allowedIPSet, err := allowedPrefixBuilder.IPSet() + if err != nil { + return nil, err + } + allowedAddresses := allowedIPSet.Prefixes() + if options.MTU == 0 { + options.MTU = 1408 + } + deviceOptions := DeviceOptions{ + Context: options.Context, + Logger: options.Logger, + System: options.System, + Handler: options.Handler, + UDPTimeout: options.UDPTimeout, + CreateDialer: options.CreateDialer, + Name: options.Name, + MTU: options.MTU, + GSO: options.GSO, + Address: options.Address, + AllowedAddress: allowedAddresses, + } + tunDevice, err := NewDevice(deviceOptions) + if err != nil { + return nil, E.Cause(err, "create WireGuard device") + } + return &Endpoint{ + options: options, + peers: peers, + ipcConf: ipcConf, + allowedAddress: allowedAddresses, + tunDevice: tunDevice, + }, nil +} -type Endpoint netip.AddrPort +func (e *Endpoint) Start(resolve bool) error { + if common.Any(e.peers, func(peer peerConfig) bool { + return !peer.endpoint.IsValid() + }) { + if !resolve { + return nil + } + for peerIndex, peer := range e.peers { + if peer.endpoint.IsValid() { + continue + } + destinationAddress, err := e.options.ResolvePeer(peer.destination.Fqdn) + if err != nil { + return E.Cause(err, "resolve endpoint domain for peer[", peerIndex, "]: ", peer.destination) + } + e.peers[peerIndex].endpoint = netip.AddrPortFrom(destinationAddress, peer.destination.Port) + } + } else if resolve { + return nil + } + var bind conn.Bind + wgListener, isWgListener := e.options.Dialer.(conn.Listener) + if isWgListener { + bind = conn.NewStdNetBind(wgListener) + } else { + var ( + isConnect bool + connectAddr netip.AddrPort + reserved [3]uint8 + ) + peerLen := len(e.peers) + if peerLen == 1 { + isConnect = true + connectAddr = e.peers[0].endpoint + reserved = e.peers[0].reserved + } + bind = NewClientBind(e.options.Context, e.options.Logger, e.options.Dialer, isConnect, connectAddr, reserved) + } + err := e.tunDevice.Start() + if err != nil { + return err + } + logger := &device.Logger{ + Verbosef: func(format string, args ...interface{}) { + e.options.Logger.Debug(fmt.Sprintf(strings.ToLower(format), args...)) + }, + Errorf: func(format string, args ...interface{}) { + e.options.Logger.Error(fmt.Sprintf(strings.ToLower(format), args...)) + }, + } + wgDevice := device.NewDevice(e.options.Context, e.tunDevice, bind, logger, e.options.Workers) + e.tunDevice.SetDevice(wgDevice) + ipcConf := e.ipcConf + for _, peer := range e.peers { + ipcConf += peer.GenerateIpcLines() + } + err = wgDevice.IpcSet(ipcConf) + if err != nil { + return E.Cause(err, "setup wireguard: \n", ipcConf) + } + e.device = wgDevice + e.pauseManager = service.FromContext[pause.Manager](e.options.Context) + if e.pauseManager != nil { + e.pauseCallback = e.pauseManager.RegisterCallback(e.onPauseUpdated) + } + return nil +} + +func (e *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + if !destination.Addr.IsValid() { + return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination") + } + return e.tunDevice.DialContext(ctx, network, destination) +} -func (e Endpoint) ClearSrc() { +func (e *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + if !destination.Addr.IsValid() { + return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination") + } + return e.tunDevice.ListenPacket(ctx, destination) } -func (e Endpoint) SrcToString() string { - return "" +func (e *Endpoint) BindUpdate() error { + return e.device.BindUpdate() } -func (e Endpoint) DstToString() string { - return (netip.AddrPort)(e).String() +func (e *Endpoint) Close() error { + if e.device != nil { + e.device.Close() + } + if e.pauseCallback != nil { + e.pauseManager.UnregisterCallback(e.pauseCallback) + } + return nil } -func (e Endpoint) DstToBytes() []byte { - b, _ := (netip.AddrPort)(e).MarshalBinary() - return b +func (e *Endpoint) onPauseUpdated(event int) { + switch event { + case pause.EventDevicePaused: + e.device.Down() + case pause.EventDeviceWake: + e.device.Up() + } } -func (e Endpoint) DstIP() netip.Addr { - return (netip.AddrPort)(e).Addr() +type peerConfig struct { + destination M.Socksaddr + endpoint netip.AddrPort + publicKeyHex string + preSharedKeyHex string + allowedIPs []netip.Prefix + keepalive time.Duration + reserved [3]uint8 } -func (e Endpoint) SrcIP() netip.Addr { - return netip.Addr{} +func (c peerConfig) GenerateIpcLines() string { + ipcLines := "\npublic_key=" + c.publicKeyHex + ipcLines += "\nendpoint=" + c.endpoint.String() + if c.preSharedKeyHex != "" { + ipcLines += "\npreshared_key=" + c.preSharedKeyHex + } + for _, allowedIP := range c.allowedIPs { + ipcLines += "\nallowed_ip=" + allowedIP.String() + } + if c.keepalive > 0 { + ipcLines += "\npersistent_keepalive_interval=" + F.ToString(int(c.keepalive.Seconds())) + } + return ipcLines } diff --git a/transport/wireguard/endpoint_options.go b/transport/wireguard/endpoint_options.go new file mode 100644 index 0000000000..6c75866263 --- /dev/null +++ b/transport/wireguard/endpoint_options.go @@ -0,0 +1,40 @@ +package wireguard + +import ( + "context" + "net/netip" + "time" + + "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +type EndpointOptions struct { + Context context.Context + Logger logger.ContextLogger + System bool + Handler tun.Handler + UDPTimeout time.Duration + Dialer N.Dialer + CreateDialer func(interfaceName string) N.Dialer + Name string + MTU uint32 + GSO bool + Address []netip.Prefix + PrivateKey string + ListenPort uint16 + ResolvePeer func(domain string) (netip.Addr, error) + Peers []PeerOptions + Workers int +} + +type PeerOptions struct { + Endpoint M.Socksaddr + PublicKey string + PreSharedKey string + AllowedIPs []netip.Prefix + PersistentKeepaliveInterval time.Duration + Reserved []uint8 +} diff --git a/transport/wireguard/resolve.go b/transport/wireguard/resolve.go deleted file mode 100644 index d7a1d19c03..0000000000 --- a/transport/wireguard/resolve.go +++ /dev/null @@ -1,148 +0,0 @@ -package wireguard - -import ( - "context" - "encoding/base64" - "encoding/hex" - "net/netip" - - "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-dns" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" -) - -type PeerConfig struct { - destination M.Socksaddr - domainStrategy dns.DomainStrategy - Endpoint netip.AddrPort - PublicKey string - PreSharedKey string - AllowedIPs []string - Reserved [3]uint8 -} - -func (c PeerConfig) GenerateIpcLines() string { - ipcLines := "\npublic_key=" + c.PublicKey - ipcLines += "\nendpoint=" + c.Endpoint.String() - if c.PreSharedKey != "" { - ipcLines += "\npreshared_key=" + c.PreSharedKey - } - for _, allowedIP := range c.AllowedIPs { - ipcLines += "\nallowed_ip=" + allowedIP - } - return ipcLines -} - -func ParsePeers(options option.WireGuardOutboundOptions) ([]PeerConfig, error) { - var peers []PeerConfig - if len(options.Peers) > 0 { - for peerIndex, rawPeer := range options.Peers { - peer := PeerConfig{ - AllowedIPs: rawPeer.AllowedIPs, - } - destination := rawPeer.ServerOptions.Build() - if destination.IsFqdn() { - peer.destination = destination - peer.domainStrategy = dns.DomainStrategy(options.DomainStrategy) - } else { - peer.Endpoint = destination.AddrPort() - } - { - bytes, err := base64.StdEncoding.DecodeString(rawPeer.PublicKey) - if err != nil { - return nil, E.Cause(err, "decode public key for peer ", peerIndex) - } - peer.PublicKey = hex.EncodeToString(bytes) - } - if rawPeer.PreSharedKey != "" { - bytes, err := base64.StdEncoding.DecodeString(rawPeer.PreSharedKey) - if err != nil { - return nil, E.Cause(err, "decode pre shared key for peer ", peerIndex) - } - peer.PreSharedKey = hex.EncodeToString(bytes) - } - if len(rawPeer.AllowedIPs) == 0 { - return nil, E.New("missing allowed_ips for peer ", peerIndex) - } - if len(rawPeer.Reserved) > 0 { - if len(rawPeer.Reserved) != 3 { - return nil, E.New("invalid reserved value for peer ", peerIndex, ", required 3 bytes, got ", len(peer.Reserved)) - } - copy(peer.Reserved[:], options.Reserved) - } - peers = append(peers, peer) - } - } else { - peer := PeerConfig{} - var ( - addressHas4 bool - addressHas6 bool - ) - for _, localAddress := range options.LocalAddress { - if localAddress.Addr().Is4() { - addressHas4 = true - } else { - addressHas6 = true - } - } - if addressHas4 { - peer.AllowedIPs = append(peer.AllowedIPs, netip.PrefixFrom(netip.IPv4Unspecified(), 0).String()) - } - if addressHas6 { - peer.AllowedIPs = append(peer.AllowedIPs, netip.PrefixFrom(netip.IPv6Unspecified(), 0).String()) - } - destination := options.ServerOptions.Build() - if destination.IsFqdn() { - peer.destination = destination - peer.domainStrategy = dns.DomainStrategy(options.DomainStrategy) - } else { - peer.Endpoint = destination.AddrPort() - } - { - bytes, err := base64.StdEncoding.DecodeString(options.PeerPublicKey) - if err != nil { - return nil, E.Cause(err, "decode peer public key") - } - peer.PublicKey = hex.EncodeToString(bytes) - } - if options.PreSharedKey != "" { - bytes, err := base64.StdEncoding.DecodeString(options.PreSharedKey) - if err != nil { - return nil, E.Cause(err, "decode pre shared key") - } - peer.PreSharedKey = hex.EncodeToString(bytes) - } - if len(options.Reserved) > 0 { - if len(options.Reserved) != 3 { - return nil, E.New("invalid reserved value, required 3 bytes, got ", len(peer.Reserved)) - } - copy(peer.Reserved[:], options.Reserved) - } - peers = append(peers, peer) - } - return peers, nil -} - -func ResolvePeers(ctx context.Context, router adapter.Router, peers []PeerConfig) error { - for peerIndex, peer := range peers { - if peer.Endpoint.IsValid() { - continue - } - destinationAddresses, err := router.Lookup(ctx, peer.destination.Fqdn, peer.domainStrategy) - if err != nil { - if len(peers) == 1 { - return E.Cause(err, "resolve endpoint domain") - } else { - return E.Cause(err, "resolve endpoint domain for peer ", peerIndex) - } - } - if len(destinationAddresses) == 0 { - return E.New("no addresses found for endpoint domain: ", peer.destination.Fqdn) - } - peers[peerIndex].Endpoint = netip.AddrPortFrom(destinationAddresses[0], peer.destination.Port) - - } - return nil -}