client.go 4.79 KB
Newer Older
Richer Maximilien's avatar
Richer Maximilien committed
1
2
3
package net

import (
4
	"bytes"
5
	"crypto/rsa"
Richer Maximilien's avatar
Richer Maximilien committed
6
7
	"crypto/tls"
	"crypto/x509"
8
9
	"errors"
	"net"
10
	"strings"
11
	"time"
Richer Maximilien's avatar
Richer Maximilien committed
12

13
	"dfss/auth"
14
	"golang.org/x/net/context"
Richer Maximilien's avatar
Richer Maximilien committed
15
16
17
18
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials"
)

19
// DefaultTimeout should be used when a non-critical timeout is used in the application.
20
var DefaultTimeout = 10 * time.Second
21

Richer Maximilien's avatar
Richer Maximilien committed
22
// Connect to a peer.
Richer Maximilien's avatar
Richer Maximilien committed
23
//
Richer Maximilien's avatar
Richer Maximilien committed
24
25
// Given parameters cert/key/ca are PEM-encoded array of bytes.
// Closing must be defered after call.
26
27
28
29
30
31
32
//
// The cert and key parameters can be set as nil for an unauthentified connection.
// If they are not, they will be provided to the remote server for authentification.
//
// serverCertHash will be matched against the remote server certificate.
// If nil, Connect will consider that the remote server is the root ca.
func Connect(addrPort string, cert *x509.Certificate, key *rsa.PrivateKey, ca *x509.Certificate, serverCertHash []byte) (*grpc.ClientConn, error) {
33
34
35

	var certificates = make([]tls.Certificate, 1)

36
37
38
39
	if key != nil && cert != nil {
		peerCert := tls.Certificate{
			Certificate: [][]byte{cert.Raw},
			PrivateKey:  key,
40
41
		}
		certificates = append(certificates, peerCert)
Richer Maximilien's avatar
Richer Maximilien committed
42
	}
43

Richer Maximilien's avatar
Richer Maximilien committed
44
	caCertPool := x509.NewCertPool()
45
	caCertPool.AddCert(ca)
Richer Maximilien's avatar
Richer Maximilien committed
46
47

	// configure transport authentificator
48
49
50
51
52
	conf := tls.Config{
		Certificates:       certificates,
		RootCAs:            caCertPool,
		InsecureSkipVerify: true, // Don't panic, it's normal and safe. See tlsCreds structure.
	}
Richer Maximilien's avatar
Richer Maximilien committed
53

54
55
56
57
	if serverCertHash == nil {
		serverCertHash = auth.GetCertificateHash(ca)
	}

Richer Maximilien's avatar
Richer Maximilien committed
58
	// let's do the dialing !
59
60
	return grpc.Dial(
		addrPort,
61
		grpc.WithTransportCredentials(&tlsCreds{config: conf, serverCertHash: serverCertHash}),
62
		grpc.WithTimeout(DefaultTimeout),
63
64
65
66
67
68
69
70
71
72
73
	)
}

// tlsCreds reimplements the default grpc TLS authenticator with no hostname verification.
// It is required because we need to connect to clients with their IP, and there is no IP SANs in our certificates.
//
// We need to enable the "InsecureSkipVerify" to perform this, that's why it's important to check the server certificate
// during the authentication process.
//
// See crypto/tls/handshake_client.go and google.golang.org/grpc/credentials/credentials.go
type tlsCreds struct {
74
75
	config         tls.Config
	serverCertHash []byte
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
}

func (c *tlsCreds) Info() credentials.ProtocolInfo {
	return credentials.ProtocolInfo{
		SecurityProtocol: "tls",
		SecurityVersion:  "1.2",
	}
}

func (c *tlsCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
	return nil, nil
}

func (c *tlsCreds) RequireTransportSecurity() bool {
	return true
}

func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ credentials.AuthInfo, err error) {
	var errChannel chan error
	if timeout != 0 {
		errChannel = make(chan error, 2)
		time.AfterFunc(timeout, func() {
			errChannel <- errors.New("credentials: Dial timed out")
		})
	}

	// Establish a secure connection WITHOUT certificate verification
	conn := tls.Client(rawConn, &c.config)
	if timeout == 0 {
		err = conn.Handshake()
	} else {
		go func() { errChannel <- conn.Handshake() }()
		err = <-errChannel
	}

	if err != nil { // Error during handshake
		_ = rawConn.Close()
		return nil, nil, err
	}

	// Successful handshake, BUT we have to authentify the server NOW
	opts := x509.VerifyOptions{
		Roots:       c.config.RootCAs,
		CurrentTime: time.Now(),
	}

	var chains [][]*x509.Certificate

	state := conn.ConnectionState()
	serverCert := state.PeerCertificates[0]
	chains, err = serverCert.Verify(opts)
	state.VerifiedChains = chains

	if err != nil {
		_ = rawConn.Close()
		return nil, nil, err
	}

134
135
136
137
138
139
140
141
	if c.serverCertHash != nil {
		// Additional check for the server cert hash
		if !bytes.Equal(auth.GetCertificateHash(serverCert), c.serverCertHash) {
			_ = rawConn.Close()
			return nil, nil, errors.New("credentials: Bad remote certificate hash")
		}
	}

142
143
144
145
146
	return conn, nil, nil
}

func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
	return nil, nil, errors.New("Server side handshake not implemented")
Richer Maximilien's avatar
Richer Maximilien committed
147
}
148
149
150
151
152
153
154
155
156

// ExternalInterfaceAddr returns a list of the system's network interface addresses
// Returns only ipv4 address if there is a lo interface, it is put at the end
func ExternalInterfaceAddr() ([]string, error) {
	addrs, err := net.InterfaceAddrs()
	if err != nil {
		return nil, err
	}

157
158
	var extAddrs = make([]string, 0)
	var localhostAddrs = make([]string, 0)
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174

	for _, a := range addrs {
		if strings.ContainsRune(a.String(), ':') {
			// ipv6, do nothing
		} else if strings.ContainsRune(a.String(), '/') {
			ip := strings.Split(a.String(), "/")[0]
			if strings.HasPrefix(ip, "127") {
				// move localhost ip at the end if present
				localhostAddrs = append(localhostAddrs, ip)
			} else {
				extAddrs = append(extAddrs, ip)
			}
		}
	}
	return append(extAddrs, localhostAddrs...), nil
}