继上篇文章稍微撕一下HashMap的源码之后,我们基本了解了HashMap
的工作机制和实现原理,这样我们就可以自己照着它的思路来自己实现一个基于Hash
算法的Map
结构容器
当然了,JDK对HashMap
有着相当多的优化手段,这里我们自己写的Map容器更像是原始的HashTable
(其实比HashTable
也差远了,不过基本思路到了就好),更别说是JDK1.8以后的带红黑树的HashMap
了,而且并发下也是不安全的。目前只实现了put
、get
和size
三个方法,后续再慢慢填坑扩容等等。。
1.0版本
成员变量
由于我们知道哈希表通常使用数组+链表的形式,因此我们定义存储结构为一个Node数组
每一个Node有四个成员变量,分别为
- 哈希值
- key
- value
- next指针
public class MyHashTable<K, V> {
static class Node<K, V> {
int hash;
K key;
V value;
Node<K, V> next;
public Node(int hash, K key, V value, Node<K, V> next) {
this.hash = hash;
this.key = key;
this.value = value;
this.next = next;
}
}
private final Node<K, V>[] nodes;
}
之后再定义一个size和初始长度(先不写扩容了),初始长度和HashMap的思路一样,为2的幂(想知道为什么的话可以看文章开头的链接的文章),以及构造方法(直接饿汉式构造)
private int size;
private int length = 16;
public MyHashTable() {
nodes = new Node[length];
size = 0;
}
hash方法
hash
方法是HashMap
中最核心的方法,对于这个方法我们也没啥好拓展的,参照HashMap
或者HashTable
就好了
其中关于算出hashcode
之后,进行无符号右移16位的原因是:可以减少碰撞,也称之为扰动函数。通过混合原始哈希码的高位和低位,以此来加大低位的随机性。而且混合后的低位掺杂了高位的部分特征,这样高位的信息也被变相保留下来。具体细节可以看这篇知乎链接的解释
public int hash(Object obj) {
if (obj == null) return 0;
int h = obj.hashCode();
return h ^ (h >>> 16);
}
put方法
put方法的逻辑为:
- 计算
key
的hash
并得出下标index
- 如果数组的
index
位置为null
的话,则直接插入 - 否则判断
index
位置的元素的hash
是否与插入元素的hash
相等,再判断是否equals
,是的话则覆盖当前位置 - 否则继续遍历链表,并判断是否冲突,直到链表尾结点,如果冲突则覆盖,判断方式和上面一样
- 否则将待插入的
node
节点插入链表尾部 - 插入完成之后元素计数器值加一
public void put(K k, V v) {
int hash = hash(k);
int index = hash & (length - 1);
if (nodes[index] == null) nodes[index] = new Node<>(hash, k, v, null);
else {
if (nodes[index].hash == hash && k.equals(nodes[index].key)) nodes[index] = new Node<>(hash, k, v, null);
else {
Node<K, V> node = nodes[index];
while (node.next != null) {
if (node.next.hash == hash && node.next.key.equals(k)) node.next = new Node<>(hash, k, v, null);
node = node.next;
}
node.next = new Node<>(hash, k, v, null);
}
}
size++;
}
get方法
get方法的逻辑为:
- 计算
key
的hash
并得出数组索引index
- 遍历
index
处的链表,判断是否有hash
等于key
的hash
以及equals
成立的节点node
,有则返回node
的value
- 直到遍历完链表还没有得到元素,则返回
null
public V get(Object obj) {
int hash = hash(obj);
int index = hash & (length - 1);
for (Node<K, V> node = nodes[index]; node != null; node = node.next) {
if (node.hash == hash && node.key.equals(obj)) return node.value;
}
return null;
}
size方法
size方法就简单了,因为我们设置了数组元素个数计数器,我们返回这个即可
public int size() {
return size;
}
最终源码与测试
经过上面的定义,我们实现的源码为
public class MyHashTable<K, V> {
static class Node<K, V> {
int hash;
K key;
V value;
Node<K, V> next;
public Node(int hash, K key, V value, Node<K, V> next) {
this.hash = hash;
this.key = key;
this.value = value;
this.next = next;
}
}
private final Node<K, V>[] nodes;
private int size;
private int length = 16;
public MyHashTable() {
nodes = new Node[length];
size = 0;
}
public void put(K k, V v) {
int hash = hash(k);
int index = hash & (length - 1);
if (nodes[index] == null) nodes[index] = new Node<>(hash, k, v, null);
else {
if (nodes[index].hash == hash && k.equals(nodes[index].key)) nodes[index] = new Node<>(hash, k, v, null);
else {
Node<K, V> node = nodes[index];
while (node.next != null) {
if (node.next.hash == hash && node.next.key.equals(k)) node.next = new Node<>(hash, k, v, null);
node = node.next;
}
node.next = new Node<>(hash, k, v, null);
}
}
size++;
}
public V get(Object obj) {
int hash = hash(obj);
int index = hash & (length - 1);
for (Node<K, V> node = nodes[index]; node != null; node = node.next) {
if (node.hash == hash && node.key.equals(obj)) return (V) node.value;
}
return null;
}
public int hash(Object obj) {
if (obj == null) return 0;
int h = obj.hashCode();
return h ^ (h >>> 16);
}
public int size() {
return size;
}
}
出现问题并排查
接下来我们测试一下各个功能是否好用
public static void main(String[] args) {
MyHashTable<Integer, Integer> intMap = new MyHashTable<>();
MyHashTable<String, Integer> strMap = new MyHashTable<>();
intMap.put(1,10);
intMap.put(2,20);
intMap.put(17, 170);
intMap.put(1, 1);
System.out.println(intMap.size());
System.out.println(intMap.get(1));
System.out.println(intMap.get(2));
System.out.println(intMap.get(17));
strMap.put("akb", 46);
strMap.put("tsh", 48);
strMap.put("akb", 48);
System.out.println(strMap.size());
System.out.println(strMap.get("akb"));
System.out.println(strMap.get("tsh"));
}
输出:
4
1
20
null
3
48
48
发现本应该被拉链法拉到1后面的17突然找不到了,变成了null
之后调换一下put 17和put 1的位置后
public static void main(String[] args) {
MyHashTable<Integer, Integer> intMap = new MyHashTable<>();
MyHashTable<String, Integer> strMap = new MyHashTable<>();
intMap.put(1,10);
intMap.put(1, 1);
intMap.put(17, 170);
System.out.println(intMap.get(1));
System.out.println(intMap.get(17));
}
1
170
又可以正常找到170了,说明之前一次成功的拉链上去了。
查看put方法之后发现了倪端:
public void put(K k, V v) {
int hash = hash(k);
int index = hash & (length - 1);
if (nodes[index] == null) nodes[index] = new Node<>(hash, k, v, null);
else {
if (nodes[index].hash == hash && k.equals(nodes[index].key)) nodes[index] = new Node<>(hash, k, v, null);
else {
Node<K, V> node = nodes[index];
while (node.next != null) {
if (node.next.hash == hash && node.next.key.equals(k)) node.next = new Node<>(hash, k, v, null);
node = node.next;
}
node.next = new Node<>(hash, k, v, null);
}
}
size++;
}
这里发现判断如果待插入的Key的hash值与数组当前位置元素相同并且equals的话,直接new一个新的node并覆盖上去,这样就会丢失指针,解决方法有两种
- new node的时候传入next参数为之前的node的next
- 不采用new node的方式,而采用修改node中成员变量value的方式来达成替换元素的作用
1.1 版本
针对以下的测试用例:
MyHashTable<Integer, Integer> intMap = new MyHashTable<>();
MyHashTable<String, Integer> strMap = new MyHashTable<>();
intMap.put(1,10);
intMap.put(2,20);
intMap.put(1, 1);
intMap.put(17, 170);
System.out.println(intMap.size());
System.out.println(intMap.get(1));
System.out.println(intMap.get(2));
System.out.println(intMap.get(17));
strMap.put("akb", 46);
strMap.put("tsh", 48);
strMap.put("akb", 48);
System.out.println(strMap.size());
System.out.println(strMap.get("akb"));
System.out.println(strMap.get("tsh"));
第一种方法
修改put方法为:
public void put(K k, V v) {
int hash = hash(k);
int index = hash & (length - 1);
if (nodes[index] == null) nodes[index] = new Node<>(hash, k, v, null);
else {
if (nodes[index].hash == hash && k.equals(nodes[index].key)) nodes[index] = new Node<>(hash, k, v, nodes[index].next);
else {
Node<K, V> node = nodes[index];
while (node.next != null) {
if (node.next.hash == hash && node.next.key.equals(k)) node.next = new Node<>(hash, k, v, node.next);
node = node.next;
}
node.next = new Node<>(hash, k, v, null);
}
}
size++;
}
输出结果为:
4
1
20
170
3
48
48
第二种方法
修改put方法为:
public void put(K k, V v) {
int hash = hash(k);
int index = hash & (length - 1);
if (nodes[index] == null) nodes[index] = new Node<>(hash, k, v, null);
else {
if (nodes[index].hash == hash && k.equals(nodes[index].key)) nodes[index].value = v;
else {
Node<K, V> node = nodes[index];
while (node.next != null) {
if (node.next.hash == hash && node.next.key.equals(k)) node.next.value = v;
node = node.next;
}
node.next = new Node<>(hash, k, v, null);
}
}
size++;
}
4
1
20
170
3
48
48
均通过了测试,大功告成
最后附上最终代码
public class MyHashTable<K, V> {
static class Node<K, V> {
int hash;
K key;
V value;
Node<K, V> next;
public Node(int hash, K key, V value, Node<K, V> next) {
this.hash = hash;
this.key = key;
this.value = value;
this.next = next;
}
}
private final Node<K, V>[] nodes;
private int size;
private int length = 16;
public MyHashTable() {
nodes = new Node[length];
size = 0;
}
public void put(K k, V v) {
int hash = hash(k);
int index = hash & (length - 1);
if (nodes[index] == null) nodes[index] = new Node<>(hash, k, v, null);
else {
if (nodes[index].hash == hash && k.equals(nodes[index].key)) nodes[index].value = v;
else {
Node<K, V> node = nodes[index];
while (node.next != null) {
if (node.next.hash == hash && node.next.key.equals(k)) node.next.value = v;
node = node.next;
}
node.next = new Node<>(hash, k, v, null);
}
}
size++;
}
public V get(Object obj) {
int hash = hash(obj);
int index = hash & (length - 1);
for (Node<K, V> node = nodes[index]; node != null; node = node.next) {
if (node.hash == hash && node.key.equals(obj)) return (V) node.value;
}
return null;
}
public int hash(Object obj) {
if (obj == null) return 0;
int h = obj.hashCode();
return h ^ (h >>> 16);
}
public int size() {
return size;
}
}