Skip to content

Commit 65e2caf

Browse files
tengu-altjoao-r-reis
authored andcommitted
Refactor HostInfo creation and ConnectAddress() method
HostInfo struct creation was refactored to create via constructor to make sure the connectAddress is valid. Panic in case of invalid connect address inside of ConnectAddress() method was removed. patch by Oleksandr Luzhniy; reviewed by João Reis, James Hartig, for CASSGO-45
1 parent f3d13d4 commit 65e2caf

File tree

4 files changed

+44
-23
lines changed

4 files changed

+44
-23
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3535

3636
- Standardized spelling of datacenter (CASSGO-35)
3737

38+
- Refactor HostInfo creation and ConnectAddress() method (CASSGO-45)
39+
3840
### Fixed
3941
- Cassandra version unmarshal fix (CASSGO-49)
4042

conn.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1690,7 +1690,11 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) {
16901690
}
16911691

16921692
for _, row := range rows {
1693-
host, err := c.session.hostInfoFromMap(row, &HostInfo{connectAddress: c.host.ConnectAddress(), port: c.session.cfg.Port})
1693+
h, err := newHostInfo(c.host.ConnectAddress(), c.session.cfg.Port)
1694+
if err != nil {
1695+
goto cont
1696+
}
1697+
host, err := c.session.hostInfoFromMap(row, h)
16941698
if err != nil {
16951699
goto cont
16961700
}

control.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,11 @@ func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) {
146146

147147
// Check if host is a literal IP address
148148
if ip := net.ParseIP(host); ip != nil {
149-
hosts = append(hosts, &HostInfo{hostname: host, connectAddress: ip, port: port})
149+
h, err := newHostInfo(ip, port)
150+
if err != nil {
151+
return nil, err
152+
}
153+
hosts = append(hosts, h)
150154
return hosts, nil
151155
}
152156

@@ -172,7 +176,12 @@ func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) {
172176
}
173177

174178
for _, ip := range ips {
175-
hosts = append(hosts, &HostInfo{hostname: host, connectAddress: ip, port: port})
179+
h, err := newHostInfo(ip, port)
180+
if err != nil {
181+
return nil, err
182+
}
183+
184+
hosts = append(hosts, h)
176185
}
177186

178187
return hosts, nil

host_source.go

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,18 @@ type HostInfo struct {
181181
tokens []string
182182
}
183183

184+
func newHostInfo(addr net.IP, port int) (*HostInfo, error) {
185+
if !validIpAddr(addr) {
186+
return nil, errors.New("invalid host address")
187+
}
188+
host := &HostInfo{}
189+
host.hostname = addr.String()
190+
host.port = port
191+
192+
host.connectAddress = addr
193+
return host, nil
194+
}
195+
184196
func (h *HostInfo) Equal(host *HostInfo) bool {
185197
if h == host {
186198
// prevent rlock reentry
@@ -213,14 +225,12 @@ func (h *HostInfo) connectAddressLocked() (net.IP, string) {
213225
} else if validIpAddr(h.rpcAddress) {
214226
return h.rpcAddress, "rpc_adress"
215227
} else if validIpAddr(h.preferredIP) {
216-
// where does perferred_ip get set?
217228
return h.preferredIP, "preferred_ip"
218229
} else if validIpAddr(h.broadcastAddress) {
219230
return h.broadcastAddress, "broadcast_address"
220-
} else if validIpAddr(h.peer) {
221-
return h.peer, "peer"
222231
}
223-
return net.IPv4zero, "invalid"
232+
return h.peer, "peer"
233+
224234
}
225235

226236
// nodeToNodeAddress returns address broadcasted between node to nodes.
@@ -240,24 +250,13 @@ func (h *HostInfo) nodeToNodeAddress() net.IP {
240250
}
241251

242252
// Returns the address that should be used to connect to the host.
243-
// If you wish to override this, use an AddressTranslator or
244-
// use a HostFilter to SetConnectAddress()
253+
// If you wish to override this, use an AddressTranslator
245254
func (h *HostInfo) ConnectAddress() net.IP {
246255
h.mu.RLock()
247256
defer h.mu.RUnlock()
248257

249-
if addr, _ := h.connectAddressLocked(); validIpAddr(addr) {
250-
return addr
251-
}
252-
panic(fmt.Sprintf("no valid connect address for host: %v. Is your cluster configured correctly?", h))
253-
}
254-
255-
func (h *HostInfo) SetConnectAddress(address net.IP) *HostInfo {
256-
// TODO(zariel): should this not be exported?
257-
h.mu.Lock()
258-
defer h.mu.Unlock()
259-
h.connectAddress = address
260-
return h
258+
addr, _ := h.connectAddressLocked()
259+
return addr
261260
}
262261

263262
func (h *HostInfo) BroadcastAddress() net.IP {
@@ -491,6 +490,10 @@ func checkSystemSchema(control *controlConn) (bool, error) {
491490
return true, nil
492491
}
493492

493+
func (s *Session) newHostInfoFromMap(addr net.IP, port int, row map[string]interface{}) (*HostInfo, error) {
494+
return s.hostInfoFromMap(row, &HostInfo{connectAddress: addr, port: port})
495+
}
496+
494497
// Given a map that represents a row from either system.local or system.peers
495498
// return as much information as we can in *HostInfo
496499
func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (*HostInfo, error) {
@@ -606,6 +609,9 @@ func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (*
606609
}
607610

608611
ip, port := s.cfg.translateAddressPort(host.ConnectAddress(), host.port)
612+
if !validIpAddr(ip) {
613+
return nil, fmt.Errorf("invalid host address (before translation: %v:%v, after translation: %v:%v)", host.ConnectAddress(), host.port, ip.String(), port)
614+
}
609615
host.connectAddress = ip
610616
host.port = port
611617

@@ -623,7 +629,7 @@ func (s *Session) hostInfoFromIter(iter *Iter, connectAddress net.IP, defaultPor
623629
return nil, errors.New("query returned 0 rows")
624630
}
625631

626-
host, err := s.hostInfoFromMap(rows[0], &HostInfo{connectAddress: connectAddress, port: defaultPort})
632+
host, err := s.newHostInfoFromMap(connectAddress, defaultPort, rows[0])
627633
if err != nil {
628634
return nil, err
629635
}
@@ -674,7 +680,7 @@ func (r *ringDescriber) getClusterPeerInfo(localHost *HostInfo) ([]*HostInfo, er
674680

675681
for _, row := range rows {
676682
// extract all available info about the peer
677-
host, err := r.session.hostInfoFromMap(row, &HostInfo{port: r.session.cfg.Port})
683+
host, err := r.session.newHostInfoFromMap(nil, r.session.cfg.Port, row)
678684
if err != nil {
679685
return nil, err
680686
} else if !isValidPeer(host) {

0 commit comments

Comments
 (0)