Skip to content

Commit

Permalink
Add support for more finegrained restriction on who is allowed to ope…
Browse files Browse the repository at this point in the history
…n tunnels
  • Loading branch information
hons82 committed Oct 27, 2021
1 parent db96428 commit 0a071b5
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 62 deletions.
5 changes: 2 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ check: .check-fmt .check-vet .check-lint .check-ineffassign .check-static .check

.PHONY: .check-vendor
.check-vendor:
@dep ensure -no-vendor -dry-run
@go mod vendor

.PHONY: test
test:
Expand All @@ -61,12 +61,11 @@ test:
.PHONY: get-deps
get-deps:
@echo "==> Installing dependencies..."
@dep ensure
@go mod init

.PHONY: get-tools
get-tools:
@echo "==> Installing tools..."
@go get -u github.com/golang/dep/cmd/dep
@go get -u golang.org/x/lint/golint
@go get -u github.com/golang/mock/gomock

Expand Down
2 changes: 1 addition & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func (c *Client) dial() (net.Conn, error) {
if err == nil {
c.logger.Log(
"level", 1,
"msg", fmt.Sprintf("Setting up keep alive using config: %v", c.config.KeepAlive.String()),
"msg", fmt.Sprintf("setting up keep alive using config: %v", c.config.KeepAlive.String()),
)
err = c.config.KeepAlive.Set(conn)
}
Expand Down
51 changes: 25 additions & 26 deletions cmd/tunneld/tunneld.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@ import (
"net"
"net/http"
"os"
"strings"
"time"

"golang.org/x/net/http2"

"github.com/bep/debounce"
tunnel "github.com/hons82/go-http-tunnel"
"github.com/hons82/go-http-tunnel/connection"
"github.com/hons82/go-http-tunnel/id"
"github.com/hons82/go-http-tunnel/fileutil"
"github.com/hons82/go-http-tunnel/log"
)

Expand Down Expand Up @@ -71,16 +70,15 @@ func main() {
}

if !autoSubscribe {
for _, c := range strings.Split(opts.clients, ",") {
if c == "" {
fatal("empty client id")
}
identifier := id.ID{}
err := identifier.UnmarshalText([]byte(c))
if err != nil {
fatal("invalid identifier %q: %s", c, err)
clients, err := fileutil.ReadPropertiesFile(opts.clients)
if err != nil {
fatal("failed to load clients: %s", err)
}

for host, value := range clients {
if err := server.RegisterTunnel(host, value); err != nil {
fatal("failed to load tunnel: %s with error %s", host, err)
}
server.Subscribe(identifier)
}
}

Expand All @@ -103,6 +101,9 @@ func main() {
// start HTTP
if opts.httpAddr != "" {
go func() {
s := &http.Server{
Addr: opts.httpAddr,
}
if opts.httpsAddr != "" {
logger.Log(
"level", 1,
Expand All @@ -114,29 +115,27 @@ func main() {
if err != nil {
fatal("failed to get https port: %s", err)
}
fatal("failed to start HTTP: %s",
http.ListenAndServe(opts.httpAddr, http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
host = r.Host
}
u := r.URL
u.Host = net.JoinHostPort(host, tlsPort)
u.Scheme = "https"
http.Redirect(w, r, u.String(), http.StatusMovedPermanently)
},
)),
s.Handler = http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
host = r.Host
}
u := r.URL
u.Host = net.JoinHostPort(host, tlsPort)
u.Scheme = "https"
http.Redirect(w, r, u.String(), http.StatusMovedPermanently)
},
)
} else {
logger.Log(
"level", 1,
"action", "start http",
"addr", opts.httpAddr,
)

fatal("failed to start HTTP: %s", http.ListenAndServe(opts.httpAddr, server))
s.Handler = server
}
fatal("failed to start HTTP: %s", s.ListenAndServe())
}()
}

Expand Down
45 changes: 45 additions & 0 deletions fileutil/file.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package fileutil

import (
"bufio"
"log"
"os"
"strings"
)

type AppConfigProperties map[string]string

func ReadPropertiesFile(filename string) (AppConfigProperties, error) {
config := AppConfigProperties{}

if len(filename) == 0 {
return config, nil
}
file, err := os.Open(filename)
if err != nil {
log.Fatal(err)
return nil, err
}
defer file.Close()

scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if equal := strings.Index(line, "="); equal >= 0 {
if key := strings.TrimSpace(line[:equal]); len(key) > 0 {
value := ""
if len(line) > equal {
value = strings.TrimSpace(line[equal+1:])
}
config[key] = value
}
}
}

if err := scanner.Err(); err != nil {
log.Fatal(err)
return nil, err
}

return config, nil
}
25 changes: 20 additions & 5 deletions id/ptls.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,49 @@ package id

