diff --git a/Raft.Test/TestServer.fs b/Raft.Test/TestServer.fs index 4e6526d..5cfd256 100644 --- a/Raft.Test/TestServer.fs +++ b/Raft.Test/TestServer.fs @@ -69,18 +69,23 @@ module TestServer = // We sent a message to every other server; process them. for i in 1..4 do - network.InboundMessages.[i].Count |> shouldEqual 1 - let message = network.InboundMessages.[i].[0] - network.InboundMessages.[i].Clear () - cluster.SendMessageDirectly (i * 1) message + let server = i * 1 + (network.AllInboundMessages server).Length |> shouldEqual 1 + let message = network.InboundMessage server 0 + network.DropMessage server 0 + cluster.SendMessageDirectly server message - network.InboundMessages.[0].Count |> shouldEqual i + (network.AllInboundMessages 0).Length |> shouldEqual i for i in 1..4 do - cluster.SendMessageDirectly 0 network.InboundMessages.[0].[i - 1] + network.InboundMessage 0 (i - 1) + |> cluster.SendMessageDirectly 0 + + network.DropMessage 0 (i - 1) + // (the messages we've already processed) - network.InboundMessages.[0].Count |> shouldEqual 4 - network.InboundMessages.[0].Clear () + (network.AllInboundMessages 0).Length |> shouldEqual 4 + (network.UndeliveredMessages 0).Length |> shouldEqual 0 cluster.Servers.[0].State |> shouldEqual ServerStatus.Leader @@ -159,22 +164,24 @@ module TestServer = cluster.Servers.[1].Sync () // Those two each sent a message to every other server. - network.InboundMessages.[0].Count |> shouldEqual 1 - network.InboundMessages.[1].Count |> shouldEqual 1 + (network.AllInboundMessages 0).Length |> shouldEqual 1 + (network.AllInboundMessages 1).Length |> shouldEqual 1 for i in 2..4 do - network.InboundMessages.[i].Count |> shouldEqual 2 + let server = i * 1 + (network.AllInboundMessages server).Length |> shouldEqual 2 - while network.InboundMessages |> Seq.concat |> Seq.isEmpty |> not do + while network.AllUndeliveredMessages () |> Seq.concat |> Seq.isEmpty |> not do let allOrderings' = - network.InboundMessages |> List.ofArray |> List.map List.ofSeq |> allOrderings + network.AllUndeliveredMessages () |> List.map List.ofSeq |> allOrderings - network.InboundMessages |> Array.iter (fun arr -> arr.Clear ()) // Process the messages! let ordering = randomChoice rand allOrderings' - for serverConsuming, message in ordering do - cluster.SendMessageDirectly (serverConsuming * 1) message + for serverConsuming, (messageId, message) in ordering do + let serverConsuming = serverConsuming * 1 + cluster.SendMessageDirectly serverConsuming message + network.DropMessage serverConsuming messageId (cluster.Servers.[0].State = Leader && cluster.Servers.[1].State = Leader) |> shouldEqual false @@ -202,9 +209,9 @@ module TestServer = let apply (History history) (cluster : Cluster<'a>) (network : Network<'a>) : unit = for pile, entry in history do - let messages = network.InboundMessages.[pile / 1] + let messages = network.AllInboundMessages pile - if entry < messages.Count then + if entry < messages.Length then cluster.SendMessageDirectly pile messages.[entry] [] @@ -217,11 +224,11 @@ module TestServer = cluster.Servers.[1].Sync () // Those two each sent a message to every other server. - network.InboundMessages.[0].Count |> shouldEqual 1 - network.InboundMessages.[1].Count |> shouldEqual 1 + (network.AllInboundMessages 0).Length |> shouldEqual 1 + (network.AllInboundMessages 1).Length |> shouldEqual 1 for i in 2..4 do - network.InboundMessages.[i].Count |> shouldEqual 2 + (network.AllInboundMessages (i * 1)).Length |> shouldEqual 2 let property (history : History) = apply history cluster network diff --git a/Raft/Server.fs b/Raft/Server.fs index 71a4eea..23213ea 100644 --- a/Raft/Server.fs +++ b/Raft/Server.fs @@ -1,5 +1,7 @@ namespace Raft +open System.Collections.Generic + /// Server state which need not survive a server crash. type VolatileState = { @@ -480,18 +482,37 @@ type Cluster<'a> = type Network<'a> = internal { - /// InboundMessages.[i] is the collection of messages sent to - /// server `i` and waiting for you to allow them through. - InboundMessages : ResizeArray>[] + /// CompleteMessageHistory.[i] is the collection of all messages + /// ever sent to server `i`. + CompleteMessageHistory : ResizeArray>[] + MessagesDelivered : HashSet[] + } + + static member Make (clusterSize : int) = + { + CompleteMessageHistory = Array.init clusterSize (fun _ -> ResizeArray ()) + MessagesDelivered = Array.init clusterSize (fun _ -> HashSet ()) } member this.AllInboundMessages (i : int) : Message<'a> list = - this.InboundMessages.[i / 1] |> List.ofSeq + this.CompleteMessageHistory.[i / 1] |> List.ofSeq member this.InboundMessage (i : int) (id : int) : Message<'a> = - this.InboundMessages.[i / 1].[id] + this.CompleteMessageHistory.[i / 1].[id] - member this.Size = this.InboundMessages.Length + member this.DropMessage (i : int) (id : int) = + this.MessagesDelivered.[i / 1].Add id |> ignore + + member this.UndeliveredMessages (i : int) : (int * Message<'a>) list = + this.CompleteMessageHistory.[i / 1] + |> Seq.indexed + |> Seq.filter (fun (count, _) -> this.MessagesDelivered.[i / 1].Contains count |> not) + |> List.ofSeq + + member this.AllUndeliveredMessages () : ((int * Message<'a>) list) list = + List.init this.CompleteMessageHistory.Length (fun i -> this.UndeliveredMessages (i * 1)) + + member this.ClusterSize = this.CompleteMessageHistory.Length [] module InMemoryCluster = @@ -500,15 +521,10 @@ module InMemoryCluster = let make<'a> (count : int) : Cluster<'a> * Network<'a> = let servers = Array.zeroCreate> count - let network = - { - InboundMessages = - fun _ -> ResizeArray> () - |> Array.init count - } + let network = Network.Make count let messageChannelHold (serverId : int) (message : Message<'a>) : unit = - let arr = network.InboundMessages.[serverId / 1] + let arr = network.CompleteMessageHistory.[serverId / 1] lock arr (fun () -> arr.Add message) for s in 0 .. servers.Length - 1 do diff --git a/RaftExplorer/Program.fs b/RaftExplorer/Program.fs index 4983186..72ef9f3 100644 --- a/RaftExplorer/Program.fs +++ b/RaftExplorer/Program.fs @@ -8,65 +8,68 @@ module Program = let printNetworkState<'a> (network : Network<'a>) : unit = let mutable wroteAnything = false - for i in 0 .. network.Size - 1 do - for count, message in Seq.indexed (network.AllInboundMessages (i * 1)) do - printfn "Server %i, message %i: %O" i count message + for i in 0 .. network.ClusterSize - 1 do + for messageId, message in network.UndeliveredMessages (i * 1) do + printfn "Server %i, message %i: %O" i messageId message wroteAnything <- true if not wroteAnything then printfn "" - let rec getMessage (clusterSize : int) = - printf "Enter : " - let s = Console.ReadLine () - + let getMessage (clusterSize : int) (s : string) : (int * int) option = match s.Split ',' with | [| serverId ; messageId |] -> + let serverId = serverId.Trim () + let messageId = messageId.Trim () + match Int32.TryParse serverId with | true, serverId -> match Int32.TryParse messageId with | true, messageId -> if serverId >= clusterSize || serverId < 0 then printf "Server ID must be between 0 and %i inclusive. " (clusterSize - 1) - getMessage clusterSize + None else - serverId * 1, messageId + Some (serverId * 1, messageId) | false, _ -> printf "Non-integer input '%s' for message ID. " messageId - getMessage clusterSize + None | false, _ -> printf "Non-integer input '%s' for server ID. " serverId - getMessage clusterSize + None | _ -> printfn "Invalid input." - getMessage clusterSize - - let rec getTimeout (clusterSize : int) = - printf "Enter server ID: " - let serverId = Console.ReadLine () + None + let rec getTimeout (clusterSize : int) (serverId : string) = match Int32.TryParse serverId with | true, serverId -> if serverId >= clusterSize || serverId < 0 then printf "Server ID must be between 0 and %i inclusive. " (clusterSize - 1) - getTimeout clusterSize + None else - serverId * 1 + Some (serverId * 1) | false, _ -> printf "Unrecognised input. " - getTimeout clusterSize + None type UserAction = | Timeout of int | NetworkMessage of int * int let rec getAction (clusterSize : int) = - printf "Enter action. Trigger [t]imeout, or allow [m]essage: " + printf "Enter action. Trigger [t]imeout , or allow [m]essage : " let s = Console.ReadLine().ToUpperInvariant () - match s with - | "T" -> getTimeout clusterSize |> Timeout - | "M" -> getMessage clusterSize |> NetworkMessage + match s.[0] with + | 'T' -> + match getTimeout clusterSize s.[1..] with + | Some t -> t |> Timeout + | None -> getAction clusterSize + | 'M' -> + match getMessage clusterSize s.[1..] with + | Some m -> m |> NetworkMessage + | None -> getAction clusterSize | _ -> printf "Unrecognised input. " getAction clusterSize @@ -85,5 +88,6 @@ module Program = | Timeout serverId -> cluster.Timeout serverId | NetworkMessage (serverId, messageId) -> network.InboundMessage serverId messageId |> cluster.SendMessage serverId + network.DropMessage serverId messageId 0