client.go 3.25 KB
Newer Older
Richer Maximilien's avatar
Richer Maximilien committed
1
2
3
package net

import (
4
	"crypto/rsa"
Richer Maximilien's avatar
Richer Maximilien committed
5
6
	"crypto/tls"
	"crypto/x509"
7
8
9
	"errors"
	"net"
	"time"
Richer Maximilien's avatar
Richer Maximilien committed
10

11
	"golang.org/x/net/context"
Richer Maximilien's avatar
Richer Maximilien committed
12
13
14
15
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials"
)

16
17
18
// DefaultTimeout should be used when a non-critical timeout is used in the application.
const DefaultTimeout = 30 * time.Second

Richer Maximilien's avatar
Richer Maximilien committed
19
// Connect to a peer.
Richer Maximilien's avatar
Richer Maximilien committed
20
//
Richer Maximilien's avatar
Richer Maximilien committed
21
22
// Given parameters cert/key/ca are PEM-encoded array of bytes.
// Closing must be defered after call.
23
func Connect(addrPort string, cert *x509.Certificate, key *rsa.PrivateKey, ca *x509.Certificate) (*grpc.ClientConn, error) {
24
25
26

	var certificates = make([]tls.Certificate, 1)

27
28
29
30
	if key != nil && cert != nil {
		peerCert := tls.Certificate{
			Certificate: [][]byte{cert.Raw},
			PrivateKey:  key,
31
32
		}
		certificates = append(certificates, peerCert)
Richer Maximilien's avatar
Richer Maximilien committed
33
	}
34

Richer Maximilien's avatar
Richer Maximilien committed
35
	caCertPool := x509.NewCertPool()
36
	caCertPool.AddCert(ca)
Richer Maximilien's avatar
Richer Maximilien committed
37
38

	// configure transport authentificator
39
40
41
42
43
	conf := tls.Config{
		Certificates:       certificates,
		RootCAs:            caCertPool,
		InsecureSkipVerify: true, // Don't panic, it's normal and safe. See tlsCreds structure.
	}
Richer Maximilien's avatar
Richer Maximilien committed
44
45

	// let's do the dialing !
46
47
48
	return grpc.Dial(
		addrPort,
		grpc.WithTransportCredentials(&tlsCreds{config: conf}),
49
		grpc.WithTimeout(DefaultTimeout),
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
	)
}

// tlsCreds reimplements the default grpc TLS authenticator with no hostname verification.
// It is required because we need to connect to clients with their IP, and there is no IP SANs in our certificates.
//
// We need to enable the "InsecureSkipVerify" to perform this, that's why it's important to check the server certificate
// during the authentication process.
//
// See crypto/tls/handshake_client.go and google.golang.org/grpc/credentials/credentials.go
type tlsCreds struct {
	config tls.Config
}

func (c *tlsCreds) Info() credentials.ProtocolInfo {
	return credentials.ProtocolInfo{
		SecurityProtocol: "tls",
		SecurityVersion:  "1.2",
	}
}

func (c *tlsCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
	return nil, nil
}

func (c *tlsCreds) RequireTransportSecurity() bool {
	return true
}

func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ credentials.AuthInfo, err error) {
	var errChannel chan error
	if timeout != 0 {
		errChannel = make(chan error, 2)
		time.AfterFunc(timeout, func() {
			errChannel <- errors.New("credentials: Dial timed out")
		})
	}

	// Establish a secure connection WITHOUT certificate verification
	conn := tls.Client(rawConn, &c.config)
	if timeout == 0 {
		err = conn.Handshake()
	} else {
		go func() { errChannel <- conn.Handshake() }()
		err = <-errChannel
	}

	if err != nil { // Error during handshake
		_ = rawConn.Close()
		return nil, nil, err
	}

	// Successful handshake, BUT we have to authentify the server NOW
	opts := x509.VerifyOptions{
		Roots:       c.config.RootCAs,
		CurrentTime: time.Now(),
	}

	var chains [][]*x509.Certificate

	state := conn.ConnectionState()
	serverCert := state.PeerCertificates[0]
	chains, err = serverCert.Verify(opts)
	state.VerifiedChains = chains

	if err != nil {
		_ = rawConn.Close()
		return nil, nil, err
	}

	return conn, nil, nil
}

func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
	return nil, nil, errors.New("Server side handshake not implemented")
Richer Maximilien's avatar
Richer Maximilien committed
125
}