@@ -7,6 +7,7 @@ package tunnel
7
7
import (
8
8
"context"
9
9
"crypto/tls"
10
+ "crypto/x509"
10
11
"encoding/json"
11
12
"errors"
12
13
"fmt"
@@ -24,6 +25,20 @@ import (
24
25
"github.com/mmatczuk/go-http-tunnel/proto"
25
26
)
26
27
28
+ // A set of listeners to manage subscribers
29
+ type SubscriptionListener interface {
30
+ // Invoked if AutoSubscribe is false and must return true if the client is allowed to subscribe or not.
31
+ // If the tlsConfig is configured to require client certificate validation, chain will contain the first
32
+ // verified chain, else the presented peer certificate.
33
+ CanSubscribe (id id.ID , chain []* x509.Certificate ) bool
34
+ // Invoked when the client has been subscribed.
35
+ // If the tlsConfig is configured to require client certificate validation, chain will contain the first
36
+ // verified chain, else the presented peer certificate.
37
+ Subscribed (id id.ID , tlsConn * tls.Conn , chain []* x509.Certificate )
38
+ // Invoked before the client is unsubscribed.
39
+ Unsubscribed (id id.ID )
40
+ }
41
+
27
42
// ServerConfig defines configuration for the Server.
28
43
type ServerConfig struct {
29
44
// Addr is TCP address to listen for client connections. If empty ":0"
@@ -41,6 +56,8 @@ type ServerConfig struct {
41
56
Logger log.Logger
42
57
// Addr is TCP address to listen for TLS SNI connections
43
58
SNIAddr string
59
+ // Optional listener to manage subscribers
60
+ SubscriptionListener SubscriptionListener
44
61
}
45
62
46
63
// Server is responsible for proxying public connections to the client over a
@@ -238,6 +255,7 @@ func (s *Server) handleClient(conn net.Conn) {
238
255
ok bool
239
256
240
257
inConnPool bool
258
+ certs []* x509.Certificate
241
259
)
242
260
243
261
tlsConn , ok := conn .(* tls.Conn )
@@ -262,14 +280,26 @@ func (s *Server) handleClient(conn net.Conn) {
262
280
263
281
logger = logger .With ("identifier" , identifier )
264
282
283
+ certs = tlsConn .ConnectionState ().PeerCertificates
284
+ if tlsConn .ConnectionState ().VerifiedChains != nil && len (tlsConn .ConnectionState ().VerifiedChains ) > 0 {
285
+ certs = tlsConn .ConnectionState ().VerifiedChains [0 ]
286
+ }
265
287
if s .config .AutoSubscribe {
266
288
s .Subscribe (identifier )
289
+ if s .config .SubscriptionListener != nil {
290
+ s .config .SubscriptionListener .Subscribed (identifier , tlsConn , certs )
291
+ }
267
292
} else if ! s .IsSubscribed (identifier ) {
268
- logger .Log (
269
- "level" , 2 ,
270
- "msg" , "unknown client" ,
271
- )
272
- goto reject
293
+ if s .config .SubscriptionListener != nil && s .config .SubscriptionListener .CanSubscribe (identifier , certs ) {
294
+ s .Subscribe (identifier )
295
+ s .config .SubscriptionListener .Subscribed (identifier , tlsConn , certs )
296
+ } else {
297
+ logger .Log (
298
+ "level" , 2 ,
299
+ "msg" , "unknown client" ,
300
+ )
301
+ goto reject
302
+ }
273
303
}
274
304
275
305
if err = conn .SetDeadline (time.Time {}); err != nil {
@@ -486,6 +516,9 @@ rollback:
486
516
// Unsubscribe removes client from registry, disconnects client if already
487
517
// connected and returns it's RegistryItem.
488
518
func (s * Server ) Unsubscribe (identifier id.ID ) * RegistryItem {
519
+ if s .config .SubscriptionListener != nil {
520
+ s .config .SubscriptionListener .Unsubscribed (identifier )
521
+ }
489
522
s .connPool .DeleteConn (identifier )
490
523
return s .registry .Unsubscribe (identifier )
491
524
}
@@ -561,6 +594,50 @@ func (s *Server) listen(l net.Listener, identifier id.ID) {
561
594
}
562
595
}
563
596
597
+ func (s * Server ) Upgrade (identifier id.ID , conn net.Conn , requestBytes []byte ) error {
598
+
599
+ var err error
600
+
601
+ msg := & proto.ControlMessage {
602
+ Action : proto .ActionProxy ,
603
+ ForwardedProto : "https" ,
604
+ }
605
+
606
+ tlsConn , ok := conn .(* tls.Conn )
607
+ if ok {
608
+ msg .ForwardedHost = tlsConn .ConnectionState ().ServerName
609
+ err = keepAlive (tlsConn .NetConn ())
610
+
611
+ } else {
612
+ msg .ForwardedHost = conn .RemoteAddr ().String ()
613
+ err = keepAlive (conn )
614
+ }
615
+
616
+ if err != nil {
617
+ s .logger .Log (
618
+ "level" , 1 ,
619
+ "msg" , "TCP keepalive for tunneled connection failed" ,
620
+ "identifier" , identifier ,
621
+ "ctrlMsg" , msg ,
622
+ "err" , err ,
623
+ )
624
+ }
625
+
626
+ go func () {
627
+ if err := s .proxyConnUpgraded (identifier , conn , msg , requestBytes ); err != nil {
628
+ s .logger .Log (
629
+ "level" , 0 ,
630
+ "msg" , "proxy error" ,
631
+ "identifier" , identifier ,
632
+ "ctrlMsg" , msg ,
633
+ "err" , err ,
634
+ )
635
+ }
636
+ }()
637
+
638
+ return nil
639
+ }
640
+
564
641
// ServeHTTP proxies http connection to the client.
565
642
func (s * Server ) ServeHTTP (w http.ResponseWriter , r * http.Request ) {
566
643
resp , err := s .RoundTrip (r )
@@ -639,6 +716,74 @@ func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) {
639
716
return s .proxyHTTP (identifier , outr , msg )
640
717
}
641
718
719
+ func (s * Server ) proxyConnUpgraded (identifier id.ID , conn net.Conn , msg * proto.ControlMessage , requestBytes []byte ) error {
720
+ s .logger .Log (
721
+ "level" , 2 ,
722
+ "action" , "proxy conn" ,
723
+ "identifier" , identifier ,
724
+ "ctrlMsg" , msg ,
725
+ )
726
+
727
+ defer conn .Close ()
728
+
729
+ pr , pw := io .Pipe ()
730
+ defer pr .Close ()
731
+ defer pw .Close ()
732
+
733
+ continueChan := make (chan int )
734
+
735
+ go func () {
736
+ pw .Write (requestBytes )
737
+ continueChan <- 1
738
+ }()
739
+
740
+ req , err := s .connectRequest (identifier , msg , pr )
741
+ if err != nil {
742
+ return err
743
+ }
744
+
745
+ ctx , cancel := context .WithCancel (context .Background ())
746
+ req = req .WithContext (ctx )
747
+
748
+ done := make (chan struct {})
749
+ go func () {
750
+ <- continueChan
751
+ transfer (pw , conn , log .NewContext (s .logger ).With (
752
+ "dir" , "user to client" ,
753
+ "dst" , identifier ,
754
+ "src" , conn .RemoteAddr (),
755
+ ))
756
+ cancel ()
757
+ close (done )
758
+ }()
759
+
760
+ resp , err := s .httpClient .Do (req )
761
+ if err != nil {
762
+ return fmt .Errorf ("io error: %s" , err )
763
+ }
764
+ defer resp .Body .Close ()
765
+
766
+ transfer (conn , resp .Body , log .NewContext (s .logger ).With (
767
+ "dir" , "client to user" ,
768
+ "dst" , conn .RemoteAddr (),
769
+ "src" , identifier ,
770
+ ))
771
+
772
+ select {
773
+ case <- done :
774
+ case <- time .After (DefaultTimeout ):
775
+ }
776
+
777
+ s .logger .Log (
778
+ "level" , 2 ,
779
+ "action" , "proxy conn done" ,
780
+ "identifier" , identifier ,
781
+ "ctrlMsg" , msg ,
782
+ )
783
+
784
+ return nil
785
+ }
786
+
642
787
func (s * Server ) proxyConn (identifier id.ID , conn net.Conn , msg * proto.ControlMessage ) error {
643
788
s .logger .Log (
644
789
"level" , 2 ,
0 commit comments