Skip to content

Commit 94c143d

Browse files
committed
Add support for more finegrained restriction on who is allowed to open tunnels
1 parent db96428 commit 94c143d

File tree

6 files changed

+170
-43
lines changed

6 files changed

+170
-43
lines changed

Makefile

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ check: .check-fmt .check-vet .check-lint .check-ineffassign .check-static .check
5151

5252
.PHONY: .check-vendor
5353
.check-vendor:
54-
@dep ensure -no-vendor -dry-run
54+
@go mod vendor
5555

5656
.PHONY: test
5757
test:
@@ -61,12 +61,11 @@ test:
6161
.PHONY: get-deps
6262
get-deps:
6363
@echo "==> Installing dependencies..."
64-
@dep ensure
64+
@go mod init
6565

6666
.PHONY: get-tools
6767
get-tools:
6868
@echo "==> Installing tools..."
69-
@go get -u github.com/golang/dep/cmd/dep
7069
@go get -u golang.org/x/lint/golint
7170
@go get -u github.com/golang/mock/gomock
7271

cmd/tunneld/tunneld.go

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,14 @@ import (
1212
"net"
1313
"net/http"
1414
"os"
15-
"strings"
1615
"time"
1716

1817
"golang.org/x/net/http2"
1918

2019
"github.com/bep/debounce"
2120
tunnel "github.com/hons82/go-http-tunnel"
2221
"github.com/hons82/go-http-tunnel/connection"
23-
"github.com/hons82/go-http-tunnel/id"
22+
"github.com/hons82/go-http-tunnel/fileutil"
2423
"github.com/hons82/go-http-tunnel/log"
2524
)
2625

@@ -71,16 +70,15 @@ func main() {
7170
}
7271

7372
if !autoSubscribe {
74-
for _, c := range strings.Split(opts.clients, ",") {
75-
if c == "" {
76-
fatal("empty client id")
77-
}
78-
identifier := id.ID{}
79-
err := identifier.UnmarshalText([]byte(c))
80-
if err != nil {
81-
fatal("invalid identifier %q: %s", c, err)
73+
clients, err := fileutil.ReadPropertiesFile(opts.clients)
74+
if err != nil {
75+
fatal("failed to load clients: %s", err)
76+
}
77+
78+
for host, value := range clients {
79+
if err := server.RegisterTunnel(host, value); err != nil {
80+
fatal("failed to load tunnel: %s with error %s", host, err)
8281
}
83-
server.Subscribe(identifier)
8482
}
8583
}
8684

fileutil/file.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package fileutil
2+
3+
import (
4+
"bufio"
5+
"log"
6+
"os"
7+
"strings"
8+
)
9+
10+
type AppConfigProperties map[string]string
11+
12+
func ReadPropertiesFile(filename string) (AppConfigProperties, error) {
13+
config := AppConfigProperties{}
14+
15+
if len(filename) == 0 {
16+
return config, nil
17+
}
18+
file, err := os.Open(filename)
19+
if err != nil {
20+
log.Fatal(err)
21+
return nil, err
22+
}
23+
defer file.Close()
24+
25+
scanner := bufio.NewScanner(file)
26+
for scanner.Scan() {
27+
line := scanner.Text()
28+
if equal := strings.Index(line, "="); equal >= 0 {
29+
if key := strings.TrimSpace(line[:equal]); len(key) > 0 {
30+
value := ""
31+
if len(line) > equal {
32+
value = strings.TrimSpace(line[equal+1:])
33+
}
34+
config[key] = value
35+
}
36+
}
37+
}
38+
39+
if err := scanner.Err(); err != nil {
40+
log.Fatal(err)
41+
return nil, err
42+
}
43+
44+
return config, nil
45+
}

id/ptls.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,31 +9,38 @@ import (
99
"fmt"
1010
)
1111

12+
type IDInfo struct {
13+
Client string
14+
}
15+
1216
var emptyID [32]byte
17+
var emptyIDInfo = &IDInfo{}
1318

1419
// PeerID is modified https://github.com/andrew-d/ptls/blob/b89c7dcc94630a77f225a48befd3710144c7c10e/ptls.go#L81
15-
func PeerID(conn *tls.Conn) (ID, error) {
20+
func PeerID(conn *tls.Conn) (ID, *IDInfo, error) {
1621
// Try a TLS connection over the given connection. We explicitly perform
1722
// the handshake, since we want to maintain the invariant that, if this
1823
// function returns successfully, then the connection should be valid
1924
// and verified.
2025
if err := conn.Handshake(); err != nil {
21-
return emptyID, err
26+
return emptyID, emptyIDInfo, err
2227
}
2328

2429
cs := conn.ConnectionState()
2530

2631
// We should have exactly one peer certificate.
2732
certs := cs.PeerCertificates
2833
if cl := len(certs); cl != 1 {
29-
return emptyID, ImproperCertsNumberError{cl}
34+
return emptyID, emptyIDInfo, ImproperCertsNumberError{cl}
3035
}
3136

3237
// Get remote cert's ID.
3338
remoteCert := certs[0]
34-
remoteID := New(remoteCert.Raw)
35-
36-
return remoteID, nil
39+
remoteID := New([]byte(remoteCert.Issuer.SerialNumber))
40+
remoteIDInfo := &IDInfo{
41+
Client: remoteCert.Issuer.SerialNumber,
42+
}
43+
return remoteID, remoteIDInfo, nil
3744
}
3845

3946
// ImproperCertsNumberError is returned from Server/Client whenever the remote

registry.go

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
// RegistryItem holds information about hosts and listeners associated with a
1717
// client.
1818
type RegistryItem struct {
19+
*id.IDInfo
1920
Hosts []*HostAuth
2021
Listeners []net.Listener
2122
}
@@ -27,6 +28,7 @@ type HostAuth struct {
2728
}
2829

2930
type hostInfo struct {
31+
*id.IDInfo
3032
identifier id.ID
3133
auth *Auth
3234
}
@@ -91,6 +93,15 @@ func (r *registry) Subscriber(hostPort string) (id.ID, *Auth, bool) {
9193
return h.identifier, h.auth, ok
9294
}
9395

96+
func (r *registry) HasTunnel(hostPort string, identifier id.ID) bool {
97+
r.mu.RLock()
98+
defer r.mu.RUnlock()
99+
100+
h, ok := r.hosts[trimPort(hostPort)]
101+
102+
return ok && h.identifier.Equals(identifier)
103+
}
104+
94105
// Unsubscribe removes client from registry and returns it's RegistryItem.
95106
func (r *registry) Unsubscribe(identifier id.ID) *RegistryItem {
96107
r.mu.Lock()
@@ -141,7 +152,7 @@ func (r *registry) set(i *RegistryItem, identifier id.ID) error {
141152
if h.Auth != nil && h.Auth.User == "" {
142153
return fmt.Errorf("missing auth user")
143154
}
144-
if _, ok := r.hosts[trimPort(h.Host)]; ok {
155+
if hi, ok := r.hosts[trimPort(h.Host)]; ok && !hi.identifier.Equals(identifier) {
145156
return fmt.Errorf("host %q is occupied", h.Host)
146157
}
147158
}
@@ -159,6 +170,35 @@ func (r *registry) set(i *RegistryItem, identifier id.ID) error {
159170
return nil
160171
}
161172

173+
func (r *registry) RegisterTunnel(host string, client string) error {
174+
identifier := id.New([]byte(client))
175+
176+
r.logger.Log(
177+
"level", 2,
178+
"action", "add tunnel",
179+
"host", host,
180+
"identifier", identifier,
181+
)
182+
183+
r.Subscribe(identifier)
184+
185+
r.mu.Lock()
186+
defer r.mu.Unlock()
187+
188+
if _, ok := r.hosts[trimPort(host)]; ok {
189+
return fmt.Errorf("host %q is occupied", host)
190+
}
191+
192+
r.hosts[trimPort(host)] = &hostInfo{
193+
identifier: identifier,
194+
IDInfo: &id.IDInfo{
195+
Client: client,
196+
},
197+
}
198+
199+
return nil
200+
}
201+
162202
func (r *registry) clear(identifier id.ID) *RegistryItem {
163203
r.logger.Log(
164204
"level", 2,
@@ -174,12 +214,6 @@ func (r *registry) clear(identifier id.ID) *RegistryItem {
174214
return nil
175215
}
176216

177-
if i.Hosts != nil {
178-
for _, h := range i.Hosts {
179-
delete(r.hosts, trimPort(h.Host))
180-
}
181-
}
182-
183217
r.items[identifier] = voidRegistryItem
184218

185219
return i

server.go

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -163,20 +163,28 @@ func listener(config *ServerConfig) (net.Listener, error) {
163163

164164
// disconnected clears resources used by client, it's invoked by connection pool when client goes away.
165165
func (s *Server) disconnected(identifier id.ID) {
166-
s.debounce.disconnectedIDs = append(s.debounce.disconnectedIDs, identifier)
166+
if s.debounce.Execute != nil {
167+
s.debounce.disconnectedIDs = append(s.debounce.disconnectedIDs, identifier)
167168

168-
s.debounce.Execute(func() {
169-
for _, id := range s.debounce.disconnectedIDs {
170-
s.logger.Log(
171-
"level", 1,
172-
"action", "disconnected",
173-
"identifier", id,
174-
)
175-
}
176-
s.debounce.disconnectedIDs = nil
177-
})
169+
s.debounce.Execute(func() {
170+
for _, id := range s.debounce.disconnectedIDs {
171+
s.logger.Log(
172+
"level", 1,
173+
"action", "disconnected",
174+
"identifier", id,
175+
)
176+
}
177+
s.debounce.disconnectedIDs = nil
178+
})
179+
} else {
180+
s.logger.Log(
181+
"level", 1,
182+
"action", "disconnected",
183+
"identifier", identifier,
184+
)
185+
}
178186

179-
i := s.registry.clear(identifier)
187+
i := s.unsubscribe(identifier)
180188
if i == nil {
181189
return
182190
}
@@ -191,6 +199,13 @@ func (s *Server) disconnected(identifier id.ID) {
191199
}
192200
}
193201

202+
func (s *Server) unsubscribe(identifier id.ID) *RegistryItem {
203+
if s.config.AutoSubscribe {
204+
return s.Unsubscribe(identifier)
205+
}
206+
return s.registry.clear(identifier)
207+
}
208+
194209
// Start starts accepting connections form clients. For accepting http traffic
195210
// from end users server must be run as handler on http server.
196211
func (s *Server) Start() {
@@ -251,6 +266,7 @@ func (s *Server) handleClient(conn net.Conn) {
251266

252267
var (
253268
identifier id.ID
269+
IDInfo *id.IDInfo
254270
req *http.Request
255271
resp *http.Response
256272
tunnels map[string]*proto.Tunnel
@@ -273,7 +289,7 @@ func (s *Server) handleClient(conn net.Conn) {
273289
goto reject
274290
}
275291

276-
identifier, err = id.PeerID(tlsConn)
292+
identifier, IDInfo, err = id.PeerID(tlsConn)
277293
if err != nil {
278294
logger.Log(
279295
"level", 2,
@@ -379,7 +395,16 @@ func (s *Server) handleClient(conn net.Conn) {
379395
goto reject
380396
}
381397

382-
if err = s.addTunnels(tunnels, identifier); err != nil {
398+
if err = s.hasTunnels(tunnels, identifier); err != nil {
399+
logger.Log(
400+
"level", 2,
401+
"msg", "tunnel check failed",
402+
"err", err,
403+
)
404+
goto reject
405+
}
406+
407+
if err = s.addTunnels(tunnels, identifier, *IDInfo); err != nil {
383408
logger.Log(
384409
"level", 2,
385410
"msg", "handshake failed",
@@ -443,10 +468,25 @@ func (s *Server) notifyError(serverError error, identifier id.ID) {
443468
s.httpClient.Do(req.WithContext(ctx))
444469
}
445470

471+
func (s *Server) hasTunnels(tunnels map[string]*proto.Tunnel, identifier id.ID) error {
472+
var err error
473+
for name, t := range tunnels {
474+
// Check the current tunnel
475+
// AutoSubscribe --> Tunnel not yet registered (means that it isn't already opened)
476+
// !AutoSubscribe -> Tunnel has to be already registered, and therefore allowed to be opened
477+
if s.config.AutoSubscribe == s.HasTunnel(t.Host, identifier) {
478+
err = fmt.Errorf("tunnel %s not allowed for %s", name, identifier)
479+
break
480+
}
481+
}
482+
return err
483+
}
484+
446485
// addTunnels invokes addHost or addListener based on data from proto.Tunnel. If
447486
// a tunnel cannot be added whole batch is reverted.
448-
func (s *Server) addTunnels(tunnels map[string]*proto.Tunnel, identifier id.ID) error {
487+
func (s *Server) addTunnels(tunnels map[string]*proto.Tunnel, identifier id.ID, IDInfo id.IDInfo) error {
449488
i := &RegistryItem{
489+
IDInfo: &IDInfo,
450490
Hosts: []*HostAuth{},
451491
Listeners: []net.Listener{},
452492
}
@@ -847,6 +887,7 @@ type ListenerInfo struct {
847887
// ClientInfo info about the client
848888
type ClientInfo struct {
849889
ID string
890+
IDInfo id.IDInfo
850891
Listeners []*ListenerInfo
851892
Hosts []string
852893
}
@@ -857,7 +898,10 @@ func (s *Server) GetClientInfo() []*ClientInfo {
857898
defer s.registry.mu.Unlock()
858899
ret := []*ClientInfo{}
859900
for k, v := range s.registry.items {
860-
c := &ClientInfo{ID: k.String()}
901+
c := &ClientInfo{
902+
ID: k.String(),
903+
IDInfo: *v.IDInfo,
904+
}
861905
ret = append(ret, c)
862906
if v == voidRegistryItem {
863907
s.logger.Log(

0 commit comments

Comments
 (0)