Skip to content

Commit

Permalink
refactor: WireGuard endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Nov 21, 2024
1 parent b910428 commit 9443c94
Show file tree
Hide file tree
Showing 75 changed files with 1,649 additions and 674 deletions.
28 changes: 28 additions & 0 deletions adapter/endpoint.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package adapter

import (
"context"

"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
)

type Endpoint interface {
Lifecycle
Type() string
Tag() string
Outbound
}

type EndpointRegistry interface {
option.EndpointOptionsRegistry
Create(ctx context.Context, router Router, logger log.ContextLogger, tag string, endpointType string, options any) (Endpoint, error)
}

type EndpointManager interface {
Lifecycle
Endpoints() []Endpoint
Get(tag string) (Endpoint, bool)
Remove(tag string) error
Create(ctx context.Context, router Router, logger log.ContextLogger, tag string, endpointType string, options any) error
}
43 changes: 43 additions & 0 deletions adapter/endpoint/adapter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package endpoint

import "github.com/sagernet/sing-box/option"

type Adapter struct {
endpointType string
endpointTag string
network []string
dependencies []string
}

func NewAdapter(endpointType string, endpointTag string, network []string, dependencies []string) Adapter {
return Adapter{
endpointType: endpointType,
endpointTag: endpointTag,
network: network,
dependencies: dependencies,
}
}

func NewAdapterWithDialerOptions(endpointType string, endpointTag string, network []string, dialOptions option.DialerOptions) Adapter {
var dependencies []string
if dialOptions.Detour != "" {
dependencies = []string{dialOptions.Detour}
}
return NewAdapter(endpointType, endpointTag, network, dependencies)
}

func (a *Adapter) Type() string {
return a.endpointType
}

func (a *Adapter) Tag() string {
return a.endpointTag
}

func (a *Adapter) Network() []string {
return a.network
}

func (a *Adapter) Dependencies() []string {
return a.dependencies
}
147 changes: 147 additions & 0 deletions adapter/endpoint/manager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package endpoint

import (
"context"
"os"
"sync"

"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)

var _ adapter.EndpointManager = (*Manager)(nil)

type Manager struct {
logger log.ContextLogger
registry adapter.EndpointRegistry
access sync.Mutex
started bool
stage adapter.StartStage
endpoints []adapter.Endpoint
endpointByTag map[string]adapter.Endpoint
}

func NewManager(logger log.ContextLogger, registry adapter.EndpointRegistry) *Manager {
return &Manager{
logger: logger,
registry: registry,
endpointByTag: make(map[string]adapter.Endpoint),
}
}

func (m *Manager) Start(stage adapter.StartStage) error {
m.access.Lock()
defer m.access.Unlock()
if m.started && m.stage >= stage {
panic("already started")
}
m.started = true
m.stage = stage
if stage == adapter.StartStateStart {
// started with outbound manager
return nil
}
for _, endpoint := range m.endpoints {
err := adapter.LegacyStart(endpoint, stage)
if err != nil {
return E.Cause(err, stage, " endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
}
}
return nil
}

func (m *Manager) Close() error {
m.access.Lock()
defer m.access.Unlock()
if !m.started {
return nil
}
m.started = false
endpoints := m.endpoints
m.endpoints = nil
monitor := taskmonitor.New(m.logger, C.StopTimeout)
var err error
for _, endpoint := range endpoints {
monitor.Start("close endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
err = E.Append(err, endpoint.Close(), func(err error) error {
return E.Cause(err, "close endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
})
monitor.Finish()
}
return nil
}

func (m *Manager) Endpoints() []adapter.Endpoint {
m.access.Lock()
defer m.access.Unlock()
return m.endpoints
}

func (m *Manager) Get(tag string) (adapter.Endpoint, bool) {
m.access.Lock()
defer m.access.Unlock()
endpoint, found := m.endpointByTag[tag]
return endpoint, found
}

func (m *Manager) Remove(tag string) error {
m.access.Lock()
endpoint, found := m.endpointByTag[tag]
if !found {
m.access.Unlock()
return os.ErrInvalid
}
delete(m.endpointByTag, tag)
index := common.Index(m.endpoints, func(it adapter.Endpoint) bool {
return it == endpoint
})
if index == -1 {
panic("invalid endpoint index")
}
m.endpoints = append(m.endpoints[:index], m.endpoints[index+1:]...)
started := m.started
m.access.Unlock()
if started {
return endpoint.Close()
}
return nil
}

func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, outboundType string, options any) error {
endpoint, err := m.registry.Create(ctx, router, logger, tag, outboundType, options)
if err != nil {
return err
}
m.access.Lock()
defer m.access.Unlock()
if m.started {
for _, stage := range adapter.ListStartStages {
err = adapter.LegacyStart(endpoint, stage)
if err != nil {
return E.Cause(err, stage, " endpoint/", endpoint.Type(), "[", endpoint.Tag(), "]")
}
}
}
if existsEndpoint, loaded := m.endpointByTag[tag]; loaded {
if m.started {
err = existsEndpoint.Close()
if err != nil {
return E.Cause(err, "close endpoint/", existsEndpoint.Type(), "[", existsEndpoint.Tag(), "]")
}
}
existsIndex := common.Index(m.endpoints, func(it adapter.Endpoint) bool {
return it == existsEndpoint
})
if existsIndex == -1 {
panic("invalid endpoint index")
}
m.endpoints = append(m.endpoints[:existsIndex], m.endpoints[existsIndex+1:]...)
}
m.endpoints = append(m.endpoints, endpoint)
m.endpointByTag[tag] = endpoint
return nil
}
72 changes: 72 additions & 0 deletions adapter/endpoint/registry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package endpoint

import (
"context"
"sync"

"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)

type ConstructorFunc[T any] func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options T) (adapter.Endpoint, error)

func Register[Options any](registry *Registry, outboundType string, constructor ConstructorFunc[Options]) {
registry.register(outboundType, func() any {
return new(Options)
}, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, rawOptions any) (adapter.Endpoint, error) {
var options *Options
if rawOptions != nil {
options = rawOptions.(*Options)
}
return constructor(ctx, router, logger, tag, common.PtrValueOrDefault(options))
})
}

var _ adapter.EndpointRegistry = (*Registry)(nil)

type (
optionsConstructorFunc func() any
constructorFunc func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options any) (adapter.Endpoint, error)
)

type Registry struct {
access sync.Mutex
optionsType map[string]optionsConstructorFunc
constructor map[string]constructorFunc
}

func NewRegistry() *Registry {
return &Registry{
optionsType: make(map[string]optionsConstructorFunc),
constructor: make(map[string]constructorFunc),
}
}

func (m *Registry) CreateOptions(outboundType string) (any, bool) {
m.access.Lock()
defer m.access.Unlock()
optionsConstructor, loaded := m.optionsType[outboundType]
if !loaded {
return nil, false
}
return optionsConstructor(), true
}

func (m *Registry) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, outboundType string, options any) (adapter.Endpoint, error) {
m.access.Lock()
defer m.access.Unlock()
constructor, loaded := m.constructor[outboundType]
if !loaded {
return nil, E.New("outbound type not found: " + outboundType)
}
return constructor(ctx, router, logger, tag, options)
}

