Commit a4f1db02 authored by Richer Maximilien's avatar Richer Maximilien Committed by Loïck Bonniot

[net] Check remote server certificate hash

- Add hash check to client connect()
- Update net.Connect calls
- Implement remote cert hash check
parent 663dc14c
Pipeline #943 passed with stage
......@@ -69,7 +69,7 @@ func (m *CreateManager) sendRequest() (*api.ErrorCode, error) {
return nil, err
}
conn, err := net.Connect(viper.GetString("platform_addrport"), cert, key, ca)
conn, err := net.Connect(viper.GetString("platform_addrport"), cert, key, ca, nil)
if err != nil {
return nil, err
}
......
......@@ -21,7 +21,7 @@ func FetchContract(passphrase, uuid, path string) error {
return err
}
conn, err := net.Connect(viper.GetString("platform_addrport"), cert, key, ca)
conn, err := net.Connect(viper.GetString("platform_addrport"), cert, key, ca, nil)
if err != nil {
return err
}
......
......@@ -83,7 +83,7 @@ func NewSignatureManager(passphrase string, c *contract.JSON) (*SignatureManager
m.cServer = m.GetServer()
go func() { log.Fatalln(net.Listen("0.0.0.0:"+strconv.Itoa(viper.GetInt("local_port")), m.cServer)) }()
conn, err := net.Connect(viper.GetString("platform_addrport"), m.auth.Cert, m.auth.Key, m.auth.CA)
conn, err := net.Connect(viper.GetString("platform_addrport"), m.auth.Cert, m.auth.Key, m.auth.CA, nil)
if err != nil {
return nil, err
}
......@@ -146,7 +146,8 @@ func (m *SignatureManager) addPeer(user *pAPI.User) (ready bool, err error) {
addrPort := user.Ip + ":" + strconv.Itoa(int(user.Port))
m.OnSignerStatusUpdate(user.Email, StatusConnecting, addrPort)
conn, err := net.Connect(addrPort, m.auth.Cert, m.auth.Key, m.auth.CA)
// This is an certificate authentificated TLS connection
conn, err := net.Connect(addrPort, m.auth.Cert, m.auth.Key, m.auth.CA, user.KeyHash)
if err != nil {
m.OnSignerStatusUpdate(user.Email, StatusError, err.Error())
return false, err
......
......@@ -34,7 +34,7 @@ func connect() (pb.PlatformClient, error) {
return nil, err
}
conn, err := net.Connect(viper.GetString("platform_addrport"), nil, nil, ca)
conn, err := net.Connect(viper.GetString("platform_addrport"), nil, nil, ca, nil)
if err != nil {
return nil, err
}
......
......@@ -68,7 +68,7 @@ func (w *Window) PrintQuantumInformation() {
return
}
quantum := float64(w.quantumField.Value()*1000)
quantum := float64(w.quantumField.Value() * 1000)
beginning := w.scene.Events[0].Date.UnixNano()
totalDuration := w.scene.Events[len(w.scene.Events)-1].Date.UnixNano() - beginning
......
......@@ -59,7 +59,7 @@ func clientTest(t *testing.T) api.PlatformClient {
cert, _ := auth.PEMToCertificate(certData)
key, _ := auth.EncryptedPEMToPrivateKey(keyData, "password")
conn, err := net.Connect("localhost:9090", cert, key, ca)
conn, err := net.Connect("localhost:9090", cert, key, ca, nil)
if err != nil {
t.Fatal("Unable to connect:", err)
}
......@@ -70,7 +70,7 @@ func clientTest(t *testing.T) api.PlatformClient {
func TestAddContractBadAuth(t *testing.T) {
caData, _ := ioutil.ReadFile(filepath.Join("..", "testdata", "dfssp_rootCA.pem"))
ca, _ := auth.PEMToCertificate(caData)
conn, err := net.Connect("localhost:9090", nil, nil, ca)
conn, err := net.Connect("localhost:9090", nil, nil, ca, nil)
if err != nil {
t.Fatal("Unable to connect:", err)
}
......
......@@ -16,7 +16,7 @@ const (
)
func clientTest(t *testing.T, hostPort string) api.PlatformClient {
conn, err := net.Connect(hostPort, nil, nil, rootCA)
conn, err := net.Connect(hostPort, nil, nil, rootCA, nil)
if err != nil {
t.Fatal("Unable to connect: ", err)
}
......
......@@ -273,7 +273,7 @@ func ExampleAuth() {
}
fmt.Println("User successfully inserted")
conn, err := net.Connect("localhost:9090", nil, nil, rootCA)
conn, err := net.Connect("localhost:9090", nil, nil, rootCA, nil)
if err != nil {
fmt.Println("Unable to connect: ", err)
}
......
package net
import (
"bytes"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
......@@ -8,19 +9,26 @@ import (
"net"
"time"
"dfss/auth"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
// DefaultTimeout should be used when a non-critical timeout is used in the application.
const DefaultTimeout = 30 * time.Second
var DefaultTimeout = 30 * time.Second
// Connect to a peer.
//
// Given parameters cert/key/ca are PEM-encoded array of bytes.
// Closing must be defered after call.
func Connect(addrPort string, cert *x509.Certificate, key *rsa.PrivateKey, ca *x509.Certificate) (*grpc.ClientConn, error) {
//
// The cert and key parameters can be set as nil for an unauthentified connection.
// If they are not, they will be provided to the remote server for authentification.
//
// serverCertHash will be matched against the remote server certificate.
// If nil, Connect will consider that the remote server is the root ca.
func Connect(addrPort string, cert *x509.Certificate, key *rsa.PrivateKey, ca *x509.Certificate, serverCertHash []byte) (*grpc.ClientConn, error) {
var certificates = make([]tls.Certificate, 1)
......@@ -42,10 +50,14 @@ func Connect(addrPort string, cert *x509.Certificate, key *rsa.PrivateKey, ca *x
InsecureSkipVerify: true, // Don't panic, it's normal and safe. See tlsCreds structure.
}
if serverCertHash == nil {
serverCertHash = auth.GetCertificateHash(ca)
}
// let's do the dialing !
return grpc.Dial(
addrPort,
grpc.WithTransportCredentials(&tlsCreds{config: conf}),
grpc.WithTransportCredentials(&tlsCreds{config: conf, serverCertHash: serverCertHash}),
grpc.WithTimeout(DefaultTimeout),
)
}
......@@ -58,7 +70,8 @@ func Connect(addrPort string, cert *x509.Certificate, key *rsa.PrivateKey, ca *x
//
// See crypto/tls/handshake_client.go and google.golang.org/grpc/credentials/credentials.go
type tlsCreds struct {
config tls.Config
config tls.Config
serverCertHash []byte
}
func (c *tlsCreds) Info() credentials.ProtocolInfo {
......@@ -117,6 +130,14 @@ func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.D
return nil, nil, err
}
if c.serverCertHash != nil {
// Additional check for the server cert hash
if !bytes.Equal(auth.GetCertificateHash(serverCert), c.serverCertHash) {
_ = rawConn.Close()
return nil, nil, errors.New("credentials: Bad remote certificate hash")
}
}
return conn, nil, nil
}
......
......@@ -8,7 +8,6 @@ import (
"dfss/auth"
pb "dfss/net/fixtures"
"golang.org/x/net/context"
)
......@@ -75,10 +74,10 @@ func (s *testServer) Auth(ctx context.Context, in *pb.Empty) (*pb.IsAuth, error)
}
func startTestServer(c chan bool) {
ca, _ := auth.PEMToCertificate([]byte(caFixture))
key, _ := auth.PEMToPrivateKey([]byte(serverKeyFixture))
DefaultTimeout = 5 * time.Second
server := NewServer(ca, key, ca)
pb.RegisterTestServer(server, &testServer{})
go func() {
......@@ -104,8 +103,7 @@ func TestServerOnly(t *testing.T) {
time.Sleep(100 * time.Millisecond)
}
func TestServerClientAuth(t *testing.T) {
func TestServerClientAuthNoCertificateHash(t *testing.T) {
// Start server
c := make(chan bool)
go startTestServer(c)
......@@ -115,7 +113,7 @@ func TestServerClientAuth(t *testing.T) {
cert, _ := auth.PEMToCertificate([]byte(clientCertFixture))
key, _ := auth.PEMToPrivateKey([]byte(clientKeyFixture))
conn, err := Connect("localhost:9000", cert, key, ca)
conn, err := Connect("localhost:9000", cert, key, ca, nil)
if err != nil {
t.Fatal("Unable to connect:", err)
......@@ -129,15 +127,39 @@ func TestServerClientAuth(t *testing.T) {
time.Sleep(100 * time.Millisecond)
}
func TestServerClientNonAuth(t *testing.T) {
func TestServerClientBadAuthWrongCertificateHash(t *testing.T) {
// Start server
c := make(chan bool)
go startTestServer(c)
time.Sleep(2 * time.Second)
ca, _ := auth.PEMToCertificate([]byte(caFixture))
cert, _ := auth.PEMToCertificate([]byte(clientCertFixture))
key, _ := auth.PEMToPrivateKey([]byte(clientKeyFixture))
// The connection won't fail immediately, because grpc is connecting in the background
conn, _ := Connect("localhost:9000", cert, key, ca, auth.GetCertificateHash(cert))
// We have to ask for a specific query in order to test the certificate hash
client := pb.NewTestClient(conn)
_, err := client.Ping(context.Background(), &pb.Hop{Id: 1})
if err == nil {
t.Fatal("Successfully connected with bad hash")
}
c <- true
time.Sleep(100 * time.Millisecond)
}
func TestServerClientNonAuth(t *testing.T) {
// Start server
c := make(chan bool)
go startTestServer(c)
time.Sleep(2 * time.Second)
ca, _ := auth.PEMToCertificate([]byte(caFixture))
conn, err := Connect("localhost:9000", nil, nil, ca)
conn, err := Connect("localhost:9000", nil, nil, ca, nil)
if err != nil {
t.Fatal("Unable to connect:", err)
......@@ -172,7 +194,6 @@ func sharedServerClientTest(t *testing.T, client pb.TestClient, expectedAuth boo
// EXAMPLE
func Example() {
// Load certs and private keys
ca, _ := auth.PEMToCertificate([]byte(caFixture))
cert, _ := auth.PEMToCertificate([]byte(clientCertFixture))
......@@ -191,7 +212,7 @@ func Example() {
// Start an authentified client
// The second and third arguments can be empty for non-auth connection
conn, err := Connect("localhost:9000", cert, ckey, ca)
conn, err := Connect("localhost:9000", cert, ckey, ca, auth.GetCertificateHash(ca))
if err != nil {
panic("Unable to connect")
}
......@@ -207,7 +228,7 @@ func Example() {
fmt.Println((*r).Id)
// Start a non-authentified client
conn, err = Connect("localhost:9000", nil, nil, ca)
conn, err = Connect("localhost:9000", nil, nil, ca, nil)
if err != nil {
panic("Unable to connect")
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment