Commit eacf1637 authored by Loïck Bonniot's avatar Loïck Bonniot

Merge branch '411_ttp_mutex' into 'master'

Add mutex on ttp

Prevent the problems due to concurrent access to the ttp

See merge request !82
parents d05936eb 0f175c74
Pipeline #2127 passed with stages
...@@ -18,28 +18,37 @@ import ( ...@@ -18,28 +18,37 @@ import (
"github.com/spf13/viper" "github.com/spf13/viper"
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc" "google.golang.org/grpc"
"gopkg.in/mgo.v2/bson"
) )
// InternalError : constant string used to return a generic error message through gRPC in case of an internal error. // InternalError : constant string used to return a generic error message through gRPC in case of an internal error.
const InternalError string = "Internal server error" const InternalError string = "Internal server error"
var mutex sync.Mutex
type ttpServer struct { type ttpServer struct {
DB *mgdb.MongoManager DB *mgdb.MongoManager
globalMut *sync.Mutex
mutMap map[bson.ObjectId]*sync.Mutex
} }
// Alert route for the TTP. // Alert route for the TTP.
func (server *ttpServer) Alert(ctx context.Context, in *tAPI.AlertRequest) (*tAPI.TTPResponse, error) { func (server *ttpServer) Alert(ctx context.Context, in *tAPI.AlertRequest) (*tAPI.TTPResponse, error) {
mutex.Lock()
defer mutex.Unlock()
valid, signatureUUID, signers, senderIndex := entities.IsRequestValid(ctx, in.Promises) valid, signatureUUID, signers, senderIndex := entities.IsRequestValid(ctx, in.Promises)
if !valid { if !valid {
dAPI.DLog("invalid request from " + net.GetCN(&ctx)) dAPI.DLog("invalid request from " + net.GetCN(&ctx))
return nil, errors.New(InternalError) return nil, errors.New(InternalError)
} }
server.globalMut.Lock()
_, ok := server.mutMap[signatureUUID]
if !ok {
server.mutMap[signatureUUID] = &sync.Mutex{}
}
server.mutMap[signatureUUID].Lock()
server.globalMut.Unlock()
defer server.mutMap[signatureUUID].Unlock()
dAPI.DLog("resolve index is: " + fmt.Sprint(in.Index)) dAPI.DLog("resolve index is: " + fmt.Sprint(in.Index))
valid = int(in.Index) < len(in.Promises[0].Context.Sequence) valid = int(in.Index) < len(in.Promises[0].Context.Sequence)
if !valid { if !valid {
...@@ -91,7 +100,7 @@ func (server *ttpServer) Alert(ctx context.Context, in *tAPI.AlertRequest) (*tAP ...@@ -91,7 +100,7 @@ func (server *ttpServer) Alert(ctx context.Context, in *tAPI.AlertRequest) (*tAP
// Try to generate the contract now // Try to generate the contract now
message, err = server.handleContractGenerationTry(manager) message, err = server.handleContractGenerationTry(manager)
// We manually update the database // We manually update the database
ok, err := server.DB.Get("signatures").UpdateByID(*(manager.Archives)) ok, err = server.DB.Get("signatures").UpdateByID(*(manager.Archives))
if !ok { if !ok {
dAPI.DLog("error during 'UpdateByID' l.81" + fmt.Sprint(err.Error())) dAPI.DLog("error during 'UpdateByID' l.81" + fmt.Sprint(err.Error()))
return nil, errors.New(InternalError) return nil, errors.New(InternalError)
...@@ -235,9 +244,14 @@ func GetServer() *grpc.Server { ...@@ -235,9 +244,14 @@ func GetServer() *grpc.Server {
os.Exit(2) os.Exit(2)
} }
mutmap := make(map[bson.ObjectId]*sync.Mutex)
server := &ttpServer{ server := &ttpServer{
DB: dbManager, DB: dbManager,
globalMut: &sync.Mutex{},
mutMap: mutmap,
} }
netServer := net.NewServer(cert, key, ca) netServer := net.NewServer(cert, key, ca)
tAPI.RegisterTTPServer(netServer, server) tAPI.RegisterTTPServer(netServer, server)
return netServer return netServer
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment