Consistent hashing for fun
I think consistent hashing is pretty fascinating. It lets you define a ring of machines that shard out data by a hash value. Imagine that your hash space is 0 -> Int.Max, and you have 2 machines. Well one machine gets all values hashed from 0 -> Int.Max/2 and the other from Int.Max/2 -> Int.Max. Clever. This is one of the major algorithms of distributed systems like cassandra and dynamoDB.
For a good visualization, check out this blog post.
The fun stuff happens when you want to add replication and fault tolerance to your hashing. Now you need to have replicants and manage when machines join and add. When someone joins, you need to re-partition the space evenly and re-distribute the values that were previously held.
Something similar when you have a node leave, you need to make sure that whatever it was responsible for in its primray space AND the things it was responsible for as a secondary replicant, are re-redistributed amongst the remaining nodes.
But the beauty of consistent hashing is that the replication basically happens for free! And so does redistribution!
Since my new feature is in all in Scala, I figured I’d write something up to see how this might play out in scala.
For the impatient, the full source ishere.
First I started with some data types
case class HashValue(value: String) extends AnyRef
case class HashKey(key: Int) extends AnyRef with Ordered[HashKey] {
override def compare(that: HashKey): Int = key.compare(that.key)
}
object HashKey {
def safe(key: Int) = new HashKey(Math.abs(key))
}
case class HashRange(minHash: HashKey, maxHash: HashKey) extends Ordered[HashRange] {
override def compare(that: HashRange): Int = minHash.compare(that.minHash)
}
I chose to wrap the key in a positive space since it made things slightly easier. In reality you want to use md5 or some actual hashing function, but I relied on the hash code here.
And then a machine to hold values:
import scala.collection.immutable.TreeMap
class Machine[TValue](val id: String) {
private var map: TreeMap[HashKey, TValue] = new TreeMap[HashKey, TValue]()
def add(key: HashKey, value: TValue): Unit = {
map = map + (key -\> value)
}
def get(hashKey: HashKey): Option[TValue] = {
map.get(hashKey)
}
def getValuesInHashRange(hashRange: HashRange): Seq[(HashKey, TValue)] ={
map.range(hashRange.minHash, hashRange.maxHash).toSeq
}
def keepOnly(hashRanges: Seq[HashRange]): Seq[(HashKey, TValue)] = {
val keepOnly: TreeMap[HashKey, TValue] =
hashRanges
.map(range =\> map.range(range.minHash, range.maxHash))
.fold(map.empty) { (tree1, tree2) =\> tree1 ++ tree2 }
val dropped = map.filter { case (k, v) =\> !keepOnly.contains(k) }
map = keepOnly
dropped.toSeq
}
}
A machine keeps a sorted tree map of hash values. This lets me really quickly get things within ranges. For example, when we re-partition a machine, it’s no longer responsible for the entire range set that it was before. But it may still be responsible for parts of it. So we want to be able to tell a machine hey, keep ranges 0-5, 12-20, but drop everything else. The tree map lets me do this really nicely.
Now for the fun part, the actual consistent hashing stuff.
Given a set of machines, we need to define how the circular partitions is defined
private def getPartitions(machines: Seq[Machine[TValue]]): Seq[(HashRange, Machine[TValue])] = {
val replicatedRanges: Seq[HashRange] = Stream.continually(defineRanges(machines.size)).flatten
val infiteMachines: Stream[Machine[TValue]] =
Stream.continually(machines.flatMap(List.fill(replicas)(\_))).flatten
replicatedRanges
.zip(infiteMachines)
.take(machines.size \* replicas)
.toList
}
What we want to make sure is that each node sits on multiple ranges, this gives us the replication factor. To do that I’ve duplicated the machines in the list by the replication factor, and made sure all the lists cycle around indefinteily, so while they are not evenly distributed around the ring (they are clustered) they do provide fault tolerance
Lets look at what it takes to put a value into the ring:
private def put(hashkey: HashKey, value: TValue): Unit = {
getReplicas(hashkey).foreach(\_.add(hashkey, value))
}
private def getReplicas(hashKey: HashKey): Seq[Machine[TValue]] = {
partitions
.filter { case (range, machine) =\> hashKey \>= range.minHash && hashKey \< range.maxHash }
.map(\_.\_2)
}
We need to make sure that for each replica in the ring that sits on a hash range, that we insert it into that machine. Thats pretty easy, though we can improve this later with better lookups
Lets look at a get
def get(hashKey: TKey): Option[TValue] = {
val key = HashKey.safe(hashKey.hashCode())
getReplicas(key)
.map(\_.get(key))
.collectFirst { case Some(x) =\> x }
}
Also similar. Go through all the replicas, and find the first one to return a value
Now lets look how to add a machine into the ring
def addMachine(): Machine[TValue] = {
id += 1
val newMachine = new Machine[TValue]("machine-" + id)
val oldMachines = partitions.map(\_.\_2).distinct
partitions = getPartitions(Seq(newMachine) ++ oldMachines)
redistribute(partitions)
newMachine
}
So we first create a new list of machines, and then ask how to re-partition the ring. Then the keys in the ring need to redistribute themselves so that only the nodes who are responsible for certain ranges contain those keys
def redistribute(newPartitions: Seq[(HashRange, Machine[TValue])]) = {
newPartitions.groupBy { case (range, machine) =\> machine }
.flatMap { case (machine, ranges) =\> machine.keepOnly(ranges.map(\_.\_1)) }
.foreach { case (k, v) =\> put(k, v) }
}
Redistributing isn’t that complicated either. We group all the nodes in the ring by the machine they are on, then for each machine we tell it to only keep values that are in its replicas. The machine keepOnly
function takes a list of ranges and will remove and return anything not in those ranges. We can now aggregate all the things that are “emitted” by the machines and re insert them into the right location
Removing a machine is really similiar
def removeMachine(machine: Machine[TValue]): Unit = {
val remainingMachines = partitions.filter { case (r, m) =\> !m.eq(machine) }.map(\_.\_2)
partitions = getPartitions(remainingMachines.distinct)
redistribute(partitions)
}
And thats all there is to it! Now we have a fast, simple consistent hasher.