func (m *Registry) register(outboundType string, optionsConstructor optionsConstructorFunc, constructor constructorFunc) {
m.access.Lock()
defer m.access.Unlock()
m.optionsType[outboundType] = optionsConstructor
m.constructor[outboundType] = constructor
}
2 changes: 1 addition & 1 deletion adapter/inbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
)

type Inbound interface {
Service
Lifecycle
Type() string
Tag() string
}
Expand Down
11 changes: 8 additions & 3 deletions adapter/inbound/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,19 @@ var _ adapter.InboundManager = (*Manager)(nil)
type Manager struct {
logger log.ContextLogger
registry adapter.InboundRegistry
endpoint adapter.EndpointManager
access sync.Mutex
started bool
stage adapter.StartStage
inbounds []adapter.Inbound
inboundByTag map[string]adapter.Inbound
}

func NewManager(logger log.ContextLogger, registry adapter.InboundRegistry) *Manager {
func NewManager(logger log.ContextLogger, registry adapter.InboundRegistry, endpoint adapter.EndpointManager) *Manager {
return &Manager{
logger: logger,
registry: registry,
endpoint: endpoint,
inboundByTag: make(map[string]adapter.Inbound),
}
}
Expand Down Expand Up @@ -79,9 +81,12 @@ func (m *Manager) Inbounds() []adapter.Inbound {

func (m *Manager) Get(tag string) (adapter.Inbound, bool) {
m.access.Lock()
defer m.access.Unlock()
inbound, found := m.inboundByTag[tag]
return inbound, found
m.access.Unlock()
if found {
return inbound, true
}
return m.endpoint.Get(tag)
}

func (m *Manager) Remove(tag string) error {
Expand Down
3 changes: 3 additions & 0 deletions adapter/lifecycle_legacy.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package adapter

func LegacyStart(starter any, stage StartStage) error {
if lifecycle, isLifecycle := starter.(Lifecycle); isLifecycle {
return lifecycle.Start(stage)
}
switch stage {
case StartStateInitialize:
if preStarter, isPreStarter := starter.(interface {
Expand Down
18 changes: 9 additions & 9 deletions adapter/outbound/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,35 @@ import (
)

type Adapter struct {
protocol string
outboundType string
outboundTag string
network []string
tag string
dependencies []string
}

func NewAdapter(protocol string, network []string, tag string, dependencies []string) Adapter {
func NewAdapter(outboundType string, outboundTag string, network []string, dependencies []string) Adapter {
return Adapter{
protocol: protocol,
outboundType: outboundType,
outboundTag: outboundTag,
network: network,
tag: tag,
dependencies: dependencies,
}
}

func NewAdapterWithDialerOptions(protocol string, network []string, tag string, dialOptions option.DialerOptions) Adapter {
func NewAdapterWithDialerOptions(outboundType string, outboundTag string, network []string, dialOptions option.DialerOptions) Adapter {
var dependencies []string
if dialOptions.Detour != "" {
dependencies = []string{dialOptions.Detour}
}
return NewAdapter(protocol, network, tag, dependencies)
return NewAdapter(outboundType, outboundTag, network, dependencies)
}

func (a *Adapter) Type() string {
return a.protocol
return a.outboundType
}

func (a *Adapter) Tag() string {
return a.tag
return a.outboundTag
}

func (a *Adapter) Network() []string {
Expand Down
Loading

0 comments on commit 9443c94

Please sign in to comment.