import (
"crypto/tls"
"crypto/x509"
"fmt"
)

type IDInfo struct {
Client string
}

var emptyID [32]byte
var emptyIDInfo = &IDInfo{}

// PeerID is modified https://github.com/andrew-d/ptls/blob/b89c7dcc94630a77f225a48befd3710144c7c10e/ptls.go#L81
func PeerID(conn *tls.Conn) (ID, error) {
func PeerID(conn *tls.Conn) (ID, IDInfo, error) {
// Try a TLS connection over the given connection. We explicitly perform
// the handshake, since we want to maintain the invariant that, if this
// function returns successfully, then the connection should be valid
// and verified.
if err := conn.Handshake(); err != nil {
return emptyID, err
return emptyID, *emptyIDInfo, err
}

cs := conn.ConnectionState()

// We should have exactly one peer certificate.
certs := cs.PeerCertificates
if cl := len(certs); cl != 1 {
return emptyID, ImproperCertsNumberError{cl}
return emptyID, *emptyIDInfo, ImproperCertsNumberError{cl}
}

// Get remote cert's ID.
remoteCert := certs[0]
remoteID := New(remoteCert.Raw)
remoteID := New(remoteID(*remoteCert))
remoteIDInfo := &IDInfo{
Client: remoteCert.Issuer.SerialNumber,
}
return remoteID, *remoteIDInfo, nil
}

return remoteID, nil
func remoteID(c x509.Certificate) []byte {
if c.Issuer.SerialNumber != "" {
return []byte(c.Issuer.SerialNumber)
}
return c.Raw
}

// ImproperCertsNumberError is returned from Server/Client whenever the remote
Expand Down
48 changes: 41 additions & 7 deletions registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
// RegistryItem holds information about hosts and listeners associated with a
// client.
type RegistryItem struct {
*id.IDInfo
Hosts []*HostAuth
Listeners []net.Listener
}
Expand All @@ -27,6 +28,7 @@ type HostAuth struct {
}

type hostInfo struct {
*id.IDInfo
identifier id.ID
auth *Auth
}
Expand Down Expand Up @@ -91,6 +93,15 @@ func (r *registry) Subscriber(hostPort string) (id.ID, *Auth, bool) {
return h.identifier, h.auth, ok
}

func (r *registry) HasTunnel(hostPort string, identifier id.ID) bool {
r.mu.RLock()
defer r.mu.RUnlock()

h, ok := r.hosts[trimPort(hostPort)]

return ok && h.identifier.Equals(identifier)
}

// Unsubscribe removes client from registry and returns it's RegistryItem.
func (r *registry) Unsubscribe(identifier id.ID) *RegistryItem {
r.mu.Lock()
Expand Down Expand Up @@ -141,7 +152,7 @@ func (r *registry) set(i *RegistryItem, identifier id.ID) error {
if h.Auth != nil && h.Auth.User == "" {
return fmt.Errorf("missing auth user")
}
if _, ok := r.hosts[trimPort(h.Host)]; ok {
if hi, ok := r.hosts[trimPort(h.Host)]; ok && !hi.identifier.Equals(identifier) {
return fmt.Errorf("host %q is occupied", h.Host)
}
}
Expand All @@ -159,6 +170,35 @@ func (r *registry) set(i *RegistryItem, identifier id.ID) error {
return nil
}

func (r *registry) RegisterTunnel(host string, client string) error {
identifier := id.New([]byte(client))

r.logger.Log(
"level", 2,
"action", "add tunnel",
"host", host,
"identifier", identifier,
)

r.Subscribe(identifier)

r.mu.Lock()
defer r.mu.Unlock()

if _, ok := r.hosts[trimPort(host)]; ok {
return fmt.Errorf("host %q is occupied", host)
}

r.hosts[trimPort(host)] = &hostInfo{
identifier: identifier,
IDInfo: &id.IDInfo{
Client: client,
},
}

return nil
}

func (r *registry) clear(identifier id.ID) *RegistryItem {
r.logger.Log(
"level", 2,
Expand All @@ -174,12 +214,6 @@ func (r *registry) clear(identifier id.ID) *RegistryItem {
return nil
}

if i.Hosts != nil {
for _, h := range i.Hosts {
delete(r.hosts, trimPort(h.Host))
}
}

r.items[identifier] = voidRegistryItem

return i
Expand Down
Loading

0 comments on commit 0a071b5

Please sign in to comment.