@@ -7,6 +7,7 @@ package tunnel
7
7
import (
8
8
"context"
9
9
"crypto/tls"
10
+ "crypto/x509"
10
11
"encoding/json"
11
12
"errors"
12
13
"fmt"
@@ -25,6 +26,20 @@ import (
25
26
"github.com/inconshreveable/go-vhost"
26
27
)
27
28
29
+ // A set of listeners to manage subscribers
30
+ type SubscriptionListener interface {
31
+ // Invoked if AutoSubscribe is false and must return true if the client is allowed to subscribe or not.
32
+ // If the tlsConfig is configured to require client certificate validation, chain will contain the first
33
+ // verified chain, else the presented peer certificate.
34
+ CanSubscribe (id id.ID , chain []* x509.Certificate ) bool
35
+ // Invoked when the client has been subscribed.
36
+ // If the tlsConfig is configured to require client certificate validation, chain will contain the first
37
+ // verified chain, else the presented peer certificate.
38
+ Subscribed (id id.ID , tlsConn * tls.Conn , chain []* x509.Certificate )
39
+ // Invoked before the client is unsubscribed.
40
+ Unsubscribed (id id.ID )
41
+ }
42
+
28
43
// ServerConfig defines configuration for the Server.
29
44
type ServerConfig struct {
30
45
// Addr is TCP address to listen for client connections. If empty ":0" is used.
@@ -43,6 +58,8 @@ type ServerConfig struct {
43
58
KeepAlive connection.KeepAliveConfig
44
59
// How long should a disconnected message been hold before sending it to the log
45
60
Debounce Debounced
61
+ // Optional listener to manage subscribers
62
+ SubscriptionListener SubscriptionListener
46
63
}
47
64
48
65
// Server is responsible for proxying public connections to the client over a
@@ -274,6 +291,7 @@ func (s *Server) handleClient(conn net.Conn) {
274
291
ok bool
275
292
276
293
inConnPool bool
294
+ certs []* x509.Certificate
277
295
278
296
remainingIDs []id.ID
279
297
found bool
@@ -301,14 +319,26 @@ func (s *Server) handleClient(conn net.Conn) {
301
319
302
320
logger = logger .With ("identifier" , identifier )
303
321
322
+ certs = tlsConn .ConnectionState ().PeerCertificates
323
+ if tlsConn .ConnectionState ().VerifiedChains != nil && len (tlsConn .ConnectionState ().VerifiedChains ) > 0 {
324
+ certs = tlsConn .ConnectionState ().VerifiedChains [0 ]
325
+ }
304
326
if s .config .AutoSubscribe {
305
327
s .Subscribe (identifier )
328
+ if s .config .SubscriptionListener != nil {
329
+ s .config .SubscriptionListener .Subscribed (identifier , tlsConn , certs )
330
+ }
306
331
} else if ! s .IsSubscribed (identifier ) {
307
- logger .Log (
308
- "level" , 2 ,
309
- "msg" , "unknown client" ,
310
- )
311
- goto reject
332
+ if s .config .SubscriptionListener != nil && s .config .SubscriptionListener .CanSubscribe (identifier , certs ) {
333
+ s .Subscribe (identifier )
334
+ s .config .SubscriptionListener .Subscribed (identifier , tlsConn , certs )
335
+ } else {
336
+ logger .Log (
337
+ "level" , 2 ,
338
+ "msg" , "unknown client" ,
339
+ )
340
+ goto reject
341
+ }
312
342
}
313
343
314
344
if err = conn .SetDeadline (time.Time {}); err != nil {
@@ -555,9 +585,12 @@ rollback:
555
585
return err
556
586
}
557
587
558
- // Disconnect removes client from registry, disconnects client if already
588
+ // Unsubscribe removes client from registry, disconnects client if already
559
589
// connected and returns it's RegistryItem.
560
- func (s * Server ) Disconnect (identifier id.ID ) * RegistryItem {
590
+ func (s * Server ) Unsubscribe (identifier id.ID ) * RegistryItem {
591
+ if s .config .SubscriptionListener != nil {
592
+ s .config .SubscriptionListener .Unsubscribed (identifier )
593
+ }
561
594
s .connPool .DeleteConn (identifier )
562
595
return s .registry .Unsubscribe (identifier )
563
596
}
@@ -639,6 +672,50 @@ func (s *Server) listen(l net.Listener, identifier id.ID) {
639
672
}
640
673
}
641
674
675
+ func (s * Server ) Upgrade (identifier id.ID , conn net.Conn , requestBytes []byte ) error {
676
+
677
+ var err error
678
+
679
+ msg := & proto.ControlMessage {
680
+ Action : proto .ActionProxy ,
681
+ ForwardedProto : "https" ,
682
+ }
683
+
684
+ tlsConn , ok := conn .(* tls.Conn )
685
+ if ok {
686
+ msg .ForwardedHost = tlsConn .ConnectionState ().ServerName
687
+ err = s .config .KeepAlive .Set (tlsConn .NetConn ())
688
+
689
+ } else {
690
+ msg .ForwardedHost = conn .RemoteAddr ().String ()
691
+ err = s .config .KeepAlive .Set (conn )
692
+ }
693
+
694
+ if err != nil {
695
+ s .logger .Log (
696
+ "level" , 1 ,
697
+ "msg" , "TCP keepalive for tunneled connection failed" ,
698
+ "identifier" , identifier ,
699
+ "ctrlMsg" , msg ,
700
+ "err" , err ,
701
+ )
702
+ }
703
+
704
+ go func () {
705
+ if err := s .proxyConnUpgraded (identifier , conn , msg , requestBytes ); err != nil {
706
+ s .logger .Log (
707
+ "level" , 0 ,
708
+ "msg" , "proxy error" ,
709
+ "identifier" , identifier ,
710
+ "ctrlMsg" , msg ,
711
+ "err" , err ,
712
+ )
713
+ }
714
+ }()
715
+
716
+ return nil
717
+ }
718
+
642
719
// ServeHTTP proxies http connection to the client.
643
720
func (s * Server ) ServeHTTP (w http.ResponseWriter , r * http.Request ) {
644
721
resp , err := s .RoundTrip (r )
@@ -724,6 +801,74 @@ func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) {
724
801
return s .proxyHTTP (identifier , outr , msg )
725
802
}
726
803
804
+ func (s * Server ) proxyConnUpgraded (identifier id.ID , conn net.Conn , msg * proto.ControlMessage , requestBytes []byte ) error {
805
+ s .logger .Log (
806
+ "level" , 2 ,
807
+ "action" , "proxy conn" ,
808
+ "identifier" , identifier ,
809
+ "ctrlMsg" , msg ,
810
+ )
811
+
812
+ defer conn .Close ()
813
+
814
+ pr , pw := io .Pipe ()
815
+ defer pr .Close ()
816
+ defer pw .Close ()
817
+
818
+ continueChan := make (chan int )
819
+
820
+ go func () {
821
+ pw .Write (requestBytes )
822
+ continueChan <- 1
823
+ }()
824
+
825
+ req , err := s .connectRequest (identifier , msg , pr )
826
+ if err != nil {
827
+ return err
828
+ }
829
+
830
+ ctx , cancel := context .WithCancel (context .Background ())
831
+ req = req .WithContext (ctx )
832
+
833
+ done := make (chan struct {})
834
+ go func () {
835
+ <- continueChan
836
+ transfer (pw , conn , log .NewContext (s .logger ).With (
837
+ "dir" , "user to client" ,
838
+ "dst" , identifier ,
839
+ "src" , conn .RemoteAddr (),
840
+ ))
841
+ cancel ()
842
+ close (done )
843
+ }()
844
+
845
+ resp , err := s .httpClient .Do (req )
846
+ if err != nil {
847
+ return fmt .Errorf ("io error: %s" , err )
848
+ }
849
+ defer resp .Body .Close ()
850
+
851
+ transfer (conn , resp .Body , log .NewContext (s .logger ).With (
852
+ "dir" , "client to user" ,
853
+ "dst" , conn .RemoteAddr (),
854
+ "src" , identifier ,
855
+ ))
856
+
857
+ select {
858
+ case <- done :
859
+ case <- time .After (DefaultTimeout ):
860
+ }
861
+
862
+ s .logger .Log (
863
+ "level" , 2 ,
864
+ "action" , "proxy conn done" ,
865
+ "identifier" , identifier ,
866
+ "ctrlMsg" , msg ,
867
+ )
868
+
869
+ return nil
870
+ }
871
+
727
872
func (s * Server ) proxyConn (identifier id.ID , conn net.Conn , msg * proto.ControlMessage ) error {
728
873
s .logger .Log (
729
874
"level" , 2 ,
0 commit comments