Skip to content

Commit

Permalink
Add support for more finegrained restriction on who is allowed to ope…
Browse files Browse the repository at this point in the history
…n tunnels
  • Loading branch information
hons82 committed Oct 21, 2021
1 parent db96428 commit 94c143d
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 43 deletions.
5 changes: 2 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ check: .check-fmt .check-vet .check-lint .check-ineffassign .check-static .check

.PHONY: .check-vendor
.check-vendor:
@dep ensure -no-vendor -dry-run
@go mod vendor

.PHONY: test
test:
Expand All @@ -61,12 +61,11 @@ test:
.PHONY: get-deps
get-deps:
@echo "==> Installing dependencies..."
@dep ensure
@go mod init

.PHONY: get-tools
get-tools:
@echo "==> Installing tools..."
@go get -u github.com/golang/dep/cmd/dep
@go get -u golang.org/x/lint/golint
@go get -u github.com/golang/mock/gomock

Expand Down
20 changes: 9 additions & 11 deletions cmd/tunneld/tunneld.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@ import (
"net"
"net/http"
"os"
"strings"
"time"

"golang.org/x/net/http2"

"github.com/bep/debounce"
tunnel "github.com/hons82/go-http-tunnel"
"github.com/hons82/go-http-tunnel/connection"
"github.com/hons82/go-http-tunnel/id"
"github.com/hons82/go-http-tunnel/fileutil"
"github.com/hons82/go-http-tunnel/log"
)

Expand Down Expand Up @@ -71,16 +70,15 @@ func main() {
}

