Skip to content

Commit 9da0263

Browse files
Ivano Culminemmatczuk
authored andcommitted
Added support for subscription listeners, dynamic subscription authorization and to upgrade a http connection for websocket support
1 parent 77db4b5 commit 9da0263

File tree

1 file changed

+150
-5
lines changed

1 file changed

+150
-5
lines changed

server.go

Lines changed: 150 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package tunnel
77
import (
88
"context"
99
"crypto/tls"
10+
"crypto/x509"
1011
"encoding/json"
1112
"errors"
1213
"fmt"
@@ -24,6 +25,20 @@ import (
2425
"github.com/mmatczuk/go-http-tunnel/proto"
2526
)
2627

28+
// A set of listeners to manage subscribers
29+
type SubscriptionListener interface {
30+
// Invoked if AutoSubscribe is false and must return true if the client is allowed to subscribe or not.
31+
// If the tlsConfig is configured to require client certificate validation, chain will contain the first
32+
// verified chain, else the presented peer certificate.
33+
CanSubscribe(id id.ID, chain []*x509.Certificate) bool
34+
// Invoked when the client has been subscribed.
35+
// If the tlsConfig is configured to require client certificate validation, chain will contain the first
36+
// verified chain, else the presented peer certificate.
37+
Subscribed(id id.ID, tlsConn *tls.Conn, chain []*x509.Certificate)
38+
// Invoked before the client is unsubscribed.
39+
Unsubscribed(id id.ID)
40+
}
41+
2742
// ServerConfig defines configuration for the Server.
2843
type ServerConfig struct {
2944
// Addr is TCP address to listen for client connections. If empty ":0"
@@ -41,6 +56,8 @@ type ServerConfig struct {
4156
Logger log.Logger
4257
// Addr is TCP address to listen for TLS SNI connections
4358
SNIAddr string
59+
// Optional listener to manage subscribers
60+
SubscriptionListener SubscriptionListener
4461
}
4562

4663
// Server is responsible for proxying public connections to the client over a
@@ -238,6 +255,7 @@ func (s *Server) handleClient(conn net.Conn) {
238255
ok bool
239256

240257
inConnPool bool
258+
certs []*x509.Certificate
241259
)
242260

243261
tlsConn, ok := conn.(*tls.Conn)
@@ -262,14 +280,26 @@ func (s *Server) handleClient(conn net.Conn) {
262280

263281
logger = logger.With("identifier", identifier)
264282

283+
certs = tlsConn.ConnectionState().PeerCertificates
284+
if tlsConn.ConnectionState().VerifiedChains != nil && len(tlsConn.ConnectionState().VerifiedChains) > 0 {
285+
certs = tlsConn.ConnectionState().VerifiedChains[0]
286+
}
265287
if s.config.AutoSubscribe {
266288
s.Subscribe(identifier)
289+
if s.config.SubscriptionListener != nil {
290+
s.config.SubscriptionListener.Subscribed(identifier, tlsConn, certs)
291+
}
267292
} else if !s.IsSubscribed(identifier) {
268-
logger.Log(
269-
"level", 2,
270-
"msg", "unknown client",
271-
)
272-
goto reject
293+
if s.config.SubscriptionListener != nil && s.config.SubscriptionListener.CanSubscribe(identifier, certs) {
294+
s.Subscribe(identifier)
295+
s.config.SubscriptionListener.Subscribed(identifier, tlsConn, certs)
296+
} else {
297+
logger.Log(
298+
"level", 2,
299+
"msg", "unknown client",
300+
)
301+
goto reject
302+
}
273303
}
274304

275305
if err = conn.SetDeadline(time.Time{}); err != nil {
@@ -486,6 +516,9 @@ rollback:
486516
// Unsubscribe removes client from registry, disconnects client if already
487517
// connected and returns it's RegistryItem.
488518
func (s *Server) Unsubscribe(identifier id.ID) *RegistryItem {
519+
if s.config.SubscriptionListener != nil {
520+
s.config.SubscriptionListener.Unsubscribed(identifier)
521+
}
489522
s.connPool.DeleteConn(identifier)
490523
return s.registry.Unsubscribe(identifier)
491524
}
@@ -561,6 +594,50 @@ func (s *Server) listen(l net.Listener, identifier id.ID) {
561594
}
562595
}
563596

597+
func (s *Server) Upgrade(identifier id.ID, conn net.Conn, requestBytes []byte) error {
598+
599+
var err error
600+
601+
msg := &proto.ControlMessage{
602+
Action: proto.ActionProxy,
603+
ForwardedProto: "https",
604+
}
605+
606+
tlsConn, ok := conn.(*tls.Conn)
607+
if ok {
608+
msg.ForwardedHost = tlsConn.ConnectionState().ServerName
609+
err = keepAlive(tlsConn.NetConn())
610+
611+
} else {
612+
msg.ForwardedHost = conn.RemoteAddr().String()
613+
err = keepAlive(conn)
614+
}
615+
616+
if err != nil {
617+
s.logger.Log(
618+
"level", 1,
619+
"msg", "TCP keepalive for tunneled connection failed",
620+
"identifier", identifier,
621+
"ctrlMsg", msg,
622+
"err", err,
623+
)
624+
}
625+
626+
go func() {
627+
if err := s.proxyConnUpgraded(identifier, conn, msg, requestBytes); err != nil {
628+
s.logger.Log(
629+
"level", 0,
630+
"msg", "proxy error",
631+
"identifier", identifier,
632+
"ctrlMsg", msg,
633+
"err", err,
634+
)
635+
}
636+
}()
637+
638+
return nil
639+
}
640+
564641
// ServeHTTP proxies http connection to the client.
565642
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
566643
resp, err := s.RoundTrip(r)
@@ -639,6 +716,74 @@ func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) {
639716
return s.proxyHTTP(identifier, outr, msg)
640717
}
641718

719+
func (s *Server) proxyConnUpgraded(identifier id.ID, conn net.Conn, msg *proto.ControlMessage, requestBytes []byte) error {
720+
s.logger.Log(
721+
"level", 2,
722+
"action", "proxy conn",
723+
"identifier", identifier,
724+
"ctrlMsg", msg,
725+
)
726+
727+
defer conn.Close()
728+
729+
pr, pw := io.Pipe()
730+
defer pr.Close()
731+
defer pw.Close()
732+
733+
continueChan := make(chan int)
734+
735+
go func() {
736+
pw.Write(requestBytes)
737+
continueChan <- 1
738+
}()
739+
740+
req, err := s.connectRequest(identifier, msg, pr)
741+
if err != nil {
742+
return err
743+
}
744+
745+
ctx, cancel := context.WithCancel(context.Background())
746+
req = req.WithContext(ctx)
747+
748+
done := make(chan struct{})
749+
go func() {
750+
<-continueChan
751+
transfer(pw, conn, log.NewContext(s.logger).With(
752+
"dir", "user to client",
753+
"dst", identifier,
754+
"src", conn.RemoteAddr(),
755+
))
756+
cancel()
757+
close(done)
758+
}()
759+
760+
resp, err := s.httpClient.Do(req)
761+
if err != nil {
762+
return fmt.Errorf("io error: %s", err)
763+
}
764+
defer resp.Body.Close()
765+
766+
transfer(conn, resp.Body, log.NewContext(s.logger).With(
767+
"dir", "client to user",
768+
"dst", conn.RemoteAddr(),
769+
"src", identifier,
770+
))
771+
772+
select {
773+
case <-done:
774+
case <-time.After(DefaultTimeout):
775+
}
776+
777+
s.logger.Log(
778+
"level", 2,
779+
"action", "proxy conn done",
780+
"identifier", identifier,
781+
"ctrlMsg", msg,
782+
)
783+
784+
return nil
785+
}
786+
642787
func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMessage) error {
643788
s.logger.Log(
644789
"level", 2,

0 commit comments

Comments
 (0)