if !autoSubscribe {
for _, c := range strings.Split(opts.clients, ",") {
if c == "" {
fatal("empty client id")
}
identifier := id.ID{}
err := identifier.UnmarshalText([]byte(c))
if err != nil {
fatal("invalid identifier %q: %s", c, err)
clients, err := fileutil.ReadPropertiesFile(opts.clients)
if err != nil {
fatal("failed to load clients: %s", err)
}

for host, value := range clients {
if err := server.RegisterTunnel(host, value); err != nil {
fatal("failed to load tunnel: %s with error %s", host, err)
}
server.Subscribe(identifier)
}
}

Expand Down
45 changes: 45 additions & 0 deletions fileutil/file.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package fileutil

import (
"bufio"
"log"
"os"
"strings"
)

type AppConfigProperties map[string]string

func ReadPropertiesFile(filename string) (AppConfigProperties, error) {
config := AppConfigProperties{}

if len(filename) == 0 {
return config, nil
}
file, err := os.Open(filename)
if err != nil {
log.Fatal(err)
return nil, err
}
defer file.Close()

scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if equal := strings.Index(line, "="); equal >= 0 {
if key := strings.TrimSpace(line[:equal]); len(key) > 0 {
value := ""
if len(line) > equal {
value = strings.TrimSpace(line[equal+1:])
}
config[key] = value
}
}
}

if err := scanner.Err(); err != nil {
log.Fatal(err)
return nil, err
}

return config, nil
}
19 changes: 13 additions & 6 deletions id/ptls.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,38 @@ import (
"fmt"
)

type IDInfo struct {
Client string
}

var emptyID [32]byte
var emptyIDInfo = &IDInfo{}

// PeerID is modified https://github.com/andrew-d/ptls/blob/b89c7dcc94630a77f225a48befd3710144c7c10e/ptls.go#L81
func PeerID(conn *tls.Conn) (ID, error) {
func PeerID(conn *tls.Conn) (ID, *IDInfo, error) {
// Try a TLS connection over the given connection. We explicitly perform
// the handshake, since we want to maintain the invariant that, if this
// function returns successfully, then the connection should be valid
// and verified.
if err := conn.Handshake(); err != nil {
return emptyID, err
return emptyID, emptyIDInfo, err
}

cs := conn.ConnectionState()

// We should have exactly one peer certificate.
certs := cs.PeerCertificates
if cl := len(certs); cl != 1 {
return emptyID, ImproperCertsNumberError{cl}
return emptyID, emptyIDInfo, ImproperCertsNumberError{cl}
}

// Get remote cert's ID.
remoteCert := certs[0]
remoteID := New(remoteCert.Raw)

return remoteID, nil
remoteID := New([]byte(remoteCert.Issuer.SerialNumber))
remoteIDInfo := &IDInfo{
Client: remoteCert.Issuer.SerialNumber,
}
return remoteID, remoteIDInfo, nil
}

// ImproperCertsNumberError is returned from Server/Client whenever the remote
Expand Down
48 changes: 41 additions & 7 deletions registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
// RegistryItem holds information about hosts and listeners associated with a
// client.
type RegistryItem struct {
*id.IDInfo
Hosts []*HostAuth
Listeners []net.Listener
}
Expand All @@ -27,6 +28,7 @@ type HostAuth struct {
}

type hostInfo struct {
*id.IDInfo
identifier id.ID
auth *Auth
}
Expand Down Expand Up @@ -91,6 +93,15 @@ func (r *registry) Subscriber(hostPort string) (id.ID, *Auth, bool) {
return h.identifier, h.auth, ok
}

func (r *registry) HasTunnel(hostPort string, identifier id.ID) bool {
r.mu.RLock()
defer r.mu.RUnlock()

h, ok := r.hosts[trimPort(hostPort)]

return ok && h.identifier.Equals(identifier)
}

// Unsubscribe removes client from registry and returns it's RegistryItem.
func (r *registry) Unsubscribe(identifier id.ID) *RegistryItem {
r.mu.Lock()
Expand Down Expand Up @@ -141,7 +152,7 @@ func (r *registry) set(i *RegistryItem, identifier id.ID) error {
if h.Auth != nil && h.Auth.User == "" {
return fmt.Errorf("missing auth user")
}
if _, ok := r.hosts[trimPort(h.Host)]; ok {
if hi, ok := r.hosts[trimPort(h.Host)]; ok && !hi.identifier.Equals(identifier) {
return fmt.Errorf("host %q is occupied", h.Host)
}
}
Expand All @@ -159,6 +170,35 @@ func (r *registry) set(i *RegistryItem, identifier id.ID) error {
return nil
}

func (r *registry) RegisterTunnel(host string, client string) error {
identifier := id.New([]byte(client))

r.logger.Log(
"level", 2,
"action", "add tunnel",
"host", host,
"identifier", identifier,
)

r.Subscribe(identifier)

r.mu.Lock()
defer r.mu.Unlock()

if _, ok := r.hosts[trimPort(host)]; ok {
return fmt.Errorf("host %q is occupied", host)
}

r.hosts[trimPort(host)] = &hostInfo{
identifier: identifier,
IDInfo: &id.IDInfo{
Client: client,
},
}

return nil
}

func (r *registry) clear(identifier id.ID) *RegistryItem {
r.logger.Log(
"level", 2,
Expand All @@ -174,12 +214,6 @@ func (r *registry) clear(identifier id.ID) *RegistryItem {
return nil
}

if i.Hosts != nil {
for _, h := range i.Hosts {
delete(r.hosts, trimPort(h.Host))
}
}

r.items[identifier] = voidRegistryItem

return i
Expand Down
76 changes: 60 additions & 16 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,20 +163,28 @@ func listener(config *ServerConfig) (net.Listener, error) {

// disconnected clears resources used by client, it's invoked by connection pool when client goes away.
func (s *Server) disconnected(identifier id.ID) {
s.debounce.disconnectedIDs = append(s.debounce.disconnectedIDs, identifier)
if s.debounce.Execute != nil {
s.debounce.disconnectedIDs = append(s.debounce.disconnectedIDs, identifier)

s.debounce.Execute(func() {
for _, id := range s.debounce.disconnectedIDs {
s.logger.Log(
"level", 1,
"action", "disconnected",
"identifier", id,
)
}
s.debounce.disconnectedIDs = nil
})
s.debounce.Execute(func() {
for _, id := range s.debounce.disconnectedIDs {
s.logger.Log(
"level", 1,
"action", "disconnected",
"identifier", id,
)
}
s.debounce.disconnectedIDs = nil
})
} else {
s.logger.Log(
"level", 1,
"action", "disconnected",
"identifier", identifier,
)
}

i := s.registry.clear(identifier)
i := s.unsubscribe(identifier)
if i == nil {
return
}
Expand All @@ -191,6 +199,13 @@ func (s *Server) disconnected(identifier id.ID) {
}
}

func (s *Server) unsubscribe(identifier id.ID) *RegistryItem {
if s.config.AutoSubscribe {
return s.Unsubscribe(identifier)
}
return s.registry.clear(identifier)
}

// Start starts accepting connections form clients. For accepting http traffic
// from end users server must be run as handler on http server.
func (s *Server) Start() {
Expand Down Expand Up @@ -251,6 +266,7 @@ func (s *Server) handleClient(conn net.Conn) {

var (
identifier id.ID
IDInfo *id.IDInfo
req *http.Request
resp *http.Response
tunnels map[string]*proto.Tunnel
Expand All @@ -273,7 +289,7 @@ func (s *Server) handleClient(conn net.Conn) {
goto reject
}

identifier, err = id.PeerID(tlsConn)
identifier, IDInfo, err = id.PeerID(tlsConn)
if err != nil {
logger.Log(
"level", 2,
Expand Down Expand Up @@ -379,7 +395,16 @@ func (s *Server) handleClient(conn net.Conn) {
goto reject
}

if err = s.addTunnels(tunnels, identifier); err != nil {
if err = s.hasTunnels(tunnels, identifier); err != nil {
logger.Log(
"level", 2,
"msg", "tunnel check failed",
"err", err,
)
goto reject
}

if err = s.addTunnels(tunnels, identifier, *IDInfo); err != nil {
logger.Log(
"level", 2,
"msg", "handshake failed",
Expand Down Expand Up @@ -443,10 +468,25 @@ func (s *Server) notifyError(serverError error, identifier id.ID) {
s.httpClient.Do(req.WithContext(ctx))
}

func (s *Server) hasTunnels(tunnels map[string]*proto.Tunnel, identifier id.ID) error {
var err error
for name, t := range tunnels {
// Check the current tunnel
// AutoSubscribe --> Tunnel not yet registered (means that it isn't already opened)
// !AutoSubscribe -> Tunnel has to be already registered, and therefore allowed to be opened
if s.config.AutoSubscribe == s.HasTunnel(t.Host, identifier) {
err = fmt.Errorf("tunnel %s not allowed for %s", name, identifier)
break
}
}
return err
}

// addTunnels invokes addHost or addListener based on data from proto.Tunnel. If
// a tunnel cannot be added whole batch is reverted.
func (s *Server) addTunnels(tunnels map[string]*proto.Tunnel, identifier id.ID) error {
func (s *Server) addTunnels(tunnels map[string]*proto.Tunnel, identifier id.ID, IDInfo id.IDInfo) error {
i := &RegistryItem{
IDInfo: &IDInfo,
Hosts: []*HostAuth{},
Listeners: []net.Listener{},
}
Expand Down Expand Up @@ -847,6 +887,7 @@ type ListenerInfo struct {
// ClientInfo info about the client
type ClientInfo struct {
ID string
IDInfo id.IDInfo
Listeners []*ListenerInfo
Hosts []string
}
Expand All @@ -857,7 +898,10 @@ func (s *Server) GetClientInfo() []*ClientInfo {
defer s.registry.mu.Unlock()
ret := []*ClientInfo{}
for k, v := range s.registry.items {
c := &ClientInfo{ID: k.String()}
c := &ClientInfo{
ID: k.String(),
IDInfo: *v.IDInfo,
}
ret = append(ret, c)
if v == voidRegistryItem {
s.logger.Log(
Expand Down

0 comments on commit 94c143d

Please sign in to comment.